Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,42 +98,6 @@ init_command = [
]
is_formatter = true

[[linter]]
code = 'PYLINT'
include_patterns = [
'**/*.py',
]
exclude_patterns = [
'docs/**',
'examples/**',
'onnxscript/_internal/converter_test.py',
'onnxscript/optimizer/**', # FIXME
'onnxscript/rewriter/**', # FIXME
'tests/functions/**',
'tests/models/**',
'tests/onnx_backend_test_code/**',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pylint_linter',
'--rcfile=pyproject_pylint.toml',
'--show-disable',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements/lintrunner/requirements.txt',
]

[[linter]]
code = 'EDITORCONFIG-CHECKER'
include_patterns = ['**']
Expand Down
9 changes: 6 additions & 3 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Sequence[float],
Sequence[bool],
Sequence[str],
None,
]

# Mapping from Python scalar types to their default ONNX DataType,
Expand Down Expand Up @@ -258,7 +259,7 @@ def initializer(

def _input_to_ir_value(
self, value: VALUE_LIKE, like_type: ir.Value | None = None
) -> ir.Value:
) -> ir.Value | None:
"""Convert a permissible input (for a call to an op) into an ir.Value.

Permissible values include ir.Value as well as python constants that can be converted
Expand All @@ -267,6 +268,8 @@ def _input_to_ir_value(
"""
if isinstance(value, ir.Value):
return value
if value is None:
return value
Comment thread
justinchuby marked this conversation as resolved.
dtype = (
like_type.type.dtype
if like_type is not None and like_type.type is not None
Expand Down Expand Up @@ -356,7 +359,7 @@ def _get_schema(
def _partition_inputs_attributes(
self,
schema: onnx.defs.OpSchema | None,
inputs: Sequence[ir.Value | ir.TensorProtocol],
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
kwargs: dict[str, Any],
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
if schema is None:
Expand Down Expand Up @@ -504,7 +507,7 @@ def subgraph(
def call_op(
self,
op_type: str,
inputs: Sequence[ir.Value | ir.TensorProtocol],
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
Comment thread
justinchuby marked this conversation as resolved.
kwargs: dict[str, Any],
):
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
Expand Down
33 changes: 33 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,39 @@ def add_mul(X, Y):

self.assertIn("does not match", str(cm.exception))

def test_none_input_is_passed_through(self):
"""Test that None inputs are preserved as None in the node's inputs."""
op, x, y = _create_builder_with_inputs()

# Gemm's third input (C) is optional; passing None should work
result = op.Gemm(x, y, None, alpha=1.0)

nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
node = nodes[0]
self.assertEqual(node.op_type, "Gemm")
# The third input should be None (optional, omitted)
self.assertEqual(len(list(node.inputs)), 3)
self.assertIs(node.inputs[0], x)
self.assertIs(node.inputs[1], y)
self.assertIsNone(node.inputs[2])
self.assertIsNotNone(result)

def test_none_input_with_custom_domain(self):
"""Test that None inputs work with custom domain ops."""
op, x, y = _create_builder_with_inputs()

result = op.CustomOp(x, None, y, _domain="com.custom")

nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
node = nodes[0]
self.assertEqual(node.op_type, "CustomOp")
self.assertIs(node.inputs[0], x)
self.assertIsNone(node.inputs[1])
self.assertIs(node.inputs[2], y)
self.assertIsNotNone(result)


class BuildSubgraphTest(unittest.TestCase):
"""Tests for GraphBuilder.subgraph()."""
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ module = [
ignore_errors = true

[tool.pylint.messages_control]
# NOTE: This list is for vscode. Add new disables in pyproject_pylint.toml for lintrunner
# Exclude patterns should be modified in .lintrunner.toml
# NOTE: This list is for vscode. Pylint is removed from CI
disable = [
"consider-using-from-import",
"format",
Expand Down
34 changes: 0 additions & 34 deletions pyproject_pylint.toml

This file was deleted.

2 changes: 0 additions & 2 deletions requirements/lintrunner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@ ruff==0.15.1
# MYPY
mypy==1.10.1
types-PyYAML==6.0.12.20250915
# PYLINT
pylint==3.3.9
# EDITORCONFIG-CHECKER
editorconfig-checker==3.4.1
Loading