@@ -347,6 +347,49 @@ def test_populate_invocation_agent_states_no_content(self):
347347 assert not invocation_context .agent_states
348348 assert not invocation_context .end_of_agents
349349
350+ def test_set_agent_state_with_end_of_agent_true (self ):
351+ """Tests that set_agent_state clears agent_state and sets end_of_agent to True."""
352+ invocation_context = self ._create_test_invocation_context (
353+ ResumabilityConfig (is_resumable = True )
354+ )
355+ invocation_context .agent_states ['agent1' ] = {}
356+ invocation_context .end_of_agents ['agent1' ] = False
357+
358+ # Set state with end_of_agent=True, which should clear the existing
359+ # agent_state.
360+ invocation_context .set_agent_state ('agent1' , end_of_agent = True )
361+ assert 'agent1' not in invocation_context .agent_states
362+ assert invocation_context .end_of_agents ['agent1' ]
363+
364+ def test_set_agent_state_with_agent_state (self ):
365+ """Tests that set_agent_state sets agent_state and sets end_of_agent to False."""
366+ agent_state = BaseAgentState ()
367+ invocation_context = self ._create_test_invocation_context (
368+ ResumabilityConfig (is_resumable = True )
369+ )
370+ invocation_context .end_of_agents ['agent1' ] = True
371+
372+ # Set state with agent_state=agent_state, which should set the agent_state
373+ # and reset the end_of_agent flag to False.
374+ invocation_context .set_agent_state ('agent1' , agent_state = agent_state )
375+ assert invocation_context .agent_states ['agent1' ] == agent_state .model_dump (
376+ mode = 'json'
377+ )
378+ assert invocation_context .end_of_agents ['agent1' ] is False
379+
380+ def test_reset_agent_state (self ):
381+ """Tests that set_agent_state clears agent_state and end_of_agent."""
382+ invocation_context = self ._create_test_invocation_context (
383+ ResumabilityConfig (is_resumable = True )
384+ )
385+ invocation_context .agent_states ['agent1' ] = {}
386+ invocation_context .end_of_agents ['agent1' ] = True
387+
388+ # Reset state, which should clear the agent_state and end_of_agent flag.
389+ invocation_context .set_agent_state ('agent1' )
390+ assert 'agent1' not in invocation_context .agent_states
391+ assert 'agent1' not in invocation_context .end_of_agents
392+
350393
351394class TestFindMatchingFunctionCall :
352395 """Test suite for find_matching_function_call."""
0 commit comments