diff --git a/examples/fireworks_math/train_fireworks_math.sh b/examples/fireworks_math/train_fireworks_math.sh index 856710aa3..e89a04b46 100644 --- a/examples/fireworks_math/train_fireworks_math.sh +++ b/examples/fireworks_math/train_fireworks_math.sh @@ -9,7 +9,8 @@ export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000 # Find the directory where rllm package is located RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))") -MODEL_PATH=Qwen/Qwen3-4B +MODEL_PATH=Qwen/Qwen3-8B +export CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m examples.fireworks_math.train_fireworks_math \ algorithm.adv_estimator=grpo \ @@ -77,5 +78,6 @@ python3 -m examples.fireworks_math.train_fireworks_math \ trainer.default_hdfs_dir=null \ trainer.total_epochs=100 \ rllm.workflow.use_workflow=True \ - fireworks.deployment_id=wtk15cs9 \ - fireworks.model_id_prefix=qwen3-4b-math \ No newline at end of file + fireworks.deployment_id=rllm-qwen3-8b-3 \ + fireworks.model_id_prefix=rllm-qwen3-8b-math \ + fireworks.return_token_ids=True \ No newline at end of file diff --git a/rllm/engine/rollout/fireworks_engine.py b/rllm/engine/rollout/fireworks_engine.py index 20f91fc26..78eb96059 100644 --- a/rllm/engine/rollout/fireworks_engine.py +++ b/rllm/engine/rollout/fireworks_engine.py @@ -2,7 +2,10 @@ import json import os import time +from urllib.parse import urljoin +import openai +import requests from fireworks.control_plane.generated.protos_grpcio.gateway.deployed_model_pb2 import ( DeployedModel as SyncDeployedModel, ) @@ -12,6 +15,8 @@ from fireworks.gateway import Gateway from rllm.engine.rollout.openai_engine import OpenAIEngine +from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START class FireworksEngine(OpenAIEngine): @@ -103,3 +108,76 @@ async def _probe_deployment(self, model_name) -> bool: continue else: return False + + async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: + kwargs.pop("application_id", None) + kwargs.pop("validate", None) + kwargs.pop("model", None) + kwargs.pop("enforce_max_prompt_length", None) + + sampling_params = self.sampling_params.copy() + sampling_params.update(kwargs) + + create_params = self._prepare_max_tokens_param(sampling_params) + + retries = self.api_retries + while retries > 0: + try: + merged_sampling_params = {**create_params, **sampling_params} + response = self._fireworks_chat_completion(messages=messages, sampling_params=merged_sampling_params) + content = response["choices"][0]["message"]["content"] + reasoning = response["choices"][0]["message"].get("reasoning", "") + tool_calls = response["choices"][0]["message"].get("tool_calls", []) + + # Build text with reasoning if available, otherwise use content + if reasoning: + text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}" + else: + text = content + + prompt_length = response["usage"]["prompt_tokens"] + completion_length = response["usage"]["completion_tokens"] + finish_reason = response["choices"][0]["finish_reason"] + + prompt_token_ids = response["prompt_token_ids"] + completion_token_ids = response["choices"][0]["token_ids"] + return ModelOutput( + text=text, + content=content, + reasoning=reasoning, + tool_calls=tool_calls, + prompt_ids=prompt_token_ids, + completion_ids=completion_token_ids, + prompt_length=prompt_length, + completion_length=completion_length, + finish_reason=finish_reason, + ) + + except openai.RateLimitError: + retries -= 1 + if retries == 0: + raise Exception("Rate limit reached and retries exhausted.") from None + print("Sleep for 5 seconds for API limit.") + await asyncio.sleep(5) + + except Exception as e: + retries -= 1 + if retries == 0: + raise Exception(f"Error processing content after retries: {e}") from e + print(f"Error: {e}, retrying...") + await asyncio.sleep(1) + + def _fireworks_chat_completion(self, messages, sampling_params): + url = urljoin(str(self.client.base_url), "chat/completions") + payload = { + "model": self.model, + "messages": messages, + **sampling_params, + } + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {self.client.api_key}", + } + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + return response.json() diff --git a/rllm/trainer/config/agent_ppo_trainer.yaml b/rllm/trainer/config/agent_ppo_trainer.yaml index 47f40aee6..019849ed3 100644 --- a/rllm/trainer/config/agent_ppo_trainer.yaml +++ b/rllm/trainer/config/agent_ppo_trainer.yaml @@ -65,7 +65,8 @@ fireworks: deployment_id: null model_id_prefix: test-model concurrency: 32 + return_token_ids: False trainer: log_episodes: false - episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} diff --git a/rllm/trainer/verl/agent_workflow_trainer_fireworks.py b/rllm/trainer/verl/agent_workflow_trainer_fireworks.py index 7c2f69fc4..6c2831fe9 100644 --- a/rllm/trainer/verl/agent_workflow_trainer_fireworks.py +++ b/rllm/trainer/verl/agent_workflow_trainer_fireworks.py @@ -97,14 +97,18 @@ def init_workers(self): self.actor_wg.init_model() self.actor_rollout_wg = self.actor_wg # for compatibility + sampling_params = { + "temperature": self.config.actor_rollout_ref.rollout.temperature, + "top_p": self.config.actor_rollout_ref.rollout.top_p, + "max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length, + } + if self.config.fireworks.return_token_ids: + sampling_params["return_token_ids"] = True + fireworks_engine = FireworksEngine( tokenizer=self.tokenizer, deployment_id=self.config.fireworks.deployment_id, - sampling_params={ - "temperature": 0.6, - "top_p": 0.95, - "max_tokens": self.config.data.max_prompt_length + self.config.data.max_response_length, - }, + sampling_params=sampling_params, ) self.fireworks_engine = fireworks_engine self.agent_execution_engine = AgentWorkflowEngine(