Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/fireworks_math/train_fireworks_math.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
# Find the directory where rllm package is located
RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")

MODEL_PATH=Qwen/Qwen3-4B
MODEL_PATH=Qwen/Qwen3-8B
export CUDA_VISIBLE_DEVICES=4,5,6,7

python3 -m examples.fireworks_math.train_fireworks_math \
algorithm.adv_estimator=grpo \
Expand Down Expand Up @@ -77,5 +78,6 @@ python3 -m examples.fireworks_math.train_fireworks_math \
trainer.default_hdfs_dir=null \
trainer.total_epochs=100 \
rllm.workflow.use_workflow=True \
fireworks.deployment_id=wtk15cs9 \
fireworks.model_id_prefix=qwen3-4b-math
fireworks.deployment_id=rllm-qwen3-8b-3 \
fireworks.model_id_prefix=rllm-qwen3-8b-math \
fireworks.return_token_ids=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set return_token_ids to True by default so no need to specify it here.

78 changes: 78 additions & 0 deletions rllm/engine/rollout/fireworks_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import json
import os
import time
from urllib.parse import urljoin

import openai
import requests
from fireworks.control_plane.generated.protos_grpcio.gateway.deployed_model_pb2 import (
DeployedModel as SyncDeployedModel,
)
Expand All @@ -12,6 +15,8 @@
from fireworks.gateway import Gateway

from rllm.engine.rollout.openai_engine import OpenAIEngine
from rllm.engine.rollout.rollout_engine import ModelOutput
from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START


class FireworksEngine(OpenAIEngine):
Expand Down Expand Up @@ -103,3 +108,76 @@ async def _probe_deployment(self, model_name) -> bool:
continue
else:
return False

async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
kwargs.pop("application_id", None)
kwargs.pop("validate", None)
kwargs.pop("model", None)
kwargs.pop("enforce_max_prompt_length", None)

sampling_params = self.sampling_params.copy()
sampling_params.update(kwargs)

create_params = self._prepare_max_tokens_param(sampling_params)

retries = self.api_retries
while retries > 0:
try:
merged_sampling_params = {**create_params, **sampling_params}
response = self._fireworks_chat_completion(messages=messages, sampling_params=merged_sampling_params)
content = response["choices"][0]["message"]["content"]
reasoning = response["choices"][0]["message"].get("reasoning", "")
tool_calls = response["choices"][0]["message"].get("tool_calls", [])

# Build text with reasoning if available, otherwise use content
if reasoning:
text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}"
else:
text = content

prompt_length = response["usage"]["prompt_tokens"]
completion_length = response["usage"]["completion_tokens"]
finish_reason = response["choices"][0]["finish_reason"]

prompt_token_ids = response["prompt_token_ids"]
completion_token_ids = response["choices"][0]["token_ids"]
return ModelOutput(
text=text,
content=content,
reasoning=reasoning,
tool_calls=tool_calls,
prompt_ids=prompt_token_ids,
completion_ids=completion_token_ids,
prompt_length=prompt_length,
completion_length=completion_length,
finish_reason=finish_reason,
)

except openai.RateLimitError:
retries -= 1
if retries == 0:
raise Exception("Rate limit reached and retries exhausted.") from None
print("Sleep for 5 seconds for API limit.")
await asyncio.sleep(5)

except Exception as e:
retries -= 1
if retries == 0:
raise Exception(f"Error processing content after retries: {e}") from e
print(f"Error: {e}, retrying...")
await asyncio.sleep(1)

def _fireworks_chat_completion(self, messages, sampling_params):
url = urljoin(str(self.client.base_url), "chat/completions")
payload = {
"model": self.model,
"messages": messages,
**sampling_params,
}
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.client.api_key}",
}
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
return response.json()
3 changes: 2 additions & 1 deletion rllm/trainer/config/agent_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ fireworks:
deployment_id: null
model_id_prefix: test-model
concurrency: 32
return_token_ids: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just set it directly to True? And then up there we don't need to specify it in train_fireworks_math.sh explicitly any more.


trainer:
log_episodes: false
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
14 changes: 9 additions & 5 deletions rllm/trainer/verl/agent_workflow_trainer_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ def init_workers(self):
self.actor_wg.init_model()
self.actor_rollout_wg = self.actor_wg # for compatibility

sampling_params = {
"temperature": self.config.actor_rollout_ref.rollout.temperature,
"top_p": self.config.actor_rollout_ref.rollout.top_p,
"max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length,
}
if self.config.fireworks.return_token_ids:
sampling_params["return_token_ids"] = True

fireworks_engine = FireworksEngine(
tokenizer=self.tokenizer,
deployment_id=self.config.fireworks.deployment_id,
sampling_params={
"temperature": 0.6,
"top_p": 0.95,
"max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length,
},
sampling_params=sampling_params,
)
self.fireworks_engine = fireworks_engine
self.agent_execution_engine = AgentWorkflowEngine(
Expand Down