@@ -171,3 +171,238 @@ def forward(self, x):
171171 )
172172 print (m )
173173 m .operation .verify ()
174+
175+
176+ @run
177+ # CHECK-LABEL: test_single_input_const_argument
178+ # CHECK: %[[int2:.+]] = torch.constant.int 2
179+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[int2]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
180+ # CHECK: return %[[buffer]] : !torch.vtensor<[3,4],f32>
181+ def test_single_input_const_argument ():
182+ class SingleConstantInputModule (torch .nn .Module ):
183+ def __init__ (self ):
184+ super ().__init__ ()
185+
186+ def forward (self , x , y = 2 ): # Single constant input
187+ return x * y
188+
189+ m = fx .export_and_import (
190+ SingleConstantInputModule (),
191+ torch .randn (3 , 4 ),
192+ experimental_support_mutation = True ,
193+ )
194+ print (m )
195+ m .operation .verify ()
196+
197+
198+ @run
199+ # CHECK-LABEL: test_single_output_const_argument
200+ # CHECK: %[[float1:.+]] = torch.constant.float 5.000000e-01
201+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float1]]
202+ # CHECK: %[[float2:.+]] = torch.constant.float 5.000000e-01
203+ # CHECK: return %[[buffer]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float
204+ def test_single_output_const_argument ():
205+ class SingleConstantOutputModule (torch .nn .Module ):
206+ def __init__ (self ):
207+ super ().__init__ ()
208+ self .scale = 0.5 # Single constant output
209+
210+ def forward (self , x ):
211+ scaled = x * self .scale
212+ return scaled , self .scale # Return tensor + constant
213+
214+ m = fx .export_and_import (
215+ SingleConstantOutputModule (),
216+ torch .randn (3 , 4 ),
217+ experimental_support_mutation = True ,
218+ )
219+ print (m )
220+ m .operation .verify ()
221+
222+
223+ @run
224+ # CHECK-LABEL: test_multiple_input_const_argument
225+ # CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
226+ # CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
227+ # CHECK: %[[float3:.+]] = torch.constant.float 3.000000e+00
228+ # CHECK: %[[int1:.+]] = torch.constant.int 1
229+ # CHECK: %[[buffer1:.+]] = torch.aten.add.Scalar %[[buffer0]], %[[float3]], %[[int1]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,4],f32>
230+ # CHECK: return %[[buffer1]] : !torch.vtensor<[3,4],f32>
231+ def test_multiple_input_const_argument ():
232+ class MultipleConstantInputModule (torch .nn .Module ):
233+ def __init__ (self ):
234+ super ().__init__ ()
235+
236+ def forward (
237+ self , x , scale = 2.0 , offset = 1.0 , multiplier = 3
238+ ): # Multiple constant inputs
239+ return x * scale + offset * multiplier
240+
241+ m = fx .export_and_import (
242+ MultipleConstantInputModule (),
243+ torch .randn (3 , 4 ),
244+ experimental_support_mutation = True ,
245+ )
246+ print (m )
247+ m .operation .verify ()
248+
249+
250+ @run
251+ # CHECK-LABEL: test_multiple_output_const_argument
252+ # CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
253+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
254+ # CHECK: %[[str:.+]] = torch.constant.str "model"
255+ # CHECK: %[[int42:.+]] = torch.constant.int 42
256+ # CHECK: %[[true:.+]] = torch.constant.bool true
257+ # CHECK: %[[none:.+]] = torch.constant.none
258+ # CHECK: return %[[buffer]], %[[float5]]
259+ # CHECK-SAME: %[[str]], %[[int42]], %[[true]], %[[none]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.str, !torch.int, !torch.bool, !torch.none
260+ def test_multiple_output_const_argument ():
261+ class MultipleConstantOutputModule (torch .nn .Module ):
262+ def __init__ (self ):
263+ super ().__init__ ()
264+ self .scale = 0.5
265+ self .name = "model"
266+ self .version = 42
267+
268+ def forward (self , x ):
269+ result = x * self .scale
270+ # Return tensor + multiple constants
271+ return result , self .scale , self .name , self .version , True , None
272+
273+ m = fx .export_and_import (
274+ MultipleConstantOutputModule (),
275+ torch .randn (3 , 4 ),
276+ experimental_support_mutation = True ,
277+ )
278+ print (m )
279+ m .operation .verify ()
280+
281+
282+ @run
283+ # CHECK-LABEL: test_input_output_const_argument
284+ # CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
285+ # CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]]
286+ # CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
287+ # CHECK: %[[buffer1:.+]] = torch.aten.mul.Scalar %[[buffer0]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
288+ # CHECK: %[[float1:.+]] = torch.constant.float 1.000000e+00
289+ # CHECK: %[[int1:.+]] = torch.constant.int 1
290+ # CHECK: %[[buffer2:.+]] = torch.aten.add.Scalar %[[buffer1]], %[[float1]], %[[int1]]
291+ # CHECK: %[[str:.+]] = torch.constant.str "combined_model"
292+ # CHECK: %[[true:.+]] = torch.constant.bool true
293+ # CHECK: %[[none:.+]] = torch.constant.none
294+ # CHECK: return %[[buffer2]], %[[float5]]
295+ # CHECK-SAME: %[[str]]
296+ def test_input_output_const_argument ():
297+ class CombinedConstantModule (torch .nn .Module ):
298+ def __init__ (self ):
299+ super ().__init__ ()
300+ self .base_scale = 0.5
301+ self .model_name = "combined_model"
302+
303+ def forward (self , x , user_scale = 2.0 , add_bias = True , bias_value = 1.0 ):
304+ if add_bias :
305+ result = (x * self .base_scale * user_scale ) + bias_value
306+ else :
307+ result = x * self .base_scale * user_scale
308+
309+ # Return mix of tensors and constants (both output and input)
310+ return (
311+ result , # tensor
312+ self .base_scale , # constantArgument output
313+ self .model_name , # constantArgument output
314+ user_scale , # constantArgument input
315+ add_bias , # constantArgument input
316+ bias_value , # constantArgument input
317+ None , # constantArgument literal (output)
318+ )
319+
320+ m = fx .export_and_import (
321+ CombinedConstantModule (), torch .randn (3 , 4 ), experimental_support_mutation = True
322+ )
323+ print (m )
324+ m .operation .verify ()
325+
326+
327+ @run
328+ # CHECK-LABEL: test_const_argument_edge_cases
329+ # CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>) ->
330+ # CHECK-SAME: (!torch.vtensor<[3,4],f32>, !torch.float, !torch.int, !torch.str, !torch.bool, !torch.none, !torch.none, !torch.str, !torch.int, !torch.bool)
331+ # CHECK: %[[float314:.+]] = torch.constant.float 3.140000e+00
332+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float314]]
333+ # CHECK: %[[int42:.+]] = torch.constant.int 42
334+ # CHECK: %[[string1:.+]] = torch.constant.str "test"
335+ # CHECK: %[[true:.+]] = torch.constant.bool true
336+ # CHECK: %[[none:.+]] = torch.constant.none
337+ # CHECK: %[[string2:.+]] = torch.constant.str "default"
338+ # CHECK: %[[int0:.+]] = torch.constant.int 0
339+ # CHECK: %[[false:.+]] = torch.constant.bool false
340+ # CHECK: return %[[buffer]], %[[float314]]
341+ # CHECK-SAME: %[[int42]], %[[string1]], %[[true]], %[[none]], %[[none]]
342+ # CHECK-SAME: %[[string2]], %[[int0]], %[[false]]
343+ def test_const_argument_edge_cases ():
344+ class EdgeCaseConstantModule (torch .nn .Module ):
345+ def __init__ (self ):
346+ super ().__init__ ()
347+ self .float_val = 3.14
348+ self .int_val = 42
349+ self .str_val = "test"
350+ self .bool_val = True
351+ self .none_val = None
352+
353+ def forward (self , x , input_none = None , input_str = "default" ):
354+ result = x * self .float_val
355+
356+ # Return all different ConstantArgument types
357+ return (
358+ result , # tensor
359+ self .float_val , # float output constantArgument
360+ self .int_val , # int output constantArgument
361+ self .str_val , # string output constantArgument
362+ self .bool_val , # bool output constantArgument
363+ self .none_val , # None output constantArgument
364+ input_none , # None input constantArgument
365+ input_str , # string input constantArgument
366+ 0 , # literal int
367+ False , # literal bool
368+ )
369+
370+ m = fx .export_and_import (
371+ EdgeCaseConstantModule (), torch .randn (3 , 4 ), experimental_support_mutation = True
372+ )
373+ print (m )
374+ m .operation .verify ()
375+
376+
377+ @run
378+ # CHECK-LABEL: test_const_argument_from_multiheadattention_layer
379+ # CHECK: func.func @main(%arg0: !torch.vtensor<[1,10,64],f32>, %arg1: !torch.vtensor<[1,10,64],f32>, %arg2: !torch.vtensor<[1,10,64],f32>) ->
380+ # CHECK-SAME: (!torch.vtensor<[1,10,64],f32>, !torch.none)
381+ # CHECK: %[[int1:.+]] = torch.constant.int 1
382+ # CHECK: %[[int0:.+]] = torch.constant.int 0
383+ # CHECK-DAG: %[[buffer:.+]] = torch.aten.transpose.int %arg0, %[[int1]], %[[int0]] : !torch.vtensor<[1,10,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,1,64],f32>
384+ def test_const_argument_from_multiheadattention_layer ():
385+ """
386+ Test case using actual MultiheadAttention where a constantArgument appears automatically
387+ due to returning the attention layer without the weights (need_weights=False)
388+ """
389+
390+ class AttentionLikeConstantModule (torch .nn .Module ):
391+ def __init__ (self ):
392+ super ().__init__ ()
393+ self .attn = torch .nn .MultiheadAttention (
394+ embed_dim = 64 , num_heads = 1 , dropout = 0.1 , batch_first = True
395+ )
396+
397+ def forward (self , query , key , value , need_weights = False ):
398+ return self .attn (query , key , value , need_weights = need_weights )
399+
400+ m = fx .export_and_import (
401+ AttentionLikeConstantModule (),
402+ torch .randn (1 , 10 , 64 ), # query
403+ torch .randn (1 , 10 , 64 ), # key
404+ torch .randn (1 , 10 , 64 ), # value
405+ experimental_support_mutation = True ,
406+ )
407+ print (m )
408+ m .operation .verify ()
0 commit comments