From db35980088e5f76f5712adfc11fc424b98099322 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Mon, 27 Oct 2025 11:51:25 -0700 Subject: [PATCH 01/16] Add multi-dataset evaluation support - Add eval_utils.py with run_evaluation() function for multi-dataset evaluation - Update main.py to support multi-dataset configuration and evaluation - Add validation config settings (enabled, eval_interval, eval_steps) - Refactor setup() to support dataset_val.datasets structure - Add unified forward() method with compute_gradients flag - Add evaluate() method that calls run_evaluation() - Update llama3_8b.yaml with multi-dataset configuration --- apps/sft/eval_utils.py | 326 ++++++++++++++++++++++++++++++++++++++++ apps/sft/llama3_8b.yaml | 13 ++ apps/sft/main.py | 148 +++++++++++++----- 3 files changed, 452 insertions(+), 35 deletions(-) create mode 100644 apps/sft/eval_utils.py diff --git a/apps/sft/eval_utils.py b/apps/sft/eval_utils.py new file mode 100644 index 000000000..c3421d086 --- /dev/null +++ b/apps/sft/eval_utils.py @@ -0,0 +1,326 @@ +"""Utility functions for evaluation to make main.py more concise.""" + +import logging +from typing import Any, Callable, Iterator + +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +def move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]: + """Move all tensors in batch to specified device. + + Args: + batch: Dictionary containing batch data + device: Target device + + Returns: + Batch with tensors moved to device (modifies in-place and returns) + """ + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + return batch + + +def extract_epoch_from_batch(batch: dict) -> int | None: + """Extract epoch number from batch metrics. + + Args: + batch: Batch dictionary with 'metrics' field + + Returns: + Epoch number from metrics, or None if not found + """ + if "metrics" in batch: + for metric in batch["metrics"]: + if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs": + return metric.value + return None + + +def start_epoch_sync( + epoch_increment: int, + device: torch.device, + dp_process_group: Any = None, +) -> tuple[torch.Tensor | None, Any]: + """Start async all_reduce for epoch synchronization across ranks. + + Args: + epoch_increment: Difference between current and starting epoch + device: Device for tensor + dp_process_group: Data parallel process group (None = default group) + + Returns: + Tuple of (epoch_tensor, pending_work) for async operation, or (None, None) if not initialized + """ + if not torch.distributed.is_initialized(): + return None, None + + epoch_tensor = torch.tensor([epoch_increment], dtype=torch.long, device=device) + pending_work = torch.distributed.all_reduce( + epoch_tensor, + op=torch.distributed.ReduceOp.MAX, + group=dp_process_group, + async_op=True, + ) + return epoch_tensor, pending_work + + +def check_epoch_complete( + pending_work: Any, + epoch_tensor: torch.Tensor | None, +) -> bool: + """Wait for async epoch sync and check if epoch completed. + + Args: + pending_work: Pending async all_reduce work + epoch_tensor: Tensor containing epoch increment + + Returns: + True if any rank completed an epoch, False otherwise + """ + if pending_work is None: + return False + + pending_work.wait() + if epoch_tensor is not None: + return bool((epoch_tensor > 0).any().item()) + return False + + +def eval_loop( + dataloader_iter: Iterator, + forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], + device: torch.device, + eval_steps: int, + dataset_name: str, + dp_process_group: Any = None, + extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, + log_interval: int = 10, +) -> tuple[float, int]: + """Run evaluation loop with epoch synchronization. + + Args: + dataloader_iter: Iterator over validation data + forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss tensor + device: Device for computation + eval_steps: Maximum number of eval steps (0 = no limit) + dataset_name: Name for logging + dp_process_group: Data parallel process group for epoch sync + extract_epoch_fn: Function to extract epoch from batch + log_interval: Log every N batches + + Returns: + Tuple of (avg_loss, num_batches) + """ + total_loss = torch.tensor(0.0, device=device) + num_batches, starting_epoch = 0, None + + # Prefetch first batch + next_batch = next(dataloader_iter) + should_break, pending_work, epoch_tensor = False, None, None + + with torch.no_grad(): + while True: + # Check if previous epoch sync completed + if pending_work is not None: + should_break = check_epoch_complete(pending_work, epoch_tensor) + pending_work = None + + if should_break: + logger.info( + f"[{dataset_name}] Epoch completed across all ranks - stopping evaluation" + ) + break + + if eval_steps > 0 and num_batches >= eval_steps: + logger.info(f"[{dataset_name}] Reached eval_steps cap of {eval_steps}") + break + + batch = next_batch + + # Track starting epoch + current_epoch = extract_epoch_fn(batch) + if starting_epoch is None: + starting_epoch = current_epoch + + # Prefetch next batch and start async epoch check + try: + next_batch = next(dataloader_iter) + next_epoch = extract_epoch_fn(next_batch) + + # Only check epochs if both are available + if next_epoch is not None and starting_epoch is not None: + epoch_increment = next_epoch - starting_epoch + if torch.distributed.is_initialized(): + epoch_tensor, pending_work = start_epoch_sync( + epoch_increment, device, dp_process_group + ) + else: + should_break = epoch_increment > 0 + except StopIteration: + should_break = True + + # Process current batch (overlaps with async all_reduce) + move_batch_to_device(batch, device) + labels = batch.pop("labels") + loss = forward_fn(batch, labels) + total_loss += loss + num_batches += 1 + + if num_batches % log_interval == 0: + logger.info( + f" [{dataset_name}] Eval batch {num_batches} | Loss: {loss:.4f}" + ) + + avg_loss = (total_loss / max(num_batches, 1)).item() + logger.info( + f"[{dataset_name}] COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}" + ) + + return avg_loss, num_batches + + +async def evaluate_single_dataset( + val_dataloader: Any, + dataset_name: str, + forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], + device: torch.device, + eval_steps: int, + dp_process_group: Any = None, + extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, +) -> dict[str, float]: + """Evaluate on a single validation dataset with epoch synchronization. + + Args: + val_dataloader: DataLoader for this validation dataset + dataset_name: Name of the dataset (for logging) + forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss + device: Device for computation + eval_steps: Maximum number of eval steps + dp_process_group: Data parallel process group + extract_epoch_fn: Function to extract epoch from batch + + Returns: + Dict with metrics: {"val_loss": float, "val_batches": int} + """ + avg_loss, num_batches = eval_loop( + dataloader_iter=iter(val_dataloader), + forward_fn=forward_fn, + device=device, + eval_steps=eval_steps, + dataset_name=dataset_name, + dp_process_group=dp_process_group, + extract_epoch_fn=extract_epoch_fn, + log_interval=10, + ) + + return {"val_loss": avg_loss, "val_batches": num_batches} + + +async def run_evaluation( + val_dataloaders: dict[str, Any], + model_parts: list[nn.Module], + forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], + device: torch.device, + eval_steps: int, + dp_process_group: Any = None, + extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, +) -> dict[str, dict[str, float]]: + """Run evaluation on multiple validation datasets. + + Evaluates on all configured validation datasets and returns per-dataset metrics. + Sets models to eval mode before evaluation and back to train mode after. + + Args: + val_dataloaders: Dict mapping dataset names to dataloaders + model_parts: List of model parts (for setting eval/train mode) + forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss + device: Device for computation + eval_steps: Maximum number of eval steps per dataset + dp_process_group: Data parallel process group + extract_epoch_fn: Function to extract epoch from batch + + Returns: + Dict mapping dataset name to metrics dict, e.g.: + { + "val_in_domain": {"val_loss": 2.5, "val_batches": 100}, + "val_out_domain": {"val_loss": 3.1, "val_batches": 100} + } + """ + logger.info("=" * 50) + logger.info("STARTING EVALUATION") + logger.info("=" * 50) + + # Set models to eval mode + for model_part in model_parts: + model_part.eval() + + all_metrics = {} + + # Evaluate on each dataset + for dataset_name, val_dataloader in val_dataloaders.items(): + logger.info(f"\n{'='*50}") + logger.info(f"Evaluating on dataset: {dataset_name}") + logger.info(f"{'='*50}") + + dataset_metrics = await evaluate_single_dataset( + val_dataloader=val_dataloader, + dataset_name=dataset_name, + forward_fn=forward_fn, + device=device, + eval_steps=eval_steps, + dp_process_group=dp_process_group, + extract_epoch_fn=extract_epoch_fn, + ) + all_metrics[dataset_name] = dataset_metrics + + # Set models back to train mode + for model_part in model_parts: + model_part.train() + + logger.info("\n" + "=" * 50) + logger.info("EVALUATION COMPLETE - Summary:") + for dataset_name, metrics in all_metrics.items(): + logger.info( + f" {dataset_name}: Loss={metrics['val_loss']:.4f}, Batches={metrics['val_batches']}" + ) + logger.info("=" * 50) + + return all_metrics + + +def get_dp_process_group(parallel_dims: Any) -> Any: + """Get the Data Parallel process group for epoch synchronization. + + Returns the DP process group if DP parallelism is enabled, otherwise None. + This ensures all_reduce only happens across ranks with different data. + + Args: + parallel_dims: ParallelDims object containing parallel configuration + + Returns: + DP process group or None if not available/needed + """ + if not torch.distributed.is_initialized(): + return None + + if parallel_dims is None: + return None + + # Check if DP is enabled + if not parallel_dims.dp_enabled: + # No DP parallelism, use default group (all ranks) + return None + + try: + # Get the "dp" submesh which contains only DP dimensions (dp_replicate + dp_shard) + # This excludes TP and PP ranks which should already be synchronized + dp_mesh = parallel_dims.world_mesh.get_group("dp") + return dp_mesh + except Exception as e: + logger.warning(f"Could not get DP process group, using default: {e}") + return None diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index e9ddc625a..b21d14196 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -26,6 +26,18 @@ optimizer: lr_scheduler: warmup_steps: 200 +# Unified dataset configuration +# First dataset with split='train' is used for training +dataset_val: + datasets: + - name: "train" + path: "yahma/alpaca-cleaned" + split: "train[:95%]" + + - name: "val" + path: "yahma/alpaca-cleaned" + split: "train[95%:]" + training: local_batch_size: 1 seq_len: 2048 @@ -62,6 +74,7 @@ metric_logging: group: sft_exp_${oc.env:USER} logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + # profiling: # enable_profiling: false diff --git a/apps/sft/main.py b/apps/sft/main.py index edda0b49d..41a8eec0b 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -22,6 +22,7 @@ import torch import torchtitan.experiments.forge.train_spec as forge_train_spec +from apps.sft.eval_utils import get_dp_process_group, run_evaluation from forge.controller import ForgeActor from forge.data.collate import collate_packed from forge.data.datasets.packed import PackedDataset, TextPacker @@ -81,6 +82,18 @@ def __init__(self, config: DictConfig): self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) + + # Evaluation settings - no defaults, must be explicit in config + validation_config = job_config.get("validation") + if validation_config is not None: + self.validation_enabled = validation_config.get("enabled") + self.eval_interval = validation_config.get("eval_interval") + self.eval_steps = validation_config.get("eval_steps") + else: + self.validation_enabled = False + self.eval_interval = None + self.eval_steps = None + self._init_dist() super().__init__(job_config) @@ -122,27 +135,54 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - self.train_dataloader = self.setup_data() - self.mlogger = await self.setup_metric_logger() - - # self.train_dataloader = self.setup_data( - # self.train_config.train_dataset_config, - # self.train_config.train_dataloader_config, - # self.train_config.packing_config, - # ) - # self.val_dataloader = self.setup_data( - # self.train_config.val_dataset_config, - # self.train_config.val_dataloader_config, - # self.train_config.packing_config, - # ) - - # TODO: confirm that this is working properly - # Should also use load, not dcp_load + # Always expect dataset_val.datasets configuration + dataset_val_config = self.job_config.get("dataset_val") + + datasets = dataset_val_config["datasets"] + + # Setup all datasets + self.val_dataloaders = {} + self.train_dataloader = None + + for i, dataset_spec in enumerate(datasets): + dataset_name = dataset_spec.get("name") + dataset_path = dataset_spec.get("path") + dataset_split = dataset_spec.get("split") + + if not dataset_name or not dataset_path or not dataset_split: + raise ValueError( + f"Each dataset must have 'name', 'path', and 'split'. " + f"Got dataset[{i}]: {dataset_spec}" + ) + + dataloader = self.setup_data( + dataset_path=dataset_path, + dataset_split=dataset_split, + ) + + # First dataset with split starting with 'train' is used for training + if i == 0 and dataset_split.startswith("train"): + self.train_dataloader = dataloader + logger.info( + f"Setup training dataset: {dataset_name} (split={dataset_split})" + ) + + # All datasets can be used for validation + self.val_dataloaders[dataset_name] = dataloader + logger.info(f"Setup dataset: {dataset_name} (split={dataset_split})") + + # If validation disabled, clear validation dataloaders (but keep training) + if not self.validation_enabled: + self.val_dataloaders = None + logger.info("Validation disabled - only training dataloader will be used") + + # Load checkpoint if resuming self.checkpointer.load(step=self.current_step) - # self.profiler = self.setup_profiler(self.train_config.profiler_config) - # self.logger = self.setup_logger(self.train_config.logger_config) - def setup_data(self): + def setup_data( + self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train" + ): + """Setup data with configurable dataset path and split.""" print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( @@ -159,8 +199,8 @@ def setup_data(self): dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + path=dataset_path, + split=dataset_split, ) packer = TextPacker(padding_idx=0) dataset = PackedDataset( @@ -174,17 +214,28 @@ def setup_data(self): collate_fn=partial( collate_packed, mask_fn=packer.create_block_mask, device=self.device ), + drop_last=True, # Ensure consistent batch sizes across DP ranks ) - # Ultimately we probably want something like this - # packer = build_packing_strategy(packing_config) - # dataset = build_dataset(dataset_config) - # dataloader = build_dataloader(dataloader_config, dataset, packer) return dataloader - def forward_backward( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + def forward( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + compute_gradients: bool = True, ) -> torch.Tensor: + """Forward pass with optional gradient computation. + + Args: + input_dict: Input dictionary containing tokens + labels: Target labels + compute_gradients: If True, compute gradients (training mode). + If False, skip backward pass (evaluation mode). + + Returns: + Loss tensor + """ model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -204,7 +255,7 @@ def forward_backward( ) if parallel_dims.pp_enabled: - # Pipeline Parallel forward / backward inside step() call + # Pipeline Parallel forward (with optional backward) with self.train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) @@ -226,7 +277,7 @@ def forward_backward( else torch.tensor([-1.0], device=self.device) ) else: - # Non-PP forward / backward + # Non-PP forward (with optional backward) with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -234,7 +285,8 @@ def forward_backward( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - loss.backward() + if compute_gradients: + loss.backward() return loss @@ -246,7 +298,7 @@ def train_step(self, batch) -> None: # self.data_parallel_size, # ) as grad_acc: labels = batch.pop("labels") - loss = self.forward_backward(batch, labels) + loss = self.forward(batch, labels, compute_gradients=True) loss = loss.item() record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) @@ -256,6 +308,32 @@ def train_step(self, batch) -> None: self.optimizers.step() self.lr_schedulers.step() + async def evaluate(self) -> dict[str, dict[str, float]]: + """Run evaluation with async all_reduce for cross-rank epoch synchronization. + + Evaluates on all configured validation datasets and returns per-dataset metrics. + + Returns: + Dict mapping dataset name to metrics dict, e.g.: + { + "val_in_domain": {"val_loss": 2.5, "val_batches": 100}, + "val_out_domain": {"val_loss": 3.1, "val_batches": 100} + } + """ + + # Create a wrapper that calls forward with compute_gradients=False + def forward_eval(input_dict, labels): + return self.forward(input_dict, labels, compute_gradients=False) + + return await run_evaluation( + val_dataloaders=self.val_dataloaders, + model_parts=self.model_parts, + forward_fn=forward_eval, + device=self.device, + eval_steps=self.eval_steps, + dp_process_group=get_dp_process_group(self.parallel_dims), + ) + @endpoint async def train(self) -> None: dataloader = iter(self.train_dataloader) @@ -280,10 +358,10 @@ async def train(self) -> None: # self.profiler.step() self.current_step += 1 - # Flush metrics - if self._rank == 0: - logger.debug(f"Flushing metrics at step {self.current_step}") - await self.mlogger.flush.call_one(global_step=self.current_step) + # Run evaluation periodically if enabled + if self.validation_enabled and self.current_step % self.eval_interval == 0: + eval_metrics = await self.evaluate() + logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") self.checkpointer.save( curr_step=self.current_step, From 801a4544ca11333f989f33a697c011d23ce1d0a6 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Tue, 28 Oct 2025 11:56:04 -0700 Subject: [PATCH 02/16] Simplify epoch tracking and fix metric extraction - Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name' - Simplify epoch tracking: compare consecutive batches instead of tracking from start - Remove starting_epoch variable - no longer needed - Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment - Add better logging for epoch changes and tracking status - Epoch sync now works correctly with the actual metric structure --- apps/sft/eval_utils.py | 53 +++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/apps/sft/eval_utils.py b/apps/sft/eval_utils.py index c3421d086..0ce6dab11 100644 --- a/apps/sft/eval_utils.py +++ b/apps/sft/eval_utils.py @@ -35,21 +35,30 @@ def extract_epoch_from_batch(batch: dict) -> int | None: Epoch number from metrics, or None if not found """ if "metrics" in batch: + # Look for num_epochs in metric keys + for metric in batch["metrics"]: + # Metrics have a 'key' attribute with paths like: + # 'dataset/yahma_alpaca-cleaned_train[:1%]/num_epochs' + if hasattr(metric, "key") and "num_epochs" in metric.key: + return int(metric.value) + + # Fallback: check for old-style metric_name attribute for metric in batch["metrics"]: if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs": - return metric.value + return int(metric.value) + return None def start_epoch_sync( - epoch_increment: int, + epoch_changed: bool, device: torch.device, dp_process_group: Any = None, ) -> tuple[torch.Tensor | None, Any]: """Start async all_reduce for epoch synchronization across ranks. Args: - epoch_increment: Difference between current and starting epoch + epoch_changed: Whether the epoch changed on this rank device: Device for tensor dp_process_group: Data parallel process group (None = default group) @@ -59,7 +68,8 @@ def start_epoch_sync( if not torch.distributed.is_initialized(): return None, None - epoch_tensor = torch.tensor([epoch_increment], dtype=torch.long, device=device) + # Convert bool to tensor: 1 if epoch changed, 0 otherwise + epoch_tensor = torch.tensor([int(epoch_changed)], dtype=torch.long, device=device) pending_work = torch.distributed.all_reduce( epoch_tensor, op=torch.distributed.ReduceOp.MAX, @@ -117,7 +127,7 @@ def eval_loop( Tuple of (avg_loss, num_batches) """ total_loss = torch.tensor(0.0, device=device) - num_batches, starting_epoch = 0, None + num_batches = 0 # Prefetch first batch next_batch = next(dataloader_iter) @@ -142,26 +152,41 @@ def eval_loop( batch = next_batch - # Track starting epoch + # Get current batch epoch current_epoch = extract_epoch_fn(batch) - if starting_epoch is None: - starting_epoch = current_epoch - # Prefetch next batch and start async epoch check + # Prefetch next batch and check for epoch change try: next_batch = next(dataloader_iter) next_epoch = extract_epoch_fn(next_batch) - # Only check epochs if both are available - if next_epoch is not None and starting_epoch is not None: - epoch_increment = next_epoch - starting_epoch + # Simple check: did epoch change between consecutive batches? + if next_epoch is not None and current_epoch is not None: + epoch_changed = next_epoch > current_epoch + + if epoch_changed: + logger.info( + f"[{dataset_name}] Epoch change detected: " + f"{current_epoch} → {next_epoch}" + ) + if torch.distributed.is_initialized(): + # All-reduce: if ANY rank's epoch changed, all ranks should stop epoch_tensor, pending_work = start_epoch_sync( - epoch_increment, device, dp_process_group + epoch_changed, device, dp_process_group ) else: - should_break = epoch_increment > 0 + # Single process: stop immediately if epoch changed + should_break = epoch_changed + else: + # No epoch tracking available - rely on eval_steps + if num_batches == 0: + logger.info( + f"[{dataset_name}] No epoch tracking available " + f"(current={current_epoch}, next={next_epoch})" + ) except StopIteration: + logger.info(f"[{dataset_name}] StopIteration - dataloader exhausted") should_break = True # Process current batch (overlaps with async all_reduce) From ebd4ac1d0d42dab2773559c0649099fe2b377629 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 13:01:01 -0800 Subject: [PATCH 03/16] first commit --- apps/sft/eval_utils.py | 351 ------------------ apps/sft/llama3_8b.yaml | 28 +- apps/sft/main.py | 256 +++++++++---- src/forge/data/datasets/hf_dataset.py | 22 +- src/forge/data/datasets/packed.py | 19 +- src/forge/data/datasets/sft_dataset.py | 3 + src/forge/data/utils.py | 129 ++++++- tests/unit_tests/datasets/test_hf.py | 123 ++++++ tests/unit_tests/datasets/test_packed.py | 46 +++ .../datasets/test_stop_after_one_epoch.py | 163 ++++++++ 10 files changed, 677 insertions(+), 463 deletions(-) delete mode 100644 apps/sft/eval_utils.py create mode 100644 tests/unit_tests/datasets/test_stop_after_one_epoch.py diff --git a/apps/sft/eval_utils.py b/apps/sft/eval_utils.py deleted file mode 100644 index 0ce6dab11..000000000 --- a/apps/sft/eval_utils.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Utility functions for evaluation to make main.py more concise.""" - -import logging -from typing import Any, Callable, Iterator - -import torch -from torch import nn - -logger = logging.getLogger(__name__) - - -def move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]: - """Move all tensors in batch to specified device. - - Args: - batch: Dictionary containing batch data - device: Target device - - Returns: - Batch with tensors moved to device (modifies in-place and returns) - """ - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to(device) - return batch - - -def extract_epoch_from_batch(batch: dict) -> int | None: - """Extract epoch number from batch metrics. - - Args: - batch: Batch dictionary with 'metrics' field - - Returns: - Epoch number from metrics, or None if not found - """ - if "metrics" in batch: - # Look for num_epochs in metric keys - for metric in batch["metrics"]: - # Metrics have a 'key' attribute with paths like: - # 'dataset/yahma_alpaca-cleaned_train[:1%]/num_epochs' - if hasattr(metric, "key") and "num_epochs" in metric.key: - return int(metric.value) - - # Fallback: check for old-style metric_name attribute - for metric in batch["metrics"]: - if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs": - return int(metric.value) - - return None - - -def start_epoch_sync( - epoch_changed: bool, - device: torch.device, - dp_process_group: Any = None, -) -> tuple[torch.Tensor | None, Any]: - """Start async all_reduce for epoch synchronization across ranks. - - Args: - epoch_changed: Whether the epoch changed on this rank - device: Device for tensor - dp_process_group: Data parallel process group (None = default group) - - Returns: - Tuple of (epoch_tensor, pending_work) for async operation, or (None, None) if not initialized - """ - if not torch.distributed.is_initialized(): - return None, None - - # Convert bool to tensor: 1 if epoch changed, 0 otherwise - epoch_tensor = torch.tensor([int(epoch_changed)], dtype=torch.long, device=device) - pending_work = torch.distributed.all_reduce( - epoch_tensor, - op=torch.distributed.ReduceOp.MAX, - group=dp_process_group, - async_op=True, - ) - return epoch_tensor, pending_work - - -def check_epoch_complete( - pending_work: Any, - epoch_tensor: torch.Tensor | None, -) -> bool: - """Wait for async epoch sync and check if epoch completed. - - Args: - pending_work: Pending async all_reduce work - epoch_tensor: Tensor containing epoch increment - - Returns: - True if any rank completed an epoch, False otherwise - """ - if pending_work is None: - return False - - pending_work.wait() - if epoch_tensor is not None: - return bool((epoch_tensor > 0).any().item()) - return False - - -def eval_loop( - dataloader_iter: Iterator, - forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], - device: torch.device, - eval_steps: int, - dataset_name: str, - dp_process_group: Any = None, - extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, - log_interval: int = 10, -) -> tuple[float, int]: - """Run evaluation loop with epoch synchronization. - - Args: - dataloader_iter: Iterator over validation data - forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss tensor - device: Device for computation - eval_steps: Maximum number of eval steps (0 = no limit) - dataset_name: Name for logging - dp_process_group: Data parallel process group for epoch sync - extract_epoch_fn: Function to extract epoch from batch - log_interval: Log every N batches - - Returns: - Tuple of (avg_loss, num_batches) - """ - total_loss = torch.tensor(0.0, device=device) - num_batches = 0 - - # Prefetch first batch - next_batch = next(dataloader_iter) - should_break, pending_work, epoch_tensor = False, None, None - - with torch.no_grad(): - while True: - # Check if previous epoch sync completed - if pending_work is not None: - should_break = check_epoch_complete(pending_work, epoch_tensor) - pending_work = None - - if should_break: - logger.info( - f"[{dataset_name}] Epoch completed across all ranks - stopping evaluation" - ) - break - - if eval_steps > 0 and num_batches >= eval_steps: - logger.info(f"[{dataset_name}] Reached eval_steps cap of {eval_steps}") - break - - batch = next_batch - - # Get current batch epoch - current_epoch = extract_epoch_fn(batch) - - # Prefetch next batch and check for epoch change - try: - next_batch = next(dataloader_iter) - next_epoch = extract_epoch_fn(next_batch) - - # Simple check: did epoch change between consecutive batches? - if next_epoch is not None and current_epoch is not None: - epoch_changed = next_epoch > current_epoch - - if epoch_changed: - logger.info( - f"[{dataset_name}] Epoch change detected: " - f"{current_epoch} → {next_epoch}" - ) - - if torch.distributed.is_initialized(): - # All-reduce: if ANY rank's epoch changed, all ranks should stop - epoch_tensor, pending_work = start_epoch_sync( - epoch_changed, device, dp_process_group - ) - else: - # Single process: stop immediately if epoch changed - should_break = epoch_changed - else: - # No epoch tracking available - rely on eval_steps - if num_batches == 0: - logger.info( - f"[{dataset_name}] No epoch tracking available " - f"(current={current_epoch}, next={next_epoch})" - ) - except StopIteration: - logger.info(f"[{dataset_name}] StopIteration - dataloader exhausted") - should_break = True - - # Process current batch (overlaps with async all_reduce) - move_batch_to_device(batch, device) - labels = batch.pop("labels") - loss = forward_fn(batch, labels) - total_loss += loss - num_batches += 1 - - if num_batches % log_interval == 0: - logger.info( - f" [{dataset_name}] Eval batch {num_batches} | Loss: {loss:.4f}" - ) - - avg_loss = (total_loss / max(num_batches, 1)).item() - logger.info( - f"[{dataset_name}] COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}" - ) - - return avg_loss, num_batches - - -async def evaluate_single_dataset( - val_dataloader: Any, - dataset_name: str, - forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], - device: torch.device, - eval_steps: int, - dp_process_group: Any = None, - extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, -) -> dict[str, float]: - """Evaluate on a single validation dataset with epoch synchronization. - - Args: - val_dataloader: DataLoader for this validation dataset - dataset_name: Name of the dataset (for logging) - forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss - device: Device for computation - eval_steps: Maximum number of eval steps - dp_process_group: Data parallel process group - extract_epoch_fn: Function to extract epoch from batch - - Returns: - Dict with metrics: {"val_loss": float, "val_batches": int} - """ - avg_loss, num_batches = eval_loop( - dataloader_iter=iter(val_dataloader), - forward_fn=forward_fn, - device=device, - eval_steps=eval_steps, - dataset_name=dataset_name, - dp_process_group=dp_process_group, - extract_epoch_fn=extract_epoch_fn, - log_interval=10, - ) - - return {"val_loss": avg_loss, "val_batches": num_batches} - - -async def run_evaluation( - val_dataloaders: dict[str, Any], - model_parts: list[nn.Module], - forward_fn: Callable[[dict, torch.Tensor], torch.Tensor], - device: torch.device, - eval_steps: int, - dp_process_group: Any = None, - extract_epoch_fn: Callable[[dict], int | None] = extract_epoch_from_batch, -) -> dict[str, dict[str, float]]: - """Run evaluation on multiple validation datasets. - - Evaluates on all configured validation datasets and returns per-dataset metrics. - Sets models to eval mode before evaluation and back to train mode after. - - Args: - val_dataloaders: Dict mapping dataset names to dataloaders - model_parts: List of model parts (for setting eval/train mode) - forward_fn: Function that takes (batch_dict, labels_tensor) and returns loss - device: Device for computation - eval_steps: Maximum number of eval steps per dataset - dp_process_group: Data parallel process group - extract_epoch_fn: Function to extract epoch from batch - - Returns: - Dict mapping dataset name to metrics dict, e.g.: - { - "val_in_domain": {"val_loss": 2.5, "val_batches": 100}, - "val_out_domain": {"val_loss": 3.1, "val_batches": 100} - } - """ - logger.info("=" * 50) - logger.info("STARTING EVALUATION") - logger.info("=" * 50) - - # Set models to eval mode - for model_part in model_parts: - model_part.eval() - - all_metrics = {} - - # Evaluate on each dataset - for dataset_name, val_dataloader in val_dataloaders.items(): - logger.info(f"\n{'='*50}") - logger.info(f"Evaluating on dataset: {dataset_name}") - logger.info(f"{'='*50}") - - dataset_metrics = await evaluate_single_dataset( - val_dataloader=val_dataloader, - dataset_name=dataset_name, - forward_fn=forward_fn, - device=device, - eval_steps=eval_steps, - dp_process_group=dp_process_group, - extract_epoch_fn=extract_epoch_fn, - ) - all_metrics[dataset_name] = dataset_metrics - - # Set models back to train mode - for model_part in model_parts: - model_part.train() - - logger.info("\n" + "=" * 50) - logger.info("EVALUATION COMPLETE - Summary:") - for dataset_name, metrics in all_metrics.items(): - logger.info( - f" {dataset_name}: Loss={metrics['val_loss']:.4f}, Batches={metrics['val_batches']}" - ) - logger.info("=" * 50) - - return all_metrics - - -def get_dp_process_group(parallel_dims: Any) -> Any: - """Get the Data Parallel process group for epoch synchronization. - - Returns the DP process group if DP parallelism is enabled, otherwise None. - This ensures all_reduce only happens across ranks with different data. - - Args: - parallel_dims: ParallelDims object containing parallel configuration - - Returns: - DP process group or None if not available/needed - """ - if not torch.distributed.is_initialized(): - return None - - if parallel_dims is None: - return None - - # Check if DP is enabled - if not parallel_dims.dp_enabled: - # No DP parallelism, use default group (all ranks) - return None - - try: - # Get the "dp" submesh which contains only DP dimensions (dp_replicate + dp_shard) - # This excludes TP and PP ranks which should already be synchronized - dp_mesh = parallel_dims.world_mesh.get_group("dp") - return dp_mesh - except Exception as e: - logger.warning(f"Could not get DP process group, using default: {e}") - return None diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index b21d14196..960c96e6d 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -26,32 +26,30 @@ optimizer: lr_scheduler: warmup_steps: 200 -# Unified dataset configuration -# First dataset with split='train' is used for training -dataset_val: - datasets: - - name: "train" - path: "yahma/alpaca-cleaned" - split: "train[:95%]" - - - name: "val" - path: "yahma/alpaca-cleaned" - split: "train[95%:]" - training: local_batch_size: 1 seq_len: 2048 max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +eval: + eval_every_n_steps: 5 # (null = disabled) + max_eval_steps: 0 # Max batches per eval dataset (null = run until epoch completes) + batch_size: ${training.local_batch_size} # Batch size for evaluation + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[95%:]" parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + tensor_parallel_degree: 2 pipeline_parallel_degree: 1 - context_parallel_degree: 1 + context_parallel_degree: 2 expert_parallel_degree: 1 disable_loss_parallel: false diff --git a/apps/sft/main.py b/apps/sft/main.py index 41a8eec0b..b7a07237a 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -22,14 +22,15 @@ import torch import torchtitan.experiments.forge.train_spec as forge_train_spec -from apps.sft.eval_utils import get_dp_process_group, run_evaluation from forge.controller import ForgeActor from forge.data.collate import collate_packed from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.data.utils import StopAfterOneEpoch from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse +from forge.util.logging import log_rank_zero from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -83,17 +84,6 @@ def __init__(self, config: DictConfig): self._rank = current_rank().rank self._size = math.prod(current_size().values()) - # Evaluation settings - no defaults, must be explicit in config - validation_config = job_config.get("validation") - if validation_config is not None: - self.validation_enabled = validation_config.get("enabled") - self.eval_interval = validation_config.get("eval_interval") - self.eval_steps = validation_config.get("eval_steps") - else: - self.validation_enabled = False - self.eval_interval = None - self.eval_steps = None - self._init_dist() super().__init__(job_config) @@ -135,54 +125,67 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - # Always expect dataset_val.datasets configuration - dataset_val_config = self.job_config.get("dataset_val") + """Setup datasets from config. - datasets = dataset_val_config["datasets"] + Loads training and evaluation datasets based on config structure. + """ + # Load training datasets + logger.info("Setting training datasets") + train_datasets_config = self.job_config.training.datasets + self.train_dataloader = self.setup_data(train_datasets_config) - # Setup all datasets + # Load eval config (might be None) + eval_config = self.job_config.get("eval", {}) self.val_dataloaders = {} - self.train_dataloader = None + self.validation_enabled = False + self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) + max_eval_steps = eval_config.get("max_eval_steps", None) + self.max_eval_steps = ( + max_eval_steps if max_eval_steps and max_eval_steps > 0 else None + ) + self.validation_enabled = ( + self.eval_every_n_steps is not None and self.eval_every_n_steps > 0 + ) + if self.validation_enabled: + logger.info("Setting eval datasets") + self.eval_datasets_config = eval_config.datasets - for i, dataset_spec in enumerate(datasets): - dataset_name = dataset_spec.get("name") - dataset_path = dataset_spec.get("path") - dataset_split = dataset_spec.get("split") + for i, dataset_config in enumerate(self.eval_datasets_config): + ds_name = dataset_config.get("dataset_name", i) - if not dataset_name or not dataset_path or not dataset_split: - raise ValueError( - f"Each dataset must have 'name', 'path', and 'split'. " - f"Got dataset[{i}]: {dataset_spec}" - ) + dataloader = self.setup_data([dataset_config]) + self.val_dataloaders[ds_name] = dataloader - dataloader = self.setup_data( - dataset_path=dataset_path, - dataset_split=dataset_split, - ) + # Load checkpoint if resuming + self.checkpointer.load(step=self.current_step) - # First dataset with split starting with 'train' is used for training - if i == 0 and dataset_split.startswith("train"): - self.train_dataloader = dataloader - logger.info( - f"Setup training dataset: {dataset_name} (split={dataset_split})" - ) + def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: + """Setup data from dataset configs. - # All datasets can be used for validation - self.val_dataloaders[dataset_name] = dataloader - logger.info(f"Setup dataset: {dataset_name} (split={dataset_split})") + Currently only supports single dataset (first in list). + Multi-dataset training with InterleavedDataset is future work. - # If validation disabled, clear validation dataloaders (but keep training) - if not self.validation_enabled: - self.val_dataloaders = None - logger.info("Validation disabled - only training dataloader will be used") + Args: + dataset_configs: List of dataset config dicts with keys like 'path', 'split', etc. - # Load checkpoint if resuming - self.checkpointer.load(step=self.current_step) + Returns: + StatefulDataLoader for the dataset - def setup_data( - self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train" - ): - """Setup data with configurable dataset path and split.""" + Raises: + ValueError: If multiple datasets provided (not yet supported) + """ + # Currently only support single dataset + if len(dataset_configs) > 1: + raise ValueError( + f"Multiple training datasets not supported yet. " + f"Got {len(dataset_configs)} datasets. " + f"For dataset mixing, use InterleavedDataset (coming soon)." + ) + + dataset_config = dataset_configs[0] + + # TODO: Evaluate if tokenizers should be created once and shared for every dataset + # Load tokenizer print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( @@ -196,29 +199,38 @@ def setup_data( ), ) + # Store tokenizer for later use (e.g., decoding in debug logs) + self.tokenizer = tokenizer + + # Get DP mesh for data sharding + dp_mesh = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_mesh = self.parallel_dims.world_mesh.get_group("dp") + + # Pass config directly to dataset constructor dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path=dataset_path, - split=dataset_split, + dp_mesh=dp_mesh, + **dataset_config, # Unpack config (path, split, etc.) ) + packer = TextPacker(padding_idx=0) dataset = PackedDataset( dataset=dataset, packer=packer, - target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model + target_tokens_per_pack=self.job_config.training.seq_len, ) - dataloader = StatefulDataLoader( + + return StatefulDataLoader( dataset=dataset, batch_size=self.job_config.training.local_batch_size, collate_fn=partial( collate_packed, mask_fn=packer.create_block_mask, device=self.device ), - drop_last=True, # Ensure consistent batch sizes across DP ranks + drop_last=True, ) - return dataloader - def forward( self, input_dict: dict[str, torch.Tensor], @@ -308,31 +320,106 @@ def train_step(self, batch) -> None: self.optimizers.step() self.lr_schedulers.step() - async def evaluate(self) -> dict[str, dict[str, float]]: - """Run evaluation with async all_reduce for cross-rank epoch synchronization. + async def evaluate(self) -> None: + """Run evaluation on multiple datasets, one at a time. + + 1. Set models to eval mode + 2. For each eval dataset: + - Create fresh iterator (starts from epoch 0) + - Use StopAfterOneEpoch to iterate until epoch boundary. This utility + is necessary for infinite iterable dataset, since epoch boundaries are not known. + - Respect max_eval_steps cap if configured + - Record loss and step metrics (on dp rank only) + 3. Restore models to train mode + """ + logger.debug("==STARTING EVALUATION==") + + # Set models to eval mode + for model_part in self.model_parts: + model_part.eval() + + # Get DP process group for epoch synchronization + dp_process_group = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_process_group = self.parallel_dims.world_mesh.get_group("dp") + + # Evaluate each dataset sequentially + for dataset_name, val_dataloader in self.val_dataloaders.items(): + logger.debug(f"=====Evaluating dataset: {dataset_name}=====") + + # Evaluation loop for this dataset + total_loss = torch.tensor(0.0, device=self.device) + num_steps = 0 + + # NOTE: Assumes batch contains batch["metrics"]["num_epochs"]: int + batch_iter = StopAfterOneEpoch( + iter(val_dataloader), # Fresh iterator from epoch 0 + self.device, + dataset_name, + dp_process_group, + ) - Evaluates on all configured validation datasets and returns per-dataset metrics. + with torch.no_grad(): + for batch in batch_iter: + # Check max_eval_steps limit + if ( + self.max_eval_steps is not None + and num_steps >= self.max_eval_steps + ): + log_rank_zero( + logger, + f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}", + ) + break + + # Move tensors to device + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(self.device) + + # Process batch + labels = batch.pop("labels") + loss = self.forward(batch, labels, compute_gradients=False) + total_loss += loss + num_steps += 1 + + # Log progress (rank 0 only) + if num_steps % 100 == 0: + loss_val = loss.item() + log_rank_zero( + logger, + f" [{dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}", + ) + + # Compute average loss + avg_loss = (total_loss / max(num_steps, 1)).item() + log_rank_zero(logger, f" [{dataset_name}] avg_loss: {avg_loss:.4f}") + + # Record metrics only on DP rank 0 to avoid double counting + # record_metric aggregates across all processes via monarch + should_record = True + if dp_process_group is not None: + dp_rank = torch.distributed.get_rank(group=dp_process_group) + should_record = dp_rank == 0 + + if should_record: + record_metric( + f"ForgeSFTRecipe/evaluate/{dataset_name}_loss", + avg_loss, + Reduce.MEAN, + ) + record_metric( + f"ForgeSFTRecipe/evaluate/{dataset_name}_steps", + num_steps, + Reduce.MEAN, + ) - Returns: - Dict mapping dataset name to metrics dict, e.g.: - { - "val_in_domain": {"val_loss": 2.5, "val_batches": 100}, - "val_out_domain": {"val_loss": 3.1, "val_batches": 100} - } - """ + # Restore train mode + for model_part in self.model_parts: + model_part.train() - # Create a wrapper that calls forward with compute_gradients=False - def forward_eval(input_dict, labels): - return self.forward(input_dict, labels, compute_gradients=False) - - return await run_evaluation( - val_dataloaders=self.val_dataloaders, - model_parts=self.model_parts, - forward_fn=forward_eval, - device=self.device, - eval_steps=self.eval_steps, - dp_process_group=get_dp_process_group(self.parallel_dims), - ) + # Summary + logger.debug("==EVALUATION COMPLETE==") @endpoint async def train(self) -> None: @@ -359,9 +446,11 @@ async def train(self) -> None: self.current_step += 1 # Run evaluation periodically if enabled - if self.validation_enabled and self.current_step % self.eval_interval == 0: - eval_metrics = await self.evaluate() - logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") + if ( + self.validation_enabled + and self.current_step % self.eval_every_n_steps == 0 + ): + await self.evaluate() self.checkpointer.save( curr_step=self.current_step, @@ -370,6 +459,11 @@ async def train(self) -> None: # self.pbar.close() + # Run final evaluation at end of training + if self.validation_enabled: + logger.info("Running final evaluation at end of training...") + await self.evaluate() + @endpoint async def cleanup(self) -> None: if self.checkpointer: diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index d7b36fe68..8399d68e5 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -70,6 +70,7 @@ def __init__( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: Any = None, **load_dataset_kwargs, ): # Store configuration @@ -79,6 +80,8 @@ def __init__( self._model_transform = model_transform self._output_transform = output_transform self._weight = weight if weight is not None else 1.0 + self._dp_mesh = dp_mesh + self._is_resumed = False # Create default transform if not provided self._metric_transform = metric_transform or DefaultDatasetMetricTransform() @@ -138,11 +141,22 @@ def _setup_hf_dataset( shuffle configuration, and filtering. Called once during __init__. """ - # Distributed setup + # Extract rank/world_size from DP mesh world_size, rank = 1, 0 - if dist.is_initialized(): + if self._dp_mesh is not None: + world_size = dist.get_world_size(group=self._dp_mesh) + rank = dist.get_rank(group=self._dp_mesh) + logger.info( + f"Using DP mesh for sharding: rank={rank}, world_size={world_size}" + ) + elif dist.is_initialized(): + # Fallback to global rank (may not respect TP/PP) world_size = dist.get_world_size() rank = dist.get_rank() + logger.warning( + f"Using global rank for sharding: rank={rank}, world_size={world_size}. " + f"If using TP/PP, pass dp_mesh for correct sharding." + ) # Load and shard dataset ds = load_dataset(**load_dataset_kwargs) @@ -218,6 +232,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]: - Adds 'num_epochs' metric to track dataset progress - Yields samples indefinitely for continuous training """ + # Reset iter + if not self._is_resumed: + self._num_epochs = 0 while True: # Infinite iteration self._ds.set_epoch(self._num_epochs) @@ -282,3 +299,4 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # HF is responsible for resuming the dataset state # where it last left off self._ds.load_state_dict(hf_state) + self._is_resumed = True diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 93a21b85e..1a22352ef 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -343,9 +343,6 @@ def _reset_packer_state(self) -> None: # exhausted: whether the dataset is exhausted self._exhausted: bool = False - # resuming: whether the packer is resuming from a checkpoint - self._resuming: bool = False - def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: """ Fills the buffer with samples from the dataset. @@ -449,18 +446,15 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None: return None def __iter__(self) -> Iterator[SampleDict]: + """Create a new iterator for the dataset. + + Always resets the packer state to ensure consistent iteration from the start. + """ if not isinstance(self.dataset, Iterable): raise TypeError("Dataset is not an iterable") - if not self._resuming: - self._reset_packer_state() - self._iterator = iter(self.dataset) - - # If resuming, the iterator must be recreated from the loaded state - if self._iterator is None: - self._iterator = iter(self.dataset) - - self._resuming = False # Consume the resume flag + self._reset_packer_state() + self._iterator = iter(self.dataset) # Main packing loop while True: @@ -502,7 +496,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise ValueError("Dataset is not stateful.") self._reset_packer_state() - self._resuming = True class TextPacker(Packer[SampleDict]): diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 00278c1e5..6264b13ea 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -162,6 +162,7 @@ def sft_iterable_dataset( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: Any = None, **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ @@ -177,6 +178,7 @@ def sft_iterable_dataset( dataset_name (str | None): Name for metrics namespacing filter_fn (Callable | None): Filter function filter_kwargs (dict[str, Any] | None): Filter function kwargs + dp_mesh (Any): Data parallel mesh for sharding (None for single process) **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset Returns: @@ -206,5 +208,6 @@ def sft_iterable_dataset( dataset_name=dataset_name, filter_fn=filter_fn, filter_kwargs=filter_kwargs, + dp_mesh=dp_mesh, **load_dataset_kwargs, ) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index be8c13857..498d3e419 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Iterator, Literal, Union import torch from torch.nn.attention.flex_attention import BlockMask +logger = logging.getLogger(__name__) + CROSS_ENTROPY_IGNORE_IDX = -100 Role = Literal[ @@ -213,3 +216,127 @@ def batch_to_device(batch: dict, device: torch.device) -> None: f"Tensor, or BlockMask with flexattention enabled. " f'Got key "{k}" with value of type {type(v)}' ) + + +class StopAfterOneEpoch: + """Iterator that wraps a dataloader and stops after one epoch completes. + + Handles epoch detection and synchronization across DP ranks using async + all_reduce. Assumes dataset inherits from InfiniteTuneIterableDataset + which provides 'metrics' with 'num_epochs' metric. + + When any rank detects an epoch change, all ranks stop (synchronized). + + Args: + dataloader_iter: Iterator over dataloader batches + device: Device for computation + dataset_name: Name for logging + dp_process_group: Data parallel process group (None for single process) + """ + + def __init__( + self, + dataloader_iter: Iterator, + device: torch.device, + dataset_name: str, + dp_process_group: Any = None, + ): + self.dataloader_iter = dataloader_iter + self.device = device + self.dataset_name = dataset_name + self.dp_process_group = dp_process_group + + # Prefetch first batch for pipeline-style execution + self._next_batch = next(dataloader_iter) + + # Track pending async epoch sync + self._epoch_tensor: torch.Tensor | None = None + self._pending_work: Any = None + self._should_stop = False + + def __iter__(self): + return self + + def __next__(self) -> dict: + """Get next batch from current epoch. + + Returns: + Batch dict guaranteed to be from current epoch + + Raises: + StopIteration: When epoch completes across all ranks + """ + # Check if previous epoch sync completed + if self._pending_work is not None: + self._pending_work.wait() + if self._epoch_tensor.item() > 0: + self._should_stop = True + self._pending_work = None + self._epoch_tensor = None + + if self._should_stop: + logger.debug( + f"[{self.dataset_name}] Eval epoch completed. Stopping data iterator." + ) + raise StopIteration + + # Get current batch + current_batch = self._next_batch + current_epoch = extract_epoch_from_batch(current_batch) + + # Prefetch next batch and check for epoch change + self._next_batch = next(self.dataloader_iter) + next_epoch = extract_epoch_from_batch(self._next_batch) + epoch_changed = next_epoch > current_epoch + + # Start async epoch sync + if torch.distributed.is_initialized(): + self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device) + self._pending_work = torch.distributed.all_reduce( + self._epoch_tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group, + async_op=True, + ) + elif epoch_changed: + # if not distributed, just update the flag directly + self._should_stop = True + + return current_batch + + +def extract_epoch_from_batch(batch: dict | list) -> int: + """Extract epoch number from batch metrics. + + Assumes datasets inherit from InfiniteTuneIterableDataset which always + adds num_epochs metric. Raises clear error if assumption is violated. + + Args: + batch: Batch dictionary with 'metrics' field OR list of sample dicts + + Returns: + Epoch number from metrics + + Raises: + ValueError: If metrics missing or no num_epochs found + """ + # Handle list of samples (uncollated batches) + if isinstance(batch, list): + if not batch: + raise ValueError("Empty batch provided") + batch = batch[0] # Extract first sample + + if "metrics" not in batch: + raise ValueError( + "Batch missing 'metrics' field. Ensure dataset inherits from " + "InfiniteTuneIterableDataset which adds this automatically." + ) + + for metric in batch["metrics"]: + if "num_epochs" in metric.key: + return int(metric.value) + + raise ValueError( + f"No 'num_epochs' metric found in batch. Got metrics: " + f"{[m.key for m in batch['metrics']]}" + ) diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index 8298bf1a8..9fd2ce464 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -272,6 +272,8 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): class TestDistributedHfIterableDataset(FSDPTest): + """Test HfIterableDataset with 2-GPU distributed setup.""" + @property def world_size(self) -> int: return 2 @@ -364,3 +366,124 @@ def create_loader(): finally: shutil.rmtree(temp_dir) + + +class TestDPShardingWithTP(FSDPTest): + """Test DP sharding with TP replication (4-GPU setup).""" + + @property + def world_size(self) -> int: + return 4 + + @gpu_test(gpu_count=4) + def test_dp_sharding_with_tp_replication(self): + """Verify DP sharding works correctly with TP/CP replication. + + This is a CRITICAL test that validates the core bug fix: + - Previously: Each rank got different batches (incorrect) + - Now: TP/CP ranks within same DP group get identical batches (correct) + + Setup: DP=2, TP=2 (4 GPUs total) + - DP group 0: ranks [0, 1] - should see SAME batches (TP replication) + - DP group 1: ranks [2, 3] - should see SAME batches (TP replication) + - DP group 0 vs 1: should see DIFFERENT batches (DP sharding) + + Mesh structure: + - TP rank 0 DP replicas: [0, 2] - shard across these + - TP rank 1 DP replicas: [1, 3] - shard across these + """ + import hashlib + + rank = dist.get_rank() + world_size = dist.get_world_size() + temp_dir = tempfile.mkdtemp(prefix=f"dp_tp_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with enough samples for clear sharding + # 40 samples / 2 DP groups = 20 samples per DP group + create_test_json_file(data_file, MEDIUM_DATASET_SIZE, offset=0) + + # Create DP mesh for sharding + # Key insight: Create groups across DP replicas for each TP rank + # TP rank = rank % 2, so: + # - TP rank 0: ranks [0, 2] (one from each DP group) + # - TP rank 1: ranks [1, 3] (one from each DP group) + tp_rank = rank % 2 + tp_world_size = 2 + dp_world_size = world_size // tp_world_size + + # Create DP groups for each TP rank + dp_groups = [] + for tp_r in range(tp_world_size): + # Ranks for this TP rank across DP groups + ranks = [tp_r + i * tp_world_size for i in range(dp_world_size)] + group = dist.new_group(ranks=ranks) + dp_groups.append(group) + + dp_mesh = dp_groups[tp_rank] + + # - Rank 0 (tp_rank=0) uses group [0, 2], gets rank=0 → shard 0 + # - Rank 1 (tp_rank=1) uses group [1, 3], gets rank=0 → shard 0 + # - Rank 2 (tp_rank=0) uses group [0, 2], gets rank=1 → shard 1 + # - Rank 3 (tp_rank=1) uses group [1, 3], gets rank=1 → shard 1 + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + dataset_name="dp_tp_test", + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), + num_shards_per_rank=2, + dp_mesh=dp_mesh, # CRITICAL: Pass dp_mesh for correct sharding + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, + ) + + # Collect batches and compute hashes + batches = list(islice(iter(dataloader), 5)) + batch_hashes = [] + for batch in batches: + # Hash the batch IDs to verify identity/difference + batch_ids = batch["id"].cpu().tolist() + batch_hash = hashlib.md5(str(batch_ids).encode()).hexdigest() + batch_hashes.append(batch_hash) + + # Gather hashes from all ranks for comparison + gathered_hashes = [None] * world_size + dist.all_gather_object(gathered_hashes, batch_hashes) + + if rank == 0: + # Verify TP replication within DP groups + # Ranks 0 and 1 should have identical hashes (same DP group) + assert gathered_hashes[0] == gathered_hashes[1], ( + f"Ranks 0 and 1 (same DP group) should see identical batches!\n" + f"Rank 0 hashes: {gathered_hashes[0][:3]}...\n" + f"Rank 1 hashes: {gathered_hashes[1][:3]}..." + ) + + # Ranks 2 and 3 should have identical hashes (same DP group) + assert gathered_hashes[2] == gathered_hashes[3], ( + f"Ranks 2 and 3 (same DP group) should see identical batches!\n" + f"Rank 2 hashes: {gathered_hashes[2][:3]}...\n" + f"Rank 3 hashes: {gathered_hashes[3][:3]}..." + ) + + # Verify DP sharding across groups + # Ranks 0/1 should see DIFFERENT batches from ranks 2/3 + assert gathered_hashes[0] != gathered_hashes[2], ( + f"Ranks 0 and 2 (different DP groups) should see different batches!\n" + f"DP group 0 hashes: {gathered_hashes[0][:3]}...\n" + f"DP group 1 hashes: {gathered_hashes[2][:3]}..." + ) + + dist.barrier() + + finally: + shutil.rmtree(temp_dir) diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 56cd5ff02..1c6c4906f 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -949,3 +949,49 @@ def create_loader(): # Verify that checkpointing and resumption work assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint assert len(result["resumed_batches"]) == steps_after_checkpoint + + def test_iter_restart_determinism(self, dataset_factory): + """Test that calling iter() multiple times produces deterministic results. + + This is critical for evaluation: each eval run should start from the + same state (epoch 0, step 0) regardless of previous iterations. + """ + samples = [ + {"tokens": [0] * 3}, + {"tokens": [1] * 2}, + {"tokens": [2] * 4}, + ] + target_tokens_per_pack = 6 + + # Create packed dataset + dataset = dataset_factory(samples) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = PackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=1, + ) + + # First iteration - get first 2 packs + iter1 = iter(packed_dataset) + packs_iter1 = list(islice(iter1, 2)) + + # Second iteration - should get same first 2 packs + iter2 = iter(packed_dataset) + packs_iter2 = list(islice(iter2, 2)) + + # Verify both iterations produce identical packs + assert len(packs_iter1) == len(packs_iter2) == 2 + + for i, (pack1, pack2) in enumerate(zip(packs_iter1, packs_iter2)): + torch.testing.assert_close( + pack1["tokens"], + pack2["tokens"], + msg=f"Pack {i}: tokens mismatch between iterations", + ) + torch.testing.assert_close( + pack1["document_ids"], + pack2["document_ids"], + msg=f"Pack {i}: document_ids mismatch between iterations", + ) diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py new file mode 100644 index 000000000..0b0399cda --- /dev/null +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for StopAfterOneEpoch iterator and extract_epoch_from_batch helper.""" +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from forge.data.datasets import HfIterableDataset + +from forge.data.utils import extract_epoch_from_batch, StopAfterOneEpoch +from forge.observability.metrics import Metric, Reduce +from torch.testing._internal.common_fsdp import FSDPTest +from torchdata.stateful_dataloader import StatefulDataLoader + +from tests.test_utils import gpu_test + + +def create_test_json_file(path: Path, num_samples: int) -> None: + """Create test data file with simple samples.""" + with open(path, "w") as f: + for i in range(num_samples): + f.write(f'{{"id": {i}, "tokens": [{i}, {i+1}]}}\n') + + +class TestExtractEpochFromBatch: + """Test extract_epoch_from_batch helper function.""" + + def test_extract_epoch_from_batch_success(self): + """Test extracting epoch from valid batch with metrics.""" + batch = { + "tokens": torch.tensor([1, 2, 3]), + "metrics": [ + Metric(key="dataset/test/num_epochs", value=2, reduction=Reduce.MAX), + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ], + } + epoch = extract_epoch_from_batch(batch) + assert epoch == 2 + + def test_extract_epoch_missing_metrics_field(self): + """Test error when batch has no 'metrics' field.""" + batch = {"tokens": torch.tensor([1, 2, 3])} + with pytest.raises(ValueError, match="Batch missing 'metrics' field"): + extract_epoch_from_batch(batch) + + def test_extract_epoch_no_num_epochs_metric(self): + """Test error when no num_epochs metric found.""" + batch = { + "metrics": [ + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ] + } + with pytest.raises(ValueError, match="No 'num_epochs' metric found"): + extract_epoch_from_batch(batch) + + +class TestStopAfterOneEpochSingleProcess: + """Test StopAfterOneEpoch in single-process mode (no distributed).""" + + def test_stop_after_one_epoch(self, tmp_path): + """Verify iterator stops after exactly one epoch completes.""" + # Create small dataset (10 samples) + data_file = tmp_path / "data.json" + create_test_json_file(data_file, num_samples=10) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader(dataset, batch_size=2, collate_fn=lambda x: x) + + # Wrap with StopAfterOneEpoch + device = torch.device("cuda") + batch_iter = StopAfterOneEpoch( + iter(dataloader), device, "test_dataset", dp_process_group=None + ) + + # Collect all batches until StopIteration + batches = [] + for batch in batch_iter: + batches.append(batch) + # Verify all batches are from epoch 0 + epoch = extract_epoch_from_batch(batch) + assert epoch == 0, f"Expected epoch 0, got {epoch}" + + # Should have consumed exactly one epoch (5 batches of size 2) + assert len(batches) == 5 + + +class TestStopAfterOneEpochDistributed(FSDPTest): + """Test StopAfterOneEpoch with distributed synchronization.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_epoch_sync_across_ranks(self): + """Verify all ranks stop when any rank detects epoch change.""" + import shutil + import tempfile + + rank = dist.get_rank() + temp_dir = tempfile.mkdtemp(prefix=f"stop_epoch_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with 20 samples, split across 2 ranks (10 each) + create_test_json_file(data_file, num_samples=20) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader( + dataset, batch_size=2, collate_fn=lambda x: x + ) + + # Get DP process group (use global group for this test) + dp_process_group = dist.group.WORLD + + batch_iter = StopAfterOneEpoch( + iter(dataloader), + torch.device("cuda"), + f"test_rank{rank}", + dp_process_group, + ) + + # Collect batches + batches = [] + for batch in batch_iter: + batches.append(batch) + # All should be epoch 0 + assert extract_epoch_from_batch(batch) == 0 + + # All ranks should have processed exactly one epoch + # Since dataset is split across ranks, each rank gets 10 samples = 5 batches + assert ( + len(batches) == 5 + ), f"Rank {rank} expected 5 batches, got {len(batches)}" + + # Synchronize to ensure both ranks completed + dist.barrier() + + finally: + shutil.rmtree(temp_dir) From d9ea30e84be504db8cfdb65cdabeac9e14e1940f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 13:21:42 -0800 Subject: [PATCH 04/16] improve docstring --- src/forge/data/utils.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index 498d3e419..906e539f0 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -305,36 +305,30 @@ def __next__(self) -> dict: return current_batch -def extract_epoch_from_batch(batch: dict | list) -> int: - """Extract epoch number from batch metrics. +def extract_epoch_from_batch(batch: dict) -> int: + """Extract epoch number from batch metrics. Useful to detect epoch changes during validation, + where we want to run exactly one epoch. - Assumes datasets inherit from InfiniteTuneIterableDataset which always - adds num_epochs metric. Raises clear error if assumption is violated. + Assumes the dataset adds "num_epochs" Metric to teh sample, where one epoch is incremented on dataset exhaustion. + For an example, check forge.src.data.datasets.HfIterableDataset. Args: - batch: Batch dictionary with 'metrics' field OR list of sample dicts + batch (dict): Batch dictionary with 'metrics' field Returns: - Epoch number from metrics + int: Max epoch number from metrics Raises: - ValueError: If metrics missing or no num_epochs found + ValueError: If metrics key is missing or not metric `num_epochs` found """ - # Handle list of samples (uncollated batches) - if isinstance(batch, list): - if not batch: - raise ValueError("Empty batch provided") - batch = batch[0] # Extract first sample - if "metrics" not in batch: raise ValueError( - "Batch missing 'metrics' field. Ensure dataset inherits from " - "InfiniteTuneIterableDataset which adds this automatically." + "Batch missing 'metrics' field. Cannot extract epoch from batch." ) - for metric in batch["metrics"]: - if "num_epochs" in metric.key: - return int(metric.value) + epochs = [metric.value for metric in batch["metrics"] if metric.key == "num_epochs"] + if epochs: + return max(epochs) raise ValueError( f"No 'num_epochs' metric found in batch. Got metrics: " From 95539c5ecadbd1582e740531ab6a031863451805 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 14:05:05 -0800 Subject: [PATCH 05/16] split forward and forward_backward --- apps/sft/main.py | 58 ++++++++++++++++++++---------------------------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 0ce3a4d57..af9035c5e 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -125,19 +125,14 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - """Setup datasets from config. - - Loads training and evaluation datasets based on config structure. - """ # Load training datasets logger.info("Setting training datasets") train_datasets_config = self.job_config.training.datasets self.train_dataloader = self.setup_data(train_datasets_config) - # Load eval config (might be None) + # Load eval datasets eval_config = self.job_config.get("eval", {}) self.val_dataloaders = {} - self.validation_enabled = False self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) max_eval_steps = eval_config.get("max_eval_steps", None) self.max_eval_steps = ( @@ -160,33 +155,28 @@ async def setup(self): self.checkpointer.load(step=self.current_step) def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: - """Setup data from dataset configs. - - Currently only supports single dataset (first in list). - Multi-dataset training with InterleavedDataset is future work. + """Instantiates datasets and returns a StatefulDataLoader. Args: - dataset_configs: List of dataset config dicts with keys like 'path', 'split', etc. + dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`. Returns: - StatefulDataLoader for the dataset + StatefulDataLoader Raises: ValueError: If multiple datasets provided (not yet supported) """ - # Currently only support single dataset + # TODO felipemello: Currently only support single dataset if len(dataset_configs) > 1: raise ValueError( f"Multiple training datasets not supported yet. " f"Got {len(dataset_configs)} datasets. " - f"For dataset mixing, use InterleavedDataset (coming soon)." ) dataset_config = dataset_configs[0] # TODO: Evaluate if tokenizers should be created once and shared for every dataset # Load tokenizer - print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( self.job_config.model.hf_assets_path, "tokenizer.json" @@ -237,26 +227,14 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: collate_fn=partial( collate_packed, mask_fn=packer.create_block_mask, device=self.device ), - drop_last=True, ) def forward( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor, - compute_gradients: bool = True, ) -> torch.Tensor: - """Forward pass with optional gradient computation. - - Args: - input_dict: Input dictionary containing tokens - labels: Target labels - compute_gradients: If True, compute gradients (training mode). - If False, skip backward pass (evaluation mode). - - Returns: - Loss tensor - """ + """Forward pass only (no backward)""" model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -276,7 +254,8 @@ def forward( ) if parallel_dims.pp_enabled: - # Pipeline Parallel forward (with optional backward) + # Pipeline Parallel forward / backward inside step() call + # Note: backward only happens if not in torch.no_grad() context with self.train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) @@ -298,7 +277,7 @@ def forward( else torch.tensor([-1.0], device=self.device) ) else: - # Non-PP forward (with optional backward) + # Non-PP forward with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -306,8 +285,19 @@ def forward( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - if compute_gradients: - loss.backward() + + return loss + + def forward_backward( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + ) -> torch.Tensor: + loss = self.forward(input_dict, labels) + + # For non-PP, explicitly call backward (PP does it inside step()) + if not self.parallel_dims.pp_enabled: + loss.backward() return loss @@ -319,7 +309,7 @@ def train_step(self, batch) -> None: # self.data_parallel_size, # ) as grad_acc: labels = batch.pop("labels") - loss = self.forward(batch, labels, compute_gradients=True) + loss = self.forward_backward(batch, labels) loss = loss.item() record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) @@ -388,7 +378,7 @@ async def evaluate(self) -> None: # Process batch labels = batch.pop("labels") - loss = self.forward(batch, labels, compute_gradients=False) + loss = self.forward(batch, labels) total_loss += loss num_steps += 1 From 0919f5be8345700cbcca2983aa09e91fe7aa6063 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 14:30:08 -0800 Subject: [PATCH 06/16] better docstrings --- apps/sft/main.py | 12 +++++------- src/forge/data/datasets/hf_dataset.py | 9 +++++---- src/forge/data/datasets/packed.py | 4 ---- src/forge/data/utils.py | 26 +++++++++++--------------- 4 files changed, 21 insertions(+), 30 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index af9035c5e..c31fb4a94 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -338,9 +338,9 @@ async def evaluate(self) -> None: model_part.eval() # Get DP process group for epoch synchronization - dp_process_group = None + dp_mesh = None if self.parallel_dims is not None and self.parallel_dims.dp_enabled: - dp_process_group = self.parallel_dims.world_mesh.get_group("dp") + dp_mesh = self.parallel_dims.world_mesh.get_group("dp") # Evaluate each dataset sequentially for dataset_name, val_dataloader in self.val_dataloaders.items(): @@ -350,12 +350,10 @@ async def evaluate(self) -> None: total_loss = torch.tensor(0.0, device=self.device) num_steps = 0 - # NOTE: Assumes batch contains batch["metrics"]["num_epochs"]: int + # NOTE: Assumes batch contains samples with Metric("num_epochs", ...) field batch_iter = StopAfterOneEpoch( - iter(val_dataloader), # Fresh iterator from epoch 0 - self.device, - dataset_name, - dp_process_group, + dataloader_iter=iter(val_dataloader), # Fresh iterator from epoch 0, + dp_process_group=dp_mesh, ) with torch.no_grad(): diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 8399d68e5..504b1336b 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -146,16 +146,18 @@ def _setup_hf_dataset( if self._dp_mesh is not None: world_size = dist.get_world_size(group=self._dp_mesh) rank = dist.get_rank(group=self._dp_mesh) - logger.info( + logger.debug( f"Using DP mesh for sharding: rank={rank}, world_size={world_size}" ) elif dist.is_initialized(): # Fallback to global rank (may not respect TP/PP) world_size = dist.get_world_size() rank = dist.get_rank() + + # TODO: is there a way to detect this and raise error instead? logger.warning( f"Using global rank for sharding: rank={rank}, world_size={world_size}. " - f"If using TP/PP, pass dp_mesh for correct sharding." + f"If using other types of parallelsim (CP/TP/PP), pass dp_mesh for correct sharding." ) # Load and shard dataset @@ -166,7 +168,6 @@ def _setup_hf_dataset( if is_streaming: logger.warning( f"Streaming datasets were not yet tested for distributed training. " - f"split_dataset_by_node is applied, but no resharding was done manually. " f"Dataset '{self.info.name}' has " f"{getattr(ds, 'num_shards', 'unknown')} shards, and your training has {world_size} ranks." f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard" @@ -201,7 +202,7 @@ def _setup_hf_dataset( if num_shards > dataset_size: raise ValueError( f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." - f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}." + f"Please decrease one of {num_shards_per_rank=} or dataloader.num_workers={num_dataloader_workers}" ) ds = ds.to_iterable_dataset(num_shards=num_shards) diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 1a22352ef..69c3aa3a5 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -446,10 +446,6 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None: return None def __iter__(self) -> Iterator[SampleDict]: - """Create a new iterator for the dataset. - - Always resets the packer state to ensure consistent iteration from the start. - """ if not isinstance(self.dataset, Iterable): raise TypeError("Dataset is not an iterable") diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index 906e539f0..b9cd70010 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -219,31 +219,29 @@ def batch_to_device(batch: dict, device: torch.device) -> None: class StopAfterOneEpoch: - """Iterator that wraps a dataloader and stops after one epoch completes. + """Iterator that wraps a dataloader and stops iterating after a rank shows that an epoch has been completed. - Handles epoch detection and synchronization across DP ranks using async - all_reduce. Assumes dataset inherits from InfiniteTuneIterableDataset - which provides 'metrics' with 'num_epochs' metric. + In distributed eval, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 + while others are already in epoch 1. To avoid hangs, all ranks *must* stop at the same time. + This means that we need to do some sort of `all_reduce` to know if at least one rank has seen epoch==1, + introducing communication overhead and blocking the forward pass. - When any rank detects an epoch change, all ranks stop (synchronized). + This function minimzes this impact by fetching one batch in advance and perfoming async all_reduce, overlapping communications. + + Assumes batch contains samples with Metric("num_epochs", ...) field to detect epoch change, as it is done in + `forge.src.data.datasets.HfIterableDataset`. Args: dataloader_iter: Iterator over dataloader batches - device: Device for computation - dataset_name: Name for logging dp_process_group: Data parallel process group (None for single process) """ def __init__( self, dataloader_iter: Iterator, - device: torch.device, - dataset_name: str, dp_process_group: Any = None, ): self.dataloader_iter = dataloader_iter - self.device = device - self.dataset_name = dataset_name self.dp_process_group = dp_process_group # Prefetch first batch for pipeline-style execution @@ -275,9 +273,7 @@ def __next__(self) -> dict: self._epoch_tensor = None if self._should_stop: - logger.debug( - f"[{self.dataset_name}] Eval epoch completed. Stopping data iterator." - ) + logger.debug("Eval epoch completed. Stopping data iterator.") raise StopIteration # Get current batch @@ -291,7 +287,7 @@ def __next__(self) -> dict: # Start async epoch sync if torch.distributed.is_initialized(): - self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device) + self._epoch_tensor = torch.tensor([int(epoch_changed)]) self._pending_work = torch.distributed.all_reduce( self._epoch_tensor, op=torch.distributed.ReduceOp.MAX, From 2b8cfbf7b2c35b6117173134a87ba012fd57f518 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 14:51:22 -0800 Subject: [PATCH 07/16] add type to dp_mesh --- apps/sft/main.py | 6 +++--- src/forge/data/datasets/hf_dataset.py | 2 +- src/forge/data/datasets/sft_dataset.py | 5 +++-- src/forge/data/utils.py | 9 +++++---- .../unit_tests/datasets/test_stop_after_one_epoch.py | 12 +++++------- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index c31fb4a94..b07798b7d 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -353,7 +353,7 @@ async def evaluate(self) -> None: # NOTE: Assumes batch contains samples with Metric("num_epochs", ...) field batch_iter = StopAfterOneEpoch( dataloader_iter=iter(val_dataloader), # Fresh iterator from epoch 0, - dp_process_group=dp_mesh, + dp_mesh=dp_mesh, ) with torch.no_grad(): @@ -395,8 +395,8 @@ async def evaluate(self) -> None: # Record metrics only on DP rank 0 to avoid double counting # record_metric aggregates across all processes via monarch should_record = True - if dp_process_group is not None: - dp_rank = torch.distributed.get_rank(group=dp_process_group) + if dp_mesh is not None: + dp_rank = torch.distributed.get_rank(group=dp_mesh) should_record = dp_rank == 0 if should_record: diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 504b1336b..5e371e6cc 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -70,7 +70,7 @@ def __init__( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, - dp_mesh: Any = None, + dp_mesh: dist.ProcessGroup | None = None, **load_dataset_kwargs, ): # Store configuration diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 6264b13ea..3820ecc97 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -7,6 +7,7 @@ from typing import Any, Callable import torch +import torch.distributed as dist from forge.data import CROSS_ENTROPY_IGNORE_IDX from forge.data.metric_transform import DefaultDatasetMetricTransform @@ -162,7 +163,7 @@ def sft_iterable_dataset( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, - dp_mesh: Any = None, + dp_mesh: dist.ProcessGroup | None = None, **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ @@ -178,7 +179,7 @@ def sft_iterable_dataset( dataset_name (str | None): Name for metrics namespacing filter_fn (Callable | None): Filter function filter_kwargs (dict[str, Any] | None): Filter function kwargs - dp_mesh (Any): Data parallel mesh for sharding (None for single process) + dp_mesh (dist.ProcessGroup | None): Data parallel process group for sharding (None for single process) **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset Returns: diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index b9cd70010..20507a267 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -9,6 +9,7 @@ from typing import Any, Iterator, Literal, Union import torch +import torch.distributed as dist from torch.nn.attention.flex_attention import BlockMask @@ -233,16 +234,16 @@ class StopAfterOneEpoch: Args: dataloader_iter: Iterator over dataloader batches - dp_process_group: Data parallel process group (None for single process) + dp_mesh: Data parallel process group (None for single process) """ def __init__( self, dataloader_iter: Iterator, - dp_process_group: Any = None, + dp_mesh: dist.ProcessGroup | None = None, ): self.dataloader_iter = dataloader_iter - self.dp_process_group = dp_process_group + self.dp_mesh = dp_mesh # Prefetch first batch for pipeline-style execution self._next_batch = next(dataloader_iter) @@ -291,7 +292,7 @@ def __next__(self) -> dict: self._pending_work = torch.distributed.all_reduce( self._epoch_tensor, op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group, + group=self.dp_mesh, async_op=True, ) elif epoch_changed: diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py index 0b0399cda..aa079bb94 100644 --- a/tests/unit_tests/datasets/test_stop_after_one_epoch.py +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -83,9 +83,9 @@ def test_stop_after_one_epoch(self, tmp_path): dataloader = StatefulDataLoader(dataset, batch_size=2, collate_fn=lambda x: x) # Wrap with StopAfterOneEpoch - device = torch.device("cuda") batch_iter = StopAfterOneEpoch( - iter(dataloader), device, "test_dataset", dp_process_group=None + dataloader_iter=iter(dataloader), + dp_mesh=None, ) # Collect all batches until StopIteration @@ -134,13 +134,11 @@ def test_epoch_sync_across_ranks(self): ) # Get DP process group (use global group for this test) - dp_process_group = dist.group.WORLD + dp_mesh = dist.group.WORLD batch_iter = StopAfterOneEpoch( - iter(dataloader), - torch.device("cuda"), - f"test_rank{rank}", - dp_process_group, + dataloader_iter=iter(dataloader), + dp_mesh=dp_mesh, ) # Collect batches From 63fabb7cee3aa8f4546839594ff170417b0a7e3b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 6 Nov 2025 15:22:35 -0800 Subject: [PATCH 08/16] fix unit tets --- apps/sft/main.py | 1 + src/forge/data/utils.py | 12 ++++++--- .../datasets/test_stop_after_one_epoch.py | 27 +++++++++++++++++-- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index b07798b7d..e9ef41653 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -353,6 +353,7 @@ async def evaluate(self) -> None: # NOTE: Assumes batch contains samples with Metric("num_epochs", ...) field batch_iter = StopAfterOneEpoch( dataloader_iter=iter(val_dataloader), # Fresh iterator from epoch 0, + device=self.device, dp_mesh=dp_mesh, ) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index 20507a267..61c375e8b 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -234,15 +234,18 @@ class StopAfterOneEpoch: Args: dataloader_iter: Iterator over dataloader batches + device: Device for synchronization tensors (use cuda for NCCL backend) dp_mesh: Data parallel process group (None for single process) """ def __init__( self, dataloader_iter: Iterator, + device: torch.device, dp_mesh: dist.ProcessGroup | None = None, ): self.dataloader_iter = dataloader_iter + self.device = device self.dp_mesh = dp_mesh # Prefetch first batch for pipeline-style execution @@ -288,7 +291,7 @@ def __next__(self) -> dict: # Start async epoch sync if torch.distributed.is_initialized(): - self._epoch_tensor = torch.tensor([int(epoch_changed)]) + self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device) self._pending_work = torch.distributed.all_reduce( self._epoch_tensor, op=torch.distributed.ReduceOp.MAX, @@ -306,7 +309,7 @@ def extract_epoch_from_batch(batch: dict) -> int: """Extract epoch number from batch metrics. Useful to detect epoch changes during validation, where we want to run exactly one epoch. - Assumes the dataset adds "num_epochs" Metric to teh sample, where one epoch is incremented on dataset exhaustion. + Assumes the dataset adds "num_epochs" Metric to the sample, where one epoch is incremented on dataset exhaustion. For an example, check forge.src.data.datasets.HfIterableDataset. Args: @@ -316,14 +319,15 @@ def extract_epoch_from_batch(batch: dict) -> int: int: Max epoch number from metrics Raises: - ValueError: If metrics key is missing or not metric `num_epochs` found + ValueError: If metrics key is missing or no metric with 'num_epochs' found """ if "metrics" not in batch: raise ValueError( "Batch missing 'metrics' field. Cannot extract epoch from batch." ) - epochs = [metric.value for metric in batch["metrics"] if metric.key == "num_epochs"] + # Match metrics where 'num_epochs' appears in the key (handles prefixed keys like 'dataset/name/num_epochs') + epochs = [metric.value for metric in batch["metrics"] if "num_epochs" in metric.key] if epochs: return max(epochs) diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py index aa079bb94..9f0bcb87f 100644 --- a/tests/unit_tests/datasets/test_stop_after_one_epoch.py +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -27,6 +27,25 @@ def create_test_json_file(path: Path, num_samples: int) -> None: f.write(f'{{"id": {i}, "tokens": [{i}, {i+1}]}}\n') +def simple_collate(batch): + """Simple collate function that mimics collate_packed behavior. + + Stacks tensors, extends metrics list, keeps other fields as lists. + """ + collated = {} + for key in batch[0].keys(): + if isinstance(batch[0][key], torch.Tensor): + collated[key] = torch.stack([sample[key] for sample in batch], dim=0) + elif key == "metrics": + # Extend all metrics into a single list + collated[key] = [] + for sample in batch: + collated[key].extend(sample[key]) + else: + collated[key] = [sample[key] for sample in batch] + return collated + + class TestExtractEpochFromBatch: """Test extract_epoch_from_batch helper function.""" @@ -80,11 +99,14 @@ def test_stop_after_one_epoch(self, tmp_path): num_shards_per_rank=1, ) - dataloader = StatefulDataLoader(dataset, batch_size=2, collate_fn=lambda x: x) + dataloader = StatefulDataLoader( + dataset, batch_size=2, collate_fn=simple_collate + ) # Wrap with StopAfterOneEpoch batch_iter = StopAfterOneEpoch( dataloader_iter=iter(dataloader), + device=torch.device("cpu"), dp_mesh=None, ) @@ -130,7 +152,7 @@ def test_epoch_sync_across_ranks(self): ) dataloader = StatefulDataLoader( - dataset, batch_size=2, collate_fn=lambda x: x + dataset, batch_size=2, collate_fn=simple_collate ) # Get DP process group (use global group for this test) @@ -138,6 +160,7 @@ def test_epoch_sync_across_ranks(self): batch_iter = StopAfterOneEpoch( dataloader_iter=iter(dataloader), + device=torch.device("cuda"), dp_mesh=dp_mesh, ) From dc4b37b96be8c8b3cfe90121c59d17b3b4b820ed Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 09:56:22 -0800 Subject: [PATCH 09/16] revert to forwrd_backward --- apps/sft/main.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index e9ef41653..d86cedf33 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -229,12 +229,13 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: ), ) - def forward( + def forward_backward( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor, + skip_backward: bool = False, ) -> torch.Tensor: - """Forward pass only (no backward)""" + """Forward pass with optional backward.""" model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -255,7 +256,7 @@ def forward( if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - # Note: backward only happens if not in torch.no_grad() context + # Note: PP backward only happens if not in torch.no_grad() context with self.train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) @@ -277,7 +278,7 @@ def forward( else torch.tensor([-1.0], device=self.device) ) else: - # Non-PP forward + # Non-PP forward / backward - must happen inside same context with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -286,18 +287,9 @@ def forward( # need to free to before bwd to avoid peaking memory del pred - return loss - - def forward_backward( - self, - input_dict: dict[str, torch.Tensor], - labels: torch.Tensor, - ) -> torch.Tensor: - loss = self.forward(input_dict, labels) - - # For non-PP, explicitly call backward (PP does it inside step()) - if not self.parallel_dims.pp_enabled: - loss.backward() + # Only run backward if requested. Useful for eval. + if not skip_backward: + loss.backward() return loss @@ -377,7 +369,7 @@ async def evaluate(self) -> None: # Process batch labels = batch.pop("labels") - loss = self.forward(batch, labels) + loss = self.forward_backward(batch, labels, skip_backward=True) total_loss += loss num_steps += 1 From eb6a3a1a4868e77004788835ccd372cee5bf66ae Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 11:09:37 -0800 Subject: [PATCH 10/16] improve logging --- apps/sft/llama3_8b.yaml | 9 ++++----- apps/sft/main.py | 45 +++++++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 960c96e6d..2026939e8 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -37,9 +37,8 @@ training: split: "train[:95%]" eval: - eval_every_n_steps: 5 # (null = disabled) - max_eval_steps: 0 # Max batches per eval dataset (null = run until epoch completes) - batch_size: ${training.local_batch_size} # Batch size for evaluation + eval_every_n_steps: 5 # (null = disabled) + max_eval_steps: null # Max batches per eval dataset (null = run until epoch completes) datasets: - path: "yahma/alpaca-cleaned" split: "train[95%:]" @@ -47,9 +46,9 @@ eval: parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 - tensor_parallel_degree: 2 + tensor_parallel_degree: 1 pipeline_parallel_degree: 1 - context_parallel_degree: 2 + context_parallel_degree: 1 expert_parallel_degree: 1 disable_loss_parallel: false diff --git a/apps/sft/main.py b/apps/sft/main.py index d86cedf33..318e46740 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -30,7 +30,6 @@ from forge.data.utils import StopAfterOneEpoch from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse -from forge.util.logging import log_rank_zero from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -125,6 +124,10 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): + + # metric logger + self.mlogger = await self.setup_metric_logger() + # Load training datasets logger.info("Setting training datasets") train_datasets_config = self.job_config.training.datasets @@ -148,6 +151,7 @@ async def setup(self): for i, dataset_config in enumerate(self.eval_datasets_config): ds_name = dataset_config.get("dataset_name", i) + # TODO: Support separate eval batch size from config (eval.local_batch_size) dataloader = self.setup_data([dataset_config]) self.val_dataloaders[ds_name] = dataloader @@ -305,7 +309,11 @@ def train_step(self, batch) -> None: loss = loss.item() record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) - logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") + if self.current_step % 10 == 0: + logger.info( + f"step {self.current_step} / {self.num_training_steps} | Loss: {loss}" + ) + # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) self.optimizers.step() @@ -323,7 +331,6 @@ async def evaluate(self) -> None: - Record loss and step metrics (on dp rank only) 3. Restore models to train mode """ - logger.debug("==STARTING EVALUATION==") # Set models to eval mode for model_part in self.model_parts: @@ -336,7 +343,7 @@ async def evaluate(self) -> None: # Evaluate each dataset sequentially for dataset_name, val_dataloader in self.val_dataloaders.items(): - logger.debug(f"=====Evaluating dataset: {dataset_name}=====") + logger.info(f"=====Evaluating dataset: {dataset_name}=====") # Evaluation loop for this dataset total_loss = torch.tensor(0.0, device=self.device) @@ -356,9 +363,8 @@ async def evaluate(self) -> None: self.max_eval_steps is not None and num_steps >= self.max_eval_steps ): - log_rank_zero( - logger, - f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}", + logger.info( + f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}" ) break @@ -374,17 +380,17 @@ async def evaluate(self) -> None: num_steps += 1 # Log progress (rank 0 only) - if num_steps % 100 == 0: + if num_steps % 50 == 0: loss_val = loss.item() - log_rank_zero( - logger, - f" [{dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}", + logger.info( + f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}" ) # Compute average loss avg_loss = (total_loss / max(num_steps, 1)).item() - log_rank_zero(logger, f" [{dataset_name}] avg_loss: {avg_loss:.4f}") - + logger.info( + f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}" + ) # Record metrics only on DP rank 0 to avoid double counting # record_metric aggregates across all processes via monarch should_record = True @@ -394,22 +400,17 @@ async def evaluate(self) -> None: if should_record: record_metric( - f"ForgeSFTRecipe/evaluate/{dataset_name}_loss", + f"evaluate/dataset_{dataset_name}_loss", avg_loss, Reduce.MEAN, ) - record_metric( - f"ForgeSFTRecipe/evaluate/{dataset_name}_steps", - num_steps, - Reduce.MEAN, - ) # Restore train mode for model_part in self.model_parts: model_part.train() # Summary - logger.debug("==EVALUATION COMPLETE==") + logger.info("==Evaluation complete==") @endpoint async def train(self) -> None: @@ -447,6 +448,10 @@ async def train(self) -> None: last_step=self.current_step == self.num_training_steps, ) + # Flush metrics + if self._rank == 0: + await self.mlogger.flush.call_one(global_step=self.current_step) + # self.pbar.close() # Run final evaluation at end of training From aadd15a589da0620465fd945dc93ae881b4dd0f1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 11:16:28 -0800 Subject: [PATCH 11/16] update configs --- apps/sft/llama3_8b.yaml | 2 +- apps/sft/qwen3_8b.yaml | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 2026939e8..57c73677e 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -37,7 +37,7 @@ training: split: "train[:95%]" eval: - eval_every_n_steps: 5 # (null = disabled) + eval_every_n_steps: 50 # (null = disabled) max_eval_steps: null # Max batches per eval dataset (null = run until epoch completes) datasets: - path: "yahma/alpaca-cleaned" diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index f7c4999bb..b40129e3d 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -31,7 +31,16 @@ training: max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +eval: + eval_every_n_steps: 50 # (null = disabled) + max_eval_steps: null # Max batches per eval dataset (null = run until epoch completes) + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[95%:]" parallelism: data_parallel_replicate_degree: 1 From 520c9d377e4e732f6aaa2222d25e879660e03c3e Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 11:40:00 -0800 Subject: [PATCH 12/16] unit test hf iter --- src/forge/data/datasets/hf_dataset.py | 12 ++--- tests/unit_tests/datasets/test_hf.py | 66 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 5e371e6cc..75b484607 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -81,7 +81,6 @@ def __init__( self._output_transform = output_transform self._weight = weight if weight is not None else 1.0 self._dp_mesh = dp_mesh - self._is_resumed = False # Create default transform if not provided self._metric_transform = metric_transform or DefaultDatasetMetricTransform() @@ -105,6 +104,10 @@ def __init__( self._metric_transform.set_source(dataset_name) # Internal state for resumption + # _start_epoch: The epoch to start from. Updated on resume from ckpt. + # useful when doing iter(ds), which restarts dataset from original state. + self._start_epoch = 0 + # _num_epochs: updated on every dataset exhaustion self._num_epochs = 0 # Load and setup HF dataset @@ -233,9 +236,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: - Adds 'num_epochs' metric to track dataset progress - Yields samples indefinitely for continuous training """ - # Reset iter - if not self._is_resumed: - self._num_epochs = 0 + self._num_epochs = self._start_epoch while True: # Infinite iteration self._ds.set_epoch(self._num_epochs) @@ -294,10 +295,9 @@ def state_dict(self) -> dict[str, Any]: return state def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self._num_epochs = state_dict["num_epochs"] + self._start_epoch = state_dict["num_epochs"] hf_state = state_dict["hf_dataset_state"] # HF is responsible for resuming the dataset state # where it last left off self._ds.load_state_dict(hf_state) - self._is_resumed = True diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index 9fd2ce464..76aa79142 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -270,6 +270,72 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): epoch_value == 1 for epoch_value in epoch_values ), f"Epoch values should be 1, got {epoch_values}" + def test_multiple_iter_calls_after_resume( + self, dataset_factory, small_dataset_file + ): + """Test that calling iter() multiple times after resuming restarts from checkpoint epoch. + + 1. Resume from checkpoint at epoch 2 + 2. Consume one epoch (now at epoch 3) + 3. Call iter(ds) again to create a new iterator + 4. The new iterator should restart from epoch 2 (checkpoint epoch), not 0 or 3 + + This ensures datasets can be re-iterated from their checkpoint state. + """ + dataset = dataset_factory(small_dataset_file, shuffle=False) + + # consume 2 epochs + it1 = iter(dataset) + samples = list(islice(it1, SMALL_DATASET_SIZE * 2)) + + # Save checkpoint after 2 epochs + state = dataset.state_dict() + + # Continue training for 1 more epoch on the same iterator + more_samples = list(islice(it1, SMALL_DATASET_SIZE)) + + # Create a new dataset instance and load the checkpoint + dataset2 = dataset_factory(small_dataset_file, shuffle=False) + dataset2.load_state_dict(state) + + # First iter() call should start from epoch 2 (the checkpoint epoch) + it2 = iter(dataset2) + first_iter_samples = list(islice(it2, SMALL_DATASET_SIZE)) + first_iter_epochs = [ + metric.value + for sample in first_iter_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 2 for epoch in first_iter_epochs + ), f"First iter() should start at checkpoint epoch 2, got {set(first_iter_epochs)}" + + # Consume one more epoch from the same iterator (now at epoch 3) + second_epoch_samples = list(islice(it2, SMALL_DATASET_SIZE)) + second_epoch_epochs = [ + metric.value + for sample in second_epoch_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 3 for epoch in second_epoch_epochs + ), f"Second epoch should be 3, got {set(second_epoch_epochs)}" + + # Call iter() again - it should restart from epoch 2, not continue from 4 + it3 = iter(dataset2) + new_iter_samples = list(islice(it3, SMALL_DATASET_SIZE)) + new_iter_epochs = [ + metric.value + for sample in new_iter_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 2 for epoch in new_iter_epochs + ), f"New iter() should restart from checkpoint epoch 2, got {set(new_iter_epochs)}" + class TestDistributedHfIterableDataset(FSDPTest): """Test HfIterableDataset with 2-GPU distributed setup.""" From a0dcc9876159b8608406ecc9ac99c583ce049ec7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 11:57:54 -0800 Subject: [PATCH 13/16] nits --- apps/sft/main.py | 7 ++--- src/forge/data/utils.py | 31 +++++++++---------- .../datasets/test_stop_after_one_epoch.py | 4 +-- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 318e46740..8693be3d1 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -202,9 +202,6 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: ), ) - # Store tokenizer for later use (e.g., decoding in debug logs) - self.tokenizer = tokenizer - # Get DP mesh for data sharding dp_mesh = None if self.parallel_dims is not None and self.parallel_dims.dp_enabled: @@ -215,7 +212,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: model_transform=tokenizer, message_transform=AlpacaToMessages(), dp_mesh=dp_mesh, - **dataset_config, # Unpack config (path, split, etc.) + **dataset_config, ) packer = TextPacker(padding_idx=0) @@ -351,7 +348,7 @@ async def evaluate(self) -> None: # NOTE: Assumes batch contains samples with Metric("num_epochs", ...) field batch_iter = StopAfterOneEpoch( - dataloader_iter=iter(val_dataloader), # Fresh iterator from epoch 0, + iter=iter(val_dataloader), # Fresh iterator from epoch 0, device=self.device, dp_mesh=dp_mesh, ) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index 61c375e8b..e335c23e4 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -220,36 +220,34 @@ def batch_to_device(batch: dict, device: torch.device) -> None: class StopAfterOneEpoch: - """Iterator that wraps a dataloader and stops iterating after a rank shows that an epoch has been completed. + """Wraps an iterator, e.g. dataloader, and stops iterating after a rank shows that an epoch has been completed. In distributed eval, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 - while others are already in epoch 1. To avoid hangs, all ranks *must* stop at the same time. - This means that we need to do some sort of `all_reduce` to know if at least one rank has seen epoch==1, - introducing communication overhead and blocking the forward pass. + while others are already in epoch 1. To avoid hangs, all ranks *must* stop at the same time, requiring communication. - This function minimzes this impact by fetching one batch in advance and perfoming async all_reduce, overlapping communications. + This function minimzes this impact by fetching one batch in advance and perfoming overlapping async all_reduce. - Assumes batch contains samples with Metric("num_epochs", ...) field to detect epoch change, as it is done in + Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in `forge.src.data.datasets.HfIterableDataset`. Args: - dataloader_iter: Iterator over dataloader batches - device: Device for synchronization tensors (use cuda for NCCL backend) - dp_mesh: Data parallel process group (None for single process) + iter (Iterator): Iterator over dataloader batches + device (torch.device): Device for synchronizing tensors + dp_mesh (dist.ProcessGroup | None): Data parallel process group (None for single process) """ def __init__( self, - dataloader_iter: Iterator, + iter: Iterator, device: torch.device, dp_mesh: dist.ProcessGroup | None = None, ): - self.dataloader_iter = dataloader_iter + self.iter = iter self.device = device self.dp_mesh = dp_mesh # Prefetch first batch for pipeline-style execution - self._next_batch = next(dataloader_iter) + self._next_batch = next(iter) # Track pending async epoch sync self._epoch_tensor: torch.Tensor | None = None @@ -285,7 +283,7 @@ def __next__(self) -> dict: current_epoch = extract_epoch_from_batch(current_batch) # Prefetch next batch and check for epoch change - self._next_batch = next(self.dataloader_iter) + self._next_batch = next(self.iter) next_epoch = extract_epoch_from_batch(self._next_batch) epoch_changed = next_epoch > current_epoch @@ -306,11 +304,10 @@ def __next__(self) -> dict: def extract_epoch_from_batch(batch: dict) -> int: - """Extract epoch number from batch metrics. Useful to detect epoch changes during validation, - where we want to run exactly one epoch. + """Extract epoch number from batch metrics. Useful to detect epoch changes during validation. - Assumes the dataset adds "num_epochs" Metric to the sample, where one epoch is incremented on dataset exhaustion. - For an example, check forge.src.data.datasets.HfIterableDataset. + Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in + `forge.src.data.datasets.HfIterableDataset`. Args: batch (dict): Batch dictionary with 'metrics' field diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py index 9f0bcb87f..d0deaf86d 100644 --- a/tests/unit_tests/datasets/test_stop_after_one_epoch.py +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -105,7 +105,7 @@ def test_stop_after_one_epoch(self, tmp_path): # Wrap with StopAfterOneEpoch batch_iter = StopAfterOneEpoch( - dataloader_iter=iter(dataloader), + iter=iter(dataloader), device=torch.device("cpu"), dp_mesh=None, ) @@ -159,7 +159,7 @@ def test_epoch_sync_across_ranks(self): dp_mesh = dist.group.WORLD batch_iter = StopAfterOneEpoch( - dataloader_iter=iter(dataloader), + iter=iter(dataloader), device=torch.device("cuda"), dp_mesh=dp_mesh, ) From 47829e6a92d72656ec4e37dd1fe42dc31044b413 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 11:59:28 -0800 Subject: [PATCH 14/16] nits --- apps/sft/llama3_8b.yaml | 4 ++-- apps/sft/qwen3_8b.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 57c73677e..e626648db 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -37,8 +37,8 @@ training: split: "train[:95%]" eval: - eval_every_n_steps: 50 # (null = disabled) - max_eval_steps: null # Max batches per eval dataset (null = run until epoch completes) + eval_every_n_steps: 50 # null = disabled + max_eval_steps: null # null = run until epoch completes datasets: - path: "yahma/alpaca-cleaned" split: "train[95%:]" diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index b40129e3d..fb39856ae 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -36,8 +36,8 @@ training: split: "train[:95%]" eval: - eval_every_n_steps: 50 # (null = disabled) - max_eval_steps: null # Max batches per eval dataset (null = run until epoch completes) + eval_every_n_steps: 50 # null = disabled + max_eval_steps: null # null = run until epoch completes datasets: - path: "yahma/alpaca-cleaned" split: "train[95%:]" From d2a35022e2a7243f2e6635dd5ed6575ec417982f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 12:06:51 -0800 Subject: [PATCH 15/16] nits --- apps/sft/main.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 8693be3d1..c8a09af21 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -155,9 +155,13 @@ async def setup(self): dataloader = self.setup_data([dataset_config]) self.val_dataloaders[ds_name] = dataloader - # Load checkpoint if resuming + # TODO: confirm that this is working properly + # Should also use load, not dcp_load self.checkpointer.load(step=self.current_step) + # self.profiler = self.setup_profiler(self.train_config.profiler_config) + # self.logger = self.setup_logger(self.train_config.logger_config) + def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: """Instantiates datasets and returns a StatefulDataLoader. @@ -219,10 +223,10 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: dataset = PackedDataset( dataset=dataset, packer=packer, - target_tokens_per_pack=self.job_config.training.seq_len, + target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model ) - return StatefulDataLoader( + dataloader = StatefulDataLoader( dataset=dataset, batch_size=self.job_config.training.local_batch_size, collate_fn=partial( @@ -230,6 +234,12 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: ), ) + # Ultimately we probably want something like this + # packer = build_packing_strategy(packing_config) + # dataset = build_dataset(dataset_config) + # dataloader = build_dataloader(dataloader_config, dataset, packer) + returndataloader + def forward_backward( self, input_dict: dict[str, torch.Tensor], @@ -346,7 +356,7 @@ async def evaluate(self) -> None: total_loss = torch.tensor(0.0, device=self.device) num_steps = 0 - # NOTE: Assumes batch contains samples with Metric("num_epochs", ...) field + # NOTE: Assumes batch contains field "metrics" batch_iter = StopAfterOneEpoch( iter=iter(val_dataloader), # Fresh iterator from epoch 0, device=self.device, @@ -406,7 +416,6 @@ async def evaluate(self) -> None: for model_part in self.model_parts: model_part.train() - # Summary logger.info("==Evaluation complete==") @endpoint @@ -451,7 +460,6 @@ async def train(self) -> None: # self.pbar.close() - # Run final evaluation at end of training if self.validation_enabled: logger.info("Running final evaluation at end of training...") await self.evaluate() From 55637351487b71506bc27dcb375cf9f82616bf25 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 7 Nov 2025 12:27:12 -0800 Subject: [PATCH 16/16] nit --- apps/sft/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index c8a09af21..dc9d0e181 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -238,7 +238,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: # packer = build_packing_strategy(packing_config) # dataset = build_dataset(dataset_config) # dataloader = build_dataloader(dataloader_config, dataset, packer) - returndataloader + return dataloader def forward_backward( self,