From 5f8b59c0f7797993df4875b08cb1284bf82fe93d Mon Sep 17 00:00:00 2001 From: Xiuyu Li Date: Tue, 11 Nov 2025 08:18:41 +0000 Subject: [PATCH] Add configurable temperature parameter for RL rollout sampling --- tinker_cookbook/completers.py | 7 ++++++- tinker_cookbook/recipes/math_rl/train.py | 2 ++ tinker_cookbook/rl/train.py | 7 ++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tinker_cookbook/completers.py b/tinker_cookbook/completers.py index df305ec..6ed9e21 100644 --- a/tinker_cookbook/completers.py +++ b/tinker_cookbook/completers.py @@ -54,6 +54,7 @@ class TinkerTokenCompleter(TokenCompleter): sampling_client: tinker.SamplingClient max_tokens: int + temperature: float = 1.0 async def __call__( self, model_input: tinker.ModelInput, stop: StopCondition @@ -63,7 +64,11 @@ async def __call__( sample_result = await self.sampling_client.sample_async( prompt=model_input, num_samples=1, - sampling_params=tinker.SamplingParams(stop=stop, max_tokens=self.max_tokens), + sampling_params=tinker.SamplingParams( + stop=stop, + max_tokens=self.max_tokens, + temperature=self.temperature, + ), ) # Extract tokens and logprobs from the first (and only) sample diff --git a/tinker_cookbook/recipes/math_rl/train.py b/tinker_cookbook/recipes/math_rl/train.py index 2ec1ed1..fa5afe1 100644 --- a/tinker_cookbook/recipes/math_rl/train.py +++ b/tinker_cookbook/recipes/math_rl/train.py @@ -34,6 +34,7 @@ class CLIConfig: groups_per_batch: int = 100 learning_rate: float = 1e-5 max_tokens: int = 5 + temperature: float = 1.0 kl_penalty_coef: float = 0.0 # Number of optimizer steps per training iteration. @@ -124,6 +125,7 @@ async def cli_main(cli_config: CLIConfig): model_name=cli_config.model_name, lora_rank=cli_config.lora_rank, max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, wandb_project=cli_config.wandb_project, wandb_name=wandb_name, log_path=log_path, diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 9d81b5e..0f70fa4 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -229,6 +229,7 @@ class Config: dataset_builder: RLDatasetBuilder # also determines batch size model_name: str max_tokens: int + temperature: float = 1.0 compute_post_kl: bool = False evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list) lora_rank: int = 32 @@ -366,6 +367,7 @@ async def trajectory_group_worker_task( sampling_client, builder, max_tokens=cfg.max_tokens, + temperature=cfg.temperature, do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, enable_logging=enable_logging, ) @@ -501,6 +503,7 @@ async def trajectory_group_worker_loop(): sampling_client, env_group_builder, max_tokens=cfg.max_tokens, + temperature=cfg.temperature, do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, ) if trajectory_group is None: @@ -659,10 +662,11 @@ async def do_group_rollout_and_filter_constant_reward( sampling_client: tinker.SamplingClient, env_group_builder: EnvGroupBuilder, max_tokens: int, + temperature: float, do_remove_constant_reward_groups: bool, enable_logging: bool = True, ) -> TrajectoryGroup | None: - policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens) + policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens, temperature=temperature) with logtree.optional_enable_logging(enable_logging): trajectory_group = await do_group_rollout(env_group_builder, policy) @@ -988,6 +992,7 @@ async def do_sync_training( sampling_client, builder, max_tokens=cfg.max_tokens, + temperature=cfg.temperature, do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, enable_logging=i < cfg.num_groups_to_log, ),