@@ -256,9 +256,12 @@ def _maybe_map_weights(
256256 # Extract weights from policy module using merge_and_unload for LLMs
257257 if not hasattr (server_weights , "model" ):
258258 raise ValueError ("TensorDictModuleBase must have a 'model' attribute" )
259- if not hasattr (server_weights .model , "merge_and_unload" ):
260- raise ValueError ("Model must have a 'merge_and_unload' method" )
261- return TensorDict (server_weights .model .merge_and_unload ().state_dict (), [])
259+ # Check if it's a LoRA model
260+ if hasattr (server_weights .model , "merge_and_unload" ):
261+ state_dict = server_weights .model .merge_and_unload ().state_dict ()
262+ else :
263+ state_dict = server_weights .model .state_dict ()
264+ return TensorDict (state_dict , [])
262265 elif isinstance (server_weights , TensorDictBase ):
263266 return server_weights
264267 elif isinstance (server_weights , dict ):
@@ -281,7 +284,11 @@ def get_model_metadata(
281284 Returns:
282285 dict[str, tuple[torch.dtype, torch.Size]]: The model metadata.
283286 """
284- sd = model .model .merge_and_unload ().state_dict ()
287+ # Check if the model has a LoRA adapter
288+ if hasattr (model .model , "merge_and_unload" ):
289+ sd = model .model .merge_and_unload ().state_dict ()
290+ else :
291+ sd = model .model .state_dict ()
285292 model_metadata = {k : (v .dtype , v .shape ) for k , v in sd .items ()}
286293 return model_metadata
287294
0 commit comments