Skip to content

Commit dcdf8ef

Browse files
committed
Make compiler toolkit works with checkpoint
The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. ghstack-source-id: a16c514 Pull-Request: #2030
1 parent 3b401c1 commit dcdf8ef

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import contextlib
88
from pathlib import Path
9-
from typing import Callable, List, Optional
9+
from typing import Any, Callable, List, Optional
1010

1111
import torch
1212
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
@@ -168,6 +168,18 @@ def __delattr__(self, name: str) -> None:
168168
else:
169169
super().__delattr__(name)
170170

171+
def state_dict(self, *args, **kwargs) -> Any:
172+
return self.inner.state_dict(*args, **kwargs)
173+
174+
def load_state_dict(self, *args, **kwargs) -> Any:
175+
return self.inner.load_state_dict(*args, **kwargs)
176+
177+
def name_parameters(self, *args, **kwargs) -> Any:
178+
return self.inner.named_parameters(*args, **kwargs)
179+
180+
def parameters(self, *args, **kwargs) -> Any:
181+
return self.inner.parameters(*args, **kwargs)
182+
171183
def forward(self, *args, **kwargs):
172184
assert "forward" not in self._overrides, "forward cannot be overridden"
173185

0 commit comments

Comments
 (0)