Skip to content

Commit fcb89a4

Browse files
committed
deprecate use_workflow
1 parent 9c5b66c commit fcb89a4

File tree

4 files changed

+5
-15
lines changed

4 files changed

+5
-15
lines changed

rllm/trainer/agent_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def __init__(
4040
agent_args: Optional arguments to pass to the agent class
4141
env_args: Optional arguments to pass to the environment class
4242
"""
43-
44-
if workflow_class is not None and config.rllm.workflow.use_workflow:
43+
if workflow_class is not None:
4544
if agent_class is not None:
4645
raise ValueError("agent_class is not supported when using workflow, instead use workflow_args['agent_cls']")
4746
if agent_args is not None:
@@ -81,7 +80,7 @@ def _train_tinker(self):
8180
from rllm.trainer.tinker.tinker_agent_trainer import TinkerAgentTrainer
8281
from rllm.trainer.tinker.tinker_workflow_trainer import TinkerWorkflowTrainer
8382

84-
if self.config.rllm.workflow.use_workflow:
83+
if self.workflow_class is not None:
8584
trainer = TinkerWorkflowTrainer(
8685
config=self.config,
8786
workflow_class=self.workflow_class,

rllm/trainer/config/tinker_agent_trainer.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ trainer:
6363
val_before_train: true
6464
default_local_dir: '/tmp/rllm-tinker-checkpoints'
6565

66-
rllm:
67-
workflow:
68-
use_workflow: false
69-
7066
# Hydra configuration
7167
hydra:
7268
run:

rllm/trainer/config/tinker_workflow_trainer.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ trainer:
6060
val_before_train: true
6161
default_local_dir: '/tmp/rllm-tinker-checkpoints'
6262

63-
rllm:
64-
workflow:
65-
use_workflow: true
66-
6763
# Hydra configuration
6864
hydra:
6965
run:

rllm/trainer/verl/train_agent_ppo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import ray
1111
from omegaconf import OmegaConf
1212

13-
from rllm.trainer.env_agent_mappings import AGENT_CLASS_MAPPING, ENV_CLASS_MAPPING, WORKFLOW_CLASS_MAPPING
13+
from rllm.trainer.env_agent_mappings import AGENT_CLASS_MAPPING, ENV_CLASS_MAPPING
1414
from rllm.trainer.verl.agent_ppo_trainer import AgentPPOTrainer
1515

1616
# Local application imports
@@ -155,9 +155,8 @@ def run(self, config, workflow_class=None, workflow_args=None, agent_class=None,
155155
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}))
156156
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
157157

158-
if config.rllm.workflow.use_workflow:
159-
if workflow_class is None:
160-
workflow_class = WORKFLOW_CLASS_MAPPING[config.rllm.workflow.name]
158+
if workflow_class is not None:
159+
# Should provide workflow_class if want to use workflow trainer
161160
workflow_args = workflow_args or {}
162161
if config.rllm.workflow.get("workflow_args") is not None:
163162
for key, value in config.rllm.workflow.get("workflow_args").items():

0 commit comments

Comments
 (0)