You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A wrapper class that allows users to easily train custom agents with custom environments
13
15
without having to directly interact with the underlying training infrastructure.
16
+
17
+
Supports two backends:
18
+
- 'verl' (default): Standard training backend supporting both workflow and agent/env classes
19
+
- 'fireworks': Pipeline-based training backend optimized for workflow-based training
14
20
"""
15
21
16
22
def__init__(
@@ -24,23 +30,39 @@ def __init__(
24
30
config: dict[str, Any] |list[str] |None=None,
25
31
train_dataset: Dataset|None=None,
26
32
val_dataset: Dataset|None=None,
33
+
backend: Literal["verl", "fireworks"] ="verl",
27
34
):
28
35
"""
29
36
Initialize the AgentTrainer.
30
37
31
38
Args:
39
+
workflow_class: The workflow class to use for training
40
+
workflow_args: Optional arguments to pass to the workflow class
32
41
agent_class: The custom agent class to use for training
33
42
env_class: The custom environment class to use for training
43
+
agent_args: Optional arguments to pass to the agent class
44
+
env_args: Optional arguments to pass to the environment class
34
45
config: Configuration overrides to apply to the default config
35
46
Can be a dictionary with dot notation keys (e.g., {"data.train_batch_size": 8})
36
47
or a list of strings in the format "key=value" (e.g., ["data.train_batch_size=8"])
37
48
train_dataset: Optional train dataset to use
38
49
val_dataset: Optional validation dataset to use
39
-
agent_args: Optional arguments to pass to the agent class
40
-
env_args: Optional arguments to pass to the environment class
50
+
backend: Training backend to use ('verl' or 'fireworks'). Default is 'verl'
41
51
"""
52
+
# Validate backend
53
+
ifbackendnotin ["verl", "fireworks"]:
54
+
raiseValueError(f"backend must be either 'verl' or 'fireworks', got '{backend}'")
55
+
56
+
self.backend=backend
57
+
58
+
# Validate backend-specific requirements
59
+
ifbackend=="fireworks":
60
+
ifagent_classisnotNoneorenv_classisnotNone:
61
+
raiseValueError("The 'fireworks' backend only supports workflow_class. agent_class and env_class are not supported. Use workflow_args to configure agent and environment.")
62
+
ifagent_argsisnotNoneorenv_argsisnotNone:
63
+
raiseValueError("The 'fireworks' backend does not support agent_args or env_args. Use workflow_args to configure the workflow.")
0 commit comments