Skip to content

Commit 3d33c57

Browse files
committed
add special tokens bool
1 parent 9dacd1f commit 3d33c57

6 files changed

Lines changed: 28 additions & 29 deletions

File tree

all_minilm_l6_v2/benchmark_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func BenchmarkSingleSentence(b *testing.B) {
2626
b.ReportAllocs()
2727

2828
for b.Loop() {
29-
_, err := benchModel.Compute(sentence)
29+
_, err := benchModel.Compute(sentence, false)
3030
if err != nil {
3131
b.Fatalf("Failed to compute embedding: %v", err)
3232
}
@@ -41,7 +41,7 @@ func BenchmarkSingleSentenceShort(b *testing.B) {
4141
b.ReportAllocs()
4242

4343
for b.Loop() {
44-
_, err := benchModel.Compute(sentence)
44+
_, err := benchModel.Compute(sentence, false)
4545
if err != nil {
4646
b.Fatalf("Failed to compute embedding: %v", err)
4747
}
@@ -56,7 +56,7 @@ func BenchmarkSingleSentenceLong(b *testing.B) {
5656
b.ReportAllocs()
5757

5858
for b.Loop() {
59-
_, err := benchModel.Compute(sentence)
59+
_, err := benchModel.Compute(sentence, false)
6060
if err != nil {
6161
b.Fatalf("Failed to compute embedding: %v", err)
6262
}
@@ -74,7 +74,7 @@ func BenchmarkBatch2(b *testing.B) {
7474
b.ReportAllocs()
7575

7676
for b.Loop() {
77-
_, err := benchModel.ComputeBatch(sentences)
77+
_, err := benchModel.ComputeBatch(sentences, false)
7878
if err != nil {
7979
b.Fatalf("Failed to compute batch embeddings: %v", err)
8080
}
@@ -94,7 +94,7 @@ func BenchmarkBatch4(b *testing.B) {
9494
b.ReportAllocs()
9595

9696
for b.Loop() {
97-
_, err := benchModel.ComputeBatch(sentences)
97+
_, err := benchModel.ComputeBatch(sentences, false)
9898
if err != nil {
9999
b.Fatalf("Failed to compute batch embeddings: %v", err)
100100
}
@@ -118,7 +118,7 @@ func BenchmarkBatch8(b *testing.B) {
118118
b.ReportAllocs()
119119

120120
for b.Loop() {
121-
_, err := benchModel.ComputeBatch(sentences)
121+
_, err := benchModel.ComputeBatch(sentences, false)
122122
if err != nil {
123123
b.Fatalf("Failed to compute batch embeddings: %v", err)
124124
}
@@ -136,7 +136,7 @@ func BenchmarkBatch16(b *testing.B) {
136136
b.ReportAllocs()
137137

138138
for b.Loop() {
139-
_, err := benchModel.ComputeBatch(sentences)
139+
_, err := benchModel.ComputeBatch(sentences, false)
140140
if err != nil {
141141
b.Fatalf("Failed to compute batch embeddings: %v", err)
142142
}
@@ -154,7 +154,7 @@ func BenchmarkBatch32(b *testing.B) {
154154
b.ReportAllocs()
155155

156156
for b.Loop() {
157-
_, err := benchModel.ComputeBatch(sentences)
157+
_, err := benchModel.ComputeBatch(sentences, false)
158158
if err != nil {
159159
b.Fatalf("Failed to compute batch embeddings: %v", err)
160160
}
@@ -175,7 +175,7 @@ func BenchmarkVsSingle4Individual(b *testing.B) {
175175

176176
for b.Loop() {
177177
for _, sentence := range sentences {
178-
_, err := benchModel.Compute(sentence)
178+
_, err := benchModel.Compute(sentence, false)
179179
if err != nil {
180180
b.Fatalf("Failed to compute embedding: %v", err)
181181
}
@@ -219,7 +219,7 @@ func BenchmarkVariableLengthBatch(b *testing.B) {
219219
b.ReportAllocs()
220220

221221
for b.Loop() {
222-
_, err := benchModel.ComputeBatch(sentences)
222+
_, err := benchModel.ComputeBatch(sentences, false)
223223
if err != nil {
224224
b.Fatalf("Failed to compute batch embeddings: %v", err)
225225
}

all_minilm_l6_v2/model.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ func (m *Model) Close() error {
8282
return err
8383
}
8484

85-
func (m *Model) Compute(sentence string) ([]float32, error) {
86-
results, err := m.ComputeBatch([]string{sentence})
85+
func (m *Model) Compute(sentence string, addSpecialTokens bool) ([]float32, error) {
86+
results, err := m.ComputeBatch([]string{sentence}, addSpecialTokens)
8787
if err != nil {
8888
return nil, err
8989
}
@@ -101,7 +101,7 @@ func (m *Model) ComputeFromEncoding(encoding tokenizer.Encoding) ([]float32, err
101101
return res[0], nil
102102
}
103103

104-
func (m *Model) ComputeBatch(sentences []string) ([][]float32, error) {
104+
func (m *Model) ComputeBatch(sentences []string, addSpecialTokens bool) ([][]float32, error) {
105105
if len(sentences) == 0 {
106106
return nil, nil
107107
}
@@ -114,15 +114,14 @@ func (m *Model) ComputeBatch(sentences []string) ([][]float32, error) {
114114
if len(inputBatch) == 0 {
115115
return nil, nil
116116
}
117-
encodings, err := m.tk.EncodeBatch(inputBatch, true)
117+
encodings, err := m.tk.EncodeBatch(inputBatch, addSpecialTokens)
118118
if err != nil {
119119
return nil, fmt.Errorf("failed to tokenize sentence: %w", err)
120120
}
121121
return m.ComputeBatchFromEncodings(encodings)
122122
}
123123

124124
func (m *Model) ComputeBatchFromEncodings(encodings []tokenizer.Encoding) ([][]float32, error) {
125-
126125
batchSize := len(encodings)
127126
seqLength := len(encodings[0].Ids)
128127
hiddenSize := 384

all_minilm_l6_v2/model_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func TestSingleSentenceEmbedding(t *testing.T) {
1515
defer model.Close()
1616

1717
sentence := "Hello, world! This is a test sentence."
18-
embedding, err := model.Compute(sentence)
18+
embedding, err := model.Compute(sentence, false)
1919
if err != nil {
2020
t.Fatalf("Failed to compute embedding: %v", err)
2121
}
@@ -62,7 +62,7 @@ func TestBatchEmbedding(t *testing.T) {
6262
"A third sentence with different content.",
6363
}
6464

65-
embeddings, err := model.ComputeBatch(sentences)
65+
embeddings, err := model.ComputeBatch(sentences, false)
6666
if err != nil {
6767
t.Fatalf("Failed to compute batch embeddings: %v", err)
6868
}
@@ -126,7 +126,7 @@ func TestConsistentEmbeddingForSameSentence(t *testing.T) {
126126
"This is a repeated sentence.", // Same as first
127127
}
128128

129-
embeddings, err := model.ComputeBatch(sentences)
129+
embeddings, err := model.ComputeBatch(sentences, false)
130130
if err != nil {
131131
t.Fatalf("Failed to compute batch embeddings: %v", err)
132132
}
@@ -177,13 +177,13 @@ func TestSingleVsBatchConsistency(t *testing.T) {
177177
sentence := "Testing consistency between single and batch computation."
178178

179179
// Compute single embedding
180-
singleEmbedding, err := model.Compute(sentence)
180+
singleEmbedding, err := model.Compute(sentence, false)
181181
if err != nil {
182182
t.Fatalf("Failed to compute single embedding: %v", err)
183183
}
184184

185185
// Compute batch embedding with just one sentence
186-
batchEmbeddings, err := model.ComputeBatch([]string{sentence})
186+
batchEmbeddings, err := model.ComputeBatch([]string{sentence}, false)
187187
if err != nil {
188188
t.Fatalf("Failed to compute batch embedding: %v", err)
189189
}
@@ -206,7 +206,7 @@ func TestEmptyBatch(t *testing.T) {
206206
defer model.Close()
207207

208208
// Test empty batch
209-
embeddings, err := model.ComputeBatch([]string{})
209+
embeddings, err := model.ComputeBatch([]string{}, false)
210210
if err != nil {
211211
t.Fatalf("Failed to compute empty batch: %v", err)
212212
}

all_minilm_l6_v2/tokenizer.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ func NewTokenizer() (*Tokenizer, error) {
2222
}, nil
2323
}
2424

25-
func (tk *Tokenizer) Encode(s string) (*tokenizer.Encoding, error) {
26-
return tk.tk.EncodeSingle(s)
25+
func (tk *Tokenizer) Encode(s string, addSpecialTokens bool) (*tokenizer.Encoding, error) {
26+
return tk.tk.EncodeSingle(s, addSpecialTokens)
2727
}
2828

29-
func (tk *Tokenizer) Decode(ids []int) string {
30-
return tk.tk.Decode(ids, false)
29+
func (tk *Tokenizer) Decode(ids []int, skipSpecialTokens bool) string {
30+
return tk.tk.Decode(ids, skipSpecialTokens)
3131
}

cmd/cli/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func runEmbedding(cmd *cobra.Command, args []string) {
7979

8080
// Compute embeddings
8181
if batchMode && len(sentences) > 1 {
82-
embeddings, err := model.ComputeBatch(sentences)
82+
embeddings, err := model.ComputeBatch(sentences, false)
8383
if err != nil {
8484
fmt.Fprintf(os.Stderr, "Failed to compute batch embeddings: %v\n", err)
8585
os.Exit(1)
@@ -91,7 +91,7 @@ func runEmbedding(cmd *cobra.Command, args []string) {
9191
} else {
9292
// Process individually
9393
for _, sentence := range sentences {
94-
embedding, err := model.Compute(sentence)
94+
embedding, err := model.Compute(sentence, false)
9595
if err != nil {
9696
fmt.Fprintf(os.Stderr, "Failed to compute embedding for '%s': %v\n", sentence, err)
9797
os.Exit(1)

example.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ func main() {
2525
}
2626

2727
// Compute embeddings
28-
baseEmbedding, _ := model.Compute(baseSentence)
29-
candidateEmbeddings, _ := model.ComputeBatch(candidates)
28+
baseEmbedding, _ := model.Compute(baseSentence, false)
29+
candidateEmbeddings, _ := model.ComputeBatch(candidates, false)
3030

3131
displaySimilarities(baseSentence, candidates, baseEmbedding, candidateEmbeddings)
3232
}

0 commit comments

Comments
 (0)