|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Dry run trainer for fast configuration validation without GPU/distributed setup. |
| 9 | +
|
| 10 | +This module provides a lightweight trainer that validates job configurations, |
| 11 | +model architecture, and dataloader setup without requiring GPU resources or |
| 12 | +distributed initialization. Useful for rapid iteration on configuration files |
| 13 | +and CI/CD validation pipelines. |
| 14 | +""" |
| 15 | + |
| 16 | +import os |
| 17 | +import sys |
| 18 | + |
| 19 | +# Add parent directory to path to import torchtitan |
| 20 | +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 21 | + |
| 22 | +import torch |
| 23 | + |
| 24 | +import torchtitan.protocols.train_spec as train_spec_module |
| 25 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
| 26 | +from torchtitan.tools import utils |
| 27 | +from torchtitan.tools.logging import logger |
| 28 | +from torchtitan.train import main, Trainer |
| 29 | + |
| 30 | + |
| 31 | +class DryRunTrainer(Trainer): |
| 32 | + """ |
| 33 | + A lightweight trainer that validates configurations without GPU allocation. |
| 34 | +
|
| 35 | + This trainer performs comprehensive validation of the training configuration |
| 36 | + without allocating GPU resources or initializing distributed setup. It validates: |
| 37 | +
|
| 38 | + - Configuration file parsing and structure |
| 39 | + - Model architecture (constructed on meta device) |
| 40 | + - Tokenizer initialization |
| 41 | + - Dataloader configuration |
| 42 | + - Parallelism settings |
| 43 | + - Model converters (if specified) |
| 44 | +
|
| 45 | + Unlike the regular Trainer, this does not: |
| 46 | + - Allocate GPU memory |
| 47 | + - Initialize distributed process groups |
| 48 | + - Create optimizers or learning rate schedulers |
| 49 | + - Set up checkpointing or metrics |
| 50 | + - Run any actual training |
| 51 | +
|
| 52 | + Args: |
| 53 | + job_config: JobConfig containing all training configuration parameters |
| 54 | +
|
| 55 | + Note: |
| 56 | + Validation completes immediately after initialization. No training loop is executed. |
| 57 | + All operations use CPU and meta devices for zero-cost validation. |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__(self, job_config: JobConfig): |
| 61 | + torch._C._log_api_usage_once("torchtitan.dry_run") |
| 62 | + |
| 63 | + self.job_config = job_config |
| 64 | + |
| 65 | + logger.info(f"Starting job: {job_config.job.description}") |
| 66 | + logger.info("DRY RUN MODE - Configuration validation only") |
| 67 | + |
| 68 | + # Use CPU device (no GPU required) |
| 69 | + self.device = torch.device("cpu") |
| 70 | + |
| 71 | + # Log and validate config |
| 72 | + job_config.maybe_log() |
| 73 | + logger.info("Configuration parsed successfully") |
| 74 | + |
| 75 | + # Get train spec |
| 76 | + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) |
| 77 | + logger.info(f"Train spec loaded for model: {job_config.model.name}") |
| 78 | + |
| 79 | + # Build tokenizer |
| 80 | + self.tokenizer = ( |
| 81 | + self.train_spec.build_tokenizer_fn(job_config) |
| 82 | + if self.train_spec.build_tokenizer_fn is not None |
| 83 | + else None |
| 84 | + ) |
| 85 | + if self.tokenizer: |
| 86 | + logger.info("Tokenizer built successfully") |
| 87 | + |
| 88 | + # Validate model configuration |
| 89 | + model_args = self.train_spec.model_args[job_config.model.flavor] |
| 90 | + model_args.update_from_config(job_config) |
| 91 | + self.model_args = model_args |
| 92 | + |
| 93 | + logger.info( |
| 94 | + f"Model args validated: {job_config.model.name} {job_config.model.flavor}" |
| 95 | + ) |
| 96 | + |
| 97 | + # Build model on meta device (validates architecture without memory allocation) |
| 98 | + logger.info("Validating model architecture...") |
| 99 | + with ( |
| 100 | + torch.device("meta"), |
| 101 | + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), |
| 102 | + ): |
| 103 | + model = self.train_spec.model_cls(model_args) |
| 104 | + |
| 105 | + # Calculate and log model size |
| 106 | + model_param_count, _ = model_args.get_nparams_and_flops( |
| 107 | + model, job_config.training.seq_len |
| 108 | + ) |
| 109 | + logger.info( |
| 110 | + f"Model architecture validated: {job_config.model.name} " |
| 111 | + f"with {model_param_count:,} parameters" |
| 112 | + ) |
| 113 | + |
| 114 | + # Validate dataloader configuration (build with minimal params) |
| 115 | + logger.info("Validating dataloader configuration...") |
| 116 | + try: |
| 117 | + # Use dp_world_size=1 and dp_rank=0 for dry run |
| 118 | + dataloader = self.train_spec.build_dataloader_fn( |
| 119 | + dp_world_size=1, |
| 120 | + dp_rank=0, |
| 121 | + tokenizer=self.tokenizer, |
| 122 | + job_config=job_config, |
| 123 | + ) |
| 124 | + logger.info("Dataloader configuration validated successfully") |
| 125 | + except Exception as e: |
| 126 | + logger.warning(f"Dataloader validation encountered issue: {e}") |
| 127 | + logger.info( |
| 128 | + "Note: Some dataloader issues may only appear with actual data paths" |
| 129 | + ) |
| 130 | + |
| 131 | + # Validate model converters if specified |
| 132 | + if job_config.model.converters: |
| 133 | + logger.info(f"Model converters specified: {job_config.model.converters}") |
| 134 | + |
| 135 | + # Validate parallelism configuration |
| 136 | + parallelism_config = job_config.parallelism |
| 137 | + logger.info( |
| 138 | + f"Parallelism config: " |
| 139 | + f"DP-shard={parallelism_config.data_parallel_shard_degree}, " |
| 140 | + f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, " |
| 141 | + f"TP={parallelism_config.tensor_parallel_degree}, " |
| 142 | + f"PP={parallelism_config.pipeline_parallel_degree}, " |
| 143 | + f"CP={parallelism_config.context_parallel_degree}" |
| 144 | + ) |
| 145 | + |
| 146 | + # Summary |
| 147 | + logger.info("=" * 80) |
| 148 | + logger.info("DRY RUN VALIDATION COMPLETE") |
| 149 | + logger.info("=" * 80) |
| 150 | + logger.info("All configurations validated successfully!") |
| 151 | + logger.info("Configuration is ready for training execution.") |
| 152 | + logger.info("=" * 80) |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + main(DryRunTrainer) |
0 commit comments