2323 UnexpectedEndOfExecution ,
2424)
2525from errors import ExecutionError
26+ from envs import get_envs
2627
2728logger = logging .getLogger (__name__ )
2829
@@ -47,7 +48,8 @@ def __init__(self, in_background: bool = False):
4748class ContextWebSocket :
4849 _ws : Optional [WebSocketClientProtocol ] = None
4950 _receive_task : Optional [asyncio .Task ] = None
50- global_env_vars : Optional [Dict [StrictStr , str ]] = None
51+ _global_env_vars : Optional [Dict [StrictStr , str ]] = None
52+ _cleanup_task : Optional [asyncio .Task ] = None
5153
5254 def __init__ (
5355 self ,
@@ -114,6 +116,113 @@ def _get_execute_request(
114116 }
115117 )
116118
119+ def _set_env_var_snippet (self , key : str , value : str ) -> str :
120+ """Get environment variable set command for the current language."""
121+ if self .language == "python" :
122+ return f"import os; os.environ['{ key } '] = '{ value } '"
123+ elif self .language in ["javascript" , "typescript" ]:
124+ return f"process.env['{ key } '] = '{ value } '"
125+ elif self .language == "deno" :
126+ return f"Deno.env.set('{ key } ', '{ value } ')"
127+ elif self .language == "r" :
128+ return f'Sys.setenv({ key } = "{ value } ")'
129+ elif self .language == "java" :
130+ return f'System.setProperty("{ key } ", "{ value } ");'
131+ elif self .language == "bash" :
132+ return f"export { key } ='{ value } '"
133+ return ""
134+
135+ def _delete_env_var_snippet (self , key : str ) -> str :
136+ """Get environment variable delete command for the current language."""
137+ if self .language == "python" :
138+ return f"import os; del os.environ['{ key } ']"
139+ elif self .language in ["javascript" , "typescript" ]:
140+ return f"delete process.env['{ key } ']"
141+ elif self .language == "deno" :
142+ return f"Deno.env.delete('{ key } ')"
143+ elif self .language == "r" :
144+ return f"Sys.unsetenv('{ key } ')"
145+ elif self .language == "java" :
146+ return f'System.clearProperty("{ key } ");'
147+ elif self .language == "bash" :
148+ return f"unset { key } "
149+ return ""
150+
151+ def _set_env_vars_code (self , env_vars : Dict [StrictStr , str ]) -> str :
152+ """Build environment variable code for the current language."""
153+ env_commands = []
154+ for k , v in env_vars .items ():
155+ command = self ._set_env_var_snippet (k , v )
156+ if command :
157+ env_commands .append (command )
158+
159+ return "\n " .join (env_commands )
160+
161+ def _reset_env_vars_code (self , env_vars : Dict [StrictStr , str ]) -> str :
162+ """Build environment variable cleanup code for the current language."""
163+ cleanup_commands = []
164+
165+ for key in env_vars :
166+ # Check if this var exists in global env vars
167+ if self ._global_env_vars and key in self ._global_env_vars :
168+ # Reset to global value
169+ value = self ._global_env_vars [key ]
170+ command = self ._set_env_var_snippet (key , value )
171+ else :
172+ # Remove the variable
173+ command = self ._delete_env_var_snippet (key )
174+
175+ if command :
176+ cleanup_commands .append (command )
177+
178+ return "\n " .join (cleanup_commands )
179+
180+ def _get_code_indentation (self , code : str ) -> str :
181+ """Get the indentation from the first non-empty line of code."""
182+ if not code or not code .strip ():
183+ return ""
184+
185+ lines = code .split ('\n ' )
186+ for line in lines :
187+ if line .strip (): # First non-empty line
188+ return line [:len (line ) - len (line .lstrip ())]
189+
190+ return ""
191+
192+ def _indent_code_with_level (self , code : str , indent_level : str ) -> str :
193+ """Apply the given indentation level to each line of code."""
194+ if not code or not indent_level :
195+ return code
196+
197+ lines = code .split ('\n ' )
198+ indented_lines = []
199+
200+ for line in lines :
201+ if line .strip (): # Non-empty lines
202+ indented_lines .append (indent_level + line )
203+ else :
204+ indented_lines .append (line )
205+
206+ return '\n ' .join (indented_lines )
207+
208+ async def _cleanup_env_vars (self , env_vars : Dict [StrictStr , str ]):
209+ """Clean up environment variables in a separate execution request."""
210+ message_id = str (uuid .uuid4 ())
211+ self ._executions [message_id ] = Execution (in_background = True )
212+
213+ try :
214+ cleanup_code = self ._reset_env_vars_code (env_vars )
215+ if cleanup_code :
216+ logger .info (f"Cleaning up env vars: { cleanup_code } " )
217+ request = self ._get_execute_request (message_id , cleanup_code , True )
218+ await self ._ws .send (request )
219+
220+ async for item in self ._wait_for_result (message_id ):
221+ if item ["type" ] == "error" :
222+ logger .error (f"Error during env var cleanup: { item } " )
223+ finally :
224+ del self ._executions [message_id ]
225+
117226 async def _wait_for_result (self , message_id : str ):
118227 queue = self ._executions [message_id ].queue
119228
@@ -133,84 +242,6 @@ async def _wait_for_result(self, message_id: str):
133242
134243 yield output .model_dump (exclude_none = True )
135244
136- async def set_env_vars (self , env_vars : Dict [StrictStr , str ]):
137- message_id = str (uuid .uuid4 ())
138- self ._executions [message_id ] = Execution (in_background = True )
139-
140- env_commands = []
141- for k , v in env_vars .items ():
142- if self .language == "python" :
143- env_commands .append (f"import os; os.environ['{ k } '] = '{ v } '" )
144- elif self .language in ["javascript" , "typescript" ]:
145- env_commands .append (f"process.env['{ k } '] = '{ v } '" )
146- elif self .language == "deno" :
147- env_commands .append (f"Deno.env.set('{ k } ', '{ v } ')" )
148- elif self .language == "r" :
149- env_commands .append (f'Sys.setenv({ k } = "{ v } ")' )
150- elif self .language == "java" :
151- env_commands .append (f'System.setProperty("{ k } ", "{ v } ");' )
152- elif self .language == "bash" :
153- env_commands .append (f"export { k } ='{ v } '" )
154- else :
155- return
156-
157- if env_commands :
158- env_vars_snippet = "\n " .join (env_commands )
159- logger .info (f"Setting env vars: { env_vars_snippet } for { self .language } " )
160- request = self ._get_execute_request (message_id , env_vars_snippet , True )
161- await self ._ws .send (request )
162-
163- async for item in self ._wait_for_result (message_id ):
164- if item ["type" ] == "error" :
165- raise ExecutionError (f"Error during execution: { item } " )
166-
167- async def reset_env_vars (self , env_vars : Dict [StrictStr , str ]):
168- # Create a dict of vars to reset and a list of vars to remove
169- vars_to_reset = {}
170- vars_to_remove = []
171-
172- for key in env_vars :
173- if self .global_env_vars and key in self .global_env_vars :
174- vars_to_reset [key ] = self .global_env_vars [key ]
175- else :
176- vars_to_remove .append (key )
177-
178- # Reset vars that exist in global env vars
179- if vars_to_reset :
180- await self .set_env_vars (vars_to_reset )
181-
182- # Remove vars that don't exist in global env vars
183- if vars_to_remove :
184- message_id = str (uuid .uuid4 ())
185- self ._executions [message_id ] = Execution (in_background = True )
186-
187- remove_commands = []
188- for key in vars_to_remove :
189- if self .language == "python" :
190- remove_commands .append (f"import os; del os.environ['{ key } ']" )
191- elif self .language in ["javascript" , "typescript" ]:
192- remove_commands .append (f"delete process.env['{ key } ']" )
193- elif self .language == "deno" :
194- remove_commands .append (f"Deno.env.delete('{ key } ')" )
195- elif self .language == "r" :
196- remove_commands .append (f"Sys.unsetenv('{ key } ')" )
197- elif self .language == "java" :
198- remove_commands .append (f'System.clearProperty("{ key } ");' )
199- elif self .language == "bash" :
200- remove_commands .append (f"unset { key } " )
201- else :
202- return
203-
204- if remove_commands :
205- remove_snippet = "\n " .join (remove_commands )
206- logger .info (f"Removing env vars: { remove_snippet } for { self .language } " )
207- request = self ._get_execute_request (message_id , remove_snippet , True )
208- await self ._ws .send (request )
209-
210- async for item in self ._wait_for_result (message_id ):
211- if item ["type" ] == "error" :
212- raise ExecutionError (f"Error during execution: { item } " )
213-
214245 async def change_current_directory (
215246 self , path : Union [str , StrictStr ], language : str
216247 ):
@@ -248,20 +279,44 @@ async def execute(
248279 env_vars : Dict [StrictStr , str ] = None ,
249280 ):
250281 message_id = str (uuid .uuid4 ())
251- logger .debug (f"Sending code for the execution ({ message_id } ): { code } " )
252-
253282 self ._executions [message_id ] = Execution ()
254283
255284 if self ._ws is None :
256285 raise Exception ("WebSocket not connected" )
257286
258287 async with self ._lock :
259- # set env vars (will override global env vars)
288+ # Wait for any pending cleanup task to complete
289+ if self ._cleanup_task and not self ._cleanup_task .done ():
290+ logger .debug ("Waiting for pending cleanup task to complete" )
291+ try :
292+ await self ._cleanup_task
293+ except Exception as e :
294+ logger .warning (f"Cleanup task failed: { e } " )
295+ finally :
296+ self ._cleanup_task = None
297+
298+ # Get the indentation level from the code
299+ code_indent = self ._get_code_indentation (code )
300+
301+ # Build the complete code snippet with env vars
302+ complete_code = code
303+
304+ global_env_vars_snippet = ""
305+ env_vars_snippet = ""
306+
307+ if self ._global_env_vars is None :
308+ self ._global_env_vars = await get_envs ()
309+ global_env_vars_snippet = self ._set_env_vars_code (self ._global_env_vars )
310+
260311 if env_vars :
261- await self .set_env_vars (env_vars )
312+ env_vars_snippet = self ._set_env_vars_code (env_vars )
262313
263- logger .info (code )
264- request = self ._get_execute_request (message_id , code , False )
314+ if global_env_vars_snippet or env_vars_snippet :
315+ indented_env_code = self ._indent_code_with_level (f"{ global_env_vars_snippet } \n { env_vars_snippet } " , code_indent )
316+ complete_code = f"{ indented_env_code } \n { complete_code } "
317+
318+ logger .info (f"Sending code for the execution ({ message_id } ): { complete_code } " )
319+ request = self ._get_execute_request (message_id , complete_code , False )
265320
266321 # Send the code for execution
267322 await self ._ws .send (request )
@@ -272,9 +327,9 @@ async def execute(
272327
273328 del self ._executions [message_id ]
274329
275- # reset env vars to their previous values, if they were set globally or remove them
330+ # Clean up env vars in a separate request after the main code has run
276331 if env_vars :
277- await self .reset_env_vars (env_vars )
332+ self . _cleanup_task = asyncio . create_task ( self ._cleanup_env_vars (env_vars ) )
278333
279334 async def _receive_message (self ):
280335 if not self ._ws :
@@ -434,7 +489,16 @@ async def close(self):
434489 if self ._ws is not None :
435490 await self ._ws .close ()
436491
437- self ._receive_task .cancel ()
492+ if self ._receive_task is not None :
493+ self ._receive_task .cancel ()
494+
495+ # Cancel any pending cleanup task
496+ if self ._cleanup_task and not self ._cleanup_task .done ():
497+ self ._cleanup_task .cancel ()
498+ try :
499+ await self ._cleanup_task
500+ except asyncio .CancelledError :
501+ pass
438502
439503 for execution in self ._executions .values ():
440504 execution .queue .put_nowait (UnexpectedEndOfExecution ())
0 commit comments