Skip to content

Commit 3ac64f7

Browse files
authored
fix: #2073 Improve type hinting in memory extension modules (#2077)
1 parent 73e7843 commit 3ac64f7

File tree

4 files changed

+38
-27
lines changed

4 files changed

+38
-27
lines changed

examples/memory/advanced_sqlite_session_example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def main():
132132
# Show current conversation
133133
print("Current conversation:")
134134
current_items = await session.get_items()
135-
for i, item in enumerate(current_items, 1):
135+
for i, item in enumerate(current_items, 1): # type: ignore[assignment]
136136
role = str(item.get("role", item.get("type", "unknown")))
137137
if item.get("type") == "function_call":
138138
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
@@ -151,8 +151,8 @@ async def main():
151151
# Show available turns for branching
152152
print("\nAvailable turns for branching:")
153153
turns = await session.get_conversation_turns()
154-
for turn in turns:
155-
print(f" Turn {turn['turn']}: {turn['content']}")
154+
for turn in turns: # type: ignore[assignment]
155+
print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index]
156156

157157
# Create a branch from turn 2
158158
print("\nCreating new branch from turn 2...")
@@ -163,7 +163,7 @@ async def main():
163163
branch_items = await session.get_items()
164164
print(f"Items copied to new branch: {len(branch_items)}")
165165
print("New branch contains:")
166-
for i, item in enumerate(branch_items, 1):
166+
for i, item in enumerate(branch_items, 1): # type: ignore[assignment]
167167
role = str(item.get("role", item.get("type", "unknown")))
168168
if item.get("type") == "function_call":
169169
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
@@ -198,7 +198,7 @@ async def main():
198198
print("\n=== New Conversation Branch ===")
199199
new_conversation = await session.get_items()
200200
print("New conversation with branch:")
201-
for i, item in enumerate(new_conversation, 1):
201+
for i, item in enumerate(new_conversation, 1): # type: ignore[assignment]
202202
role = str(item.get("role", item.get("type", "unknown")))
203203
if item.get("type") == "function_call":
204204
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
@@ -224,8 +224,8 @@ async def main():
224224
# Show conversation turns in current branch
225225
print("\nConversation turns in current branch:")
226226
current_turns = await session.get_conversation_turns()
227-
for turn in current_turns:
228-
print(f" Turn {turn['turn']}: {turn['content']}")
227+
for turn in current_turns: # type: ignore[assignment]
228+
print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index]
229229

230230
print("\n=== Branch Switching Demo ===")
231231
print("We can switch back to the main branch...")

examples/memory/dapr_session_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@ async def demonstrate_multi_store():
417417
r_items = await redis_session.get_items()
418418
p_items = await pg_session.get_items()
419419

420-
r_example = r_items[-1]["content"] if r_items else "empty"
421-
p_example = p_items[-1]["content"] if p_items else "empty"
420+
r_example = r_items[-1]["content"] if r_items else "empty" # type: ignore[typeddict-item]
421+
p_example = p_items[-1]["content"] if p_items else "empty" # type: ignore[typeddict-item]
422422

423423
print(f"{redis_store}: {len(r_items)} items; example: {r_example}")
424424
print(f"{pg_store}: {len(p_items)} items; example: {p_example}")

src/agents/extensions/memory/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@
88

99
from __future__ import annotations
1010

11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any
12+
13+
if TYPE_CHECKING:
14+
from .advanced_sqlite_session import AdvancedSQLiteSession
15+
from .dapr_session import (
16+
DAPR_CONSISTENCY_EVENTUAL,
17+
DAPR_CONSISTENCY_STRONG,
18+
DaprSession,
19+
)
20+
from .encrypt_session import EncryptedSession
21+
from .redis_session import RedisSession
22+
from .sqlalchemy_session import SQLAlchemySession
1223

1324
__all__: list[str] = [
1425
"AdvancedSQLiteSession",

tests/extensions/memory/test_dapr_session.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ async def _create_test_session(
182182
session = DaprSession(
183183
session_id=session_id,
184184
state_store_name="statestore",
185-
dapr_client=fake_dapr_client,
185+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
186186
)
187187

188188
# Clean up any existing data
@@ -260,12 +260,12 @@ async def test_session_isolation(fake_dapr_client: FakeDaprClient):
260260
session1 = DaprSession(
261261
session_id="session_1",
262262
state_store_name="statestore",
263-
dapr_client=fake_dapr_client,
263+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
264264
)
265265
session2 = DaprSession(
266266
session_id="session_2",
267267
state_store_name="statestore",
268-
dapr_client=fake_dapr_client,
268+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
269269
)
270270

271271
try:
@@ -386,7 +386,7 @@ async def test_pop_from_empty_session(fake_dapr_client: FakeDaprClient):
386386
session = DaprSession(
387387
session_id="empty_session",
388388
state_store_name="statestore",
389-
dapr_client=fake_dapr_client,
389+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
390390
)
391391
try:
392392
await session.clear_session()
@@ -540,7 +540,7 @@ async def test_dapr_connectivity(fake_dapr_client: FakeDaprClient):
540540
session = DaprSession(
541541
session_id="connectivity_test",
542542
state_store_name="statestore",
543-
dapr_client=fake_dapr_client,
543+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
544544
)
545545
try:
546546
# Test ping
@@ -555,7 +555,7 @@ async def test_ttl_functionality(fake_dapr_client: FakeDaprClient):
555555
session = DaprSession(
556556
session_id="ttl_test",
557557
state_store_name="statestore",
558-
dapr_client=fake_dapr_client,
558+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
559559
ttl=3600, # 1 hour TTL
560560
)
561561

@@ -586,15 +586,15 @@ async def test_consistency_levels(fake_dapr_client: FakeDaprClient):
586586
session_eventual = DaprSession(
587587
session_id="eventual_test",
588588
state_store_name="statestore",
589-
dapr_client=fake_dapr_client,
589+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
590590
consistency=DAPR_CONSISTENCY_EVENTUAL,
591591
)
592592

593593
# Test strong consistency
594594
session_strong = DaprSession(
595595
session_id="strong_test",
596596
state_store_name="statestore",
597-
dapr_client=fake_dapr_client,
597+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
598598
consistency=DAPR_CONSISTENCY_STRONG,
599599
)
600600

@@ -621,7 +621,7 @@ async def test_external_client_not_closed(fake_dapr_client: FakeDaprClient):
621621
session = DaprSession(
622622
session_id="external_client_test",
623623
state_store_name="statestore",
624-
dapr_client=fake_dapr_client,
624+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
625625
)
626626

627627
try:
@@ -650,7 +650,7 @@ async def test_internal_client_ownership(fake_dapr_client: FakeDaprClient):
650650
session = DaprSession(
651651
session_id="internal_client_test",
652652
state_store_name="statestore",
653-
dapr_client=fake_dapr_client,
653+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
654654
)
655655
session._owns_client = True # Simulate ownership
656656

@@ -732,7 +732,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient):
732732
session1 = DaprSession(
733733
session_id="close_test_1",
734734
state_store_name="statestore",
735-
dapr_client=fake_dapr_client,
735+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
736736
)
737737

738738
# Verify _owns_client is False for external client
@@ -749,7 +749,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient):
749749
session2 = DaprSession(
750750
session_id="close_test_2",
751751
state_store_name="statestore",
752-
dapr_client=fake_dapr_client2,
752+
dapr_client=fake_dapr_client2, # type: ignore[arg-type]
753753
)
754754
session2._owns_client = True # Simulate ownership
755755

@@ -788,8 +788,8 @@ async def test_already_deserialized_messages(fake_dapr_client: FakeDaprClient):
788788
# Should handle both string and dict messages
789789
items = await session.get_items()
790790
assert len(items) == 2
791-
assert items[0]["content"] == "First message"
792-
assert items[1]["content"] == "Second message"
791+
assert items[0]["content"] == "First message" # type: ignore[typeddict-item]
792+
assert items[1]["content"] == "Second message" # type: ignore[typeddict-item]
793793

794794
await session.close()
795795

@@ -800,7 +800,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
800800
async with DaprSession(
801801
"test_cm_session",
802802
state_store_name="statestore",
803-
dapr_client=fake_dapr_client,
803+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
804804
) as session:
805805
# Verify we got the session object back
806806
assert session.session_id == "test_cm_session"
@@ -809,7 +809,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
809809
await session.add_items([{"role": "user", "content": "Test message"}])
810810
items = await session.get_items()
811811
assert len(items) == 1
812-
assert items[0]["content"] == "Test message"
812+
assert items[0]["content"] == "Test message" # type: ignore[typeddict-item]
813813

814814
# After exiting context manager, close should have been called
815815
# Verify we can still check the state (fake client doesn't truly disconnect)
@@ -819,7 +819,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
819819
owned_session = DaprSession(
820820
"test_cm_owned",
821821
state_store_name="statestore",
822-
dapr_client=fake_dapr_client,
822+
dapr_client=fake_dapr_client, # type: ignore[arg-type]
823823
)
824824
# Manually set ownership to simulate from_address behavior
825825
owned_session._owns_client = True

0 commit comments

Comments
 (0)