44import numpy as np
55import tensorrt as trt
66import torch
7+ from torch ._subclasses .fake_tensor import unset_fake_temporarily
78from torch .fx .node import Target
89from torch_tensorrt .dynamo ._SourceIR import SourceIR
910from torch_tensorrt .dynamo .conversion import impl
@@ -32,21 +33,22 @@ def batch_norm(
3233 source_ir : Optional [SourceIR ],
3334 name : str ,
3435 input : trt .ITensor ,
35- weight : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
36- bias : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
37- running_mean : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
38- running_var : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
39- training : bool ,
4036 momentum : float ,
4137 eps : float ,
42- cudnn_enabled : bool ,
4338 return_mean_rstd : bool ,
39+ weight : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
40+ bias : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
41+ running_mean : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
42+ running_var : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
43+ training : bool = False ,
44+ cudnn_enabled : bool = False ,
4445) -> Union [trt .ITensor , Tuple [trt .ITensor , torch .Tensor , torch .Tensor ]]:
4546 if has_dynamic_shape (input .shape ):
4647 assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
4748
4849 # Save the original output shape for later use
4950 output_shape = input .shape
51+ feature_num = output_shape [1 ]
5052 # We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5153 # Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5254 # In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
@@ -59,26 +61,41 @@ def batch_norm(
5961 ]
6062 ):
6163 # We name the weight here according to the state_dict name
62- weight = (
63- get_trt_tensor (ctx , 1.0 , f"{ name } _weight" , dtype = input .dtype )
64- if weight is None
65- else get_trt_tensor (ctx , weight , f"{ name } _weight" )
66- )
67- bias = (
68- get_trt_tensor (ctx , 0.0 , f"{ name } _bias" , dtype = input .dtype )
69- if bias is None
70- else get_trt_tensor (ctx , bias , f"{ name } _bias" )
71- )
72- running_mean = (
73- get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" , dtype = input .dtype )
74- if running_mean is None
75- else get_trt_tensor (ctx , running_mean , f"{ name } _running_mean" )
76- )
77- running_var = (
78- get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" , dtype = input .dtype )
79- if running_var is None
80- else get_trt_tensor (ctx , running_var , f"{ name } _running_var" )
81- )
64+ with unset_fake_temporarily ():
65+ weight = (
66+ get_trt_tensor (
67+ ctx , torch .ones ((feature_num ,)), f"{ name } _weight" , dtype = input .dtype
68+ )
69+ if weight is None
70+ else get_trt_tensor (ctx , weight , f"{ name } _weight" )
71+ )
72+ bias = (
73+ get_trt_tensor (
74+ ctx , torch .zeros ((feature_num ,)), f"{ name } _bias" , dtype = input .dtype
75+ )
76+ if bias is None
77+ else get_trt_tensor (ctx , bias , f"{ name } _bias" )
78+ )
79+ running_mean = (
80+ get_trt_tensor (
81+ ctx ,
82+ torch .zeros ((feature_num ,)),
83+ f"{ name } _running_mean" ,
84+ dtype = input .dtype ,
85+ )
86+ if running_mean is None
87+ else get_trt_tensor (ctx , running_mean , f"{ name } _running_mean" )
88+ )
89+ running_var = (
90+ get_trt_tensor (
91+ ctx ,
92+ torch .ones ((feature_num ,)),
93+ f"{ name } _running_var" ,
94+ dtype = input .dtype ,
95+ )
96+ if running_var is None
97+ else get_trt_tensor (ctx , running_var , f"{ name } _running_var" )
98+ )
8299
83100 # eps_tensor for numerical stability
84101 eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" , dtype = input .dtype )
@@ -110,8 +127,7 @@ def batch_norm(
110127
111128 # Reshape scale and bias_adjusted to match input shape for broadcasting
112129 expanded_shape = [1 ] * len (output_shape )
113- expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
114-
130+ expanded_shape [1 ] = feature_num # Set channel dimension
115131 scale_reshape = impl .shuffle .reshape (
116132 ctx ,
117133 target ,
@@ -143,21 +159,24 @@ def batch_norm(
143159 )
144160
145161 else :
146- if weight is None :
147- weight = 1.0
162+ with unset_fake_temporarily ():
163+ if weight is None :
164+ weight = torch .ones ((feature_num ,))
148165
149- if bias is None :
150- bias = 0.0
166+ if bias is None :
167+ bias = torch . zeros (( feature_num ,))
151168
152- if running_mean is None :
153- running_mean = 0.0
169+ if running_mean is None :
170+ running_mean = torch . zeros (( feature_num ,))
154171
155- if running_var is None :
156- running_var = 1.0
157- adjusted_scale , adjusted_bias = batch_norm_constant_folding (
158- weight , bias , running_mean , running_var , eps
159- )
160- power = torch .ones_like (adjusted_scale )
172+ if running_var is None :
173+ running_var = torch .ones ((feature_num ,))
174+
175+ power = torch .ones_like (weight )
176+
177+ adjusted_scale , adjusted_bias = batch_norm_constant_folding (
178+ weight , bias , running_mean , running_var , eps
179+ )
161180
162181 adjusted_scale = to_trt_weights (
163182 ctx ,
@@ -188,9 +207,7 @@ def batch_norm(
188207 source_ir = source_ir ,
189208 )
190209
191- output_shape = input .shape
192210 if len (input .shape ) < 4 :
193-
194211 new_shape = (
195212 (input .shape [0 ], input .shape [1 ], 1 , 1 )
196213 if len (input .shape ) == 2
0 commit comments