Skip to content

Commit 71b2e31

Browse files
authored
Moving Cell Execution State to Document Awareness (#146)
* Add unit and integration tests for the kernel API * Move cell execution state to the document's awareness object * Address PR feedback: simplify conditional logic and add YRoom helper method - Combine three similar else blocks in _updatePrompt method into one - Use optional chaining to simplify null checking in _getCellExecutionStateFromAwareness - Add set_cell_awareness_state method to YRoom for better encapsulation - Update kernel client to use new YRoom method for setting cell states
1 parent 8fe77a1 commit 71b2e31

File tree

14 files changed

+1324
-42
lines changed

14 files changed

+1324
-42
lines changed

conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import pytest
22

3-
pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin")
3+
pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin", "pytest_asyncio")
4+
5+
6+
def pytest_configure(config):
7+
"""Configure pytest settings."""
8+
# Set asyncio fixture loop scope to function to avoid warnings
9+
config.option.asyncio_default_fixture_loop_scope = "function"
410

511

612
@pytest.fixture

jupyter_server_documents/kernels/kernel_client.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,41 @@ async def stop_listening(self):
100100
_listening_task: t.Optional[t.Awaitable] = Any(allow_none=True)
101101

102102
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
103-
"""Use the given session to send the message."""
103+
"""
104+
Handle incoming kernel messages and set up immediate cell execution state tracking.
105+
106+
This method processes incoming kernel messages and caches them for response mapping.
107+
Importantly, it detects execute_request messages and immediately sets the corresponding
108+
cell state to 'busy' to provide real-time feedback for queued cell executions.
109+
110+
This ensures that when multiple cells are executed simultaneously, all queued cells
111+
show a '*' prompt immediately, not just the currently executing cell.
112+
113+
Args:
114+
channel_name: The kernel channel name (shell, iopub, etc.)
115+
msg: The raw kernel message as bytes
116+
"""
104117
# Cache the message ID and its socket name so that
105118
# any response message can be mapped back to the
106119
# source channel.
107120
header = self.session.unpack(msg[0])
108-
msg_id = header["msg_id"]
121+
msg_id = header["msg_id"]
122+
msg_type = header.get("msg_type")
109123
metadata = self.session.unpack(msg[2])
110124
cell_id = metadata.get("cellId")
111125

112-
# Clear cell outputs if cell is re-executedq
126+
# Clear cell outputs if cell is re-executed
113127
if cell_id:
114128
existing = self.message_cache.get(cell_id=cell_id)
115129
if existing and existing['msg_id'] != msg_id:
116130
asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id))
117131

132+
# IMPORTANT: Set cell to 'busy' immediately when execute_request is received
133+
# This ensures queued cells show '*' prompt even before kernel starts processing them
134+
if msg_type == "execute_request" and channel_name == "shell" and cell_id:
135+
for yroom in self._yrooms:
136+
yroom.set_cell_awareness_state(cell_id, "busy")
137+
118138
self.message_cache.add({
119139
"msg_id": msg_id,
120140
"channel": channel_name,
@@ -240,27 +260,27 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
240260
metadata["metadata"]["language_info"] = language_info
241261

242262
case "status":
243-
# Unpack cell-specific information and determine execution state
263+
# Handle kernel status messages and update cell execution states
264+
# This provides real-time feedback about cell execution progress
244265
content = self.session.unpack(dmsg["content"])
245266
execution_state = content.get("execution_state")
267+
246268
# Update status across all collaborative rooms
247269
for yroom in self._yrooms:
248-
# If this status came from the shell channel, update
249-
# the notebook status.
250-
if parent_msg_data["channel"] == "shell":
251-
awareness = yroom.get_awareness()
252-
if awareness is not None:
270+
awareness = yroom.get_awareness()
271+
if awareness is not None:
272+
# If this status came from the shell channel, update
273+
# the notebook kernel status.
274+
if parent_msg_data and parent_msg_data.get("channel") == "shell":
253275
# Update the kernel execution state at the top document level
254276
awareness.set_local_state_field("kernel", {"execution_state": execution_state})
255-
# Specifically update the running cell's execution state if cell_id is provided
256-
if cell_id:
257-
notebook = await yroom.get_jupyter_ydoc()
258-
_, target_cell = notebook.find_cell(cell_id)
259-
if target_cell:
260-
# Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
261-
# https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
262-
state = 'running' if execution_state == 'busy' else execution_state
263-
target_cell["execution_state"] = state
277+
278+
# Store cell execution state for persistence across client connections
279+
# This ensures that cell execution states survive page refreshes
280+
if cell_id:
281+
for yroom in self._yrooms:
282+
yroom.set_cell_execution_state(cell_id, execution_state)
283+
yroom.set_cell_awareness_state(cell_id, execution_state)
264284
break
265285

266286
case "execute_input":
@@ -278,8 +298,7 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
278298
case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output":
279299
if cell_id:
280300
# Process specific output messages through an optional processor
281-
if self.output_processor and cell_id:
282-
cell_id = parent_msg_data.get('cell_id')
301+
if self.output_processor:
283302
content = self.session.unpack(dmsg["content"])
284303
self.output_processor.process_output(dmsg['msg_type'], cell_id, content)
285304

jupyter_server_documents/rooms/yroom.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations # see PEP-563 for motivation behind this
2-
from typing import TYPE_CHECKING, cast
2+
from typing import TYPE_CHECKING, cast, Any
33
import asyncio
44
import uuid
55
import pycrdt
@@ -369,6 +369,38 @@ def get_awareness(self) -> pycrdt.Awareness:
369369
"""
370370
return self._awareness
371371

372+
def get_cell_execution_states(self) -> dict:
373+
"""
374+
Returns the persistent cell execution states for this room.
375+
These states survive client disconnections but are not saved to disk.
376+
"""
377+
if not hasattr(self, '_cell_execution_states'):
378+
self._cell_execution_states: dict[str, str] = {}
379+
return self._cell_execution_states
380+
381+
def set_cell_execution_state(self, cell_id: str, execution_state: str) -> None:
382+
"""
383+
Sets the execution state for a specific cell.
384+
This state persists across client disconnections.
385+
"""
386+
if not hasattr(self, '_cell_execution_states'):
387+
self._cell_execution_states = {}
388+
self._cell_execution_states[cell_id] = execution_state
389+
390+
def set_cell_awareness_state(self, cell_id: str, execution_state: str) -> None:
391+
"""
392+
Sets the execution state for a specific cell in the awareness system.
393+
This provides real-time updates to all connected clients.
394+
"""
395+
awareness = self.get_awareness()
396+
if awareness is not None:
397+
local_state = awareness.get_local_state()
398+
if local_state is not None:
399+
cell_states = local_state.get("cell_execution_states", {})
400+
else:
401+
cell_states = {}
402+
cell_states[cell_id] = execution_state
403+
awareness.set_local_state_field("cell_execution_states", cell_states)
372404

373405
def add_message(self, client_id: str, message: bytes) -> None:
374406
"""
@@ -512,7 +544,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
512544
return
513545

514546
self.clients.mark_synced(client_id)
515-
547+
516548
# Send SyncStep1 message
517549
try:
518550
assert isinstance(new_client.websocket, WebSocketHandler)

jupyter_server_documents/tests/kernels/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Configuration for kernel tests."""
2+
3+
import pytest
4+
from unittest.mock import MagicMock
5+
6+
7+
@pytest.fixture
8+
def mock_logger():
9+
"""Create a mock logger for testing."""
10+
return MagicMock()
11+
12+
13+
@pytest.fixture
14+
def mock_session():
15+
"""Create a mock session for testing."""
16+
session = MagicMock()
17+
session.msg_header.return_value = {"msg_id": "test-msg-id"}
18+
session.msg.return_value = {"test": "message"}
19+
session.serialize.return_value = ["", "serialized", "msg"]
20+
session.deserialize.return_value = {"msg_type": "test", "content": b"test"}
21+
session.unpack.return_value = {"test": "data"}
22+
session.feed_identities.return_value = ([], [b"test", b"message"])
23+
return session
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import pytest
2+
from unittest.mock import MagicMock, patch
3+
4+
from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient
5+
from jupyter_server_documents.kernels.message_cache import KernelMessageCache
6+
from jupyter_server_documents.outputs import OutputProcessor
7+
8+
9+
class TestDocumentAwareKernelClient:
10+
"""Test cases for DocumentAwareKernelClient."""
11+
12+
def test_default_message_cache(self):
13+
"""Test that message cache is created by default."""
14+
client = DocumentAwareKernelClient()
15+
assert isinstance(client.message_cache, KernelMessageCache)
16+
17+
def test_default_output_processor(self):
18+
"""Test that output processor is created by default."""
19+
client = DocumentAwareKernelClient()
20+
assert isinstance(client.output_processor, OutputProcessor)
21+
22+
@pytest.mark.asyncio
23+
async def test_stop_listening_no_task(self):
24+
"""Test that stop_listening does nothing when no task exists."""
25+
client = DocumentAwareKernelClient()
26+
client._listening_task = None
27+
28+
# Should not raise an exception
29+
await client.stop_listening()
30+
31+
def test_add_listener(self):
32+
"""Test adding a listener."""
33+
client = DocumentAwareKernelClient()
34+
35+
def test_listener(channel, msg):
36+
pass
37+
38+
client.add_listener(test_listener)
39+
40+
assert test_listener in client._listeners
41+
42+
def test_remove_listener(self):
43+
"""Test removing a listener."""
44+
client = DocumentAwareKernelClient()
45+
46+
def test_listener(channel, msg):
47+
pass
48+
49+
client.add_listener(test_listener)
50+
client.remove_listener(test_listener)
51+
52+
assert test_listener not in client._listeners
53+
54+
@pytest.mark.asyncio
55+
async def test_add_yroom(self):
56+
"""Test adding a YRoom."""
57+
client = DocumentAwareKernelClient()
58+
59+
mock_yroom = MagicMock()
60+
await client.add_yroom(mock_yroom)
61+
62+
assert mock_yroom in client._yrooms
63+
64+
@pytest.mark.asyncio
65+
async def test_remove_yroom(self):
66+
"""Test removing a YRoom."""
67+
client = DocumentAwareKernelClient()
68+
69+
mock_yroom = MagicMock()
70+
client._yrooms.add(mock_yroom)
71+
72+
await client.remove_yroom(mock_yroom)
73+
74+
assert mock_yroom not in client._yrooms
75+
76+
def test_send_kernel_info_creates_message(self):
77+
"""Test that send_kernel_info creates a kernel info message."""
78+
client = DocumentAwareKernelClient()
79+
80+
# Mock session
81+
from jupyter_client.session import Session
82+
client.session = Session()
83+
84+
with patch.object(client, 'handle_incoming_message') as mock_handle:
85+
client.send_kernel_info()
86+
87+
# Verify that handle_incoming_message was called with shell channel
88+
mock_handle.assert_called_once()
89+
args, kwargs = mock_handle.call_args
90+
assert args[0] == "shell" # Channel name
91+
assert isinstance(args[1], list) # Message list
92+
93+
@pytest.mark.asyncio
94+
async def test_handle_outgoing_message_control_channel(self):
95+
"""Test that control channel messages bypass document handling."""
96+
client = DocumentAwareKernelClient()
97+
98+
msg = [b"test", b"message"]
99+
100+
with patch.object(client, 'handle_document_related_message') as mock_handle_doc:
101+
with patch.object(client, 'send_message_to_listeners') as mock_send:
102+
await client.handle_outgoing_message("control", msg)
103+
104+
mock_handle_doc.assert_not_called()
105+
mock_send.assert_called_once_with("control", msg)

0 commit comments

Comments
 (0)