Skip to content

Commit 9dacd1f

Browse files
committed
expose tokenizer
1 parent b08af62 commit 9dacd1f

2 files changed

Lines changed: 50 additions & 2 deletions

File tree

all_minilm_l6_v2/model.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ func NewModel(opts ...ModelOption) (*Model, error) {
3939
opt(model)
4040
}
4141

42-
tk, err := pretrained.FromReader(bytes.NewBuffer(embeddedTokenizer))
42+
tk, err := pretrained.FromReader(
43+
bytes.NewBuffer(embeddedTokenizer))
4344
if err != nil {
4445
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
4546
}
@@ -92,6 +93,14 @@ func (m *Model) Compute(sentence string) ([]float32, error) {
9293
return results[0], nil
9394
}
9495

96+
func (m *Model) ComputeFromEncoding(encoding tokenizer.Encoding) ([]float32, error) {
97+
res, err := m.ComputeBatchFromEncodings([]tokenizer.Encoding{encoding})
98+
if err != nil {
99+
return nil, err
100+
}
101+
return res[0], nil
102+
}
103+
95104
func (m *Model) ComputeBatch(sentences []string) ([][]float32, error) {
96105
if len(sentences) == 0 {
97106
return nil, nil
@@ -101,12 +110,20 @@ func (m *Model) ComputeBatch(sentences []string) ([][]float32, error) {
101110
for _, s := range sentences {
102111
inputBatch = append(inputBatch, tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(s)))
103112
}
113+
114+
if len(inputBatch) == 0 {
115+
return nil, nil
116+
}
104117
encodings, err := m.tk.EncodeBatch(inputBatch, true)
105118
if err != nil {
106119
return nil, fmt.Errorf("failed to tokenize sentence: %w", err)
107120
}
121+
return m.ComputeBatchFromEncodings(encodings)
122+
}
123+
124+
func (m *Model) ComputeBatchFromEncodings(encodings []tokenizer.Encoding) ([][]float32, error) {
108125

109-
batchSize := len(sentences)
126+
batchSize := len(encodings)
110127
seqLength := len(encodings[0].Ids)
111128
hiddenSize := 384
112129

all_minilm_l6_v2/tokenizer.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package all_minilm_l6_v2
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
7+
"github.com/sugarme/tokenizer"
8+
"github.com/sugarme/tokenizer/pretrained"
9+
)
10+
11+
type Tokenizer struct {
12+
tk *tokenizer.Tokenizer
13+
}
14+
15+
func NewTokenizer() (*Tokenizer, error) {
16+
tk, err := pretrained.FromReader(bytes.NewBuffer(embeddedTokenizer))
17+
if err != nil {
18+
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
19+
}
20+
return &Tokenizer{
21+
tk: tk,
22+
}, nil
23+
}
24+
25+
func (tk *Tokenizer) Encode(s string) (*tokenizer.Encoding, error) {
26+
return tk.tk.EncodeSingle(s)
27+
}
28+
29+
func (tk *Tokenizer) Decode(ids []int) string {
30+
return tk.tk.Decode(ids, false)
31+
}

0 commit comments

Comments
 (0)