From 8141cef5042a99f63b6c2c584627cbde6636f11e Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 11:05:23 -0800 Subject: [PATCH 1/5] add trainer protocol --- src/forge/api/__init__.py | 7 ++ src/forge/api/trainer.py | 135 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 src/forge/api/__init__.py create mode 100644 src/forge/api/trainer.py diff --git a/src/forge/api/__init__.py b/src/forge/api/__init__.py new file mode 100644 index 000000000..78c794923 --- /dev/null +++ b/src/forge/api/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Forge library modules diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py new file mode 100644 index 000000000..548d9a878 --- /dev/null +++ b/src/forge/api/trainer.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Trainer protocol. + +This file defines the unified training interface compatible +with all supported torchforge trainers. +""" + +from typing import Any, Protocol, runtime_checkable + +import torch + + +@runtime_checkable +class Trainer(Protocol): + """Protocol for all trainers in torchforge.""" + + async def accumulate_gradients( + self, microbatch: dict[str, torch.Tensor] + ) -> dict[str, Any]: + """Accumulate gradients from one microbatch. + + Does NOT clear gradients - they accumulate on top of existing. + Can be called multiple times before optim_step(). + + Returns: + dict with keys: + - loss: float + - metrics: dict[str, float] + """ + ... + + async def optim_step(self, params: dict[str, Any] | None = None) -> dict[str, Any]: + """Apply optimizer step and clear gradients after. + + Returns: + dict with keys: + - step: int + - learning_rate: float + - accumulated_microbatches: int + """ + ... + + async def clear_gradients(self) -> None: + """Clear accumulated gradients without applying.""" + ... + + async def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Run forward pass, no backward. + + Returns: + dict with key: + - logits: torch.Tensor + """ + ... + + async def forward_backward( + self, data: list[dict[str, torch.Tensor]] + ) -> dict[str, Any]: + """Clear first, then forward+backward on all items in data. + + Convenience wrapper equivalent to: + clear_gradients() + accumulate_gradients() for each item + + Does NOT call optim_step() - you must call it separately. + + Returns: + dict with keys: + - loss: float + - metrics: dict[str, float] + """ + ... + + async def save_state(self, name: str) -> dict[str, Any]: + """Save the checkpoint. + + Returns: + dict with keys: + - path: str + - step: int + """ + ... + + async def load_state(self, path: str) -> dict[str, Any]: + """Load checkpoint. + + Returns: + dict with keys: + - step: int + - learning_rate: float + """ + ... + + async def save_weights_for_sampler(self, name: str) -> dict[str, Any]: + """Export weights for inference. + + Returns: + dict with keys: + - path: str + - version: str or int + """ + ... + + async def get_info(self) -> dict[str, Any]: + """Get static model metadata. + + Returns: + dict with keys like: + - model_name: str + - step: int + - config: dict + """ + ... + + async def get_status(self) -> dict[str, Any]: + """Get runtime status. + + Returns: + dict with keys like: + - step: int + - accumulated_microbatches: int + """ + ... + + def get_tokenizer(self): + """Get the tokenizer. + + Returns: + PreTrainedTokenizer + """ + ... From 189b24247a6cd6f5cc741fe9e37cd5aab0fe6e4b Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 11:30:37 -0800 Subject: [PATCH 2/5] remove some get_* i'm not convinced on yet --- src/forge/api/trainer.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py index 548d9a878..30e749363 100644 --- a/src/forge/api/trainer.py +++ b/src/forge/api/trainer.py @@ -105,27 +105,6 @@ async def save_weights_for_sampler(self, name: str) -> dict[str, Any]: """ ... - async def get_info(self) -> dict[str, Any]: - """Get static model metadata. - - Returns: - dict with keys like: - - model_name: str - - step: int - - config: dict - """ - ... - - async def get_status(self) -> dict[str, Any]: - """Get runtime status. - - Returns: - dict with keys like: - - step: int - - accumulated_microbatches: int - """ - ... - def get_tokenizer(self): """Get the tokenizer. From d89bc1bcd047dc5497e75b1b0a26ef31cd0c7875 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 14:23:08 -0800 Subject: [PATCH 3/5] bulk changes --- src/forge/api/__init__.py | 25 +++- src/forge/api/trainer.py | 306 +++++++++++++++++++++++++++++++------- src/forge/api/types.py | 172 +++++++++++++++++++++ 3 files changed, 452 insertions(+), 51 deletions(-) create mode 100644 src/forge/api/types.py diff --git a/src/forge/api/__init__.py b/src/forge/api/__init__.py index 78c794923..aafac4a54 100644 --- a/src/forge/api/__init__.py +++ b/src/forge/api/__init__.py @@ -4,4 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Forge library modules +"""Forge public API module. + +This module defines the public interfaces that all Forge implementations conform to. +""" + +from forge.api.trainer import Trainer +from forge.api.types import ( + ForwardResult, + OptimStepResult, + TextTrainBatch, + TrainerInfo, + TrainerStatus, + TrainResult, +) + +__all__ = [ + "Trainer", + "TextTrainBatch", + "TrainResult", + "OptimStepResult", + "ForwardResult", + "TrainerInfo", + "TrainerStatus", +] diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py index 30e749363..cb8061ba7 100644 --- a/src/forge/api/trainer.py +++ b/src/forge/api/trainer.py @@ -4,111 +4,317 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Trainer protocol. +"""Trainer protocol for Forge. + +This module defines the unified training interface that all trainer implementations +must conform to. -This file defines the unified training interface compatible -with all supported torchforge trainers. """ from typing import Any, Protocol, runtime_checkable import torch +from forge.api.types import ( + ForwardResult, + OptimStepResult, + TextTrainBatch, + TrainerInfo, + TrainerStatus, + TrainResult, +) + @runtime_checkable class Trainer(Protocol): - """Protocol for all trainers in torchforge.""" + """Protocol defining the standard interface for all Forge trainers.""" - async def accumulate_gradients( - self, microbatch: dict[str, torch.Tensor] - ) -> dict[str, Any]: - """Accumulate gradients from one microbatch. + async def forward_backward(self, batch: TextTrainBatch) -> TrainResult: + """Execute forward pass and backward pass for one batch of data. - Does NOT clear gradients - they accumulate on top of existing. - Can be called multiple times before optim_step(). + Basic usage - single batch per optimizer step: + >>> batch = TextTrainBatch( + >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]), + >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]), + >>> ) + >>> result = await trainer.forward_backward(batch) + >>> await trainer.optim_step() # Apply gradients + + To accumulate gradients over multiple batches before optimizer step: + >>> await trainer.forward_backward(batch1) # Accumulates + >>> await trainer.forward_backward(batch2) # Accumulates another batch + >>> await trainer.optim_step() # Apply all accumulated gradients + + Args: + batch: TextTrainBatch containing input_ids, target_ids, and optional + target_mask/target_weights. See forge.api.types.TextTrainBatch for details. Returns: - dict with keys: - - loss: float - - metrics: dict[str, float] + TrainResult containing loss and metrics + + Note: + The loss function is configured at trainer creation time via the + `loss` parameter, not passed to this method. """ ... - async def optim_step(self, params: dict[str, Any] | None = None) -> dict[str, Any]: - """Apply optimizer step and clear gradients after. + async def optim_step(self, params: dict[str, Any] | None = None) -> OptimStepResult: + """Apply optimizer step using accumulated gradients, then clear gradients. + + This method: + 1. Applies accumulated gradients via the optimizer + 2. Steps the learning rate scheduler + 3. Clears all gradients (zero_grad) + 4. Increments the training step counter + 5. May trigger automatic checkpointing (implementation-dependent) + + Gradients must have been accumulated via forward_backward() calls before + calling this method. + + Args: + params: Optional optimizer parameters. Currently reserved for future use. + Most implementations ignore this and use the optimizer config from + trainer initialization. Returns: - dict with keys: - - step: int - - learning_rate: float - - accumulated_microbatches: int + OptimStepResult containing step number, learning rate, and accumulated batch count + + Example: + >>> # Accumulate over 4 batches + >>> for batch in batches[:4]: + >>> await trainer.forward_backward(batch) + >>> result = await trainer.optim_step() + >>> print(f"Step {result.step}, LR {result.learning_rate:.2e}") + >>> print(f"Accumulated {result.accumulated_microbatches} batches") """ ... async def clear_gradients(self) -> None: - """Clear accumulated gradients without applying.""" + """Clear accumulated gradients without applying them. + + Use this when you need to discard accumulated gradients without performing + an optimizer step. Common scenarios: + - Exception during gradient accumulation + - Skipping a training step due to some condition + - Recovering from OOM or other errors + + This is equivalent to calling optimizer.zero_grad() and resetting internal + accumulation counters. + + Example - Error recovery: + >>> try: + >>> for batch in batches: + >>> await trainer.forward_backward(batch) + >>> await trainer.optim_step() + >>> except torch.cuda.OutOfMemoryError: + >>> await trainer.clear_gradients() # Discard partial gradients + >>> # Retry with smaller batches + + Example - Conditional skip: + >>> await trainer.forward_backward(batch) + >>> if should_skip_step(): + >>> await trainer.clear_gradients() # Don't apply these gradients + >>> else: + >>> await trainer.optim_step() + """ ... - async def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Run forward pass, no backward. + async def forward(self, inputs: dict[str, torch.Tensor]) -> ForwardResult: + """Run forward pass only, without backward pass (for evaluation/inference). + + This method executes the model's forward pass without computing gradients. + Useful for: + - Evaluation on validation/test data + - Getting model predictions/logits + - Debugging model outputs + + Args: + inputs: Dictionary containing model inputs. Typically includes: + - input_ids: torch.Tensor [batch_size, seq_len] + Other keys depend on the model architecture. Returns: - dict with key: - - logits: torch.Tensor + ForwardResult containing model logits + + Note: + This runs in torch.no_grad() context - no gradients are computed. + + Example: + >>> eval_batch = {"input_ids": torch.tensor([[1, 2, 3, 4]])} + >>> output = await trainer.forward(eval_batch) + >>> logits = output.logits # [1, 4, vocab_size] + >>> predictions = logits.argmax(dim=-1) # [1, 4] """ ... - async def forward_backward( - self, data: list[dict[str, torch.Tensor]] + async def save_state( + self, name: str | None = None, path: str | None = None ) -> dict[str, Any]: - """Clear first, then forward+backward on all items in data. + """Save a checkpoint of the current trainer state. - Convenience wrapper equivalent to: - clear_gradients() + accumulate_gradients() for each item + Saves the complete training state including model weights, optimizer state, + learning rate scheduler state, and current step counter. This checkpoint + can be loaded later to resume training from this exact point. - Does NOT call optim_step() - you must call it separately. + Args: + name: Optional checkpoint name/identifier. If None, uses the current + step number (e.g., "step-1000"). + path: Optional base directory or URI where checkpoint should be saved. + If None, uses the default checkpoint directory configured at trainer + creation. Supports different backends via URI schemes: + - `/local/path` - local filesystem + - `ts://key` - TorchStore + - `s3://bucket/key` - S3 + + Location resolution: + - Both provided: path/name (e.g., "/checkpoints" + "best" = "/checkpoints/best") + - Only path: use path directly + - Only name: default_dir/name + - Neither: default_dir/step-{current_step} Returns: - dict with keys: - - loss: float - - metrics: dict[str, float] + dict containing: + - path: str - Full path where checkpoint was saved + - step: int - Training step at which checkpoint was saved + + Example: + >>> # Save to default location with step number + >>> result = await trainer.save_state() # => /default/step-1000 + >>> + >>> # Save with custom name to default location + >>> result = await trainer.save_state("best-model") # => /default/best-model + >>> + >>> # Save to custom base directory + >>> result = await trainer.save_state("final", "/custom/checkpoints") + >>> # => /custom/checkpoints/final """ ... - async def save_state(self, name: str) -> dict[str, Any]: - """Save the checkpoint. + async def load_state(self, path: str | None = None) -> dict[str, Any]: + """Load a previously saved checkpoint. + + Restores the complete training state from a checkpoint, including model + weights, optimizer state, learning rate scheduler state, and step counter. + + Args: + path: Optional path or URI to the checkpoint to load. If None, loads + the most recent checkpoint from the default directory. Can be: + - `/local/path/checkpoint` - local filesystem + - `ts://key` - TorchStore + - `s3://bucket/key` - S3 + + Returns: + dict containing: + - step: int - Training step from the loaded checkpoint + - learning_rate: float - Learning rate from the loaded checkpoint + + Example: + >>> # Load latest checkpoint from default location + >>> result = await trainer.load_state() + >>> print(f"Resumed from step {result['step']}") + >>> + >>> # Load specific checkpoint by path + >>> result = await trainer.load_state("/checkpoints/step-5000") + >>> + >>> # Load from TorchStore + >>> result = await trainer.load_state("ts://checkpoint-key") + """ + ... + + async def save_weights( + self, name: str | None = None, path: str | None = None + ) -> dict[str, Any]: + """Save model weights only (without optimizer/scheduler state). + + Saves only the model weights in a format suitable for inference/sampling. + This is lighter weight than save_state() since it excludes training state + like optimizer and scheduler. + + Args: + name: Optional checkpoint name/identifier. If None, uses the current + step number (e.g., "weights-step-1000"). + path: Optional base directory or URI where weights should be saved. + If None, uses the default location configured at trainer creation. + Supports different backends via URI schemes: + - `/local/path` - local filesystem + - `ts://key` - TorchStore + - `s3://bucket/key` - S3 + + Location resolution: + - Both provided: path/name + - Only path: use path directly + - Only name: default_dir/name + - Neither: default_dir/step-{current_step} Returns: - dict with keys: - - path: str - - step: int + dict containing: + - path: str - Full URI where weights were saved + - version: str | int - The name/version that was saved + + Example: + >>> # Save to default location with step number + >>> result = await trainer.save_weights() + >>> + >>> # Save to TorchStore for inference server + >>> result = await trainer.save_weights("policy-v1", "ts://policy-weights") + >>> # → ts://policy-weights/policy-v1 + >>> + >>> # Save to S3 + >>> result = await trainer.save_weights(path="s3://bucket/models/final") """ ... - async def load_state(self, path: str) -> dict[str, Any]: - """Load checkpoint. + async def get_info(self) -> TrainerInfo: + """Get static trainer and model metadata. + + Returns information about the trainer configuration and model architecture + that doesn't change during training. Returns: - dict with keys: - - step: int - - learning_rate: float + TrainerInfo containing model name, step, config, and parallelism settings + + Example: + >>> info = await trainer.get_info() + >>> print(f"Training {info.model_name} at step {info.step}") + >>> print(f"Vocab size: {info.config['vocab_size']}") + >>> print(f"Data parallel degree: {info.parallelism['dp_degree']}") """ ... - async def save_weights_for_sampler(self, name: str) -> dict[str, Any]: - """Export weights for inference. + async def get_status(self) -> TrainerStatus: + """Get current runtime status of the trainer. + + Returns dynamic information about the trainer's current state that changes + during training. Returns: - dict with keys: - - path: str - - version: str or int + TrainerStatus containing current step and accumulated batch count + + Example: + >>> status = await trainer.get_status() + >>> print(f"Current step: {status.step}") + >>> if status.accumulated_microbatches > 0: + >>> print(f"Warning: {status.accumulated_microbatches} " + >>> f"batches accumulated without optimizer step") """ ... def get_tokenizer(self): - """Get the tokenizer. + """Get the tokenizer associated with this model. + + Returns the tokenizer used for encoding/decoding text with this model. + Useful for preprocessing inputs or decoding model outputs. Returns: - PreTrainedTokenizer + PreTrainedTokenizer: The HuggingFace tokenizer for this model + + Note: + This is a synchronous method (not async) since tokenizer access is + typically fast and doesn't require remote calls. + + Example: + >>> tokenizer = trainer.get_tokenizer() + >>> tokens = tokenizer.encode("Hello world") + >>> text = tokenizer.decode([1, 2, 3, 4]) """ ... diff --git a/src/forge/api/types.py b/src/forge/api/types.py new file mode 100644 index 000000000..31bdaf4e5 --- /dev/null +++ b/src/forge/api/types.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Type definitions for the Forge API.""" + +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass +class TextTrainBatch: + """A batch of text training data for forward_backward. + + This dataclass defines the standard format for text training batches across all + Forge text trainers. + + Attributes: + input_ids: Input token IDs. Shape: [batch_size, seq_len] + target_ids: Target token IDs for loss computation. Shape: [batch_size, seq_len] + target_mask: Mask indicating which tokens to compute loss on. + Shape: [batch_size, seq_len]. Values are 0 (ignore) or 1 (compute loss). + If None, computes loss on all tokens. + target_weights: Per-token weights for loss computation. + Shape: [batch_size, seq_len]. Used for importance weighting, such as + advantages in RL (GRPO, PPO) or custom loss weighting schemes. + If None, all tokens have weight 1.0. + + Example: + >>> batch = TextTrainBatch( + >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]), + >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]), + >>> target_mask=torch.tensor([[0, 0, 1, 1, 1]]), # Only predict last 3 tokens + >>> target_weights=torch.tensor([[0, 0, 1.0, 0.8, 1.2]]), # Weight by advantage + >>> ) + >>> result = await trainer.forward_backward(batch) + """ + + input_ids: torch.Tensor + target_ids: torch.Tensor + target_mask: torch.Tensor | None = None + target_weights: torch.Tensor | None = None + + +@dataclass +class TrainResult: + """Result from a forward_backward pass. + + Attributes: + loss: Loss value computed for the batch + metrics: Additional metrics computed during training (e.g., perplexity, + accuracy). May be empty if no additional metrics are tracked. + + Example: + >>> result = await trainer.forward_backward(batch) + >>> print(f"Loss: {result.loss:.4f}") + >>> if result.metrics: + >>> print(f"Metrics: {result.metrics}") + """ + + loss: float + metrics: dict[str, float] + + +@dataclass +class OptimStepResult: + """Result from an optimizer step. + + Attributes: + step: Training step number after this optimizer step + learning_rate: Current learning rate used for this step + accumulated_microbatches: Number of forward_backward calls that were + accumulated before this optimizer step. Useful for tracking gradient + accumulation behavior. + + Example: + >>> result = await trainer.optim_step() + >>> print(f"Step {result.step}, LR={result.learning_rate:.2e}") + >>> print(f"Accumulated {result.accumulated_microbatches} batches") + """ + + step: int + learning_rate: float + accumulated_microbatches: int + + +@dataclass +class ForwardResult: + """Result from a forward pass (evaluation/inference). + + Attributes: + logits: Model output logits (pre-softmax). Shape: [batch_size, seq_len, vocab_size] + + Example: + >>> result = await trainer.forward(eval_batch) + >>> predictions = result.logits.argmax(dim=-1) # [batch_size, seq_len] + """ + + logits: torch.Tensor + + +@dataclass +class TrainerInfo: + """Static trainer and model metadata. + + This contains information about the trainer configuration and model architecture + that doesn't change during training. + + Note: + The exact format of `config` and `parallelism` dicts depends on the underlying + trainer implementation (TorchTitan, HuggingFace, etc.). The fields below + document common keys, but implementations may include additional fields. + + Attributes: + model_name: Name or path of the model being trained + step: Current training step + config: Model configuration. Common keys include: + - vocab_size: int - Size of the vocabulary + - hidden_size: int - Hidden dimension size + - num_layers: int - Number of transformer layers + - num_attention_heads: int - Number of attention heads + - max_seq_len: int - Maximum sequence length + parallelism: Parallelism configuration. Common keys include: + - dp_degree: int - Data parallel degree + - tp_degree: int - Tensor parallel degree + - pp_degree: int - Pipeline parallel degree + - dp_rank: int - Current data parallel rank + - tp_rank: int - Current tensor parallel rank + - device: str - Device identifier (e.g., "cuda:0") + - gradient_accumulation_steps: int - Number of microbatches per step + + Example: + >>> info = await trainer.get_info() + >>> print(f"Training {info.model_name} at step {info.step}") + >>> print(f"Vocab size: {info.config['vocab_size']}") + >>> print(f"DP={info.parallelism['dp_degree']}, " + >>> f"TP={info.parallelism['tp_degree']}") + >>> print(f"Device: {info.parallelism['device']}") + """ + + model_name: str + step: int + config: dict[str, Any] + parallelism: dict[str, Any] + + +@dataclass +class TrainerStatus: + """Runtime status of the trainer. + + This contains dynamic information about the trainer's current state that + changes during training. + + Attributes: + step: Current training step + accumulated_microbatches: Number of batches accumulated since the last + optim_step. Will be 0 if gradients were just applied/cleared. + + Example: + >>> status = await trainer.get_status() + >>> print(f"Current step: {status.step}") + >>> if status.accumulated_microbatches > 0: + >>> print(f"Warning: {status.accumulated_microbatches} batches " + >>> f"accumulated without optimizer step") + """ + + step: int + accumulated_microbatches: int From 83f3a8f1a32b06cea7aad47b203192b528522504 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 14:31:06 -0800 Subject: [PATCH 4/5] forwardbackwardresult --- src/forge/api/__init__.py | 4 ++-- src/forge/api/trainer.py | 6 +++--- src/forge/api/types.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/forge/api/__init__.py b/src/forge/api/__init__.py index aafac4a54..777f68c61 100644 --- a/src/forge/api/__init__.py +++ b/src/forge/api/__init__.py @@ -11,18 +11,18 @@ from forge.api.trainer import Trainer from forge.api.types import ( + ForwardBackwardResult, ForwardResult, OptimStepResult, TextTrainBatch, TrainerInfo, TrainerStatus, - TrainResult, ) __all__ = [ "Trainer", "TextTrainBatch", - "TrainResult", + "ForwardBackwardResult", "OptimStepResult", "ForwardResult", "TrainerInfo", diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py index cb8061ba7..3d7952f52 100644 --- a/src/forge/api/trainer.py +++ b/src/forge/api/trainer.py @@ -16,12 +16,12 @@ import torch from forge.api.types import ( + ForwardBackwardResult, ForwardResult, OptimStepResult, TextTrainBatch, TrainerInfo, TrainerStatus, - TrainResult, ) @@ -29,7 +29,7 @@ class Trainer(Protocol): """Protocol defining the standard interface for all Forge trainers.""" - async def forward_backward(self, batch: TextTrainBatch) -> TrainResult: + async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult: """Execute forward pass and backward pass for one batch of data. Basic usage - single batch per optimizer step: @@ -50,7 +50,7 @@ async def forward_backward(self, batch: TextTrainBatch) -> TrainResult: target_mask/target_weights. See forge.api.types.TextTrainBatch for details. Returns: - TrainResult containing loss and metrics + ForwardBackwardResult containing loss and metrics Note: The loss function is configured at trainer creation time via the diff --git a/src/forge/api/types.py b/src/forge/api/types.py index 31bdaf4e5..f8c814b6c 100644 --- a/src/forge/api/types.py +++ b/src/forge/api/types.py @@ -47,7 +47,7 @@ class TextTrainBatch: @dataclass -class TrainResult: +class ForwardBackwardResult: """Result from a forward_backward pass. Attributes: From 4d11e4eacba2193375b93868f10cb093fa774635 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 14:40:07 -0800 Subject: [PATCH 5/5] add custom loss --- src/forge/api/__init__.py | 2 ++ src/forge/api/trainer.py | 19 ++++++++++++++++--- src/forge/api/types.py | 6 +++++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/forge/api/__init__.py b/src/forge/api/__init__.py index 777f68c61..1b04d01a9 100644 --- a/src/forge/api/__init__.py +++ b/src/forge/api/__init__.py @@ -13,6 +13,7 @@ from forge.api.types import ( ForwardBackwardResult, ForwardResult, + LossFn, OptimStepResult, TextTrainBatch, TrainerInfo, @@ -27,4 +28,5 @@ "ForwardResult", "TrainerInfo", "TrainerStatus", + "LossFn", ] diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py index 3d7952f52..34dd13aec 100644 --- a/src/forge/api/trainer.py +++ b/src/forge/api/trainer.py @@ -18,6 +18,7 @@ from forge.api.types import ( ForwardBackwardResult, ForwardResult, + LossFn, OptimStepResult, TextTrainBatch, TrainerInfo, @@ -29,7 +30,9 @@ class Trainer(Protocol): """Protocol defining the standard interface for all Forge trainers.""" - async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult: + async def forward_backward( + self, batch: TextTrainBatch, loss_fn: LossFn | None = None + ) -> ForwardBackwardResult: """Execute forward pass and backward pass for one batch of data. Basic usage - single batch per optimizer step: @@ -45,16 +48,26 @@ async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult >>> await trainer.forward_backward(batch2) # Accumulates another batch >>> await trainer.optim_step() # Apply all accumulated gradients + Custom loss function for specific batches: + >>> def custom_loss(logits: torch.Tensor, batch: TextTrainBatch) -> torch.Tensor: + >>> # Custom loss computation (e.g., PPO clip, DPO, etc.) + >>> return loss + >>> + >>> result = await trainer.forward_backward(batch, loss_fn=custom_loss) + Args: batch: TextTrainBatch containing input_ids, target_ids, and optional target_mask/target_weights. See forge.api.types.TextTrainBatch for details. + loss_fn: Optional custom loss function. If None, uses the loss function + configured at trainer creation. Signature: (logits, batch) -> loss. + Useful for mixed training objectives or experimentation. Returns: ForwardBackwardResult containing loss and metrics Note: - The loss function is configured at trainer creation time via the - `loss` parameter, not passed to this method. + The default loss function is configured at trainer creation time via the + `loss` parameter. The `loss_fn` parameter here allows per-batch override. """ ... diff --git a/src/forge/api/types.py b/src/forge/api/types.py index f8c814b6c..8ae24e4d4 100644 --- a/src/forge/api/types.py +++ b/src/forge/api/types.py @@ -7,11 +7,15 @@ """Type definitions for the Forge API.""" from dataclasses import dataclass -from typing import Any +from typing import Any, Callable, TypeAlias import torch +# Loss function signature: takes logits and batch, returns scalar loss +LossFn: TypeAlias = Callable[[torch.Tensor, "TextTrainBatch"], torch.Tensor] + + @dataclass class TextTrainBatch: """A batch of text training data for forward_backward.