Skip to content

Commit 6df2e1f

Browse files
author
John Welsh
committed
added gelu converter
1 parent 1fc31b7 commit 6df2e1f

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Master]
44

5+
- Added converter for ``torch.nn.functional.gelu``
6+
- Added converter for ``torch.nn.functional.linear``
7+
- Added converter for ``torch.nn.functional.silu``
8+
59
## [0.2.0] - 03/02/2021
610

711
- Added converter for ``torch.Tensor.flatten``

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .div import *
2929
from .expand import *
3030
from .floordiv import *
31+
from .gelu import *
3132
from .getitem import *
3233
from .identity import *
3334
from .instance_norm import *

torch2trt/converters/gelu.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
import math
4+
5+
6+
@tensorrt_converter('torch.nn.functional.gelu')
7+
def convert_gelu_v1(ctx):
8+
# approximate equation 1 from paper
9+
input = get_arg(ctx, 'input', 0, None)
10+
output = ctx.method_return
11+
12+
x, c05, c1, cs2pi, c044, c3 = add_missing_trt_tensors(
13+
ctx.network,
14+
[input, 0.5, 1.0, math.sqrt(2.0 / math.pi), 0.044715, 3.0]
15+
)
16+
17+
x, c05, c1, cs2pi, c044, c3 = broadcast_trt_tensors(
18+
ctx.network,
19+
[x, c05, c1, cs2pi, c044, c3],
20+
len(output.shape) - 1
21+
)
22+
23+
y = ctx.network.add_elementwise(x, c3, trt.ElementWiseOperation.POW).get_output(0)
24+
y = ctx.network.add_elementwise(y, c044, trt.ElementWiseOperation.PROD).get_output(0)
25+
y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.SUM).get_output(0)
26+
y = ctx.network.add_elementwise(y, cs2pi, trt.ElementWiseOperation.PROD).get_output(0)
27+
y = ctx.network.add_activation(y, trt.ActivationType.TANH).get_output(0)
28+
y = ctx.network.add_elementwise(y, c1, trt.ElementWiseOperation.SUM).get_output(0)
29+
y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.PROD).get_output(0)
30+
y = ctx.network.add_elementwise(y, c05, trt.ElementWiseOperation.PROD).get_output(0)
31+
32+
output._trt = y
33+
34+
35+
# @tensorrt_converter('torch.nn.functional.gelu')
36+
# def convert_gelu_v2(ctx):
37+
# # approximate equation 1 from paper
38+
# input = get_arg(ctx, 'input', 0, None)
39+
# output = ctx.method_return
40+
41+
# x, c1702 = add_missing_trt_tensors(
42+
# ctx.network,
43+
# [input, 1.702]
44+
# )
45+
46+
# x, c1702 = broadcast_trt_tensors(
47+
# ctx.network,
48+
# [x, c1702],
49+
# len(output.shape) - 1
50+
# )
51+
52+
# y = ctx.network.add_elementwise(x, c1702, trt.ElementWiseOperation.PROD).get_output(0)
53+
# y = ctx.network.add_activation(y, trt.ActivationType.SIGMOID).get_output(0)
54+
# y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.PROD).get_output(0)
55+
56+
# output._trt = y
57+
58+
59+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5)])
60+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
61+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)])
62+
def test_silu():
63+
return torch.nn.GELU()

0 commit comments

Comments
 (0)