Skip to content

Commit ab9d168

Browse files
authored
Add guards against loading pkl files (#9048)
* Add guards against loading pkl files * remove extra comment * add tests and modify old tests * add import and fix tests * Update test to allow loading of pickled models with a safety flag * Change from dangersouly_allow_pickle to allow_pickle, remove env var, and suggest saving with module.save(x.json) * fix extra whitespace * fix test warning
1 parent fe03ead commit ab9d168

File tree

7 files changed

+93
-18
lines changed

7 files changed

+93
-18
lines changed

docs/docs/tutorials/games/index.ipynb

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,14 +746,21 @@
746746
"If you want to load and use the agent program, you can do that as follows."
747747
]
748748
},
749+
{
750+
"cell_type": "markdown",
751+
"metadata": {},
752+
"source": [
753+
"> **⚠️ Security Warning:** Loading `.pkl` files can execute arbitrary code and may be dangerous. Only save and load pickle files from trusted sources in secure environments. Consider using JSON format when possible for safer serialization."
754+
]
755+
},
749756
{
750757
"cell_type": "code",
751758
"execution_count": 16,
752759
"metadata": {},
753760
"outputs": [],
754761
"source": [
755762
"loaded = Agent()\n",
756-
"loaded.load('finetuned_4o_mini_001.pkl')"
763+
"loaded.load('finetuned_4o_mini_001.pkl', allow_pickle=True)"
757764
]
758765
}
759766
],

docs/docs/tutorials/saving/index.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ compiled_dspy_program.save("./dspy_program/program.json", save_program=False)
3838

3939
To save the state of your program to a pickle file:
4040

41+
!!! danger "Security Warning: Pickle Files Can Execute Arbitrary Code"
42+
Loading `.pkl` files can execute arbitrary code and may be dangerous. Only load pickle files from trusted sources in secure environments. **Prefer using `.json` files whenever possible**. If you must use pickle files, ensure you trust the source and use the `allow_pickle=True` parameter when loading.
43+
4144
```python
4245
compiled_dspy_program.save("./dspy_program/program.pkl", save_program=False)
4346
```
@@ -57,9 +60,12 @@ assert str(compiled_dspy_program.signature) == str(loaded_dspy_program.signature
5760

5861
Or load the state from a pickle file:
5962

63+
!!! danger "Security Warning"
64+
Remember to use `allow_pickle=True` when loading pickle files, and only load from trusted sources.
65+
6066
```python
6167
loaded_dspy_program = dspy.ChainOfThought("question -> answer") # Recreate the same program.
62-
loaded_dspy_program.load("./dspy_program/program.pkl")
68+
loaded_dspy_program.load("./dspy_program/program.pkl", allow_pickle=True)
6369

6470
assert len(compiled_dspy_program.demos) == len(loaded_dspy_program.demos)
6571
for original_demo, loaded_demo in zip(compiled_dspy_program.demos, loaded_dspy_program.demos):
@@ -70,6 +76,9 @@ assert str(compiled_dspy_program.signature) == str(loaded_dspy_program.signature
7076

7177
## Whole Program Saving
7278

79+
!!! warning "Security Notice: Whole Program Saving Uses Pickle"
80+
Whole program saving uses `cloudpickle` for serialization, which has the same security risks as pickle files. Only load programs from trusted sources in secure environments.
81+
7382
Starting from `dspy>=2.6.0`, DSPy supports saving the whole program, including the architecture and the state. This feature
7483
is powered by `cloudpickle`, which is a library for serializing and deserializing Python objects.
7584

dspy/primitives/base_module.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def save(self, path, save_program=False, modules_to_serialize=None):
202202
if not path.exists():
203203
# Create the directory (and any parent directories)
204204
path.mkdir(parents=True)
205-
205+
logger.warning("Loading untrusted .pkl files can run arbitrary code, which may be dangerous. To avoid "
206+
'this, prefer saving using json format using module.save("module.json").')
206207
try:
207208
modules_to_serialize = modules_to_serialize or []
208209
for module in modules_to_serialize:
@@ -233,26 +234,34 @@ def save(self, path, save_program=False, modules_to_serialize=None):
233234
"with `.pkl`, or saving the whole program by setting `save_program=True`."
234235
)
235236
elif path.suffix == ".pkl":
237+
logger.warning("Loading untrusted .pkl files can run arbitrary code, which may be dangerous. To avoid "
238+
'this, prefer saving using json format using module.save("module.json").')
236239
state = self.dump_state(json_mode=False)
237240
state["metadata"] = metadata
238241
with open(path, "wb") as f:
239242
cloudpickle.dump(state, f)
240243
else:
241244
raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")
242245

243-
def load(self, path):
246+
def load(self, path, allow_pickle=False):
244247
"""Load the saved module. You may also want to check out dspy.load, if you want to
245248
load an entire program, not just the state for an existing program.
246249
247250
Args:
248251
path (str): Path to the saved state file, which should be a .json or a .pkl file
252+
allow_pickle (bool): If True, allow loading .pkl files, which can run arbitrary code.
253+
This is dangerous and should only be used if you are sure about the source of the file and in a trusted environment.
249254
"""
250255
path = Path(path)
251256

252257
if path.suffix == ".json":
253258
with open(path, "rb") as f:
254259
state = orjson.loads(f.read())
255260
elif path.suffix == ".pkl":
261+
if not allow_pickle:
262+
raise ValueError("Loading .pkl files can run arbitrary code, which may be dangerous. Prefer "
263+
"saving with .json files if possible. Set `allow_pickle=True` "
264+
"if you are sure about the source of the file and in a trusted environment.")
256265
with open(path, "rb") as f:
257266
state = cloudpickle.load(f)
258267
else:

dspy/utils/saving.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,21 @@ def get_dependency_versions():
2424
}
2525

2626

27-
def load(path: str) -> "Module":
27+
def load(path: str, allow_pickle: bool = False) -> "Module":
2828
"""Load saved DSPy model.
2929
3030
This method is used to load a saved DSPy model with `save_program=True`, i.e., the model is saved with cloudpickle.
3131
3232
Args:
3333
path (str): Path to the saved model.
34+
allow_pickle (bool): Whether to allow loading the model with pickle. This is dangerous and should only be used if you are sure you trust the source of the model.
3435
3536
Returns:
3637
The loaded model, a `dspy.Module` instance.
3738
"""
39+
if not allow_pickle:
40+
raise ValueError("Loading with pickle is not allowed. Please set `allow_pickle=True` if you are sure you trust the source of the model.")
41+
3842
path = Path(path)
3943
if not path.exists():
4044
raise FileNotFoundError(f"The path '{path}' does not exist.")

tests/predict/test_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_lm_field_after_dump_and_load_state(tmp_path, filename):
253253
assert file_path.exists()
254254

255255
loaded_predict = dspy.Predict("q->a")
256-
loaded_predict.load(file_path)
256+
loaded_predict.load(file_path, allow_pickle=True)
257257

258258
assert original_predict.dump_state() == loaded_predict.dump_state()
259259

tests/primitives/test_base_module.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def dummy_metric(example, pred, trace=None):
123123
compiled_cot.save(save_path)
124124

125125
new_cot = dspy.ChainOfThought(MySignature)
126-
new_cot.load(save_path)
126+
new_cot.load(save_path, allow_pickle=True)
127127

128128
assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature)
129129
assert new_cot.predict.demos == compiled_cot.predict.demos
@@ -162,7 +162,7 @@ def forward(self, q):
162162

163163
# Test the loading fails without using `modules_to_serialize`
164164
with pytest.raises(ModuleNotFoundError):
165-
dspy.load(tmp_path)
165+
dspy.load(tmp_path, allow_pickle=True)
166166

167167
sys.path.insert(0, str(tmp_path))
168168
import custom_module
@@ -179,7 +179,7 @@ def forward(self, q):
179179
sys.path.remove(str(tmp_path))
180180
del custom_module
181181

182-
loaded_module = dspy.load(tmp_path)
182+
loaded_module = dspy.load(tmp_path, allow_pickle=True)
183183
assert loaded_module.cot.predict.signature == cot.cot.predict.signature
184184

185185
finally:
@@ -223,12 +223,16 @@ def emit(self, record):
223223
# Mock version during load
224224
with patch("dspy.primitives.base_module.get_dependency_versions", return_value=load_versions):
225225
loaded_predict = dspy.Predict("question->answer")
226-
loaded_predict.load(save_path)
226+
loaded_predict.load(save_path, allow_pickle=True)
227227

228-
# Assert warnings were logged, and one warning for each mismatched dependency.
229-
assert len(handler.messages) == 3
228+
# Assert warnings were logged: 1 for pickle loading + 3 for version mismatches
229+
assert len(handler.messages) == 4
230230

231-
for msg in handler.messages:
231+
# First message is about pickle loading
232+
assert ".pkl" in handler.messages[0]
233+
234+
# Rest are version mismatch warnings
235+
for msg in handler.messages[1:]:
232236
assert "There is a mismatch of" in msg
233237

234238
# Verify the model still loads correctly despite version mismatches

tests/utils/test_saving.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_save_predict(tmp_path):
1414
assert (tmp_path / "metadata.json").exists()
1515
assert (tmp_path / "program.pkl").exists()
1616

17-
loaded_predict = dspy.load(tmp_path)
17+
loaded_predict = dspy.load(tmp_path, allow_pickle=True)
1818
assert isinstance(loaded_predict, dspy.Predict)
1919

2020
assert predict.signature == loaded_predict.signature
@@ -29,7 +29,7 @@ def __init__(self):
2929
model = CustomModel()
3030
model.save(tmp_path, save_program=True)
3131

32-
loaded_model = dspy.load(tmp_path)
32+
loaded_model = dspy.load(tmp_path, allow_pickle=True)
3333
assert isinstance(loaded_model, CustomModel)
3434

3535
assert len(model.predictors()) == len(loaded_model.predictors())
@@ -51,7 +51,7 @@ class MySignature(dspy.Signature):
5151
predict.signature = predict.signature.with_instructions("You are a helpful assistant.")
5252
predict.save(tmp_path, save_program=True)
5353

54-
loaded_predict = dspy.load(tmp_path)
54+
loaded_predict = dspy.load(tmp_path, allow_pickle=True)
5555
assert isinstance(loaded_predict, dspy.Predict)
5656

5757
assert predict.signature == loaded_predict.signature
@@ -77,7 +77,7 @@ def dummy_metric(example, pred, trace=None):
7777
compiled_predict = optimizer.compile(predict, trainset=trainset)
7878
compiled_predict.save(tmp_path, save_program=True)
7979

80-
loaded_predict = dspy.load(tmp_path)
80+
loaded_predict = dspy.load(tmp_path, allow_pickle=True)
8181
assert compiled_predict.demos == loaded_predict.demos
8282
assert compiled_predict.signature == loaded_predict.signature
8383

@@ -115,7 +115,7 @@ def emit(self, record):
115115

116116
# Mock version during load
117117
with patch("dspy.utils.saving.get_dependency_versions", return_value=load_versions):
118-
loaded_predict = dspy.load(tmp_path)
118+
loaded_predict = dspy.load(tmp_path, allow_pickle=True)
119119

120120
# Assert warnings were logged, and one warning for each mismatched dependency.
121121
assert len(handler.messages) == 3
@@ -131,3 +131,45 @@ def emit(self, record):
131131
# Clean up: restore original level and remove handler
132132
logger.setLevel(original_level)
133133
logger.removeHandler(handler)
134+
135+
136+
def test_pickle_loading_requires_explicit_permission(tmp_path):
137+
"""Test that loading pickle files requires explicit permission."""
138+
predict = dspy.Predict("question->answer")
139+
predict.save(tmp_path, save_program=True)
140+
141+
# Should fail without dangerously_allow_pickle
142+
with pytest.raises(ValueError, match="Loading with pickle is not allowed"):
143+
dspy.load(tmp_path)
144+
145+
# Should succeed with dangerously_allow_pickle
146+
loaded_predict = dspy.load(tmp_path, allow_pickle=True)
147+
assert isinstance(loaded_predict, dspy.Predict)
148+
149+
150+
def test_pkl_file_loading_requires_explicit_permission(tmp_path):
151+
"""Test that loading .pkl files requires explicit permission."""
152+
predict = dspy.Predict("question->answer")
153+
pkl_path = tmp_path / "model.pkl"
154+
predict.save(pkl_path)
155+
156+
# Should fail without allow_pickle
157+
new_predict = dspy.Predict("question->answer")
158+
with pytest.raises(ValueError, match="Loading .pkl files can run arbitrary code"):
159+
new_predict.load(pkl_path)
160+
161+
# Should succeed with allow_pickle
162+
new_predict.load(pkl_path, allow_pickle=True)
163+
assert new_predict.dump_state() == predict.dump_state()
164+
165+
166+
def test_json_file_loading_works_without_permission(tmp_path):
167+
"""Test that loading .json files works without explicit permission."""
168+
predict = dspy.Predict("question->answer")
169+
json_path = tmp_path / "model.json"
170+
predict.save(json_path)
171+
172+
# Should succeed without allow_pickle
173+
new_predict = dspy.Predict("question->answer")
174+
new_predict.load(json_path)
175+
assert new_predict.dump_state() == predict.dump_state()

0 commit comments

Comments
 (0)