From 2b78b0dbf48816b9c4978c2e97e5474af90cb690 Mon Sep 17 00:00:00 2001 From: Sijun <31901639+jeffreysijuntan@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:32:19 -0800 Subject: [PATCH 1/5] squash revert n_training_gpus_per_node=2 token_ids rllm-qwen3-8b-3 fireworks.return_token_ids=True rllm-qwen3-8b-2 ["chioces"] formatting revert Update README.md --- .../fireworks_math/train_fireworks_math.sh | 6 +- rllm/engine/rollout/fireworks_engine.py | 78 +++++++++++++++++++ rllm/trainer/config/agent_ppo_trainer.yaml | 3 +- .../verl/agent_workflow_trainer_fireworks.py | 14 ++-- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/examples/fireworks_math/train_fireworks_math.sh b/examples/fireworks_math/train_fireworks_math.sh index 856710aa3..b7ffb78ad 100644 --- a/examples/fireworks_math/train_fireworks_math.sh +++ b/examples/fireworks_math/train_fireworks_math.sh @@ -10,6 +10,7 @@ export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000 RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))") MODEL_PATH=Qwen/Qwen3-4B +export CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m examples.fireworks_math.train_fireworks_math \ algorithm.adv_estimator=grpo \ @@ -77,5 +78,6 @@ python3 -m examples.fireworks_math.train_fireworks_math \ trainer.default_hdfs_dir=null \ trainer.total_epochs=100 \ rllm.workflow.use_workflow=True \ - fireworks.deployment_id=wtk15cs9 \ - fireworks.model_id_prefix=qwen3-4b-math \ No newline at end of file + fireworks.deployment_id=rllm-qwen3-8b-3 \ + fireworks.model_id_prefix=rllm-qwen3-8b-math \ + fireworks.return_token_ids=True \ No newline at end of file diff --git a/rllm/engine/rollout/fireworks_engine.py b/rllm/engine/rollout/fireworks_engine.py index 20f91fc26..061048bed 100644 --- a/rllm/engine/rollout/fireworks_engine.py +++ b/rllm/engine/rollout/fireworks_engine.py @@ -2,7 +2,10 @@ import json import os import time +from urllib.parse import urljoin +import openai +import requests from fireworks.control_plane.generated.protos_grpcio.gateway.deployed_model_pb2 import ( DeployedModel as SyncDeployedModel, ) @@ -12,6 +15,8 @@ from fireworks.gateway import Gateway from rllm.engine.rollout.openai_engine import OpenAIEngine +from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START class FireworksEngine(OpenAIEngine): @@ -103,3 +108,76 @@ async def _probe_deployment(self, model_name) -> bool: continue else: return False + + async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: + kwargs.pop("application_id", None) + kwargs.pop("validate", None) + kwargs.pop("model", None) + kwargs.pop("enforce_max_prompt_length", None) + + sampling_params = self.sampling_params.copy() + sampling_params.update(kwargs) + + create_params = self._prepare_max_tokens_param(sampling_params) + + retries = self.api_retries + while retries > 0: + try: + merged_sampling_params = {**create_params, **sampling_params} + response = self._fireworks_chat_completion(messages=messages, sampling_params=merged_sampling_params) + content = response["choices"][0]["message"]["content"] + reasoning = response["choices"][0]["message"].get("reasoning", "") + tool_calls = response["choices"][0]["message"].get("tool_calls", []) + + # Build text with reasoning if available, otherwise use content + if reasoning: + text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}" + else: + text = content + + prompt_length = response["usage"]["prompt_tokens"] + completion_length = response["usage"]["completion_tokens"] + finish_reason = response["choices"][0]["finish_reason"] + + prompt_token_ids = response["prompt_token_ids"] + completion_token_ids = response.json()['choices'][0]['token_ids'] + return ModelOutput( + text=text, + content=content, + reasoning=reasoning, + tool_calls=tool_calls, + prompt_ids=prompt_token_ids, + completion_ids=completion_token_ids, + prompt_length=prompt_length, + completion_length=completion_length, + finish_reason=finish_reason, + ) + + except openai.RateLimitError: + retries -= 1 + if retries == 0: + raise Exception("Rate limit reached and retries exhausted.") from None + print("Sleep for 5 seconds for API limit.") + await asyncio.sleep(5) + + except Exception as e: + retries -= 1 + if retries == 0: + raise Exception(f"Error processing content after retries: {e}") from e + print(f"Error: {e}, retrying...") + await asyncio.sleep(1) + + def _fireworks_chat_completion(self, messages, sampling_params): + url = urljoin(str(self.client.base_url), "/chat/completions") + payload = { + "model": self.model, + "messages": messages, + **sampling_params, + } + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {self.client.api_key}", + } + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + return response.json() diff --git a/rllm/trainer/config/agent_ppo_trainer.yaml b/rllm/trainer/config/agent_ppo_trainer.yaml index 47f40aee6..019849ed3 100644 --- a/rllm/trainer/config/agent_ppo_trainer.yaml +++ b/rllm/trainer/config/agent_ppo_trainer.yaml @@ -65,7 +65,8 @@ fireworks: deployment_id: null model_id_prefix: test-model concurrency: 32 + return_token_ids: False trainer: log_episodes: false - episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} diff --git a/rllm/trainer/verl/agent_workflow_trainer_fireworks.py b/rllm/trainer/verl/agent_workflow_trainer_fireworks.py index 7c2f69fc4..6c2831fe9 100644 --- a/rllm/trainer/verl/agent_workflow_trainer_fireworks.py +++ b/rllm/trainer/verl/agent_workflow_trainer_fireworks.py @@ -97,14 +97,18 @@ def init_workers(self): self.actor_wg.init_model() self.actor_rollout_wg = self.actor_wg # for compatibility + sampling_params = { + "temperature": self.config.actor_rollout_ref.rollout.temperature, + "top_p": self.config.actor_rollout_ref.rollout.top_p, + "max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length, + } + if self.config.fireworks.return_token_ids: + sampling_params["return_token_ids"] = True + fireworks_engine = FireworksEngine( tokenizer=self.tokenizer, deployment_id=self.config.fireworks.deployment_id, - sampling_params={ - "temperature": 0.6, - "top_p": 0.95, - "max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length, - }, + sampling_params=sampling_params, ) self.fireworks_engine = fireworks_engine self.agent_execution_engine = AgentWorkflowEngine( From 3cafb8ba669271debdbe10630f8fbf33aca80a1b Mon Sep 17 00:00:00 2001 From: 1stprinciple Date: Mon, 10 Nov 2025 22:28:13 +0100 Subject: [PATCH 2/5] formatting --- rllm/engine/rollout/fireworks_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllm/engine/rollout/fireworks_engine.py b/rllm/engine/rollout/fireworks_engine.py index 061048bed..455c08069 100644 --- a/rllm/engine/rollout/fireworks_engine.py +++ b/rllm/engine/rollout/fireworks_engine.py @@ -140,7 +140,7 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: finish_reason = response["choices"][0]["finish_reason"] prompt_token_ids = response["prompt_token_ids"] - completion_token_ids = response.json()['choices'][0]['token_ids'] + completion_token_ids = response.json()["choices"][0]["token_ids"] return ModelOutput( text=text, content=content, From 18dc2437554ceac8a35824f1d7fb69e7eb07cf5f Mon Sep 17 00:00:00 2001 From: 1stprinciple Date: Tue, 11 Nov 2025 00:01:08 +0100 Subject: [PATCH 3/5] remove / --- rllm/engine/rollout/fireworks_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllm/engine/rollout/fireworks_engine.py b/rllm/engine/rollout/fireworks_engine.py index 455c08069..5454ea604 100644 --- a/rllm/engine/rollout/fireworks_engine.py +++ b/rllm/engine/rollout/fireworks_engine.py @@ -168,7 +168,7 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: await asyncio.sleep(1) def _fireworks_chat_completion(self, messages, sampling_params): - url = urljoin(str(self.client.base_url), "/chat/completions") + url = urljoin(str(self.client.base_url), "chat/completions") payload = { "model": self.model, "messages": messages, From b95f7e50691acb349e9e4c08b194a98f4e35c33e Mon Sep 17 00:00:00 2001 From: 1stprinciple Date: Tue, 11 Nov 2025 00:34:56 +0100 Subject: [PATCH 4/5] remove .json() --- rllm/engine/rollout/fireworks_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllm/engine/rollout/fireworks_engine.py b/rllm/engine/rollout/fireworks_engine.py index 5454ea604..78eb96059 100644 --- a/rllm/engine/rollout/fireworks_engine.py +++ b/rllm/engine/rollout/fireworks_engine.py @@ -140,7 +140,7 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: finish_reason = response["choices"][0]["finish_reason"] prompt_token_ids = response["prompt_token_ids"] - completion_token_ids = response.json()["choices"][0]["token_ids"] + completion_token_ids = response["choices"][0]["token_ids"] return ModelOutput( text=text, content=content, From f537ec43b923be38b8690df84a4f5730fa153207 Mon Sep 17 00:00:00 2001 From: 1stprinciple Date: Tue, 11 Nov 2025 01:34:30 +0100 Subject: [PATCH 5/5] Qwen/Qwen3-8B --- examples/fireworks_math/train_fireworks_math.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fireworks_math/train_fireworks_math.sh b/examples/fireworks_math/train_fireworks_math.sh index b7ffb78ad..e89a04b46 100644 --- a/examples/fireworks_math/train_fireworks_math.sh +++ b/examples/fireworks_math/train_fireworks_math.sh @@ -9,7 +9,7 @@ export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000 # Find the directory where rllm package is located RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))") -MODEL_PATH=Qwen/Qwen3-4B +MODEL_PATH=Qwen/Qwen3-8B export CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m examples.fireworks_math.train_fireworks_math \