|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import ntpath |
| 16 | +import os |
15 | 17 | from pathlib import Path |
| 18 | +from textwrap import dedent |
16 | 19 | from typing import Literal |
17 | 20 | from typing import Type |
| 21 | +from unittest import mock |
18 | 22 |
|
19 | 23 | from google.adk.agents import config_agent_utils |
20 | 24 | from google.adk.agents.agent_config import AgentConfig |
21 | 25 | from google.adk.agents.base_agent import BaseAgent |
22 | 26 | from google.adk.agents.base_agent_config import BaseAgentConfig |
| 27 | +from google.adk.agents.common_configs import AgentRefConfig |
23 | 28 | from google.adk.agents.llm_agent import LlmAgent |
24 | 29 | from google.adk.agents.loop_agent import LoopAgent |
25 | 30 | from google.adk.agents.parallel_agent import ParallelAgent |
@@ -280,3 +285,91 @@ class MyCustomAgentConfig(BaseAgentConfig): |
280 | 285 | config.root.model_dump() |
281 | 286 | ) |
282 | 287 | assert my_custom_config.other_field == "other value" |
| 288 | + |
| 289 | + |
| 290 | +@pytest.mark.parametrize( |
| 291 | + ("config_rel_path", "child_rel_path", "child_name", "instruction"), |
| 292 | + [ |
| 293 | + ( |
| 294 | + Path("main.yaml"), |
| 295 | + Path("sub_agents/child.yaml"), |
| 296 | + "child_agent", |
| 297 | + "I am a child agent", |
| 298 | + ), |
| 299 | + ( |
| 300 | + Path("level1/level2/nested_main.yaml"), |
| 301 | + Path("sub/nested_child.yaml"), |
| 302 | + "nested_child", |
| 303 | + "I am nested", |
| 304 | + ), |
| 305 | + ], |
| 306 | +) |
| 307 | +def test_resolve_agent_reference_resolves_relative_paths( |
| 308 | + config_rel_path: Path, |
| 309 | + child_rel_path: Path, |
| 310 | + child_name: str, |
| 311 | + instruction: str, |
| 312 | + tmp_path: Path, |
| 313 | +): |
| 314 | + """Verify resolve_agent_reference resolves relative sub-agent paths.""" |
| 315 | + config_file = tmp_path / config_rel_path |
| 316 | + config_file.parent.mkdir(parents=True, exist_ok=True) |
| 317 | + |
| 318 | + child_config_path = config_file.parent / child_rel_path |
| 319 | + child_config_path.parent.mkdir(parents=True, exist_ok=True) |
| 320 | + child_config_path.write_text(dedent(f""" |
| 321 | + agent_class: LlmAgent |
| 322 | + name: {child_name} |
| 323 | + model: gemini-2.0-flash |
| 324 | + instruction: {instruction} |
| 325 | + """).lstrip()) |
| 326 | + |
| 327 | + config_file.write_text(dedent(f""" |
| 328 | + agent_class: LlmAgent |
| 329 | + name: main_agent |
| 330 | + model: gemini-2.0-flash |
| 331 | + instruction: I am the main agent |
| 332 | + sub_agents: |
| 333 | + - config_path: {child_rel_path.as_posix()} |
| 334 | + """).lstrip()) |
| 335 | + |
| 336 | + ref_config = AgentRefConfig(config_path=child_rel_path.as_posix()) |
| 337 | + agent = config_agent_utils.resolve_agent_reference( |
| 338 | + ref_config, str(config_file) |
| 339 | + ) |
| 340 | + |
| 341 | + assert agent.name == child_name |
| 342 | + |
| 343 | + config_dir = os.path.dirname(str(config_file.resolve())) |
| 344 | + assert config_dir == str(config_file.parent.resolve()) |
| 345 | + |
| 346 | + expected_child_path = os.path.join(config_dir, *child_rel_path.parts) |
| 347 | + assert os.path.exists(expected_child_path) |
| 348 | + |
| 349 | + |
| 350 | +def test_resolve_agent_reference_uses_windows_dirname(): |
| 351 | + """Ensure Windows-style config references resolve via os.path.dirname.""" |
| 352 | + ref_config = AgentRefConfig(config_path="sub\\child.yaml") |
| 353 | + recorded: dict[str, str] = {} |
| 354 | + |
| 355 | + def fake_from_config(path: str): |
| 356 | + recorded["path"] = path |
| 357 | + return "sentinel" |
| 358 | + |
| 359 | + with ( |
| 360 | + mock.patch.object( |
| 361 | + config_agent_utils, |
| 362 | + "from_config", |
| 363 | + autospec=True, |
| 364 | + side_effect=fake_from_config, |
| 365 | + ), |
| 366 | + mock.patch.object(config_agent_utils.os, "path", ntpath), |
| 367 | + ): |
| 368 | + referencing = r"C:\workspace\agents\main.yaml" |
| 369 | + result = config_agent_utils.resolve_agent_reference(ref_config, referencing) |
| 370 | + |
| 371 | + expected_path = ntpath.join( |
| 372 | + ntpath.dirname(referencing), ref_config.config_path |
| 373 | + ) |
| 374 | + assert result == "sentinel" |
| 375 | + assert recorded["path"] == expected_path |
0 commit comments