1+ from torch2trt .torch2trt import *
2+ from torch2trt .module_test import add_module_test
3+ import math
4+
5+
6+ @tensorrt_converter ('torch.nn.functional.gelu' )
7+ def convert_gelu_v1 (ctx ):
8+ # approximate equation 1 from paper
9+ input = get_arg (ctx , 'input' , 0 , None )
10+ output = ctx .method_return
11+
12+ x , c05 , c1 , cs2pi , c044 , c3 = add_missing_trt_tensors (
13+ ctx .network ,
14+ [input , 0.5 , 1.0 , math .sqrt (2.0 / math .pi ), 0.044715 , 3.0 ]
15+ )
16+
17+ x , c05 , c1 , cs2pi , c044 , c3 = broadcast_trt_tensors (
18+ ctx .network ,
19+ [x , c05 , c1 , cs2pi , c044 , c3 ],
20+ len (output .shape ) - 1
21+ )
22+
23+ y = ctx .network .add_elementwise (x , c3 , trt .ElementWiseOperation .POW ).get_output (0 )
24+ y = ctx .network .add_elementwise (y , c044 , trt .ElementWiseOperation .PROD ).get_output (0 )
25+ y = ctx .network .add_elementwise (x , y , trt .ElementWiseOperation .SUM ).get_output (0 )
26+ y = ctx .network .add_elementwise (y , cs2pi , trt .ElementWiseOperation .PROD ).get_output (0 )
27+ y = ctx .network .add_activation (y , trt .ActivationType .TANH ).get_output (0 )
28+ y = ctx .network .add_elementwise (y , c1 , trt .ElementWiseOperation .SUM ).get_output (0 )
29+ y = ctx .network .add_elementwise (x , y , trt .ElementWiseOperation .PROD ).get_output (0 )
30+ y = ctx .network .add_elementwise (y , c05 , trt .ElementWiseOperation .PROD ).get_output (0 )
31+
32+ output ._trt = y
33+
34+
35+ # @tensorrt_converter('torch.nn.functional.gelu')
36+ # def convert_gelu_v2(ctx):
37+ # # approximate equation 1 from paper
38+ # input = get_arg(ctx, 'input', 0, None)
39+ # output = ctx.method_return
40+
41+ # x, c1702 = add_missing_trt_tensors(
42+ # ctx.network,
43+ # [input, 1.702]
44+ # )
45+
46+ # x, c1702 = broadcast_trt_tensors(
47+ # ctx.network,
48+ # [x, c1702],
49+ # len(output.shape) - 1
50+ # )
51+
52+ # y = ctx.network.add_elementwise(x, c1702, trt.ElementWiseOperation.PROD).get_output(0)
53+ # y = ctx.network.add_activation(y, trt.ActivationType.SIGMOID).get_output(0)
54+ # y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.PROD).get_output(0)
55+
56+ # output._trt = y
57+
58+
59+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 5 )])
60+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 5 , 3 )])
61+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 5 , 3 , 3 )])
62+ def test_silu ():
63+ return torch .nn .GELU ()
0 commit comments