Skip to content

Commit 8e4d41d

Browse files
Copilotjustinchuby
andauthored
[torchlib] Implement aten_bilinear function using Einsum (#2574)
This PR implements the `aten_bilinear` function that was previously raising `NotImplementedError`. The bilinear transformation computes `y = x1^T A x2 + b` where: - `input1` has shape `(..., in1_features)` - `input2` has shape `(..., in2_features)` - `weight` has shape `(out_features, in1_features, in2_features)` - `bias` has shape `(out_features)` (optional) - Output has shape `(..., out_features)` ## Implementation Details The implementation is done using einsum. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent a106bad commit 8e4d41d

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
11951195
return op.CastLike(sampled, self)
11961196

11971197

1198+
@torch_op("aten::bilinear", trace_only=True)
11981199
def aten_bilinear(
11991200
input1: TensorType,
12001201
input2: TensorType,
@@ -1203,7 +1204,23 @@ def aten_bilinear(
12031204
) -> TensorType:
12041205
"""bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""
12051206

1206-
raise NotImplementedError()
1207+
# Bilinear transformation: y = x1^T A x2 + b
1208+
# input1 shape: (..., in1_features)
1209+
# input2 shape: (..., in2_features)
1210+
# weight shape: (out_features, in1_features, in2_features)
1211+
# bias shape: (out_features) - optional
1212+
# output shape: (..., out_features)
1213+
1214+
# Use Einsum to compute the bilinear transformation
1215+
# "...i,oij,...j->...o" means:
1216+
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1217+
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
1218+
1219+
# Add bias if provided
1220+
if bias is not None:
1221+
result = op.Add(result, bias)
1222+
1223+
return result
12071224

12081225

12091226
def aten_binary_cross_entropy_with_logits(

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs)
3737
yield opinfo_core.SampleInput(item, dtype=dtype)
3838

3939

40+
def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs):
41+
"""Sample inputs for bilinear operation."""
42+
del op_info
43+
del kwargs
44+
45+
make_arg = functools.partial(
46+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
47+
)
48+
49+
# Test cases: (batch_size, in1_features, in2_features, out_features)
50+
cases = [
51+
(2, 3, 4, 5), # Basic case
52+
(1, 2, 2, 1), # Minimal case
53+
(3, 5, 7, 4), # Different dimensions
54+
(2, 1, 1, 3), # Single input features
55+
]
56+
57+
for batch_size, in1_features, in2_features, out_features in cases:
58+
input1 = make_arg((batch_size, in1_features))
59+
input2 = make_arg((batch_size, in2_features))
60+
weight = make_arg((out_features, in1_features, in2_features))
61+
bias = make_arg((out_features,))
62+
63+
# Test with bias
64+
yield opinfo_core.SampleInput(input1, args=(input2, weight, bias))
65+
66+
# Test without bias (only for first case to avoid too many tests)
67+
if batch_size == 2:
68+
yield opinfo_core.SampleInput(input1, args=(input2, weight, None))
69+
70+
4071
def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs):
4172
del op_info
4273

@@ -2283,6 +2314,13 @@ def __init__(self):
22832314
# To avoid name duplication, it is possible to rename the OpInfo and specify
22842315
# the `op` field explicitly.
22852316
OP_DB: List[opinfo_core.OpInfo] = [
2317+
opinfo_core.OpInfo(
2318+
"bilinear",
2319+
op=torch.nn.functional.bilinear,
2320+
dtypes=common_dtype.floating_types(),
2321+
sample_inputs_func=sample_inputs_bilinear,
2322+
supports_out=False,
2323+
),
22862324
opinfo_core.OpInfo(
22872325
"ops.aten.bernoulli.p",
22882326
aten_name="bernoulli.p",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ def _where_input_wrangler(
657657
),
658658
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
659659
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
660+
TorchLibOpInfo(
661+
"bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)}
662+
),
660663
TorchLibOpInfo(
661664
# This string is a unique ID. In extra_opinfo.py, we
662665
# also define test data for this ID with

0 commit comments

Comments
 (0)