Skip to content

Commit fb67b42

Browse files
author
John Welsh
committed
functional linear converter
1 parent 2653a71 commit fb67b42

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

torch2trt/converters/Linear.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,35 @@
22
from torch2trt.module_test import add_module_test
33

44

5-
@tensorrt_converter('torch.nn.Linear.forward')
5+
@tensorrt_converter('torch.nn.functional.linear')
66
def convert_Linear(ctx):
7-
module = ctx.method_args[0]
8-
input = ctx.method_args[1]
7+
input = ctx.method_args[0]
8+
weight = get_arg(ctx, 'weight', 1, None)
9+
bias = get_arg(ctx, 'bias', 2, None)
910
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1011
output = ctx.method_return
1112

1213
# reshape to ...xNx1x1
1314
layer = ctx.network.add_shuffle(input_trt)
1415
layer.reshape_dims = tuple(input_trt.shape) + (1, 1)
1516

16-
bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
17-
if module.bias is not None:
18-
bias = module.bias.detach().cpu().numpy()
17+
bias_trt = trt.Weights(torch_dtype_to_trt(weight.dtype))
18+
if bias is not None:
19+
bias_trt = bias.detach().cpu().numpy()
1920

2021
# add fully connected
2122
layer = ctx.network.add_fully_connected(
2223
input=layer.get_output(0),
23-
num_outputs=module.out_features,
24-
kernel=module.weight.detach().cpu().numpy(),
25-
bias=bias)
24+
num_outputs=int(weight.shape[0]),
25+
kernel=weight.detach().cpu().numpy(),
26+
bias=bias_trt)
2627

2728
# reshape back to N
2829
layer = ctx.network.add_shuffle(layer.get_output(0))
2930
layer.reshape_dims = tuple(output.shape[1:])
3031

3132
output._trt = layer.get_output(0)
32-
33+
3334

3435
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10)])
3536
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)])
@@ -42,4 +43,4 @@ def test_Linear_basic():
4243
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)])
4344
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)])
4445
def test_Linear_no_bias():
45-
return torch.nn.Linear(10, 5, bias=False)
46+
return torch.nn.Linear(10, 5, bias=False)

0 commit comments

Comments
 (0)