diff --git a/.gitignore b/.gitignore index 62e7b92..f71d659 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ Untitled* .vscode/ playground/ +# pixi environments +.pixi/* +!.pixi/config.toml diff --git a/jupyter_server_documents/app.py b/jupyter_server_documents/app.py index 0fd1981..d9bd5fb 100644 --- a/jupyter_server_documents/app.py +++ b/jupyter_server_documents/app.py @@ -1,7 +1,7 @@ from jupyter_server.extension.application import ExtensionApp from traitlets.config import Config - from traitlets import Instance, Type + from .handlers import RouteHandler, FileIDIndexHandler from .websockets import YRoomWebsocket from .rooms.yroom_manager import YRoomManager @@ -80,12 +80,43 @@ def get_fileid_manager(): ) def _link_jupyter_server_extension(self, server_app): - """Setup custom config needed by this extension.""" + """Setup custom config needed by this extension. + + Only applies configuration if not already set by user config. + """ c = Config() - c.ServerApp.kernel_websocket_connection_class = "jupyter_server_documents.kernels.websocket_connection.NextGenKernelWebsocketConnection" - c.ServerApp.kernel_manager_class = "jupyter_server_documents.kernels.multi_kernel_manager.NextGenMappingKernelManager" - c.MultiKernelManager.kernel_manager_class = "jupyter_server_documents.kernels.kernel_manager.NextGenKernelManager" - c.ServerApp.session_manager_class = "jupyter_server_documents.session_manager.YDocSessionManager" + + # Configure kernel manager classes to use nextgen-kernels-api + if not server_app.config.ServerApp.get("kernel_manager_class"): + c.ServerApp.kernel_manager_class = "nextgen_kernels_api.services.kernels.kernelmanager.MultiKernelManager" + + if not server_app.config.ServerApp.get("kernel_websocket_connection_class"): + c.ServerApp.kernel_websocket_connection_class = "nextgen_kernels_api.services.kernels.connection.kernel_client_connection.KernelClientWebsocketConnection" + + if not server_app.config.ServerApp.get("session_manager_class"): + c.ServerApp.session_manager_class = "jupyter_server_documents.session_manager.YDocSessionManager" + + # Configure kernel manager hierarchy + if not server_app.config.MultiKernelManager.get("kernel_manager_class"): + c.MultiKernelManager.kernel_manager_class = "nextgen_kernels_api.services.kernels.kernelmanager.KernelManager" + + # Configure kernel client + if not server_app.config.KernelManager.get("client_class"): + c.KernelManager.client_class = "jupyter_server_documents.kernel_client.DocumentAwareKernelClient" + c.KernelManager.client_factory = "jupyter_server_documents.kernel_client.DocumentAwareKernelClient" + + # Configure websocket message filtering + if not server_app.config.KernelClientWebsocketConnection.get("exclude_msg_types"): + c.KernelClientWebsocketConnection.exclude_msg_types = [ + ("status", "iopub"), + ("stream", "iopub"), + ("display_data", "iopub"), + ("execute_result", "iopub"), + ("error", "iopub"), + ("update_display_data", "iopub"), + ("clear_output", "iopub"), + ] + server_app.update_config(c) super()._link_jupyter_server_extension(server_app) diff --git a/jupyter_server_documents/kernel_client.py b/jupyter_server_documents/kernel_client.py new file mode 100644 index 0000000..c3993b6 --- /dev/null +++ b/jupyter_server_documents/kernel_client.py @@ -0,0 +1,219 @@ +"""Document-aware kernel client for collaborative notebook editing. + +This module extends nextgen-kernels-api's JupyterServerKernelClient to add +notebook-specific functionality required for real-time collaboration: + +- Routes kernel messages to collaborative YRooms for document state synchronization +- Processes and separates large outputs to optimize document size +- Tracks cell execution states and updates awareness for real-time UI feedback +- Manages notebook metadata updates from kernel info +""" +import asyncio +import typing as t + +from nextgen_kernels_api.services.kernels.client import JupyterServerKernelClient +from traitlets import Instance, Set, Type, default + +from jupyter_server_documents.outputs import OutputProcessor +from jupyter_server_documents.rooms.yroom import YRoom + + +class DocumentAwareKernelClient(JupyterServerKernelClient): + """Kernel client with collaborative document awareness and output processing. + + Extends the base JupyterServerKernelClient to integrate with YRooms for + real-time collaboration, process outputs for optimization, and track cell + execution states across connected clients. + """ + + _yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set()) + + output_processor = Instance(OutputProcessor, allow_none=True) + + output_processor_class = Type( + klass=OutputProcessor, default_value=OutputProcessor + ).tag(config=True) + + @default("output_processor") + def _default_output_processor(self) -> OutputProcessor: + return self.output_processor_class(parent=self, config=self.config) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Register listener for document-related messages + # Combines state updates and outputs to share deserialization logic + self.add_listener( + self._handle_document_messages, + msg_types=[ + ("kernel_info_reply", "shell"), + ("status", "iopub"), + ("execute_input", "iopub"), + ("stream", "iopub"), + ("display_data", "iopub"), + ("execute_result", "iopub"), + ("error", "iopub"), + ("update_display_data", "iopub"), + ("clear_output", "iopub"), + ], + ) + + async def _handle_document_messages(self, channel_name: str, msg: list[bytes]): + """Route kernel messages to document state and output handlers. + + Deserializes kernel protocol messages and dispatches them to appropriate + handlers based on message type. Extracts parent message and cell ID context + needed by most handlers. + """ + if channel_name not in ("iopub", "shell"): + return + + # Deserialize message components + # Base client strips signature, leaving [header, parent_header, metadata, content, ...buffers] + try: + if len(msg) < 4: + self.log.debug(f"Message too short: {len(msg)} parts") + return + + header = self.session.unpack(msg[0]) + parent_header = self.session.unpack(msg[1]) + metadata = self.session.unpack(msg[2]) + + dmsg = { + "header": header, + "parent_header": parent_header, + "metadata": metadata, + "content": msg[3], # Keep as bytes, unpack in handlers + "buffers": msg[4:] if len(msg) > 4 else [], + "msg_id": header["msg_id"], + "msg_type": header["msg_type"], + } + except Exception as e: + self.log.debug(f"Skipping message that can't be deserialized: {e}") + return + + # Extract parent message context for cell ID lookup + parent_msg_id = dmsg.get("parent_header", {}).get("msg_id") + parent_msg_data = self.message_cache.get(parent_msg_id) if parent_msg_id else None + cell_id = parent_msg_data.get("cell_id") if parent_msg_data else None + + # Dispatch to appropriate handler + msg_type = dmsg.get("msg_type") + match msg_type: + case "kernel_info_reply": + await self._handle_kernel_info_reply(dmsg) + case "status": + await self._handle_status_message(dmsg, parent_msg_data, cell_id) + case "execute_input": + await self._handle_execute_input(dmsg, cell_id) + case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output": + await self._handle_output_message(dmsg, msg_type, cell_id) + + async def _handle_kernel_info_reply(self, msg: dict): + """Update notebook metadata with kernel language info.""" + content = self.session.unpack(msg["content"]) + language_info = content.get("language_info") + + if language_info: + for yroom in self._yrooms: + try: + notebook = await yroom.get_jupyter_ydoc() + metadata = notebook.ymeta + metadata["metadata"]["language_info"] = language_info + except Exception as e: + self.log.warning(f"Failed to update language info for yroom: {e}") + + async def _handle_status_message( + self, dmsg: dict, parent_msg_data: dict | None, cell_id: str | None + ): + """Update kernel and cell execution states from status messages. + + Updates both document-level kernel status and cell-specific execution states, + storing them persistently and in awareness for real-time UI updates. + """ + content = self.session.unpack(dmsg["content"]) + execution_state = content.get("execution_state") + + for yroom in self._yrooms: + awareness = yroom.get_awareness() + if awareness is None: + continue + + # Update document-level kernel status if this is a top-level status message + if parent_msg_data and parent_msg_data.get("channel") == "shell": + awareness.set_local_state_field( + "kernel", {"execution_state": execution_state} + ) + + # Update cell execution state for persistence and awareness + if cell_id: + yroom.set_cell_execution_state(cell_id, execution_state) + yroom.set_cell_awareness_state(cell_id, execution_state) + break + + async def _handle_execute_input(self, dmsg: dict, cell_id: str | None): + """Update cell execution count when execution begins.""" + if not cell_id: + return + + content = self.session.unpack(dmsg["content"]) + execution_count = content.get("execution_count") + + if execution_count is not None: + for yroom in self._yrooms: + notebook = await yroom.get_jupyter_ydoc() + _, target_cell = notebook.find_cell(cell_id) + if target_cell: + target_cell["execution_count"] = execution_count + break + + async def _handle_output_message(self, dmsg: dict, msg_type: str, cell_id: str | None): + """Process output messages through output processor.""" + if not cell_id: + return + + if self.output_processor: + content = self.session.unpack(dmsg["content"]) + self.output_processor.process_output(msg_type, cell_id, content) + else: + self.log.warning("No output processor configured") + + async def add_yroom(self, yroom: YRoom): + """Register a YRoom to receive kernel messages.""" + self._yrooms.add(yroom) + + async def remove_yroom(self, yroom: YRoom): + """Unregister a YRoom from receiving kernel messages.""" + self._yrooms.discard(yroom) + + def handle_incoming_message(self, channel_name: str, msg: list[bytes]): + """Handle messages from WebSocket clients before routing to kernel. + + Extends base implementation to: + - Set cell awareness to 'busy' immediately on execute_request + - Clear outputs when cell is re-executed + + This ensures UI updates happen immediately rather than waiting for + kernel processing, providing better UX for queued executions. + """ + try: + header = self.session.unpack(msg[0]) + msg_id = header["msg_id"] + msg_type = header.get("msg_type") + metadata = self.session.unpack(msg[2]) + cell_id = metadata.get("cellId") + + if cell_id: + # Clear outputs if this is a re-execution of the same cell + existing = self.message_cache.get(cell_id=cell_id) + if existing and existing["msg_id"] != msg_id: + asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id)) + + # Set awareness state immediately for queued cells + if msg_type == "execute_request" and channel_name == "shell": + for yroom in self._yrooms: + yroom.set_cell_awareness_state(cell_id, "busy") + except Exception as e: + self.log.debug(f"Error handling awareness for incoming message: {e}") + + super().handle_incoming_message(channel_name, msg) diff --git a/jupyter_server_documents/kernels/__init__.py b/jupyter_server_documents/kernels/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jupyter_server_documents/kernels/kernel_client.py b/jupyter_server_documents/kernels/kernel_client.py deleted file mode 100644 index ba2f9d9..0000000 --- a/jupyter_server_documents/kernels/kernel_client.py +++ /dev/null @@ -1,334 +0,0 @@ -""" -A new Kernel client that is aware of ydocuments. -""" -import anyio -import asyncio -import json -import typing as t - -from traitlets import Set, Instance, Any, Type, default -from jupyter_client.asynchronous.client import AsyncKernelClient - -from .message_cache import KernelMessageCache -from jupyter_server_documents.rooms.yroom import YRoom -from jupyter_server_documents.outputs import OutputProcessor -from jupyter_server.utils import ensure_async - -from .kernel_client_abc import AbstractDocumentAwareKernelClient - - -class DocumentAwareKernelClient(AsyncKernelClient): - """ - A kernel client that routes messages to registered ydocs. - """ - # Having this message cache is not ideal. - # Unfortunately, we don't include the parent channel - # in the messages that generate IOPub status messages, thus, - # we can't differential between the control channel vs. - # shell channel status. This message cache gives us - # the ability to map status message back to their source. - message_cache = Instance( - klass=KernelMessageCache - ) - - @default('message_cache') - def _default_message_cache(self): - return KernelMessageCache(parent=self) - - # A set of callables that are called when a kernel - # message is received. - _listeners = Set(allow_none=True) - - # A set of YRooms that will intercept output and kernel - # status messages. - _yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set()) - - - output_processor = Instance( - OutputProcessor, - allow_none=True - ) - - output_process_class = Type( - klass=OutputProcessor, - default_value=OutputProcessor - ).tag(config=True) - - @default("output_processor") - def _default_output_processor(self) -> OutputProcessor: - self.log.info("Creating output processor") - return self.output_process_class(parent=self, config=self.config) - - async def start_listening(self): - """Start listening to messages coming from the kernel. - - Use anyio to setup a task group for listening. - """ - # Wrap a taskgroup so that it can be backgrounded. - async def _listening(): - async with anyio.create_task_group() as tg: - for channel_name in ["shell", "control", "stdin", "iopub"]: - tg.start_soon( - self._listen_for_messages, channel_name - ) - - # Background this task. - self._listening_task = asyncio.create_task(_listening()) - - async def stop_listening(self): - """Stop listening to the kernel. - """ - # If the listening task isn't defined yet - # do nothing. - if not hasattr(self, '_listening_task') or not self._listening_task: - return - - # Attempt to cancel the task. - try: - self._listening_task.cancel() - # Await cancellation. - await self._listening_task - except asyncio.CancelledError: - self.log.info("Disconnected from client from the kernel.") - # Log any exceptions that were raised. - except Exception as err: - self.log.error(err) - finally: - # Clear the task reference - self._listening_task = None - - _listening_task: t.Optional[t.Awaitable] = Any(allow_none=True) - - def handle_incoming_message(self, channel_name: str, msg: list[bytes]): - """ - Handle incoming kernel messages and set up immediate cell execution state tracking. - - This method processes incoming kernel messages and caches them for response mapping. - Importantly, it detects execute_request messages and immediately sets the corresponding - cell state to 'busy' to provide real-time feedback for queued cell executions. - - This ensures that when multiple cells are executed simultaneously, all queued cells - show a '*' prompt immediately, not just the currently executing cell. - - Args: - channel_name: The kernel channel name (shell, iopub, etc.) - msg: The raw kernel message as bytes - """ - # Cache the message ID and its socket name so that - # any response message can be mapped back to the - # source channel. - header = self.session.unpack(msg[0]) - msg_id = header["msg_id"] - msg_type = header.get("msg_type") - metadata = self.session.unpack(msg[2]) - cell_id = metadata.get("cellId") - - # Clear cell outputs if cell is re-executed - if cell_id: - existing = self.message_cache.get(cell_id=cell_id) - if existing and existing['msg_id'] != msg_id: - asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id)) - - # IMPORTANT: Set cell to 'busy' immediately when execute_request is received - # This ensures queued cells show '*' prompt even before kernel starts processing them - if msg_type == "execute_request" and channel_name == "shell" and cell_id: - for yroom in self._yrooms: - yroom.set_cell_awareness_state(cell_id, "busy") - - self.message_cache.add({ - "msg_id": msg_id, - "channel": channel_name, - "cell_id": cell_id - }) - channel = getattr(self, f"{channel_name}_channel") - if channel.socket is None: - self.log.error(f"Channel {channel_name} socket is None! Cannot send message. Channel alive: {channel.is_alive()}") - raise AttributeError(f"Channel {channel_name} socket is None") - channel.session.send_raw(channel.socket, msg) - - def send_kernel_info(self): - """Sends a kernel info message on the shell channel. Useful - for determining if the kernel is busy or idle. - """ - msg = self.session.msg("kernel_info_request") - # Send message, skipping the delimiter and signature - msg = self.session.serialize(msg)[2:] - self.handle_incoming_message("shell", msg) - - def add_listener(self, callback: t.Callable[[str, list[bytes]], None]): - """Add a listener to the ZMQ Interface. - - A listener is a callable function/method that takes - the deserialized (minus the content) ZMQ message. - - If the listener is already registered, it won't be registered again. - """ - self._listeners.add(callback) - - def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): - """Remove a listener. If the listener - is not found, this method does nothing. - """ - self._listeners.discard(callback) - - async def _listen_for_messages(self, channel_name: str): - """The basic polling loop for listened to kernel messages - on a ZMQ socket. - """ - # Wire up the ZMQ sockets - # Setup up ZMQSocket broadcasting. - channel = getattr(self, f"{channel_name}_channel") - while True: - # Wait for a message - await channel.socket.poll(timeout=float("inf")) - raw_msg = await channel.socket.recv_multipart() - # Drop identities and delimit from the message parts. - _, fed_msg_list = self.session.feed_identities(raw_msg) - msg = fed_msg_list - try: - await self.handle_outgoing_message(channel_name, msg) - except Exception as err: - self.log.error(err) - - async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]): - """ - Sends message to all registered listeners. - """ - async with anyio.create_task_group() as tg: - # Broadcast the message to all listeners. - for listener in self._listeners: - async def _wrap_listener(listener_to_wrap, channel_name, msg): - """ - Wrap the listener to ensure its async and - logs (instead of raises) exceptions. - """ - try: - await ensure_async(listener_to_wrap(channel_name, msg)) - except Exception as err: - self.log.error(err) - - tg.start_soon(_wrap_listener, listener, channel_name, msg) - - async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): - """This is the main method that consumes every - message coming back from the kernel. It parses the header - (not the content, which might be large) and updates - the last_activity, execution_state, and lifecycsle_state - when appropriate. Then, it routes the message - to all listeners. - """ - if channel_name in ('iopub', 'shell'): - msg = await self.handle_document_related_message(msg) - # If msg has been cleared by the handler, escape this method. - if msg is None: - return - - await self.send_message_to_listeners(channel_name, msg) - - async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optional[t.List[bytes]]: - """ - Processes document-related messages received from a Jupyter kernel. - - Messages are deserialized and handled based on their type. Supported message types - include updating language info, kernel status, execution state, execution count, - and various output types. Some messages may be processed by an output processor - before deciding whether to forward them. - - Returns the original message if it is not processed further, otherwise None to indicate - that the message should not be forwarded. - """ - # Begin to deserialize the message safely within a try-except block - try: - dmsg = self.session.deserialize(msg, content=False) - except Exception as e: - self.log.error(f"Error deserializing message: {e}") - raise - - # Safely get parent message ID and data - parent_header = dmsg.get("parent_header", {}) - parent_msg_id = parent_header.get("msg_id") - - # Get parent message data from cache (may be None if not found) - parent_msg_data = self.message_cache.get(parent_msg_id) if parent_msg_id else None - - # Safely extract cell_id - cell_id = parent_msg_data.get('cell_id') if parent_msg_data else None - - # Handle different message types using pattern matching - match dmsg["msg_type"]: - case "kernel_info_reply": - # Unpack the content to extract language info - content = self.session.unpack(dmsg["content"]) - language_info = content["language_info"] - # Update the language info metadata for each collaborative room - for yroom in self._yrooms: - notebook = await yroom.get_jupyter_ydoc() - # The metadata ydoc is not exposed as a - # public property. - metadata = notebook.ymeta - metadata["metadata"]["language_info"] = language_info - - case "status": - # Handle kernel status messages and update cell execution states - # This provides real-time feedback about cell execution progress - content = self.session.unpack(dmsg["content"]) - execution_state = content.get("execution_state") - - # Update status across all collaborative rooms - for yroom in self._yrooms: - awareness = yroom.get_awareness() - if awareness is not None: - # If this status came from the shell channel, update - # the notebook kernel status. - if parent_msg_data and parent_msg_data.get("channel") == "shell": - # Update the kernel execution state at the top document level - awareness.set_local_state_field("kernel", {"execution_state": execution_state}) - - # Store cell execution state for persistence across client connections - # This ensures that cell execution states survive page refreshes - if cell_id: - for yroom in self._yrooms: - yroom.set_cell_execution_state(cell_id, execution_state) - yroom.set_cell_awareness_state(cell_id, execution_state) - break - - case "execute_input": - if cell_id: - # Extract execution count and update each collaborative room's notebook - content = self.session.unpack(dmsg["content"]) - execution_count = content["execution_count"] - for yroom in self._yrooms: - notebook = await yroom.get_jupyter_ydoc() - _, target_cell = notebook.find_cell(cell_id) - if target_cell: - target_cell["execution_count"] = execution_count - break - - case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output": - if cell_id: - # Process specific output messages through an optional processor - if self.output_processor: - content = self.session.unpack(dmsg["content"]) - self.output_processor.process_output(dmsg['msg_type'], cell_id, content) - - # Suppress forwarding of processed messages by returning None - return None - - # Default return if message is processed and does not need forwarding - return msg - - async def add_yroom(self, yroom: YRoom): - """ - Register a YRoom with this kernel client. YRooms will - intercept display and kernel status messages. - """ - self._yrooms.add(yroom) - - async def remove_yroom(self, yroom: YRoom): - """ - De-register a YRoom from handling kernel client messages. - """ - self._yrooms.discard(yroom) - - -AbstractDocumentAwareKernelClient.register(DocumentAwareKernelClient) diff --git a/jupyter_server_documents/kernels/kernel_client_abc.py b/jupyter_server_documents/kernels/kernel_client_abc.py deleted file mode 100644 index ecb705c..0000000 --- a/jupyter_server_documents/kernels/kernel_client_abc.py +++ /dev/null @@ -1,42 +0,0 @@ -import typing as t -from abc import ABC, abstractmethod - -from jupyter_server_documents.rooms.yroom import YRoom - - -class AbstractKernelClient(ABC): - - @abstractmethod - async def start_listening(self): - ... - - @abstractmethod - async def stop_listening(self): - ... - - @abstractmethod - def handle_incoming_message(self, channel_name: str, msg: list[bytes]): - ... - - @abstractmethod - async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): - ... - - @abstractmethod - def add_listener(self, callback: t.Callable[[str, list[bytes]], None]): - ... - - @abstractmethod - def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): - ... - - -class AbstractDocumentAwareKernelClient(AbstractKernelClient): - - @abstractmethod - async def add_yroom(self, yroom: YRoom): - ... - - @abstractmethod - async def remove_yroom(self, yroom: YRoom): - ... \ No newline at end of file diff --git a/jupyter_server_documents/kernels/kernel_manager.py b/jupyter_server_documents/kernels/kernel_manager.py deleted file mode 100644 index a59e443..0000000 --- a/jupyter_server_documents/kernels/kernel_manager.py +++ /dev/null @@ -1,175 +0,0 @@ -import typing -import asyncio -from traitlets import default -from traitlets import Instance -from traitlets import Int -from traitlets import Dict -from traitlets import Type -from traitlets import Unicode -from traitlets import validate -from traitlets import observe -from traitlets import Set -from traitlets import TraitError -from traitlets import DottedObjectName -from traitlets.utils.importstring import import_item - -from jupyter_client.manager import AsyncKernelManager - -# from . import types -from .states import ExecutionStates, LifecycleStates -from .kernel_client import AsyncKernelClient - - -class NextGenKernelManager(AsyncKernelManager): - - main_client = Instance(AsyncKernelClient, allow_none=True) - - client_class = DottedObjectName( - "jupyter_server_documents.kernels.kernel_client.DocumentAwareKernelClient" - ) - - client_factory: Type = Type(klass="jupyter_server_documents.kernels.kernel_client.DocumentAwareKernelClient") - - connection_attempts: int = Int( - default_value=10, - help="The number of initial heartbeat attempts once the kernel is alive. Each attempt is 1 second apart." - ).tag(config=True) - - execution_state: ExecutionStates = Unicode() - - @validate("execution_state") - def _validate_execution_state(self, proposal: dict): - value = proposal["value"] - if type(value) == ExecutionStates: - # Extract the enum value. - value = value.value - if not value in ExecutionStates: - raise TraitError(f"execution_state must be one of {ExecutionStates}") - return value - - lifecycle_state: LifecycleStates = Unicode() - - @validate("lifecycle_state") - def _validate_lifecycle_state(self, proposal: dict): - value = proposal["value"] - if type(value) == LifecycleStates: - # Extract the enum value. - value = value.value - if not value in LifecycleStates: - raise TraitError(f"lifecycle_state must be one of {LifecycleStates}") - return value - - def set_state( - self, - lifecycle_state: LifecycleStates = None, - execution_state: ExecutionStates = None, - ): - if lifecycle_state: - self.lifecycle_state = lifecycle_state.value - if execution_state: - self.execution_state = execution_state.value - - async def start_kernel(self, *args, **kwargs): - self.set_state(LifecycleStates.STARTING, ExecutionStates.STARTING) - out = await super().start_kernel(*args, **kwargs) - self.set_state(LifecycleStates.STARTED) - # Schedule the kernel to connect. - # Do not await here, since many clients expect - # the server to complete the start flow even - # if the kernel is not fully connected yet. - task = asyncio.create_task(self.connect()) - return out - - async def shutdown_kernel(self, *args, **kwargs): - self.set_state(LifecycleStates.TERMINATING) - await self.disconnect() - out = await super().shutdown_kernel(*args, **kwargs) - self.set_state(LifecycleStates.TERMINATED, ExecutionStates.DEAD) - - async def restart_kernel(self, *args, **kwargs): - self.set_state(LifecycleStates.RESTARTING) - return await super().restart_kernel(*args, **kwargs) - - async def connect(self): - """Open a single client interface to the kernel. - - Ideally this method doesn't care if the kernel - is actually started. It will just try a ZMQ - connection anyways and wait. This is helpful for - handling 'pending' kernels, which might still - be in a starting phase. We can keep a connection - open regardless if the kernel is ready. - """ - # Use the new API for getting a client. - self.main_client = self.client() - # Track execution state by watching all messages that come through - # the kernel client. - self.main_client.add_listener(self.execution_state_listener) - self.set_state(LifecycleStates.CONNECTING, ExecutionStates.STARTING) - await self.broadcast_state() - self.main_client.start_channels() - await self.main_client.start_listening() - # The Heartbeat channel is paused by default; unpause it here - self.main_client.hb_channel.unpause() - # Wait for a living heartbeat. - attempt = 0 - while not self.main_client.hb_channel.is_alive(): - attempt += 1 - if attempt > self.connection_attempts: - # Set the state to unknown. - self.set_state(LifecycleStates.UNKNOWN, ExecutionStates.UNKNOWN) - raise Exception("The kernel took too long to connect to the ZMQ sockets.") - # Wait a second until the next time we try again. - await asyncio.sleep(0.5) - # Wait for the kernel to reach an idle state. - while self.execution_state != ExecutionStates.IDLE.value: - self.main_client.send_kernel_info() - await asyncio.sleep(0.1) - - async def disconnect(self): - if self.main_client: - await self.main_client.stop_listening() - self.main_client.stop_channels() - - async def broadcast_state(self): - """Broadcast state to all listeners""" - if not self.main_client: - return - - # Manufacture an IOPub status message from the shell channel. - session = self.main_client.session - parent_header = session.msg_header("status") - parent_msg_id = parent_header["msg_id"] - self.main_client.message_cache.add({ - "msg_id": parent_msg_id, - "channel": "shell", - "cellId": None - }) - msg = session.msg("status", content={"execution_state": self.execution_state}, parent=parent_header) - smsg = session.serialize(msg)[1:] - await self.main_client.handle_outgoing_message("iopub", smsg) - - def execution_state_listener(self, channel_name: str, msg: list[bytes]): - """Set the execution state by watching messages returned by the shell channel.""" - # Only continue if we're on the IOPub where the status is published. - if channel_name != "iopub": - return - - session = self.main_client.session - # Unpack the message - deserialized_msg = session.deserialize(msg, content=False) - if deserialized_msg["msg_type"] == "status": - content = session.unpack(deserialized_msg["content"]) - execution_state = content["execution_state"] - if execution_state == "starting": - # Don't broadcast, since this message is already going out. - self.set_state(execution_state=ExecutionStates.STARTING) - else: - parent = deserialized_msg.get("parent_header", {}) - msg_id = parent.get("msg_id", "") - message_data = self.main_client.message_cache.get(msg_id) - if message_data is None: - return - parent_channel = message_data.get("channel") - if parent_channel and parent_channel == "shell": - self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state)) \ No newline at end of file diff --git a/jupyter_server_documents/kernels/message_cache.py b/jupyter_server_documents/kernels/message_cache.py deleted file mode 100644 index b31ba0b..0000000 --- a/jupyter_server_documents/kernels/message_cache.py +++ /dev/null @@ -1,226 +0,0 @@ -import json -from collections import OrderedDict -from traitlets import Dict, Instance, Int -from traitlets.config import LoggingConfigurable - - -class MissingKeyException(Exception): - """An exception when a dictionary is missing a required key.""" - -class InvalidKeyException(Exception): - """An exception when the key doesn't match msg_id property in value""" - - -class KernelMessageCache(LoggingConfigurable): - """ - A cache for storing kernel messages, optimized for access by message ID and cell ID. - - The cache uses an OrderedDict for message IDs to maintain insertion order and - implement LRU eviction. Messages are also indexed by cell ID for faster - retrieval when the cell ID is known. - - Attributes: - _by_cell_id (dict): A dictionary mapping cell IDs to message data. - _by_msg_id (OrderedDict): An OrderedDict mapping message IDs to message data, - maintaining insertion order for LRU eviction. - maxsize (int): The maximum number of messages to store in the cache. - """ - - _by_cell_id = Dict({}) - _by_msg_id = Instance(OrderedDict, default_value=OrderedDict()) - maxsize = Int(default_value=10000).tag(config=True) - - - def __repr__(self): - """ - Returns a JSON string representation of the message ID cache. - """ - return json.dumps(self._by_msg_id, indent=2) - - def __getitem__(self, msg_id): - """ - Retrieves a message from the cache by message ID. Moves the accessed - message to the end of the OrderedDict to update its access time. - - Args: - msg_id (str): The message ID. - - Returns: - dict: The message data. - - Raises: - KeyError: If the message ID is not found in the cache. - """ - out = self._by_msg_id[msg_id] - self._by_msg_id.move_to_end(msg_id) - return out - - def __setitem__(self, msg_id, value): - """ - Adds a message to the cache. If the cache is full, the least recently - used message is evicted. - - Args: - msg_id (str): The message ID. - value (dict): The message data. - - Raises: - Exception: If the msg_id does not match the message ID in the value, - or if the message data is missing required fields - ("msg_id", "channel"). - """ - if "msg_id" not in value: - raise MissingKeyException("`msg_id` missing in message data") - - if "channel" not in value: - raise MissingKeyException("`channel` missing in message data") - - if value["msg_id"] != msg_id: - raise InvalidKeyException("Key must match `msg_id` in value") - - # Remove the existing msg_id if a new msg with same cell_id exists - if value["channel"] == "shell" and "cell_id" in value and value["cell_id"] in self._by_cell_id: - existing_msg_id = self._by_cell_id[value["cell_id"]]["msg_id"] - if msg_id != existing_msg_id: - del self._by_msg_id[existing_msg_id] - - if "cell_id" in value and value['cell_id'] is not None: - self._by_cell_id[value['cell_id']] = value - - self._by_msg_id[msg_id] = value - if len(self._by_msg_id) > self.maxsize: - self._remove_oldest() - - def _remove_oldest(self): - """ - Removes the least recently used message from the cache. - """ - try: - key, item = self._by_msg_id.popitem(last=False) - if 'cell_id' in item: # Check if 'cell_id' key exists - try: - del self._by_cell_id[item['cell_id']] - except KeyError: - pass # Handle the case where the cell_id is not present - except KeyError: - pass # Handle the case where the cache is empty - - def __delitem__(self, msg_id): - """ - Removes a message from the cache by message ID. - - Args: - msg_id (str): The message ID. - """ - msg_data = self._by_msg_id[msg_id] - try: - cell_id = msg_data["cell_id"] - del self._by_cell_id[cell_id] - except KeyError: - pass - del self._by_msg_id[msg_id] - - def __contains__(self, msg_id): - """ - Checks if a message with the given message ID is in the cache. - - Args: - msg_id (str): The message ID. - - Returns: - bool: True if the message is in the cache, False otherwise. - """ - return msg_id in self._by_msg_id - - def __iter__(self): - """ - Returns an iterator over the message IDs in the cache. - """ - for msg_id in self._by_msg_id: - yield msg_id - - def __len__(self): - """ - Returns the number of messages in the cache. - """ - return len(self._by_msg_id) - - def add(self, data): - """ - Adds a message to the cache using its message ID as the key. - - Args: - data (dict): The message data. - """ - self[data['msg_id']] = data - - def get(self, msg_id=None, cell_id=None): - """ - Retrieves a message from the cache, either by message ID or cell ID. - - Args: - msg_id (str, optional): The message ID. Defaults to None. - cell_id (str, optional): The cell ID. Defaults to None. - - Returns: - dict: The message data, or None if not found. - """ - try: - out = self._by_cell_id[cell_id] - msg_id = out['msg_id'] - self._by_msg_id.move_to_end(msg_id) - return out - except KeyError: - try: - out = self._by_msg_id[msg_id] - self._by_msg_id.move_to_end(msg_id) - return out - except KeyError: - return None - - def remove(self, msg_id=None, cell_id=None): - """ - Removes a message from the cache, either by message ID or cell ID. - - Args: - msg_id (str, optional): The message ID. Defaults to None. - cell_id (str, optional): The cell ID. Defaults to None. - """ - try: - out = self._by_cell_id[cell_id] - msg_id = out['msg_id'] - del self._by_msg_id[msg_id] - del self._by_cell_id[cell_id] - except KeyError: - try: - out = self._by_msg_id[msg_id] - try: - cell_id = out['cell_id'] - del self._by_cell_id[cell_id] - except KeyError: - pass - finally: - del self._by_msg_id[msg_id] - except KeyError: - return - - def pop(self, msg_id=None, cell_id=None): - """ - Removes and returns a message from the cache, either by message ID or cell ID. - - Args: - msg_id (str, optional): The message ID. Defaults to None. - cell_id (str, optional): The cell ID. Defaults to None. - - Returns: - dict: The message data. - - Raises: - KeyError: If the message ID or cell ID is not found. - """ - try: - out = self._by_cell_id[cell_id] - except KeyError: - out = self._by_msg_id[msg_id] - self.remove(msg_id=out['msg_id']) - return out \ No newline at end of file diff --git a/jupyter_server_documents/kernels/multi_kernel_manager.py b/jupyter_server_documents/kernels/multi_kernel_manager.py deleted file mode 100644 index bc26806..0000000 --- a/jupyter_server_documents/kernels/multi_kernel_manager.py +++ /dev/null @@ -1,17 +0,0 @@ -from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager - - -class NextGenMappingKernelManager(AsyncMappingKernelManager): - - def start_watching_activity(self, kernel_id): - pass - - def stop_buffering(self, kernel_id): - pass - - # NOTE: Since we disable watching activity and buffering here, - # this method needs to be forked and remove code related to these things. - async def restart_kernel(self, kernel_id, now=False): - """Restart a kernel by kernel_id""" - self._check_kernel_id(kernel_id) - await self.pinned_superclass._async_restart_kernel(self, kernel_id, now=now) \ No newline at end of file diff --git a/jupyter_server_documents/kernels/states.py b/jupyter_server_documents/kernels/states.py deleted file mode 100644 index 5adf563..0000000 --- a/jupyter_server_documents/kernels/states.py +++ /dev/null @@ -1,35 +0,0 @@ -from enum import Enum -from enum import EnumMeta - -class StrContainerEnumMeta(EnumMeta): - def __contains__(cls, item): - for name, member in cls.__members__.items(): - if item == name or item == member.value: - return True - return False -class StrContainerEnum(str, Enum, metaclass=StrContainerEnumMeta): - """A Enum object that enables search for items - in a normal Enum object based on key and value. - """ - -class LifecycleStates(StrContainerEnum): - UNKNOWN = "unknown" - STARTING = "starting" - STARTED = "started" - TERMINATING = "terminating" - CONNECTING = "connecting" - CONNECTED = "connected" - RESTARTING = "restarting" - RECONNECTING = "reconnecting" - CULLED = "culled" - DISCONNECTED = "disconnected" - TERMINATED = "terminated" - DEAD = "dead" - - -class ExecutionStates(StrContainerEnum): - BUSY = "busy" - IDLE = "idle" - STARTING = "starting" - UNKNOWN = "unknown" - DEAD = "dead" \ No newline at end of file diff --git a/jupyter_server_documents/kernels/websocket_connection.py b/jupyter_server_documents/kernels/websocket_connection.py deleted file mode 100644 index 075d2cb..0000000 --- a/jupyter_server_documents/kernels/websocket_connection.py +++ /dev/null @@ -1,46 +0,0 @@ -from tornado.websocket import WebSocketClosedError -from jupyter_server.services.kernels.connection.base import ( - BaseKernelWebsocketConnection, -) -from .states import LifecycleStates -from jupyter_server.services.kernels.connection.base import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1 - -class NextGenKernelWebsocketConnection(BaseKernelWebsocketConnection): - """A websocket client that connects to a kernel manager. - - NOTE: This connection only works with the (newer) v1 websocket protocol. - https://jupyter-server.readthedocs.io/en/latest/developers/websocket-protocols.html - """ - - kernel_ws_protocol = "v1.kernel.websocket.jupyter.org" - - async def connect(self): - """A synchronous method for connecting to the kernel via a kernel session. - This connection might take a few minutes, so we turn this into an - asyncio task happening in parallel. - """ - self.kernel_manager.main_client.add_listener(self.handle_outgoing_message) - await self.kernel_manager.broadcast_state() - self.log.info("Kernel websocket is now listening to kernel.") - - def disconnect(self): - self.kernel_manager.main_client.remove_listener(self.handle_outgoing_message) - - def handle_incoming_message(self, incoming_msg): - """Handle the incoming WS message""" - channel_name, msg_list = deserialize_msg_from_ws_v1(incoming_msg) - if self.kernel_manager.main_client: - self.kernel_manager.main_client.handle_incoming_message(channel_name, msg_list) - - def handle_outgoing_message(self, channel_name, msg): - """Handle the ZMQ message.""" - try: - # Remove signature from message to be compatible with Jupyter Server. - # See here: https://github.com/jupyter-server/jupyter_server/blob/4ee6e1ddc058f87b2c5699cd6b9caf780a013044/jupyter_server/services/kernels/connection/channels.py#L504 - msg = msg[1:] - msg = serialize_msg_to_ws_v1(msg, channel_name) - self.websocket_handler.write_message(msg, binary=True) - except WebSocketClosedError: - self.log.warning("A ZMQ message arrived on a closed websocket channel.") - except Exception as err: - self.log.error(err) \ No newline at end of file diff --git a/jupyter_server_documents/outputs/output_processor.py b/jupyter_server_documents/outputs/output_processor.py index c8b55be..38a5caf 100644 --- a/jupyter_server_documents/outputs/output_processor.py +++ b/jupyter_server_documents/outputs/output_processor.py @@ -4,7 +4,7 @@ from traitlets import Unicode, Bool, Set from traitlets.config import LoggingConfigurable -from jupyter_server_documents.kernels.message_cache import KernelMessageCache +from jupyter_server.serverapp import ServerApp class OutputProcessor(LoggingConfigurable): @@ -18,10 +18,8 @@ class OutputProcessor(LoggingConfigurable): @property def settings(self): - """A shortcut for the Tornado web app settings.""" - # self.KernelClient.KernelManager.AsyncMultiKernelManager.ServerApp - return self.parent.parent.parent.parent.web_app.settings - + return ServerApp.instance().web_app.settings + @property def kernel_client(self): """A shortcut to the kernel client this output processor is attached to.""" @@ -82,7 +80,7 @@ async def clear_cell_outputs(self, cell_id): def process_output(self, msg_type: str, cell_id: str, content: dict): """Process outgoing messages from the kernel. - + This returns the input dmsg if no the message should be sent to clients, or None if it should not be sent. @@ -93,11 +91,20 @@ def process_output(self, msg_type: str, cell_id: str, content: dict): The content has not been deserialized yet as we need to verify we should process it. """ + + def task_done_callback(task): + try: + task.result() # This will raise any exception that occurred + except Exception as e: + self.log.error(f"Error in output task: {e}", exc_info=True) + if msg_type == "clear_output": - asyncio.create_task(self.clear_output_task(cell_id, content)) + task = asyncio.create_task(self.clear_output_task(cell_id, content)) + task.add_done_callback(task_done_callback) else: - asyncio.create_task(self.output_task(msg_type, cell_id, content)) - + task = asyncio.create_task(self.output_task(msg_type, cell_id, content)) + task.add_done_callback(task_done_callback) + return None # Don't allow the original message to propagate to the frontend async def clear_output_task(self, cell_id, content): @@ -163,7 +170,6 @@ async def output_task(self, msg_type, cell_id, content): target_cell["outputs"][output_index] = output else: target_cell["outputs"].append(output) - self.log.info(f"Wrote output to ydoc: {path} {cell_id} {output}") def transform_output(self, msg_type, content, ydoc=False): diff --git a/jupyter_server_documents/rooms/yroom_file_api.py b/jupyter_server_documents/rooms/yroom_file_api.py index 6bd7c2d..06c39a5 100644 --- a/jupyter_server_documents/rooms/yroom_file_api.py +++ b/jupyter_server_documents/rooms/yroom_file_api.py @@ -546,7 +546,7 @@ async def save(self, jupyter_ydoc: YBaseDoc): # Set most recent `last_modified` timestamp if file_data['last_modified']: - self.log.info(f"Reseting last_modified to {file_data['last_modified']}") + self.log.debug(f"Resetting last_modified to {file_data['last_modified']}") self._last_modified = file_data['last_modified'] # Set `dirty` to `False` to hide the "unsaved changes" icon in the diff --git a/jupyter_server_documents/session_manager.py b/jupyter_server_documents/session_manager.py index 473a0cd..eb3672b 100644 --- a/jupyter_server_documents/session_manager.py +++ b/jupyter_server_documents/session_manager.py @@ -5,7 +5,7 @@ from jupyter_server_fileid.manager import BaseFileIdManager from jupyter_server_documents.rooms.yroom_manager import YRoomManager from jupyter_server_documents.rooms.yroom import YRoom -from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient +from jupyter_server_documents.kernel_client import DocumentAwareKernelClient class YDocSessionManager(SessionManager): @@ -39,12 +39,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._room_ids = {} - def get_kernel_client(self, kernel_id: str) -> DocumentAwareKernelClient: - """Get the kernel client for a running kernel.""" - kernel_manager = self.kernel_manager.get_kernel(kernel_id) - kernel_client = kernel_manager.main_client - return kernel_client - def get_yroom(self, session_id: str) -> YRoom: """ Get the `YRoom` for a session given its ID. The session must have @@ -70,9 +64,136 @@ def _init_session_yroom(self, session_id: str, path: str) -> YRoom: room_id = f"json:notebook:{file_id}" yroom = self.yroom_manager.get_room(room_id) self._room_ids[session_id] = room_id - return yroom + async def _ensure_yroom_connected(self, session_id: str, kernel_id: str) -> None: + """ + Ensures that a session's yroom is connected to its kernel client. + + This method is critical for maintaining the connection between collaborative + document state (yroom) and kernel execution state. It handles scenarios where + the yroom-kernel connection may have been lost, such as: + + - Server restarts where sessions persist but in-memory connections are lost + - Remote/persistent kernels that survive across server lifecycles + - Recovery from transient failures or race conditions during session setup + + The method is idempotent - it checks if the yroom is already connected before + attempting to add it, preventing duplicate connections. + + Args: + session_id: The unique identifier for the session + kernel_id: The unique identifier for the kernel + + Note: + This method silently handles cases where the yroom or kernel don't exist, + or where the session has no associated yroom. Failures are logged but + don't raise exceptions. + """ + # Check if this session has an associated yroom in the cache + room_id = self._room_ids.get(session_id) + + # If not cached, populate it from the session's path + # This handles persistent sessions that survive server restarts + if not room_id: + try: + # Get the session from the database to find its path + # Use super() to avoid infinite recursion since we're called from get_session + session = await super().get_session(session_id=session_id) + if session and session.get("type") == "notebook": + path = session.get("path") + if path: + # Use the same logic as _init_session_yroom to calculate room_id + file_id = self.file_id_manager.index(path) + room_id = f"json:notebook:{file_id}" + # Cache it for future calls + self._room_ids[session_id] = room_id + self.log.debug(f"Populated room_id {room_id} from session path for session {session_id}") + else: + self.log.debug(f"Session {session_id} has no path") + return + else: + self.log.debug(f"Session {session_id} is not a notebook type") + return + except Exception as e: + self.log.warning(f"Failed to lookup session {session_id}: {e}") + return + + if not room_id: + # Session has no yroom (e.g., console session or non-notebook type) + return + + # Get the yroom if it exists + yroom = self.yroom_manager.get_room(room_id) + if not yroom: + # Room doesn't exist yet or was cleaned up + return + + # Ensure the yroom is added to the kernel client + try: + kernel_manager = self.serverapp.kernel_manager.get_kernel(kernel_id) + kernel_client = kernel_manager.kernel_client + + # Check if yroom is already connected to avoid duplicate connections + if hasattr(kernel_client, '_yrooms') and yroom not in kernel_client._yrooms: + await kernel_client.add_yroom(yroom) + self.log.info( + f"Reconnected yroom {room_id} to kernel_client for session {session_id}. " + f"This ensures kernel messages are routed to the collaborative document." + ) + except Exception as e: + # Log but don't fail - the session is still valid even if yroom connection fails + self.log.warning( + f"Failed to connect yroom to kernel_client for session {session_id}: {e}" + ) + + async def get_session(self, **kwargs) -> Optional[dict[str, Any]]: + """ + Retrieves a session and ensures the yroom-kernel connection is established. + + This override of the parent's get_session() adds a critical step: verifying + and restoring the connection between the session's yroom (collaborative state) + and its kernel client (execution engine). + + Why this matters: + - When reconnecting to persistent/remote kernels, the in-memory yroom connection + may not exist even though both the session and kernel are valid + - Server restarts can break yroom-kernel connections while sessions persist + - This ensures that every time a session is retrieved, it's fully functional + for collaborative notebook editing and execution + + Args: + **kwargs: Arguments passed to the parent's get_session() method + (e.g., session_id, path, kernel_id) + + Returns: + The session model dict, or None if no session is found + """ + session = await super().get_session(**kwargs) + + # If no session found, return None + if session is None: + return None + + # Extract session and kernel information + session_id = session.get("id") + kernel_info = session.get("kernel") + + # Only process sessions with valid kernel and session ID + if not kernel_info or not session_id: + return session + + kernel_id = kernel_info.get("id") + if not kernel_id: + return session + + # Ensure the yroom is connected to the kernel client + # This is especially important for persistent kernels that survive server restarts + await self._ensure_yroom_connected(session_id, kernel_id) + + return session + + async def create_session( self, path: Optional[str] = None, @@ -83,7 +204,34 @@ async def create_session( ) -> dict[str, Any]: """ After creating a session, connects the yroom to the kernel client. + Sets kernel status to "starting" before kernel launch. """ + # For notebooks, set up the YRoom and set initial status before starting kernel + should_setup_yroom = ( + type == "notebook" and + name is not None and + path is not None + ) + + yroom = None + if should_setup_yroom: + # Calculate the real path + real_path = os.path.join(os.path.split(path)[0], name) + + # Initialize the YRoom before starting the kernel + file_id = self.file_id_manager.index(real_path) + room_id = f"json:notebook:{file_id}" + yroom = self.yroom_manager.get_room(room_id) + + # Set initial kernel status to "starting" in awareness + awareness = yroom.get_awareness() + if awareness is not None: + self.log.info("Setting kernel execution_state to 'starting' before kernel launch") + awareness.set_local_state_field( + "kernel", {"execution_state": "starting"} + ) + + # Now create the session and start the kernel session_model = await super().create_session( path, name, @@ -108,32 +256,23 @@ async def create_session( self.log.warning(f"`name` or `path` was not given for new session at '{path}'.") return session_model - # Otherwise, get a `YRoom` and add it to this session's kernel client. - - # When JupyterLab creates a session, it uses a fake path - # which is the relative path + UUID, i.e. the notebook - # name is incorrect temporarily. It later makes multiple - # updates to the session to correct the path. - # - # Here, we create the true path to store in the fileID service - # by dropping the UUID and appending the file name. - real_path = os.path.join(os.path.split(path)[0], name) - - # Get YRoom for this session and store its ID in `self._room_ids` - yroom = self._init_session_yroom(session_id, real_path) - + # Store the room ID for this session + if yroom: + self._room_ids[session_id] = yroom.room_id + else: + # Shouldn't happen, but handle it anyway + real_path = os.path.join(os.path.split(path)[0], name) + yroom = self._init_session_yroom(session_id, real_path) + # Add YRoom to this session's kernel client - # TODO: we likely have a race condition here... need to - # think about it more. Currently, the kernel client gets - # created after the kernel starts fully. We need the - # kernel client instantiated _before_ trying to connect - # the yroom. - kernel_client = self.get_kernel_client(kernel_id) + # Ensure the kernel client is fully connected before proceeding + # to avoid queuing messages on first execution + kernel_manager = self.serverapp.kernel_manager.get_kernel(kernel_id) + kernel_client = kernel_manager.kernel_client await kernel_client.add_yroom(yroom) self.log.info(f"Connected yroom {yroom.room_id} to kernel {kernel_id}. yroom: {yroom}") return session_model - async def update_session(self, session_id: str, **update) -> None: """ Updates the session identified by `session_id` using the keyword @@ -158,10 +297,12 @@ async def update_session(self, session_id: str, **update) -> None: ) yroom = self.get_yroom(session_id) if old_kernel_id: - old_kernel_client = self.get_kernel_client(old_kernel_id) + old_kernel_manager = self.serverapp.kernel_manager.get_kernel(old_kernel_id) + old_kernel_client = old_kernel_manager.kernel_client await old_kernel_client.remove_yroom(yroom=yroom) if new_kernel_id: - new_kernel_client = self.get_kernel_client(new_kernel_id) + new_kernel_manager = self.serverapp.kernel_manager.get_kernel(new_kernel_id) + new_kernel_client = new_kernel_manager.kernel_client await new_kernel_client.add_yroom(yroom=yroom) # Apply update and return @@ -177,7 +318,8 @@ async def delete_session(self, session_id): # Remove YRoom from session's kernel client yroom = self.get_yroom(session_id) - kernel_client = self.get_kernel_client(kernel_id) + kernel_manager = self.serverapp.kernel_manager.get_kernel(kernel_id) + kernel_client = kernel_manager.kernel_client await kernel_client.remove_yroom(yroom) # Remove room ID stored for the session diff --git a/jupyter_server_documents/tests/kernels/__init__.py b/jupyter_server_documents/tests/kernels/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jupyter_server_documents/tests/kernels/conftest.py b/jupyter_server_documents/tests/kernels/conftest.py deleted file mode 100644 index d957c17..0000000 --- a/jupyter_server_documents/tests/kernels/conftest.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Configuration for kernel tests.""" - -import pytest -from unittest.mock import MagicMock - - -@pytest.fixture -def mock_logger(): - """Create a mock logger for testing.""" - return MagicMock() - - -@pytest.fixture -def mock_session(): - """Create a mock session for testing.""" - session = MagicMock() - session.msg_header.return_value = {"msg_id": "test-msg-id"} - session.msg.return_value = {"test": "message"} - session.serialize.return_value = ["", "serialized", "msg"] - session.deserialize.return_value = {"msg_type": "test", "content": b"test"} - session.unpack.return_value = {"test": "data"} - session.feed_identities.return_value = ([], [b"test", b"message"]) - return session \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_client.py b/jupyter_server_documents/tests/kernels/test_kernel_client.py deleted file mode 100644 index 3f24bab..0000000 --- a/jupyter_server_documents/tests/kernels/test_kernel_client.py +++ /dev/null @@ -1,105 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch - -from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient -from jupyter_server_documents.kernels.message_cache import KernelMessageCache -from jupyter_server_documents.outputs import OutputProcessor - - -class TestDocumentAwareKernelClient: - """Test cases for DocumentAwareKernelClient.""" - - def test_default_message_cache(self): - """Test that message cache is created by default.""" - client = DocumentAwareKernelClient() - assert isinstance(client.message_cache, KernelMessageCache) - - def test_default_output_processor(self): - """Test that output processor is created by default.""" - client = DocumentAwareKernelClient() - assert isinstance(client.output_processor, OutputProcessor) - - @pytest.mark.asyncio - async def test_stop_listening_no_task(self): - """Test that stop_listening does nothing when no task exists.""" - client = DocumentAwareKernelClient() - client._listening_task = None - - # Should not raise an exception - await client.stop_listening() - - def test_add_listener(self): - """Test adding a listener.""" - client = DocumentAwareKernelClient() - - def test_listener(channel, msg): - pass - - client.add_listener(test_listener) - - assert test_listener in client._listeners - - def test_remove_listener(self): - """Test removing a listener.""" - client = DocumentAwareKernelClient() - - def test_listener(channel, msg): - pass - - client.add_listener(test_listener) - client.remove_listener(test_listener) - - assert test_listener not in client._listeners - - @pytest.mark.asyncio - async def test_add_yroom(self): - """Test adding a YRoom.""" - client = DocumentAwareKernelClient() - - mock_yroom = MagicMock() - await client.add_yroom(mock_yroom) - - assert mock_yroom in client._yrooms - - @pytest.mark.asyncio - async def test_remove_yroom(self): - """Test removing a YRoom.""" - client = DocumentAwareKernelClient() - - mock_yroom = MagicMock() - client._yrooms.add(mock_yroom) - - await client.remove_yroom(mock_yroom) - - assert mock_yroom not in client._yrooms - - def test_send_kernel_info_creates_message(self): - """Test that send_kernel_info creates a kernel info message.""" - client = DocumentAwareKernelClient() - - # Mock session - from jupyter_client.session import Session - client.session = Session() - - with patch.object(client, 'handle_incoming_message') as mock_handle: - client.send_kernel_info() - - # Verify that handle_incoming_message was called with shell channel - mock_handle.assert_called_once() - args, kwargs = mock_handle.call_args - assert args[0] == "shell" # Channel name - assert isinstance(args[1], list) # Message list - - @pytest.mark.asyncio - async def test_handle_outgoing_message_control_channel(self): - """Test that control channel messages bypass document handling.""" - client = DocumentAwareKernelClient() - - msg = [b"test", b"message"] - - with patch.object(client, 'handle_document_related_message') as mock_handle_doc: - with patch.object(client, 'send_message_to_listeners') as mock_send: - await client.handle_outgoing_message("control", msg) - - mock_handle_doc.assert_not_called() - mock_send.assert_called_once_with("control", msg) \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py b/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py deleted file mode 100644 index fb25bd2..0000000 --- a/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py +++ /dev/null @@ -1,490 +0,0 @@ -import pytest -import asyncio -import json -from unittest.mock import MagicMock, AsyncMock, patch -from jupyter_client.session import Session -from jupyter_server_documents.ydocs import YNotebook -import pycrdt - -from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient -from jupyter_server_documents.rooms.yroom import YRoom -from jupyter_server_documents.outputs import OutputProcessor - - -class TestDocumentAwareKernelClientIntegration: - """Integration tests for DocumentAwareKernelClient with YDoc updates.""" - - @pytest.fixture - def mock_yroom_with_notebook(self): - """Create a mock YRoom with a real YNotebook.""" - # Create a real YDoc and YNotebook - ydoc = pycrdt.Doc() - awareness = MagicMock(spec=pycrdt.Awareness) # Mock awareness instead of using real one - - # Mock the local state to track changes - local_state = {} - awareness.get_local_state = MagicMock(return_value=local_state) - - # Mock set_local_state_field to actually update the local_state dict - def mock_set_local_state_field(field, value): - local_state[field] = value - - awareness.set_local_state_field = MagicMock(side_effect=mock_set_local_state_field) - - ynotebook = YNotebook(ydoc, awareness) - - # Add a simple notebook structure with one cell - ynotebook.set({ - "cells": [ - { - "cell_type": "code", - "id": "test-cell-1", - "source": "2 + 2", - "metadata": {}, - "outputs": [], - "execution_count": None - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.9.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 - }) - - # Create mock YRoom - yroom = MagicMock(spec=YRoom) - yroom.get_jupyter_ydoc = AsyncMock(return_value=ynotebook) - yroom.get_awareness = MagicMock(return_value=awareness) - - # Add persistent cell execution state storage - yroom._cell_execution_states = {} - - def mock_get_cell_execution_states(): - return yroom._cell_execution_states - - def mock_set_cell_execution_state(cell_id, execution_state): - yroom._cell_execution_states[cell_id] = execution_state - - yroom.get_cell_execution_states = MagicMock(side_effect=mock_get_cell_execution_states) - yroom.set_cell_execution_state = MagicMock(side_effect=mock_set_cell_execution_state) - - # Add awareness cell execution state management - def mock_set_cell_awareness_state(cell_id, execution_state): - current_local_state = awareness.get_local_state() - if current_local_state is None: - current_local_state = local_state - cell_states = current_local_state.get("cell_execution_states", {}) - cell_states[cell_id] = execution_state - awareness.set_local_state_field("cell_execution_states", cell_states) - - yroom.set_cell_awareness_state = MagicMock(side_effect=mock_set_cell_awareness_state) - - return yroom, ynotebook - - @pytest.fixture - def kernel_client_with_yroom(self, mock_yroom_with_notebook): - """Create a DocumentAwareKernelClient with a real YRoom and YNotebook.""" - yroom, ynotebook = mock_yroom_with_notebook - - client = DocumentAwareKernelClient() - client.session = Session() - client.log = MagicMock() - - # Add the YRoom to the client - client._yrooms = {yroom} - - # Mock output processor - client.output_processor = MagicMock(spec=OutputProcessor) - client.output_processor.process_output = MagicMock() - - return client, yroom, ynotebook - - def create_kernel_message(self, session, msg_type, content, parent_msg_id=None, cell_id=None): - """Helper to create properly formatted kernel messages.""" - parent_header = {"msg_id": parent_msg_id} if parent_msg_id else {} - metadata = {"cellId": cell_id} if cell_id else {} - - msg = session.msg(msg_type, content, parent=parent_header, metadata=metadata) - return session.serialize(msg) - - @pytest.mark.asyncio - async def test_execute_input_updates_execution_count(self, kernel_client_with_yroom): - """Test that execute_input messages update execution count in YDoc.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) - - # Create execute_input message - content = {"code": "2 + 2", "execution_count": 1} - msg_parts = self.create_kernel_message( - client.session, "execute_input", content, parent_msg_id, cell_id - ) - - # Process the message - await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify the execution count was updated in the YDoc - cells = ynotebook.get_cell_list() - target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None) - assert target_cell is not None - assert target_cell.get("execution_count") == 1 - - @pytest.mark.asyncio - async def test_status_message_updates_cell_execution_state(self, kernel_client_with_yroom): - """Test that status messages update cell execution state in YRoom for persistence and awareness for real-time updates.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id and channel - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={ - "cell_id": cell_id, - "channel": "shell" - }) - - # Create status message with 'busy' state - content = {"execution_state": "busy"} - msg_parts = self.create_kernel_message( - client.session, "status", content, parent_msg_id, cell_id - ) - - # Process the message - await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify the cell execution state was stored in YRoom for persistence - cell_states = yroom.get_cell_execution_states() - assert cell_states[cell_id] == "busy" - - # Verify it's also in awareness for real-time updates - awareness = yroom.get_awareness() - local_state = awareness.get_local_state() - assert local_state is not None - assert "cell_execution_states" in local_state - assert local_state["cell_execution_states"][cell_id] == "busy" - - @pytest.mark.asyncio - async def test_kernel_info_reply_updates_language_info(self, kernel_client_with_yroom): - """Test that kernel_info_reply updates language info in YDoc metadata.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache - parent_msg_id = "kernel-info-request-123" - client.message_cache.get = MagicMock(return_value={"cell_id": None}) - - # Create kernel_info_reply message - content = { - "language_info": { - "name": "python", - "version": "3.9.0", - "mimetype": "text/x-python", - "file_extension": ".py" - } - } - msg_parts = self.create_kernel_message( - client.session, "kernel_info_reply", content, parent_msg_id - ) - - # Process the message - await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify language info was updated in notebook metadata - metadata = ynotebook.get_meta() - assert "language_info" in metadata["metadata"] - assert metadata["metadata"]["language_info"]["name"] == "python" - assert metadata["metadata"]["language_info"]["version"] == "3.9.0" - - @pytest.mark.asyncio - async def test_output_message_processed_and_suppressed(self, kernel_client_with_yroom): - """Test that output messages are processed by output processor and suppressed.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) - - # Create execute_result message (output) - content = { - "data": {"text/plain": "4"}, - "metadata": {}, - "execution_count": 1 - } - msg_parts = self.create_kernel_message( - client.session, "execute_result", content, parent_msg_id, cell_id - ) - - # Process the message - result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify the output processor was called - client.output_processor.process_output.assert_called_once_with( - "execute_result", cell_id, content - ) - - # Verify the message was suppressed (returned None) - assert result is None - - @pytest.mark.asyncio - async def test_stream_output_message_processed(self, kernel_client_with_yroom): - """Test that stream output messages are processed correctly.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) - - # Create stream message - content = { - "name": "stdout", - "text": "4\n" - } - msg_parts = self.create_kernel_message( - client.session, "stream", content, parent_msg_id, cell_id - ) - - # Process the message - result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify the output processor was called - client.output_processor.process_output.assert_called_once_with( - "stream", cell_id, content - ) - - # Verify the message was suppressed - assert result is None - - @pytest.mark.asyncio - async def test_error_output_message_processed(self, kernel_client_with_yroom): - """Test that error output messages are processed correctly.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) - - # Create error message - content = { - "ename": "NameError", - "evalue": "name 'x' is not defined", - "traceback": ["Traceback (most recent call last):", "NameError: name 'x' is not defined"] - } - msg_parts = self.create_kernel_message( - client.session, "error", content, parent_msg_id, cell_id - ) - - # Process the message - result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter - - # Verify the output processor was called - client.output_processor.process_output.assert_called_once_with( - "error", cell_id, content - ) - - # Verify the message was suppressed - assert result is None - - @pytest.mark.asyncio - async def test_complete_execution_flow(self, kernel_client_with_yroom): - """Test complete execution flow: execute_input -> status -> output -> status.""" - client, yroom, ynotebook = kernel_client_with_yroom - - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - - # Mock message cache to return cell_id and channel - client.message_cache.get = MagicMock(return_value={ - "cell_id": cell_id, - "channel": "shell" - }) - - # Step 1: Execute input - execute_input_content = {"code": "2 + 2", "execution_count": 1} - msg_parts = self.create_kernel_message( - client.session, "execute_input", execute_input_content, parent_msg_id, cell_id - ) - await client.handle_document_related_message(msg_parts[1:]) - - # Step 2: Status busy - status_busy_content = {"execution_state": "busy"} - msg_parts = self.create_kernel_message( - client.session, "status", status_busy_content, parent_msg_id, cell_id - ) - await client.handle_document_related_message(msg_parts[1:]) - - # Step 3: Execute result - result_content = { - "data": {"text/plain": "4"}, - "metadata": {}, - "execution_count": 1 - } - msg_parts = self.create_kernel_message( - client.session, "execute_result", result_content, parent_msg_id, cell_id - ) - await client.handle_document_related_message(msg_parts[1:]) - - # Step 4: Status idle - status_idle_content = {"execution_state": "idle"} - msg_parts = self.create_kernel_message( - client.session, "status", status_idle_content, parent_msg_id, cell_id - ) - await client.handle_document_related_message(msg_parts[1:]) - - # Verify final state of the cell in YDoc and awareness - cells = ynotebook.get_cell_list() - target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None) - assert target_cell is not None - assert target_cell.get("execution_count") == 1 - - # Verify execution state is stored in awareness, not YDoc - awareness = yroom.get_awareness() - cell_execution_states = awareness.get_local_state().get("cell_execution_states", {}) - assert cell_execution_states.get(cell_id) == "idle" - - # Verify output processor was called for the result - client.output_processor.process_output.assert_called_with( - "execute_result", cell_id, result_content - ) - - @pytest.mark.asyncio - async def test_awareness_state_updates_for_kernel_status(self, kernel_client_with_yroom): - """Test that kernel status updates awareness state.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return shell channel (for notebook-level status) - parent_msg_id = "kernel-info-request-123" - client.message_cache.get = MagicMock(return_value={ - "cell_id": None, - "channel": "shell" - }) - - # Create status message for kernel-level state - content = {"execution_state": "busy"} - msg_parts = self.create_kernel_message( - client.session, "status", content, parent_msg_id - ) - - # Process the message - await client.handle_document_related_message(msg_parts[1:]) - - # Verify awareness was updated - awareness = yroom.get_awareness() - awareness.set_local_state_field.assert_called_once_with( - "kernel", {"execution_state": "busy"} - ) - - @pytest.mark.asyncio - async def test_multiple_cells_execution_states(self, kernel_client_with_yroom): - """Test that multiple cells can have different execution states.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Add another cell to the notebook - cells = ynotebook.get_cell_list() - ynotebook.append_cell({ - "cell_type": "code", - "id": "test-cell-2", - "source": "print('hello')", - "metadata": {}, - "outputs": [], - "execution_count": None - }) - - # Mock message cache to return different cell_ids - def mock_get(msg_id): - if msg_id == "execute-request-123": - return {"cell_id": "test-cell-1", "channel": "shell"} - elif msg_id == "execute-request-456": - return {"cell_id": "test-cell-2", "channel": "shell"} - return None - - client.message_cache.get = MagicMock(side_effect=mock_get) - - # Set first cell to busy - content1 = {"execution_state": "busy"} - msg_parts1 = self.create_kernel_message( - client.session, "status", content1, "execute-request-123", "test-cell-1" - ) - await client.handle_document_related_message(msg_parts1[1:]) - - # Set second cell to idle - content2 = {"execution_state": "idle"} - msg_parts2 = self.create_kernel_message( - client.session, "status", content2, "execute-request-456", "test-cell-2" - ) - await client.handle_document_related_message(msg_parts2[1:]) - - # Verify both cells have correct states in awareness - awareness = yroom.get_awareness() - cell_execution_states = awareness.get_local_state().get("cell_execution_states", {}) - - assert cell_execution_states.get("test-cell-1") == "busy" # 'busy' state - assert cell_execution_states.get("test-cell-2") == "idle" - - @pytest.mark.asyncio - async def test_message_without_cell_id_skips_cell_updates(self, kernel_client_with_yroom): - """Test that messages without cell_id don't update cell-specific data.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return no cell_id - parent_msg_id = "some-request-123" - client.message_cache.get = MagicMock(return_value={"cell_id": None}) - - # Create execute_input message without cell_id - content = {"code": "2 + 2", "execution_count": 1} - msg_parts = self.create_kernel_message( - client.session, "execute_input", content, parent_msg_id - ) - - # Process the message - await client.handle_document_related_message(msg_parts[1:]) - - # Verify no cell was updated (execution_count should remain None) - cells = ynotebook.get_cell_list() - for cell in cells: - assert cell.get("execution_count") is None - - @pytest.mark.asyncio - async def test_display_data_message_processing(self, kernel_client_with_yroom): - """Test that display_data messages are processed correctly.""" - client, yroom, ynotebook = kernel_client_with_yroom - - # Mock message cache to return cell_id - parent_msg_id = "execute-request-123" - cell_id = "test-cell-1" - client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) - - # Create display_data message - content = { - "data": { - "text/plain": "Hello World", - "text/html": "

Hello World

" - }, - "metadata": {} - } - msg_parts = self.create_kernel_message( - client.session, "display_data", content, parent_msg_id, cell_id - ) - - # Process the message - result = await client.handle_document_related_message(msg_parts[1:]) - - # Verify the output processor was called - client.output_processor.process_output.assert_called_once_with( - "display_data", cell_id, content - ) - - # Verify the message was suppressed - assert result is None \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_manager.py b/jupyter_server_documents/tests/kernels/test_kernel_manager.py deleted file mode 100644 index b6c7d48..0000000 --- a/jupyter_server_documents/tests/kernels/test_kernel_manager.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -from unittest.mock import patch - -from jupyter_server_documents.kernels.kernel_manager import NextGenKernelManager -from jupyter_server_documents.kernels.states import ExecutionStates, LifecycleStates - - -class TestNextGenKernelManager: - """Test cases for NextGenKernelManager.""" - - def test_set_state_lifecycle_only(self): - """Test setting only lifecycle state.""" - km = NextGenKernelManager() - km.set_state(LifecycleStates.STARTING) - assert km.lifecycle_state == LifecycleStates.STARTING.value - - def test_set_state_execution_only(self): - """Test setting only execution state.""" - km = NextGenKernelManager() - km.set_state(execution_state=ExecutionStates.IDLE) - assert km.execution_state == ExecutionStates.IDLE.value - - def test_set_state_both(self): - """Test setting both lifecycle and execution states.""" - km = NextGenKernelManager() - km.set_state(LifecycleStates.CONNECTED, ExecutionStates.BUSY) - assert km.lifecycle_state == LifecycleStates.CONNECTED.value - assert km.execution_state == ExecutionStates.BUSY.value - - def test_lifecycle_state_validation(self): - """Test lifecycle state validation.""" - km = NextGenKernelManager() - with pytest.raises(Exception): - km.lifecycle_state = "invalid_state" - - def test_execution_state_validation(self): - """Test execution state validation.""" - km = NextGenKernelManager() - with pytest.raises(Exception): - km.execution_state = "invalid_state" - - def test_execution_state_listener_non_iopub_channel(self): - """Test execution state listener ignores non-iopub channels.""" - km = NextGenKernelManager() - original_state = km.execution_state - - km.execution_state_listener("shell", [b"test", b"message"]) - - # State should remain unchanged - assert km.execution_state == original_state - - @pytest.mark.asyncio - async def test_disconnect_without_client(self): - """Test disconnecting when no client exists.""" - km = NextGenKernelManager() - km.main_client = None - - # Should not raise an exception - await km.disconnect() - - @pytest.mark.asyncio - async def test_restart_kernel_sets_state(self): - """Test that restart_kernel sets restarting state.""" - km = NextGenKernelManager() - - with patch('jupyter_client.manager.AsyncKernelManager.restart_kernel') as mock_restart: - mock_restart.return_value = None - await km.restart_kernel() - - assert km.lifecycle_state == LifecycleStates.RESTARTING.value - mock_restart.assert_called_once() \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py b/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py deleted file mode 100644 index 6cfd872..0000000 --- a/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, MagicMock, patch - -from jupyter_server_documents.kernels.multi_kernel_manager import NextGenMappingKernelManager - - -@pytest.fixture -def multi_kernel_manager(): - """Create a NextGenMappingKernelManager instance for testing.""" - mkm = NextGenMappingKernelManager() - mkm._check_kernel_id = MagicMock() - mkm.pinned_superclass = MagicMock() - mkm.pinned_superclass._async_restart_kernel = AsyncMock() - return mkm - - -class TestNextGenMappingKernelManager: - """Test cases for NextGenMappingKernelManager.""" - - def test_start_watching_activity_noop(self, multi_kernel_manager): - """Test that start_watching_activity does nothing.""" - # Should not raise an exception - multi_kernel_manager.start_watching_activity("test-kernel-id") - - def test_stop_buffering_noop(self, multi_kernel_manager): - """Test that stop_buffering does nothing.""" - # Should not raise an exception - multi_kernel_manager.stop_buffering("test-kernel-id") - - @pytest.mark.asyncio - async def test_restart_kernel_checks_id(self, multi_kernel_manager): - """Test that restart_kernel checks kernel ID.""" - kernel_id = "test-kernel-id" - - await multi_kernel_manager.restart_kernel(kernel_id) - - multi_kernel_manager._check_kernel_id.assert_called_once_with(kernel_id) - - @pytest.mark.asyncio - async def test_restart_kernel_calls_superclass(self, multi_kernel_manager): - """Test that restart_kernel calls the superclass method.""" - kernel_id = "test-kernel-id" - - await multi_kernel_manager.restart_kernel(kernel_id, now=True) - - multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_called_once_with( - multi_kernel_manager, kernel_id, now=True - ) - - @pytest.mark.asyncio - async def test_restart_kernel_default_now_parameter(self, multi_kernel_manager): - """Test that restart_kernel uses default now=False.""" - kernel_id = "test-kernel-id" - - await multi_kernel_manager.restart_kernel(kernel_id) - - multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_called_once_with( - multi_kernel_manager, kernel_id, now=False - ) - - @pytest.mark.asyncio - async def test_restart_kernel_propagates_exceptions(self, multi_kernel_manager): - """Test that restart_kernel propagates exceptions from superclass.""" - kernel_id = "test-kernel-id" - test_exception = Exception("Test restart error") - multi_kernel_manager.pinned_superclass._async_restart_kernel.side_effect = test_exception - - with pytest.raises(Exception, match="Test restart error"): - await multi_kernel_manager.restart_kernel(kernel_id) - - @pytest.mark.asyncio - async def test_restart_kernel_propagates_id_check_exceptions(self, multi_kernel_manager): - """Test that restart_kernel propagates exceptions from kernel ID check.""" - kernel_id = "invalid-kernel-id" - test_exception = ValueError("Invalid kernel ID") - multi_kernel_manager._check_kernel_id.side_effect = test_exception - - with pytest.raises(ValueError, match="Invalid kernel ID"): - await multi_kernel_manager.restart_kernel(kernel_id) - - # Superclass method should not be called if ID check fails - multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_not_called() \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_states.py b/jupyter_server_documents/tests/kernels/test_states.py deleted file mode 100644 index 3ca60d3..0000000 --- a/jupyter_server_documents/tests/kernels/test_states.py +++ /dev/null @@ -1,175 +0,0 @@ -import pytest - -from jupyter_server_documents.kernels.states import LifecycleStates, ExecutionStates, StrContainerEnum, StrContainerEnumMeta - - -class TestStrContainerEnumMeta: - """Test cases for StrContainerEnumMeta.""" - - def test_contains_by_name(self): - """Test that enum names are found with 'in' operator.""" - assert "IDLE" in ExecutionStates - assert "STARTED" in LifecycleStates - - def test_contains_by_value(self): - """Test that enum values are found with 'in' operator.""" - assert "idle" in ExecutionStates - assert "started" in LifecycleStates - - def test_contains_missing(self): - """Test that missing items are not found.""" - assert "MISSING" not in ExecutionStates - assert "missing" not in LifecycleStates - - -class TestStrContainerEnum: - """Test cases for StrContainerEnum base class.""" - - def test_is_string_subclass(self): - """Test that StrContainerEnum is a string subclass.""" - assert issubclass(StrContainerEnum, str) - - def test_enum_value_is_string(self): - """Test that enum values can be used as strings.""" - idle_state = ExecutionStates.IDLE - assert isinstance(idle_state, str) - assert idle_state == "idle" - assert idle_state.upper() == "IDLE" - - -class TestLifecycleStates: - """Test cases for LifecycleStates enum.""" - - def test_all_states_defined(self): - """Test that all expected lifecycle states are defined.""" - expected_states = [ - "UNKNOWN", "STARTING", "STARTED", "TERMINATING", "CONNECTING", - "CONNECTED", "RESTARTING", "RECONNECTING", "CULLED", - "DISCONNECTED", "TERMINATED", "DEAD" - ] - - for state in expected_states: - assert hasattr(LifecycleStates, state) - - def test_state_values(self): - """Test that state values are lowercase versions of names.""" - assert LifecycleStates.UNKNOWN.value == "unknown" - assert LifecycleStates.STARTING.value == "starting" - assert LifecycleStates.STARTED.value == "started" - assert LifecycleStates.TERMINATING.value == "terminating" - assert LifecycleStates.CONNECTING.value == "connecting" - assert LifecycleStates.CONNECTED.value == "connected" - assert LifecycleStates.RESTARTING.value == "restarting" - assert LifecycleStates.RECONNECTING.value == "reconnecting" - assert LifecycleStates.CULLED.value == "culled" - assert LifecycleStates.DISCONNECTED.value == "disconnected" - assert LifecycleStates.TERMINATED.value == "terminated" - assert LifecycleStates.DEAD.value == "dead" - - def test_state_equality(self): - """Test that states can be compared by value.""" - assert LifecycleStates.UNKNOWN == "unknown" - assert LifecycleStates.STARTING == "starting" - assert LifecycleStates.CONNECTED == "connected" - - def test_state_membership(self): - """Test state membership using 'in' operator.""" - assert "starting" in LifecycleStates - assert "STARTING" in LifecycleStates - assert "connected" in LifecycleStates - assert "CONNECTED" in LifecycleStates - assert "invalid_state" not in LifecycleStates - - def test_state_iteration(self): - """Test iterating over lifecycle states.""" - states = list(LifecycleStates) - assert len(states) == 12 # Total number of defined states - assert LifecycleStates.UNKNOWN in states - assert LifecycleStates.DEAD in states - - -class TestExecutionStates: - """Test cases for ExecutionStates enum.""" - - def test_all_states_defined(self): - """Test that all expected execution states are defined.""" - expected_states = ["BUSY", "IDLE", "STARTING", "UNKNOWN", "DEAD"] - - for state in expected_states: - assert hasattr(ExecutionStates, state) - - def test_state_values(self): - """Test that state values are lowercase versions of names.""" - assert ExecutionStates.BUSY.value == "busy" - assert ExecutionStates.IDLE.value == "idle" - assert ExecutionStates.STARTING.value == "starting" - assert ExecutionStates.UNKNOWN.value == "unknown" - assert ExecutionStates.DEAD.value == "dead" - - def test_state_equality(self): - """Test that states can be compared by value.""" - assert ExecutionStates.BUSY == "busy" - assert ExecutionStates.IDLE == "idle" - assert ExecutionStates.STARTING == "starting" - assert ExecutionStates.UNKNOWN == "unknown" - assert ExecutionStates.DEAD == "dead" - - def test_state_membership(self): - """Test state membership using 'in' operator.""" - assert "busy" in ExecutionStates - assert "BUSY" in ExecutionStates - assert "idle" in ExecutionStates - assert "IDLE" in ExecutionStates - assert "invalid_state" not in ExecutionStates - - def test_state_iteration(self): - """Test iterating over execution states.""" - states = list(ExecutionStates) - assert len(states) == 5 # Total number of defined states - assert ExecutionStates.BUSY in states - assert ExecutionStates.IDLE in states - - def test_state_string_operations(self): - """Test that states can be used in string operations.""" - busy_state = ExecutionStates.BUSY - assert busy_state.upper() == "BUSY" - assert busy_state.capitalize() == "Busy" - assert len(busy_state) == 4 - assert busy_state.startswith("b") - - -class TestEnumIntegration: - """Integration tests for both enums.""" - - def test_enum_types_are_different(self): - """Test that the two enum types are distinct.""" - # Since both are StrContainerEnum subclasses, they compare as equal strings - # but they are different types - assert type(LifecycleStates.STARTING) != type(ExecutionStates.STARTING) - assert LifecycleStates.STARTING is not ExecutionStates.STARTING - - def test_enum_values_can_be_same(self): - """Test that enum values can be the same string.""" - # Both have "starting", "unknown", "dead" values - assert LifecycleStates.STARTING.value == ExecutionStates.STARTING.value == "starting" - assert LifecycleStates.UNKNOWN.value == ExecutionStates.UNKNOWN.value == "unknown" - assert LifecycleStates.DEAD.value == ExecutionStates.DEAD.value == "dead" - - def test_enum_members_are_unique_within_enum(self): - """Test that enum members are unique within their enum.""" - lifecycle_values = [state.value for state in LifecycleStates] - execution_values = [state.value for state in ExecutionStates] - - # Check for uniqueness within each enum - assert len(lifecycle_values) == len(set(lifecycle_values)) - assert len(execution_values) == len(set(execution_values)) - - def test_enum_membership_is_type_specific(self): - """Test that membership checks are type-specific.""" - # "idle" is in ExecutionStates but not in LifecycleStates - assert "idle" in ExecutionStates - assert "idle" not in LifecycleStates - - # "connected" is in LifecycleStates but not in ExecutionStates - assert "connected" in LifecycleStates - assert "connected" not in ExecutionStates \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_websocket_connection.py b/jupyter_server_documents/tests/kernels/test_websocket_connection.py deleted file mode 100644 index 3e73536..0000000 --- a/jupyter_server_documents/tests/kernels/test_websocket_connection.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from tornado.websocket import WebSocketClosedError - -from jupyter_server_documents.kernels.websocket_connection import NextGenKernelWebsocketConnection - - -class TestNextGenKernelWebsocketConnection: - """Test cases for NextGenKernelWebsocketConnection.""" - - def test_kernel_ws_protocol(self): - """Test that the websocket protocol is set correctly.""" - assert NextGenKernelWebsocketConnection.kernel_ws_protocol == "v1.kernel.websocket.jupyter.org" - - def test_inheritance(self): - """Test that the class inherits from BaseKernelWebsocketConnection.""" - from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection - - assert issubclass(NextGenKernelWebsocketConnection, BaseKernelWebsocketConnection) - - # Test that required methods are implemented - conn = NextGenKernelWebsocketConnection() - assert hasattr(conn, 'connect') - assert hasattr(conn, 'disconnect') - assert hasattr(conn, 'handle_incoming_message') - assert hasattr(conn, 'handle_outgoing_message') - assert hasattr(conn, 'kernel_ws_protocol') - - @patch('jupyter_server_documents.kernels.websocket_connection.deserialize_msg_from_ws_v1') - def test_handle_incoming_message_deserializes(self, mock_deserialize): - """Test that incoming messages are deserialized correctly.""" - conn = NextGenKernelWebsocketConnection() - - # Mock the kernel_manager property - mock_kernel_manager = MagicMock() - mock_kernel_manager.main_client = MagicMock() - - with patch.object(type(conn), 'kernel_manager', mock_kernel_manager): - mock_deserialize.return_value = ("shell", [b"test", b"message"]) - - incoming_msg = b"test_websocket_message" - conn.handle_incoming_message(incoming_msg) - - mock_deserialize.assert_called_once_with(incoming_msg) - - @patch('jupyter_server_documents.kernels.websocket_connection.deserialize_msg_from_ws_v1') - def test_handle_incoming_message_no_client(self, mock_deserialize): - """Test that incoming messages are ignored when no client exists.""" - conn = NextGenKernelWebsocketConnection() - - # Mock the kernel_manager property with no client - mock_kernel_manager = MagicMock() - mock_kernel_manager.main_client = None - - with patch.object(type(conn), 'kernel_manager', mock_kernel_manager): - mock_deserialize.return_value = ("shell", [b"test", b"message"]) - - incoming_msg = b"test_websocket_message" - - # Should not raise an exception - conn.handle_incoming_message(incoming_msg) - - @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') - def test_handle_outgoing_message_removes_signature(self, mock_serialize): - """Test that the signature is properly removed from outgoing messages.""" - conn = NextGenKernelWebsocketConnection() - - # Mock websocket_handler and log to avoid traitlet validation - mock_handler = MagicMock() - mock_log = MagicMock() - - with patch.object(type(conn), 'websocket_handler', mock_handler): - with patch.object(type(conn), 'log', mock_log): - mock_serialize.return_value = b"serialized_message" - - # Message with signature at index 0 - msg = [b"signature", b"header", b"parent", b"metadata", b"content"] - conn.handle_outgoing_message("iopub", msg) - - # Should call serialize with msg[1:] (signature removed) - mock_serialize.assert_called_once_with( - [b"header", b"parent", b"metadata", b"content"], "iopub" - ) - - @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') - def test_handle_outgoing_message_websocket_closed(self, mock_serialize): - """Test that closed websocket errors are handled gracefully.""" - conn = NextGenKernelWebsocketConnection() - - mock_serialize.return_value = b"serialized_message" - - # Mock websocket_handler to raise WebSocketClosedError - mock_handler = MagicMock() - mock_handler.write_message.side_effect = WebSocketClosedError() - mock_log = MagicMock() - - with patch.object(type(conn), 'websocket_handler', mock_handler): - with patch.object(type(conn), 'log', mock_log): - msg = [b"signature", b"header", b"parent", b"metadata", b"content"] - conn.handle_outgoing_message("iopub", msg) - - mock_log.warning.assert_called_once_with( - "A ZMQ message arrived on a closed websocket channel." - ) - - @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') - def test_handle_outgoing_message_general_exception(self, mock_serialize): - """Test that general exceptions are handled gracefully.""" - conn = NextGenKernelWebsocketConnection() - - mock_serialize.return_value = b"serialized_message" - test_exception = Exception("Test exception") - - # Mock websocket_handler to raise exception - mock_handler = MagicMock() - mock_handler.write_message.side_effect = test_exception - mock_log = MagicMock() - - with patch.object(type(conn), 'websocket_handler', mock_handler): - with patch.object(type(conn), 'log', mock_log): - msg = [b"signature", b"header", b"parent", b"metadata", b"content"] - conn.handle_outgoing_message("iopub", msg) - - mock_log.error.assert_called_once_with(test_exception) \ No newline at end of file diff --git a/jupyter_server_documents/tests/test_kernel_message_cache.py b/jupyter_server_documents/tests/test_kernel_message_cache.py index f89aadc..20f21bc 100644 --- a/jupyter_server_documents/tests/test_kernel_message_cache.py +++ b/jupyter_server_documents/tests/test_kernel_message_cache.py @@ -1,6 +1,6 @@ import pytest from collections import OrderedDict -from jupyter_server_documents.kernels.message_cache import InvalidKeyException, KernelMessageCache, MissingKeyException # Replace your_module +from nextgen_kernels_api.services.kernels.cache import InvalidKeyException, KernelMessageCache, MissingKeyException def create_cache(maxsize=None): diff --git a/jupyter_server_documents/tests/test_session_manager.py b/jupyter_server_documents/tests/test_session_manager.py new file mode 100644 index 0000000..4d42fe9 --- /dev/null +++ b/jupyter_server_documents/tests/test_session_manager.py @@ -0,0 +1,293 @@ +"""Tests for YDocSessionManager yroom-kernel connection logic. + +These tests verify that the session manager properly maintains connections +between yrooms (collaborative document state) and kernel clients, especially +for persistent kernels that survive server restarts. +""" +import pytest +from unittest.mock import AsyncMock, Mock, MagicMock, patch +from traitlets.config import LoggingConfigurable +from jupyter_server_documents.session_manager import YDocSessionManager + + +@pytest.fixture +def session_manager(): + """Create a mock session manager for testing.""" + # Create mock dependencies + mock_file_id_manager = Mock() + mock_yroom_manager = Mock() + mock_kernel_manager = Mock() + + # Create a Configurable parent with the proper structure + class MockServerApp(LoggingConfigurable): + @property + def kernel_manager(self): + return mock_kernel_manager + + @property + def web_app(self): + mock_web_app = Mock() + mock_web_app.settings = { + "file_id_manager": mock_file_id_manager, + "yroom_manager": mock_yroom_manager + } + return mock_web_app + + # Create the session manager with mock parent + manager = YDocSessionManager(parent=MockServerApp()) + + # Initialize the _room_ids dict + manager._room_ids = {} + + return manager + + +@pytest.fixture +def mock_kernel_client(): + """Create a mock kernel client with _yrooms attribute.""" + client = Mock() + client._yrooms = set() + return client + + +@pytest.fixture +def mock_yroom(): + """Create a mock YRoom.""" + yroom = Mock() + yroom.room_id = "json:notebook:test-file-id" + return yroom + + +class TestEnsureYRoomConnected: + """Tests for _ensure_yroom_connected method.""" + + @pytest.mark.asyncio + async def test_uses_cached_room_id(self, session_manager, mock_yroom, mock_kernel_client): + """Test that cached room_id is used when available.""" + session_id = "session-123" + kernel_id = "kernel-456" + room_id = "json:notebook:cached-file-id" + + # Set up cached room_id + session_manager._room_ids[session_id] = room_id + + # Mock yroom manager + session_manager.yroom_manager.get_room.return_value = mock_yroom + mock_yroom.room_id = room_id + + # Mock kernel client's add_yroom method as async + mock_kernel_client.add_yroom = AsyncMock() + + # Mock kernel manager + mock_kernel_manager = Mock() + mock_kernel_manager.kernel_client = mock_kernel_client + session_manager.serverapp.kernel_manager.get_kernel.return_value = mock_kernel_manager + + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + # Verify cached room_id was used + session_manager.yroom_manager.get_room.assert_called_once_with(room_id) + + # Verify add_yroom was called with the mock yroom + mock_kernel_client.add_yroom.assert_called_once_with(mock_yroom) + + @pytest.mark.asyncio + async def test_reconstructs_room_id_from_session_path(self, session_manager, mock_yroom, mock_kernel_client): + """Test that room_id is reconstructed from session path when not cached.""" + session_id = "session-123" + kernel_id = "kernel-456" + path = "/path/to/notebook.ipynb" + file_id = "reconstructed-file-id" + room_id = f"json:notebook:{file_id}" + + # Mock get_session from parent (SessionManager) to return session with path + mock_session = { + "id": session_id, + "type": "notebook", + "path": path + } + + # Patch the parent class's get_session method + with patch('jupyter_server.services.sessions.sessionmanager.SessionManager.get_session', new_callable=AsyncMock) as mock_parent_get_session: + mock_parent_get_session.return_value = mock_session + + # Mock file_id_manager + session_manager.file_id_manager.index.return_value = file_id + + # Mock yroom manager + session_manager.yroom_manager.get_room.return_value = mock_yroom + mock_yroom.room_id = room_id + + # Mock kernel client's add_yroom as async + mock_kernel_client.add_yroom = AsyncMock() + + # Mock kernel manager + mock_kernel_manager = Mock() + mock_kernel_manager.kernel_client = mock_kernel_client + session_manager.serverapp.kernel_manager.get_kernel.return_value = mock_kernel_manager + + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + # Verify room_id was reconstructed + session_manager.file_id_manager.index.assert_called_once_with(path) + + # Verify room_id was cached + assert session_manager._room_ids[session_id] == room_id + + # Verify add_yroom was called + mock_kernel_client.add_yroom.assert_called_once_with(mock_yroom) + + @pytest.mark.asyncio + async def test_skips_non_notebook_sessions(self, session_manager): + """Test that non-notebook sessions are skipped.""" + session_id = "session-123" + kernel_id = "kernel-456" + + # Mock get_session to return console session + mock_session = { + "id": session_id, + "type": "console", + "path": "/path/to/console" + } + + with patch('jupyter_server.services.sessions.sessionmanager.SessionManager.get_session', new_callable=AsyncMock) as mock_parent_get_session: + mock_parent_get_session.return_value = mock_session + + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + # Verify no room_id was created + assert session_id not in session_manager._room_ids + + @pytest.mark.asyncio + async def test_skips_when_yroom_already_connected(self, session_manager, mock_yroom, mock_kernel_client): + """Test that already-connected yrooms are not re-added.""" + session_id = "session-123" + kernel_id = "kernel-456" + room_id = "json:notebook:test-file-id" + + # Set up cached room_id + session_manager._room_ids[session_id] = room_id + + # Mock yroom manager + session_manager.yroom_manager.get_room.return_value = mock_yroom + + # Yroom already in kernel client's _yrooms + mock_kernel_client._yrooms.add(mock_yroom) + + # Mock kernel manager + mock_kernel_manager = Mock() + mock_kernel_manager.kernel_client = mock_kernel_client + session_manager.serverapp.kernel_manager.get_kernel.return_value = mock_kernel_manager + + # Track initial state + initial_yrooms_count = len(mock_kernel_client._yrooms) + + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + # Verify yroom was not added again (count unchanged) + assert len(mock_kernel_client._yrooms) == initial_yrooms_count + + @pytest.mark.asyncio + async def test_handles_missing_yroom_gracefully(self, session_manager): + """Test that missing yroom is handled gracefully without errors.""" + session_id = "session-123" + kernel_id = "kernel-456" + room_id = "json:notebook:missing-file-id" + + # Set up cached room_id + session_manager._room_ids[session_id] = room_id + + # Mock yroom manager to return None (yroom doesn't exist) + session_manager.yroom_manager.get_room.return_value = None + + # Should not raise an error + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + @pytest.mark.asyncio + async def test_handles_kernel_client_without_yrooms_attribute(self, session_manager, mock_yroom): + """Test graceful handling when kernel client doesn't have _yrooms attribute.""" + session_id = "session-123" + kernel_id = "kernel-456" + room_id = "json:notebook:test-file-id" + + # Set up cached room_id + session_manager._room_ids[session_id] = room_id + + # Mock yroom manager + session_manager.yroom_manager.get_room.return_value = mock_yroom + + # Mock kernel client WITHOUT _yrooms attribute + mock_kernel_client = Mock(spec=[]) # Empty spec, no _yrooms + + # Mock kernel manager + mock_kernel_manager = Mock() + mock_kernel_manager.kernel_client = mock_kernel_client + session_manager.serverapp.kernel_manager.get_kernel.return_value = mock_kernel_manager + + # Should not raise an error + await session_manager._ensure_yroom_connected(session_id, kernel_id) + + +class TestGetSession: + """Tests for get_session method override.""" + + @pytest.mark.asyncio + async def test_calls_ensure_yroom_connected(self, session_manager): + """Test that get_session calls _ensure_yroom_connected for notebook sessions.""" + session_id = "session-123" + kernel_id = "kernel-456" + + mock_session = { + "id": session_id, + "kernel": {"id": kernel_id}, + "type": "notebook" + } + + # Patch the parent SessionManager's get_session + with patch('jupyter_server.services.sessions.sessionmanager.SessionManager.get_session', new_callable=AsyncMock) as mock_parent_get_session: + mock_parent_get_session.return_value = mock_session + + with patch.object(session_manager, '_ensure_yroom_connected', new_callable=AsyncMock) as mock_ensure: + result = await session_manager.get_session(session_id=session_id) + + # Verify _ensure_yroom_connected was called + mock_ensure.assert_called_once_with(session_id, kernel_id) + + # Verify session was returned + assert result == mock_session + + @pytest.mark.asyncio + async def test_handles_none_session(self, session_manager): + """Test that get_session handles None session gracefully.""" + with patch('jupyter_server.services.sessions.sessionmanager.SessionManager.get_session', new_callable=AsyncMock) as mock_parent_get_session: + mock_parent_get_session.return_value = None + + with patch.object(session_manager, '_ensure_yroom_connected', new_callable=AsyncMock) as mock_ensure: + result = await session_manager.get_session(session_id="missing-session") + + # Verify _ensure_yroom_connected was NOT called + mock_ensure.assert_not_called() + + # Verify None was returned + assert result is None + + @pytest.mark.asyncio + async def test_handles_session_without_kernel(self, session_manager): + """Test that get_session handles sessions without kernel gracefully.""" + mock_session = { + "id": "session-123", + "kernel": None, + "type": "notebook" + } + + with patch('jupyter_server.services.sessions.sessionmanager.SessionManager.get_session', new_callable=AsyncMock) as mock_parent_get_session: + mock_parent_get_session.return_value = mock_session + + with patch.object(session_manager, '_ensure_yroom_connected', new_callable=AsyncMock) as mock_ensure: + result = await session_manager.get_session(session_id="session-123") + + # Verify _ensure_yroom_connected was NOT called + mock_ensure.assert_not_called() + + # Verify session was returned + assert result == mock_session diff --git a/pyproject.toml b/pyproject.toml index ab8f111..9caa4c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "jupyter_server_fileid>=0.9.0,<0.10.0", "pycrdt>=0.12.0,<0.13.0", "jupyter_ydoc>=3.0.0,<4.0.0", + "nextgen-kernels-api>=0.9.0", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -96,3 +97,7 @@ before-build-python = ["jlpm clean:all"] [tool.check-wheel-contents] ignore = ["W002"] + +[tool.pytest.ini_options] +testpaths = ["jupyter_server_documents/tests"] +norecursedirs = ["repos", ".git", ".pixi", "node_modules"]