|
31 | 31 | UINT32, |
32 | 32 | UINT64, |
33 | 33 | graph, |
| 34 | + ir, |
34 | 35 | ) |
35 | 36 | from onnxscript.function_libs.torch_lib.ops import common as common_ops |
36 | 37 | from onnxscript.function_libs.torch_lib.registration import torch_op |
@@ -4749,28 +4750,10 @@ def aten_layer_norm( |
4749 | 4750 | start_axis = -len(normalized_shape) |
4750 | 4751 |
|
4751 | 4752 | if weight is None: |
4752 | | - one = op.Constant(value_float=1.0) |
| 4753 | + one = op.Constant(value=ir.tensor(1, dtype=input.dtype)) |
4753 | 4754 | weight = op.Expand(one, op.Shape(input, start=start_axis)) |
4754 | 4755 |
|
4755 | | - if bias is None: |
4756 | | - zero = op.Constant(value_float=0.0) |
4757 | | - bias = op.Expand(zero, op.Shape(input, start=start_axis)) |
4758 | | - |
4759 | | - return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps) |
4760 | | - |
4761 | | - |
4762 | | -@torch_op("aten::layer_norm", private=True) |
4763 | | -def _aten_layer_norm_onnx( |
4764 | | - input: TReal, |
4765 | | - weight: TReal, |
4766 | | - bias: TReal, |
4767 | | - axis: int, |
4768 | | - eps: float = 1e-05, |
4769 | | -) -> TReal: |
4770 | | - """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor""" |
4771 | | - |
4772 | | - # TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982 |
4773 | | - result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps) |
| 4756 | + result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps) |
4774 | 4757 | return result |
4775 | 4758 |
|
4776 | 4759 |
|
|
0 commit comments