1- import json
21import threading
3- import time
4- import uuid
52from concurrent .futures import Future
6- from typing import Any , Callable , List , Optional
3+ from typing import Any , Callable , List , Optional , Dict
74
85import requests
96from e2b import EnvVars , ProcessMessage , Sandbox
107from e2b .constants import TIMEOUT
11- from websocket import create_connection
128
13- from e2b_code_interpreter .models import Error , KernelException , Result
9+ from e2b_code_interpreter .messaging import JupyterKernelWebSocket
10+ from e2b_code_interpreter .models import KernelException , Result
1411
1512
1613class CodeInterpreter (Sandbox ):
@@ -40,15 +37,18 @@ def __init__(
4037 ** kwargs ,
4138 )
4239 self .notebook = JupyterExtension (self )
40+ # Close all the websocket connections when the interpreter is closed
41+ self ._process_cleanup .append (self .notebook .close )
4342
4443
4544class JupyterExtension :
4645 _default_kernel_id : Optional [str ] = None
46+ _connected_kernels : Dict [str , JupyterKernelWebSocket ] = {}
4747
4848 def __init__ (self , sandbox : CodeInterpreter ):
4949 self ._sandbox = sandbox
5050 self ._kernel_id_set = Future ()
51- self ._set_default_kernel_id ()
51+ self ._start_connectiong_to_default_kernel ()
5252
5353 def exec_cell (
5454 self ,
@@ -58,12 +58,18 @@ def exec_cell(
5858 on_stderr : Optional [Callable [[ProcessMessage ], Any ]] = None ,
5959 ) -> Result :
6060 kernel_id = kernel_id or self .default_kernel_id
61- ws = self ._connect_kernel (kernel_id )
62- ws .send (json .dumps (self ._send_execute_request (code )))
63- result = self ._wait_for_result (ws , on_stdout , on_stderr )
61+ ws = self ._connected_kernels .get (kernel_id )
6462
65- ws .close ()
63+ if not ws :
64+ ws = JupyterKernelWebSocket (
65+ url = f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
66+ )
67+ self ._connected_kernels [kernel_id ] = ws
68+ ws .connect ()
69+
70+ session_id = ws .send_execution_message (code , on_stdout , on_stderr )
6671
72+ result = ws .get_result (session_id )
6773 return result
6874
6975 @property
@@ -73,31 +79,42 @@ def default_kernel_id(self) -> str:
7379
7480 return self ._default_kernel_id
7581
76- def create_kernel (self , timeout : Optional [float ] = TIMEOUT ) -> str :
82+ def create_kernel (self , cwd : Optional [str ] = None ,timeout : Optional [float ] = TIMEOUT ) -> str :
83+ data = {"cwd" : cwd } if cwd else None
7784 response = requests .post (
7885 f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels" ,
86+ json = data ,
7987 timeout = timeout ,
8088 )
8189 if not response .ok :
8290 raise KernelException (f"Failed to create kernel: { response .text } " )
83- return response .json ()["id" ]
91+
92+ kernel_id = response .json ()["id" ]
93+
94+ threading .Thread (target = self ._connect_to_kernel_ws , args = kernel_id ).start ()
95+ return kernel_id
8496
8597 def restart_kernel (
8698 self , kernel_id : Optional [str ] = None , timeout : Optional [float ] = TIMEOUT
8799 ) -> None :
88100 kernel_id = kernel_id or self .default_kernel_id
101+
102+ self ._connected_kernels [kernel_id ].close ()
89103 response = requests .post (
90104 f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /restart" ,
91105 timeout = timeout ,
92106 )
93107 if not response .ok :
94108 raise KernelException (f"Failed to restart kernel { kernel_id } " )
95109
110+ threading .Thread (target = self ._connect_to_kernel_ws , args = kernel_id ).start ()
111+
96112 def shutdown_kernel (
97113 self , kernel_id : Optional [str ] = None , timeout : Optional [float ] = TIMEOUT
98114 ) -> None :
99115 kernel_id = kernel_id or self .default_kernel_id
100116
117+ self ._connected_kernels [kernel_id ].close ()
101118 response = requests .delete (
102119 f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } " ,
103120 timeout = timeout ,
@@ -114,114 +131,21 @@ def list_kernels(self, timeout: Optional[float] = TIMEOUT) -> List[str]:
114131 raise KernelException (f"Failed to list kernels: { response .text } " )
115132 return [kernel ["id" ] for kernel in response .json ()]
116133
117- def _set_default_kernel_id (self , timeout : Optional [float ] = TIMEOUT ) -> None :
118- def set_kernel_id ():
119- self ._kernel_id_set .set_result (
120- self ._sandbox .filesystem .read ("/root/.jupyter/kernel_id" , timeout = timeout ).strip ()
121- )
122-
123- threading .Thread (target = set_kernel_id ).start ()
134+ def close (self ):
135+ for ws in self ._connected_kernels .values ():
136+ ws .close ()
124137
125- def _connect_kernel (self , kernel_id : str , timeout : Optional [float ] = TIMEOUT ):
126- return create_connection (
127- f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
128- timeout = timeout ,
138+ def _connect_to_kernel_ws (self , kernel_id : str ) -> None :
139+ ws = JupyterKernelWebSocket (
140+ url = f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
129141 )
142+ ws .connect ()
143+ self ._connected_kernels [kernel_id ] = ws
130144
131- @staticmethod
132- def _send_execute_request (code : str ) -> dict :
133- msg_id = str (uuid .uuid4 ())
134- session = str (uuid .uuid4 ())
135-
136- return {
137- "header" : {
138- "msg_id" : msg_id ,
139- "username" : "e2b" ,
140- "session" : session ,
141- "msg_type" : "execute_request" ,
142- "version" : "5.3" ,
143- },
144- "parent_header" : {},
145- "metadata" : {},
146- "content" : {
147- "code" : code ,
148- "silent" : False ,
149- "store_history" : False ,
150- "user_expressions" : {},
151- "allow_stdin" : False ,
152- },
153- }
154-
155- @staticmethod
156- def _wait_for_result (
157- ws ,
158- on_stdout : Optional [Callable [[ProcessMessage ], Any ]],
159- on_stderr : Optional [Callable [[ProcessMessage ], Any ]],
160- ) -> Result :
161- result = Result ()
162- input_accepted = False
163-
164- while True :
165- response = json .loads (ws .recv ())
166- if response ["msg_type" ] == "error" :
167- result .error = Error (
168- name = response ["content" ]["ename" ],
169- value = response ["content" ]["evalue" ],
170- traceback = response ["content" ]["traceback" ],
171- )
172-
173- elif response ["msg_type" ] == "stream" :
174- if response ["content" ]["name" ] == "stdout" :
175- result .stdout .append (response ["content" ]["text" ])
176- if on_stdout :
177- on_stdout (
178- ProcessMessage (
179- line = response ["content" ]["text" ],
180- timestamp = time .time_ns (),
181- )
182- )
183-
184- elif response ["content" ]["name" ] == "stderr" :
185- result .stderr .append (response ["content" ]["text" ])
186- if on_stderr :
187- on_stderr (
188- ProcessMessage (
189- line = response ["content" ]["text" ],
190- error = True ,
191- timestamp = time .time_ns (),
192- )
193- )
194-
195- elif response ["msg_type" ] == "display_data" :
196- result .display_data .append (response ["content" ]["data" ])
197-
198- elif response ["msg_type" ] == "execute_result" :
199- result .output = response ["content" ]["data" ]["text/plain" ]
200-
201- elif response ["msg_type" ] == "status" :
202- if response ["content" ]["execution_state" ] == "idle" :
203- if input_accepted :
204- break
205- elif response ["content" ]["execution_state" ] == "error" :
206- result .error = Error (
207- name = response ["content" ]["ename" ],
208- value = response ["content" ]["evalue" ],
209- traceback = response ["content" ]["traceback" ],
210- )
211- break
212-
213- elif response ["msg_type" ] == "execute_reply" :
214- if response ["content" ]["status" ] == "error" :
215- result .error = Error (
216- name = response ["content" ]["ename" ],
217- value = response ["content" ]["evalue" ],
218- traceback = response ["content" ]["traceback" ],
219- )
220- elif response ["content" ]["status" ] == "ok" :
221- pass
222-
223- elif response ["msg_type" ] == "execute_input" :
224- input_accepted = True
225- else :
226- print ("[UNHANDLED MESSAGE TYPE]:" , response ["msg_type" ])
227- return result
145+ def _start_connectiong_to_default_kernel (self , timeout : Optional [float ] = TIMEOUT ) -> None :
146+ def setup_default_kernel ():
147+ kernel_id = self ._sandbox .filesystem .read ("/root/.jupyter/kernel_id" , timeout = timeout ).strip ()
148+ self ._connect_to_kernel_ws (kernel_id )
149+ self ._kernel_id_set .set_result (kernel_id )
150+
151+ threading .Thread (target = setup_default_kernel ).start ()
0 commit comments