Skip to content

Commit 919f3ea

Browse files
aneesh-dbylwu-amzn
andauthored
Add NONE pooling mode to support pre-pooled model outputs (#4710)
The documentation states that sentence-transformer models traced with the sentence-transformers library include post-processing, implying no additional pooling is needed. However, the code always applies MEAN pooling by default, even for models that already provide pre-pooled sentence_embedding output. This adds NONE as a pooling option so that models with pre-computed sentence embeddings can use those outputs directly without redundant pooling computation. Changes: - Add NONE to PoolingMode enum in BaseModelConfig - Update ONNXSentenceTransformerTextEmbeddingTranslator to use second output (sentence_embedding) when pooling_mode is NONE - Update HuggingfaceTextEmbeddingTranslator to support NONE pooling with fallback logic for various output formats - Add unit tests for NONE pooling in both ONNX and TorchScript - Update documentation with NONE pooling description - Add release notes entry Resolves #4708 Signed-off-by: Aneesh Nema <aneesh.nema@databricks.com> Signed-off-by: Yaliang Wu <ylwu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com>
1 parent fc7b333 commit 919f3ea

File tree

6 files changed

+70
-9
lines changed

6 files changed

+70
-9
lines changed

common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ public enum PoolingMode {
247247
MAX("max"),
248248
WEIGHTED_MEAN("weightedmean"),
249249
CLS("cls"),
250-
LAST_TOKEN("lasttoken");
250+
LAST_TOKEN("lasttoken"),
251+
NONE("none");
251252

252253
private String name;
253254

docs/model_serving_framework/text_embedding_model_examples.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ POST /_plugins/_ml/models/zwla5YUB1qmVrJFlwzXJ/_unload
295295
Without [`sentence-transformers`](https://pypi.org/project/sentence-transformers/) installed, you can trace this model `AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')`.
296296
But model traced this way doesn't include post-processing. So user have to specify post-process logic with `pooling_mode` and `normalize_result`.
297297

298-
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`.
298+
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`, `none`.
299299

300300
The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model).
301301

@@ -322,7 +322,7 @@ POST /_plugins/_ml/models/_upload
322322
User can export Pytorch model to ONNX, then upload and run it with ml-commons APIs.
323323
This example ONNX model also needs to specify post-process logic with `pooling_mode` and `normalize_result`.
324324

325-
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`.
325+
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`, `none`.
326326

327327
### Pooling Methods
328328

@@ -334,6 +334,7 @@ Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`,
334334
| `weightedmean` | Weighted average where later tokens have higher weights |
335335
| `cls` | Uses the first token (CLS token) embedding |
336336
| `lasttoken` | Uses the last non-padding token's embedding. Useful for decoder-only models where the final token captures cumulative context |
337+
| `none` | Uses pre-pooled output from model directly without additional pooling computation. Use when model already provides pooled embeddings (e.g., `sentence_embedding` or `pooler_output`). Avoids redundant pooling and matches original model behavior |
337338

338339
The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model).
339340

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,29 @@ public NDList processInput(TranslatorContext ctx, String input) {
7878
/** {@inheritDoc} */
7979
@Override
8080
public float[] processOutput(TranslatorContext ctx, NDList list) {
81-
NDArray embeddings = list.get("last_hidden_state");
82-
if (embeddings == null) {
83-
embeddings = list.get(0);
81+
NDArray embeddings;
82+
83+
// NONE pooling mode uses pre-pooled output directly if available
84+
if ("none".equals(pooling)) {
85+
// Try to get pre-pooled output (sentence_embedding, pooler_output, etc.)
86+
embeddings = list.get("sentence_embedding");
87+
if (embeddings == null) {
88+
embeddings = list.get("pooler_output");
89+
}
90+
if (embeddings == null && list.size() > 1) {
91+
// Use second output if available
92+
embeddings = list.get(1);
93+
}
94+
if (embeddings == null) {
95+
// Fallback to first output
96+
embeddings = list.get(0);
97+
}
98+
} else {
99+
// For other pooling modes, use last_hidden_state or first output
100+
embeddings = list.get("last_hidden_state");
101+
if (embeddings == null) {
102+
embeddings = list.get(0);
103+
}
84104
}
85105
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
86106
long[] attentionMask = encoding.getAttentionMask();
@@ -105,6 +125,9 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
105125
case "lasttoken":
106126
embeddings = lastTokenPool(embeddings, inputAttentionMask);
107127
break;
128+
case "none":
129+
// No pooling - use pre-pooled output as-is
130+
break;
108131
default:
109132
throw new AssertionError("Unexpected pooling model: " + pooling);
110133
}
@@ -232,9 +255,10 @@ public HuggingfaceTextEmbeddingTranslator.Builder optPoolingMode(String poolingM
232255
&& !"cls".equals(poolingMode)
233256
&& !"mean_sqrt_len".equals(poolingMode)
234257
&& !"weightedmean".equals(poolingMode)
235-
&& !"lasttoken".equals(poolingMode)) {
258+
&& !"lasttoken".equals(poolingMode)
259+
&& !"none".equals(poolingMode)) {
236260
throw new IllegalArgumentException(
237-
"Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len, weightedmean, lasttoken]."
261+
"Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len, weightedmean, lasttoken, none]."
238262
);
239263
}
240264
this.pooling = poolingMode;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,16 @@ public NDList processInput(TranslatorContext ctx, Input input) {
8787
/** {@inheritDoc} */
8888
@Override
8989
public Output processOutput(TranslatorContext ctx, NDList list) {
90-
NDArray embeddings = list.get(0);
90+
NDArray embeddings;
91+
92+
// NONE pooling mode uses pre-pooled output directly if available
93+
if (this.poolingMode == TextEmbeddingModelConfig.PoolingMode.NONE && list.size() > 1) {
94+
// Use the second output (sentence_embedding) which is pre-pooled
95+
embeddings = list.get(1);
96+
} else {
97+
// Use first output (token_embeddings) for explicit pooling
98+
embeddings = list.get(0);
99+
}
91100
int shapeLength = embeddings.getShape().getShape().length;
92101
if (shapeLength == 3) {
93102
embeddings = embeddings.get(0);
@@ -115,6 +124,9 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
115124
case LAST_TOKEN:
116125
embeddings = lastTokenPool(embeddings, inputAttentionMask);
117126
break;
127+
case NONE:
128+
// No pooling - use pre-pooled output as-is
129+
break;
118130
default:
119131
throw new IllegalArgumentException("Unsupported pooling method");
120132
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,17 @@ public void initModel_predict_ONNX_LastTokenPooling() throws URISyntaxException
225225
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
226226
}
227227

228+
@Test
229+
public void initModel_predict_ONNX_NonePooling() throws URISyntaxException {
230+
String modelFile = "all-MiniLM-L6-v2_onnx.zip";
231+
String modelType = "bert";
232+
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.NONE;
233+
boolean normalize = true;
234+
int modelMaxLength = 512;
235+
MLModelFormat modelFormat = MLModelFormat.ONNX;
236+
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
237+
}
238+
228239
@Test
229240
public void initModel_predict_TorchScript_Huggingface_LastTokenPooling() throws URISyntaxException {
230241
String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip";
@@ -236,6 +247,17 @@ public void initModel_predict_TorchScript_Huggingface_LastTokenPooling() throws
236247
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
237248
}
238249

250+
@Test
251+
public void initModel_predict_TorchScript_Huggingface_NonePooling() throws URISyntaxException {
252+
String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip";
253+
String modelType = "bert";
254+
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.NONE;
255+
boolean normalize = true;
256+
int modelMaxLength = 512;
257+
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
258+
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
259+
}
260+
239261
private void initModel_predict_HuggingfaceModel(
240262
String modelFile,
241263
String modelType,

release-notes/opensearch-ml-commons.release-notes-3.4.0.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Compatible with OpenSearch and OpenSearch Dashboards version 3.4.0
99
* allow higher maximum number of batch inference job tasks ([#4474](https://github.com/opensearch-project/ml-commons/pull/4474))
1010

1111
### Bug Fixes
12+
* Add NONE pooling mode to support pre-pooled model outputs, fixing bug where MEAN pooling was applied by default ([#4708](https://github.com/opensearch-project/ml-commons/issues/4708))
1213
* Fix agent type update ([#4341](https://github.com/opensearch-project/ml-commons/pull/4341))
1314
* Handle edge case of empty values of tool configs ([#4479](https://github.com/opensearch-project/ml-commons/pull/4479))
1415
* Fix OpenAI RAG integration tests: Replace Wikimedia image URL with Unsplash ([#4472](https://github.com/opensearch-project/ml-commons/pull/4472))

0 commit comments

Comments
 (0)