@@ -183,7 +183,6 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
183183 return rtol , atol
184184
185185
186-
187186def op_assert_ref (test_case , op , test_dtype , i , orig , decomp , ref , args , kwargs ):
188187 assert orig .dtype == decomp .dtype , f"{ i } Operation: { op } "
189188 if orig .numel () == 0 or decomp .numel () == 0 :
@@ -432,7 +431,7 @@ def normalize_op_input_output(f, sample, requires_grad=True):
432431 "bernoulli" ,
433432 ), # bernoulli is a function of randomness, so couldn't do cross-reference.
434433 # XPU specific exclude cases
435- # ("xpu", None, "some_xpu_specific_op"), # 根据需要添加 XPU 特定的排除项
434+ # ("xpu", None, "some_xpu_specific_op"),
436435}
437436
438437CROSS_REF_BACKWARD_EXCLUDE_SET = {
@@ -445,7 +444,7 @@ def normalize_op_input_output(f, sample, requires_grad=True):
445444 "bernoulli" ,
446445 ), # bernoulli is a function of randomness, so couldn't do cross-reference.
447446 # XPU specific backward exclude cases
448- # ("xpu", torch.float16, "nn.functional.some_op"), # 根据需要添加
447+ # ("xpu", torch.float16, "nn.functional.some_op"),
449448}
450449
451450all_decomposed = set ()
0 commit comments