@@ -39,7 +39,8 @@ func NewModel(opts ...ModelOption) (*Model, error) {
3939 opt (model )
4040 }
4141
42- tk , err := pretrained .FromReader (bytes .NewBuffer (embeddedTokenizer ))
42+ tk , err := pretrained .FromReader (
43+ bytes .NewBuffer (embeddedTokenizer ))
4344 if err != nil {
4445 return nil , fmt .Errorf ("failed to load tokenizer: %w" , err )
4546 }
@@ -92,6 +93,14 @@ func (m *Model) Compute(sentence string) ([]float32, error) {
9293 return results [0 ], nil
9394}
9495
96+ func (m * Model ) ComputeFromEncoding (encoding tokenizer.Encoding ) ([]float32 , error ) {
97+ res , err := m .ComputeBatchFromEncodings ([]tokenizer.Encoding {encoding })
98+ if err != nil {
99+ return nil , err
100+ }
101+ return res [0 ], nil
102+ }
103+
95104func (m * Model ) ComputeBatch (sentences []string ) ([][]float32 , error ) {
96105 if len (sentences ) == 0 {
97106 return nil , nil
@@ -101,12 +110,20 @@ func (m *Model) ComputeBatch(sentences []string) ([][]float32, error) {
101110 for _ , s := range sentences {
102111 inputBatch = append (inputBatch , tokenizer .NewSingleEncodeInput (tokenizer .NewRawInputSequence (s )))
103112 }
113+
114+ if len (inputBatch ) == 0 {
115+ return nil , nil
116+ }
104117 encodings , err := m .tk .EncodeBatch (inputBatch , true )
105118 if err != nil {
106119 return nil , fmt .Errorf ("failed to tokenize sentence: %w" , err )
107120 }
121+ return m .ComputeBatchFromEncodings (encodings )
122+ }
123+
124+ func (m * Model ) ComputeBatchFromEncodings (encodings []tokenizer.Encoding ) ([][]float32 , error ) {
108125
109- batchSize := len (sentences )
126+ batchSize := len (encodings )
110127 seqLength := len (encodings [0 ].Ids )
111128 hiddenSize := 384
112129
0 commit comments