@@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
635635 module .forward (tu .rand (1 , 2 , 1 , 3 ), tu .rand (2 ), tu .rand (2 ))
636636
637637
638+ # ==============================================================================
639+ class RMSNormModule (torch .nn .Module ):
640+ def __init__ (self ):
641+ super ().__init__ ()
642+
643+ @export
644+ @annotate_args (
645+ [
646+ None ,
647+ ([8 , 9 , 1 , 2 , 4 ], torch .float32 , True ),
648+ ([1 , 2 , 4 ], torch .float32 , True ),
649+ ]
650+ )
651+ def forward (self , x , weight ):
652+ list = [1 , 2 , 4 ]
653+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.5 )
654+
655+
656+ @register_test_case (module_factory = lambda : RMSNormModule ())
657+ def RMSNormModule_basic (module , tu : TestUtils ):
658+ module .forward (tu .rand (8 , 9 , 1 , 2 , 4 ), tu .rand (1 , 2 , 4 ))
659+
660+
661+ class RMSNormWithoutEpsModule (torch .nn .Module ):
662+ def __init__ (self ):
663+ super ().__init__ ()
664+
665+ @export
666+ @annotate_args (
667+ [
668+ None ,
669+ ([2 , 5 , 2 , 2 , 3 ], torch .float32 , True ),
670+ ([2 , 2 , 3 ], torch .float32 , True ),
671+ ]
672+ )
673+ def forward (self , x , weight ):
674+ list = [2 , 2 , 3 ]
675+ return torch .ops .aten .rms_norm (x , list , weight )
676+
677+
678+ @register_test_case (module_factory = lambda : RMSNormWithoutEpsModule ())
679+ def RMSNormWithoutEpsModule_basic (module , tu : TestUtils ):
680+ module .forward (tu .rand (2 , 5 , 2 , 2 , 3 ), tu .rand (2 , 2 , 3 ))
681+
682+
683+ class RMSNormWithoutWeightModule (torch .nn .Module ):
684+ def __init__ (self ):
685+ super ().__init__ ()
686+
687+ @export
688+ @annotate_args (
689+ [
690+ None ,
691+ ([1 , 2 , 3 , 4 ], torch .float32 , True ),
692+ ]
693+ )
694+ def forward (self , x ):
695+ list = [4 ]
696+ return torch .ops .aten .rms_norm (x , list , eps = 0.5 )
697+
698+
699+ @register_test_case (module_factory = lambda : RMSNormWithoutWeightModule ())
700+ def RMSNormWithoutWeightModule_basic (module , tu : TestUtils ):
701+ module .forward (tu .rand (1 , 2 , 3 , 4 ))
702+
703+
704+ class RMSNormAllNormalizeModule (torch .nn .Module ):
705+ def __init__ (self ):
706+ super ().__init__ ()
707+
708+ @export
709+ @annotate_args (
710+ [None , ([5 , 6 , 3 ], torch .float32 , True ), ([5 , 6 , 3 ], torch .float32 , True )]
711+ )
712+ def forward (self , x , weight ):
713+ list = [5 , 6 , 3 ]
714+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.7 )
715+
716+
717+ @register_test_case (module_factory = lambda : RMSNormAllNormalizeModule ())
718+ def RMSNormAllNormalizeModule_basic (module , tu : TestUtils ):
719+ module .forward (tu .rand (5 , 6 , 3 ), tu .rand (5 , 6 , 3 ))
720+
721+
722+ class RMSNormDynamicModule (torch .nn .Module ):
723+ def __init__ (self ):
724+ super ().__init__ ()
725+
726+ @export
727+ @annotate_args (
728+ [
729+ None ,
730+ ([- 1 , - 1 , - 1 , - 1 ], torch .float32 , True ),
731+ ([- 1 , - 1 , - 1 ], torch .float32 , True ),
732+ ]
733+ )
734+ def forward (self , x , weight ):
735+ list = [2 , 3 , 4 ]
736+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.8 )
737+
738+
739+ @register_test_case (module_factory = lambda : RMSNormDynamicModule ())
740+ def RMSNormDynamicModule_basic (module , tu : TestUtils ):
741+ module .forward (tu .rand (1 , 2 , 3 , 4 ), tu .rand (2 , 3 , 4 ))
742+
743+
638744# ==============================================================================
639745class RenormModuleFloat32 (torch .nn .Module ):
640746 def __init__ (self ):
0 commit comments