Skip to content

Commit f5d2b18

Browse files
authored
Add dry run mode (#2012)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2012 * #2011 Summary: The current configuration validation requires torchx and GPUs. It can waste time, resources, ane engery. Polar bears are crying. Let's fix this by providing a dry run mode. This PR doesn't verify everything. In theory, we should be able to verify parallelisms settings as well. This PR is just a start but it at least can let us catch the typos quickly.
1 parent 11d73a2 commit f5d2b18

File tree

3 files changed

+172
-7
lines changed

3 files changed

+172
-7
lines changed

run_train.sh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,24 @@ set -ex
1010
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
13+
# DRY_RUN=1 ./run_train.sh # for config validation without GPU
1314
NGPU=${NGPU:-"8"}
1415
export LOG_RANK=${LOG_RANK:-0}
1516
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1617
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
18+
DRY_RUN=${DRY_RUN:-0}
1719

1820
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
1921

20-
PYTORCH_ALLOC_CONF="expandable_segments:True" \
21-
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
22-
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
23-
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
24-
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
22+
if [ "$DRY_RUN" = "1" ]; then
23+
# Dry run mode: validate configuration without GPU/distributed setup
24+
echo "Running in DRY RUN mode - configuration validation only"
25+
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
26+
else
27+
# Normal training with torchrun
28+
PYTORCH_ALLOC_CONF="expandable_segments:True" \
29+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
30+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
31+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
32+
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
33+
fi

scripts/dry_run.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Dry run trainer for fast configuration validation without GPU/distributed setup.
9+
10+
This module provides a lightweight trainer that validates job configurations,
11+
model architecture, and dataloader setup without requiring GPU resources or
12+
distributed initialization. Useful for rapid iteration on configuration files
13+
and CI/CD validation pipelines.
14+
"""
15+
16+
import os
17+
import sys
18+
19+
# Add parent directory to path to import torchtitan
20+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21+
22+
import torch
23+
24+
import torchtitan.protocols.train_spec as train_spec_module
25+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
26+
from torchtitan.tools import utils
27+
from torchtitan.tools.logging import logger
28+
from torchtitan.train import main, Trainer
29+
30+
31+
class DryRunTrainer(Trainer):
32+
"""
33+
A lightweight trainer that validates configurations without GPU allocation.
34+
35+
This trainer performs comprehensive validation of the training configuration
36+
without allocating GPU resources or initializing distributed setup. It validates:
37+
38+
- Configuration file parsing and structure
39+
- Model architecture (constructed on meta device)
40+
- Tokenizer initialization
41+
- Dataloader configuration
42+
- Parallelism settings
43+
- Model converters (if specified)
44+
45+
Unlike the regular Trainer, this does not:
46+
- Allocate GPU memory
47+
- Initialize distributed process groups
48+
- Create optimizers or learning rate schedulers
49+
- Set up checkpointing or metrics
50+
- Run any actual training
51+
52+
Args:
53+
job_config: JobConfig containing all training configuration parameters
54+
55+
Note:
56+
Validation completes immediately after initialization. No training loop is executed.
57+
All operations use CPU and meta devices for zero-cost validation.
58+
"""
59+
60+
def __init__(self, job_config: JobConfig):
61+
torch._C._log_api_usage_once("torchtitan.dry_run")
62+
63+
self.job_config = job_config
64+
65+
logger.info(f"Starting job: {job_config.job.description}")
66+
logger.info("DRY RUN MODE - Configuration validation only")
67+
68+
# Use CPU device (no GPU required)
69+
self.device = torch.device("cpu")
70+
71+
# Log and validate config
72+
job_config.maybe_log()
73+
logger.info("Configuration parsed successfully")
74+
75+
# Get train spec
76+
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
77+
logger.info(f"Train spec loaded for model: {job_config.model.name}")
78+
79+
# Build tokenizer
80+
self.tokenizer = (
81+
self.train_spec.build_tokenizer_fn(job_config)
82+
if self.train_spec.build_tokenizer_fn is not None
83+
else None
84+
)
85+
if self.tokenizer:
86+
logger.info("Tokenizer built successfully")
87+
88+
# Validate model configuration
89+
model_args = self.train_spec.model_args[job_config.model.flavor]
90+
model_args.update_from_config(job_config)
91+
self.model_args = model_args
92+
93+
logger.info(
94+
f"Model args validated: {job_config.model.name} {job_config.model.flavor}"
95+
)
96+
97+
# Build model on meta device (validates architecture without memory allocation)
98+
logger.info("Validating model architecture...")
99+
with (
100+
torch.device("meta"),
101+
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
102+
):
103+
model = self.train_spec.model_cls(model_args)
104+
105+
# Calculate and log model size
106+
model_param_count, _ = model_args.get_nparams_and_flops(
107+
model, job_config.training.seq_len
108+
)
109+
logger.info(
110+
f"Model architecture validated: {job_config.model.name} "
111+
f"with {model_param_count:,} parameters"
112+
)
113+
114+
# Validate dataloader configuration (build with minimal params)
115+
logger.info("Validating dataloader configuration...")
116+
try:
117+
# Use dp_world_size=1 and dp_rank=0 for dry run
118+
dataloader = self.train_spec.build_dataloader_fn(
119+
dp_world_size=1,
120+
dp_rank=0,
121+
tokenizer=self.tokenizer,
122+
job_config=job_config,
123+
)
124+
logger.info("Dataloader configuration validated successfully")
125+
except Exception as e:
126+
logger.warning(f"Dataloader validation encountered issue: {e}")
127+
logger.info(
128+
"Note: Some dataloader issues may only appear with actual data paths"
129+
)
130+
131+
# Validate model converters if specified
132+
if job_config.model.converters:
133+
logger.info(f"Model converters specified: {job_config.model.converters}")
134+
135+
# Validate parallelism configuration
136+
parallelism_config = job_config.parallelism
137+
logger.info(
138+
f"Parallelism config: "
139+
f"DP-shard={parallelism_config.data_parallel_shard_degree}, "
140+
f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, "
141+
f"TP={parallelism_config.tensor_parallel_degree}, "
142+
f"PP={parallelism_config.pipeline_parallel_degree}, "
143+
f"CP={parallelism_config.context_parallel_degree}"
144+
)
145+
146+
# Summary
147+
logger.info("=" * 80)
148+
logger.info("DRY RUN VALIDATION COMPLETE")
149+
logger.info("=" * 80)
150+
logger.info("All configurations validated successfully!")
151+
logger.info("Configuration is ready for training execution.")
152+
logger.info("=" * 80)
153+
154+
155+
if __name__ == "__main__":
156+
main(DryRunTrainer)

torchtitan/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ def load_state_dict(self, state_dict: dict[str, Any]):
698698
self.ntokens_seen = state_dict["ntokens_seen"]
699699

700700
def close(self) -> None:
701-
if self.checkpointer:
701+
if hasattr(self, "checkpointer") and self.checkpointer:
702702
self.checkpointer.close()
703-
if self.metrics_processor:
703+
if hasattr(self, "metrics_processor") and self.metrics_processor:
704704
self.metrics_processor.close()
705705

706706

0 commit comments

Comments
 (0)