|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | from asyncio import Queue |
6 | | -from typing import Any, AsyncGenerator, Dict, Optional, Union |
| 6 | +from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union |
7 | 7 | from unittest.mock import MagicMock |
8 | 8 |
|
9 | 9 | import pytest |
| 10 | +from pydantic import Field |
10 | 11 |
|
11 | 12 | import strands |
12 | 13 | from strands import Agent |
@@ -1450,3 +1451,232 @@ def test_function_tool_metadata_validate_signature_missing_context_config(): |
1450 | 1451 | @strands.tool |
1451 | 1452 | def my_tool(tool_context: ToolContext): |
1452 | 1453 | pass |
| 1454 | + |
| 1455 | + |
| 1456 | +def test_tool_decorator_annotated_string_description(): |
| 1457 | + """Test tool decorator with Annotated type hints for descriptions.""" |
| 1458 | + |
| 1459 | + @strands.tool |
| 1460 | + def annotated_tool( |
| 1461 | + name: Annotated[str, "The user's full name"], |
| 1462 | + age: Annotated[int, "The user's age in years"], |
| 1463 | + city: str, # No annotation - should use docstring or generic |
| 1464 | + ) -> str: |
| 1465 | + """Tool with annotated parameters. |
| 1466 | +
|
| 1467 | + Args: |
| 1468 | + city: The user's city (from docstring) |
| 1469 | + """ |
| 1470 | + return f"{name}, {age}, {city}" |
| 1471 | + |
| 1472 | + spec = annotated_tool.tool_spec |
| 1473 | + schema = spec["inputSchema"]["json"] |
| 1474 | + |
| 1475 | + # Check that annotated descriptions are used |
| 1476 | + assert schema["properties"]["name"]["description"] == "The user's full name" |
| 1477 | + assert schema["properties"]["age"]["description"] == "The user's age in years" |
| 1478 | + |
| 1479 | + # Check that docstring is still used for non-annotated params |
| 1480 | + assert schema["properties"]["city"]["description"] == "The user's city (from docstring)" |
| 1481 | + |
| 1482 | + # Verify all are required |
| 1483 | + assert set(schema["required"]) == {"name", "age", "city"} |
| 1484 | + |
| 1485 | + |
| 1486 | +def test_tool_decorator_annotated_pydantic_field_constraints(): |
| 1487 | + """Test tool decorator with Pydantic Field in Annotated.""" |
| 1488 | + |
| 1489 | + @strands.tool |
| 1490 | + def field_annotated_tool( |
| 1491 | + email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")], |
| 1492 | + score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50, |
| 1493 | + ) -> str: |
| 1494 | + """Tool with Pydantic Field annotations.""" |
| 1495 | + return f"{email}: {score}" |
| 1496 | + |
| 1497 | + spec = field_annotated_tool.tool_spec |
| 1498 | + schema = spec["inputSchema"]["json"] |
| 1499 | + |
| 1500 | + # Check descriptions from Field |
| 1501 | + assert schema["properties"]["email"]["description"] == "User's email address" |
| 1502 | + assert schema["properties"]["score"]["description"] == "Score between 0-100" |
| 1503 | + |
| 1504 | + # Check that constraints are preserved |
| 1505 | + assert schema["properties"]["score"]["minimum"] == 0 |
| 1506 | + assert schema["properties"]["score"]["maximum"] == 100 |
| 1507 | + |
| 1508 | + # Check required fields |
| 1509 | + assert "email" in schema["required"] |
| 1510 | + assert "score" not in schema["required"] # Has default |
| 1511 | + |
| 1512 | + |
| 1513 | +def test_tool_decorator_annotated_overrides_docstring(): |
| 1514 | + """Test that Annotated descriptions override docstring descriptions.""" |
| 1515 | + |
| 1516 | + @strands.tool |
| 1517 | + def override_tool(param: Annotated[str, "Description from annotation"]) -> str: |
| 1518 | + """Tool with both annotation and docstring. |
| 1519 | +
|
| 1520 | + Args: |
| 1521 | + param: Description from docstring (should be overridden) |
| 1522 | + """ |
| 1523 | + return param |
| 1524 | + |
| 1525 | + spec = override_tool.tool_spec |
| 1526 | + schema = spec["inputSchema"]["json"] |
| 1527 | + |
| 1528 | + # Annotated description should win |
| 1529 | + assert schema["properties"]["param"]["description"] == "Description from annotation" |
| 1530 | + |
| 1531 | + |
| 1532 | +def test_tool_decorator_annotated_optional_type(): |
| 1533 | + """Test tool with Optional types in Annotated.""" |
| 1534 | + |
| 1535 | + @strands.tool |
| 1536 | + def optional_annotated_tool( |
| 1537 | + required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None |
| 1538 | + ) -> str: |
| 1539 | + """Tool with optional annotated parameter.""" |
| 1540 | + return f"{required}, {optional}" |
| 1541 | + |
| 1542 | + spec = optional_annotated_tool.tool_spec |
| 1543 | + schema = spec["inputSchema"]["json"] |
| 1544 | + |
| 1545 | + # Check descriptions |
| 1546 | + assert schema["properties"]["required"]["description"] == "Required parameter" |
| 1547 | + assert schema["properties"]["optional"]["description"] == "Optional parameter" |
| 1548 | + |
| 1549 | + # Check required list |
| 1550 | + assert "required" in schema["required"] |
| 1551 | + assert "optional" not in schema["required"] |
| 1552 | + |
| 1553 | + |
| 1554 | +def test_tool_decorator_annotated_complex_types(): |
| 1555 | + """Test tool with complex types in Annotated.""" |
| 1556 | + |
| 1557 | + @strands.tool |
| 1558 | + def complex_annotated_tool( |
| 1559 | + tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] |
| 1560 | + ) -> str: |
| 1561 | + """Tool with complex annotated types.""" |
| 1562 | + return f"Tags: {len(tags)}, Config: {len(config)}" |
| 1563 | + |
| 1564 | + spec = complex_annotated_tool.tool_spec |
| 1565 | + schema = spec["inputSchema"]["json"] |
| 1566 | + |
| 1567 | + # Check descriptions |
| 1568 | + assert schema["properties"]["tags"]["description"] == "List of tag strings" |
| 1569 | + assert schema["properties"]["config"]["description"] == "Configuration dictionary" |
| 1570 | + |
| 1571 | + # Check types are preserved |
| 1572 | + assert schema["properties"]["tags"]["type"] == "array" |
| 1573 | + assert schema["properties"]["config"]["type"] == "object" |
| 1574 | + |
| 1575 | + |
| 1576 | +def test_tool_decorator_annotated_mixed_styles(): |
| 1577 | + """Test tool with mixed annotation styles.""" |
| 1578 | + |
| 1579 | + @strands.tool |
| 1580 | + def mixed_tool( |
| 1581 | + plain: str, |
| 1582 | + annotated_str: Annotated[str, "String description"], |
| 1583 | + annotated_field: Annotated[int, Field(description="Field description", ge=0)], |
| 1584 | + docstring_only: int, |
| 1585 | + ) -> str: |
| 1586 | + """Tool with mixed parameter styles. |
| 1587 | +
|
| 1588 | + Args: |
| 1589 | + plain: Plain parameter description |
| 1590 | + docstring_only: Docstring description for this param |
| 1591 | + """ |
| 1592 | + return "mixed" |
| 1593 | + |
| 1594 | + spec = mixed_tool.tool_spec |
| 1595 | + schema = spec["inputSchema"]["json"] |
| 1596 | + |
| 1597 | + # Check each style works correctly |
| 1598 | + assert schema["properties"]["plain"]["description"] == "Plain parameter description" |
| 1599 | + assert schema["properties"]["annotated_str"]["description"] == "String description" |
| 1600 | + assert schema["properties"]["annotated_field"]["description"] == "Field description" |
| 1601 | + assert schema["properties"]["docstring_only"]["description"] == "Docstring description for this param" |
| 1602 | + |
| 1603 | + |
| 1604 | +@pytest.mark.asyncio |
| 1605 | +async def test_tool_decorator_annotated_execution(alist): |
| 1606 | + """Test that annotated tools execute correctly.""" |
| 1607 | + |
| 1608 | + @strands.tool |
| 1609 | + def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str: |
| 1610 | + """Test execution with annotations.""" |
| 1611 | + return f"Hello {name} " * count |
| 1612 | + |
| 1613 | + # Test tool use |
| 1614 | + tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}} |
| 1615 | + stream = execution_test.stream(tool_use, {}) |
| 1616 | + |
| 1617 | + result = (await alist(stream))[-1] |
| 1618 | + assert result["tool_result"]["status"] == "success" |
| 1619 | + assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"] |
| 1620 | + |
| 1621 | + # Test direct call |
| 1622 | + direct_result = execution_test("Bob", 3) |
| 1623 | + assert direct_result == "Hello Bob Hello Bob Hello Bob " |
| 1624 | + |
| 1625 | + |
| 1626 | +def test_tool_decorator_annotated_no_description_fallback(): |
| 1627 | + """Test that Annotated without description falls back to docstring.""" |
| 1628 | + |
| 1629 | + @strands.tool |
| 1630 | + def no_desc_annotated( |
| 1631 | + param: Annotated[str, Field()], # Field without description |
| 1632 | + ) -> str: |
| 1633 | + """Tool with Annotated but no description. |
| 1634 | +
|
| 1635 | + Args: |
| 1636 | + param: Docstring description |
| 1637 | + """ |
| 1638 | + return param |
| 1639 | + |
| 1640 | + spec = no_desc_annotated.tool_spec |
| 1641 | + schema = spec["inputSchema"]["json"] |
| 1642 | + |
| 1643 | + # Should fall back to docstring |
| 1644 | + assert schema["properties"]["param"]["description"] == "Docstring description" |
| 1645 | + |
| 1646 | + |
| 1647 | +def test_tool_decorator_annotated_empty_string_description(): |
| 1648 | + """Test handling of empty string descriptions in Annotated.""" |
| 1649 | + |
| 1650 | + @strands.tool |
| 1651 | + def empty_desc_tool( |
| 1652 | + param: Annotated[str, ""], # Empty string description |
| 1653 | + ) -> str: |
| 1654 | + """Tool with empty annotation description. |
| 1655 | +
|
| 1656 | + Args: |
| 1657 | + param: Docstring description |
| 1658 | + """ |
| 1659 | + return param |
| 1660 | + |
| 1661 | + spec = empty_desc_tool.tool_spec |
| 1662 | + schema = spec["inputSchema"]["json"] |
| 1663 | + |
| 1664 | + # Empty string is still a valid description, should not fall back |
| 1665 | + assert schema["properties"]["param"]["description"] == "" |
| 1666 | + |
| 1667 | + |
| 1668 | +@pytest.mark.asyncio |
| 1669 | +async def test_tool_decorator_annotated_validation_error(alist): |
| 1670 | + """Test that validation works correctly with annotated parameters.""" |
| 1671 | + |
| 1672 | + @strands.tool |
| 1673 | + def validation_tool(age: Annotated[int, "User age"]) -> str: |
| 1674 | + """Tool for validation testing.""" |
| 1675 | + return f"Age: {age}" |
| 1676 | + |
| 1677 | + # Test with wrong type |
| 1678 | + tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}} |
| 1679 | + stream = validation_tool.stream(tool_use, {}) |
| 1680 | + |
| 1681 | + result = (await alist(stream))[-1] |
| 1682 | + assert result["tool_result"]["status"] == "error" |
0 commit comments