Skip to content

Commit 0226bb3

Browse files
Copilotjustinchuby
andcommitted
fix: add explanatory comment for linear_before_reset=1 in aten_gru
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 170a407 commit 0226bb3

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4348,6 +4348,11 @@ def aten_gru(
43484348
# Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size]
43494349
hidden_size_attr = hx.shape[2]
43504350

4351+
# linear_before_reset=1 matches PyTorch's GRU formulation where the linear
4352+
# transformation is applied before multiplying by the reset gate:
4353+
# ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh)
4354+
# The ONNX default (linear_before_reset=0) uses a different equation and
4355+
# would produce numerically incorrect results.
43514356
if B is not None:
43524357
Y, Y_h = op.GRU(
43534358
current_input,

0 commit comments

Comments
 (0)