Skip to content

Commit 8ff3234

Browse files
authored
Normalize tl.sqrt and libdevice.sqrt for tests (#866)
1 parent 5e11da4 commit 8ff3234

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

helion/_testing.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,31 @@ def normalize_device_name(code: str) -> str:
719719
)
720720
return re.sub(reg_pattern_for_torch_device, "device=DEVICE", normalized_code)
721721

722+
@staticmethod
723+
def normalize_codegen_variants(code: str) -> str:
724+
# TODO(oulgen): Remove when PyTorch 2.10 becomes stable
725+
726+
# Remove libdevice import line if present
727+
code = re.sub(
728+
r"^\s*from torch\._inductor\.runtime\.triton_compat import libdevice\s*\n?",
729+
"",
730+
code,
731+
flags=re.MULTILINE,
732+
)
733+
734+
# Normalize sqrt variants
735+
# libdevice.sqrt( -> tl.sqrt_rn(
736+
code = re.sub(r"\blibdevice\.sqrt\s*\(", "tl.sqrt_rn(", code)
737+
# tl.sqrt( -> tl.sqrt_rn(
738+
return re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code)
739+
740+
@classmethod
741+
def normalize_code(cls, code: str) -> str:
742+
code = cls.normalize_tensor_descriptors(code)
743+
code = cls.normalize_device_name(code)
744+
code = cls.normalize_codegen_variants(code)
745+
return code.strip()
746+
722747
def lookup(self, test_id: str, value: str) -> tuple[str, str]:
723748
test_id = self.normalize_id(test_id)
724749
if self._current_id != test_id:
@@ -733,9 +758,10 @@ def lookup(self, test_id: str, value: str) -> tuple[str, str]:
733758
expected_values.append("")
734759
expected = ""
735760

736-
value = self.normalize_tensor_descriptors(value)
737-
value = self.normalize_device_name(value)
738-
value = value.strip()
761+
# Normalize both actual and expected for robust comparisons
762+
value = self.normalize_code(value)
763+
expected = self.normalize_code(expected)
764+
739765
if value != expected and os.environ.get("EXPECTTEST_ACCEPT", "0") not in {
740766
"0",
741767
"false",

0 commit comments

Comments
 (0)