From 414dfbfe5d12f4e9c7b30977c49c48871abe6ba2 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 9 Nov 2025 02:11:32 +0100 Subject: [PATCH 1/3] init --- docs/book/how-to/deployment/deployment.md | 34 +- .../steps-pipelines/advanced_features.md | 22 + .../weather_agent/pipelines/weather_agent.py | 7 +- examples/weather_agent/run.py | 2 + examples/weather_agent/steps/__init__.py | 4 +- examples/weather_agent/steps/comparison.py | 35 ++ examples/weather_agent/steps/weather_agent.py | 15 +- examples/weather_agent/ui/index.html | 151 ++++- src/zenml/config/deployment_settings.py | 39 ++ src/zenml/deployers/server/models.py | 13 + src/zenml/deployers/server/runtime.py | 46 ++ src/zenml/deployers/server/service.py | 161 ++++- src/zenml/deployers/server/sessions.py | 548 ++++++++++++++++++ src/zenml/steps/step_context.py | 33 ++ tests/unit/deployers/server/test_runtime.py | 52 ++ tests/unit/deployers/server/test_service.py | 103 +++- tests/unit/deployers/server/test_sessions.py | 157 +++++ 17 files changed, 1387 insertions(+), 35 deletions(-) create mode 100644 examples/weather_agent/steps/comparison.py create mode 100644 src/zenml/deployers/server/sessions.py create mode 100644 tests/unit/deployers/server/test_sessions.py diff --git a/docs/book/how-to/deployment/deployment.md b/docs/book/how-to/deployment/deployment.md index 84e773b85b6..33f8f4ec59d 100644 --- a/docs/book/how-to/deployment/deployment.md +++ b/docs/book/how-to/deployment/deployment.md @@ -148,6 +148,38 @@ curl -X POST http://localhost:8000/invoke \ -d '{"parameters": {"city": "London", "temperature": 20}}' ``` +### Session-aware invocations + +Deployments remember information across `/invoke` calls because **sessions are enabled by default** (using the in-memory backend). Define the block below only if you want to opt-out or customize TTL/size limits: + +```yaml +deployment_settings: + sessions: + enabled: false # disable sessions for stateless deployments + ttl_seconds: 3600 # override session expiration + max_state_bytes: 32768 # tighten state payload guardrail +``` + +Key things to know when sessions are enabled: + +* Clients may supply an optional `session_id` flag (CLI) or JSON field. If omitted, the deployment generates a new ID and echoes it back in `metadata.session_id`. Reuse that ID to resume the same conversation. +* The session identifier is included in deployment logs/metrics so multi-turn traces remain easy to follow. +* Steps can read/write a mutable `session_state` dict via `get_step_context().session_state`. Mutations are persisted automatically after the run. See [Managing conversational session state](../steps-pipelines/advanced_features.md#managing-conversational-session-state) for concrete step examples. + +Example invocation flow: + +```bash +# First turn – server generates the session_id +zenml deployment invoke my_deployment --city="London" --temperature=20 + +# Follow-up turn reusing the echoed session id +zenml deployment invoke my_deployment \ + --city="Berlin" --temperature=18 \ + --session-id="session-12345" +``` + +The [weather agent example](https://github.com/zenml-io/zenml/tree/develop/examples/weather_agent) shows how a real pipeline uses `session_state` to keep a running history of weather analyses across turns. + ## Deployment Lifecycle Once a Deployment is created, it is tied to the specific **Deployer** stack component that was used to provision it and can be managed independently of the active stack as a standalone entity with its own lifecycle. @@ -519,4 +551,4 @@ Pipeline deployment transforms ZenML pipelines from batch processing workflows i See also: - [Steps & Pipelines](../steps-pipelines/steps_and_pipelines.md) - Core building blocks -- [Deployer Stack Component](../../component-guide/deployers/README.md) - The stack component that manages the deployment of pipelines as long-running HTTP servers \ No newline at end of file +- [Deployer Stack Component](../../component-guide/deployers/README.md) - The stack component that manages the deployment of pipelines as long-running HTTP servers diff --git a/docs/book/how-to/steps-pipelines/advanced_features.md b/docs/book/how-to/steps-pipelines/advanced_features.md index 8590d089343..058508f985b 100644 --- a/docs/book/how-to/steps-pipelines/advanced_features.md +++ b/docs/book/how-to/steps-pipelines/advanced_features.md @@ -674,6 +674,28 @@ def my_step(some_parameter: int = 1): raise ValueError("My exception") ``` +### Managing conversational session state + +When a deployment is invoked with sessions enabled (the default behavior for deployer-based services), each step can access a per-session dictionary through the step context. This is useful for LLM workflows or any pipeline that needs to remember information across `/invoke` calls. + +```python +from zenml import step, get_step_context + +@step +def chat_step(message: str) -> str: + ctx = get_step_context() + session_state = ctx.session_state # Live dict persisted after the run + history = session_state.setdefault("history", []) + history.append({"role": "user", "content": message}) + + reply = f"Echoing turn {len(history)}: {message}" + history.append({"role": "assistant", "content": reply}) + session_state["last_reply"] = reply + return reply +``` + +If sessions are disabled for a deployment, `ctx.session_state` simply returns an empty dict, so the same code works without extra guards. + ### Using Alerter in Hooks You can use the [Alerter stack component](https://docs.zenml.io/component-guide/alerters) to send notifications when steps fail or succeed: diff --git a/examples/weather_agent/pipelines/weather_agent.py b/examples/weather_agent/pipelines/weather_agent.py index ceedadb8096..517be106782 100644 --- a/examples/weather_agent/pipelines/weather_agent.py +++ b/examples/weather_agent/pipelines/weather_agent.py @@ -19,7 +19,7 @@ init_hook, ) from starlette.middleware.gzip import GZipMiddleware -from steps import analyze_weather_with_llm, get_weather +from steps import analyze_weather_with_llm, compare_city_trends, get_weather from zenml import pipeline from zenml.config import ( @@ -266,7 +266,7 @@ def on_shutdown( ) def weather_agent( city: str = "London", -) -> tuple[Dict[str, float], str]: +) -> tuple[Dict[str, float], str, str]: """Weather agent pipeline. Args: @@ -277,4 +277,5 @@ def weather_agent( """ weather_data = get_weather(city=city) result = analyze_weather_with_llm(weather_data=weather_data, city=city) - return weather_data, result + comparison = compare_city_trends(analysis=result) + return weather_data, result, comparison diff --git a/examples/weather_agent/run.py b/examples/weather_agent/run.py index 77733e85a44..28fe6a86e72 100644 --- a/examples/weather_agent/run.py +++ b/examples/weather_agent/run.py @@ -14,4 +14,6 @@ # run = client.get_pipeline_run(run.id) if run: result = run.steps["analyze_weather_with_llm"].output.load() + comparison = run.steps["compare_city_trends"].output.load() print(result) + print("\n" + comparison) diff --git a/examples/weather_agent/steps/__init__.py b/examples/weather_agent/steps/__init__.py index fe6cb024d25..a6288673154 100644 --- a/examples/weather_agent/steps/__init__.py +++ b/examples/weather_agent/steps/__init__.py @@ -4,9 +4,11 @@ pipeline. """ +from .comparison import compare_city_trends from .weather_agent import analyze_weather_with_llm, get_weather __all__ = [ "analyze_weather_with_llm", + "compare_city_trends", "get_weather", -] \ No newline at end of file +] diff --git a/examples/weather_agent/steps/comparison.py b/examples/weather_agent/steps/comparison.py new file mode 100644 index 00000000000..2d142ea1ce0 --- /dev/null +++ b/examples/weather_agent/steps/comparison.py @@ -0,0 +1,35 @@ +"""Comparison steps for the weather agent pipeline.""" + +from typing import Annotated + +from zenml import step +from zenml.steps import get_step_context + + +@step +def compare_city_trends(analysis: str) -> Annotated[str, "city_comparison"]: + """Return how the current city compares to the previous turn. + + Args: + analysis: The analysis of the current city. + + Returns: + A string comparing the current city to the previous city. + """ + session_state = get_step_context().session_state + history = session_state.get("history", []) + if len(history) < 2: + return "Not enough history to compare cities yet." + + current = history[-1] + previous = history[-2] + delta_temp = current["temperature"] - previous["temperature"] + delta_humidity = current["humidity"] - previous["humidity"] + delta_wind = current["wind_speed"] - previous["wind_speed"] + + return ( + f"Comparing {current['city']} to {previous['city']}:\n" + f"• Temperature change: {delta_temp:+.1f}°C\n" + f"• Humidity change: {delta_humidity:+.0f}%\n" + f"• Wind speed change: {delta_wind:+.1f} km/h" + ) diff --git a/examples/weather_agent/steps/weather_agent.py b/examples/weather_agent/steps/weather_agent.py index 347f63b471b..55e2eac52a1 100644 --- a/examples/weather_agent/steps/weather_agent.py +++ b/examples/weather_agent/steps/weather_agent.py @@ -40,6 +40,17 @@ def analyze_weather_with_llm( wind = weather_data["wind_speed"] step_context = get_step_context() + session_state = step_context.session_state + history = session_state.setdefault("history", []) + history.append( + { + "city": city, + "temperature": round(temp, 2), + "humidity": humidity, + "wind_speed": round(wind, 2), + } + ) + session_state["turn_count"] = len(history) pipeline_state = step_context.pipeline_state client = None @@ -83,7 +94,7 @@ def analyze_weather_with_llm( llm_analysis = response.choices[0].message.content - return f"""🤖 LLM Weather Analysis for {city}: + return f"""🤖 LLM Weather Analysis for {city} (turn {len(history)}): {llm_analysis} @@ -138,7 +149,7 @@ def analyze_weather_with_llm( if wind > 20: warning += " Strong winds - secure loose items." - return f"""🤖 Weather Analysis for {city}: + return f"""🤖 Weather Analysis for {city} (turn {len(history)}): Assessment: {temp_desc.title()} weather with {humidity}% humidity Comfort Level: {comfort}/10 diff --git a/examples/weather_agent/ui/index.html b/examples/weather_agent/ui/index.html index e9540900eda..96ccde7d51c 100644 --- a/examples/weather_agent/ui/index.html +++ b/examples/weather_agent/ui/index.html @@ -154,9 +154,14 @@ /* ===== Form ===== */ .form-row { display: flex; - gap: var(--spacing-sm); + gap: var(--spacing-md); align-items: center; flex-wrap: wrap; + margin-bottom: var(--spacing-md); + } + + .form-row:last-of-type { + margin-bottom: 0; } .input { @@ -230,13 +235,16 @@ /* ===== Status Pill ===== */ .status-pill { - display: inline-block; - padding: 6px 12px; + display: inline-flex; + align-items: center; + padding: 8px 16px; border-radius: var(--radius-full); - font-size: 12px; + font-size: 13px; font-weight: 600; line-height: 18px; - margin-top: var(--spacing-sm); + margin-top: var(--spacing-md); + font-family: var(--font-mono); + border: 1px solid transparent; } .status-pill.auth-set { @@ -249,6 +257,12 @@ color: var(--warning); } + .status-pill.info { + background: var(--zenml-purple-lighter); + color: var(--zenml-purple-dark); + border-color: var(--zenml-purple-light); + } + /* ===== Messages ===== */ .message { margin-top: var(--spacing-md); @@ -273,17 +287,40 @@ /* ===== Result Display ===== */ .result { margin-top: var(--spacing-lg); - padding: var(--spacing-lg); + padding: var(--spacing-xl); border: 1px solid var(--border-moderate); border-radius: var(--radius-md); - background: var(--surface-secondary); - white-space: pre-wrap; + background: var(--surface-primary); + box-shadow: var(--shadow-subtle); } .result:empty { display: none; } + .result-header { + display: flex; + align-items: center; + gap: var(--spacing-sm); + padding: var(--spacing-md) var(--spacing-lg); + background: var(--zenml-purple-lighter); + border-left: 4px solid var(--zenml-purple); + border-radius: var(--radius-md); + font-size: 14px; + line-height: 20px; + color: var(--zenml-purple-dark); + } + + .result-header::before { + content: '🤖'; + font-size: 20px; + } + + .result-header strong { + font-weight: 600; + color: var(--zenml-purple-dark); + } + /* Markdown content styling */ .result h1, .result h2, @@ -315,13 +352,31 @@ } .result ul { - padding-left: 20px; - margin: var(--spacing-sm) 0; + padding-left: 24px; + margin: var(--spacing-md) 0; + list-style-type: disc; } .result li { - margin: var(--spacing-xs) 0; + margin: var(--spacing-sm) 0; + color: var(--text-primary); + line-height: 1.6; + } + + .result ol { + padding-left: 24px; + margin: var(--spacing-md) 0; + } + + .result ol li { + margin: var(--spacing-md) 0; color: var(--text-primary); + line-height: 1.6; + } + + .result strong { + color: var(--zenml-purple-dark); + font-weight: 600; } .result code { @@ -546,6 +601,17 @@

Try it

Set API Key +
+ + +
Try it Auth required
+
-
+ + @@ -586,6 +664,14 @@

Try it

const authStatus = document.getElementById('auth-status'); const API_KEY_STORAGE = 'weather_agent_api_key'; const resultEl = document.getElementById('result'); + const resultContainer = document.getElementById('result-container'); + const cityNameEl = document.getElementById('city-name'); + const turnNumberEl = document.getElementById('turn-number'); + const comparisonEl = document.getElementById('comparison'); + const comparisonContainer = document.getElementById('comparison-container'); + const sessionInput = document.getElementById('session-input'); + const sessionInfo = document.getElementById('session-info'); + const resetSessionBtn = document.getElementById('reset-session-btn'); function setLoading(isLoading) { fetchBtn.disabled = isLoading; @@ -623,15 +709,24 @@

Try it

} setLoading(true); resultEl.innerHTML = ''; + resultContainer.style.display = 'none'; metricsEl.innerHTML = ''; + comparisonEl.innerHTML = ''; + comparisonContainer.style.display = 'none'; + sessionInfo.style.display = 'none'; try { const headers = { 'Content-Type': 'application/json' }; const key = getApiKey(); if (AUTH_ENABLED && key) headers['Authorization'] = `Bearer ${key}`; + const sessionId = sessionInput.value.trim(); + const body = { parameters: { city: city } }; + if (sessionId) { + body.session_id = sessionId; + } const res = await fetch(INVOKE_URL, { method: 'POST', headers, - body: JSON.stringify({ parameters: { city: city } }) + body: JSON.stringify(body) }); if (!res.ok) { const text = await res.text(); @@ -645,6 +740,8 @@

Try it

const outputs = data.outputs || {}; const weather = outputs.weather_data || null; const md = outputs.weather_analysis; + const comparison = outputs.city_comparison; + const metadata = data.metadata || {}; if (!md) { messageEl.textContent = 'No analysis returned.'; messageEl.className = 'message info'; @@ -655,6 +752,21 @@

Try it

} const html = marked.parse(md); resultEl.innerHTML = DOMPurify.sanitize(html); + cityNameEl.textContent = city; + const turnNum = metadata.turn_number || 1; + turnNumberEl.textContent = turnNum; + resultContainer.style.display = 'block'; + if (comparison) { + const comparisonHtml = marked.parse(comparison); + comparisonEl.innerHTML = DOMPurify.sanitize(comparisonHtml); + comparisonContainer.style.display = 'block'; + } + if (metadata.session_id) { + sessionInput.value = metadata.session_id; + sessionInfo.textContent = `Session ID: ${metadata.session_id}`; + sessionInfo.className = 'status-pill info'; + sessionInfo.style.display = 'inline-flex'; + } clearMessage(); } catch (e) { showError(e.message || String(e)); @@ -663,6 +775,17 @@

Try it

} } + resetSessionBtn.addEventListener('click', () => { + sessionInput.value = ''; + sessionInfo.style.display = 'none'; + resultContainer.style.display = 'none'; + resultEl.innerHTML = ''; + comparisonEl.innerHTML = ''; + comparisonContainer.style.display = 'none'; + metricsEl.innerHTML = ''; + clearMessage(); + }); + fetchBtn.addEventListener('click', fetchSuggestions); cityInput.addEventListener('keydown', function (e) { if (e.key === 'Enter') { fetchSuggestions(); } @@ -798,4 +921,4 @@

Try it

`; } - \ No newline at end of file + diff --git a/src/zenml/config/deployment_settings.py b/src/zenml/config/deployment_settings.py index ba40d74d609..9f6f1af1dfa 100644 --- a/src/zenml/config/deployment_settings.py +++ b/src/zenml/config/deployment_settings.py @@ -526,6 +526,40 @@ class DeploymentDefaultMiddleware(IntFlag): ALL = CORS | SECURE_HEADERS +class SessionBackendType(str, Enum): + """Session storage backend types.""" + + INMEMORY = "inmemory" + LOCAL = "local" + REDIS = "redis" + + +class SessionSettings(BaseModel): + """Configuration for deployment session management. + + Sessions enable stateful interactions across multiple deployment + invocations. Phase 1 supports only in-memory storage; the schema + is forward-compatible with persistent backends (local, redis). + + Attributes: + enabled: Whether session management is enabled for this deployment. + backend: Storage backend type (only 'inmemory' supported in Phase 1). + ttl_seconds: Default session TTL in seconds (None = no expiry). + max_state_bytes: Maximum size for session state in bytes. + max_sessions: Maximum number of sessions to store (LRU eviction). + backend_config: Backend-specific configuration (reserved for future use). + """ + + model_config = ConfigDict(extra="forbid") + + enabled: bool = True + backend: SessionBackendType = SessionBackendType.INMEMORY + ttl_seconds: Optional[int] = 24 * 60 * 60 # 24 hours default + max_state_bytes: Optional[int] = 64 * 1024 # 64 KB default + max_sessions: Optional[int] = 10_000 + backend_config: Dict[str, Any] = Field(default_factory=dict) + + class DeploymentSettings(BaseSettings): """Settings for the pipeline deployment. @@ -693,6 +727,11 @@ class DeploymentSettings(BaseSettings): # Pluggable app extensions for advanced features app_extensions: Optional[List[AppExtensionSpec]] = None + sessions: SessionSettings = Field( + default_factory=SessionSettings, + title="Session management configuration.", + ) + uvicorn_host: str = "0.0.0.0" # nosec uvicorn_port: int = 8000 uvicorn_workers: int = 1 diff --git a/src/zenml/deployers/server/models.py b/src/zenml/deployers/server/models.py index b2afd395530..088aed9dd38 100644 --- a/src/zenml/deployers/server/models.py +++ b/src/zenml/deployers/server/models.py @@ -44,6 +44,11 @@ class DeploymentInvocationResponseMetadata(BaseModel): title="The parameters used for the pipeline execution." ) + session_id: Optional[str] = Field( + default=None, + title="The session ID used for this invocation (if sessions enabled).", + ) + class BaseDeploymentInvocationRequest(BaseModel): """Base pipeline invoke request model.""" @@ -63,6 +68,14 @@ class BaseDeploymentInvocationRequest(BaseModel): "storing them as artifacts.", ) + session_id: Optional[str] = Field( + default=None, + title="Optional session ID to resume existing session state. " + "If provided and sessions are enabled, the deployment will attempt " + "to load the session. If not found or expired, a new session with " + "this ID will be created.", + ) + class BaseDeploymentInvocationResponse(BaseModel): """Base pipeline invoke response model.""" diff --git a/src/zenml/deployers/server/runtime.py b/src/zenml/deployers/server/runtime.py index 2660136298d..f14f33c6591 100644 --- a/src/zenml/deployers/server/runtime.py +++ b/src/zenml/deployers/server/runtime.py @@ -43,6 +43,10 @@ class _DeploymentState(BaseModel): # In-memory data storage for artifacts in_memory_data: Dict[str, Any] = Field(default_factory=dict) + # Session management + session_id: Optional[str] = None + session_state: Dict[str, Any] = Field(default_factory=dict) + def reset(self) -> None: """Reset the deployment state.""" self.active = False @@ -52,6 +56,8 @@ def reset(self) -> None: self.outputs = {} self.skip_artifact_materialization = False self.in_memory_data = {} + self.session_id = None + self.session_state = {} _deployment_context: contextvars.ContextVar[_DeploymentState] = ( @@ -73,6 +79,8 @@ def start( snapshot: PipelineSnapshotResponse, parameters: Dict[str, Any], skip_artifact_materialization: bool = False, + session_id: Optional[str] = None, + session_state: Optional[Dict[str, Any]] = None, ) -> None: """Initialize deployment state for the current request context. @@ -81,6 +89,8 @@ def start( snapshot: The snapshot to deploy. parameters: The parameters to deploy. skip_artifact_materialization: Whether to skip artifact materialization. + session_id: Optional session ID for stateful deployments. + session_state: Optional session state dictionary for stateful deployments. """ state = _DeploymentState() state.active = True @@ -89,6 +99,8 @@ def start( state.pipeline_parameters = parameters state.outputs = {} state.skip_artifact_materialization = skip_artifact_materialization + state.session_id = session_id + state.session_state = dict(session_state or {}) _deployment_context.set(state) @@ -168,3 +180,37 @@ def get_in_memory_data(uri: str) -> Any: state = _get_context() return state.in_memory_data[uri] return None + + +def get_session_id() -> Optional[str]: + """Get the current session ID. + + Returns: + Session ID if active, None otherwise. + """ + if is_active(): + return _get_context().session_id + return None + + +def get_session_state() -> Dict[str, Any]: + """Get the current session state. + + Returns: + Live session state dictionary. Mutations to this dict will be + reflected in the runtime context. Empty dict if no session is active. + """ + if is_active(): + return _get_context().session_state + return {} + + +def set_session_state(state: Dict[str, Any]) -> None: + """Replace the current session state. + + Args: + state: New session state dictionary to store (will be copied). + """ + if is_active(): + context = _get_context() + context.session_state = dict(state) diff --git a/src/zenml/deployers/server/service.py b/src/zenml/deployers/server/service.py index 97a1ef0f47d..d94c08227e9 100644 --- a/src/zenml/deployers/server/service.py +++ b/src/zenml/deployers/server/service.py @@ -32,6 +32,7 @@ import zenml.pipelines.run_utils as run_utils from zenml.client import Client +from zenml.config.deployment_settings import SessionBackendType from zenml.deployers.server import runtime from zenml.deployers.server.models import ( AppInfo, @@ -44,6 +45,11 @@ ServiceInfo, SnapshotInfo, ) +from zenml.deployers.server.sessions import ( + InMemorySessionBackend, + Session, + SessionManager, +) from zenml.deployers.utils import ( deployment_snapshot_request_from_source_snapshot, ) @@ -313,6 +319,9 @@ def initialize(self) -> None: ) self._client.zen_store.reinitialize_session() + self.session_manager: Optional[SessionManager] = None + self._configure_sessions() + # Execution tracking self.service_start_time = time.time() self.last_execution_time: Optional[datetime] = None @@ -335,6 +344,43 @@ def cleanup(self) -> None: """Execute cleanup hook if present.""" BaseOrchestrator.run_cleanup_hook(self.snapshot) + def _configure_sessions(self) -> None: + """Configure session management based on deployment settings. + + Raises: + ValueError: If an unsupported backend is configured. + """ + session_settings = self.app_runner.settings.sessions + + if not session_settings.enabled: + logger.debug("Session management is disabled") + return + + # Only in-memory backend is supported + if session_settings.backend != SessionBackendType.INMEMORY: + raise ValueError( + f"Unsupported session backend: {session_settings.backend}. " + f"Only '{SessionBackendType.INMEMORY}' backend is supported currently." + ) + + # Initialize in-memory backend with configured limits + backend = InMemorySessionBackend( + max_sessions=session_settings.max_sessions + ) + + self.session_manager = SessionManager( + backend=backend, + ttl_seconds=session_settings.ttl_seconds, + max_state_bytes=session_settings.max_state_bytes, + ) + + logger.info( + f"Session management enabled [backend={session_settings.backend}] " + f"[ttl_seconds={session_settings.ttl_seconds}] " + f"[max_state_bytes={session_settings.max_state_bytes}] " + f"[max_sessions={session_settings.max_sessions}]" + ) + def execute_pipeline( self, request: BaseDeploymentInvocationRequest, @@ -354,7 +400,18 @@ def execute_pipeline( logger.info("Starting pipeline execution") placeholder_run: Optional[PipelineRunResponse] = None + session: Optional[Session] = None + session_state_snapshot: Dict[str, Any] = {} + try: + session = self._resolve_session(request.session_id) + if session: + session_state_snapshot = dict(session.state) + logger.debug( + f"Using session [session_id={session.id}] " + f"[deployment_id={session.deployment_id}]" + ) + # Create a placeholder run separately from the actual execution, # so that we have a run ID to include in the response even if the # pipeline execution fails. @@ -364,11 +421,14 @@ def execute_pipeline( ) ) - captured_outputs = self._execute_with_orchestrator( - placeholder_run=placeholder_run, - deployment_snapshot=deployment_snapshot, - resolved_params=parameters, - skip_artifact_materialization=request.skip_artifact_materialization, + captured_outputs, session_state_snapshot = ( + self._execute_with_orchestrator( + placeholder_run=placeholder_run, + deployment_snapshot=deployment_snapshot, + resolved_params=parameters, + skip_artifact_materialization=request.skip_artifact_materialization, + session=session, + ) ) # Map outputs using fast (in-memory) or slow (artifact) path @@ -379,6 +439,7 @@ def execute_pipeline( mapped_outputs=mapped_outputs, start_time=start_time, resolved_params=parameters, + session_id=session.id if session else request.session_id, ) except Exception as e: @@ -389,7 +450,10 @@ def execute_pipeline( start_time=start_time, resolved_params=parameters, error=e, + session_id=session.id if session else request.session_id, ) + finally: + self._persist_session_state(session, session_state_snapshot) def get_service_info(self) -> ServiceInfo: """Get service information. @@ -534,8 +598,9 @@ def _execute_with_orchestrator( deployment_snapshot: PipelineSnapshotResponse, resolved_params: Dict[str, Any], skip_artifact_materialization: bool, - ) -> Optional[Dict[str, Dict[str, Any]]]: - """Run the snapshot via the orchestrator and return the concrete run. + session: Optional[Session] = None, + ) -> Tuple[Optional[Dict[str, Dict[str, Any]]], Dict[str, Any]]: + """Run the snapshot via the orchestrator and return outputs and session state. Args: placeholder_run: The placeholder run to execute the pipeline on. @@ -544,9 +609,10 @@ def _execute_with_orchestrator( resolved_params: Normalized pipeline parameters. skip_artifact_materialization: Whether runtime should skip artifact materialization. + session: Optional session for stateful execution. Returns: - The in-memory outputs of the execution. + Tuple of (in-memory outputs, final session state snapshot). Raises: RuntimeError: If the orchestrator has not been initialized. @@ -566,16 +632,18 @@ def _execute_with_orchestrator( updated=datetime.now(), ) - # Start deployment runtime context with parameters (still needed for - # in-memory materializer) runtime.start( request_id=str(uuid4()), snapshot=deployment_snapshot, parameters=resolved_params, skip_artifact_materialization=skip_artifact_materialization, + session_id=session.id if session else None, + session_state=session.state if session else None, ) captured_outputs: Optional[Dict[str, Dict[str, Any]]] = None + session_state_snapshot: Dict[str, Any] = {} + try: # Use the new deployment snapshot with pre-configured settings orchestrator.run( @@ -584,9 +652,9 @@ def _execute_with_orchestrator( placeholder_run=placeholder_run, ) - # Capture in-memory outputs before stopping the runtime context if runtime.is_active(): captured_outputs = runtime.get_outputs() + session_state_snapshot = dict(runtime.get_session_state()) except Exception as e: logger.exception(f"Failed to execute pipeline: {e}") raise RuntimeError(f"Failed to execute pipeline: {e}") @@ -594,7 +662,7 @@ def _execute_with_orchestrator( # Always stop deployment runtime context runtime.stop() - return captured_outputs + return captured_outputs, session_state_snapshot def _execute_init_hook(self) -> None: """Execute init hook if present. @@ -645,6 +713,7 @@ def _build_response( mapped_outputs: Optional[Dict[str, Any]] = None, placeholder_run: Optional[PipelineRunResponse] = None, error: Optional[Exception] = None, + session_id: Optional[str] = None, ) -> BaseDeploymentInvocationResponse: """Build success response with execution tracking. @@ -654,6 +723,7 @@ def _build_response( mapped_outputs: The mapped outputs. placeholder_run: The placeholder run that was executed. error: The error that occurred. + session_id: The session ID used for this invocation. Returns: A BaseDeploymentInvocationResponse describing the execution. @@ -691,5 +761,72 @@ def _build_response( parameters_used=resolved_params, snapshot_id=self.snapshot.id, snapshot_name=self.snapshot.name, + session_id=session_id, ), ) + + def _resolve_session( + self, requested_session_id: Optional[str] + ) -> Optional[Session]: + """Resolve or create a session for the current invocation. + + Args: + requested_session_id: Optional session ID from the request. + + Returns: + Resolved session if sessions are enabled, None otherwise. + """ + if not self.session_manager: + return None + + try: + session = self.session_manager.resolve( + requested_id=requested_session_id, + deployment_id=str(self.deployment.id), + pipeline_id=str(self.snapshot.pipeline.id), + ) + return session + except Exception as e: + logger.warning( + f"Failed to resolve session [session_id={requested_session_id}] " + f"[deployment_id={self.deployment.id}]: {e}" + ) + return None + + def _persist_session_state( + self, session: Optional[Session], state_snapshot: Dict[str, Any] + ) -> None: + """Persist updated session state to the backend. + + Args: + session: The session to update. + state_snapshot: The state snapshot to persist. + """ + if not session or not self.session_manager: + return + + # Skip persistence if state hasn't changed + if state_snapshot == session.state: + logger.debug( + f"Session state unchanged, skipping persistence " + f"[session_id={session.id}] [deployment_id={session.deployment_id}]" + ) + return + + try: + self.session_manager.persist_state(session, state_snapshot) + logger.debug( + f"Persisted session state [session_id={session.id}] " + f"[deployment_id={session.deployment_id}]" + ) + except ValueError as e: + # Size limit exceeded or serialization error + logger.warning( + f"Failed to persist session state [session_id={session.id}] " + f"[deployment_id={session.deployment_id}]: {e}" + ) + except Exception as e: + logger.error( + f"Unexpected error persisting session state " + f"[session_id={session.id}] [deployment_id={session.deployment_id}]: {e}" + ) diff --git a/src/zenml/deployers/server/sessions.py b/src/zenml/deployers/server/sessions.py new file mode 100644 index 00000000000..0cfc20b624a --- /dev/null +++ b/src/zenml/deployers/server/sessions.py @@ -0,0 +1,548 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Deployment-scoped session infrastructure. + +This module provides server-managed session storage to enable stateful +interactions across multiple deployment invocations. Sessions are scoped +to deployments and support TTL-based expiration, thread-safe concurrent +access, and optional size limits. + +Key components: +- Session: Pydantic model representing session data and metadata +- SessionBackend: Abstract interface for session storage +- InMemorySessionBackend: Thread-safe in-memory implementation with LRU eviction +- SessionManager: High-level orchestrator for session lifecycle management + +Assumptions: +- Last-write-wins semantics for concurrent updates to the same session +- Sessions are deployment-scoped; different deployments have isolated namespaces +- Expiration is lazy (checked on access) plus periodic cleanup via backend.cleanup() +""" + +import json +import threading +from abc import ABC, abstractmethod +from collections import OrderedDict +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Tuple +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field + +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class Session(BaseModel): + """Represents a deployment session with state and metadata. + + Attributes: + id: Unique session identifier (hex string). + deployment_id: ID of the deployment this session belongs to. + pipeline_id: Optional ID of the pipeline associated with this session. + state: Arbitrary JSON-serializable state dictionary. + created_at: Timestamp when the session was created (UTC). + updated_at: Timestamp when the session was last accessed/modified (UTC). + expires_at: Optional expiration timestamp (UTC); None means no expiry. + """ + + model_config = ConfigDict(extra="forbid") + + id: str + deployment_id: str + pipeline_id: Optional[str] = None + state: Dict[str, Any] = Field(default_factory=dict) + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + expires_at: Optional[datetime] = None + + def touch(self, ttl_seconds: Optional[int] = None) -> None: + """Update access timestamp and optionally extend expiration. + + Args: + ttl_seconds: If provided, set expires_at to now + ttl_seconds. + If None and expires_at is already set, leave it unchanged. + """ + now = datetime.now(timezone.utc) + self.updated_at = now + + if ttl_seconds is not None: + self.expires_at = now + timedelta(seconds=ttl_seconds) + + def is_expired(self) -> bool: + """Check if the session has expired. + + Returns: + True if expires_at is set and in the past, False otherwise. + """ + if self.expires_at is None: + return False + return datetime.now(timezone.utc) > self.expires_at + + +class SessionBackend(ABC): + """Abstract interface for session storage backends.""" + + @abstractmethod + def load(self, session_id: str, deployment_id: str) -> Optional[Session]: + """Load a session by ID within a deployment scope. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + + Returns: + The session if found and not expired, None otherwise. + """ + + @abstractmethod + def create( + self, + session_id: str, + deployment_id: str, + pipeline_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + ttl_seconds: Optional[int] = None, + ) -> Session: + """Create a new session. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + initial_state: Optional initial state dictionary. + ttl_seconds: Optional TTL in seconds; if provided, sets expires_at. + + Returns: + The created session. + + Raises: + ValueError: If a session with the same ID already exists. + """ + + @abstractmethod + def update( + self, + session_id: str, + deployment_id: str, + state: Dict[str, Any], + ttl_seconds: Optional[int] = None, + ) -> Session: + """Update an existing session's state. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + state: New state dictionary (replaces existing state). + ttl_seconds: Optional TTL to refresh expiration. + + Returns: + The updated session. + + Raises: + KeyError: If the session does not exist. + """ + + @abstractmethod + def delete(self, session_id: str, deployment_id: str) -> None: + """Delete a session. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + """ + + @abstractmethod + def cleanup(self) -> int: + """Remove all expired sessions across all deployments. + + Returns: + Number of sessions removed. + """ + + +class InMemorySessionBackend(SessionBackend): + """Thread-safe in-memory session storage with LRU eviction. + + Uses an OrderedDict to track access order for LRU eviction when + max_sessions is exceeded. All operations are guarded by a reentrant lock. + + Attributes: + max_sessions: Optional maximum number of sessions to store. When + exceeded, least-recently-used sessions are evicted. + """ + + def __init__(self, max_sessions: Optional[int] = None) -> None: + """Initialize the in-memory backend. + + Args: + max_sessions: Optional capacity limit; None means unlimited. + """ + self._sessions: OrderedDict[Tuple[str, str], Session] = OrderedDict() + self._lock = threading.RLock() + self.max_sessions = max_sessions + + def load(self, session_id: str, deployment_id: str) -> Optional[Session]: + """Load a session, performing lazy expiration removal. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + + Returns: + Deep copy of the session if found and valid, None otherwise. + """ + with self._lock: + key = (deployment_id, session_id) + session = self._sessions.get(key) + + if session is None: + return None + + # Lazy expiration check + if session.is_expired(): + logger.debug( + f"Session expired on load [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + del self._sessions[key] + return None + + # Move to end (mark as recently used) + self._sessions.move_to_end(key) + + return session.model_copy(deep=True) + + def create( + self, + session_id: str, + deployment_id: str, + pipeline_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + ttl_seconds: Optional[int] = None, + ) -> Session: + """Create a new session with optional TTL. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + initial_state: Optional initial state dictionary. + ttl_seconds: Optional TTL in seconds. + + Returns: + Deep copy of the created session. + + Raises: + ValueError: If a session with the same ID already exists. + """ + with self._lock: + key = (deployment_id, session_id) + + if key in self._sessions: + raise ValueError( + f"Session already exists [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + # Create new session + session = Session( + id=session_id, + deployment_id=deployment_id, + pipeline_id=pipeline_id, + state=initial_state or {}, + ) + + # Set expiration if TTL provided + if ttl_seconds is not None: + session.touch(ttl_seconds) + + # Store and mark as recently used + self._sessions[key] = session + self._sessions.move_to_end(key) + + # Enforce capacity limit via LRU eviction + self._evict_if_needed() + + logger.info( + f"Created session [session_id={session_id}] " + f"[deployment_id={deployment_id}] " + f"[ttl_seconds={ttl_seconds}]" + ) + + return session.model_copy(deep=True) + + def update( + self, + session_id: str, + deployment_id: str, + state: Dict[str, Any], + ttl_seconds: Optional[int] = None, + ) -> Session: + """Update session state and refresh timestamps. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + state: New state dictionary (replaces existing). + ttl_seconds: Optional TTL to refresh expiration. + + Returns: + Deep copy of the updated session. + + Raises: + KeyError: If the session does not exist. + """ + with self._lock: + key = (deployment_id, session_id) + + if key not in self._sessions: + raise KeyError( + f"Session not found [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + session = self._sessions[key] + + # Replace state with deep copy + session.state = dict(state) + + # Refresh timestamps and optionally extend expiration + session.touch(ttl_seconds) + + # Mark as recently used + self._sessions.move_to_end(key) + + logger.debug( + f"Updated session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + return session.model_copy(deep=True) + + def delete(self, session_id: str, deployment_id: str) -> None: + """Delete a session (silent if not found). + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + """ + with self._lock: + key = (deployment_id, session_id) + if key in self._sessions: + del self._sessions[key] + logger.info( + f"Deleted session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + def cleanup(self) -> int: + """Remove all expired sessions. + + Returns: + Number of sessions removed. + """ + with self._lock: + expired_keys = [ + key + for key, session in self._sessions.items() + if session.is_expired() + ] + + for key in expired_keys: + del self._sessions[key] + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired sessions") + + return len(expired_keys) + + def _evict_if_needed(self) -> None: + """Evict least-recently-used sessions if capacity exceeded. + + Must be called while holding self._lock. + """ + if self.max_sessions is None: + return + + while len(self._sessions) > self.max_sessions: + # Remove oldest (first) entry + evicted_key, evicted_session = self._sessions.popitem(last=False) + logger.warning( + f"Evicted LRU session [session_id={evicted_session.id}] " + f"[deployment_id={evicted_session.deployment_id}] " + f"due to capacity limit ({self.max_sessions})" + ) + + +class SessionManager: + """High-level orchestrator for session lifecycle management. + + Handles session resolution (get-or-create), state persistence with + size limits, and cleanup coordination. + + Attributes: + backend: The storage backend for sessions. + ttl_seconds: Default TTL for new sessions (None = no expiry). + max_state_bytes: Optional maximum size for session state in bytes. + """ + + def __init__( + self, + backend: SessionBackend, + ttl_seconds: Optional[int] = None, + max_state_bytes: Optional[int] = None, + ) -> None: + """Initialize the session manager. + + Args: + backend: Storage backend for sessions. + ttl_seconds: Default session TTL in seconds (None = no expiry). + max_state_bytes: Optional maximum state size in bytes. + """ + self.backend = backend + self.ttl_seconds = ttl_seconds + self.max_state_bytes = max_state_bytes + self._logger = logger + + def resolve( + self, + requested_id: Optional[str], + deployment_id: str, + pipeline_id: Optional[str] = None, + ) -> Session: + """Resolve a session by ID or create a new one. + + If requested_id is provided, attempts to load the existing session. + If not found or expired, creates a new session with that ID. + If requested_id is None, generates a new ID and creates a session. + + Args: + requested_id: Optional session ID to resume. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + + Returns: + The resolved or newly created session. + """ + session_id = requested_id or uuid4().hex + + # Attempt to load existing session + if requested_id: + session = self.backend.load(session_id, deployment_id) + if session: + # Refresh TTL on access + session.touch(self.ttl_seconds) + self.backend.update( + session_id, + deployment_id, + session.state, + ttl_seconds=self.ttl_seconds, + ) + self._logger.debug( + f"Resolved existing session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + return session + + # Create new session + session = self.backend.create( + session_id=session_id, + deployment_id=deployment_id, + pipeline_id=pipeline_id, + initial_state={}, + ttl_seconds=self.ttl_seconds, + ) + + self._logger.info( + f"Created new session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + return session + + def persist_state( + self, + session: Session, + new_state: Dict[str, Any], + ) -> Session: + """Persist updated state to the backend with size validation. + + Args: + session: The session to update. + new_state: New state dictionary to persist. + + Returns: + The updated session from the backend. + + Raises: + ValueError: If new_state exceeds max_state_bytes. + """ + # Enforce size limit if configured + if self.max_state_bytes is not None: + self._ensure_size_within_limit(new_state) + + updated_session = self.backend.update( + session_id=session.id, + deployment_id=session.deployment_id, + state=new_state, + ttl_seconds=self.ttl_seconds, + ) + + self._logger.debug( + f"Persisted state [session_id={session.id}] " + f"[deployment_id={session.deployment_id}]" + ) + + return updated_session + + def delete_session(self, session: Session) -> None: + """Delete a session from the backend. + + Args: + session: The session to delete. + """ + self.backend.delete(session.id, session.deployment_id) + + def _ensure_size_within_limit(self, state: Dict[str, Any]) -> None: + """Validate that state size is within configured limit. + + Args: + state: The state dictionary to validate. + + Raises: + ValueError: If state exceeds max_state_bytes. + """ + if self.max_state_bytes is None: + return + + # Serialize to JSON and measure byte size + try: + state_json = json.dumps(state, ensure_ascii=False) + state_bytes = len(state_json.encode("utf-8")) + except (TypeError, ValueError) as e: + raise ValueError( + f"Session state must be JSON-serializable: {e}" + ) from e + + if state_bytes > self.max_state_bytes: + raise ValueError( + f"Session state size ({state_bytes} bytes) exceeds " + f"maximum allowed ({self.max_state_bytes} bytes)" + ) diff --git a/src/zenml/steps/step_context.py b/src/zenml/steps/step_context.py index a60eeef5b76..f4ec35a2cc7 100644 --- a/src/zenml/steps/step_context.py +++ b/src/zenml/steps/step_context.py @@ -231,6 +231,39 @@ def pipeline_state(self) -> Optional[Any]: """ return get_or_create_run_context().state + @property + def session_state(self) -> Dict[str, Any]: + """Returns the current session state. + + This property provides access to deployment-scoped session data + that persists across multiple invocations of the same deployment. + The returned dictionary is live - modifications to it will update + the session state for the current invocation. + + Returns: + Dictionary containing session state data. Empty dict if sessions + are not enabled or no session is active. + + Example: + ```python + from zenml import step, get_step_context + + @step + def my_step() -> None: + context = get_step_context() + + # Read session state + counter = context.session_state.get("counter", 0) + + # Modify session state (changes persist to session) + context.session_state["counter"] = counter + 1 + context.session_state["last_run"] = str(datetime.now()) + ``` + """ + from zenml.deployers.server import runtime + + return runtime.get_session_state() + @property def model(self) -> "Model": """Returns configured Model. diff --git a/tests/unit/deployers/server/test_runtime.py b/tests/unit/deployers/server/test_runtime.py index 9707dd729b9..3bbd7982600 100644 --- a/tests/unit/deployers/server/test_runtime.py +++ b/tests/unit/deployers/server/test_runtime.py @@ -291,3 +291,55 @@ def test_context_reset_clears_all_data(self): with pytest.raises(KeyError): runtime.get_in_memory_data("memory://artifact/1") assert runtime.should_skip_artifact_materialization() is False + + def test_session_context_lifecycle(self): + """Test session state management during active runtime.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + # Start runtime with session context + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + session_id="session-123", + session_state={"history": []}, + ) + + # Verify session ID and state are accessible + assert runtime.get_session_id() == "session-123" + session_state = runtime.get_session_state() + assert session_state == {"history": []} + + # Mutate the state via the returned dict (should be live) + session_state["history"].append("event1") + session_state["counter"] = 1 + + # Verify mutations are reflected + updated_state = runtime.get_session_state() + assert updated_state["history"] == ["event1"] + assert updated_state["counter"] == 1 + + # Replace state entirely + runtime.set_session_state({"foo": "bar"}) + replaced_state = runtime.get_session_state() + assert replaced_state == {"foo": "bar"} + assert "history" not in replaced_state + + # Stop runtime and verify reset + runtime.stop() + assert runtime.get_session_id() is None + assert runtime.get_session_state() == {} + + def test_session_helpers_when_inactive(self): + """Test session helpers behave safely when runtime is inactive.""" + # Without starting runtime, verify safe defaults + assert runtime.get_session_id() is None + assert runtime.get_session_state() == {} + + # Attempt to set state while inactive (should be no-op) + runtime.set_session_state({"ignored": True}) + + # Verify state remains empty + assert runtime.get_session_state() == {} + assert runtime.get_session_id() is None diff --git a/tests/unit/deployers/server/test_service.py b/tests/unit/deployers/server/test_service.py index 84f04ad5bb2..5792401cd08 100644 --- a/tests/unit/deployers/server/test_service.py +++ b/tests/unit/deployers/server/test_service.py @@ -16,7 +16,7 @@ from __future__ import annotations from types import SimpleNamespace -from typing import Dict, List, Type +from typing import Any, Dict, List, Type from uuid import uuid4 import pytest @@ -64,6 +64,7 @@ def _make_snapshot() -> SimpleNamespace: source="test.module.pipeline", ) stack = SimpleNamespace(name="test_stack") + pipeline = SimpleNamespace(id=uuid4(), name="test_pipeline") return SimpleNamespace( id=uuid4(), @@ -72,6 +73,7 @@ def _make_snapshot() -> SimpleNamespace: pipeline_spec=pipeline_spec, step_configurations={}, stack=stack, + pipeline=pipeline, ) @@ -129,6 +131,7 @@ def _make_service_stub(mocker: MockerFixture) -> PipelineDeploymentService: service = PipelineDeploymentService(app_runner) service._client = mocker.MagicMock() service._orchestrator = mocker.MagicMock() + service.session_manager = None mocker.patch.object( type(service), "input_model", @@ -150,13 +153,14 @@ def test_execute_pipeline_calls_subroutines(mocker: MockerFixture) -> None: captured_outputs: Dict[str, Dict[str, object]] = { "step1": {"result": "value"} } + session_state_snapshot: Dict[str, Any] = {} mapped_outputs = {"result": "value"} service._prepare_execute_with_orchestrator = mocker.MagicMock( return_value=(placeholder_run, deployment_snapshot) ) service._execute_with_orchestrator = mocker.MagicMock( - return_value=captured_outputs + return_value=(captured_outputs, session_state_snapshot) ) service._map_outputs = mocker.MagicMock(return_value=mapped_outputs) service._build_response = mocker.MagicMock(return_value="response") @@ -175,6 +179,7 @@ def test_execute_pipeline_calls_subroutines(mocker: MockerFixture) -> None: deployment_snapshot=deployment_snapshot, resolved_params={"city": "Berlin", "temperature": 20}, skip_artifact_materialization=False, + session=None, ) service._map_outputs.assert_called_once_with(captured_outputs) service._build_response.assert_called_once() @@ -299,3 +304,97 @@ def test_input_output_schema_properties(mocker: MockerFixture) -> None: assert service.input_schema == {"type": "object"} assert service.output_schema == {"type": "object"} + + +def test_resolve_session_returns_none_without_manager( + mocker: MockerFixture, +) -> None: + """_resolve_session should return None when no session manager is configured.""" + service = _make_service_stub(mocker) + + # Ensure session_manager is None + assert service.session_manager is None + + # Should return None without error + session = service._resolve_session("some-session-id") + assert session is None + + +def test_resolve_session_invokes_manager_with_ids( + mocker: MockerFixture, +) -> None: + """_resolve_session should forward request to session manager with correct IDs.""" + service = _make_service_stub(mocker) + + # Configure mock session manager + mock_manager = mocker.MagicMock() + mock_session = SimpleNamespace( + id="session-123", + deployment_id=str(service.deployment.id), + state={"counter": 1}, + ) + mock_manager.resolve.return_value = mock_session + service.session_manager = mock_manager + + # Call _resolve_session + result = service._resolve_session("session-123") + + # Verify manager was called with correct arguments + mock_manager.resolve.assert_called_once_with( + requested_id="session-123", + deployment_id=str(service.deployment.id), + pipeline_id=str(service.snapshot.pipeline.id), + ) + + # Verify result matches mock return + assert result == mock_session + + +def test_persist_session_state_noop_when_state_unchanged( + mocker: MockerFixture, +) -> None: + """_persist_session_state should skip persistence when state hasn't changed.""" + service = _make_service_stub(mocker) + + # Configure mock session manager + mock_manager = mocker.MagicMock() + service.session_manager = mock_manager + + # Create session with initial state + session = SimpleNamespace( + id="session-123", + deployment_id=str(service.deployment.id), + state={"counter": 1, "data": "value"}, + ) + + # Call with identical state snapshot + state_snapshot = {"counter": 1, "data": "value"} + service._persist_session_state(session, state_snapshot) + + # Verify manager.persist_state was NOT called + mock_manager.persist_state.assert_not_called() + + +def test_persist_session_state_persists_changes( + mocker: MockerFixture, +) -> None: + """_persist_session_state should delegate to manager when state differs.""" + service = _make_service_stub(mocker) + + # Configure mock session manager + mock_manager = mocker.MagicMock() + service.session_manager = mock_manager + + # Create session with initial state + session = SimpleNamespace( + id="session-123", + deployment_id=str(service.deployment.id), + state={"counter": 1}, + ) + + # Call with modified state snapshot + state_snapshot = {"counter": 2, "new_field": "added"} + service._persist_session_state(session, state_snapshot) + + # Verify manager.persist_state was called with correct arguments + mock_manager.persist_state.assert_called_once_with(session, state_snapshot) diff --git a/tests/unit/deployers/server/test_sessions.py b/tests/unit/deployers/server/test_sessions.py new file mode 100644 index 00000000000..660283643df --- /dev/null +++ b/tests/unit/deployers/server/test_sessions.py @@ -0,0 +1,157 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for session infrastructure.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from zenml.deployers.server.sessions import ( + InMemorySessionBackend, + SessionManager, +) + + +class TestInMemorySessionBackend: + """Test InMemorySessionBackend storage and eviction.""" + + def test_inmemory_backend_performs_lru_eviction(self): + """Test that LRU eviction removes least-recently-used sessions.""" + backend = InMemorySessionBackend(max_sessions=1) + deployment_id = "deployment-1" + + # Create first session + s1 = backend.create( + session_id="session-1", + deployment_id=deployment_id, + initial_state={"data": "first"}, + ) + assert s1.id == "session-1" + + # Verify first session is loadable + loaded_s1 = backend.load("session-1", deployment_id) + assert loaded_s1 is not None + assert loaded_s1.id == "session-1" + + # Create second session (should evict first due to capacity=1) + s2 = backend.create( + session_id="session-2", + deployment_id=deployment_id, + initial_state={"data": "second"}, + ) + assert s2.id == "session-2" + + # First session should be evicted + evicted_s1 = backend.load("session-1", deployment_id) + assert evicted_s1 is None + + # Second session should still be present + loaded_s2 = backend.load("session-2", deployment_id) + assert loaded_s2 is not None + assert loaded_s2.id == "session-2" + + def test_inmemory_backend_lazy_expiration_removes_stale_sessions(self): + """Test that expired sessions are removed on load.""" + backend = InMemorySessionBackend() + deployment_id = "deployment-1" + session_id = "session-expired" + + # Create session with TTL + session = backend.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={"data": "value"}, + ttl_seconds=1, + ) + assert session.expires_at is not None + + # Manually expire the session by setting expires_at to the past + key = (deployment_id, session_id) + stored_session = backend._sessions[key] + stored_session.expires_at = datetime.now(timezone.utc) - timedelta( + seconds=1 + ) + + # Attempt to load - should return None and remove the session + loaded = backend.load(session_id, deployment_id) + assert loaded is None + + # Verify session was removed from storage + assert key not in backend._sessions + + +class TestSessionManager: + """Test SessionManager orchestration logic.""" + + def test_session_manager_resolves_existing_session(self): + """Test that SessionManager can resolve existing sessions by ID.""" + backend = InMemorySessionBackend() + manager = SessionManager(backend=backend, ttl_seconds=3600) + deployment_id = "deployment-1" + + # Resolve without providing ID (should generate new session) + session1 = manager.resolve( + requested_id=None, + deployment_id=deployment_id, + pipeline_id="pipeline-1", + ) + assert session1.id is not None + assert session1.deployment_id == deployment_id + + # Resolve again with the same ID (should return existing session) + session2 = manager.resolve( + requested_id=session1.id, + deployment_id=deployment_id, + pipeline_id="pipeline-1", + ) + assert session2.id == session1.id + assert session2.deployment_id == session1.deployment_id + + # Verify it's the same session (state should be preserved) + session1_updated = manager.persist_state(session1, {"counter": 42}) + session2_reloaded = manager.resolve( + requested_id=session1.id, + deployment_id=deployment_id, + ) + assert ( + session2_reloaded.state["counter"] + == session1_updated.state["counter"] + ) + + def test_session_manager_enforces_state_size_limit(self): + """Test that SessionManager rejects oversized state payloads.""" + backend = InMemorySessionBackend() + manager = SessionManager( + backend=backend, + ttl_seconds=3600, + max_state_bytes=100, # Very small limit for testing + ) + deployment_id = "deployment-1" + + # Resolve a session + session = manager.resolve( + requested_id=None, + deployment_id=deployment_id, + ) + + # Attempt to persist state that exceeds the limit + large_state = {"data": "x" * 100000} # Much larger than 10 bytes + + with pytest.raises(ValueError, match="exceeds maximum allowed"): + manager.persist_state(session, large_state) + + # Verify small state works fine + small_state = {"ok": "y"} # Should be under 10 bytes + updated = manager.persist_state(session, small_state) + assert updated.state == small_state From 8ad81b7f27cd52199f0b2ca5b6307899a64a49c2 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 9 Nov 2025 15:50:46 +0100 Subject: [PATCH 2/3] updates and fixes --- docs/book/how-to/deployment/deployment.md | 92 ++- .../how-to/deployment/deployment_settings.md | 43 ++ .../steps-pipelines/advanced_features.md | 30 +- src/zenml/config/deployment_settings.py | 24 +- src/zenml/deployers/server/service.py | 64 +- src/zenml/deployers/server/session.py | 162 ++++++ src/zenml/deployers/server/sessions.py | 548 ------------------ .../deployers/server/sessions/__init__.py | 41 ++ .../sessions/inmemory_session_backend.py | 245 ++++++++ .../server/sessions/local_session_backend.py | 459 +++++++++++++++ .../server/sessions/session_manager.py | 172 ++++++ tests/unit/deployers/server/test_sessions.py | 338 +++++++++++ 12 files changed, 1627 insertions(+), 591 deletions(-) create mode 100644 src/zenml/deployers/server/session.py delete mode 100644 src/zenml/deployers/server/sessions.py create mode 100644 src/zenml/deployers/server/sessions/__init__.py create mode 100644 src/zenml/deployers/server/sessions/inmemory_session_backend.py create mode 100644 src/zenml/deployers/server/sessions/local_session_backend.py create mode 100644 src/zenml/deployers/server/sessions/session_manager.py diff --git a/docs/book/how-to/deployment/deployment.md b/docs/book/how-to/deployment/deployment.md index 33f8f4ec59d..c6e3d568cb6 100644 --- a/docs/book/how-to/deployment/deployment.md +++ b/docs/book/how-to/deployment/deployment.md @@ -148,33 +148,79 @@ curl -X POST http://localhost:8000/invoke \ -d '{"parameters": {"city": "London", "temperature": 20}}' ``` -### Session-aware invocations +### Session-aware Invocations -Deployments remember information across `/invoke` calls because **sessions are enabled by default** (using the in-memory backend). Define the block below only if you want to opt-out or customize TTL/size limits: +Deployments support session-aware execution so that multiple `/invoke` calls can share context, which is critical for: +- LLM chat and tool-using agents +- Multi-step decision flows +- User-specific short-term state + +#### Session Basics + +- Each request may include a `session_id`. +- If no `session_id` is provided, the deployment can generate one and echo it back in `metadata.session_id`. +- Reusing the same `session_id` lets your steps read/write session-scoped state across invocations. +- Session state is stored separately from the LLM context window: it's your authoritative server-side memory. + +**Key separation**: LLMs have a model context window (tokens in a single prompt). Session state is your durable, structured memory you choose to feed into prompts. Keeping this state compact and curated is standard best practice across frameworks like LangChain/LangGraph. + +#### Configuring Sessions + +Sessions are controlled via the `deployment_settings.sessions` block: ```yaml deployment_settings: sessions: - enabled: false # disable sessions for stateless deployments - ttl_seconds: 3600 # override session expiration - max_state_bytes: 32768 # tighten state payload guardrail + enabled: true # enable/disable session support + ttl_seconds: 3600 # inactivity timeout before a session expires + max_state_bytes: 65536 # soft limit for serialized session state ``` -Key things to know when sessions are enabled: +**Fields:** + +- `enabled` + - `true` (default recommended for LLM/agent-style deployments) + - `false` for fully stateless APIs +- `ttl_seconds` + - Per-session inactivity timeout. After this, the session is expired and its state can be garbage-collected. + - Choose based on your use case (e.g. 15–30 min for chats, hours for workflows) +- `max_state_bytes` + - A guardrail on how large the serialized `session_state` can get + - Prevents unbounded growth, DoS-style misuse, and huge payloads when reloading/saving state + +#### Using Session State in Steps + +Within a deployed pipeline, steps access session state via `get_step_context()`: + +```python +from zenml import step, get_step_context + +@step +def chat_step(message: str) -> str: + ctx = get_step_context() + session_state = ctx.session_state # session dict; persisted if sessions enabled -* Clients may supply an optional `session_id` flag (CLI) or JSON field. If omitted, the deployment generates a new ID and echoes it back in `metadata.session_id`. Reuse that ID to resume the same conversation. -* The session identifier is included in deployment logs/metrics so multi-turn traces remain easy to follow. -* Steps can read/write a mutable `session_state` dict via `get_step_context().session_state`. Mutations are persisted automatically after the run. See [Managing conversational session state](../steps-pipelines/advanced_features.md#managing-conversational-session-state) for concrete step examples. + history = session_state.setdefault("history", []) + history.append({"role": "user", "content": message}) -Example invocation flow: + reply = f"Echo {len(history)}: {message}" + history.append({"role": "assistant", "content": reply}) + + session_state["last_reply"] = reply + return reply +``` + +If sessions are disabled, `session_state` behaves as an empty, non-persisted dict so the same code is safe to run. + +**Example flow:** ```bash -# First turn – server generates the session_id -zenml deployment invoke my_deployment --city="London" --temperature=20 +# 1st call – server may generate a session_id +zenml deployment invoke my_deployment --city="London" -# Follow-up turn reusing the echoed session id +# 2nd call – reusing the session zenml deployment invoke my_deployment \ - --city="Berlin" --temperature=18 \ + --city="Berlin" \ --session-id="session-12345" ``` @@ -447,7 +493,7 @@ curl -X POST http://localhost:8000/invoke \ ## Deployment Initialization, Cleanup and State -It often happens that the HTTP requests made to the same deployment share some type of initialization or cleanup or need to share the same global state or. For example: +It often happens that the HTTP requests made to the same deployment share some type of initialization or cleanup or need to share the same global state. For example: * a machine learning model needs to be loaded in memory, initialized and then shared between all the HTTP requests made to the deployment in order to be used by the deployed pipeline to make predictions @@ -493,6 +539,17 @@ The following happens when the pipeline is deployed and then later invoked: This mechanism can be used to initialize and share global state between all the HTTP requests made to the deployment or to execute long-running initialization or cleanup operations when the deployment is started or stopped rather than on each HTTP request. +{% hint style="info" %} +**Deployment State vs Session State** + +ZenML deployments support two types of state: + +- **`pipeline_state`**: Deployment-global state shared across all invocations (e.g., loaded models, DB clients, caches). Set via `on_init` hook, accessed via `get_step_context().pipeline_state`. +- **`session_state`**: Per-session state that persists across multiple invocations with the same `session_id` (e.g., conversation history, user context). Accessed via `get_step_context().session_state`. + +Use `pipeline_state` for expensive resources you want to load once and reuse, and `session_state` for conversational or multi-turn workflows where each session needs its own memory. +{% endhint %} + ## Deployment Configuration The deployer settings cover aspects of the pipeline deployment process and specific back-end infrastructure used to provision and manage the resources required to run the deployment servers. Independently of that, `DeploymentSettings` can be used to fully customize all aspects pertaining to the deployment ASGI application itself, including: @@ -542,8 +599,9 @@ For more detailed information on deployment options, see the [deployment setting 3. **Return Useful Data**: Design pipeline outputs to provide meaningful responses 4. **Use Type Annotations**: Leverage Pydantic models for complex parameter types 5. **Use Global Initialization and State**: Use the `on_init` and `on_cleanup` hooks along with the `pipeline_state` step context property to initialize and share global state between all the HTTP requests made to the deployment. Also use these hooks to execute long-running initialization or cleanup operations when the deployment is started or stopped rather than on each HTTP request. -5. **Handle Errors Gracefully**: Implement proper error handling in your steps -6. **Test Locally First**: Validate your deployable pipeline locally before deploying to production +6. **Keep Session State Small**: For session-aware deployments, store only compact summaries, IDs, and essential context in `session_state`. Move large artifacts (documents, embeddings, full histories) to external storage (vector stores, databases, object storage) and keep only references in session state. This matches best practices from frameworks like LangChain/LangGraph. +7. **Handle Errors Gracefully**: Implement proper error handling in your steps +8. **Test Locally First**: Validate your deployable pipeline locally before deploying to production ## Conclusion diff --git a/docs/book/how-to/deployment/deployment_settings.md b/docs/book/how-to/deployment/deployment_settings.md index c43698d2691..7841ea07803 100644 --- a/docs/book/how-to/deployment/deployment_settings.md +++ b/docs/book/how-to/deployment/deployment_settings.md @@ -127,6 +127,7 @@ Check out [this page](https://docs.zenml.io/concepts/steps_and_pipelines/configu `DeploymentSettings` expose the following basic customization options. The sections below provide short examples and guidance. +- sessions (session-aware invocations) - application metadata and paths - built-in endpoints and middleware toggles - static files (SPAs) and dashboards @@ -135,6 +136,48 @@ short examples and guidance. - startup and shutdown hooks - uvicorn server options, logging level, and thread pool size +### Sessions + +Configure session support for deployments that need to maintain state across multiple invocations (e.g., LLM agents, chatbots, multi-turn workflows): + +```python +from zenml.config import DeploymentSettings + +settings = DeploymentSettings( + sessions={ + "enabled": True, # Enable session support + "ttl_seconds": 1800, # 30 minute session timeout + "max_state_bytes": 32768, # 32KB state size limit + } +) +``` + +Or in YAML: + +```yaml +settings: + deployment: + sessions: + enabled: true + ttl_seconds: 1800 + max_state_bytes: 32768 +``` + +**Session Configuration Fields:** + +- `enabled` (default: `True`): Enable or disable session support. When enabled, each invocation can include a `session_id` to maintain state across calls. +- `ttl_seconds` (default: `86400` / 24 hours): Inactivity timeout before a session expires. Choose based on your use case (e.g., 15-30 min for chats, hours for workflows). +- `max_state_bytes` (default: `65536` / 64 KB): Maximum size for serialized session state. This prevents unbounded growth and potential abuse. + +**When to Use Sessions:** + +- ✅ LLM chat and conversational agents +- ✅ Multi-step workflows requiring context +- ✅ User-specific short-term state +- ❌ Fully stateless REST APIs + +For more details on using session state in your pipeline steps, see [Managing conversational session state](../steps-pipelines/advanced_features.md#managing-conversational-session-state). + ### Application metadata You can set `app_title`, `app_description`, and `app_version` to be reflected in the ASGI application's metadata: diff --git a/docs/book/how-to/steps-pipelines/advanced_features.md b/docs/book/how-to/steps-pipelines/advanced_features.md index 058508f985b..4b6c00671f5 100644 --- a/docs/book/how-to/steps-pipelines/advanced_features.md +++ b/docs/book/how-to/steps-pipelines/advanced_features.md @@ -676,24 +676,48 @@ def my_step(some_parameter: int = 1): ### Managing conversational session state -When a deployment is invoked with sessions enabled (the default behavior for deployer-based services), each step can access a per-session dictionary through the step context. This is useful for LLM workflows or any pipeline that needs to remember information across `/invoke` calls. +When a deployment is invoked with sessions enabled, each step can access a per-session dictionary through the step context. This is useful for LLM workflows, agents, or any pipeline that needs to remember information across `/invoke` calls. + +#### Understanding Deployment State + +ZenML deployments support two types of state: + +- **`pipeline_state`**: Deployment-global state shared across all invocations (e.g., loaded models, DB clients, caches). Set via `on_init` hook, accessed via `get_step_context().pipeline_state`. +- **`session_state`**: Per-session state that persists across multiple invocations with the same `session_id` (e.g., conversation history, user context). Accessed via `get_step_context().session_state`. + +This mirrors common LLM/agent designs: small short-term memory (session state) + external long-term memory (vector stores, databases). + +#### Using Session State ```python from zenml import step, get_step_context @step -def chat_step(message: str) -> str: +def agent_turn(message: str) -> str: ctx = get_step_context() session_state = ctx.session_state # Live dict persisted after the run + history = session_state.setdefault("history", []) history.append({"role": "user", "content": message}) - reply = f"Echoing turn {len(history)}: {message}" + # Use external tools/vector DB for heavy context; keep session state light + reply = plan_and_call_llm(history=history[-10:], message=message) + history.append({"role": "assistant", "content": reply}) session_state["last_reply"] = reply return reply ``` +#### Best Practices for Session State + +- **Keep it compact**: Store summaries, pointers, IDs, and essential context only +- **Push large artifacts elsewhere**: Documents, embeddings, and full histories belong in databases, vector stores, or object storage +- **Use size guardrails**: The `deployment_settings.sessions.max_state_bytes` setting (default 64 KB) prevents unbounded growth +- **Configure TTL appropriately**: Set `ttl_seconds` based on your use case (e.g., 15-30 min for chats, hours for workflows) +- **Store references, not content**: Keep file paths, document IDs, and embedding keys in session state rather than the actual data + +This approach matches best practices from frameworks like LangChain and LangGraph, where short-term working memory is kept small and structured. + If sessions are disabled for a deployment, `ctx.session_state` simply returns an empty dict, so the same code works without extra guards. ### Using Alerter in Hooks diff --git a/src/zenml/config/deployment_settings.py b/src/zenml/config/deployment_settings.py index 9f6f1af1dfa..73a88ec965f 100644 --- a/src/zenml/config/deployment_settings.py +++ b/src/zenml/config/deployment_settings.py @@ -531,23 +531,33 @@ class SessionBackendType(str, Enum): INMEMORY = "inmemory" LOCAL = "local" - REDIS = "redis" class SessionSettings(BaseModel): """Configuration for deployment session management. Sessions enable stateful interactions across multiple deployment - invocations. Phase 1 supports only in-memory storage; the schema - is forward-compatible with persistent backends (local, redis). + invocations. Two backend types are supported: + + - **inmemory**: Process-local storage. Fast and simple, but each uvicorn + worker maintains its own isolated session store. When `uvicorn_workers > 1` + or the deployment is scaled horizontally, sessions won't be shared across + workers. Best for single-worker deployments or dev/testing environments. + + - **local**: SQLite-backed storage. Sessions are persisted to disk and + shared across all uvicorn workers on the same host/VM. Survives worker + restarts. Use this for multi-worker deployments on a single machine. + Not suitable for multi-node deployments (use Redis/DB backends for that). + + **Important:** Session state must contain only JSON-serializable values. + Non-serializable values will cause persistence failures. Attributes: enabled: Whether session management is enabled for this deployment. - backend: Storage backend type (only 'inmemory' supported in Phase 1). + backend: Storage backend type ('inmemory' or 'local'). ttl_seconds: Default session TTL in seconds (None = no expiry). max_state_bytes: Maximum size for session state in bytes. - max_sessions: Maximum number of sessions to store (LRU eviction). - backend_config: Backend-specific configuration (reserved for future use). + backend_config: Configuration for the selected backend. """ model_config = ConfigDict(extra="forbid") @@ -556,7 +566,6 @@ class SessionSettings(BaseModel): backend: SessionBackendType = SessionBackendType.INMEMORY ttl_seconds: Optional[int] = 24 * 60 * 60 # 24 hours default max_state_bytes: Optional[int] = 64 * 1024 # 64 KB default - max_sessions: Optional[int] = 10_000 backend_config: Dict[str, Any] = Field(default_factory=dict) @@ -727,6 +736,7 @@ class DeploymentSettings(BaseSettings): # Pluggable app extensions for advanced features app_extensions: Optional[List[AppExtensionSpec]] = None + # Session management configuration sessions: SessionSettings = Field( default_factory=SessionSettings, title="Session management configuration.", diff --git a/src/zenml/deployers/server/service.py b/src/zenml/deployers/server/service.py index d94c08227e9..cfefc7e1e2f 100644 --- a/src/zenml/deployers/server/service.py +++ b/src/zenml/deployers/server/service.py @@ -17,6 +17,7 @@ import traceback from abc import ABC, abstractmethod from datetime import datetime, timezone +from pathlib import Path from typing import ( TYPE_CHECKING, Annotated, @@ -33,6 +34,7 @@ import zenml.pipelines.run_utils as run_utils from zenml.client import Client from zenml.config.deployment_settings import SessionBackendType +from zenml.config.global_config import GlobalConfiguration from zenml.deployers.server import runtime from zenml.deployers.server.models import ( AppInfo, @@ -45,11 +47,8 @@ ServiceInfo, SnapshotInfo, ) -from zenml.deployers.server.sessions import ( - InMemorySessionBackend, - Session, - SessionManager, -) +from zenml.deployers.server.session import Session +from zenml.deployers.server.sessions import SessionManager from zenml.deployers.utils import ( deployment_snapshot_request_from_source_snapshot, ) @@ -350,23 +349,57 @@ def _configure_sessions(self) -> None: Raises: ValueError: If an unsupported backend is configured. """ + from zenml.deployers.server.session import SessionBackend + from zenml.deployers.server.sessions import ( + InMemoryBackendConfig, + InMemorySessionBackend, + LocalBackendConfig, + LocalSessionBackend, + SessionManager, + ) + session_settings = self.app_runner.settings.sessions if not session_settings.enabled: logger.debug("Session management is disabled") return - # Only in-memory backend is supported - if session_settings.backend != SessionBackendType.INMEMORY: - raise ValueError( - f"Unsupported session backend: {session_settings.backend}. " - f"Only '{SessionBackendType.INMEMORY}' backend is supported currently." + backend_type = session_settings.backend + backend: SessionBackend + + if backend_type == SessionBackendType.INMEMORY: + if self.app_runner.settings.uvicorn_workers > 1: + logger.warning( + "In-memory sessions are worker-local and not shared across " + "uvicorn workers. Set `uvicorn_workers=1` or switch to a " + "shared backend like 'local'." + ) + + inmemory_cfg = InMemoryBackendConfig( + **session_settings.backend_config ) + backend = InMemorySessionBackend(config=inmemory_cfg) + + elif backend_type == SessionBackendType.LOCAL: + local_cfg = LocalBackendConfig(**session_settings.backend_config) + db_path = local_cfg.database_path + if not db_path: + local_root = GlobalConfiguration().local_stores_path + db_dir = ( + Path(local_root) / "deployments" / str(self.deployment.id) + ) + db_path = str(db_dir / "zenml_deployment_sessions.db") - # Initialize in-memory backend with configured limits - backend = InMemorySessionBackend( - max_sessions=session_settings.max_sessions - ) + backend = LocalSessionBackend( + config=local_cfg, + db_path=db_path, + ) + + else: + raise ValueError( + f"Unsupported session backend: {backend_type}. " + "Supported: 'inmemory', 'local'." + ) self.session_manager = SessionManager( backend=backend, @@ -377,8 +410,7 @@ def _configure_sessions(self) -> None: logger.info( f"Session management enabled [backend={session_settings.backend}] " f"[ttl_seconds={session_settings.ttl_seconds}] " - f"[max_state_bytes={session_settings.max_state_bytes}] " - f"[max_sessions={session_settings.max_sessions}]" + f"[max_state_bytes={session_settings.max_state_bytes}]" ) def execute_pipeline( diff --git a/src/zenml/deployers/server/session.py b/src/zenml/deployers/server/session.py new file mode 100644 index 00000000000..4810b377ccc --- /dev/null +++ b/src/zenml/deployers/server/session.py @@ -0,0 +1,162 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Session models and abstractions for deployment session management.""" + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class BaseBackendConfig(BaseModel): + """Base configuration shared by all session backends. + + This class doesn't define any fields itself but serves as the parent + for all backend-specific configurations, enabling polymorphic config handling. + """ + + model_config = ConfigDict(extra="forbid") + + +class Session(BaseModel): + """Represents a deployment session with state and metadata. + + Attributes: + id: Unique session identifier (hex string). + deployment_id: ID of the deployment this session belongs to. + pipeline_id: Optional ID of the pipeline associated with this session. + state: Arbitrary JSON-serializable state dictionary. + created_at: Timestamp when the session was created (UTC). + updated_at: Timestamp when the session was last accessed/modified (UTC). + expires_at: Optional expiration timestamp (UTC); None means no expiry. + """ + + model_config = ConfigDict(extra="forbid") + + id: str + deployment_id: str + pipeline_id: Optional[str] = None + state: Dict[str, Any] = Field(default_factory=dict) + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + expires_at: Optional[datetime] = None + + def touch(self, ttl_seconds: Optional[int] = None) -> None: + """Update access timestamp and optionally extend expiration. + + Args: + ttl_seconds: If provided, set expires_at to now + ttl_seconds. + If None and expires_at is already set, leave it unchanged. + """ + now = datetime.now(timezone.utc) + self.updated_at = now + + if ttl_seconds is not None: + self.expires_at = now + timedelta(seconds=ttl_seconds) + + def is_expired(self) -> bool: + """Check if the session has expired. + + Returns: + True if expires_at is set and in the past, False otherwise. + """ + if self.expires_at is None: + return False + return datetime.now(timezone.utc) > self.expires_at + + +class SessionBackend(ABC): + """Abstract interface for session storage backends.""" + + @abstractmethod + def load(self, session_id: str, deployment_id: str) -> Optional[Session]: + """Load a session by ID within a deployment scope. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + + Returns: + The session if found and not expired, None otherwise. + """ + + @abstractmethod + def create( + self, + session_id: str, + deployment_id: str, + pipeline_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + ttl_seconds: Optional[int] = None, + ) -> Session: + """Create a new session. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + initial_state: Optional initial state dictionary. + ttl_seconds: Optional TTL in seconds; if provided, sets expires_at. + + Returns: + The created session. + + Raises: + ValueError: If a session with the same ID already exists. + """ + + @abstractmethod + def update( + self, + session_id: str, + deployment_id: str, + state: Dict[str, Any], + ttl_seconds: Optional[int] = None, + ) -> Session: + """Update an existing session's state. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + state: New state dictionary (replaces existing state). + ttl_seconds: Optional TTL to refresh expiration. + + Returns: + The updated session. + + Raises: + KeyError: If the session does not exist. + """ + + @abstractmethod + def delete(self, session_id: str, deployment_id: str) -> None: + """Delete a session. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + """ + + @abstractmethod + def cleanup(self) -> int: + """Remove all expired sessions across all deployments. + + Returns: + Number of sessions removed. + """ diff --git a/src/zenml/deployers/server/sessions.py b/src/zenml/deployers/server/sessions.py deleted file mode 100644 index 0cfc20b624a..00000000000 --- a/src/zenml/deployers/server/sessions.py +++ /dev/null @@ -1,548 +0,0 @@ -# Copyright (c) ZenML GmbH 2025. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Deployment-scoped session infrastructure. - -This module provides server-managed session storage to enable stateful -interactions across multiple deployment invocations. Sessions are scoped -to deployments and support TTL-based expiration, thread-safe concurrent -access, and optional size limits. - -Key components: -- Session: Pydantic model representing session data and metadata -- SessionBackend: Abstract interface for session storage -- InMemorySessionBackend: Thread-safe in-memory implementation with LRU eviction -- SessionManager: High-level orchestrator for session lifecycle management - -Assumptions: -- Last-write-wins semantics for concurrent updates to the same session -- Sessions are deployment-scoped; different deployments have isolated namespaces -- Expiration is lazy (checked on access) plus periodic cleanup via backend.cleanup() -""" - -import json -import threading -from abc import ABC, abstractmethod -from collections import OrderedDict -from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Optional, Tuple -from uuid import uuid4 - -from pydantic import BaseModel, ConfigDict, Field - -from zenml.logger import get_logger - -logger = get_logger(__name__) - - -class Session(BaseModel): - """Represents a deployment session with state and metadata. - - Attributes: - id: Unique session identifier (hex string). - deployment_id: ID of the deployment this session belongs to. - pipeline_id: Optional ID of the pipeline associated with this session. - state: Arbitrary JSON-serializable state dictionary. - created_at: Timestamp when the session was created (UTC). - updated_at: Timestamp when the session was last accessed/modified (UTC). - expires_at: Optional expiration timestamp (UTC); None means no expiry. - """ - - model_config = ConfigDict(extra="forbid") - - id: str - deployment_id: str - pipeline_id: Optional[str] = None - state: Dict[str, Any] = Field(default_factory=dict) - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - expires_at: Optional[datetime] = None - - def touch(self, ttl_seconds: Optional[int] = None) -> None: - """Update access timestamp and optionally extend expiration. - - Args: - ttl_seconds: If provided, set expires_at to now + ttl_seconds. - If None and expires_at is already set, leave it unchanged. - """ - now = datetime.now(timezone.utc) - self.updated_at = now - - if ttl_seconds is not None: - self.expires_at = now + timedelta(seconds=ttl_seconds) - - def is_expired(self) -> bool: - """Check if the session has expired. - - Returns: - True if expires_at is set and in the past, False otherwise. - """ - if self.expires_at is None: - return False - return datetime.now(timezone.utc) > self.expires_at - - -class SessionBackend(ABC): - """Abstract interface for session storage backends.""" - - @abstractmethod - def load(self, session_id: str, deployment_id: str) -> Optional[Session]: - """Load a session by ID within a deployment scope. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - - Returns: - The session if found and not expired, None otherwise. - """ - - @abstractmethod - def create( - self, - session_id: str, - deployment_id: str, - pipeline_id: Optional[str] = None, - initial_state: Optional[Dict[str, Any]] = None, - ttl_seconds: Optional[int] = None, - ) -> Session: - """Create a new session. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - pipeline_id: Optional pipeline identifier. - initial_state: Optional initial state dictionary. - ttl_seconds: Optional TTL in seconds; if provided, sets expires_at. - - Returns: - The created session. - - Raises: - ValueError: If a session with the same ID already exists. - """ - - @abstractmethod - def update( - self, - session_id: str, - deployment_id: str, - state: Dict[str, Any], - ttl_seconds: Optional[int] = None, - ) -> Session: - """Update an existing session's state. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - state: New state dictionary (replaces existing state). - ttl_seconds: Optional TTL to refresh expiration. - - Returns: - The updated session. - - Raises: - KeyError: If the session does not exist. - """ - - @abstractmethod - def delete(self, session_id: str, deployment_id: str) -> None: - """Delete a session. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - """ - - @abstractmethod - def cleanup(self) -> int: - """Remove all expired sessions across all deployments. - - Returns: - Number of sessions removed. - """ - - -class InMemorySessionBackend(SessionBackend): - """Thread-safe in-memory session storage with LRU eviction. - - Uses an OrderedDict to track access order for LRU eviction when - max_sessions is exceeded. All operations are guarded by a reentrant lock. - - Attributes: - max_sessions: Optional maximum number of sessions to store. When - exceeded, least-recently-used sessions are evicted. - """ - - def __init__(self, max_sessions: Optional[int] = None) -> None: - """Initialize the in-memory backend. - - Args: - max_sessions: Optional capacity limit; None means unlimited. - """ - self._sessions: OrderedDict[Tuple[str, str], Session] = OrderedDict() - self._lock = threading.RLock() - self.max_sessions = max_sessions - - def load(self, session_id: str, deployment_id: str) -> Optional[Session]: - """Load a session, performing lazy expiration removal. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - - Returns: - Deep copy of the session if found and valid, None otherwise. - """ - with self._lock: - key = (deployment_id, session_id) - session = self._sessions.get(key) - - if session is None: - return None - - # Lazy expiration check - if session.is_expired(): - logger.debug( - f"Session expired on load [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - del self._sessions[key] - return None - - # Move to end (mark as recently used) - self._sessions.move_to_end(key) - - return session.model_copy(deep=True) - - def create( - self, - session_id: str, - deployment_id: str, - pipeline_id: Optional[str] = None, - initial_state: Optional[Dict[str, Any]] = None, - ttl_seconds: Optional[int] = None, - ) -> Session: - """Create a new session with optional TTL. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - pipeline_id: Optional pipeline identifier. - initial_state: Optional initial state dictionary. - ttl_seconds: Optional TTL in seconds. - - Returns: - Deep copy of the created session. - - Raises: - ValueError: If a session with the same ID already exists. - """ - with self._lock: - key = (deployment_id, session_id) - - if key in self._sessions: - raise ValueError( - f"Session already exists [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - - # Create new session - session = Session( - id=session_id, - deployment_id=deployment_id, - pipeline_id=pipeline_id, - state=initial_state or {}, - ) - - # Set expiration if TTL provided - if ttl_seconds is not None: - session.touch(ttl_seconds) - - # Store and mark as recently used - self._sessions[key] = session - self._sessions.move_to_end(key) - - # Enforce capacity limit via LRU eviction - self._evict_if_needed() - - logger.info( - f"Created session [session_id={session_id}] " - f"[deployment_id={deployment_id}] " - f"[ttl_seconds={ttl_seconds}]" - ) - - return session.model_copy(deep=True) - - def update( - self, - session_id: str, - deployment_id: str, - state: Dict[str, Any], - ttl_seconds: Optional[int] = None, - ) -> Session: - """Update session state and refresh timestamps. - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - state: New state dictionary (replaces existing). - ttl_seconds: Optional TTL to refresh expiration. - - Returns: - Deep copy of the updated session. - - Raises: - KeyError: If the session does not exist. - """ - with self._lock: - key = (deployment_id, session_id) - - if key not in self._sessions: - raise KeyError( - f"Session not found [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - - session = self._sessions[key] - - # Replace state with deep copy - session.state = dict(state) - - # Refresh timestamps and optionally extend expiration - session.touch(ttl_seconds) - - # Mark as recently used - self._sessions.move_to_end(key) - - logger.debug( - f"Updated session [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - - return session.model_copy(deep=True) - - def delete(self, session_id: str, deployment_id: str) -> None: - """Delete a session (silent if not found). - - Args: - session_id: The session identifier. - deployment_id: The deployment identifier. - """ - with self._lock: - key = (deployment_id, session_id) - if key in self._sessions: - del self._sessions[key] - logger.info( - f"Deleted session [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - - def cleanup(self) -> int: - """Remove all expired sessions. - - Returns: - Number of sessions removed. - """ - with self._lock: - expired_keys = [ - key - for key, session in self._sessions.items() - if session.is_expired() - ] - - for key in expired_keys: - del self._sessions[key] - - if expired_keys: - logger.info(f"Cleaned up {len(expired_keys)} expired sessions") - - return len(expired_keys) - - def _evict_if_needed(self) -> None: - """Evict least-recently-used sessions if capacity exceeded. - - Must be called while holding self._lock. - """ - if self.max_sessions is None: - return - - while len(self._sessions) > self.max_sessions: - # Remove oldest (first) entry - evicted_key, evicted_session = self._sessions.popitem(last=False) - logger.warning( - f"Evicted LRU session [session_id={evicted_session.id}] " - f"[deployment_id={evicted_session.deployment_id}] " - f"due to capacity limit ({self.max_sessions})" - ) - - -class SessionManager: - """High-level orchestrator for session lifecycle management. - - Handles session resolution (get-or-create), state persistence with - size limits, and cleanup coordination. - - Attributes: - backend: The storage backend for sessions. - ttl_seconds: Default TTL for new sessions (None = no expiry). - max_state_bytes: Optional maximum size for session state in bytes. - """ - - def __init__( - self, - backend: SessionBackend, - ttl_seconds: Optional[int] = None, - max_state_bytes: Optional[int] = None, - ) -> None: - """Initialize the session manager. - - Args: - backend: Storage backend for sessions. - ttl_seconds: Default session TTL in seconds (None = no expiry). - max_state_bytes: Optional maximum state size in bytes. - """ - self.backend = backend - self.ttl_seconds = ttl_seconds - self.max_state_bytes = max_state_bytes - self._logger = logger - - def resolve( - self, - requested_id: Optional[str], - deployment_id: str, - pipeline_id: Optional[str] = None, - ) -> Session: - """Resolve a session by ID or create a new one. - - If requested_id is provided, attempts to load the existing session. - If not found or expired, creates a new session with that ID. - If requested_id is None, generates a new ID and creates a session. - - Args: - requested_id: Optional session ID to resume. - deployment_id: The deployment identifier. - pipeline_id: Optional pipeline identifier. - - Returns: - The resolved or newly created session. - """ - session_id = requested_id or uuid4().hex - - # Attempt to load existing session - if requested_id: - session = self.backend.load(session_id, deployment_id) - if session: - # Refresh TTL on access - session.touch(self.ttl_seconds) - self.backend.update( - session_id, - deployment_id, - session.state, - ttl_seconds=self.ttl_seconds, - ) - self._logger.debug( - f"Resolved existing session [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - return session - - # Create new session - session = self.backend.create( - session_id=session_id, - deployment_id=deployment_id, - pipeline_id=pipeline_id, - initial_state={}, - ttl_seconds=self.ttl_seconds, - ) - - self._logger.info( - f"Created new session [session_id={session_id}] " - f"[deployment_id={deployment_id}]" - ) - - return session - - def persist_state( - self, - session: Session, - new_state: Dict[str, Any], - ) -> Session: - """Persist updated state to the backend with size validation. - - Args: - session: The session to update. - new_state: New state dictionary to persist. - - Returns: - The updated session from the backend. - - Raises: - ValueError: If new_state exceeds max_state_bytes. - """ - # Enforce size limit if configured - if self.max_state_bytes is not None: - self._ensure_size_within_limit(new_state) - - updated_session = self.backend.update( - session_id=session.id, - deployment_id=session.deployment_id, - state=new_state, - ttl_seconds=self.ttl_seconds, - ) - - self._logger.debug( - f"Persisted state [session_id={session.id}] " - f"[deployment_id={session.deployment_id}]" - ) - - return updated_session - - def delete_session(self, session: Session) -> None: - """Delete a session from the backend. - - Args: - session: The session to delete. - """ - self.backend.delete(session.id, session.deployment_id) - - def _ensure_size_within_limit(self, state: Dict[str, Any]) -> None: - """Validate that state size is within configured limit. - - Args: - state: The state dictionary to validate. - - Raises: - ValueError: If state exceeds max_state_bytes. - """ - if self.max_state_bytes is None: - return - - # Serialize to JSON and measure byte size - try: - state_json = json.dumps(state, ensure_ascii=False) - state_bytes = len(state_json.encode("utf-8")) - except (TypeError, ValueError) as e: - raise ValueError( - f"Session state must be JSON-serializable: {e}" - ) from e - - if state_bytes > self.max_state_bytes: - raise ValueError( - f"Session state size ({state_bytes} bytes) exceeds " - f"maximum allowed ({self.max_state_bytes} bytes)" - ) diff --git a/src/zenml/deployers/server/sessions/__init__.py b/src/zenml/deployers/server/sessions/__init__.py new file mode 100644 index 00000000000..b396791bdbc --- /dev/null +++ b/src/zenml/deployers/server/sessions/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Session backend implementations. + +This package contains concrete implementations of the SessionBackend interface: +- InMemorySessionBackend: Thread-safe in-memory implementation with LRU eviction +- LocalSessionBackend: SQLite-backed implementation for multi-worker single-host deployments +- SessionManager: High-level orchestrator for session lifecycle management + +The core abstractions (Session, SessionBackend) are defined in the parent +sessions.py module. +""" +from zenml.deployers.server.sessions.inmemory_session_backend import ( + InMemoryBackendConfig, + InMemorySessionBackend, +) +from zenml.deployers.server.sessions.local_session_backend import ( + LocalBackendConfig, + LocalSessionBackend, +) +from zenml.deployers.server.sessions.session_manager import SessionManager + +__all__ = [ + "InMemoryBackendConfig", + "LocalBackendConfig", + "InMemorySessionBackend", + "LocalSessionBackend", + "SessionManager", +] + diff --git a/src/zenml/deployers/server/sessions/inmemory_session_backend.py b/src/zenml/deployers/server/sessions/inmemory_session_backend.py new file mode 100644 index 00000000000..77ba35ad014 --- /dev/null +++ b/src/zenml/deployers/server/sessions/inmemory_session_backend.py @@ -0,0 +1,245 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""In-memory session backend implementation.""" + +import threading +from collections import OrderedDict +from typing import Any, Dict, Optional, Tuple + +from pydantic import Field + +from zenml.deployers.server.session import ( + BaseBackendConfig, + Session, + SessionBackend, +) +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class InMemoryBackendConfig(BaseBackendConfig): + """Configuration for in-memory session backend. + + Attributes: + max_sessions: Maximum number of sessions to store. When exceeded, + least-recently-used sessions are evicted. None means unlimited. + """ + + max_sessions: Optional[int] = Field( + default=10_000, + description="Maximum number of sessions (LRU eviction when exceeded)", + ) + + +class InMemorySessionBackend(SessionBackend): + """Thread-safe in-memory session storage with LRU eviction. + + Uses an OrderedDict to track access order for LRU eviction when + max_sessions is exceeded. All operations are guarded by a reentrant lock. + + Attributes: + config: Backend configuration. + """ + + def __init__(self, config: InMemoryBackendConfig) -> None: + """Initialize the in-memory backend. + + Args: + config: Configuration for the backend. + """ + self.config = config + self._sessions: OrderedDict[Tuple[str, str], Session] = OrderedDict() + self._lock = threading.RLock() + + def load(self, session_id: str, deployment_id: str) -> Optional[Session]: + """Load a session, performing lazy expiration removal. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + + Returns: + Deep copy of the session if found and valid, None otherwise. + """ + with self._lock: + key = (deployment_id, session_id) + session = self._sessions.get(key) + + if session is None: + return None + + if session.is_expired(): + logger.debug( + f"Session expired on load [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + del self._sessions[key] + return None + + self._sessions.move_to_end(key) + return session.model_copy(deep=True) + + def create( + self, + session_id: str, + deployment_id: str, + pipeline_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + ttl_seconds: Optional[int] = None, + ) -> Session: + """Create a new session with optional TTL. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + initial_state: Optional initial state dictionary. + ttl_seconds: Optional TTL in seconds. + + Returns: + Deep copy of the created session. + + Raises: + ValueError: If a session with the same ID already exists. + """ + with self._lock: + key = (deployment_id, session_id) + + if key in self._sessions: + raise ValueError( + f"Session already exists [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + session = Session( + id=session_id, + deployment_id=deployment_id, + pipeline_id=pipeline_id, + state=initial_state or {}, + ) + + if ttl_seconds is not None: + session.touch(ttl_seconds) + + self._sessions[key] = session + self._sessions.move_to_end(key) + + self._evict_if_needed() + + logger.info( + f"Created session [session_id={session_id}] " + f"[deployment_id={deployment_id}] " + f"[ttl_seconds={ttl_seconds}]" + ) + + return session.model_copy(deep=True) + + def update( + self, + session_id: str, + deployment_id: str, + state: Dict[str, Any], + ttl_seconds: Optional[int] = None, + ) -> Session: + """Update session state and refresh timestamps. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + state: New state dictionary (replaces existing). + ttl_seconds: Optional TTL to refresh expiration. + + Returns: + Deep copy of the updated session. + + Raises: + KeyError: If the session does not exist. + """ + with self._lock: + key = (deployment_id, session_id) + + if key not in self._sessions: + raise KeyError( + f"Session not found [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + session = self._sessions[key] + + session.state = dict(state) + + session.touch(ttl_seconds) + + self._sessions.move_to_end(key) + + logger.debug( + f"Updated session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + return session.model_copy(deep=True) + + def delete(self, session_id: str, deployment_id: str) -> None: + """Delete a session (silent if not found). + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + """ + with self._lock: + key = (deployment_id, session_id) + if key in self._sessions: + del self._sessions[key] + logger.info( + f"Deleted session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + def cleanup(self) -> int: + """Remove all expired sessions. + + Returns: + Number of sessions removed. + """ + with self._lock: + expired_keys = [ + key + for key, session in self._sessions.items() + if session.is_expired() + ] + + for key in expired_keys: + del self._sessions[key] + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired sessions") + + return len(expired_keys) + + def _evict_if_needed(self) -> None: + """Evict least-recently-used sessions if capacity exceeded. + + Must be called while holding self._lock. + """ + if self.config.max_sessions is None: + return + + while len(self._sessions) > self.config.max_sessions: + evicted_key, evicted_session = self._sessions.popitem(last=False) + logger.warning( + f"Evicted LRU session [session_id={evicted_session.id}] " + f"[deployment_id={evicted_session.deployment_id}] " + f"due to capacity limit ({self.config.max_sessions})" + ) diff --git a/src/zenml/deployers/server/sessions/local_session_backend.py b/src/zenml/deployers/server/sessions/local_session_backend.py new file mode 100644 index 00000000000..5f5f2f1b795 --- /dev/null +++ b/src/zenml/deployers/server/sessions/local_session_backend.py @@ -0,0 +1,459 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Local SQLite-based session backend implementation.""" + +import json +import sqlite3 +import threading +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Tuple, Union + +from pydantic import Field + +from zenml.deployers.server.session import ( + BaseBackendConfig, + Session, + SessionBackend, +) +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class LocalBackendConfig(BaseBackendConfig): + """Configuration for local SQLite session backend. + + Attributes: + database_path: Custom path to SQLite database file. If not provided, + defaults to /deployments//zenml_deployment_sessions.db + journal_mode: SQLite journal mode (WAL recommended for concurrent access). + synchronous: SQLite synchronous mode (NORMAL balances durability and performance). + timeout: Connection timeout in seconds. + max_retry_attempts: Maximum number of retry attempts for locked database. + retry_base_delay: Base delay in seconds for exponential backoff retries. + """ + + database_path: Optional[str] = Field( + default=None, + description="Custom DB file path (optional)", + ) + journal_mode: Literal["WAL", "DELETE"] = Field( + default="WAL", + description="SQLite journal mode", + ) + synchronous: Literal["OFF", "NORMAL", "FULL"] = Field( + default="NORMAL", + description="SQLite synchronous mode", + ) + timeout: float = Field( + default=5.0, + ge=0.1, + description="Connection timeout in seconds", + ) + max_retry_attempts: int = Field( + default=3, + ge=1, + description="Max retries for locked DB", + ) + retry_base_delay: float = Field( + default=0.05, + ge=0.01, + description="Base delay for retry backoff", + ) + + +# SQLite DDL for local backend +_CREATE_SESSIONS_TABLE = """ +CREATE TABLE IF NOT EXISTS sessions ( + deployment_id TEXT NOT NULL, + session_id TEXT NOT NULL, + pipeline_id TEXT, + state_json TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + expires_at TEXT, + PRIMARY KEY (deployment_id, session_id) +) +""" + +_CREATE_SESSIONS_INDEX = """ +CREATE INDEX IF NOT EXISTS idx_sessions_expires + ON sessions (deployment_id, expires_at) +""" + + +class LocalSessionBackend(SessionBackend): + """SQLite-based session storage shared by uvicorn workers on one host. + + This backend uses a local SQLite database file to persist sessions, + enabling multiple uvicorn workers to share session state on the same + host/VM. It uses WAL mode for concurrent access and provides the same + semantics as the in-memory backend. + + Attributes: + config: Backend configuration. + db_path: Path to the SQLite database file (derived from config). + """ + + def __init__( + self, + config: LocalBackendConfig, + db_path: Union[str, Path], + ) -> None: + """Initialize the local session backend. + + Args: + config: Configuration for the backend. + db_path: Path to the SQLite database file (overrides config.database_path if provided). + """ + self.config = config + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + self._lock = threading.RLock() + self._conn = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + isolation_level=None, # Manual transaction control + timeout=self.config.timeout, + detect_types=sqlite3.PARSE_DECLTYPES, + ) + self._conn.row_factory = sqlite3.Row + + with self._lock, self._conn: + self._conn.execute( + f"PRAGMA journal_mode={self.config.journal_mode}" + ) + self._conn.execute(f"PRAGMA synchronous={self.config.synchronous}") + self._conn.execute("PRAGMA foreign_keys=ON") + self._conn.execute( + f"PRAGMA busy_timeout={int(self.config.timeout * 1000)}" + ) + self._conn.execute(_CREATE_SESSIONS_TABLE) + self._conn.execute(_CREATE_SESSIONS_INDEX) + + logger.info( + f"Initialized local session backend [db_path={self.db_path}] " + f"[journal_mode={self.config.journal_mode}] [synchronous={self.config.synchronous}]" + ) + + def __del__(self) -> None: + """Close connection on cleanup.""" + try: + self._conn.close() + except Exception as e: + logger.debug("Error closing connection: %s", e) + + def load(self, session_id: str, deployment_id: str) -> Optional[Session]: + """Load a session, performing lazy expiration removal. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + + Returns: + The session if found and not expired, None otherwise. + """ + with self._lock, self._conn: + row = self._conn.execute( + """ + SELECT session_id, pipeline_id, state_json, + created_at, updated_at, expires_at + FROM sessions + WHERE deployment_id = ? AND session_id = ? + """, + (deployment_id, session_id), + ).fetchone() + + if not row: + return None + + session = self._row_to_session(row, deployment_id) + + # Lazy expiration check + if session.is_expired(): + logger.debug( + f"Session expired on load [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + self.delete(session_id, deployment_id) + return None + + return session + + def create( + self, + session_id: str, + deployment_id: str, + pipeline_id: Optional[str] = None, + initial_state: Optional[Dict[str, Any]] = None, + ttl_seconds: Optional[int] = None, + ) -> Session: + """Create a new session with optional TTL. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + initial_state: Optional initial state dictionary. + ttl_seconds: Optional TTL in seconds. + + Returns: + The created session. + + Raises: + ValueError: If a session with the same ID already exists. + """ + now = datetime.now(timezone.utc) + state = initial_state or {} + state_json = json.dumps(state, ensure_ascii=False) + + created_at = now.isoformat() + updated_at = now.isoformat() + expires_at = None + if ttl_seconds is not None: + expires_at = (now + timedelta(seconds=ttl_seconds)).isoformat() + + try: + self._execute_with_retry( + """ + INSERT INTO sessions ( + deployment_id, session_id, pipeline_id, + state_json, created_at, updated_at, expires_at + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + deployment_id, + session_id, + pipeline_id, + state_json, + created_at, + updated_at, + expires_at, + ), + ) + except sqlite3.IntegrityError as e: + raise ValueError( + f"Session already exists [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) from e + + logger.info( + f"Created session [session_id={session_id}] " + f"[deployment_id={deployment_id}] " + f"[ttl_seconds={ttl_seconds}]" + ) + + return Session( + id=session_id, + deployment_id=deployment_id, + pipeline_id=pipeline_id, + state=state, + created_at=now, + updated_at=now, + expires_at=(now + timedelta(seconds=ttl_seconds)) + if ttl_seconds is not None + else None, + ) + + def update( + self, + session_id: str, + deployment_id: str, + state: Dict[str, Any], + ttl_seconds: Optional[int] = None, + ) -> Session: + """Update session state and refresh timestamps. + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + state: New state dictionary (replaces existing). + ttl_seconds: Optional TTL to refresh expiration. + + Returns: + The updated session. + + Raises: + KeyError: If the session does not exist. + """ + now = datetime.now(timezone.utc) + state_json = json.dumps(state, ensure_ascii=False) + updated_at = now.isoformat() + + expires_at = None + if ttl_seconds is not None: + expires_at = (now + timedelta(seconds=ttl_seconds)).isoformat() + + cursor = self._execute_with_retry( + """ + UPDATE sessions + SET state_json = ?, + updated_at = ?, + expires_at = COALESCE(?, expires_at) + WHERE deployment_id = ? AND session_id = ? + """, + (state_json, updated_at, expires_at, deployment_id, session_id), + ) + + if cursor.rowcount == 0: + raise KeyError( + f"Session not found [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + logger.debug( + f"Updated session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + with self._lock, self._conn: + row = self._conn.execute( + """ + SELECT session_id, pipeline_id, state_json, + created_at, updated_at, expires_at + FROM sessions + WHERE deployment_id = ? AND session_id = ? + """, + (deployment_id, session_id), + ).fetchone() + + if not row: + raise KeyError( + f"Session not found after update [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + return self._row_to_session(row, deployment_id) + + def delete(self, session_id: str, deployment_id: str) -> None: + """Delete a session (silent if not found). + + Args: + session_id: The session identifier. + deployment_id: The deployment identifier. + """ + cursor = self._execute_with_retry( + """ + DELETE FROM sessions + WHERE deployment_id = ? AND session_id = ? + """, + (deployment_id, session_id), + ) + + if cursor.rowcount > 0: + logger.info( + f"Deleted session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + def cleanup(self) -> int: + """Remove all expired sessions. + + Returns: + Number of sessions removed. + """ + now = datetime.now(timezone.utc).isoformat() + + cursor = self._execute_with_retry( + """ + DELETE FROM sessions + WHERE expires_at IS NOT NULL AND expires_at <= ? + """, + (now,), + ) + + removed_count = cursor.rowcount + if removed_count > 0: + logger.info(f"Cleaned up {removed_count} expired sessions") + + return removed_count + + def _execute_with_retry( + self, + sql: str, + params: Tuple[Any, ...] = (), + ) -> sqlite3.Cursor: + """Execute SQL with retry logic for locked database. + + Args: + sql: SQL statement to execute. + params: Parameters for the SQL statement. + + Returns: + The cursor from the execution. + + Raises: + sqlite3.OperationalError: If database remains locked after retries. + Exception: If any other exception occurs during execution. + OperationalError: If database remains locked after retries. + """ + delay = self.config.retry_base_delay + + for attempt in range(self.config.max_retry_attempts): + try: + with self._lock: + self._conn.execute("BEGIN IMMEDIATE") + try: + cursor = self._conn.execute(sql, params) + self._conn.commit() + return cursor + except Exception: + self._conn.rollback() + raise + except sqlite3.OperationalError as e: + if "database is locked" not in str(e).lower(): + raise + if attempt == self.config.max_retry_attempts - 1: + logger.error( + f"Database locked after {self.config.max_retry_attempts} attempts" + ) + raise + + time.sleep(delay) + delay *= 2 + + raise sqlite3.OperationalError("Unexpected retry loop exit") + + def _row_to_session(self, row: sqlite3.Row, deployment_id: str) -> Session: + """Convert a database row to a Session object. + + Args: + row: SQLite row from sessions table. + deployment_id: The deployment identifier. + + Returns: + Session object constructed from the row. + """ + state = json.loads(row["state_json"]) + + created_at = datetime.fromisoformat(row["created_at"]) + updated_at = datetime.fromisoformat(row["updated_at"]) + expires_at = ( + datetime.fromisoformat(row["expires_at"]) + if row["expires_at"] + else None + ) + + return Session( + id=row["session_id"], + deployment_id=deployment_id, + pipeline_id=row["pipeline_id"], + state=state, + created_at=created_at, + updated_at=updated_at, + expires_at=expires_at, + ) diff --git a/src/zenml/deployers/server/sessions/session_manager.py b/src/zenml/deployers/server/sessions/session_manager.py new file mode 100644 index 00000000000..0e2f9049293 --- /dev/null +++ b/src/zenml/deployers/server/sessions/session_manager.py @@ -0,0 +1,172 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Session manager for orchestrating session lifecycle.""" + +import json +from typing import Any, Dict, Optional +from uuid import uuid4 + +from zenml.deployers.server.session import Session, SessionBackend +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class SessionManager: + """High-level orchestrator for session lifecycle management. + + Handles session resolution (get-or-create), state persistence with + size limits, and cleanup coordination. + + Attributes: + backend: The storage backend for sessions. + ttl_seconds: Default TTL for new sessions (None = no expiry). + max_state_bytes: Optional maximum size for session state in bytes. + """ + + def __init__( + self, + backend: SessionBackend, + ttl_seconds: Optional[int] = None, + max_state_bytes: Optional[int] = None, + ) -> None: + """Initialize the session manager. + + Args: + backend: Storage backend for sessions. + ttl_seconds: Default session TTL in seconds (None = no expiry). + max_state_bytes: Optional maximum state size in bytes. + """ + self.backend = backend + self.ttl_seconds = ttl_seconds + self.max_state_bytes = max_state_bytes + self._logger = logger + + def resolve( + self, + requested_id: Optional[str], + deployment_id: str, + pipeline_id: Optional[str] = None, + ) -> Session: + """Resolve a session by ID or create a new one. + + If requested_id is provided, attempts to load the existing session. + If not found or expired, creates a new session with that ID. + If requested_id is None, generates a new ID and creates a session. + + Args: + requested_id: Optional session ID to resume. + deployment_id: The deployment identifier. + pipeline_id: Optional pipeline identifier. + + Returns: + The resolved or newly created session. + """ + session_id = requested_id or uuid4().hex + + if requested_id: + session = self.backend.load(session_id, deployment_id) + if session: + session.touch(self.ttl_seconds) + self.backend.update( + session_id, + deployment_id, + session.state, + ttl_seconds=self.ttl_seconds, + ) + self._logger.debug( + f"Resolved existing session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + return session + + session = self.backend.create( + session_id=session_id, + deployment_id=deployment_id, + pipeline_id=pipeline_id, + initial_state={}, + ttl_seconds=self.ttl_seconds, + ) + + self._logger.info( + f"Created new session [session_id={session_id}] " + f"[deployment_id={deployment_id}]" + ) + + return session + + def persist_state( + self, + session: Session, + new_state: Dict[str, Any], + ) -> Session: + """Persist updated state to the backend with size validation. + + Args: + session: The session to update. + new_state: New state dictionary to persist. + + Returns: + The updated session from the backend. + """ + if self.max_state_bytes is not None: + self._ensure_size_within_limit(new_state) + + updated_session = self.backend.update( + session_id=session.id, + deployment_id=session.deployment_id, + state=new_state, + ttl_seconds=self.ttl_seconds, + ) + + self._logger.debug( + f"Persisted state [session_id={session.id}] " + f"[deployment_id={session.deployment_id}]" + ) + + return updated_session + + def delete_session(self, session: Session) -> None: + """Delete a session from the backend. + + Args: + session: The session to delete. + """ + self.backend.delete(session.id, session.deployment_id) + + def _ensure_size_within_limit(self, state: Dict[str, Any]) -> None: + """Validate that state size is within configured limit. + + Args: + state: The state dictionary to validate. + + Raises: + ValueError: If state exceeds max_state_bytes. + """ + if self.max_state_bytes is None: + return + + try: + state_json = json.dumps(state, ensure_ascii=False) + state_bytes = len(state_json.encode("utf-8")) + except (TypeError, ValueError) as e: + raise ValueError( + f"Session state must be JSON-serializable: {e}" + ) from e + + if state_bytes > self.max_state_bytes: + raise ValueError( + f"Session state size ({state_bytes} bytes) exceeds " + f"maximum allowed ({self.max_state_bytes} bytes)" + ) diff --git a/tests/unit/deployers/server/test_sessions.py b/tests/unit/deployers/server/test_sessions.py index 660283643df..616809a8047 100644 --- a/tests/unit/deployers/server/test_sessions.py +++ b/tests/unit/deployers/server/test_sessions.py @@ -13,12 +13,16 @@ # permissions and limitations under the License. """Unit tests for session infrastructure.""" +import multiprocessing +import tempfile from datetime import datetime, timedelta, timezone +from pathlib import Path import pytest from zenml.deployers.server.sessions import ( InMemorySessionBackend, + LocalSessionBackend, SessionManager, ) @@ -155,3 +159,337 @@ def test_session_manager_enforces_state_size_limit(self): small_state = {"ok": "y"} # Should be under 10 bytes updated = manager.persist_state(session, small_state) assert updated.state == small_state + + +class TestLocalSessionBackend: + """Test LocalSessionBackend SQLite-based storage.""" + + @pytest.fixture + def temp_db_path(self): + """Provide a temporary database path for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "sessions.db" + + def test_local_backend_crud_roundtrip(self, temp_db_path): + """Test basic CRUD operations on local backend.""" + backend = LocalSessionBackend(db_path=temp_db_path) + deployment_id = "deployment-1" + session_id = "session-1" + + # Create session + session = backend.create( + session_id=session_id, + deployment_id=deployment_id, + pipeline_id="pipeline-1", + initial_state={"counter": 0}, + ttl_seconds=3600, + ) + assert session.id == session_id + assert session.deployment_id == deployment_id + assert session.pipeline_id == "pipeline-1" + assert session.state == {"counter": 0} + assert session.expires_at is not None + + # Load session + loaded = backend.load(session_id, deployment_id) + assert loaded is not None + assert loaded.id == session_id + assert loaded.state == {"counter": 0} + + # Update session + updated = backend.update( + session_id, + deployment_id, + {"counter": 42}, + ttl_seconds=7200, + ) + assert updated.state == {"counter": 42} + + # Verify update persisted + reloaded = backend.load(session_id, deployment_id) + assert reloaded is not None + assert reloaded.state == {"counter": 42} + + # Delete session + backend.delete(session_id, deployment_id) + deleted = backend.load(session_id, deployment_id) + assert deleted is None + + def test_local_backend_duplicate_create_raises(self, temp_db_path): + """Test that creating a duplicate session raises ValueError.""" + backend = LocalSessionBackend(db_path=temp_db_path) + deployment_id = "deployment-1" + session_id = "session-1" + + backend.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={}, + ) + + with pytest.raises(ValueError, match="already exists"): + backend.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={}, + ) + + def test_local_backend_update_missing_raises(self, temp_db_path): + """Test that updating a non-existent session raises KeyError.""" + backend = LocalSessionBackend(db_path=temp_db_path) + + with pytest.raises(KeyError, match="not found"): + backend.update( + session_id="nonexistent", + deployment_id="deployment-1", + state={}, + ) + + def test_local_backend_delete_is_idempotent(self, temp_db_path): + """Test that deleting a non-existent session is silent.""" + backend = LocalSessionBackend(db_path=temp_db_path) + # Should not raise + backend.delete("nonexistent", "deployment-1") + + def test_local_backend_lazy_expiration(self, temp_db_path): + """Test that expired sessions are removed on load.""" + backend = LocalSessionBackend(db_path=temp_db_path) + deployment_id = "deployment-1" + session_id = "session-expired" + + # Create session with very short TTL + backend.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={"data": "value"}, + ttl_seconds=1, + ) + + # Manually set expiration to the past + + with backend._lock, backend._conn: + past_time = ( + datetime.now(timezone.utc) - timedelta(seconds=10) + ).isoformat() + backend._conn.execute( + "UPDATE sessions SET expires_at = ? WHERE session_id = ?", + (past_time, session_id), + ) + backend._conn.commit() + + # Load should detect expiration and delete + loaded = backend.load(session_id, deployment_id) + assert loaded is None + + # Verify session was deleted + with backend._lock, backend._conn: + row = backend._conn.execute( + "SELECT * FROM sessions WHERE session_id = ?", + (session_id,), + ).fetchone() + assert row is None + + def test_local_backend_cleanup_removes_expired(self, temp_db_path): + """Test that cleanup removes all expired sessions.""" + backend = LocalSessionBackend(db_path=temp_db_path) + deployment_id = "deployment-1" + + # Create valid session + backend.create( + session_id="session-valid", + deployment_id=deployment_id, + initial_state={}, + ttl_seconds=3600, + ) + + # Create expired session + backend.create( + session_id="session-expired", + deployment_id=deployment_id, + initial_state={}, + ttl_seconds=1, + ) + + # Manually expire the second session + + with backend._lock, backend._conn: + past_time = ( + datetime.now(timezone.utc) - timedelta(seconds=10) + ).isoformat() + backend._conn.execute( + "UPDATE sessions SET expires_at = ? WHERE session_id = ?", + (past_time, "session-expired"), + ) + backend._conn.commit() + + # Run cleanup + removed_count = backend.cleanup() + assert removed_count == 1 + + # Valid session should still exist + valid = backend.load("session-valid", deployment_id) + assert valid is not None + + # Expired session should be gone + expired = backend.load("session-expired", deployment_id) + assert expired is None + + def test_local_backend_deployment_isolation(self, temp_db_path): + """Test that sessions are isolated by deployment_id.""" + backend = LocalSessionBackend(db_path=temp_db_path) + + # Create sessions in different deployments + backend.create( + session_id="session-1", + deployment_id="deployment-A", + initial_state={"deploy": "A"}, + ) + backend.create( + session_id="session-1", + deployment_id="deployment-B", + initial_state={"deploy": "B"}, + ) + + # Both should be loadable with correct data + session_a = backend.load("session-1", "deployment-A") + session_b = backend.load("session-1", "deployment-B") + + assert session_a is not None + assert session_b is not None + assert session_a.state == {"deploy": "A"} + assert session_b.state == {"deploy": "B"} + + def test_local_backend_survives_reconnection(self, temp_db_path): + """Test that sessions persist across backend instances.""" + deployment_id = "deployment-1" + session_id = "session-persistent" + + # Create session with first backend instance + backend1 = LocalSessionBackend(db_path=temp_db_path) + backend1.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={"value": 123}, + ) + del backend1 + + # Load session with second backend instance + backend2 = LocalSessionBackend(db_path=temp_db_path) + loaded = backend2.load(session_id, deployment_id) + assert loaded is not None + assert loaded.state == {"value": 123} + + def test_local_backend_concurrent_writes(self, temp_db_path): + """Test that concurrent writes are handled correctly.""" + backend = LocalSessionBackend(db_path=temp_db_path) + deployment_id = "deployment-1" + + # Create multiple sessions rapidly + import threading + + def create_session(session_id): + try: + backend.create( + session_id=session_id, + deployment_id=deployment_id, + initial_state={"id": session_id}, + ) + except Exception: + pass # Ignore errors for this test + + threads = [] + for i in range(10): + t = threading.Thread(target=create_session, args=(f"session-{i}",)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Verify sessions were created + for i in range(10): + loaded = backend.load(f"session-{i}", deployment_id) + assert loaded is not None + + +class TestLocalSessionBackendMultiprocess: + """Test LocalSessionBackend multi-process scenarios.""" + + @pytest.fixture + def temp_db_path(self): + """Provide a temporary database path for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "sessions.db" + + def test_multiprocess_shared_state(self, temp_db_path): + """Test that multiple processes can share session state.""" + + def process_writer(db_path, session_id, value): + """Process that writes a session.""" + backend = LocalSessionBackend(db_path=db_path) + backend.create( + session_id=session_id, + deployment_id="deployment-1", + initial_state={"value": value}, + ) + + def process_reader(db_path, session_id, expected_value): + """Process that reads a session.""" + backend = LocalSessionBackend(db_path=db_path) + session = backend.load(session_id, "deployment-1") + assert session is not None + assert session.state["value"] == expected_value + + # Process A writes + p1 = multiprocessing.Process( + target=process_writer, args=(temp_db_path, "session-1", 42) + ) + p1.start() + p1.join() + assert p1.exitcode == 0 + + # Process B reads + p2 = multiprocessing.Process( + target=process_reader, args=(temp_db_path, "session-1", 42) + ) + p2.start() + p2.join() + assert p2.exitcode == 0 + + +class TestSessionManagerWithLocalBackend: + """Test SessionManager with LocalSessionBackend.""" + + @pytest.fixture + def temp_db_path(self): + """Provide a temporary database path for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "sessions.db" + + def test_session_manager_with_local_backend(self, temp_db_path): + """Test SessionManager works correctly with local backend.""" + backend = LocalSessionBackend(db_path=temp_db_path) + manager = SessionManager( + backend=backend, ttl_seconds=3600, max_state_bytes=1024 + ) + deployment_id = "deployment-1" + + # Resolve new session + session = manager.resolve( + requested_id=None, + deployment_id=deployment_id, + pipeline_id="pipeline-1", + ) + assert session.id is not None + + # Persist state + updated = manager.persist_state(session, {"counter": 1}) + assert updated.state == {"counter": 1} + + # Resolve existing session + resolved = manager.resolve( + requested_id=session.id, + deployment_id=deployment_id, + ) + assert resolved.id == session.id + assert resolved.state == {"counter": 1} From c57ab711e695b0a509a7bf2d4fc96fa74887793f Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 9 Nov 2025 16:06:50 +0100 Subject: [PATCH 3/3] fix tests --- tests/unit/deployers/server/test_sessions.py | 90 ++++++++++++-------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/tests/unit/deployers/server/test_sessions.py b/tests/unit/deployers/server/test_sessions.py index 616809a8047..c4adcd6ec0d 100644 --- a/tests/unit/deployers/server/test_sessions.py +++ b/tests/unit/deployers/server/test_sessions.py @@ -21,18 +21,43 @@ import pytest from zenml.deployers.server.sessions import ( + InMemoryBackendConfig, InMemorySessionBackend, + LocalBackendConfig, LocalSessionBackend, SessionManager, ) +# Module-level helper functions for multiprocessing tests +# These need to be at module level to be picklable +def _process_writer(db_path, session_id, value): + """Process that writes a session.""" + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=db_path) + backend.create( + session_id=session_id, + deployment_id="deployment-1", + initial_state={"value": value}, + ) + + +def _process_reader(db_path, session_id, expected_value): + """Process that reads a session.""" + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=db_path) + session = backend.load(session_id, "deployment-1") + assert session is not None + assert session.state["value"] == expected_value + + class TestInMemorySessionBackend: """Test InMemorySessionBackend storage and eviction.""" def test_inmemory_backend_performs_lru_eviction(self): """Test that LRU eviction removes least-recently-used sessions.""" - backend = InMemorySessionBackend(max_sessions=1) + config = InMemoryBackendConfig(max_sessions=1) + backend = InMemorySessionBackend(config=config) deployment_id = "deployment-1" # Create first session @@ -67,7 +92,8 @@ def test_inmemory_backend_performs_lru_eviction(self): def test_inmemory_backend_lazy_expiration_removes_stale_sessions(self): """Test that expired sessions are removed on load.""" - backend = InMemorySessionBackend() + config = InMemoryBackendConfig() + backend = InMemorySessionBackend(config=config) deployment_id = "deployment-1" session_id = "session-expired" @@ -100,7 +126,8 @@ class TestSessionManager: def test_session_manager_resolves_existing_session(self): """Test that SessionManager can resolve existing sessions by ID.""" - backend = InMemorySessionBackend() + config = InMemoryBackendConfig() + backend = InMemorySessionBackend(config=config) manager = SessionManager(backend=backend, ttl_seconds=3600) deployment_id = "deployment-1" @@ -135,7 +162,8 @@ def test_session_manager_resolves_existing_session(self): def test_session_manager_enforces_state_size_limit(self): """Test that SessionManager rejects oversized state payloads.""" - backend = InMemorySessionBackend() + config = InMemoryBackendConfig() + backend = InMemorySessionBackend(config=config) manager = SessionManager( backend=backend, ttl_seconds=3600, @@ -172,7 +200,8 @@ def temp_db_path(self): def test_local_backend_crud_roundtrip(self, temp_db_path): """Test basic CRUD operations on local backend.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) deployment_id = "deployment-1" session_id = "session-1" @@ -217,7 +246,8 @@ def test_local_backend_crud_roundtrip(self, temp_db_path): def test_local_backend_duplicate_create_raises(self, temp_db_path): """Test that creating a duplicate session raises ValueError.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) deployment_id = "deployment-1" session_id = "session-1" @@ -236,7 +266,8 @@ def test_local_backend_duplicate_create_raises(self, temp_db_path): def test_local_backend_update_missing_raises(self, temp_db_path): """Test that updating a non-existent session raises KeyError.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) with pytest.raises(KeyError, match="not found"): backend.update( @@ -247,13 +278,15 @@ def test_local_backend_update_missing_raises(self, temp_db_path): def test_local_backend_delete_is_idempotent(self, temp_db_path): """Test that deleting a non-existent session is silent.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) # Should not raise backend.delete("nonexistent", "deployment-1") def test_local_backend_lazy_expiration(self, temp_db_path): """Test that expired sessions are removed on load.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) deployment_id = "deployment-1" session_id = "session-expired" @@ -291,7 +324,8 @@ def test_local_backend_lazy_expiration(self, temp_db_path): def test_local_backend_cleanup_removes_expired(self, temp_db_path): """Test that cleanup removes all expired sessions.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) deployment_id = "deployment-1" # Create valid session @@ -336,7 +370,8 @@ def test_local_backend_cleanup_removes_expired(self, temp_db_path): def test_local_backend_deployment_isolation(self, temp_db_path): """Test that sessions are isolated by deployment_id.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) # Create sessions in different deployments backend.create( @@ -365,7 +400,8 @@ def test_local_backend_survives_reconnection(self, temp_db_path): session_id = "session-persistent" # Create session with first backend instance - backend1 = LocalSessionBackend(db_path=temp_db_path) + config1 = LocalBackendConfig() + backend1 = LocalSessionBackend(config=config1, db_path=temp_db_path) backend1.create( session_id=session_id, deployment_id=deployment_id, @@ -374,14 +410,16 @@ def test_local_backend_survives_reconnection(self, temp_db_path): del backend1 # Load session with second backend instance - backend2 = LocalSessionBackend(db_path=temp_db_path) + config2 = LocalBackendConfig() + backend2 = LocalSessionBackend(config=config2, db_path=temp_db_path) loaded = backend2.load(session_id, deployment_id) assert loaded is not None assert loaded.state == {"value": 123} def test_local_backend_concurrent_writes(self, temp_db_path): """Test that concurrent writes are handled correctly.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) deployment_id = "deployment-1" # Create multiple sessions rapidly @@ -423,26 +461,9 @@ def temp_db_path(self): def test_multiprocess_shared_state(self, temp_db_path): """Test that multiple processes can share session state.""" - - def process_writer(db_path, session_id, value): - """Process that writes a session.""" - backend = LocalSessionBackend(db_path=db_path) - backend.create( - session_id=session_id, - deployment_id="deployment-1", - initial_state={"value": value}, - ) - - def process_reader(db_path, session_id, expected_value): - """Process that reads a session.""" - backend = LocalSessionBackend(db_path=db_path) - session = backend.load(session_id, "deployment-1") - assert session is not None - assert session.state["value"] == expected_value - # Process A writes p1 = multiprocessing.Process( - target=process_writer, args=(temp_db_path, "session-1", 42) + target=_process_writer, args=(temp_db_path, "session-1", 42) ) p1.start() p1.join() @@ -450,7 +471,7 @@ def process_reader(db_path, session_id, expected_value): # Process B reads p2 = multiprocessing.Process( - target=process_reader, args=(temp_db_path, "session-1", 42) + target=_process_reader, args=(temp_db_path, "session-1", 42) ) p2.start() p2.join() @@ -468,7 +489,8 @@ def temp_db_path(self): def test_session_manager_with_local_backend(self, temp_db_path): """Test SessionManager works correctly with local backend.""" - backend = LocalSessionBackend(db_path=temp_db_path) + config = LocalBackendConfig() + backend = LocalSessionBackend(config=config, db_path=temp_db_path) manager = SessionManager( backend=backend, ttl_seconds=3600, max_state_bytes=1024 )