11from __future__ import annotations
22
33import json
4+ import os
45import signal
56import subprocess
67import sys
910from typing import cast
1011
1112import anyio
12- from anyio import TASK_STATUS_IGNORED , create_task_group , open_process
13+ from anyio import TASK_STATUS_IGNORED , create_task_group , open_file , open_process , run_process
1314from anyio .abc import TaskStatus
15+ from anyio .streams .text import TextReceiveStream
1416
1517from jupyverse_api .kernel import Kernel
1618
@@ -32,6 +34,7 @@ class KernelSubprocess(Kernel):
3234 kernel_cwd : str | None
3335 capture_output : bool
3436 connection_cfg : cfg_t | None = None
37+ kernelenv_path : str = ""
3538
3639 def __post_init__ (self ):
3740 super ().__init__ ()
@@ -45,6 +48,8 @@ def __post_init__(self):
4548 raise RuntimeError ("No connection_cfg" )
4649 self .key = cast (str , self .connection_cfg ["key" ])
4750 self .wait_for_ready = True
51+ self ._process = None
52+ self ._pid = None
4853
4954 async def start (self , * , task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ) -> None :
5055 async with (
@@ -64,23 +69,48 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) ->
6469 self ._from_iopub_receive_stream ,
6570 create_task_group () as self .task_group ,
6671 ):
67- with open (self .kernelspec_path ) as f :
68- kernelspec = json .load (f )
69- cmd = [s .format (connection_file = self .connection_file ) for s in kernelspec ["argv" ]]
70- if cmd and cmd [0 ] in {
71- "python" ,
72- f"python{ sys .version_info [0 ]} " ,
73- "python" + "." .join (map (str , sys .version_info [:2 ])),
74- }:
75- cmd [0 ] = sys .executable
72+ async with await open_file (self .kernelspec_path ) as f :
73+ contents = await f .read ()
74+ kernelspec = json .loads (contents )
75+ launch_kernel_cmd = [
76+ s .format (connection_file = self .connection_file ) for s in kernelspec ["argv" ]
77+ ]
7678 if self .capture_output :
7779 stdout = subprocess .DEVNULL
7880 stderr = subprocess .STDOUT
7981 else :
8082 stdout = None
8183 stderr = None
8284 kernel_cwd = self .kernel_cwd if self .kernel_cwd else None
83- self ._process = await open_process (cmd , stdout = stdout , stderr = stderr , cwd = kernel_cwd )
85+ kernelenv = ""
86+ if self .kernelenv_path :
87+ path = anyio .Path (self .kernelenv_path )
88+ if await path .is_file ():
89+ kernelenv = await path .read_text ()
90+ if kernelenv :
91+ import yaml # type: ignore[import-untyped]
92+ env_name = yaml .load (kernelenv , Loader = yaml .CLoader )["name" ]
93+ cmd = f"micromamba create -f { self .kernelenv_path } --yes"
94+ result = await run_process (cmd )
95+ if result .returncode == 0 :
96+ cmd = """bash -c 'eval "$(micromamba shell hook --shell bash)";""" + \
97+ f"micromamba activate { env_name } ;" + \
98+ " " .join (launch_kernel_cmd ) + "' & echo $!"
99+ process = await open_process (cmd )
100+ assert process .stdout is not None
101+ async for text in TextReceiveStream (process .stdout ):
102+ self ._pid = int (text )
103+ break
104+ else :
105+ if launch_kernel_cmd and launch_kernel_cmd [0 ] in {
106+ "python" ,
107+ f"python{ sys .version_info [0 ]} " ,
108+ "python" + "." .join (map (str , sys .version_info [:2 ])),
109+ }:
110+ launch_kernel_cmd [0 ] = sys .executable
111+ self ._process = await open_process (
112+ launch_kernel_cmd , stdout = stdout , stderr = stderr , cwd = kernel_cwd
113+ )
84114
85115 assert self .connection_cfg is not None
86116 identity = uuid .uuid4 ().hex .encode ("ascii" )
@@ -108,14 +138,17 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) ->
108138 self .started .set ()
109139
110140 async def stop (self ) -> None :
111- try :
112- self ._process .terminate ()
113- except ProcessLookupError :
114- pass
115- await self ._process .wait ()
116- if self .write_connection_file :
117- path = anyio .Path (self .connection_file )
118- await path .unlink (missing_ok = True )
141+ if self ._process :
142+ try :
143+ self ._process .terminate ()
144+ except ProcessLookupError :
145+ pass
146+ await self ._process .wait ()
147+ if self .write_connection_file :
148+ path = anyio .Path (self .connection_file )
149+ await path .unlink (missing_ok = True )
150+ else :
151+ os .kill (self ._pid , signal .SIGTERM )
119152
120153 await self .shell_channel .stop ()
121154 await self .stdin_channel .stop ()
@@ -124,7 +157,10 @@ async def stop(self) -> None:
124157 self .task_group .cancel_scope .cancel ()
125158
126159 async def interrupt (self ) -> None :
127- self ._process .send_signal (signal .SIGINT )
160+ if self ._process :
161+ self ._process .send_signal (signal .SIGINT )
162+ else :
163+ os .kill (self ._pid , signal .SIGINT )
128164
129165 async def forward_messages_to_shell (self ) -> None :
130166 async for msg in self ._to_shell_receive_stream :
0 commit comments