From 498d3df28cfa332702bbddd7ef243864d642433c Mon Sep 17 00:00:00 2001 From: Angky William Date: Thu, 13 Nov 2025 13:11:03 -0800 Subject: [PATCH 01/18] SFT data iterator --- src/art/utils/iterate_dataset.py | 195 ++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index fda51c41..c8a86b2a 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -1,7 +1,8 @@ +import json import math import random from dataclasses import dataclass -from typing import Generator, Generic, List, TypeVar +from typing import Any, Generator, Generic, Iterable, List, TypeVar from tqdm.auto import tqdm @@ -92,3 +93,195 @@ def iterate_dataset( if progress_bar: progress_bar.close() + + +def get_file_row_count(file_path: str) -> int: + """ + Count the number of non-empty rows in a JSONL file. + + Args: + file_path: Path to JSONL file + + Returns: + Number of non-empty lines in the file + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + count = get_file_row_count("data.jsonl") + print(f"Dataset has {count} items") + """ + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + count = 0 + with open(file_path, "r") as f: + for line in f: + if line.strip(): + count += 1 + return count + + +def iterate_trajectories( + trajectories: List["Trajectory"], epochs: int +) -> Generator["Trajectory", None, None]: + """ + Iterate over a list of trajectories for multiple epochs. + + Args: + trajectories: List of Trajectory objects + epochs: Number of times to iterate over the list + + Yields: + Trajectory objects from the list + + Example: + # Load trajectories once + trajs = [traj1, traj2, traj3] + + # Iterate 3 times + for traj in iterate_trajectories(trajs, epochs=3): + # Process trajectory + pass + """ + for _ in range(epochs): + for trajectory in trajectories: + yield trajectory + + +def iterate_file(file_path: str, epochs: int) -> Generator["Trajectory", None, None]: + """ + Read JSONL file for each epoch, yielding Trajectory objects. + + Each line should contain a dict with: + - messages: List of chat messages + - tools: Optional list of tools + - reward: Optional reward (defaults to default_reward) + - split: Optional split name (stored in metadata) + - Any other fields will be stored in metadata + + Args: + file_path: Path to JSONL file (one JSON object per line) + epochs: Number of times to read through the file + default_reward: Default reward value if not specified in data + + Yields: + Trajectory objects parsed from the file + + Raises: + ValueError: If file_path does not end with .jsonl + """ + from art.trajectories import Trajectory + + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + for _ in range(epochs): + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + + # Extract messages and convert to messages_and_choices format + messages = data.get("messages", []) + tools = data.get("tools", None) + + # Create trajectory + yield Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + + +def chunk_trajectories( + trajectories: Iterable["Trajectory"], + batch_size: int, + chunk_size: int, + shuffle_buffer_size: int = 10000, + seed: int | None = None, +) -> Generator[List["Trajectory"], None, None]: + """ + Chunk trajectories from an iterable into batches. + + Args: + trajectories: Iterable of Trajectory objects (can be list, generator, etc.) + batch_size: Number of chunks per batch + chunk_size: Number of trajectories per chunk + shuffle_buffer_size: Size of shuffle buffer. Default: 10000 (~200MB-1GB). + Set to 0 for no shuffle (sequential order). + Recommended: 1000-50000 depending on available RAM. + Larger buffer = better shuffle quality but more memory. + seed: Random seed for deterministic shuffling. Default: None (non-deterministic). + Set to an integer for reproducible results. + + Yields: + List of trajectories (batch_size * chunk_size items) + + Example: + # Default shuffle (buffer_size=10000, random) + chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8) + + # Deterministic shuffle (reproducible) + chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, seed=42) + + # No shuffle + chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, shuffle_buffer_size=0) + + # Larger buffer for better shuffle + chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, shuffle_buffer_size=50000, seed=42) + """ + items_per_batch = batch_size * chunk_size + + if shuffle_buffer_size > 0: + # Set seed for deterministic shuffling + if seed is not None: + random.seed(seed) + + # Buffer-based shuffle + shuffle_buffer: List["Trajectory"] = [] + batch_items = [] + + for trajectory in trajectories: + shuffle_buffer.append(trajectory) + + # Once buffer is full, start yielding + if len(shuffle_buffer) >= shuffle_buffer_size: + # Pop random item from buffer + idx = random.randint(0, len(shuffle_buffer) - 1) + traj = shuffle_buffer.pop(idx) + + batch_items.append(traj) + + if len(batch_items) == items_per_batch: + yield batch_items + batch_items = [] + + # Flush remaining items in shuffle buffer + random.shuffle(shuffle_buffer) + for traj in shuffle_buffer: + batch_items.append(traj) + + if len(batch_items) == items_per_batch: + yield batch_items + batch_items = [] + + # Yield any remaining items as a final batch + if batch_items: + yield batch_items + else: + # No shuffle - simple batching + batch_items = [] + for trajectory in trajectories: + batch_items.append(trajectory) + + if len(batch_items) == items_per_batch: + yield batch_items + batch_items = [] + + # Yield any remaining items as a final batch + if batch_items: + yield batch_items From 3bd818f44be6eb7617bd3b9ff9f7ffae1b4c84b6 Mon Sep 17 00:00:00 2001 From: Angky William Date: Thu, 13 Nov 2025 16:14:32 -0800 Subject: [PATCH 02/18] Add SFT LR utils --- src/art/utils/sft.py | 92 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 src/art/utils/sft.py diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py new file mode 100644 index 00000000..4ec39528 --- /dev/null +++ b/src/art/utils/sft.py @@ -0,0 +1,92 @@ +"""Utilities for supervised fine-tuning (SFT).""" + +import math +from typing import Generator, List, Literal + + +def create_lr_schedule( + total_steps: int, + peak_lr: float, + method: Literal["cosine", "linear", "constant"] = "cosine", + warmup_steps: int = 0, + min_lr: float = 0.0, +) -> List[float]: + """ + Create learning rate schedule for training with optional warmup. + + Args: + total_steps: Total number of training steps + peak_lr: Peak learning rate + method: Learning rate schedule method. Options: + - "cosine": Cosine annealing from peak_lr to min_lr + - "linear": Linear decay from peak_lr to min_lr + - "constant": Constant learning rate (peak_lr for all steps) + warmup_steps: Number of warmup steps (linear warmup from 0 to peak_lr) + min_lr: Minimum learning rate (floor for decay schedules) + + Returns: + List of learning rates for each step + + Example: + # Cosine schedule with warmup + lrs = create_lr_schedule(100, 1e-4, method="cosine", warmup_steps=10) + + # Use with training loop + for step, chunk in enumerate(chunk_trajectories(...)): + train_sft(chunk, learning_rate=lrs[step]) + """ + learning_rates = [] + + for step in range(total_steps): + # Warmup phase: linear warmup from 0 to peak_lr + if step < warmup_steps: + lr = peak_lr * (step / warmup_steps) + else: + # Main schedule phase + # Adjust step to be relative to post-warmup period + adjusted_step = step - warmup_steps + adjusted_total = total_steps - warmup_steps + + if method == "cosine": + # Cosine annealing: lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + cos(pi * t)) + lr = min_lr + (peak_lr - min_lr) * 0.5 * ( + 1 + math.cos(math.pi * adjusted_step / adjusted_total) + ) + elif method == "linear": + # Linear decay: lr = peak_lr - (peak_lr - min_lr) * (t / total) + lr = peak_lr - (peak_lr - min_lr) * (adjusted_step / adjusted_total) + elif method == "constant": + # Constant learning rate + lr = peak_lr + else: + raise ValueError( + f"Unknown method: {method}. Choose from: cosine, linear, constant" + ) + + learning_rates.append(lr) + + return learning_rates + + +def chunk_learning_rate( + learning_rates: List[float], + chunk_size: int, +) -> Generator[List[float], None, None]: + """ + Chunk a list of learning rates into groups. + + Args: + learning_rates: List of learning rate values + chunk_size: Number of learning rates per chunk + + Yields: + List of learning rates (chunk_size items, last chunk may be smaller) + + Example: + lrs = create_lr_schedule(10, 1e-4) + for lr_chunk in chunk_learning_rate(lrs, chunk_size=3): + # lr_chunk has 3 learning rates (or fewer for last chunk) + print(lr_chunk) # [1e-5, 2e-5, 3e-5] + """ + for i in range(0, len(learning_rates), chunk_size): + yield learning_rates[i : i + chunk_size] From 66ec62074121b5035003b77983a7c90788078fa6 Mon Sep 17 00:00:00 2001 From: Angky William Date: Thu, 13 Nov 2025 18:11:53 -0800 Subject: [PATCH 03/18] train_sft skeleton --- src/art/backend.py | 21 ++++++++++++++++++--- src/art/dev/__init__.py | 3 ++- src/art/dev/train.py | 5 +++++ src/art/local/backend.py | 19 +++++++++++++++++-- src/art/model.py | 24 +++++++++++++++++++++++- src/art/serverless/backend.py | 21 ++++++++++++++++++--- src/art/types.py | 7 ++++++- 7 files changed, 89 insertions(+), 11 deletions(-) diff --git a/src/art/backend.py b/src/art/backend.py index 9fa95c0e..473681a0 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal import httpx from tqdm import auto as tqdm @@ -8,8 +8,8 @@ from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider from . import dev -from .trajectories import TrajectoryGroup -from .types import TrainConfig +from .trajectories import Trajectory, TrajectoryGroup +from .types import SFTConfig, TrainConfig if TYPE_CHECKING: from .model import Model, TrainableModel @@ -126,6 +126,21 @@ async def _train_model( if pbar is not None: pbar.close() + async def _train_sft( + self, + model: "TrainableModel", + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented. " + "This method will be available in a future release." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + # ------------------------------------------------------------------ # Experimental support for S3 # ------------------------------------------------------------------ diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py index b60525d9..6257135f 100644 --- a/src/art/dev/__init__.py +++ b/src/art/dev/__init__.py @@ -7,7 +7,7 @@ ) from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config from .torchtune import TorchtuneArgs -from .train import TrainConfig +from .train import SFTConfig, TrainConfig __all__ = [ "EngineArgs", @@ -18,6 +18,7 @@ "get_openai_server_config", "OpenAIServerConfig", "ServerArgs", + "SFTConfig", "TorchtuneArgs", "TrainConfig", ] diff --git a/src/art/dev/train.py b/src/art/dev/train.py index f6491b15..6a540a9f 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -22,3 +22,8 @@ class TrainConfig(TypedDict, total=False): scale_learning_rate_by_reward_std_dev: bool scale_rewards: bool truncated_importance_sampling: float | None + + +class SFTConfig(TypedDict, total=False): + """Experimental SFT configuration options. Use at your own risk.""" + pass diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 13a906b4..ef1e2e3a 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -5,7 +5,7 @@ import subprocess from datetime import datetime from types import TracebackType -from typing import AsyncIterator, Literal, cast +from typing import AsyncIterator, Iterable, Literal, cast import aiohttp import numpy as np @@ -54,7 +54,7 @@ ) from ..preprocessing.tokenize import tokenize_trajectory_groups from ..trajectories import Trajectory, TrajectoryGroup -from ..types import Message, TrainConfig +from ..types import Message, SFTConfig, TrainConfig from ..utils import format_message, get_model_step from .checkpoints import ( delete_checkpoints, @@ -521,6 +521,21 @@ async def _train_model( if verbose: print("_train_model complete") + async def _train_sft( + self, + model: TrainableModel, + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented for LocalBackend. " + "Please use the Backend HTTP API or implement this method." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + def _get_reward_std_dev_learning_rate_multiplier( self, model: TrainableModel ) -> float: diff --git a/src/art/model.py b/src/art/model.py index 43c519b2..4593a8b6 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -7,7 +7,7 @@ from . import dev from .trajectories import Trajectory, TrajectoryGroup -from .types import TrainConfig +from .types import SFTConfig, TrainConfig if TYPE_CHECKING: from art.backend import Backend @@ -386,3 +386,25 @@ async def train( self, list(trajectory_groups), config, _config or {}, verbose ): pass + + async def train_sft( + self, + trajectories: Iterable[Trajectory], + config: SFTConfig, + _config: dev.SFTConfig | None = None, + verbose: bool = False, + ) -> None: + """ + Supervised fine-tune the model with trajectories and per-batch learning rates. + + Args: + trajectories: An iterable of Trajectory objects. + config: SFT configuration including learning_rates and batch_size. + _config: Additional experimental configuration that is subject to change and + not yet part of the public API. Use at your own risk. + verbose: Whether to print verbose output. + """ + async for _ in self.backend()._train_sft( + self, trajectories, config, _config or {}, verbose + ): + pass diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 604faea5..a07ae789 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal from openai._types import NOT_GIVEN from tqdm import auto as tqdm @@ -9,8 +9,8 @@ from .. import dev from ..backend import Backend -from ..trajectories import TrajectoryGroup -from ..types import TrainConfig +from ..trajectories import Trajectory, TrajectoryGroup +from ..types import SFTConfig, TrainConfig if TYPE_CHECKING: from ..model import Model, TrainableModel @@ -159,6 +159,21 @@ async def _train_model( raise RuntimeError(f"Training job failed: {error_message}") after = event.id + async def _train_sft( + self, + model: "TrainableModel", + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented for ServerlessBackend. " + "Please use the Backend HTTP API or implement this method." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + # ------------------------------------------------------------------ # Experimental support for S3 # ------------------------------------------------------------------ diff --git a/src/art/types.py b/src/art/types.py index fd1bb272..89d3ced3 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Iterable, Literal import pydantic from openai.types.chat.chat_completion import Choice @@ -17,4 +17,9 @@ class TrainConfig(pydantic.BaseModel): beta: float = 0.0 +class SFTConfig(pydantic.BaseModel): + learning_rates: Iterable[float] + batch_size: int + + Verbosity = Literal[0, 1, 2] From 4aeda2fcb41d38642a80f17d34573d4903af7d06 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 14 Nov 2025 15:05:51 -0800 Subject: [PATCH 04/18] SFT Shape 0.1 --- src/art/types.py | 2 +- src/art/utils/iterate_dataset.py | 177 ++++++++++++++----------------- src/art/utils/sft.py | 77 ++++++++++++-- 3 files changed, 149 insertions(+), 107 deletions(-) diff --git a/src/art/types.py b/src/art/types.py index 89d3ced3..f1e345aa 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -18,7 +18,7 @@ class TrainConfig(pydantic.BaseModel): class SFTConfig(pydantic.BaseModel): - learning_rates: Iterable[float] + learning_rate: float | Iterable[float] batch_size: int diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index c8a86b2a..07dddfcf 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -123,31 +123,98 @@ def get_file_row_count(file_path: str) -> int: return count +def get_total_steps(traj_len: int, epochs: int, batch_size: int) -> int: + """ + Calculate total number of training steps given dataset size, epochs, and batch size. + + Args: + traj_len: Number of trajectories in the dataset + epochs: Number of epochs to train + batch_size: Number of trajectories per batch/step + + Returns: + Total number of training steps + + Example: + # 100 trajectories, 3 epochs, batch size of 10 + total_steps = get_total_steps(100, 3, 10) + # Returns 30 (10 steps per epoch * 3 epochs) + + # With partial batch at end + total_steps = get_total_steps(105, 3, 10) + # Returns 33 (11 steps per epoch * 3 epochs) + """ + steps_per_epoch = math.ceil(traj_len / batch_size) + return steps_per_epoch * epochs + + def iterate_trajectories( - trajectories: List["Trajectory"], epochs: int -) -> Generator["Trajectory", None, None]: + trajectories: List["Trajectory"], + epochs: int, + batch_size: int, + chunk_size: int = 1, + initial_step: int = 0, +) -> Generator[List["Trajectory"], None, None]: """ - Iterate over a list of trajectories for multiple epochs. + Iterate over a list of trajectories for multiple epochs, yielding batches. + Shuffles trajectories at the start of each epoch with a fixed seed for reproducibility. Args: trajectories: List of Trajectory objects epochs: Number of times to iterate over the list + batch_size: Number of chunks per batch + chunk_size: Number of trajectories per chunk. Defaults to 1. + initial_step: The global step number to start from. Defaults to 0. + Useful for resuming training. Yields: - Trajectory objects from the list + List of trajectories (batch_size * chunk_size items) Example: # Load trajectories once trajs = [traj1, traj2, traj3] - # Iterate 3 times - for traj in iterate_trajectories(trajs, epochs=3): - # Process trajectory + # Iterate 3 epochs, 2 trajectories per batch + for batch in iterate_trajectories(trajs, epochs=3, batch_size=2): + # batch is a list of 2 trajectories + train_sft(batch, ...) + + # With chunk_size + for batch in iterate_trajectories(trajs, epochs=3, batch_size=4, chunk_size=5): + # batch is a list of 20 trajectories (4 chunks * 5 per chunk) + pass + + # Resume from step 10 + for batch in iterate_trajectories(trajs, epochs=3, batch_size=2, initial_step=10): + # Skips first 10 batches, starts from step 10 pass """ - for _ in range(epochs): - for trajectory in trajectories: - yield trajectory + + dataset_size = len(trajectories) + if dataset_size == 0: + return + + items_per_step = batch_size * chunk_size + steps_per_epoch = math.ceil(dataset_size / items_per_step) + + for epoch in range(epochs): + # Create indices and shuffle deterministically based on epoch + indices = list(range(dataset_size)) + random.seed(epoch) + random.shuffle(indices) + + for i in range(0, dataset_size, items_per_step): + batch_index = i // items_per_step + # Calculate global step number + global_step = epoch * steps_per_epoch + batch_index + + # Skip if before initial_step + if global_step < initial_step: + continue + + batch_indices = indices[i : i + items_per_step] + batch_items = [trajectories[idx] for idx in batch_indices] + yield batch_items def iterate_file(file_path: str, epochs: int) -> Generator["Trajectory", None, None]: @@ -195,93 +262,3 @@ def iterate_file(file_path: str, epochs: int) -> Generator["Trajectory", None, N tools=tools if tools else None, reward=0.0 ) - - -def chunk_trajectories( - trajectories: Iterable["Trajectory"], - batch_size: int, - chunk_size: int, - shuffle_buffer_size: int = 10000, - seed: int | None = None, -) -> Generator[List["Trajectory"], None, None]: - """ - Chunk trajectories from an iterable into batches. - - Args: - trajectories: Iterable of Trajectory objects (can be list, generator, etc.) - batch_size: Number of chunks per batch - chunk_size: Number of trajectories per chunk - shuffle_buffer_size: Size of shuffle buffer. Default: 10000 (~200MB-1GB). - Set to 0 for no shuffle (sequential order). - Recommended: 1000-50000 depending on available RAM. - Larger buffer = better shuffle quality but more memory. - seed: Random seed for deterministic shuffling. Default: None (non-deterministic). - Set to an integer for reproducible results. - - Yields: - List of trajectories (batch_size * chunk_size items) - - Example: - # Default shuffle (buffer_size=10000, random) - chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8) - - # Deterministic shuffle (reproducible) - chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, seed=42) - - # No shuffle - chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, shuffle_buffer_size=0) - - # Larger buffer for better shuffle - chunk_trajectories(iterate_file("data.jsonl", epochs=1), 4, 8, shuffle_buffer_size=50000, seed=42) - """ - items_per_batch = batch_size * chunk_size - - if shuffle_buffer_size > 0: - # Set seed for deterministic shuffling - if seed is not None: - random.seed(seed) - - # Buffer-based shuffle - shuffle_buffer: List["Trajectory"] = [] - batch_items = [] - - for trajectory in trajectories: - shuffle_buffer.append(trajectory) - - # Once buffer is full, start yielding - if len(shuffle_buffer) >= shuffle_buffer_size: - # Pop random item from buffer - idx = random.randint(0, len(shuffle_buffer) - 1) - traj = shuffle_buffer.pop(idx) - - batch_items.append(traj) - - if len(batch_items) == items_per_batch: - yield batch_items - batch_items = [] - - # Flush remaining items in shuffle buffer - random.shuffle(shuffle_buffer) - for traj in shuffle_buffer: - batch_items.append(traj) - - if len(batch_items) == items_per_batch: - yield batch_items - batch_items = [] - - # Yield any remaining items as a final batch - if batch_items: - yield batch_items - else: - # No shuffle - simple batching - batch_items = [] - for trajectory in trajectories: - batch_items.append(trajectory) - - if len(batch_items) == items_per_batch: - yield batch_items - batch_items = [] - - # Yield any remaining items as a final batch - if batch_items: - yield batch_items diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index 4ec39528..bfdbea16 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -1,7 +1,10 @@ """Utilities for supervised fine-tuning (SFT).""" import math -from typing import Generator, List, Literal +from typing import TYPE_CHECKING, Generator, List, Literal + +if TYPE_CHECKING: + from art.model import TrainableModel def create_lr_schedule( @@ -68,25 +71,87 @@ def create_lr_schedule( return learning_rates -def chunk_learning_rate( +def iterate_learning_rates( learning_rates: List[float], chunk_size: int, + initial_step: int = 0, ) -> Generator[List[float], None, None]: """ - Chunk a list of learning rates into groups. + Iterate over learning rates in chunks, with support for resuming from a specific step. Args: learning_rates: List of learning rate values chunk_size: Number of learning rates per chunk + initial_step: The step number to start from. Defaults to 0. + Useful for resuming training. Yields: List of learning rates (chunk_size items, last chunk may be smaller) Example: lrs = create_lr_schedule(10, 1e-4) - for lr_chunk in chunk_learning_rate(lrs, chunk_size=3): + for lr_chunk in iterate_learning_rates(lrs, chunk_size=3): # lr_chunk has 3 learning rates (or fewer for last chunk) - print(lr_chunk) # [1e-5, 2e-5, 3e-5] + # Yields: [lr0, lr1, lr2], [lr3, lr4, lr5], [lr6, lr7, lr8], [lr9] + + # Resume from step 5 + for lr_chunk in iterate_learning_rates(lrs, chunk_size=3, initial_step=5): + # Starts from learning rate 5: yields [lr5, lr6, lr7], [lr8, lr9] + pass """ - for i in range(0, len(learning_rates), chunk_size): + for i in range(initial_step, len(learning_rates), chunk_size): yield learning_rates[i : i + chunk_size] + + +async def train_sft_from_file( + model: "TrainableModel", + file_path: str, + batch_size: int, + learning_rate: float, + epochs: int +) -> None: + """ + Convenience function to train a model with SFT from a JSONL file. + + Args: + model: TrainableModel to train + file_path: Path to JSONL file containing trajectories + batch_size: Number of trajectories per batch/step + learning_rate: Peak learning rate (uses cosine schedule) + epochs: Number of epochs to train + + Example: + await train_sft_from_file( + model=model, + file_path="data.jsonl", + batch_size=10, + learning_rate=1e-5, + epochs=3, + ) + """ + from art.types import SFTConfig + from art.utils.iterate_dataset import get_file_row_count, get_total_steps, iterate_file + + # Calculate total steps + num_trajectories = get_file_row_count(file_path) + total_steps = get_total_steps(num_trajectories, epochs, batch_size) + + # Set warmup steps: 10% of total steps, capped at 1000 + warmup_steps = min(total_steps // 10, 1000) + + # Create cosine learning rate schedule with warmup + learning_rates = create_lr_schedule( + total_steps=total_steps, + peak_lr=learning_rate, + method="cosine", + warmup_steps=warmup_steps, + ) + + # Create SFT config + config = SFTConfig(learning_rate=learning_rates, batch_size=batch_size) + + # Train the model + await model.train_sft( + trajectories=iterate_file(file_path, epochs=epochs), + config=config + ) From 4ff152b9727b6c3f7897f8810a87359a778af117 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 14 Nov 2025 15:10:27 -0800 Subject: [PATCH 05/18] Add shuffle to SFTConfig --- src/art/types.py | 1 + src/art/utils/sft.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/art/types.py b/src/art/types.py index f1e345aa..3d43535a 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -20,6 +20,7 @@ class TrainConfig(pydantic.BaseModel): class SFTConfig(pydantic.BaseModel): learning_rate: float | Iterable[float] batch_size: int + shuffle: bool = False Verbosity = Literal[0, 1, 2] diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index bfdbea16..563de1f6 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -147,8 +147,8 @@ async def train_sft_from_file( warmup_steps=warmup_steps, ) - # Create SFT config - config = SFTConfig(learning_rate=learning_rates, batch_size=batch_size) + # Create SFT config with shuffling enabled + config = SFTConfig(learning_rate=learning_rates, batch_size=batch_size, shuffle=True) # Train the model await model.train_sft( From b6f0380249261481dadce43e4be35c587177efa5 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 14 Nov 2025 15:32:39 -0800 Subject: [PATCH 06/18] change SFT args order --- src/art/utils/sft.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index 563de1f6..9ff04cbd 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -106,9 +106,9 @@ def iterate_learning_rates( async def train_sft_from_file( model: "TrainableModel", file_path: str, - batch_size: int, + epochs: int, learning_rate: float, - epochs: int + batch_size: int = 8, ) -> None: """ Convenience function to train a model with SFT from a JSONL file. @@ -116,17 +116,16 @@ async def train_sft_from_file( Args: model: TrainableModel to train file_path: Path to JSONL file containing trajectories - batch_size: Number of trajectories per batch/step - learning_rate: Peak learning rate (uses cosine schedule) epochs: Number of epochs to train + learning_rate: Peak learning rate (uses cosine schedule) + batch_size: Number of trajectories per batch/step. Defaults to 8. Example: await train_sft_from_file( model=model, file_path="data.jsonl", - batch_size=10, - learning_rate=1e-5, epochs=3, + learning_rate=1e-5, ) """ from art.types import SFTConfig From e32db378956643a2c56442349e61077b25b273ae Mon Sep 17 00:00:00 2001 From: Angky William Date: Mon, 17 Nov 2025 16:25:20 -0800 Subject: [PATCH 07/18] Refactor SFT to accept batched trajectories MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move batching and shuffling logic from SFTConfig into iterator functions. train_sft now accepts Iterable[List[Trajectory]] instead of individual trajectories, simplifying the API and making batch management more explicit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/art/backend.py | 4 +- src/art/local/backend.py | 4 +- src/art/model.py | 8 +- src/art/serverless/backend.py | 4 +- src/art/types.py | 4 +- src/art/utils/iterate_dataset.py | 174 +++++++++++++++++++++-------- src/art/utils/sft.py | 10 +- tests/unit/test_sft.py | 182 +++++++++++++++++++++++++++++++ 8 files changed, 329 insertions(+), 61 deletions(-) create mode 100644 tests/unit/test_sft.py diff --git a/src/art/backend.py b/src/art/backend.py index 473681a0..07b01d12 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, List, Literal import httpx from tqdm import auto as tqdm @@ -129,7 +129,7 @@ async def _train_model( async def _train_sft( self, model: "TrainableModel", - trajectories: Iterable[Trajectory], + trajectories: Iterable[List[Trajectory]], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/local/backend.py b/src/art/local/backend.py index ef1e2e3a..13c83fef 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -5,7 +5,7 @@ import subprocess from datetime import datetime from types import TracebackType -from typing import AsyncIterator, Iterable, Literal, cast +from typing import AsyncIterator, Iterable, List, Literal, cast import aiohttp import numpy as np @@ -524,7 +524,7 @@ async def _train_model( async def _train_sft( self, model: TrainableModel, - trajectories: Iterable[Trajectory], + trajectories: Iterable[List[Trajectory]], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/model.py b/src/art/model.py index 4593a8b6..afd7f9ac 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar, cast, overload import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient @@ -389,16 +389,16 @@ async def train( async def train_sft( self, - trajectories: Iterable[Trajectory], + trajectories: Iterable[List[Trajectory]], config: SFTConfig, _config: dev.SFTConfig | None = None, verbose: bool = False, ) -> None: """ - Supervised fine-tune the model with trajectories and per-batch learning rates. + Supervised fine-tune the model with batches of trajectories. Args: - trajectories: An iterable of Trajectory objects. + trajectories: An iterable of trajectory batches (lists of Trajectory objects). config: SFT configuration including learning_rates and batch_size. _config: Additional experimental configuration that is subject to change and not yet part of the public API. Use at your own risk. diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index a07ae789..c6f928b5 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, List, Literal from openai._types import NOT_GIVEN from tqdm import auto as tqdm @@ -162,7 +162,7 @@ async def _train_model( async def _train_sft( self, model: "TrainableModel", - trajectories: Iterable[Trajectory], + trajectories: Iterable[List[Trajectory]], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/types.py b/src/art/types.py index 3d43535a..6e2073e4 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -18,9 +18,7 @@ class TrainConfig(pydantic.BaseModel): class SFTConfig(pydantic.BaseModel): - learning_rate: float | Iterable[float] - batch_size: int - shuffle: bool = False + learning_rate: Iterable[float] Verbosity = Literal[0, 1, 2] diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index 07dddfcf..146845af 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -2,10 +2,13 @@ import math import random from dataclasses import dataclass -from typing import Any, Generator, Generic, Iterable, List, TypeVar +from typing import TYPE_CHECKING, Any, Generator, Generic, Iterable, List, TypeVar from tqdm.auto import tqdm +if TYPE_CHECKING: + from art.trajectories import Trajectory + T = TypeVar("T") @@ -154,39 +157,40 @@ def iterate_trajectories( batch_size: int, chunk_size: int = 1, initial_step: int = 0, -) -> Generator[List["Trajectory"], None, None]: +) -> Generator[List[List["Trajectory"]], None, None]: """ - Iterate over a list of trajectories for multiple epochs, yielding batches. + Iterate over a list of trajectories for multiple epochs, yielding chunks of batches. Shuffles trajectories at the start of each epoch with a fixed seed for reproducibility. Args: trajectories: List of Trajectory objects epochs: Number of times to iterate over the list - batch_size: Number of chunks per batch - chunk_size: Number of trajectories per chunk. Defaults to 1. + batch_size: Number of trajectories per batch (inner list size) + chunk_size: Number of batches per chunk (outer list size). Defaults to 1. initial_step: The global step number to start from. Defaults to 0. Useful for resuming training. Yields: - List of trajectories (batch_size * chunk_size items) + List of lists of trajectories (chunk_size batches, each with batch_size trajectories) Example: # Load trajectories once - trajs = [traj1, traj2, traj3] + trajs = [traj1, traj2, traj3, traj4] - # Iterate 3 epochs, 2 trajectories per batch - for batch in iterate_trajectories(trajs, epochs=3, batch_size=2): - # batch is a list of 2 trajectories - train_sft(batch, ...) + # Iterate 3 epochs, 2 trajectories per batch, 1 batch per chunk + for chunk in iterate_trajectories(trajs, epochs=3, batch_size=2, chunk_size=1): + # chunk is [[traj1, traj2]] or [[traj3, traj4]] + train_sft(chunk, ...) - # With chunk_size - for batch in iterate_trajectories(trajs, epochs=3, batch_size=4, chunk_size=5): - # batch is a list of 20 trajectories (4 chunks * 5 per chunk) + # With chunk_size > 1 + for chunk in iterate_trajectories(trajs, epochs=3, batch_size=5, chunk_size=4): + # chunk is a list of 4 batches, each batch has 5 trajectories + # [[traj0-4], [traj5-9], [traj10-14], [traj15-19]] pass # Resume from step 10 - for batch in iterate_trajectories(trajs, epochs=3, batch_size=2, initial_step=10): - # Skips first 10 batches, starts from step 10 + for chunk in iterate_trajectories(trajs, epochs=3, batch_size=2, chunk_size=1, initial_step=10): + # Skips first 10 chunks, starts from step 10 pass """ @@ -204,61 +208,145 @@ def iterate_trajectories( random.shuffle(indices) for i in range(0, dataset_size, items_per_step): - batch_index = i // items_per_step + step_index = i // items_per_step # Calculate global step number - global_step = epoch * steps_per_epoch + batch_index + global_step = epoch * steps_per_epoch + step_index # Skip if before initial_step if global_step < initial_step: continue - batch_indices = indices[i : i + items_per_step] - batch_items = [trajectories[idx] for idx in batch_indices] - yield batch_items + step_indices = indices[i : i + items_per_step] + + # Structure as list of batches, where each batch has batch_size trajectories + chunk: List[List["Trajectory"]] = [] + for batch_idx in range(0, len(step_indices), batch_size): + batch_indices = step_indices[batch_idx : batch_idx + batch_size] + batch = [trajectories[idx] for idx in batch_indices] + chunk.append(batch) + yield chunk -def iterate_file(file_path: str, epochs: int) -> Generator["Trajectory", None, None]: + +def iterate_file( + file_path: str, + epochs: int, + batch_size: int, + shuffle: bool = True, + shuffle_buffer_size: int = 10000, + seed: int | None = 42, +) -> Generator[List["Trajectory"], None, None]: """ - Read JSONL file for each epoch, yielding Trajectory objects. + Read JSONL file for each epoch, yielding batches of Trajectory objects. Each line should contain a dict with: - messages: List of chat messages - tools: Optional list of tools - - reward: Optional reward (defaults to default_reward) + - reward: Optional reward (defaults to 0.0) - split: Optional split name (stored in metadata) - Any other fields will be stored in metadata Args: file_path: Path to JSONL file (one JSON object per line) epochs: Number of times to read through the file - default_reward: Default reward value if not specified in data + batch_size: Number of trajectories per batch. Defaults to 8. + Batches carry over across epochs. + shuffle: Whether to shuffle trajectories. Defaults to True. + shuffle_buffer_size: Size of shuffle buffer. Default: 10000. + Only used if shuffle=True. + seed: Random seed for deterministic shuffling. Default: 42. + Only used if shuffle=True. Yields: - Trajectory objects parsed from the file + Batches of Trajectory objects (lists of size batch_size, last batch may be smaller) Raises: ValueError: If file_path does not end with .jsonl + + Example: + # With shuffle and batching + for batch in iterate_file("data.jsonl", epochs=3, batch_size=8): + # batch is a list of 8 trajectories (or fewer for the last batch) + process(batch) + + # No shuffle + for batch in iterate_file("data.jsonl", epochs=3, batch_size=8, shuffle=False): + process(batch) """ from art.trajectories import Trajectory if not file_path.endswith(".jsonl"): raise ValueError(f"Only JSONL files are supported. Got: {file_path}") - for _ in range(epochs): - with open(file_path, "r") as f: - for line in f: - if not line.strip(): - continue + # Batch accumulator that carries over across epochs + batch: List["Trajectory"] = [] - data = json.loads(line) - - # Extract messages and convert to messages_and_choices format - messages = data.get("messages", []) - tools = data.get("tools", None) - - # Create trajectory - yield Trajectory( - messages_and_choices=messages, - tools=tools if tools else None, - reward=0.0 - ) + for epoch in range(epochs): + if shuffle and seed is not None: + random.seed(seed + epoch) + + if shuffle: + # Streaming shuffle with buffer + shuffle_buffer: List["Trajectory"] = [] + + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + traj = Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + + shuffle_buffer.append(traj) + + # Once buffer is full, start yielding + if len(shuffle_buffer) >= shuffle_buffer_size: + idx = random.randint(0, len(shuffle_buffer) - 1) + batch.append(shuffle_buffer.pop(idx)) + + # Yield batch when it reaches batch_size + if len(batch) == batch_size: + yield batch + batch = [] + + # Flush remaining items in shuffle buffer + random.shuffle(shuffle_buffer) + for traj in shuffle_buffer: + batch.append(traj) + + # Yield batch when it reaches batch_size + if len(batch) == batch_size: + yield batch + batch = [] + else: + # No shuffle - sequential reading + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + batch.append(Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + )) + + # Yield batch when it reaches batch_size + if len(batch) == batch_size: + yield batch + batch = [] + + # Yield any remaining trajectories in the final batch + if batch: + yield batch diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index 9ff04cbd..a7118406 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -129,11 +129,11 @@ async def train_sft_from_file( ) """ from art.types import SFTConfig - from art.utils.iterate_dataset import get_file_row_count, get_total_steps, iterate_file + from art.utils.iterate_dataset import get_file_row_count, iterate_file - # Calculate total steps + # Calculate total steps - batches carry over across epochs num_trajectories = get_file_row_count(file_path) - total_steps = get_total_steps(num_trajectories, epochs, batch_size) + total_steps = math.ceil((num_trajectories * epochs) / batch_size) # Set warmup steps: 10% of total steps, capped at 1000 warmup_steps = min(total_steps // 10, 1000) @@ -147,10 +147,10 @@ async def train_sft_from_file( ) # Create SFT config with shuffling enabled - config = SFTConfig(learning_rate=learning_rates, batch_size=batch_size, shuffle=True) + config = SFTConfig(learning_rate=learning_rates) # Train the model await model.train_sft( - trajectories=iterate_file(file_path, epochs=epochs), + trajectories=iterate_file(file_path, epochs=epochs, batch_size=batch_size), config=config ) diff --git a/tests/unit/test_sft.py b/tests/unit/test_sft.py new file mode 100644 index 00000000..43e0c66c --- /dev/null +++ b/tests/unit/test_sft.py @@ -0,0 +1,182 @@ +"""Unit tests for SFT utilities.""" + +import json +import math +import tempfile +from pathlib import Path +from typing import Iterable, List + +import pytest + +from art.trajectories import Trajectory +from art.types import SFTConfig +from art.utils.iterate_dataset import iterate_file, iterate_trajectories +from art.utils.sft import create_lr_schedule + + +# Helper to create dummy trajectories +def create_dummy_trajectory(idx: int) -> Trajectory: + """Create a dummy trajectory with a unique identifier.""" + return Trajectory( + messages_and_choices=[ + {"role": "user", "content": f"Message {idx}"}, + {"role": "assistant", "content": f"Response {idx}"}, + ], + reward=float(idx), + ) + + +# Helper to create a temporary JSONL file +def create_temp_jsonl(num_trajectories: int) -> Path: + """Create a temporary JSONL file with dummy trajectories.""" + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for i in range(num_trajectories): + data = { + "messages": [ + {"role": "user", "content": f"Message {i}"}, + {"role": "assistant", "content": f"Response {i}"}, + ], + } + temp_file.write(json.dumps(data) + "\n") + temp_file.close() + return Path(temp_file.name) + + +# Dummy train_sft for integration testing +def dummy_train_sft( + trajectories: Iterable[List[Trajectory]], + config: SFTConfig, +) -> dict: + """ + Dummy train_sft function that collects batches and learning rates. + + Args: + trajectories: Iterable of trajectory batches + config: SFT configuration with learning rates + + Returns: + dict with: + - num_batches: number of batches processed + - total_trajectories: total number of trajectories seen + - learning_rates_used: list of learning rates used + """ + num_batches = 0 + total_trajectories = 0 + + for batch in trajectories: + num_batches += 1 + total_trajectories += len(batch) + + return { + "num_batches": num_batches, + "total_trajectories": total_trajectories + } + + +# ============================================================================ +# Integration tests +# ============================================================================ + +def test_integration_iterate_trajectories_with_train_sft(): + """Test using iterate_trajectories chunks with train_sft.""" + trajectories = [create_dummy_trajectory(i) for i in range(20)] + + # batch_size=8, chunk_size=2 means each chunk has up to 2 batches of 8 trajectories + # With 20 trajectories per epoch: + # - Items per chunk: 8 * 2 = 16 + # - Chunks per epoch: ceil(20/16) = 2 (one with 16 trajs, one with 4 trajs) + # With 3 epochs: 2 * 3 = 6 chunks total + + # Create LR schedule for up to 2 batches per chunk + lrs_per_chunk = create_lr_schedule(2, peak_lr=1e-4, method="linear") + + # Manually iterate over chunks and train on each + results = [] + for chunk in iterate_trajectories( + trajectories, + epochs=3, + batch_size=8, # 8 trajectories per batch + chunk_size=2, # 2 batches per chunk + ): + print(f"Chunk: {chunk}") + # chunk is List[List[Trajectory]] which is an Iterable[List[Trajectory]] + result = dummy_train_sft( + trajectories=chunk, + config=SFTConfig(learning_rate=lrs_per_chunk), + ) + results.append(result) + + # Should have 6 chunks total (2 per epoch * 3 epochs) + assert len(results) == 6 + # Pattern repeats for each epoch: full chunk (2 batches), partial chunk (1 batch) + assert results[0]["num_batches"] == 2 # Epoch 1, chunk 1 + assert results[0]["total_trajectories"] == 16 + assert results[1]["num_batches"] == 1 # Epoch 1, chunk 2 (partial) + assert results[1]["total_trajectories"] == 4 + assert results[2]["num_batches"] == 2 # Epoch 2, chunk 1 + assert results[2]["total_trajectories"] == 16 + assert results[3]["num_batches"] == 1 # Epoch 2, chunk 2 (partial) + assert results[3]["total_trajectories"] == 4 + assert results[4]["num_batches"] == 2 # Epoch 3, chunk 1 + assert results[4]["total_trajectories"] == 16 + assert results[5]["num_batches"] == 1 # Epoch 3, chunk 2 (partial) + assert results[5]["total_trajectories"] == 4 + +def test_integration_iterate_file_with_train_sft(): + """Test using iterate_file directly with train_sft.""" + jsonl_file = create_temp_jsonl(100) + + try: + # Create learning rate schedule + total_steps = math.ceil((100 * 2) / 3) # 10 trajectories, 2 epochs, batch_size=3 + lrs = create_lr_schedule(total_steps, peak_lr=1e-4, method="constant") + + config = SFTConfig(learning_rate=lrs) + + # Pass iterate_file directly to train_sft + result = dummy_train_sft( + trajectories=iterate_file( + str(jsonl_file), + epochs=2, + batch_size=3, + shuffle=True, + ), + config=config, + ) + + # Should process 7 batches: [3, 3, 3, 3, 3, 3, 2] + assert result["num_batches"] == 67 + assert result["total_trajectories"] == 200 + finally: + jsonl_file.unlink() + +# def test_total_steps_calculation(): +# """Test that total steps calculation matches actual batches.""" +# num_trajectories = 105 +# epochs = 3 +# batch_size = 8 + +# # This is how train_sft_from_file calculates total_steps +# expected_total_steps = math.ceil((num_trajectories * epochs) / batch_size) + +# # Create file and count actual batches +# jsonl_file = create_temp_jsonl(num_trajectories) + +# try: +# batches = list(iterate_file( +# str(jsonl_file), +# epochs=epochs, +# batch_size=batch_size, +# shuffle=False, +# )) + +# actual_batches = len(batches) + +# # Should match +# assert actual_batches == expected_total_steps +# finally: +# jsonl_file.unlink() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 9138b0754594ffa4ecc0051bea71a086558e96a6 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 18 Nov 2025 16:45:31 -0800 Subject: [PATCH 08/18] Tokenize SFT Batch --- src/art/preprocessing/tokenize_sft.py | 116 ++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/art/preprocessing/tokenize_sft.py diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py new file mode 100644 index 00000000..a126f2f5 --- /dev/null +++ b/src/art/preprocessing/tokenize_sft.py @@ -0,0 +1,116 @@ +"""Tokenization utilities for Supervised Fine-Tuning (SFT).""" + +from dataclasses import dataclass +from typing import Generator + +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from ..trajectories import Trajectory + + +@dataclass +class SFTBatch: + """A batch of tokenized trajectories for supervised fine-tuning. + + Attributes: + trajectory_tensors: List of tensor dictionaries, one per trajectory. + Each dict contains 'input_ids', 'attention_mask', and 'labels'. + learning_rate: Learning rate to use for this batch. + num_items_in_batch: Number of trajectories in this batch. + """ + trajectory_tensors: list[dict[str, torch.Tensor]] + learning_rate: float + num_items_in_batch: int + + +def tokenize_sft_batches( + trajectory_batches: list[list[Trajectory]], + learning_rates: list[float], + tokenizer: PreTrainedTokenizerBase, + instruction_part: str, + response_part: str, +) -> Generator[SFTBatch, None, None]: + """ + Tokenize trajectory batches for supervised fine-tuning. + + Args: + trajectory_batches: List of trajectory batches + learning_rates: Learning rate for each batch + tokenizer: Tokenizer to use for encoding + instruction_part: Instruction template part (e.g., "User:") + response_part: Response template part (e.g., "Assistant:") + + Yields: + SFTBatch object containing: + - trajectory_tensors: List of tensors for each trajectory + - learning_rate: Learning rate for this batch + - num_items_in_batch: Number of trajectories in this batch + """ + instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids + response_ids = tokenizer(response_part, add_special_tokens=False).input_ids + instruction_length = len(instruction_ids) + response_length = len(response_ids) + max_length = max(instruction_length, response_length) + + def _train_on_responses_only(input_ids: list[int]) -> list[int]: + labels = [-100] * len(input_ids) + m = len(input_ids) - max_length + first_response = response_ids[0] + first_instruction = instruction_ids[0] + j = 0 + + while j < m: + if input_ids[j] == first_response: + if input_ids[j : j + response_length] == response_ids: + j = j + response_length + start = j + while j < m: + if input_ids[j] == first_instruction and input_ids[j : j + instruction_length] == instruction_ids: + j = j + instruction_length + labels[start : j] = input_ids[start : j] + break + elif j == (m - 1): + j = m + labels[start:] = input_ids[start:] + break + j += 1 + j += 1 + + return labels + + for trajectory_batch, lr in zip(trajectory_batches, learning_rates): + trajectory_tensors = [] + + for trajectory in trajectory_batch: + messages = trajectory.messages_and_choices + tools = trajectory.tools + + formatted_text = tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=False, + add_generation_prompt=False + ) + + processed = tokenizer(formatted_text) + + input_ids = processed['input_ids'] + attention_mask = processed['attention_mask'] + + labels = _train_on_responses_only(input_ids) + + trajectory_tensor = { + 'input_ids': torch.tensor([input_ids], dtype=torch.long), + 'attention_mask': torch.tensor([attention_mask], dtype=torch.long), + 'labels': torch.tensor([labels], dtype=torch.long), + } + + trajectory_tensors.append(trajectory_tensor) + + yield SFTBatch( + trajectory_tensors=trajectory_tensors, + learning_rate=lr, + num_items_in_batch=len(trajectory_tensors), + ) + From 18a789792905ad90374fde06e99978eb70ec3fc6 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 18 Nov 2025 17:34:57 -0800 Subject: [PATCH 09/18] Add num_trainable_tokens to SFTBatch --- src/art/preprocessing/tokenize_sft.py | 29 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py index a126f2f5..8e74d43c 100644 --- a/src/art/preprocessing/tokenize_sft.py +++ b/src/art/preprocessing/tokenize_sft.py @@ -17,11 +17,13 @@ class SFTBatch: trajectory_tensors: List of tensor dictionaries, one per trajectory. Each dict contains 'input_ids', 'attention_mask', and 'labels'. learning_rate: Learning rate to use for this batch. - num_items_in_batch: Number of trajectories in this batch. + num_trajectories: Number of trajectories in this batch. + num_trainable_tokens: Total number of tokens being trained on (labels != -100). """ trajectory_tensors: list[dict[str, torch.Tensor]] learning_rate: float - num_items_in_batch: int + num_trajectories: int + num_trainable_tokens: int def tokenize_sft_batches( @@ -45,7 +47,8 @@ def tokenize_sft_batches( SFTBatch object containing: - trajectory_tensors: List of tensors for each trajectory - learning_rate: Learning rate for this batch - - num_items_in_batch: Number of trajectories in this batch + - num_trajectories: Number of trajectories in this batch + - num_trainable_tokens: Total number of trainable tokens """ instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids response_ids = tokenizer(response_part, add_special_tokens=False).input_ids @@ -86,17 +89,16 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: messages = trajectory.messages_and_choices tools = trajectory.tools - formatted_text = tokenizer.apply_chat_template( + # Single-step tokenization: apply_chat_template with tokenize=True + input_ids = tokenizer.apply_chat_template( messages, tools=tools, - tokenize=False, + tokenize=True, add_generation_prompt=False ) - processed = tokenizer(formatted_text) - - input_ids = processed['input_ids'] - attention_mask = processed['attention_mask'] + # Create attention mask (all 1s - no padding) + attention_mask = [1] * len(input_ids) labels = _train_on_responses_only(input_ids) @@ -108,9 +110,16 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: trajectory_tensors.append(trajectory_tensor) + # Calculate total trainable tokens (labels != -100) + num_trainable_tokens = sum( + (tensor_dict['labels'] != -100).sum().item() + for tensor_dict in trajectory_tensors + ) + yield SFTBatch( trajectory_tensors=trajectory_tensors, learning_rate=lr, - num_items_in_batch=len(trajectory_tensors), + num_trajectories=len(trajectory_tensors), + num_trainable_tokens=num_trainable_tokens, ) From 90bf94bed0b50aaa4b91e39cd53c6e6a070de1c9 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 18 Nov 2025 18:10:30 -0800 Subject: [PATCH 10/18] draft train_sft --- src/art/unsloth/train_sft.py | 141 +++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/art/unsloth/train_sft.py diff --git a/src/art/unsloth/train_sft.py b/src/art/unsloth/train_sft.py new file mode 100644 index 00000000..6c5b175c --- /dev/null +++ b/src/art/unsloth/train_sft.py @@ -0,0 +1,141 @@ +"""Training utilities for Supervised Fine-Tuning (SFT).""" + +import asyncio +from collections import defaultdict +from typing import TYPE_CHECKING, Callable, Iterator + +import nest_asyncio +import torch +from trl import SFTTrainer + +if TYPE_CHECKING: + from ..preprocessing.tokenize_sft import SFTBatch + +nest_asyncio.apply() + + +async def train_sft( + trainer: SFTTrainer, + input_queue: asyncio.Queue["SFTBatch"], + results_queue: asyncio.Queue[dict[str, float]], +) -> None: + """ + Train an SFT model using batches from a queue. + + Args: + trainer: TRL SFTTrainer instance + input_queue: Queue containing SFTBatch objects + results_queue: Queue for training metrics/results + """ + _get_batch_samples = trainer.get_batch_samples + _log = trainer.log + + trainer.get_batch_samples = get_batch_samples_fn(trainer, input_queue) + trainer.log = get_log_fn(trainer, results_queue) + + # Ensure we have a metrics container in the expected format + try: + is_dict = isinstance(getattr(trainer, "_metrics", None), dict) + is_train_dict = is_dict and isinstance(trainer._metrics.get("train"), dict) + except Exception: + is_train_dict = False + if not is_train_dict: + trainer._metrics = {"train": defaultdict(list)} + + try: + trainer.train() + finally: + trainer.get_batch_samples = _get_batch_samples + trainer.log = _log + + +def get_batch_samples_fn( + trainer: SFTTrainer, + input_queue: asyncio.Queue["SFTBatch"], +) -> Callable[..., tuple[list[dict[str, torch.Tensor]], torch.Tensor]]: + """ + Create a get_batch_samples function that: + 1. Reads SFTBatch from queue + 2. Sets learning rate from batch + 3. Sets gradient accumulation steps + 4. Returns batch samples and num_items_in_batch as tensor + """ + + def get_batch_samples( + epoch_iterator: Iterator, + num_batches: int, + device: torch.device | str | None = None, + ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor]: + """ + Override get_batch_samples to read from queue instead of epoch_iterator. + + Returns: + tuple of (batch_samples, num_items_in_batch as tensor int) + """ + # Read SFTBatch from queue asynchronously + async def get_sft_batch() -> "SFTBatch": + return await input_queue.get() + + # Get the batch from queue + sft_batch: "SFTBatch" = asyncio.run(get_sft_batch()) + + # Set learning rate for this batch + if optimizer := trainer.optimizer: + optimizer = getattr(optimizer, "optimizer", optimizer) + if param_groups := getattr(optimizer, "param_groups"): + for param_group in param_groups: + param_group["lr"] = sft_batch.learning_rate + + # Set gradient accumulation steps to number of trajectories + # We're doing micro-batch size 1, so accumulate across all trajectories + if hasattr(trainer.args, "gradient_accumulation_steps"): + trainer.args.gradient_accumulation_steps = sft_batch.num_trajectories + + # Convert each trajectory to a separate sample for micro-batching + # Trainer will process each sample individually and accumulate gradients + batch_samples = [] + for trajectory_tensor in sft_batch.trajectory_tensors: + # Move each trajectory's tensors to device + sample = { + key: tensor.to(device) + for key, tensor in trajectory_tensor.items() + } + batch_samples.append(sample) + + # Return batch samples and num_items_in_batch as tensor (on device) + num_items_in_batch = torch.tensor( + sft_batch.num_trajectories, + dtype=torch.long, + device=device + ) + + return batch_samples, num_items_in_batch + + return get_batch_samples + + +def get_log_fn( + trainer: SFTTrainer, + results_queue: asyncio.Queue[dict[str, float]], +) -> Callable[..., None]: + """ + Create a logging function that sends metrics to the results queue. + Same pattern as GRPO trainer. + """ + def log(logs: dict[str, float], start_time: float | None = None) -> None: + """Log metrics and send to results queue.""" + metrics = { + key: sum(val) / len(val) for key, val in trainer._metrics["train"].items() + } # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + logs.pop("learning_rate", None) + results_queue.put_nowait(logs) + trainer._metrics["train"].clear() + + return log \ No newline at end of file From 12e21420bfe31229a2313c23ebff6d600745ef73 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 14:19:57 -0800 Subject: [PATCH 11/18] Flatten trajectory for train_sft --- src/art/backend.py | 4 +- src/art/local/backend.py | 2 +- src/art/model.py | 8 +- src/art/serverless/backend.py | 4 +- src/art/types.py | 4 +- src/art/unsloth/service_sft.py | 280 +++++++++++++++++++++++ src/art/unsloth/train_sft_manual.py | 337 ++++++++++++++++++++++++++++ 7 files changed, 629 insertions(+), 10 deletions(-) create mode 100644 src/art/unsloth/service_sft.py create mode 100644 src/art/unsloth/train_sft_manual.py diff --git a/src/art/backend.py b/src/art/backend.py index 07b01d12..473681a0 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncIterator, Iterable, List, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal import httpx from tqdm import auto as tqdm @@ -129,7 +129,7 @@ async def _train_model( async def _train_sft( self, model: "TrainableModel", - trajectories: Iterable[List[Trajectory]], + trajectories: Iterable[Trajectory], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 13c83fef..938328a6 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -524,7 +524,7 @@ async def _train_model( async def _train_sft( self, model: TrainableModel, - trajectories: Iterable[List[Trajectory]], + trajectories: Iterable[Trajectory], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/model.py b/src/art/model.py index afd7f9ac..ba88601d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar, cast, overload +from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient @@ -389,16 +389,16 @@ async def train( async def train_sft( self, - trajectories: Iterable[List[Trajectory]], + trajectories: Iterable[Trajectory], config: SFTConfig, _config: dev.SFTConfig | None = None, verbose: bool = False, ) -> None: """ - Supervised fine-tune the model with batches of trajectories. + Supervised fine-tune the model with an iterable of trajectories. Args: - trajectories: An iterable of trajectory batches (lists of Trajectory objects). + trajectories: An iterable of Trajectory objects. config: SFT configuration including learning_rates and batch_size. _config: Additional experimental configuration that is subject to change and not yet part of the public API. Use at your own risk. diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index c6f928b5..a07ae789 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Iterable, List, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal from openai._types import NOT_GIVEN from tqdm import auto as tqdm @@ -162,7 +162,7 @@ async def _train_model( async def _train_sft( self, model: "TrainableModel", - trajectories: Iterable[List[Trajectory]], + trajectories: Iterable[Trajectory], config: SFTConfig, dev_config: dev.SFTConfig, verbose: bool = False, diff --git a/src/art/types.py b/src/art/types.py index 6e2073e4..6dbb9b24 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -18,7 +18,9 @@ class TrainConfig(pydantic.BaseModel): class SFTConfig(pydantic.BaseModel): - learning_rate: Iterable[float] + learning_rate: float = 5e-5 + batch_size: int | Literal["auto"] = "auto" + custom_lr_schedule: list[float] = [] Verbosity = Literal[0, 1, 2] diff --git a/src/art/unsloth/service_sft.py b/src/art/unsloth/service_sft.py new file mode 100644 index 00000000..da4cd5c5 --- /dev/null +++ b/src/art/unsloth/service_sft.py @@ -0,0 +1,280 @@ +"""Service for Supervised Fine-Tuning (SFT).""" + +import asyncio +import functools +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, AsyncIterator + +from datasets import Dataset +from trl import SFTConfig, SFTTrainer + +from .. import dev +from ..local.checkpoints import get_last_checkpoint_dir +from .train_sft import train_sft + +if TYPE_CHECKING: + from ..preprocessing.tokenize_sft import SFTBatch + + +@dataclass +class SFTService: + """ + Service for managing SFT training with queue-based batch processing. + + Attributes: + model_name: Name of the model + base_model: Base model identifier + config: Internal model configuration + output_dir: Directory for saving checkpoints and logs + """ + model_name: str + base_model: str + config: dev.InternalModelConfig + output_dir: str + _train_task: asyncio.Task[None] | None = None + + @functools.cached_property + def input_queue(self) -> asyncio.Queue["SFTBatch"]: + """Queue for receiving SFTBatch objects.""" + return asyncio.Queue() + + @functools.cached_property + def results_queue(self) -> asyncio.Queue[dict[str, float]]: + """Queue for training metrics.""" + return asyncio.Queue() + + @functools.cached_property + def trainer(self) -> SFTTrainer: + """ + Initialize SFTTrainer with PEFT configuration. + """ + import peft + import unsloth + from transformers import PreTrainedTokenizerBase + + # Initialize model and tokenizer + model, tokenizer = unsloth.FastLanguageModel.from_pretrained( + **self.config.get("init_args", {}) + ) + + # Initialize PEFT model + if isinstance(model, peft.peft_model.PeftModelForCausalLM): + peft_model = model + else: + peft_model = unsloth.FastLanguageModel.get_peft_model( + model, **self.config.get("peft_args", {}) + ) + + # Create a large dummy dataset for the trainer + # The actual data comes from the input_queue + dummy_data = {"text": ""} + dataset = Dataset.from_list([dummy_data for _ in range(10_000_000)]) + + # Get trainer configuration + trainer_args = self.config.get("trainer_args", {}) + sft_config = SFTConfig( + output_dir=self.output_dir, + **trainer_args + ) + + # Initialize SFTTrainer + trainer = SFTTrainer( + model=peft_model, + args=sft_config, + train_dataset=dataset, + processing_class=tokenizer, + ) + + return trainer + + async def train( + self, + batches: AsyncIterator["SFTBatch"] | list["SFTBatch"], + ) -> AsyncIterator[dict[str, float]]: + """ + Train the model using batches from tokenize_sft_batches. + + Args: + batches: AsyncIterator or list of SFTBatch objects from tokenize_sft_batches + + Yields: + Training metrics (loss, learning_rate, etc.) + + Example: + ```python + # Create batches from tokenizer + batches = tokenize_sft_batches( + trajectory_batches=trajectory_batches, + learning_rates=learning_rates, + tokenizer=tokenizer, + instruction_part="<|im_start|>user\\n", + response_part="<|im_start|>assistant\\n", + ) + + # Train + async for metrics in service.train(batches): + print(f"Loss: {metrics['loss']:.4f}") + ``` + """ + # Start the training task if not already started + if self._train_task is None: + self._train_task = asyncio.create_task( + train_sft( + trainer=self.trainer, + input_queue=self.input_queue, + results_queue=self.results_queue, + ) + ) + await asyncio.sleep(0.1) # Let trainer initialize + + # Producer: Feed batches to the input queue + async def feed_batches(): + if hasattr(batches, '__aiter__'): + # AsyncIterator + async for batch in batches: + await self.input_queue.put(batch) + else: + # Regular iterable (e.g., list, generator) + for batch in batches: + await self.input_queue.put(batch) + + # Start feeding batches in the background + feed_task = asyncio.create_task(feed_batches()) + + # Consumer: Yield metrics from results queue + try: + while not feed_task.done() or not self.results_queue.empty(): + try: + metrics = await asyncio.wait_for( + self.results_queue.get(), + timeout=0.1 + ) + yield metrics + except asyncio.TimeoutError: + continue + finally: + await feed_task + + def save_checkpoint(self, checkpoint_name: str | None = None) -> str: + """ + Save model checkpoint. + + Args: + checkpoint_name: Optional name for checkpoint. If None, uses step number. + + Returns: + Path to saved checkpoint + """ + if checkpoint_name is None: + from ..utils.output_dirs import get_step_checkpoint_dir + checkpoint_path = get_step_checkpoint_dir( + self.output_dir, + self.trainer.state.global_step + ) + else: + checkpoint_path = os.path.join(self.output_dir, checkpoint_name) + + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + self.trainer.save_model(checkpoint_path) + return checkpoint_path + + def load_checkpoint(self, checkpoint_path: str | None = None) -> str: + """ + Load model checkpoint. + + Args: + checkpoint_path: Path to checkpoint. If None, loads last checkpoint. + + Returns: + Path to loaded checkpoint + """ + if checkpoint_path is None: + checkpoint_path = get_last_checkpoint_dir(self.output_dir) + if checkpoint_path is None: + raise ValueError(f"No checkpoint found in {self.output_dir}") + + # Reload the model with checkpoint + import peft + + self.trainer.model = peft.PeftModel.from_pretrained( + self.trainer.model.base_model, + checkpoint_path + ) + + return checkpoint_path + + +# Example usage function +async def example_sft_training(): + """ + Example of how to use SFTService for training. + """ + from transformers import AutoTokenizer + from ..preprocessing.tokenize_sft import tokenize_sft_batches + from ..trajectories import Trajectory + + # Initialize service + service = SFTService( + model_name="my-sft-model", + base_model="Qwen/Qwen2.5-0.5B-Instruct", + config={ + "init_args": { + "model_name": "Qwen/Qwen2.5-0.5B-Instruct", + "max_seq_length": 2048, + "load_in_4bit": True, + }, + "peft_args": { + "r": 16, + "lora_alpha": 16, + "lora_dropout": 0, + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], + "bias": "none", + "task_type": "CAUSAL_LM", + }, + "trainer_args": { + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "num_train_epochs": 1, + "learning_rate": 2e-4, + "logging_steps": 1, + "optim": "adamw_8bit", + }, + }, + output_dir="./output/sft-training", + ) + + # Prepare data + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + + trajectory_batches = [ + [ + Trajectory( + messages_and_choices=[ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + ], + reward=1.0, + ), + ], + ] + + learning_rates = [2e-4] + + # Tokenize batches + batches = tokenize_sft_batches( + trajectory_batches=trajectory_batches, + learning_rates=learning_rates, + tokenizer=tokenizer, + instruction_part="<|im_start|>user\n", + response_part="<|im_start|>assistant\n", + ) + + # Train + async for metrics in service.train(batches): + print(f"Step {metrics.get('step')}: Loss={metrics.get('loss'):.4f}") + + # Save checkpoint + checkpoint_path = service.save_checkpoint() + print(f"Saved checkpoint to {checkpoint_path}") + diff --git a/src/art/unsloth/train_sft_manual.py b/src/art/unsloth/train_sft_manual.py new file mode 100644 index 00000000..c67caf7f --- /dev/null +++ b/src/art/unsloth/train_sft_manual.py @@ -0,0 +1,337 @@ +"""Manual training loop for Supervised Fine-Tuning (SFT) - simpler alternative to Trainer.""" + +import asyncio +from typing import TYPE_CHECKING + +import torch +from peft import PeftModel + +if TYPE_CHECKING: + from ..preprocessing.tokenize_sft import SFTBatch + + +async def train_sft_manual( + model: PeftModel, + optimizer: torch.optim.Optimizer, + input_queue: asyncio.Queue["SFTBatch"], + results_queue: asyncio.Queue[dict[str, float]], + device: torch.device | str = "cuda", +) -> None: + """ + Manual training loop for SFT - simpler alternative to Trainer. + + CausalLM models automatically compute cross-entropy loss when labels are provided, + so we don't need to compute loss manually. + + Args: + model: PEFT model to train + optimizer: Optimizer (e.g., AdamW) + input_queue: Queue containing SFTBatch objects + results_queue: Queue for training metrics + device: Device to train on + + Example: + ```python + import torch + from peft import get_peft_model, LoraConfig + + # Setup model + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B") + peft_config = LoraConfig(r=16, lora_alpha=16, ...) + model = get_peft_model(model, peft_config) + model = model.to("cuda") + + # Setup optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4) + + # Train + await train_sft_manual(model, optimizer, input_queue, results_queue) + ``` + """ + model.train() + global_step = 0 + + while True: + try: + # Get batch from queue + async def get_batch() -> "SFTBatch": + return await input_queue.get() + + sft_batch: "SFTBatch" = asyncio.run(get_batch()) + + # Set learning rate for this batch + for param_group in optimizer.param_groups: + param_group["lr"] = sft_batch.learning_rate + + # Track metrics for this batch + batch_loss = 0.0 + num_trajectories = sft_batch.num_trajectories + + # Process each trajectory with gradient accumulation + for idx, trajectory_tensor in enumerate(sft_batch.trajectory_tensors): + # Move tensors to device + inputs = { + key: tensor.to(device) + for key, tensor in trajectory_tensor.items() + } + + # Forward pass - CausalLM computes loss automatically when labels provided + outputs = model(**inputs) + loss = outputs.loss + + # Scale loss by number of trajectories (for gradient accumulation) + loss = loss / num_trajectories + + # Backward pass + loss.backward() + + # Accumulate loss for logging + batch_loss += loss.item() + + # Optimizer step after accumulating gradients from all trajectories + optimizer.step() + optimizer.zero_grad() + + global_step += 1 + + # Prepare metrics + metrics = { + "step": global_step, + "loss": batch_loss, + "learning_rate": sft_batch.learning_rate, + "num_trajectories": sft_batch.num_trajectories, + "num_trainable_tokens": sft_batch.num_trainable_tokens, + } + + # Send metrics to results queue + results_queue.put_nowait(metrics) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Error in training loop: {e}") + break + + +async def train_sft_manual_with_scheduler( + model: PeftModel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler | None, + input_queue: asyncio.Queue["SFTBatch"], + results_queue: asyncio.Queue[dict[str, float]], + device: torch.device | str = "cuda", + max_grad_norm: float | None = 1.0, +) -> None: + """ + Manual training loop with learning rate scheduler and gradient clipping. + + Args: + model: PEFT model to train + optimizer: Optimizer + scheduler: Learning rate scheduler (optional) + input_queue: Queue containing SFTBatch objects + results_queue: Queue for training metrics + device: Device to train on + max_grad_norm: Max gradient norm for clipping (None to disable) + """ + model.train() + global_step = 0 + + while True: + try: + # Get batch from queue + async def get_batch() -> "SFTBatch": + return await input_queue.get() + + sft_batch: "SFTBatch" = asyncio.run(get_batch()) + + # Override learning rate if specified in batch + # (allows per-batch learning rate control) + for param_group in optimizer.param_groups: + param_group["lr"] = sft_batch.learning_rate + + # Track metrics + batch_loss = 0.0 + num_trajectories = sft_batch.num_trajectories + + # Process each trajectory with gradient accumulation + for trajectory_tensor in sft_batch.trajectory_tensors: + # Move to device + inputs = { + key: tensor.to(device) + for key, tensor in trajectory_tensor.items() + } + + # Forward pass - loss computed automatically + outputs = model(**inputs) + loss = outputs.loss / num_trajectories + + # Backward pass + loss.backward() + + batch_loss += loss.item() + + # Gradient clipping + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) + + # Optimizer step + optimizer.step() + optimizer.zero_grad() + + # Scheduler step (if provided) + if scheduler is not None: + scheduler.step() + + global_step += 1 + + # Prepare metrics + metrics = { + "step": global_step, + "loss": batch_loss, + "learning_rate": sft_batch.learning_rate, + "num_trajectories": num_trajectories, + "num_trainable_tokens": sft_batch.num_trainable_tokens, + "grad_norm": torch.nn.utils.clip_grad_norm_( + model.parameters(), float('inf') + ).item() if max_grad_norm else None, + } + + results_queue.put_nowait(metrics) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Error in training loop: {e}") + break + + +# Complete example with manual training loop +async def example_manual_training(): + """ + Complete example showing manual training loop usage. + """ + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import get_peft_model, LoraConfig + from ..preprocessing.tokenize_sft import tokenize_sft_batches + from ..trajectories import Trajectory + + # 1. Setup model + base_model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2.5-0.5B-Instruct", + torch_dtype=torch.float16, + ) + + # 2. Apply PEFT + peft_config = LoraConfig( + r=16, + lora_alpha=16, + lora_dropout=0.0, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(base_model, peft_config) + model = model.to("cuda") + + # 3. Setup optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4) + + # 4. Setup queues + input_queue = asyncio.Queue() + results_queue = asyncio.Queue() + + # 5. Start training task + train_task = asyncio.create_task( + train_sft_manual( + model=model, + optimizer=optimizer, + input_queue=input_queue, + results_queue=results_queue, + device="cuda", + ) + ) + + # 6. Prepare and tokenize data + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + + trajectory_batches = [ + [ + Trajectory( + messages_and_choices=[ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + ], + reward=1.0, + ), + Trajectory( + messages_and_choices=[ + {"role": "user", "content": "What is 3+3?"}, + {"role": "assistant", "content": "3+3 equals 6."}, + ], + reward=1.0, + ), + ], + ] + + batches = tokenize_sft_batches( + trajectory_batches=trajectory_batches, + learning_rates=[2e-4], + tokenizer=tokenizer, + instruction_part="<|im_start|>user\n", + response_part="<|im_start|>assistant\n", + ) + + # 7. Feed batches to queue + for batch in batches: + await input_queue.put(batch) + + # 8. Monitor training + num_batches = len(trajectory_batches) + for _ in range(num_batches): + metrics = await results_queue.get() + print(f"Step {metrics['step']}: Loss={metrics['loss']:.4f}, " + f"LR={metrics['learning_rate']:.2e}, " + f"Trainable tokens={metrics['num_trainable_tokens']}") + + # 9. Stop training + train_task.cancel() + + # 10. Save model + model.save_pretrained("./output/manual-sft-model") + print("Training complete!") + + +# Comparison: Manual vs Trainer +""" +MANUAL TRAINING LOOP: +Pros: + ✅ Simple and transparent - you see exactly what happens + ✅ Direct control over training loop + ✅ No need to override Trainer methods + ✅ Loss computed automatically by CausalLM + ✅ Easy to add custom logic + ✅ Fewer abstractions + +Cons: + ❌ No built-in features (logging, checkpointing, distributed training) + ❌ Need to implement gradient accumulation manually + ❌ No automatic mixed precision (need to add yourself) + +TRAINER API: +Pros: + ✅ Built-in features (logging, checkpointing, distributed) + ✅ Automatic mixed precision + ✅ Integrated with HuggingFace ecosystem + +Cons: + ❌ More complex - need to override get_batch_samples + ❌ Less transparent - harder to debug + ❌ More abstractions + +RECOMMENDATION: +- Use MANUAL for simple cases, prototyping, and full control +- Use TRAINER for production, distributed training, and HF integration +""" + From 4ea6c5e715fb8b9ebd70394034a444d18d9e12fb Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 14:50:41 -0800 Subject: [PATCH 12/18] Tokenize SFT Batches support flat list and add padding --- src/art/preprocessing/tokenize_sft.py | 79 +++++++++++++++++++++------ 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py index 8e74d43c..fcac8bd2 100644 --- a/src/art/preprocessing/tokenize_sft.py +++ b/src/art/preprocessing/tokenize_sft.py @@ -1,5 +1,6 @@ """Tokenization utilities for Supervised Fine-Tuning (SFT).""" +import math from dataclasses import dataclass from typing import Generator @@ -27,17 +28,19 @@ class SFTBatch: def tokenize_sft_batches( - trajectory_batches: list[list[Trajectory]], + trajectories: list[Trajectory], + batch_size: int, learning_rates: list[float], tokenizer: PreTrainedTokenizerBase, instruction_part: str, response_part: str, ) -> Generator[SFTBatch, None, None]: """ - Tokenize trajectory batches for supervised fine-tuning. + Tokenize trajectories into batches for supervised fine-tuning. Args: - trajectory_batches: List of trajectory batches + trajectories: Flat list of trajectories + batch_size: Number of trajectories per batch learning_rates: Learning rate for each batch tokenizer: Tokenizer to use for encoding instruction_part: Instruction template part (e.g., "User:") @@ -50,19 +53,31 @@ def tokenize_sft_batches( - num_trajectories: Number of trajectories in this batch - num_trainable_tokens: Total number of trainable tokens """ + # Validate inputs + num_trajectories = len(trajectories) + num_learning_rates = len(learning_rates) + expected_num_batches = math.ceil(num_trajectories / batch_size) + + if num_learning_rates != expected_num_batches: + raise ValueError( + f"Mismatch between trajectories and learning_rates: " + f"{num_trajectories} trajectories with batch_size={batch_size} " + f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" + ) + instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids response_ids = tokenizer(response_part, add_special_tokens=False).input_ids instruction_length = len(instruction_ids) response_length = len(response_ids) max_length = max(instruction_length, response_length) - + def _train_on_responses_only(input_ids: list[int]) -> list[int]: labels = [-100] * len(input_ids) m = len(input_ids) - max_length first_response = response_ids[0] first_instruction = instruction_ids[0] j = 0 - + while j < m: if input_ids[j] == first_response: if input_ids[j : j + response_length] == response_ids: @@ -79,16 +94,21 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: break j += 1 j += 1 - + return labels - - for trajectory_batch, lr in zip(trajectory_batches, learning_rates): - trajectory_tensors = [] - + + # Batch trajectories + for batch_idx, lr in enumerate(learning_rates): + start_idx = batch_idx * batch_size + end_idx = start_idx + batch_size + trajectory_batch = trajectories[start_idx:end_idx] + + # First pass: tokenize all trajectories + tokenized_trajectories = [] for trajectory in trajectory_batch: messages = trajectory.messages_and_choices tools = trajectory.tools - + # Single-step tokenization: apply_chat_template with tokenize=True input_ids = tokenizer.apply_chat_template( messages, @@ -96,26 +116,49 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: tokenize=True, add_generation_prompt=False ) - - # Create attention mask (all 1s - no padding) + + # Create attention mask (all 1s - no padding yet) attention_mask = [1] * len(input_ids) - + labels = _train_on_responses_only(input_ids) - + + tokenized_trajectories.append({ + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': labels, + }) + + # Find max length in this batch for padding + max_length = max(len(t['input_ids']) for t in tokenized_trajectories) + + # Second pass: pad all trajectories to max_length + trajectory_tensors = [] + for tokenized in tokenized_trajectories: + input_ids = tokenized['input_ids'] + attention_mask = tokenized['attention_mask'] + labels = tokenized['labels'] + + # Pad to max_length + padding_length = max_length - len(input_ids) + if padding_length > 0: + input_ids = input_ids + [tokenizer.pad_token_id] * padding_length + attention_mask = attention_mask + [0] * padding_length + labels = labels + [-100] * padding_length + trajectory_tensor = { 'input_ids': torch.tensor([input_ids], dtype=torch.long), 'attention_mask': torch.tensor([attention_mask], dtype=torch.long), 'labels': torch.tensor([labels], dtype=torch.long), } - + trajectory_tensors.append(trajectory_tensor) - + # Calculate total trainable tokens (labels != -100) num_trainable_tokens = sum( (tensor_dict['labels'] != -100).sum().item() for tensor_dict in trajectory_tensors ) - + yield SFTBatch( trajectory_tensors=trajectory_tensors, learning_rate=lr, From f7bb20336ec3b55076c8619c682710a586da0296 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 15:37:39 -0800 Subject: [PATCH 13/18] Fix max_length duplicate name issue --- src/art/preprocessing/tokenize_sft.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py index fcac8bd2..f7194219 100644 --- a/src/art/preprocessing/tokenize_sft.py +++ b/src/art/preprocessing/tokenize_sft.py @@ -69,11 +69,11 @@ def tokenize_sft_batches( response_ids = tokenizer(response_part, add_special_tokens=False).input_ids instruction_length = len(instruction_ids) response_length = len(response_ids) - max_length = max(instruction_length, response_length) + max_template_length = max(instruction_length, response_length) def _train_on_responses_only(input_ids: list[int]) -> list[int]: labels = [-100] * len(input_ids) - m = len(input_ids) - max_length + m = len(input_ids) - max_template_length first_response = response_ids[0] first_instruction = instruction_ids[0] j = 0 @@ -129,17 +129,17 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: }) # Find max length in this batch for padding - max_length = max(len(t['input_ids']) for t in tokenized_trajectories) + max_seq_length = max(len(t['input_ids']) for t in tokenized_trajectories) - # Second pass: pad all trajectories to max_length + # Second pass: pad all trajectories to max_seq_length trajectory_tensors = [] for tokenized in tokenized_trajectories: input_ids = tokenized['input_ids'] attention_mask = tokenized['attention_mask'] labels = tokenized['labels'] - # Pad to max_length - padding_length = max_length - len(input_ids) + # Pad to max_seq_length + padding_length = max_seq_length - len(input_ids) if padding_length > 0: input_ids = input_ids + [tokenizer.pad_token_id] * padding_length attention_mask = attention_mask + [0] * padding_length From d59e52481ae3aee46ca7568e2c72b6d015c4fc2b Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 15:43:59 -0800 Subject: [PATCH 14/18] Remove unused file --- src/art/unsloth/service_sft.py | 280 ----------------------- src/art/unsloth/train_sft_manual.py | 337 ---------------------------- 2 files changed, 617 deletions(-) delete mode 100644 src/art/unsloth/service_sft.py delete mode 100644 src/art/unsloth/train_sft_manual.py diff --git a/src/art/unsloth/service_sft.py b/src/art/unsloth/service_sft.py deleted file mode 100644 index da4cd5c5..00000000 --- a/src/art/unsloth/service_sft.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Service for Supervised Fine-Tuning (SFT).""" - -import asyncio -import functools -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, AsyncIterator - -from datasets import Dataset -from trl import SFTConfig, SFTTrainer - -from .. import dev -from ..local.checkpoints import get_last_checkpoint_dir -from .train_sft import train_sft - -if TYPE_CHECKING: - from ..preprocessing.tokenize_sft import SFTBatch - - -@dataclass -class SFTService: - """ - Service for managing SFT training with queue-based batch processing. - - Attributes: - model_name: Name of the model - base_model: Base model identifier - config: Internal model configuration - output_dir: Directory for saving checkpoints and logs - """ - model_name: str - base_model: str - config: dev.InternalModelConfig - output_dir: str - _train_task: asyncio.Task[None] | None = None - - @functools.cached_property - def input_queue(self) -> asyncio.Queue["SFTBatch"]: - """Queue for receiving SFTBatch objects.""" - return asyncio.Queue() - - @functools.cached_property - def results_queue(self) -> asyncio.Queue[dict[str, float]]: - """Queue for training metrics.""" - return asyncio.Queue() - - @functools.cached_property - def trainer(self) -> SFTTrainer: - """ - Initialize SFTTrainer with PEFT configuration. - """ - import peft - import unsloth - from transformers import PreTrainedTokenizerBase - - # Initialize model and tokenizer - model, tokenizer = unsloth.FastLanguageModel.from_pretrained( - **self.config.get("init_args", {}) - ) - - # Initialize PEFT model - if isinstance(model, peft.peft_model.PeftModelForCausalLM): - peft_model = model - else: - peft_model = unsloth.FastLanguageModel.get_peft_model( - model, **self.config.get("peft_args", {}) - ) - - # Create a large dummy dataset for the trainer - # The actual data comes from the input_queue - dummy_data = {"text": ""} - dataset = Dataset.from_list([dummy_data for _ in range(10_000_000)]) - - # Get trainer configuration - trainer_args = self.config.get("trainer_args", {}) - sft_config = SFTConfig( - output_dir=self.output_dir, - **trainer_args - ) - - # Initialize SFTTrainer - trainer = SFTTrainer( - model=peft_model, - args=sft_config, - train_dataset=dataset, - processing_class=tokenizer, - ) - - return trainer - - async def train( - self, - batches: AsyncIterator["SFTBatch"] | list["SFTBatch"], - ) -> AsyncIterator[dict[str, float]]: - """ - Train the model using batches from tokenize_sft_batches. - - Args: - batches: AsyncIterator or list of SFTBatch objects from tokenize_sft_batches - - Yields: - Training metrics (loss, learning_rate, etc.) - - Example: - ```python - # Create batches from tokenizer - batches = tokenize_sft_batches( - trajectory_batches=trajectory_batches, - learning_rates=learning_rates, - tokenizer=tokenizer, - instruction_part="<|im_start|>user\\n", - response_part="<|im_start|>assistant\\n", - ) - - # Train - async for metrics in service.train(batches): - print(f"Loss: {metrics['loss']:.4f}") - ``` - """ - # Start the training task if not already started - if self._train_task is None: - self._train_task = asyncio.create_task( - train_sft( - trainer=self.trainer, - input_queue=self.input_queue, - results_queue=self.results_queue, - ) - ) - await asyncio.sleep(0.1) # Let trainer initialize - - # Producer: Feed batches to the input queue - async def feed_batches(): - if hasattr(batches, '__aiter__'): - # AsyncIterator - async for batch in batches: - await self.input_queue.put(batch) - else: - # Regular iterable (e.g., list, generator) - for batch in batches: - await self.input_queue.put(batch) - - # Start feeding batches in the background - feed_task = asyncio.create_task(feed_batches()) - - # Consumer: Yield metrics from results queue - try: - while not feed_task.done() or not self.results_queue.empty(): - try: - metrics = await asyncio.wait_for( - self.results_queue.get(), - timeout=0.1 - ) - yield metrics - except asyncio.TimeoutError: - continue - finally: - await feed_task - - def save_checkpoint(self, checkpoint_name: str | None = None) -> str: - """ - Save model checkpoint. - - Args: - checkpoint_name: Optional name for checkpoint. If None, uses step number. - - Returns: - Path to saved checkpoint - """ - if checkpoint_name is None: - from ..utils.output_dirs import get_step_checkpoint_dir - checkpoint_path = get_step_checkpoint_dir( - self.output_dir, - self.trainer.state.global_step - ) - else: - checkpoint_path = os.path.join(self.output_dir, checkpoint_name) - - os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) - self.trainer.save_model(checkpoint_path) - return checkpoint_path - - def load_checkpoint(self, checkpoint_path: str | None = None) -> str: - """ - Load model checkpoint. - - Args: - checkpoint_path: Path to checkpoint. If None, loads last checkpoint. - - Returns: - Path to loaded checkpoint - """ - if checkpoint_path is None: - checkpoint_path = get_last_checkpoint_dir(self.output_dir) - if checkpoint_path is None: - raise ValueError(f"No checkpoint found in {self.output_dir}") - - # Reload the model with checkpoint - import peft - - self.trainer.model = peft.PeftModel.from_pretrained( - self.trainer.model.base_model, - checkpoint_path - ) - - return checkpoint_path - - -# Example usage function -async def example_sft_training(): - """ - Example of how to use SFTService for training. - """ - from transformers import AutoTokenizer - from ..preprocessing.tokenize_sft import tokenize_sft_batches - from ..trajectories import Trajectory - - # Initialize service - service = SFTService( - model_name="my-sft-model", - base_model="Qwen/Qwen2.5-0.5B-Instruct", - config={ - "init_args": { - "model_name": "Qwen/Qwen2.5-0.5B-Instruct", - "max_seq_length": 2048, - "load_in_4bit": True, - }, - "peft_args": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0, - "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], - "bias": "none", - "task_type": "CAUSAL_LM", - }, - "trainer_args": { - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 4, - "num_train_epochs": 1, - "learning_rate": 2e-4, - "logging_steps": 1, - "optim": "adamw_8bit", - }, - }, - output_dir="./output/sft-training", - ) - - # Prepare data - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") - - trajectory_batches = [ - [ - Trajectory( - messages_and_choices=[ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."}, - ], - reward=1.0, - ), - ], - ] - - learning_rates = [2e-4] - - # Tokenize batches - batches = tokenize_sft_batches( - trajectory_batches=trajectory_batches, - learning_rates=learning_rates, - tokenizer=tokenizer, - instruction_part="<|im_start|>user\n", - response_part="<|im_start|>assistant\n", - ) - - # Train - async for metrics in service.train(batches): - print(f"Step {metrics.get('step')}: Loss={metrics.get('loss'):.4f}") - - # Save checkpoint - checkpoint_path = service.save_checkpoint() - print(f"Saved checkpoint to {checkpoint_path}") - diff --git a/src/art/unsloth/train_sft_manual.py b/src/art/unsloth/train_sft_manual.py deleted file mode 100644 index c67caf7f..00000000 --- a/src/art/unsloth/train_sft_manual.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Manual training loop for Supervised Fine-Tuning (SFT) - simpler alternative to Trainer.""" - -import asyncio -from typing import TYPE_CHECKING - -import torch -from peft import PeftModel - -if TYPE_CHECKING: - from ..preprocessing.tokenize_sft import SFTBatch - - -async def train_sft_manual( - model: PeftModel, - optimizer: torch.optim.Optimizer, - input_queue: asyncio.Queue["SFTBatch"], - results_queue: asyncio.Queue[dict[str, float]], - device: torch.device | str = "cuda", -) -> None: - """ - Manual training loop for SFT - simpler alternative to Trainer. - - CausalLM models automatically compute cross-entropy loss when labels are provided, - so we don't need to compute loss manually. - - Args: - model: PEFT model to train - optimizer: Optimizer (e.g., AdamW) - input_queue: Queue containing SFTBatch objects - results_queue: Queue for training metrics - device: Device to train on - - Example: - ```python - import torch - from peft import get_peft_model, LoraConfig - - # Setup model - model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B") - peft_config = LoraConfig(r=16, lora_alpha=16, ...) - model = get_peft_model(model, peft_config) - model = model.to("cuda") - - # Setup optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4) - - # Train - await train_sft_manual(model, optimizer, input_queue, results_queue) - ``` - """ - model.train() - global_step = 0 - - while True: - try: - # Get batch from queue - async def get_batch() -> "SFTBatch": - return await input_queue.get() - - sft_batch: "SFTBatch" = asyncio.run(get_batch()) - - # Set learning rate for this batch - for param_group in optimizer.param_groups: - param_group["lr"] = sft_batch.learning_rate - - # Track metrics for this batch - batch_loss = 0.0 - num_trajectories = sft_batch.num_trajectories - - # Process each trajectory with gradient accumulation - for idx, trajectory_tensor in enumerate(sft_batch.trajectory_tensors): - # Move tensors to device - inputs = { - key: tensor.to(device) - for key, tensor in trajectory_tensor.items() - } - - # Forward pass - CausalLM computes loss automatically when labels provided - outputs = model(**inputs) - loss = outputs.loss - - # Scale loss by number of trajectories (for gradient accumulation) - loss = loss / num_trajectories - - # Backward pass - loss.backward() - - # Accumulate loss for logging - batch_loss += loss.item() - - # Optimizer step after accumulating gradients from all trajectories - optimizer.step() - optimizer.zero_grad() - - global_step += 1 - - # Prepare metrics - metrics = { - "step": global_step, - "loss": batch_loss, - "learning_rate": sft_batch.learning_rate, - "num_trajectories": sft_batch.num_trajectories, - "num_trainable_tokens": sft_batch.num_trainable_tokens, - } - - # Send metrics to results queue - results_queue.put_nowait(metrics) - - except asyncio.CancelledError: - break - except Exception as e: - print(f"Error in training loop: {e}") - break - - -async def train_sft_manual_with_scheduler( - model: PeftModel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler | None, - input_queue: asyncio.Queue["SFTBatch"], - results_queue: asyncio.Queue[dict[str, float]], - device: torch.device | str = "cuda", - max_grad_norm: float | None = 1.0, -) -> None: - """ - Manual training loop with learning rate scheduler and gradient clipping. - - Args: - model: PEFT model to train - optimizer: Optimizer - scheduler: Learning rate scheduler (optional) - input_queue: Queue containing SFTBatch objects - results_queue: Queue for training metrics - device: Device to train on - max_grad_norm: Max gradient norm for clipping (None to disable) - """ - model.train() - global_step = 0 - - while True: - try: - # Get batch from queue - async def get_batch() -> "SFTBatch": - return await input_queue.get() - - sft_batch: "SFTBatch" = asyncio.run(get_batch()) - - # Override learning rate if specified in batch - # (allows per-batch learning rate control) - for param_group in optimizer.param_groups: - param_group["lr"] = sft_batch.learning_rate - - # Track metrics - batch_loss = 0.0 - num_trajectories = sft_batch.num_trajectories - - # Process each trajectory with gradient accumulation - for trajectory_tensor in sft_batch.trajectory_tensors: - # Move to device - inputs = { - key: tensor.to(device) - for key, tensor in trajectory_tensor.items() - } - - # Forward pass - loss computed automatically - outputs = model(**inputs) - loss = outputs.loss / num_trajectories - - # Backward pass - loss.backward() - - batch_loss += loss.item() - - # Gradient clipping - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) - - # Optimizer step - optimizer.step() - optimizer.zero_grad() - - # Scheduler step (if provided) - if scheduler is not None: - scheduler.step() - - global_step += 1 - - # Prepare metrics - metrics = { - "step": global_step, - "loss": batch_loss, - "learning_rate": sft_batch.learning_rate, - "num_trajectories": num_trajectories, - "num_trainable_tokens": sft_batch.num_trainable_tokens, - "grad_norm": torch.nn.utils.clip_grad_norm_( - model.parameters(), float('inf') - ).item() if max_grad_norm else None, - } - - results_queue.put_nowait(metrics) - - except asyncio.CancelledError: - break - except Exception as e: - print(f"Error in training loop: {e}") - break - - -# Complete example with manual training loop -async def example_manual_training(): - """ - Complete example showing manual training loop usage. - """ - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - from peft import get_peft_model, LoraConfig - from ..preprocessing.tokenize_sft import tokenize_sft_batches - from ..trajectories import Trajectory - - # 1. Setup model - base_model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", - torch_dtype=torch.float16, - ) - - # 2. Apply PEFT - peft_config = LoraConfig( - r=16, - lora_alpha=16, - lora_dropout=0.0, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], - bias="none", - task_type="CAUSAL_LM", - ) - model = get_peft_model(base_model, peft_config) - model = model.to("cuda") - - # 3. Setup optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4) - - # 4. Setup queues - input_queue = asyncio.Queue() - results_queue = asyncio.Queue() - - # 5. Start training task - train_task = asyncio.create_task( - train_sft_manual( - model=model, - optimizer=optimizer, - input_queue=input_queue, - results_queue=results_queue, - device="cuda", - ) - ) - - # 6. Prepare and tokenize data - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") - - trajectory_batches = [ - [ - Trajectory( - messages_and_choices=[ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."}, - ], - reward=1.0, - ), - Trajectory( - messages_and_choices=[ - {"role": "user", "content": "What is 3+3?"}, - {"role": "assistant", "content": "3+3 equals 6."}, - ], - reward=1.0, - ), - ], - ] - - batches = tokenize_sft_batches( - trajectory_batches=trajectory_batches, - learning_rates=[2e-4], - tokenizer=tokenizer, - instruction_part="<|im_start|>user\n", - response_part="<|im_start|>assistant\n", - ) - - # 7. Feed batches to queue - for batch in batches: - await input_queue.put(batch) - - # 8. Monitor training - num_batches = len(trajectory_batches) - for _ in range(num_batches): - metrics = await results_queue.get() - print(f"Step {metrics['step']}: Loss={metrics['loss']:.4f}, " - f"LR={metrics['learning_rate']:.2e}, " - f"Trainable tokens={metrics['num_trainable_tokens']}") - - # 9. Stop training - train_task.cancel() - - # 10. Save model - model.save_pretrained("./output/manual-sft-model") - print("Training complete!") - - -# Comparison: Manual vs Trainer -""" -MANUAL TRAINING LOOP: -Pros: - ✅ Simple and transparent - you see exactly what happens - ✅ Direct control over training loop - ✅ No need to override Trainer methods - ✅ Loss computed automatically by CausalLM - ✅ Easy to add custom logic - ✅ Fewer abstractions - -Cons: - ❌ No built-in features (logging, checkpointing, distributed training) - ❌ Need to implement gradient accumulation manually - ❌ No automatic mixed precision (need to add yourself) - -TRAINER API: -Pros: - ✅ Built-in features (logging, checkpointing, distributed) - ✅ Automatic mixed precision - ✅ Integrated with HuggingFace ecosystem - -Cons: - ❌ More complex - need to override get_batch_samples - ❌ Less transparent - harder to debug - ❌ More abstractions - -RECOMMENDATION: -- Use MANUAL for simple cases, prototyping, and full control -- Use TRAINER for production, distributed training, and HF integration -""" - From 7f6309a346edab32024e48870a508839c1fda3e9 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 15:49:05 -0800 Subject: [PATCH 15/18] remove unused typing --- src/art/local/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 938328a6..ef1e2e3a 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -5,7 +5,7 @@ import subprocess from datetime import datetime from types import TracebackType -from typing import AsyncIterator, Iterable, List, Literal, cast +from typing import AsyncIterator, Iterable, Literal, cast import aiohttp import numpy as np From 5ec5575bf4b8aef46b034ca236682449716f393d Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 16:52:07 -0800 Subject: [PATCH 16/18] sft iterator --- src/art/utils/sft.py | 159 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 1 deletion(-) diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index a7118406..4561fcf4 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -1,16 +1,31 @@ """Utilities for supervised fine-tuning (SFT).""" import math +import random +from dataclasses import dataclass from typing import TYPE_CHECKING, Generator, List, Literal if TYPE_CHECKING: from art.model import TrainableModel + from art.trajectories import Trajectory + from art.types import SFTConfig + + +@dataclass +class SFTDatasetChunk: + """Container for SFT dataset chunk with trajectories, config, and step information.""" + + trajectories: List["Trajectory"] + config: "SFTConfig" + step: int + epoch: int + epoch_step: int def create_lr_schedule( total_steps: int, peak_lr: float, - method: Literal["cosine", "linear", "constant"] = "cosine", + method: Literal["cosine", "linear", "constant"] = "linear", warmup_steps: int = 0, min_lr: float = 0.0, ) -> List[float]: @@ -103,6 +118,148 @@ def iterate_learning_rates( yield learning_rates[i : i + chunk_size] +def create_sft_dataset_iterator( + trajectories: List["Trajectory"], + epochs: int = 1, + batch_size: int = 1, + chunk_size: int = 50, + peak_lr: float = 2e-4, + schedule_type: Literal["cosine", "linear", "constant"] = "linear", + warmup_ratio: float = 0.1, + initial_step: int = 0, +) -> Generator[SFTDatasetChunk, None, None]: + """ + Create an iterator that yields SFT dataset chunks with trajectories, config, and step info. + + Combines trajectory batching with learning rate scheduling. Yields SFTDatasetChunk objects + containing flattened trajectories, SFTConfig with learning rates, and step tracking info. + + Args: + trajectories: List of Trajectory objects to train on + epochs: Number of times to iterate over the trajectories. Default: 1 + batch_size: Number of trajectories per batch. Default: 1 + chunk_size: Number of batches per chunk. Default: 50 + peak_lr: Peak learning rate. Default: 5e-5 + schedule_type: Learning rate schedule type ("cosine", "linear", "constant"). Default: "linear" + warmup_ratio: Ratio of total steps to use for warmup (0.0 to 1.0). Default: 0.1 + initial_step: The global chunk step to start from. Default: 0. + Useful for resuming training. + + Yields: + SFTDatasetChunk containing: + - trajectories: Flattened list of trajectories (chunk_size * batch_size trajectories) + - config: SFTConfig with custom_lr_schedule containing learning rates for each batch + - step: Global step number across all epochs + - epoch: Current epoch number (0-indexed) + - epoch_step: Step number within current epoch (0-indexed) + + Example: + trajectories = [traj1, traj2, ..., traj100] + + # Create SFT dataset iterator with linear schedule + for chunk in create_sft_dataset_iterator( + trajectories=trajectories, + epochs=3, + batch_size=4, + chunk_size=10, + peak_lr=1e-4, + schedule_type="linear", + warmup_ratio=0.1, + ): + # chunk.trajectories is a flat list of 40 trajectories (10 batches * 4 per batch) + # chunk.config.custom_lr_schedule is a list of 10 learning rates (one per batch) + # chunk.config.batch_size is 4 + # chunk.step is global step number + # chunk.epoch is current epoch + # chunk.epoch_step is step within epoch + train_sft(chunk.trajectories, chunk.config) + + # Resume from chunk step 5 + for chunk in create_sft_dataset_iterator( + trajectories=trajectories, + epochs=3, + batch_size=4, + chunk_size=10, + initial_step=5, + ): + # Starts from chunk step 5 + pass + """ + from art.types import SFTConfig + + dataset_size = len(trajectories) + if dataset_size == 0: + return + + # Calculate total batch steps (one step per batch) + batches_per_epoch = math.ceil(dataset_size / batch_size) + total_batch_steps = batches_per_epoch * epochs + + # Calculate warmup steps + warmup_steps = int(total_batch_steps * warmup_ratio) + + # Create learning rate schedule (one LR per batch) + learning_rates = create_lr_schedule( + total_steps=total_batch_steps, + peak_lr=peak_lr, + method=schedule_type, + warmup_steps=warmup_steps, + min_lr=0.0, + ) + + # Calculate chunk iteration parameters + items_per_chunk = batch_size * chunk_size + chunks_per_epoch = math.ceil(dataset_size / items_per_chunk) + + for epoch in range(epochs): + # Create indices and shuffle deterministically based on epoch + indices = list(range(dataset_size)) + random.seed(epoch) + random.shuffle(indices) + + for chunk_idx in range(chunks_per_epoch): + # Calculate step numbers + epoch_step = chunk_idx + global_step = epoch * chunks_per_epoch + chunk_idx + + # Skip if before initial_step + if global_step < initial_step: + continue + + # Get indices for this chunk + chunk_start = chunk_idx * items_per_chunk + chunk_end = min(chunk_start + items_per_chunk, dataset_size) + step_indices = indices[chunk_start:chunk_end] + + # Flatten trajectories for this chunk + chunk_trajectories: List["Trajectory"] = [ + trajectories[idx] for idx in step_indices + ] + + # Calculate learning rates for each batch in this chunk + chunk_lrs: List[float] = [] + num_batches_in_chunk = math.ceil(len(step_indices) / batch_size) + + for batch_idx in range(num_batches_in_chunk): + # Calculate global batch step + global_batch_step = epoch * batches_per_epoch + (chunk_start // batch_size) + batch_idx + chunk_lrs.append(learning_rates[global_batch_step]) + + # Create SFTConfig with custom learning rate schedule + config = SFTConfig( + batch_size=batch_size, + custom_lr_schedule=chunk_lrs, + ) + + yield SFTDatasetChunk( + trajectories=chunk_trajectories, + config=config, + step=global_step, + epoch=epoch, + epoch_step=epoch_step, + ) + + async def train_sft_from_file( model: "TrainableModel", file_path: str, From d6688cf1422bed06a56aeacfd1bc2f19864bf751 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 21 Nov 2025 17:21:20 -0800 Subject: [PATCH 17/18] SFT Iterator --- src/art/utils/iterate_dataset.py | 260 +------------------------------ src/art/utils/sft.py | 195 ++++++++++++++++++----- 2 files changed, 157 insertions(+), 298 deletions(-) diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index 146845af..fda51c41 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -1,14 +1,10 @@ -import json import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Iterable, List, TypeVar +from typing import Generator, Generic, List, TypeVar from tqdm.auto import tqdm -if TYPE_CHECKING: - from art.trajectories import Trajectory - T = TypeVar("T") @@ -96,257 +92,3 @@ def iterate_dataset( if progress_bar: progress_bar.close() - - -def get_file_row_count(file_path: str) -> int: - """ - Count the number of non-empty rows in a JSONL file. - - Args: - file_path: Path to JSONL file - - Returns: - Number of non-empty lines in the file - - Raises: - ValueError: If file_path does not end with .jsonl - - Example: - count = get_file_row_count("data.jsonl") - print(f"Dataset has {count} items") - """ - if not file_path.endswith(".jsonl"): - raise ValueError(f"Only JSONL files are supported. Got: {file_path}") - - count = 0 - with open(file_path, "r") as f: - for line in f: - if line.strip(): - count += 1 - return count - - -def get_total_steps(traj_len: int, epochs: int, batch_size: int) -> int: - """ - Calculate total number of training steps given dataset size, epochs, and batch size. - - Args: - traj_len: Number of trajectories in the dataset - epochs: Number of epochs to train - batch_size: Number of trajectories per batch/step - - Returns: - Total number of training steps - - Example: - # 100 trajectories, 3 epochs, batch size of 10 - total_steps = get_total_steps(100, 3, 10) - # Returns 30 (10 steps per epoch * 3 epochs) - - # With partial batch at end - total_steps = get_total_steps(105, 3, 10) - # Returns 33 (11 steps per epoch * 3 epochs) - """ - steps_per_epoch = math.ceil(traj_len / batch_size) - return steps_per_epoch * epochs - - -def iterate_trajectories( - trajectories: List["Trajectory"], - epochs: int, - batch_size: int, - chunk_size: int = 1, - initial_step: int = 0, -) -> Generator[List[List["Trajectory"]], None, None]: - """ - Iterate over a list of trajectories for multiple epochs, yielding chunks of batches. - Shuffles trajectories at the start of each epoch with a fixed seed for reproducibility. - - Args: - trajectories: List of Trajectory objects - epochs: Number of times to iterate over the list - batch_size: Number of trajectories per batch (inner list size) - chunk_size: Number of batches per chunk (outer list size). Defaults to 1. - initial_step: The global step number to start from. Defaults to 0. - Useful for resuming training. - - Yields: - List of lists of trajectories (chunk_size batches, each with batch_size trajectories) - - Example: - # Load trajectories once - trajs = [traj1, traj2, traj3, traj4] - - # Iterate 3 epochs, 2 trajectories per batch, 1 batch per chunk - for chunk in iterate_trajectories(trajs, epochs=3, batch_size=2, chunk_size=1): - # chunk is [[traj1, traj2]] or [[traj3, traj4]] - train_sft(chunk, ...) - - # With chunk_size > 1 - for chunk in iterate_trajectories(trajs, epochs=3, batch_size=5, chunk_size=4): - # chunk is a list of 4 batches, each batch has 5 trajectories - # [[traj0-4], [traj5-9], [traj10-14], [traj15-19]] - pass - - # Resume from step 10 - for chunk in iterate_trajectories(trajs, epochs=3, batch_size=2, chunk_size=1, initial_step=10): - # Skips first 10 chunks, starts from step 10 - pass - """ - - dataset_size = len(trajectories) - if dataset_size == 0: - return - - items_per_step = batch_size * chunk_size - steps_per_epoch = math.ceil(dataset_size / items_per_step) - - for epoch in range(epochs): - # Create indices and shuffle deterministically based on epoch - indices = list(range(dataset_size)) - random.seed(epoch) - random.shuffle(indices) - - for i in range(0, dataset_size, items_per_step): - step_index = i // items_per_step - # Calculate global step number - global_step = epoch * steps_per_epoch + step_index - - # Skip if before initial_step - if global_step < initial_step: - continue - - step_indices = indices[i : i + items_per_step] - - # Structure as list of batches, where each batch has batch_size trajectories - chunk: List[List["Trajectory"]] = [] - for batch_idx in range(0, len(step_indices), batch_size): - batch_indices = step_indices[batch_idx : batch_idx + batch_size] - batch = [trajectories[idx] for idx in batch_indices] - chunk.append(batch) - - yield chunk - - -def iterate_file( - file_path: str, - epochs: int, - batch_size: int, - shuffle: bool = True, - shuffle_buffer_size: int = 10000, - seed: int | None = 42, -) -> Generator[List["Trajectory"], None, None]: - """ - Read JSONL file for each epoch, yielding batches of Trajectory objects. - - Each line should contain a dict with: - - messages: List of chat messages - - tools: Optional list of tools - - reward: Optional reward (defaults to 0.0) - - split: Optional split name (stored in metadata) - - Any other fields will be stored in metadata - - Args: - file_path: Path to JSONL file (one JSON object per line) - epochs: Number of times to read through the file - batch_size: Number of trajectories per batch. Defaults to 8. - Batches carry over across epochs. - shuffle: Whether to shuffle trajectories. Defaults to True. - shuffle_buffer_size: Size of shuffle buffer. Default: 10000. - Only used if shuffle=True. - seed: Random seed for deterministic shuffling. Default: 42. - Only used if shuffle=True. - - Yields: - Batches of Trajectory objects (lists of size batch_size, last batch may be smaller) - - Raises: - ValueError: If file_path does not end with .jsonl - - Example: - # With shuffle and batching - for batch in iterate_file("data.jsonl", epochs=3, batch_size=8): - # batch is a list of 8 trajectories (or fewer for the last batch) - process(batch) - - # No shuffle - for batch in iterate_file("data.jsonl", epochs=3, batch_size=8, shuffle=False): - process(batch) - """ - from art.trajectories import Trajectory - - if not file_path.endswith(".jsonl"): - raise ValueError(f"Only JSONL files are supported. Got: {file_path}") - - # Batch accumulator that carries over across epochs - batch: List["Trajectory"] = [] - - for epoch in range(epochs): - if shuffle and seed is not None: - random.seed(seed + epoch) - - if shuffle: - # Streaming shuffle with buffer - shuffle_buffer: List["Trajectory"] = [] - - with open(file_path, "r") as f: - for line in f: - if not line.strip(): - continue - - data = json.loads(line) - messages = data.get("messages", []) - tools = data.get("tools", None) - - traj = Trajectory( - messages_and_choices=messages, - tools=tools if tools else None, - reward=0.0 - ) - - shuffle_buffer.append(traj) - - # Once buffer is full, start yielding - if len(shuffle_buffer) >= shuffle_buffer_size: - idx = random.randint(0, len(shuffle_buffer) - 1) - batch.append(shuffle_buffer.pop(idx)) - - # Yield batch when it reaches batch_size - if len(batch) == batch_size: - yield batch - batch = [] - - # Flush remaining items in shuffle buffer - random.shuffle(shuffle_buffer) - for traj in shuffle_buffer: - batch.append(traj) - - # Yield batch when it reaches batch_size - if len(batch) == batch_size: - yield batch - batch = [] - else: - # No shuffle - sequential reading - with open(file_path, "r") as f: - for line in f: - if not line.strip(): - continue - - data = json.loads(line) - messages = data.get("messages", []) - tools = data.get("tools", None) - - batch.append(Trajectory( - messages_and_choices=messages, - tools=tools if tools else None, - reward=0.0 - )) - - # Yield batch when it reaches batch_size - if len(batch) == batch_size: - yield batch - batch = [] - - # Yield any remaining trajectories in the final batch - if batch: - yield batch diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index 4561fcf4..74e5e971 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -1,10 +1,13 @@ """Utilities for supervised fine-tuning (SFT).""" +import json import math import random from dataclasses import dataclass from typing import TYPE_CHECKING, Generator, List, Literal +from tqdm.auto import tqdm + if TYPE_CHECKING: from art.model import TrainableModel from art.trajectories import Trajectory @@ -21,6 +24,33 @@ class SFTDatasetChunk: epoch: int epoch_step: int +def get_file_row_count(file_path: str) -> int: + """ + Count the number of non-empty rows in a JSONL file. + + Args: + file_path: Path to JSONL file + + Returns: + Number of non-empty lines in the file + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + count = get_file_row_count("data.jsonl") + print(f"Dataset has {count} items") + """ + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + count = 0 + with open(file_path, "r") as f: + for line in f: + if line.strip(): + count += 1 + return count + def create_lr_schedule( total_steps: int, @@ -86,38 +116,6 @@ def create_lr_schedule( return learning_rates -def iterate_learning_rates( - learning_rates: List[float], - chunk_size: int, - initial_step: int = 0, -) -> Generator[List[float], None, None]: - """ - Iterate over learning rates in chunks, with support for resuming from a specific step. - - Args: - learning_rates: List of learning rate values - chunk_size: Number of learning rates per chunk - initial_step: The step number to start from. Defaults to 0. - Useful for resuming training. - - Yields: - List of learning rates (chunk_size items, last chunk may be smaller) - - Example: - lrs = create_lr_schedule(10, 1e-4) - for lr_chunk in iterate_learning_rates(lrs, chunk_size=3): - # lr_chunk has 3 learning rates (or fewer for last chunk) - # Yields: [lr0, lr1, lr2], [lr3, lr4, lr5], [lr6, lr7, lr8], [lr9] - - # Resume from step 5 - for lr_chunk in iterate_learning_rates(lrs, chunk_size=3, initial_step=5): - # Starts from learning rate 5: yields [lr5, lr6, lr7], [lr8, lr9] - pass - """ - for i in range(initial_step, len(learning_rates), chunk_size): - yield learning_rates[i : i + chunk_size] - - def create_sft_dataset_iterator( trajectories: List["Trajectory"], epochs: int = 1, @@ -127,6 +125,7 @@ def create_sft_dataset_iterator( schedule_type: Literal["cosine", "linear", "constant"] = "linear", warmup_ratio: float = 0.1, initial_step: int = 0, + use_tqdm: bool = True, ) -> Generator[SFTDatasetChunk, None, None]: """ Create an iterator that yields SFT dataset chunks with trajectories, config, and step info. @@ -144,6 +143,7 @@ def create_sft_dataset_iterator( warmup_ratio: Ratio of total steps to use for warmup (0.0 to 1.0). Default: 0.1 initial_step: The global chunk step to start from. Default: 0. Useful for resuming training. + use_tqdm: Whether to display a progress bar. Default: True Yields: SFTDatasetChunk containing: @@ -199,7 +199,7 @@ def create_sft_dataset_iterator( warmup_steps = int(total_batch_steps * warmup_ratio) # Create learning rate schedule (one LR per batch) - learning_rates = create_lr_schedule( + custom_lr_schedule = create_lr_schedule( total_steps=total_batch_steps, peak_lr=peak_lr, method=schedule_type, @@ -210,6 +210,16 @@ def create_sft_dataset_iterator( # Calculate chunk iteration parameters items_per_chunk = batch_size * chunk_size chunks_per_epoch = math.ceil(dataset_size / items_per_chunk) + total_steps = chunks_per_epoch * epochs + + progress_bar = None + if use_tqdm: + progress_bar = tqdm( + initial=initial_step, + total=total_steps, + desc="Training SFT", + unit="chunk", + ) for epoch in range(epochs): # Create indices and shuffle deterministically based on epoch @@ -243,7 +253,7 @@ def create_sft_dataset_iterator( for batch_idx in range(num_batches_in_chunk): # Calculate global batch step global_batch_step = epoch * batches_per_epoch + (chunk_start // batch_size) + batch_idx - chunk_lrs.append(learning_rates[global_batch_step]) + chunk_lrs.append(custom_lr_schedule[global_batch_step]) # Create SFTConfig with custom learning rate schedule config = SFTConfig( @@ -259,6 +269,114 @@ def create_sft_dataset_iterator( epoch_step=epoch_step, ) + # Update progress bar after yielding + if progress_bar: + progress_bar.update(1) + + if progress_bar: + progress_bar.close() + +def iterate_file( + file_path: str, + epochs: int, + shuffle: bool = True, + shuffle_buffer_size: int = 10000, + seed: int | None = 42, +) -> Generator["Trajectory", None, None]: + """ + Read JSONL file for each epoch, yielding individual Trajectory objects. + + Completes reading the entire file for one epoch before starting the next epoch. + This ensures all trajectories from epoch N are yielded before any from epoch N+1. + + Each line should contain a dict with: + - messages: List of chat messages + - tools: Optional list of tools + - reward: Optional reward (defaults to 0.0) + - split: Optional split name (stored in metadata) + - Any other fields will be stored in metadata + + Args: + file_path: Path to JSONL file (one JSON object per line) + epochs: Number of times to read through the file + shuffle: Whether to shuffle trajectories. Defaults to True. + shuffle_buffer_size: Size of shuffle buffer for streaming shuffle. Default: 10000. + Only used if shuffle=True. + seed: Random seed for deterministic shuffling. Default: 42. + Only used if shuffle=True. + + Yields: + Individual Trajectory objects + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + # With shuffle + for trajectory in iterate_file("data.jsonl", epochs=3, shuffle=True): + # trajectory is a single Trajectory object + process(trajectory) + + # No shuffle + for trajectory in iterate_file("data.jsonl", epochs=3, shuffle=False): + process(trajectory) + """ + from art.trajectories import Trajectory + + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + for epoch in range(epochs): + if shuffle and seed is not None: + random.seed(seed + epoch) + + if shuffle: + # Streaming shuffle with buffer + shuffle_buffer: List["Trajectory"] = [] + + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + traj = Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + + shuffle_buffer.append(traj) + + # Once buffer is full, start yielding randomly + if len(shuffle_buffer) >= shuffle_buffer_size: + idx = random.randint(0, len(shuffle_buffer) - 1) + yield shuffle_buffer.pop(idx) + + # Flush remaining items in shuffle buffer at end of epoch + random.shuffle(shuffle_buffer) + for traj in shuffle_buffer: + yield traj + else: + # No shuffle - sequential reading + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + yield Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + async def train_sft_from_file( model: "TrainableModel", @@ -286,7 +404,6 @@ async def train_sft_from_file( ) """ from art.types import SFTConfig - from art.utils.iterate_dataset import get_file_row_count, iterate_file # Calculate total steps - batches carry over across epochs num_trajectories = get_file_row_count(file_path) @@ -296,18 +413,18 @@ async def train_sft_from_file( warmup_steps = min(total_steps // 10, 1000) # Create cosine learning rate schedule with warmup - learning_rates = create_lr_schedule( + custom_lr_schedule = create_lr_schedule( total_steps=total_steps, peak_lr=learning_rate, - method="cosine", + method="linear", warmup_steps=warmup_steps, ) # Create SFT config with shuffling enabled - config = SFTConfig(learning_rate=learning_rates) + config = SFTConfig(custom_lr_schedule=custom_lr_schedule, batch_size=batch_size) # Train the model await model.train_sft( - trajectories=iterate_file(file_path, epochs=epochs, batch_size=batch_size), + trajectories=iterate_file(file_path, epochs=epochs), config=config ) From 6c63af56b1b6fb302ce345cd1458eddcbcc79681 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 25 Nov 2025 11:25:35 -0800 Subject: [PATCH 18/18] Use Unsloth for train on response --- src/art/preprocessing/tokenize_sft.py | 107 +++++++++++++++++++------- 1 file changed, 81 insertions(+), 26 deletions(-) diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py index f7194219..87faaf78 100644 --- a/src/art/preprocessing/tokenize_sft.py +++ b/src/art/preprocessing/tokenize_sft.py @@ -9,6 +9,12 @@ from ..trajectories import Trajectory +# Import Unsloth Zoo utilities for robust token matching +# Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py +# These functions handle edge cases with tokenization (newlines, spaces, etc.) +import unsloth # Must import first to set UNSLOTH_IS_PRESENT env var +from unsloth_zoo.dataset_utils import _find_common_token_ids + @dataclass class SFTBatch: @@ -65,36 +71,85 @@ def tokenize_sft_batches( f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" ) - instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids - response_ids = tokenizer(response_part, add_special_tokens=False).input_ids - instruction_length = len(instruction_ids) - response_length = len(response_ids) - max_template_length = max(instruction_length, response_length) + # Get most common tokens using Unsloth approach + Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer, force_match=False) + A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer, force_match=False) + + # Store temporary stuff + A_first = A_must[0] + len_A_must = len(A_must) + A_left_reversed = A_left[::-1] + A_right_forward = A_right + + Q_first = Q_must[0] + len_Q_must = len(Q_must) + Q_left_reversed = Q_left[::-1] + Q_right_forward = Q_right def _train_on_responses_only(input_ids: list[int]) -> list[int]: - labels = [-100] * len(input_ids) - m = len(input_ids) - max_template_length - first_response = response_ids[0] - first_instruction = instruction_ids[0] + """Unsloth-based implementation for marking trainable tokens.""" + n = len(input_ids) + labels = [-100] * n + n_minus_1 = n - 1 j = 0 - - while j < m: - if input_ids[j] == first_response: - if input_ids[j : j + response_length] == response_ids: - j = j + response_length - start = j - while j < m: - if input_ids[j] == first_instruction and input_ids[j : j + instruction_length] == instruction_ids: - j = j + instruction_length - labels[start : j] = input_ids[start : j] - break - elif j == (m - 1): - j = m - labels[start:] = input_ids[start:] - break - j += 1 + + while j < n: + # Find + if (input_ids[j] == A_first) and \ + (input_ids[j : (k := j + len_A_must)] == A_must): + + # Now backtrack to get previous optional tokens + for optional_left in A_left_reversed: + if j < 1: break + if optional_left == input_ids[j-1]: j -= 1 + else: break + + # And forwards look as well + for optional_right in A_right_forward: + if k >= n_minus_1: break + if optional_right == input_ids[k+1]: k += 1 + else: break + + assistant_k = k + j = assistant_k + + # Given , now find next user + while j < n: + # Find + # Also accept last final item if assistant is the last turn + if (j == n_minus_1) or \ + ((input_ids[j] == Q_first) and \ + (input_ids[j : (k := j + len_Q_must)] == Q_must)): + + # Now backtrack to get previous optional tokens + for optional_left in Q_left_reversed: + if j < 1: break + if optional_left == input_ids[j-1]: j -= 1 + else: break + + # And forwards look as well + for optional_right in Q_right_forward: + if k >= n_minus_1: break + if optional_right == input_ids[k+1]: k += 1 + else: break + + user_j = j + + # Account for last item + if user_j != n_minus_1: + j = k + else: + user_j = n + k = n + + # Now copy input_ids to labels + labels[assistant_k : user_j] = input_ids[assistant_k : user_j] + break + + j += 1 + j += 1 - + return labels # Batch trajectories