1111import lightning as L
1212import torch
1313from lightning .fabric .plugins import BitsandbytesPrecision
14- from lightning .fabric .strategies import FSDPStrategy
14+ from lightning .fabric .strategies import ModelParallelStrategy
1515from lightning .fabric .utilities import ThroughputMonitor
1616from lightning_utilities .core .imports import RequirementCache
1717from torch .utils .data import ConcatDataset , DataLoader
2020from litgpt .args import EvalArgs , LogArgs , TrainArgs
2121from litgpt .data import Alpaca , DataModule
2222from litgpt .generate .base import generate
23- from litgpt .lora import GPT , Block , Config , lora_filter , mark_only_lora_as_trainable
23+ from litgpt .lora import GPT , Block , Config , mark_only_lora_as_trainable
2424from litgpt .prompts import save_prompt_style
2525from litgpt .scripts .merge_lora import merge_lora
2626from litgpt .tokenizer import Tokenizer
@@ -70,6 +70,7 @@ def setup(
7070 lr_warmup_steps = 100 ,
7171 epochs = 5 ,
7272 max_seq_length = None ,
73+ max_time = None ,
7374 ),
7475 log : LogArgs = LogArgs (),
7576 eval : EvalArgs = EvalArgs (interval = 100 , max_new_tokens = 100 , max_iters = 100 ),
@@ -105,6 +106,7 @@ def setup(
105106 seed: The random seed to use for reproducibility.
106107 access_token: Optional API token to access models with restrictions.
107108 """
109+
108110 checkpoint_dir = auto_download_checkpoint (model_name = checkpoint_dir , access_token = access_token )
109111 pprint (locals ())
110112 data = Alpaca () if data is None else data
@@ -152,12 +154,10 @@ def setup(
152154 "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
153155 " when using the --quantize flag."
154156 )
155- strategy = FSDPStrategy (
156- auto_wrap_policy = {torch .nn .Linear },
157- activation_checkpointing_policy = {Block },
158- state_dict_type = "full" ,
159- limit_all_gathers = True ,
160- cpu_offload = False ,
157+ strategy = ModelParallelStrategy (
158+ parallelize_fn = parallelize_fn ,
159+ data_parallel_size = devices * num_nodes ,
160+ tensor_parallel_size = 1 ,
161161 )
162162 else :
163163 strategy = "auto"
@@ -174,7 +174,9 @@ def setup(
174174 if torch .cuda .is_available () and devices > 1 :
175175 check_nvlink_connectivity (fabric )
176176
177- fabric .launch (main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes )
177+ fabric .launch (
178+ main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes , precision
179+ )
178180
179181
180182def main (
@@ -189,6 +191,7 @@ def main(
189191 eval : EvalArgs ,
190192 optimizer : Union [str , Dict ],
191193 num_nodes : int = 1 ,
194+ precision : Optional [str ] = None ,
192195) -> None :
193196 validate_args (train , eval )
194197
@@ -229,7 +232,6 @@ def main(
229232 optimizer = fabric .setup_optimizers (optimizer )
230233 scheduler = get_lr_scheduler (optimizer , warmup_steps = train .lr_warmup_steps , max_steps = lr_max_steps )
231234
232- # strict=False because missing keys due to LoRA weights not contained in state dict
233235 load_checkpoint (fabric , model , checkpoint_path , strict = False )
234236
235237 train_time = time .perf_counter ()
@@ -264,12 +266,19 @@ def main(
264266 save_path = out_dir / "final" / "lit_model.pth.lora"
265267 save_path .parent .mkdir (parents = True , exist_ok = True )
266268 save_lora_checkpoint (fabric , model , save_path )
269+
270+ fabric .barrier ()
267271 if fabric .global_rank == 0 :
268272 # Copy checkpoint files from original checkpoint dir
269273 copy_config_files (checkpoint_dir , save_path .parent )
270274 save_hyperparameters (setup , save_path .parent )
271275 save_prompt_style (data .prompt_style , save_path .parent )
272- merge_lora (checkpoint_dir = save_path .parent )
276+ merge_lora (
277+ checkpoint_dir = save_path .parent ,
278+ pretrained_checkpoint_dir = checkpoint_dir ,
279+ precision = precision ,
280+ )
281+ fabric .barrier ()
273282
274283
275284def fit (
@@ -316,6 +325,8 @@ def fit(
316325 total_lengths = 0
317326 total_t0 = time .perf_counter ()
318327
328+ max_time = train .max_time or float ("inf" )
329+
319330 token_counts = {
320331 "raw_tokens" : torch .tensor (0 , device = fabric .device , dtype = torch .long ),
321332 "raw_tokens_plus_prompt_template" : torch .tensor (0 , device = fabric .device , dtype = torch .long ),
@@ -327,6 +338,12 @@ def fit(
327338 iter_t0 = time .perf_counter ()
328339 batch = next (train_iterator )
329340 if train_iterator .epoch >= train .epochs :
341+ generate_example (fabric , model , tokenizer , eval , data )
342+ fabric .print (f"Number of epochs { train .epochs } reached, stopping training..." )
343+ break
344+ if iter_t0 - total_t0 > max_time :
345+ generate_example (fabric , model , tokenizer , eval , data )
346+ fabric .print (f"Max time ({ max_time / 60.0 :.2f} m) reached, stopping training..." )
330347 break
331348 input_ids , targets = batch ["input_ids" ], batch ["labels" ]
332349
@@ -497,9 +514,45 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
497514 return longest_seq_length , longest_seq_ix
498515
499516
517+ def parallelize_fn (model , device_mesh , activation_checkpointing = True ):
518+ from torch .distributed ._composable .fsdp .fully_shard import fully_shard
519+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import CheckpointWrapper , checkpoint_wrapper
520+
521+ if activation_checkpointing :
522+ model .transformer .h = torch .nn .ModuleList (
523+ [checkpoint_wrapper (el , preserve_rng_state = False ) for el in model .transformer .h ]
524+ )
525+
526+ dp_mesh = device_mesh ["data_parallel" ]
527+
528+ for m in reversed (list (model .modules ())):
529+ if (
530+ (isinstance (m , torch .nn .Linear ) and m .weight .requires_grad )
531+ or isinstance (m , CheckpointWrapper )
532+ or isinstance (m , Block )
533+ ):
534+ fully_shard (m , mesh = dp_mesh )
535+
536+ fully_shard (model , mesh = dp_mesh )
537+
538+ return model
539+
540+
500541def save_lora_checkpoint (fabric : L .Fabric , model : torch .nn .Module , file_path : Path ) -> None :
501- fabric .print (f"Saving LoRA weights to { str (file_path )!r} " )
502- fabric .save (file_path , {"model" : model }, filter = {"model" : lora_filter })
542+ cpu_state_dict = {}
543+ sharded_sd = model .state_dict ()
544+ for param_name , param in sharded_sd .items ():
545+ if "lora_" not in param_name :
546+ continue
547+ if param .is_cpu :
548+ param = param .to (fabric .device )
549+ if hasattr (param , "_local_tensor" ):
550+ param = param .full_tensor ()
551+ if fabric .is_global_zero :
552+ cpu_state_dict [param_name ] = param .cpu ()
553+ fabric .barrier ()
554+ if fabric .is_global_zero :
555+ torch .save ({"model" : cpu_state_dict }, file_path )
503556
504557
505558def validate_args (train : TrainArgs , eval : EvalArgs ) -> None :
0 commit comments