@@ -620,22 +620,19 @@ auto aten_registrations TORCHTRT_UNUSED =
620620 {" aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)" })})
621621 .evaluator({c10::Symbol::fromQualString (" aten::arange" ),
622622 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
623- int input_size = n->inputs ().size ();
624- int scalar_count = 0 ;
625- for (int i = 0 ; i < input_size; i++) {
626- if (args.at (n->input (i)).IValue ()->isScalar ()) {
627- scalar_count += 1 ;
628- }
629- }
630- if (scalar_count == 1 ) {
623+ auto schema = n->maybeSchema ();
624+ TORCHTRT_CHECK (schema, " Unable to get schema for node: " << *n);
625+ auto name = schema->operator_name ();
626+
627+ if (c10::toString (name) == " aten::arange" ) {
631628 if (args.at (n->input (0 )).IValue ()->isInt ()) {
632629 int end_scalar = args.at (n->input (0 )).unwrapToInt ();
633630 return torch::arange (end_scalar);
634631 } else if (args.at (n->input (0 )).IValue ()->isDouble ()) {
635632 float end_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
636633 return torch::arange (end_scalar);
637634 }
638- } else if (scalar_count == 2 ) {
635+ } else if (c10::toString (name) == " aten::arange.start " ) {
639636 if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble ()) {
640637 float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
641638 float end_scalar = args.at (n->input (1 )).unwrapToScalar ().to <float >();
@@ -645,7 +642,7 @@ auto aten_registrations TORCHTRT_UNUSED =
645642 int end_scalar = args.at (n->input (1 )).unwrapToInt ();
646643 return torch::arange (start_scalar, end_scalar);
647644 }
648- } else if (scalar_count == 3 ) {
645+ } else if (c10::toString (name) == " aten::arange.start_step " ) {
649646 if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble () ||
650647 args.at (n->input (2 )).IValue ()->isDouble ()) {
651648 float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
@@ -659,8 +656,7 @@ auto aten_registrations TORCHTRT_UNUSED =
659656 return torch::arange (start_scalar, end_scalar, step_scalar);
660657 }
661658 } else {
662- TORCHTRT_THROW_ERROR (
663- " Invalid input argument size for aten::arange, input argument size: " << input_size);
659+ TORCHTRT_THROW_ERROR (" Unsupported aten::arange variant: " << name);
664660 }
665661 return {};
666662 },
0 commit comments