Skip to content

Commit c3b5f3e

Browse files
committed
add the post_init
1 parent 49cea07 commit c3b5f3e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/utils/test_modeling_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(self, config):
134134
super().__init__(config)
135135
self.linear = nn.Linear(5, 5)
136136
self.linear_2 = nn.Linear(5, 5)
137+
self.post_init()
137138

138139
def forward(self, x):
139140
return self.linear_2(self.linear(x))
@@ -147,6 +148,7 @@ def __init__(self, config):
147148
super().__init__(config)
148149
self.linear = nn.Linear(50, 50)
149150
self.linear_2 = nn.Linear(50, 50)
151+
self.post_init()
150152

151153
def forward(self, x):
152154
return self.linear_2(self.linear(x))
@@ -160,6 +162,7 @@ def __init__(self, config):
160162
super().__init__(config)
161163
self.linear = nn.Linear(50, 50)
162164
self.linear_2 = nn.Linear(50, 50)
165+
self.post_init()
163166

164167
def forward(self, x):
165168
return self.linear_2(self.linear(x))
@@ -171,6 +174,7 @@ def __init__(self, config):
171174
super().__init__(config)
172175
self.linear = nn.Linear(5, 5)
173176
self.linear_2 = nn.Linear(5, 5)
177+
self.post_init()
174178

175179
def forward(self, x):
176180
return self.linear_2(self.linear(x))
@@ -193,6 +197,7 @@ def __init__(self, config):
193197
# linear is a common name between Base and Head on purpose.
194198
self.linear = nn.Linear(5, 5)
195199
self.linear2 = nn.Linear(5, 5)
200+
self.post_init()
196201

197202
def forward(self, x):
198203
return self.linear2(self.linear(self.base(x)))
@@ -209,6 +214,7 @@ def __init__(self, config):
209214
# direct params and submodules is helpful for testing offloading logic
210215
self.weight = nn.Parameter(torch.rand((5, 5)))
211216
self.base = BaseModel(config)
217+
self.post_init()
212218

213219
def forward(self, x):
214220
return self.base(x @ self.weight.T)
@@ -225,6 +231,7 @@ def __init__(self, config):
225231
self.submodule = ModelWithDirectParam(config)
226232
# needed so model can have at least one module on accelerator
227233
self.linear = nn.Linear(5, 5)
234+
self.post_init()
228235

229236
def forward(self, x):
230237
return self.linear(self.submodule(x))
@@ -240,6 +247,7 @@ def __init__(self, config):
240247
super().__init__(config)
241248
self.base = BaseModel(config)
242249
self.decoder = nn.Linear(5, 5)
250+
self.post_init()
243251

244252
def forward(self, x):
245253
return self.decoder(self.base(x))

0 commit comments

Comments
 (0)