Skip to content

Commit fcc3d96

Browse files
author
John Welsh
committed
added layer_norm converter
1 parent 6df2e1f commit fcc3d96

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## [Master]
44

5+
- Added converter for ``torch.nn.functional.layer_norm``
56
- Added converter for ``torch.nn.functional.gelu``
67
- Added converter for ``torch.nn.functional.linear``
78
- Added converter for ``torch.nn.functional.silu``

torch2trt/converters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
from .floordiv import *
3131
from .gelu import *
3232
from .getitem import *
33+
from .group_norm import *
3334
from .identity import *
3435
from .instance_norm import *
3536
from .interpolate import *
36-
from .group_norm import *
37+
from .layer_norm import *
3738
from .max import *
3839
from .max_pool2d import *
3940
from .mean import *

torch2trt/converters/layer_norm.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.nn.functional.layer_norm')
6+
def convert_layernorm(ctx):
7+
input = get_arg(ctx, 'input', 0, None)
8+
shape = get_arg(ctx, 'normalized_shape', 1, None)
9+
weight = get_arg(ctx, 'weight', 2, None)
10+
bias = get_arg(ctx, 'bias', 3, None)
11+
eps = get_arg(ctx, 'eps', 4, 1e-05)
12+
output = ctx.method_return
13+
14+
input_trt, eps_trt = add_missing_trt_tensors(
15+
ctx.network,
16+
[input, eps]
17+
)
18+
19+
input_trt, eps_trt = broadcast_trt_tensors(
20+
ctx.network,
21+
[input_trt, eps_trt],
22+
len(output.shape) - 1
23+
)
24+
25+
if weight is not None:
26+
_, weight_trt = add_missing_trt_tensors(
27+
ctx.network,
28+
[input, weight]
29+
)
30+
_, weight_trt = broadcast_trt_tensors(
31+
ctx.network,
32+
[input_trt, weight_trt],
33+
len(output.shape) - 1
34+
)
35+
36+
if bias is not None:
37+
_, bias_trt = add_missing_trt_tensors(
38+
ctx.network,
39+
[input, bias]
40+
)
41+
_, bias_trt = broadcast_trt_tensors(
42+
ctx.network,
43+
[input_trt, bias_trt],
44+
len(output.shape) - 1
45+
)
46+
47+
if isinstance(shape, int):
48+
shape = (shape,)
49+
dim = tuple([-i - 1 for i in range(len(shape))])
50+
dim = torch_dim_resolve_negative(dim, len(input.shape))
51+
axes = torch_dim_to_trt_axes(dim)
52+
53+
ux = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, axes, keep_dims=True).get_output(0)
54+
numerator = ctx.network.add_elementwise(input_trt, ux, trt.ElementWiseOperation.SUB).get_output(0)
55+
varx = ctx.network.add_elementwise(numerator, numerator, trt.ElementWiseOperation.PROD).get_output(0)
56+
varx = ctx.network.add_reduce(varx, trt.ReduceOperation.AVG, axes, keep_dims=True).get_output(0)
57+
denom = ctx.network.add_elementwise(varx, eps_trt, trt.ElementWiseOperation.SUM).get_output(0)
58+
denom = ctx.network.add_unary(denom, trt.UnaryOperation.SQRT).get_output(0)
59+
y = ctx.network.add_elementwise(numerator, denom, trt.ElementWiseOperation.DIV).get_output(0)
60+
61+
if weight is not None:
62+
y = ctx.network.add_elementwise(y, weight_trt, trt.ElementWiseOperation.PROD).get_output(0)
63+
64+
if bias is not None:
65+
y = ctx.network.add_elementwise(y, bias_trt, trt.ElementWiseOperation.SUM).get_output(0)
66+
67+
output._trt = y
68+
69+
70+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
71+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
72+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
73+
def test_layer_norm_1d():
74+
return torch.nn.LayerNorm(3)
75+
76+
77+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
78+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
79+
def test_layer_norm_2d():
80+
return torch.nn.LayerNorm((5, 3))
81+
82+
83+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
84+
def test_layer_norm_3d():
85+
return torch.nn.LayerNorm((5, 5, 3))
86+
87+
88+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
89+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
90+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
91+
def test_layer_norm_1d_nonaffine():
92+
return torch.nn.LayerNorm(3, elementwise_affine=False)
93+
94+
95+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
96+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
97+
def test_layer_norm_2d_nonaffine():
98+
return torch.nn.LayerNorm((5, 3), elementwise_affine=False)
99+
100+
101+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
102+
def test_layer_norm_3d_nonaffine():
103+
return torch.nn.LayerNorm((5, 5, 3), elementwise_affine=False)

torch2trt/torch2trt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ def trt_num_outputs(engine):
8787
return count
8888

8989

90+
def torch_dim_resolve_negative(dim, ndim):
91+
if not isinstance(dim, tuple):
92+
dim = (dim,)
93+
pos = []
94+
for d in dim:
95+
if d < 0:
96+
d = ndim + d
97+
pos.append(d)
98+
return tuple(pos)
99+
100+
90101
def torch_dim_to_trt_axes(dim):
91102
"""Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
92103
if not isinstance(dim, tuple):

0 commit comments

Comments
 (0)