We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 170a407 commit 0226bb3Copy full SHA for 0226bb3
1 file changed
onnxscript/function_libs/torch_lib/ops/core.py
@@ -4348,6 +4348,11 @@ def aten_gru(
4348
# Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size]
4349
hidden_size_attr = hx.shape[2]
4350
4351
+ # linear_before_reset=1 matches PyTorch's GRU formulation where the linear
4352
+ # transformation is applied before multiplying by the reset gate:
4353
+ # ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh)
4354
+ # The ONNX default (linear_before_reset=0) uses a different equation and
4355
+ # would produce numerically incorrect results.
4356
if B is not None:
4357
Y, Y_h = op.GRU(
4358
current_input,
0 commit comments