Skip to content

Commit e1a96e8

Browse files
committed
fix recursion error when setting tp_wrapped_module #122
1 parent a7d1939 commit e1a96e8

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/tensor_parallel/wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ def forward(self, *args, **kwargs):
7373

7474
def __getattr__(self, attr):
7575
return getattr(self.tp_wrapped_module, attr)
76+
77+
def __setattr__(self, attr, value):
78+
super().__setattr__(attr, value)
79+
if attr == "tp_wrapped_module":
80+
self.__dict__["tp_wrapped_module"] = value # to be accessible without getattr after nn.Module removed it from __dict__

0 commit comments

Comments
 (0)