Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions onnxscript/rewriter/redundant_scatter_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from onnxscript.rewriter import pattern as orp


def fail(*args):
return onnxscript.rewriter.MatchResult().fail(*args)


class ScatterAllDynamic(orp.RewriteRuleClassBase):
def pattern(self, op, data, axis, transposed_data, updates):
# Construct update-indices spanning an entire axis:
Expand All @@ -41,24 +37,26 @@
def check(self, context, data, axis, transposed_data, **_):
# Check that updated-indices represent the full range of the first dimension of the transposed data.
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
result = onnxscript.rewriter.MatchResult()
axis_value = ir_utils.get_singleton_value(axis)
if not isinstance(axis_value, int):
return fail("Axis value must be a constant integer.", axis)
return result.fail("Axis value must be a constant integer.", axis)

Check warning on line 43 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L43

Added line #L43 was not covered by tests
shape: ir.Shape | None = data.shape
if shape is None:
return fail("Data shape is not statically known.", data)
return result.fail("Data shape is not statically known.", data)

Check warning on line 46 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L46

Added line #L46 was not covered by tests
updated_dim_value = shape[axis_value]
transposed_data_shape: ir.Shape | None = transposed_data.shape
if transposed_data_shape is None:
return fail("Transposed data shape is not statically known.", transposed_data)
return result.fail(

Check warning on line 50 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L50

Added line #L50 was not covered by tests
"Transposed data shape is not statically known.", transposed_data
)
actual_dim_value = transposed_data_shape[0]
if updated_dim_value != actual_dim_value:
# The first dimension of the transposed data does not match the updated dimension,
# so we cannot apply this rule.
return fail(
return result.fail(

Check warning on line 57 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L57

Added line #L57 was not covered by tests
"The first dimension of the transposed data does not match the updated dimension.",
data,
transposed_data,
[data, transposed_data],
)
return True

Expand All @@ -81,20 +79,23 @@
"""Check if the ScatterND is redundant due to static indices covering entire tensor."""
# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
result = onnxscript.rewriter.MatchResult()
if data.shape is None:
return fail("The value 'data' shape is not statically known.", data)
return result.fail("The value 'data' shape is not statically known.", data)

Check warning on line 84 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L84

Added line #L84 was not covered by tests
if updates.shape is None:
return fail("The value 'updates' shape is not statically known.", updates)
return result.fail("The value 'updates' shape is not statically known.", updates)

Check warning on line 86 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L86

Added line #L86 was not covered by tests
if data.shape != updates.shape:
return fail("The shape of 'data' and 'updates' are different.", data, updates)
return result.fail(

Check warning on line 88 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L88

Added line #L88 was not covered by tests
"The shape of 'data' and 'updates' are different.", [data, updates]
)

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
return fail("The value 'indices' is not statically known.", indices)
return result.fail("The value 'indices' is not statically known.", indices)

Check warning on line 94 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L94

Added line #L94 was not covered by tests
expected_indices = [[i] for i in range(data.shape[0])]
actual_indices = indices.const_value.numpy().tolist()
if actual_indices != expected_indices:
return fail("The 'indices' is not referring to the whole data.", indices)
return result.fail("The 'indices' is not referring to the whole data.", indices)

Check warning on line 98 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L98

Added line #L98 was not covered by tests

return True

Expand Down
Loading