Skip to content

Commit 2956f86

Browse files
authored
fix: Gracefully skip overlong prompts during training to prevent crashes (#281)
1 parent 4762c2d commit 2956f86

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

rllm/data/dataset.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from typing import Any
55

6+
import numpy as np
67
import pandas as pd
78
import polars as pl
89
import torch
@@ -373,6 +374,27 @@ def remove_dataset(cls, name: str) -> bool:
373374
logger.info(f"Removed dataset '{name}' from registry.")
374375
return True
375376

377+
@staticmethod
378+
def _convert_to_json_serializable(obj: Any) -> Any:
379+
"""Convert numpy arrays and other non-serializable objects to JSON-serializable types.
380+
381+
Args:
382+
obj: Object to convert
383+
384+
Returns:
385+
JSON-serializable version of the object
386+
"""
387+
if isinstance(obj, np.ndarray):
388+
return obj.tolist()
389+
elif isinstance(obj, np.integer | np.floating):
390+
return obj.item()
391+
elif isinstance(obj, dict):
392+
return {key: DatasetRegistry._convert_to_json_serializable(value) for key, value in obj.items()}
393+
elif isinstance(obj, list | tuple):
394+
return [DatasetRegistry._convert_to_json_serializable(item) for item in obj]
395+
else:
396+
return obj
397+
376398
@classmethod
377399
def apply_verl_postprocessing(cls, data: list[dict[str, Any]]) -> list[dict[str, Any]]:
378400
"""Apply Verl postprocessing to the dataset.
@@ -382,16 +404,27 @@ def apply_verl_postprocessing(cls, data: list[dict[str, Any]]) -> list[dict[str,
382404
383405
Returns:
384406
List of dictionaries with Verl-compatible format
407+
408+
Note:
409+
All nested structures (lists, dicts) are JSON-serialized to avoid
410+
PyArrow "Nested data conversions not implemented for chunked array outputs"
411+
error when loading from Parquet in distributed contexts.
385412
"""
386413
processed_data = []
387414
for entry in data:
415+
# Convert numpy arrays to lists before JSON serialization
416+
serializable_entry = cls._convert_to_json_serializable(entry)
417+
388418
processed_entry = {
389-
"prompt": [{"role": "user", "content": "placeholder"}],
390-
"reward_model": {
391-
"style": "rule",
392-
"ground_truth": None,
393-
},
394-
"extra_info": entry,
419+
# Serialize nested structures as JSON strings to avoid PyArrow chunked array issues
420+
"prompt": json.dumps([{"role": "user", "content": "placeholder"}]),
421+
"reward_model": json.dumps(
422+
{
423+
"style": "rule",
424+
"ground_truth": None,
425+
}
426+
),
427+
"extra_info": json.dumps(serializable_entry),
395428
}
396429
processed_data.append(processed_entry)
397430
return processed_data

0 commit comments

Comments
 (0)