Skip to content

Commit 83df7df

Browse files
authored
Fix {lstm,gru}: slice hidden/cell state with sequence (#2803)
Per the spec, these two fields should be 1D-tensors: > - starts (heterogeneous) - Tind: > 1-D tensor of starting indices of corresponding axis in axes > > - ends (heterogeneous) - Tind: > 1-D tensor of ending indices (exclusive) of corresponding axis in axes With the original code; the output is like this when printing the graph: ``` <snip> int64 val_1 = {1}, int64 val_3 = {0}, <snip> [node_Slice_5] val_5 = Slice (transpose_1, val_3, val_1, val_4) [node_Slice_6] val_6 = Slice (transpose_2, val_3, val_1, val_4) ``` With the changes: ``` <snip> int64[1] val_3 = {0}, int64[1] val_4 = {1}, <snip> [node_Slice_5] val_5 = Slice (transpose_1, val_3, val_4, val_3) [node_Slice_6] val_6 = Slice (transpose_2, val_3, val_4, val_3) ``` This being scalar was seemingly accepted by the ONNX reference evaluator and potentially other runtimes, but not in [tract](https://github.com/sonos/tract).
1 parent 5ae26d9 commit 83df7df

File tree

1 file changed

+3
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+3
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,7 +4254,7 @@ def aten_gru(
42544254
# Extract hidden state for this layer
42554255
layer_start = layer_idx * num_directions
42564256
layer_end = (layer_idx + 1) * num_directions
4257-
layer_h = op.Slice(hx, layer_start, layer_end, axes=[0])
4257+
layer_h = op.Slice(hx, [layer_start], [layer_end], axes=[0])
42584258

42594259
# Extract parameters for this layer
42604260
# Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction
@@ -5770,8 +5770,8 @@ def aten_lstm(
57705770
# Extract hidden and cell states for this layer
57715771
layer_start = layer_idx * num_directions
57725772
layer_end = (layer_idx + 1) * num_directions
5773-
layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0])
5774-
layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0])
5773+
layer_h = op.Slice(initial_h, [layer_start], [layer_end], axes=[0])
5774+
layer_c = op.Slice(initial_c, [layer_start], [layer_end], axes=[0])
57755775

57765776
# Extract parameters for this layer
57775777
# Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction

0 commit comments

Comments
 (0)