Skip to content

Commit 3389cac

Browse files
authored
Fix ImportError when using zenml login without SQL dependencies (#4252)
Fixes #4251 Fixes tests that were broken on Python 3.10 since PR #4224.
1 parent 18d335e commit 3389cac

File tree

3 files changed

+70
-30
lines changed

3 files changed

+70
-30
lines changed

src/zenml/cli/login.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,8 +952,8 @@ def login(
952952
# Get the server that the client is currently connected to, if any
953953
current_non_local_server: Optional[str] = None
954954
gc = GlobalConfiguration()
955-
store_cfg = gc.store_configuration
956-
if store_cfg.type == StoreType.REST:
955+
store_cfg = gc.get_store_configuration(allow_default=False)
956+
if store_cfg and store_cfg.type == StoreType.REST:
957957
if not connected_to_local_server():
958958
current_non_local_server = store_cfg.url
959959

src/zenml/config/global_config.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import uuid
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
19+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, cast, overload
2020
from uuid import UUID
2121

2222
from packaging import version
@@ -488,9 +488,25 @@ def get_config_environment_vars(self) -> Dict[str, str]:
488488

489489
return environment_vars
490490

491-
def _get_store_configuration(
492-
self, baseline: Optional[StoreConfiguration] = None
493-
) -> StoreConfiguration:
491+
@overload
492+
def get_store_configuration(
493+
self,
494+
baseline: Optional[StoreConfiguration] = ...,
495+
allow_default: Literal[True] = ...,
496+
) -> StoreConfiguration: ...
497+
498+
@overload
499+
def get_store_configuration(
500+
self,
501+
baseline: Optional[StoreConfiguration] = ...,
502+
allow_default: Literal[False] = ...,
503+
) -> Optional[StoreConfiguration]: ...
504+
505+
def get_store_configuration(
506+
self,
507+
baseline: Optional[StoreConfiguration] = None,
508+
allow_default: bool = True,
509+
) -> Optional[StoreConfiguration]:
494510
"""Get the store configuration.
495511
496512
This method computes a store configuration starting from a baseline and
@@ -504,9 +520,12 @@ def _get_store_configuration(
504520
505521
Args:
506522
baseline: Optional baseline store configuration to use.
523+
allow_default: Whether to fall back to the default store
524+
configuration if none is set.
507525
508526
Returns:
509-
The store configuration.
527+
The store configuration or `None` if defaults are disallowed and no
528+
configuration is available.
510529
"""
511530
from zenml.zen_stores.base_zen_store import BaseZenStore
512531

@@ -555,15 +574,20 @@ def _get_store_configuration(
555574
logger.debug(
556575
"Using environment variables to update store config"
557576
)
558-
if not store:
577+
if not store and allow_default:
559578
store = self.get_default_store()
560-
store = store.model_copy(update=env_store_config, deep=True)
579+
if store:
580+
store = store.model_copy(
581+
update=env_store_config, deep=True
582+
)
561583

562584
# Step 2: Only after we've applied the environment variables, we
563585
# fallback to the default store if no store configuration is set. This
564586
# is to avoid importing the SQL store config in cases where a rest store
565587
# is configured with environment variables.
566588
if not store:
589+
if not allow_default:
590+
return None
567591
store = self.get_default_store()
568592

569593
# Step 3: Replace or update the baseline secrets store configuration
@@ -631,7 +655,7 @@ def store_configuration(self) -> StoreConfiguration:
631655
# configuration from there and disregard the global configuration.
632656
if self._zen_store is not None:
633657
return self._zen_store.config
634-
return self._get_store_configuration()
658+
return self.get_store_configuration()
635659

636660
def get_default_store(self) -> StoreConfiguration:
637661
"""Get the default SQLite store configuration.
@@ -655,7 +679,7 @@ def set_default_store(self) -> None:
655679
default store.
656680
"""
657681
# Apply the environment variables to the default store configuration
658-
default_store_cfg = self._get_store_configuration(
682+
default_store_cfg = self.get_store_configuration(
659683
baseline=self.get_default_store()
660684
)
661685
self._configure_store(default_store_cfg)
@@ -697,8 +721,10 @@ def set_store(
697721
constructor.
698722
"""
699723
# Apply the environment variables to the custom store configuration
700-
config = self._get_store_configuration(baseline=config)
701-
self._configure_store(config, skip_default_registrations, **kwargs)
724+
resolved_config = self.get_store_configuration(baseline=config)
725+
self._configure_store(
726+
resolved_config, skip_default_registrations, **kwargs
727+
)
702728
logger.info("Updated the global store configuration.")
703729

704730
@property

tests/integration/functional/cli/test_cli.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,31 +84,43 @@ def test_cli_does_not_set_custom_source_root_if_inside_repository(
8484
mock_set_custom_source_root.assert_not_called()
8585

8686

87-
def _mock_rest_store_environment(mocker):
88-
"""Mock dependencies needed for connect_to_server REST flow."""
87+
def _mock_rest_store_environment(mocker, login_module):
88+
"""Mock dependencies needed for connect_to_server REST flow.
89+
90+
Args:
91+
mocker: pytest-mock fixture.
92+
login_module: The imported zenml.cli.login module. We pass this
93+
explicitly and use patch.object() because string-based patches
94+
like "zenml.cli.login.web_login" fail on Python 3.10 due to
95+
the star import in zenml/cli/__init__.py shadowing the submodule.
96+
"""
8997
credentials_store = mocker.Mock()
9098
credentials_store.has_valid_credentials.return_value = True
9199
mocker.patch(
92-
"zenml.cli.login.get_credentials_store", return_value=credentials_store
100+
"zenml.login.credentials_store.get_credentials_store",
101+
return_value=credentials_store,
93102
)
94-
mocker.patch(
95-
"zenml.cli.login.BaseZenStore.get_store_type",
103+
mocker.patch.object(
104+
login_module.BaseZenStore,
105+
"get_store_type",
96106
return_value=StoreType.REST,
97107
)
98-
mocker.patch(
99-
"zenml.cli.login.RestZenStoreConfiguration", return_value="rest-config"
108+
mocker.patch.object(
109+
login_module,
110+
"RestZenStoreConfiguration",
111+
return_value="rest-config",
100112
)
101-
mocker.patch("zenml.cli.login.cli_utils.declare")
102-
mocker.patch("zenml.cli.login.web_login")
113+
mocker.patch.object(login_module.cli_utils, "declare")
114+
mocker.patch.object(login_module, "web_login")
103115

104116

105117
def test_connect_to_server_sets_project_after_success(mocker):
106118
"""Project flag should set the active project after connecting."""
107-
_mock_rest_store_environment(mocker)
108119
login_module = importlib.import_module("zenml.cli.login")
109-
mock_gc = mocker.patch("zenml.cli.login.GlobalConfiguration")
120+
_mock_rest_store_environment(mocker, login_module)
121+
mock_gc = mocker.patch.object(login_module, "GlobalConfiguration")
110122
mock_gc.return_value.set_store.return_value = None
111-
mock_set_project = mocker.patch("zenml.cli.login._set_active_project")
123+
mock_set_project = mocker.patch.object(login_module, "_set_active_project")
112124

113125
login_module.connect_to_server(
114126
url="https://example.com",
@@ -120,13 +132,15 @@ def test_connect_to_server_sets_project_after_success(mocker):
120132

121133
def test_connect_to_server_does_not_set_project_on_failure(mocker):
122134
"""Project change should be skipped if connecting to the store fails."""
123-
_mock_rest_store_environment(mocker)
124135
login_module = importlib.import_module("zenml.cli.login")
125-
mock_gc = mocker.patch("zenml.cli.login.GlobalConfiguration")
136+
_mock_rest_store_environment(mocker, login_module)
137+
mock_gc = mocker.patch.object(login_module, "GlobalConfiguration")
126138
mock_gc.return_value.set_store.side_effect = IllegalOperationError("boom")
127-
mock_set_project = mocker.patch("zenml.cli.login._set_active_project")
128-
mocker.patch(
129-
"zenml.cli.login.cli_utils.error", side_effect=RuntimeError("exit")
139+
mock_set_project = mocker.patch.object(login_module, "_set_active_project")
140+
mocker.patch.object(
141+
login_module.cli_utils,
142+
"error",
143+
side_effect=RuntimeError("exit"),
130144
)
131145

132146
with pytest.raises(RuntimeError):

0 commit comments

Comments
 (0)