Skip to content

Commit b32a3c1

Browse files
authored
[torchlib] Fix layer norm dtype
Fix layer norm dtype mismatch errors
1 parent 882a442 commit b32a3c1

1 file changed

Lines changed: 1 addition & 19 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4749,27 +4749,9 @@ def aten_layer_norm(
47494749
start_axis = -len(normalized_shape)
47504750

47514751
if weight is None:
4752-
one = op.Constant(value_float=1.0)
4752+
one = op.Constant(value=ir.tensor(1, dtype=input.dtype))
47534753
weight = op.Expand(one, op.Shape(input, start=start_axis))
47544754

4755-
if bias is None:
4756-
zero = op.Constant(value_float=0.0)
4757-
bias = op.Expand(zero, op.Shape(input, start=start_axis))
4758-
4759-
return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps)
4760-
4761-
4762-
@torch_op("aten::layer_norm", private=True)
4763-
def _aten_layer_norm_onnx(
4764-
input: TReal,
4765-
weight: TReal,
4766-
bias: TReal,
4767-
axis: int,
4768-
eps: float = 1e-05,
4769-
) -> TReal:
4770-
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
4771-
4772-
# TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
47734755
result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
47744756
return result
47754757

0 commit comments

Comments
 (0)