Skip to content

Commit 96ce554

Browse files
committed
Make compiler toolkit works with checkpoint
The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. ghstack-source-id: 3be8074 Pull-Request: #2030
1 parent 2d4561b commit 96ce554

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import contextlib
88
from pathlib import Path
9-
from typing import Callable, List, Optional
9+
from typing import Any, Callable, List, Optional
1010

1111
import torch
1212
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
@@ -168,6 +168,18 @@ def __delattr__(self, name: str) -> None:
168168
else:
169169
super().__delattr__(name)
170170

171+
def state_dict(self, *args, **kwargs) -> Any:
172+
return self.inner.state_dict(*args, **kwargs)
173+
174+
def load_state_dict(self, *args, **kwargs) -> Any:
175+
return self.inner.load_state_dict(*args, **kwargs)
176+
177+
def name_parameters(self, *args, **kwargs) -> Any:
178+
return self.inner.named_parameters(*args, **kwargs)
179+
180+
def parameters(self, *args, **kwargs) -> Any:
181+
return self.inner.parameters(*args, **kwargs)
182+
171183
def forward(self, *args, **kwargs):
172184
assert "forward" not in self._overrides, "forward cannot be overridden"
173185

0 commit comments

Comments
 (0)