Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tinker_cookbook/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TinkerTokenCompleter(TokenCompleter):

sampling_client: tinker.SamplingClient
max_tokens: int
temperature: float = 1.0

async def __call__(
self, model_input: tinker.ModelInput, stop: StopCondition
Expand All @@ -63,7 +64,11 @@ async def __call__(
sample_result = await self.sampling_client.sample_async(
prompt=model_input,
num_samples=1,
sampling_params=tinker.SamplingParams(stop=stop, max_tokens=self.max_tokens),
sampling_params=tinker.SamplingParams(
stop=stop,
max_tokens=self.max_tokens,
temperature=self.temperature,
),
)

# Extract tokens and logprobs from the first (and only) sample
Expand Down
2 changes: 2 additions & 0 deletions tinker_cookbook/recipes/math_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CLIConfig:
groups_per_batch: int = 100
learning_rate: float = 1e-5
max_tokens: int = 5
temperature: float = 1.0
kl_penalty_coef: float = 0.0

# Number of optimizer steps per training iteration.
Expand Down Expand Up @@ -124,6 +125,7 @@ async def cli_main(cli_config: CLIConfig):
model_name=cli_config.model_name,
lora_rank=cli_config.lora_rank,
max_tokens=cli_config.max_tokens,
temperature=cli_config.temperature,
wandb_project=cli_config.wandb_project,
wandb_name=wandb_name,
log_path=log_path,
Expand Down
7 changes: 6 additions & 1 deletion tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class Config:
dataset_builder: RLDatasetBuilder # also determines batch size
model_name: str
max_tokens: int
temperature: float = 1.0
compute_post_kl: bool = False
evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)
lora_rank: int = 32
Expand Down Expand Up @@ -366,6 +367,7 @@ async def trajectory_group_worker_task(
sampling_client,
builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
enable_logging=enable_logging,
)
Expand Down Expand Up @@ -501,6 +503,7 @@ async def trajectory_group_worker_loop():
sampling_client,
env_group_builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
)
if trajectory_group is None:
Expand Down Expand Up @@ -659,10 +662,11 @@ async def do_group_rollout_and_filter_constant_reward(
sampling_client: tinker.SamplingClient,
env_group_builder: EnvGroupBuilder,
max_tokens: int,
temperature: float,
do_remove_constant_reward_groups: bool,
enable_logging: bool = True,
) -> TrajectoryGroup | None:
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens)
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens, temperature=temperature)

with logtree.optional_enable_logging(enable_logging):
trajectory_group = await do_group_rollout(env_group_builder, policy)
Expand Down Expand Up @@ -988,6 +992,7 @@ async def do_sync_training(
sampling_client,
builder,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
enable_logging=i < cfg.num_groups_to_log,
),
Expand Down