Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tinker_cookbook/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TinkerTokenCompleter(TokenCompleter):

sampling_client: tinker.SamplingClient
max_tokens: int
temperature: float = 1.0

async def __call__(
self, model_input: tinker.ModelInput, stop: StopCondition
Expand All @@ -63,7 +64,11 @@ async def __call__(
sample_result = await self.sampling_client.sample_async(
prompt=model_input,
num_samples=1,
sampling_params=tinker.SamplingParams(stop=stop, max_tokens=self.max_tokens),
sampling_params=tinker.SamplingParams(
stop=stop,
max_tokens=self.max_tokens,
temperature=self.temperature,
),
)

# Extract tokens and logprobs from the first (and only) sample
Expand Down
2 changes: 2 additions & 0 deletions tinker_cookbook/recipes/math_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CLIConfig:
groups_per_batch: int = 100
learning_rate: float = 1e-5
max_tokens: int = 5
temperature: float = 1.0
kl_penalty_coef: float = 0.0

# Number of optimizer steps per training iteration.
Expand Down Expand Up @@ -124,6 +125,7 @@ async def cli_main(cli_config: CLIConfig):
model_name=cli_config.model_name,
lora_rank=cli_config.lora_rank,
max_tokens=cli_config.max_tokens,
temperature=cli_config.temperature,
wandb_project=cli_config.wandb_project,
wandb_name=wandb_name,
log_path=log_path,
Expand Down
7 changes: 6 additions & 1 deletion tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class Config:
dataset_builder: RLDatasetBuilder # also determines batch size
model_name: str
max_tokens: int
temperature: float = 1.0
compute_post_kl: bool = False
evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)
lora_rank: int = 32
Expand Down Expand Up @@ -366,6 +367,7 @@ async def trajectory_group_worker_task(
sampling_client,
builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
enable_logging=enable_logging,
)
Expand Down Expand Up @@ -501,6 +503,7 @@ async def trajectory_group_worker_loop():
sampling_client,
env_group_builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
)
if trajectory_group is None:
Expand Down Expand Up @@ -659,10 +662,11 @@ async def do_group_rollout_and_filter_constant_reward(
sampling_client: tinker.SamplingClient,
env_group_builder: EnvGroupBuilder,
max_tokens: int,
temperature: float,
do_remove_constant_reward_groups: bool,
enable_logging: bool = True,
) -> TrajectoryGroup | None:
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens)
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens, temperature=temperature)

with logtree.optional_enable_logging(enable_logging):
trajectory_group = await do_group_rollout(env_group_builder, policy)
Expand Down Expand Up @@ -988,6 +992,7 @@ async def do_sync_training(
sampling_client,
builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
enable_logging=i < cfg.num_groups_to_log,
),
Expand Down