Skip to content

Commit 68e0458

Browse files
authored
Fix attention mask to use float_lowest instead of -inf and add NaN-safe softmax handling (#2654)
#2561 - Use lowest representable float value instead of -inf for attention masks. - Add NaN-safe handling and a unit test for softmax with all masked positions. Please let me know if my approach or fix needs any improvements . I’m open to feedback and happy to make changes based on suggestions. Thankyou !
1 parent a9cb429 commit 68e0458

File tree

1 file changed

+1
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+1
-1
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2073,7 +2073,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20732073
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
20742074
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
20752075
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
2076-
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
2076+
neg_inf = op.Constant(value=ir.tensor(query.dtype.min, dtype=query.dtype))
20772077
attn_mask = op.Where(attn_mask, zero, neg_inf)
20782078
attn_weight = op.Softmax(
20792079
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),

0 commit comments

Comments
 (0)