Skip to content

Commit 6e91205

Browse files
authored
feat: modify aten_bilinear from einsum to matmul (#2746)
#### What this does - Replaces Bilinear Transformation's einsum implementation with matmul #### Why - Performance increase from MatMul over Einsum #### Testing - Existing tests pass for function implementation Resolves #2573
1 parent 72321e5 commit 6e91205

File tree

1 file changed

+27
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+27
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,12 +1208,34 @@ def aten_bilinear(
12081208
# bias shape: (out_features) - optional
12091209
# output shape: (..., out_features)
12101210

1211-
# Use Einsum to compute the bilinear transformation
1212-
# "...i,oij,...j->...o" means:
1213-
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1214-
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
1211+
# input1 and input2 must have identical batch dimensions
1212+
# Use MatMul to compute the bilinear transformation
1213+
batch_size = op.Shape(input1, start=0, end=-1)
1214+
input1_shape = op.Shape(input1, start=-1)
1215+
input2_shape = op.Shape(input2, start=-1)
1216+
output_shape = op.Shape(weight, start=0, end=1)
1217+
neg_1 = op.Constant(value_ints=[-1])
1218+
1219+
# (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features)
1220+
W_permute = op.Transpose(weight, perm=[1, 0, 2])
1221+
1222+
# (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features)
1223+
W_flat = op.Reshape(
1224+
W_permute,
1225+
op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0),
1226+
)
1227+
1228+
# (..., in1_features) @ (in1_features, out_features * in2_features) -> (..., out_features * in2_features)
1229+
tmp = op.MatMul(input1, W_flat)
1230+
1231+
# (..., out_features * in2_features) -> (..., out_features, in2_features)
1232+
tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0))
1233+
1234+
# (..., in2_features) -> (..., in2_features, 1)
1235+
# -> (..., out_features, in2_features) @ (..., in2_features, 1)
1236+
# -> (..., out_features, 1) -> (..., out_features)
1237+
result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1)
12151238

1216-
# Add bias if provided
12171239
if bias is not None:
12181240
result = op.Add(result, bias)
12191241

0 commit comments

Comments
 (0)