Skip to content

Commit dd29563

Browse files
authored
fix nvfp4 serialization (#3140)
Update [ghstack-poisoned]
1 parent c63899b commit dd29563

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ def test_inference_workflow_nvfp4(
191191
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
192192
)
193193

194+
# serialization
195+
with tempfile.NamedTemporaryFile() as f:
196+
torch.save(m_mx.state_dict(), f)
197+
f.seek(0)
198+
199+
# temporary workaround for https://github.com/pytorch/ao/issues/3077
200+
torch.serialization.add_safe_globals([getattr])
201+
202+
_ = torch.load(f, weights_only=True)
203+
194204

195205
class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
196206
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _nvfp4_inference_linear_transform(
211211
NVFP4MMConfig,
212212
MXGemmKernelChoice,
213213
QuantizeTensorToMXKwargs,
214+
QuantizeTensorToNVFP4Kwargs,
214215
ScaleCalculationMode,
215216
]
216217
)

0 commit comments

Comments
 (0)