@@ -134,6 +134,7 @@ def __init__(self, config):
134134 super ().__init__ (config )
135135 self .linear = nn .Linear (5 , 5 )
136136 self .linear_2 = nn .Linear (5 , 5 )
137+ self .post_init ()
137138
138139 def forward (self , x ):
139140 return self .linear_2 (self .linear (x ))
@@ -147,6 +148,7 @@ def __init__(self, config):
147148 super ().__init__ (config )
148149 self .linear = nn .Linear (50 , 50 )
149150 self .linear_2 = nn .Linear (50 , 50 )
151+ self .post_init ()
150152
151153 def forward (self , x ):
152154 return self .linear_2 (self .linear (x ))
@@ -160,6 +162,7 @@ def __init__(self, config):
160162 super ().__init__ (config )
161163 self .linear = nn .Linear (50 , 50 )
162164 self .linear_2 = nn .Linear (50 , 50 )
165+ self .post_init ()
163166
164167 def forward (self , x ):
165168 return self .linear_2 (self .linear (x ))
@@ -171,6 +174,7 @@ def __init__(self, config):
171174 super ().__init__ (config )
172175 self .linear = nn .Linear (5 , 5 )
173176 self .linear_2 = nn .Linear (5 , 5 )
177+ self .post_init ()
174178
175179 def forward (self , x ):
176180 return self .linear_2 (self .linear (x ))
@@ -193,6 +197,7 @@ def __init__(self, config):
193197 # linear is a common name between Base and Head on purpose.
194198 self .linear = nn .Linear (5 , 5 )
195199 self .linear2 = nn .Linear (5 , 5 )
200+ self .post_init ()
196201
197202 def forward (self , x ):
198203 return self .linear2 (self .linear (self .base (x )))
@@ -209,6 +214,7 @@ def __init__(self, config):
209214 # direct params and submodules is helpful for testing offloading logic
210215 self .weight = nn .Parameter (torch .rand ((5 , 5 )))
211216 self .base = BaseModel (config )
217+ self .post_init ()
212218
213219 def forward (self , x ):
214220 return self .base (x @ self .weight .T )
@@ -225,6 +231,7 @@ def __init__(self, config):
225231 self .submodule = ModelWithDirectParam (config )
226232 # needed so model can have at least one module on accelerator
227233 self .linear = nn .Linear (5 , 5 )
234+ self .post_init ()
228235
229236 def forward (self , x ):
230237 return self .linear (self .submodule (x ))
@@ -240,6 +247,7 @@ def __init__(self, config):
240247 super ().__init__ (config )
241248 self .base = BaseModel (config )
242249 self .decoder = nn .Linear (5 , 5 )
250+ self .post_init ()
243251
244252 def forward (self , x ):
245253 return self .decoder (self .base (x ))
0 commit comments