Skip to content

Commit 922e854

Browse files
authored
feat(kernels): add opt-out flag to disable kernels hub usage through the lib (#41990)
* feat(kernels): add opt-out flag to disable kernels hub usage through the lib * misc(quality): style * test(kernels): add some opt-out test logic * chore(quality): style here we go again * chore(quality): style here we go again ... again * chore(quality): STYLE
1 parent f9e668a commit 922e854

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
import re
1516
from collections.abc import Callable
1617
from functools import partial
1718
from types import ModuleType
1819
from typing import Optional, Union
1920

2021
from ..modeling_flash_attention_utils import lazy_import_flash_attention
21-
from ..utils import logging
22+
from ..utils import ENV_VARS_TRUE_VALUES, logging
2223
from ..utils.import_utils import is_kernels_available
2324
from .flash_attention import flash_attention_forward
2425

@@ -33,10 +34,22 @@
3334
get_kernel,
3435
register_kernel_mapping,
3536
replace_kernel_forward_from_hub,
36-
use_kernel_forward_from_hub,
3737
)
3838

39+
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
3940
_kernels_available = True
41+
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
42+
43+
def use_kernel_forward_from_hub(layer_name: str):
44+
if _kernels_enabled:
45+
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
46+
47+
return _kernels_use_kernel_forward_from_hub(layer_name)
48+
else:
49+
logger.warning_once(
50+
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
51+
)
52+
return lambda cls: cls
4053

4154
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
4255
"MultiScaleDeformableAttention": {
@@ -167,6 +180,7 @@ def register_kernel_mapping_transformers(mapping=None):
167180

168181
except ImportError:
169182
_kernels_available = False
183+
_kernels_enabled = False
170184

171185
# Stub to make decorators int transformers work when `kernels`
172186
# is not installed.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
import unittest
3+
from unittest.mock import patch
4+
5+
from transformers.testing_utils import require_kernels
6+
7+
8+
@require_kernels
9+
class HubKernelsTests(unittest.TestCase):
10+
def test_disable_hub_kernels(self):
11+
"""
12+
Test that _kernels_enabled is False when USE_HUB_KERNELS when USE_HUB_KERNELS=OFF
13+
"""
14+
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
15+
# Re-import to ensure the environment variable takes effect
16+
import importlib
17+
18+
from transformers.integrations import hub_kernels
19+
20+
importlib.reload(hub_kernels)
21+
22+
# Verify that kernels are disabled
23+
self.assertFalse(hub_kernels._kernels_enabled)
24+
25+
def test_enable_hub_kernels_default(self):
26+
"""
27+
Test that _kernels_enabled is True when USE_HUB_KERNELS is not provided (default behavior)
28+
"""
29+
# Remove USE_HUB_KERNELS from the environment if it exists
30+
env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
31+
with patch.dict(os.environ, env_without_hub_kernels, clear=True):
32+
# Re-import to ensure the environment variable change takes effect
33+
import importlib
34+
35+
from transformers.integrations import hub_kernels
36+
37+
importlib.reload(hub_kernels)
38+
39+
# Verify that kernels are enabled by default
40+
self.assertTrue(hub_kernels._kernels_enabled)
41+
42+
def test_enable_hub_kernels_on(self):
43+
"""
44+
Test that _kernels_enabled is True when USE_HUB_KERNELS=ON
45+
"""
46+
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
47+
# Re-import to ensure the environment variable takes effect
48+
import importlib
49+
50+
from transformers.integrations import hub_kernels
51+
52+
importlib.reload(hub_kernels)
53+
54+
# Verify that kernels are enabled
55+
self.assertTrue(hub_kernels._kernels_enabled)
56+
57+
@patch("kernels.use_kernel_forward_from_hub")
58+
def test_use_kernel_forward_from_hub_not_called_when_disabled(self, mocked_use_kernel_forward):
59+
"""
60+
Test that kernels.use_kernel_forward_from_hub is not called when USE_HUB_KERNELS is disabled
61+
"""
62+
# Set environment variable to disable hub kernels
63+
with patch.dict(os.environ, {"USE_HUB_KERNELS": "OFF"}):
64+
# Re-import to ensure the environment variable takes effect
65+
import importlib
66+
67+
from transformers.integrations import hub_kernels
68+
69+
importlib.reload(hub_kernels)
70+
71+
# Call the function with a test layer name
72+
decorator = hub_kernels.use_kernel_forward_from_hub("DummyLayer")
73+
74+
# Verify that the kernels function was never called
75+
mocked_use_kernel_forward.assert_not_called()
76+
77+
# Verify that we get a no-op decorator
78+
class FooClass:
79+
pass
80+
81+
result = decorator(FooClass)
82+
self.assertIs(result, FooClass)
83+
84+
@patch("kernels.use_kernel_forward_from_hub")
85+
def test_use_kernel_forward_from_hub_called_when_enabled_default(self, mocked_use_kernel_forward):
86+
"""
87+
Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS is not set (default)
88+
"""
89+
# Remove USE_HUB_KERNELS from the environment if it exists
90+
env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
91+
with patch.dict(os.environ, env_without_hub_kernels, clear=True):
92+
# Re-import to ensure the environment variable change takes effect
93+
import importlib
94+
95+
from transformers.integrations import hub_kernels
96+
97+
importlib.reload(hub_kernels)
98+
99+
# Call the function with a test layer name
100+
hub_kernels.use_kernel_forward_from_hub("FooLayer")
101+
102+
# Verify that the kernels function was called once with the correct argument
103+
mocked_use_kernel_forward.assert_called_once_with("FooLayer")
104+
105+
@patch("kernels.use_kernel_forward_from_hub")
106+
def test_use_kernel_forward_from_hub_called_when_enabled_on(self, mocked_use_kernel_forward):
107+
"""
108+
Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS=ON
109+
"""
110+
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
111+
# Re-import to ensure the environment variable change takes effect
112+
import importlib
113+
114+
from transformers.integrations import hub_kernels
115+
116+
importlib.reload(hub_kernels)
117+
118+
# Call the function with a test layer name
119+
hub_kernels.use_kernel_forward_from_hub("FooLayer")
120+
121+
# Verify that the kernels function was called once with the correct argument
122+
mocked_use_kernel_forward.assert_called_once_with("FooLayer")

0 commit comments

Comments
 (0)