|
| 1 | +from torch2trt.torch2trt import * |
| 2 | +from torch2trt.module_test import add_module_test |
| 3 | + |
| 4 | + |
| 5 | +@tensorrt_converter('torch.fmod') |
| 6 | +def convert_mod(ctx): |
| 7 | + input_a = ctx.method_args[0] |
| 8 | + input_b = ctx.method_args[1] |
| 9 | + output = ctx.method_return |
| 10 | + input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b]) |
| 11 | + input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1) |
| 12 | + # we can not use ElementWiseOperation.FLOOR_DIV directly because Torch truncate negative result toward 0 |
| 13 | + # but TensorRT FLOOR_DIV op toward -Inf |
| 14 | + # sign = ab / |ab| |
| 15 | + # floordiv result: sign * (|a| // |b|) |
| 16 | + ab_layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.PROD) |
| 17 | + abs_ab_layer = ctx.network.add_unary(ab_layer.get_output(0), trt.UnaryOperation.ABS) |
| 18 | + sign_layer = ctx.network.add_elementwise(ab_layer.get_output(0), abs_ab_layer.get_output(0), |
| 19 | + trt.ElementWiseOperation.DIV) |
| 20 | + abs_a_layer = ctx.network.add_unary(input_a_trt, trt.UnaryOperation.ABS) |
| 21 | + abs_b_layer = ctx.network.add_unary(input_b_trt, trt.UnaryOperation.ABS) |
| 22 | + abs_floor_layer = ctx.network.add_elementwise(abs_a_layer.get_output(0), abs_b_layer.get_output(0), |
| 23 | + trt.ElementWiseOperation.FLOOR_DIV) |
| 24 | + # a % b = a - (a//b) * b |
| 25 | + floordiv_layer = ctx.network.add_elementwise(sign_layer.get_output(0), abs_floor_layer.get_output(0), |
| 26 | + trt.ElementWiseOperation.PROD) |
| 27 | + prod_layer = ctx.network.add_elementwise(floordiv_layer.get_output(0), input_b_trt, trt.ElementWiseOperation.PROD) |
| 28 | + sub_layer = ctx.network.add_elementwise(input_a_trt, prod_layer.get_output(0), trt.ElementWiseOperation.SUB) |
| 29 | + output._trt = sub_layer.get_output(0) |
| 30 | + |
| 31 | + |
| 32 | +@tensorrt_converter('torch.Tensor.__mod__') |
| 33 | +# we need separate converter for operator because for some reason Torch use truncation toward -Inf for this op. |
| 34 | +# bug is filed: https://github.com/pytorch/pytorch/issues/52425 |
| 35 | +# but for now we have to convert model exactly |
| 36 | +def convert_mod(ctx): |
| 37 | + input_a = ctx.method_args[0] |
| 38 | + input_b = ctx.method_args[1] |
| 39 | + output = ctx.method_return |
| 40 | + input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b]) |
| 41 | + input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1) |
| 42 | + # a % b = a - (a//b) * b |
| 43 | + floordiv_layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.FLOOR_DIV) |
| 44 | + prod_layer = ctx.network.add_elementwise(floordiv_layer.get_output(0), input_b_trt, trt.ElementWiseOperation.PROD) |
| 45 | + mod_layer = ctx.network.add_elementwise(input_a_trt, prod_layer.get_output(0), trt.ElementWiseOperation.SUB) |
| 46 | + output._trt = mod_layer.get_output(0) |
| 47 | + |
| 48 | + |
| 49 | +class Mod(torch.nn.Module): |
| 50 | + def __init__(self): |
| 51 | + super(Mod, self).__init__() |
| 52 | + |
| 53 | + def forward(self, x, y): |
| 54 | + return x % y |
| 55 | + |
| 56 | + |
| 57 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)]) |
| 58 | +def test_mod_op(): |
| 59 | + return Mod() |
| 60 | + |
| 61 | + |
| 62 | +class ModAssign(torch.nn.Module): |
| 63 | + def __init__(self): |
| 64 | + super(ModAssign, self).__init__() |
| 65 | + |
| 66 | + def forward(self, x, y): |
| 67 | + x %= y |
| 68 | + return x |
| 69 | + |
| 70 | + |
| 71 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)]) |
| 72 | +def test_mod_op_assign(): |
| 73 | + return ModAssign() |
| 74 | + |
| 75 | + |
| 76 | +class ModConst(torch.nn.Module): |
| 77 | + def __init__(self): |
| 78 | + super(ModConst, self).__init__() |
| 79 | + |
| 80 | + def forward(self, x): |
| 81 | + return x % 2. |
| 82 | + |
| 83 | + |
| 84 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)]) |
| 85 | +def test_mod_op_const(): |
| 86 | + return ModConst() |
| 87 | + |
| 88 | + |
| 89 | +class TorchMod(torch.nn.Module): |
| 90 | + def __init__(self): |
| 91 | + super(TorchMod, self).__init__() |
| 92 | + |
| 93 | + def forward(self, x, y): |
| 94 | + return torch.fmod(x, y) |
| 95 | + |
| 96 | + |
| 97 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 40, 20)]) |
| 98 | +def test_mod_func(): |
| 99 | + return TorchMod() |
0 commit comments