@@ -719,6 +719,31 @@ def normalize_device_name(code: str) -> str:
719719 )
720720 return re .sub (reg_pattern_for_torch_device , "device=DEVICE" , normalized_code )
721721
722+ @staticmethod
723+ def normalize_codegen_variants (code : str ) -> str :
724+ # TODO(oulgen): Remove when PyTorch 2.10 becomes stable
725+
726+ # Remove libdevice import line if present
727+ code = re .sub (
728+ r"^\s*from torch\._inductor\.runtime\.triton_compat import libdevice\s*\n?" ,
729+ "" ,
730+ code ,
731+ flags = re .MULTILINE ,
732+ )
733+
734+ # Normalize sqrt variants
735+ # libdevice.sqrt( -> tl.sqrt_rn(
736+ code = re .sub (r"\blibdevice\.sqrt\s*\(" , "tl.sqrt_rn(" , code )
737+ # tl.sqrt( -> tl.sqrt_rn(
738+ return re .sub (r"\btl\.sqrt\s*\(" , "tl.sqrt_rn(" , code )
739+
740+ @classmethod
741+ def normalize_code (cls , code : str ) -> str :
742+ code = cls .normalize_tensor_descriptors (code )
743+ code = cls .normalize_device_name (code )
744+ code = cls .normalize_codegen_variants (code )
745+ return code .strip ()
746+
722747 def lookup (self , test_id : str , value : str ) -> tuple [str , str ]:
723748 test_id = self .normalize_id (test_id )
724749 if self ._current_id != test_id :
@@ -733,9 +758,10 @@ def lookup(self, test_id: str, value: str) -> tuple[str, str]:
733758 expected_values .append ("" )
734759 expected = ""
735760
736- value = self .normalize_tensor_descriptors (value )
737- value = self .normalize_device_name (value )
738- value = value .strip ()
761+ # Normalize both actual and expected for robust comparisons
762+ value = self .normalize_code (value )
763+ expected = self .normalize_code (expected )
764+
739765 if value != expected and os .environ .get ("EXPECTTEST_ACCEPT" , "0" ) not in {
740766 "0" ,
741767 "false" ,
0 commit comments