22from torch2trt .module_test import add_module_test
33
44
5- @tensorrt_converter ('torch.nn.Linear.forward ' )
5+ @tensorrt_converter ('torch.nn.functional.linear ' )
66def convert_Linear (ctx ):
7- module = ctx .method_args [0 ]
8- input = ctx .method_args [1 ]
7+ input = ctx .method_args [0 ]
8+ weight = get_arg (ctx , 'weight' , 1 , None )
9+ bias = get_arg (ctx , 'bias' , 2 , None )
910 input_trt = add_missing_trt_tensors (ctx .network , [input ])[0 ]
1011 output = ctx .method_return
1112
1213 # reshape to ...xNx1x1
1314 layer = ctx .network .add_shuffle (input_trt )
1415 layer .reshape_dims = tuple (input_trt .shape ) + (1 , 1 )
1516
16- bias = trt .Weights (torch_dtype_to_trt (module . weight .dtype ))
17- if module . bias is not None :
18- bias = module . bias .detach ().cpu ().numpy ()
17+ bias_trt = trt .Weights (torch_dtype_to_trt (weight .dtype ))
18+ if bias is not None :
19+ bias_trt = bias .detach ().cpu ().numpy ()
1920
2021 # add fully connected
2122 layer = ctx .network .add_fully_connected (
2223 input = layer .get_output (0 ),
23- num_outputs = module . out_features ,
24- kernel = module . weight .detach ().cpu ().numpy (),
25- bias = bias )
24+ num_outputs = int ( weight . shape [ 0 ]) ,
25+ kernel = weight .detach ().cpu ().numpy (),
26+ bias = bias_trt )
2627
2728 # reshape back to N
2829 layer = ctx .network .add_shuffle (layer .get_output (0 ))
2930 layer .reshape_dims = tuple (output .shape [1 :])
3031
3132 output ._trt = layer .get_output (0 )
32-
33+
3334
3435@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 )])
3536@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 10 )])
@@ -42,4 +43,4 @@ def test_Linear_basic():
4243@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 10 )])
4344@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 4 , 10 )])
4445def test_Linear_no_bias ():
45- return torch .nn .Linear (10 , 5 , bias = False )
46+ return torch .nn .Linear (10 , 5 , bias = False )
0 commit comments