diff --git a/langgraph/checkpoint/redis/jsonplus_redis.py b/langgraph/checkpoint/redis/jsonplus_redis.py index 3e2654d..c668d3f 100644 --- a/langgraph/checkpoint/redis/jsonplus_redis.py +++ b/langgraph/checkpoint/redis/jsonplus_redis.py @@ -39,13 +39,10 @@ class JsonPlusRedisSerializer(JsonPlusSerializer): ] def dumps(self, obj: Any) -> bytes: - """Use orjson for simple objects, fallback to parent for complex objects.""" - try: - # Fast path: Use orjson for JSON-serializable objects - return orjson.dumps(obj) - except TypeError: - # Complex objects (Send, etc.) need parent's msgpack serialization - return super().dumps(obj) + """Use orjson for serialization with LangChain object support via default handler.""" + # Use orjson with default handler for LangChain objects + # The _default method from parent class handles LangChain serialization + return orjson.dumps(obj, default=self._default) def loads(self, data: bytes) -> Any: """Use orjson for JSON parsing with reviver support, fallback to parent for msgpack data.""" @@ -54,9 +51,15 @@ def loads(self, data: bytes) -> Any: parsed = orjson.loads(data) # Apply reviver for LangChain objects (lc format) return self._revive_if_needed(parsed) - except orjson.JSONDecodeError: - # Fallback: Parent handles msgpack and other formats - return super().loads(data) + except (orjson.JSONDecodeError, TypeError): + # Fallback: Parent handles msgpack and other formats via loads_typed + # Attempt to detect type and use loads_typed + try: + # Try loading as msgpack via parent's loads_typed + return super().loads_typed(("msgpack", data)) + except Exception: + # If that fails, try loading as json string + return super().loads_typed(("json", data)) def _revive_if_needed(self, obj: Any) -> Any: """Recursively apply reviver to handle LangChain serialized objects. @@ -93,6 +96,7 @@ def dumps_typed(self, obj: Any) -> tuple[str, str]: # type: ignore[override] if isinstance(obj, (bytes, bytearray)): return "base64", base64.b64encode(obj).decode("utf-8") else: + # All objects should be JSON-serializable (LangChain objects are pre-serialized) return "json", self.dumps(obj).decode("utf-8") def loads_typed(self, data: tuple[str, Union[str, bytes]]) -> Any: diff --git a/test_jsonplus_redis_serializer.py b/test_jsonplus_redis_serializer.py new file mode 100644 index 0000000..c0259ec --- /dev/null +++ b/test_jsonplus_redis_serializer.py @@ -0,0 +1,196 @@ +"""Standalone test to verify the JsonPlusRedisSerializer fix works. + +This can be run directly without pytest infrastructure: + python test_fix_standalone.py +""" + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer + + +def test_human_message_serialization(): + """Test that HumanMessage can be serialized without TypeError.""" + print("Testing HumanMessage serialization...") + + serializer = JsonPlusRedisSerializer() + msg = HumanMessage(content="What is the weather?", id="msg-1") + + try: + # This would raise TypeError before the fix + serialized = serializer.dumps(msg) + print(f" ✓ Serialized to {len(serialized)} bytes") + + # Deserialize + deserialized = serializer.loads(serialized) + assert isinstance(deserialized, HumanMessage) + assert deserialized.content == "What is the weather?" + assert deserialized.id == "msg-1" + print(f" ✓ Deserialized correctly: {deserialized.content}") + + return True + except TypeError as e: + print(f" ✗ FAILED: {e}") + return False + + +def test_all_message_types(): + """Test all LangChain message types.""" + print("\nTesting all message types...") + + serializer = JsonPlusRedisSerializer() + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Hi!"), + SystemMessage(content="System prompt"), + ] + + for msg in messages: + try: + serialized = serializer.dumps(msg) + deserialized = serializer.loads(serialized) + assert type(deserialized) == type(msg) + print(f" ✓ {type(msg).__name__} works") + except Exception as e: + print(f" ✗ {type(msg).__name__} FAILED: {e}") + return False + + return True + + +def test_message_list(): + """Test list of messages (common pattern in LangGraph).""" + print("\nTesting message list...") + + serializer = JsonPlusRedisSerializer() + messages = [ + HumanMessage(content="Question 1"), + AIMessage(content="Answer 1"), + HumanMessage(content="Question 2"), + ] + + try: + serialized = serializer.dumps(messages) + deserialized = serializer.loads(serialized) + + assert isinstance(deserialized, list) + assert len(deserialized) == 3 + assert all(isinstance(m, (HumanMessage, AIMessage)) for m in deserialized) + print(f" ✓ List of {len(deserialized)} messages works") + + return True + except Exception as e: + print(f" ✗ FAILED: {e}") + return False + + +def test_nested_structure(): + """Test nested structure with messages (realistic LangGraph state).""" + print("\nTesting nested structure with messages...") + + serializer = JsonPlusRedisSerializer() + state = { + "messages": [ + HumanMessage(content="Query"), + AIMessage(content="Response"), + ], + "step": 1, + } + + try: + serialized = serializer.dumps(state) + deserialized = serializer.loads(serialized) + + assert "messages" in deserialized + assert len(deserialized["messages"]) == 2 + assert isinstance(deserialized["messages"][0], HumanMessage) + assert isinstance(deserialized["messages"][1], AIMessage) + print(f" ✓ Nested structure works") + + return True + except Exception as e: + print(f" ✗ FAILED: {e}") + return False + + +def test_dumps_typed(): + """Test dumps_typed (what checkpointer actually uses).""" + print("\nTesting dumps_typed...") + + serializer = JsonPlusRedisSerializer() + msg = HumanMessage(content="Test", id="test-123") + + try: + type_str, blob = serializer.dumps_typed(msg) + assert type_str == "json" + assert isinstance(blob, str) + print(f" ✓ dumps_typed returns: type='{type_str}', blob={len(blob)} chars") + + deserialized = serializer.loads_typed((type_str, blob)) + assert isinstance(deserialized, HumanMessage) + assert deserialized.content == "Test" + print(f" ✓ loads_typed works correctly") + + return True + except Exception as e: + print(f" ✗ FAILED: {e}") + return False + + +def test_backwards_compatibility(): + """Test that regular objects still work.""" + print("\nTesting backwards compatibility...") + + serializer = JsonPlusRedisSerializer() + test_cases = [ + ("string", "hello"), + ("int", 42), + ("dict", {"key": "value"}), + ("list", [1, 2, 3]), + ] + + for name, obj in test_cases: + try: + serialized = serializer.dumps(obj) + deserialized = serializer.loads(serialized) + assert deserialized == obj + print(f" ✓ {name} works") + except Exception as e: + print(f" ✗ {name} FAILED: {e}") + return False + + return True + + +def main(): + """Run all tests.""" + print("=" * 70) + print("JsonPlusRedisSerializer Fix Validation") + print("=" * 70) + + tests = [ + test_human_message_serialization, + test_all_message_types, + test_message_list, + test_nested_structure, + test_dumps_typed, + test_backwards_compatibility, + ] + + results = [] + for test in tests: + results.append(test()) + + print("\n" + "=" * 70) + print(f"Results: {sum(results)}/{len(results)} tests passed") + print("=" * 70) + + if all(results): + print("\n✅ ALL TESTS PASSED - Fix is working correctly!") + return 0 + else: + print("\n❌ SOME TESTS FAILED - Fix may not be working") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/test_jsonplus_serializer_default_handler.py b/tests/test_jsonplus_serializer_default_handler.py new file mode 100644 index 0000000..34993d0 --- /dev/null +++ b/tests/test_jsonplus_serializer_default_handler.py @@ -0,0 +1,207 @@ +"""Test JsonPlusRedisSerializer uses orjson with default handler for LangChain objects. + +This test validates the fix for the bug where JsonPlusRedisSerializer.dumps() +was not using the default parameter with orjson, causing TypeError when +serializing LangChain message objects like HumanMessage and AIMessage. + +The fix ensures all LangChain Serializable objects are properly handled by +using orjson.dumps(obj, default=self._default) instead of plain orjson.dumps(obj). +""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer + + +def test_serializer_uses_default_handler_for_messages(): + """Test that dumps() uses the default handler for LangChain message objects. + + Before the fix, this would raise: + TypeError: Type is not JSON serializable: HumanMessage + + After the fix, messages are properly serialized via the _default handler. + """ + serializer = JsonPlusRedisSerializer() + + # Test HumanMessage + human_msg = HumanMessage(content="What is the weather?", id="msg-1") + + # This should NOT raise TypeError + serialized_bytes = serializer.dumps(human_msg) + assert isinstance(serialized_bytes, bytes) + + # Deserialize and verify + deserialized = serializer.loads(serialized_bytes) + assert isinstance(deserialized, HumanMessage) + assert deserialized.content == "What is the weather?" + assert deserialized.id == "msg-1" + + +def test_serializer_handles_all_message_types(): + """Test that all LangChain message types are properly serialized. + + This ensures the fix works for all message subclasses, not just HumanMessage. + """ + serializer = JsonPlusRedisSerializer() + + messages = [ + HumanMessage(content="Hello", id="human-1"), + AIMessage(content="Hi there!", id="ai-1"), + SystemMessage(content="You are a helpful assistant", id="sys-1"), + ToolMessage(content="Tool result", tool_call_id="tool-1", id="tool-msg-1"), + ] + + for msg in messages: + # Serialize + serialized = serializer.dumps(msg) + assert isinstance(serialized, bytes) + + # Deserialize + deserialized = serializer.loads(serialized) + + # Verify type is preserved + assert type(deserialized) == type(msg) + assert deserialized.content == msg.content + assert deserialized.id == msg.id + + +def test_serializer_handles_message_lists(): + """Test that lists of messages are properly serialized. + + This is a common pattern in LangGraph state where messages are stored as lists. + """ + serializer = JsonPlusRedisSerializer() + + messages = [ + HumanMessage(content="What's 2+2?"), + AIMessage(content="2+2 equals 4"), + HumanMessage(content="Thanks!"), + ] + + # Serialize the list + serialized = serializer.dumps(messages) + assert isinstance(serialized, bytes) + + # Deserialize + deserialized = serializer.loads(serialized) + + # Verify structure + assert isinstance(deserialized, list) + assert len(deserialized) == 3 + assert all(isinstance(msg, (HumanMessage, AIMessage)) for msg in deserialized) + assert deserialized[0].content == "What's 2+2?" + assert deserialized[1].content == "2+2 equals 4" + + +def test_serializer_handles_nested_structures_with_messages(): + """Test that nested structures containing messages are properly serialized. + + This tests the scenario where messages are embedded in dicts or other structures. + """ + serializer = JsonPlusRedisSerializer() + + state = { + "messages": [ + HumanMessage(content="Query"), + AIMessage(content="Response"), + ], + "metadata": { + "step": 1, + "last_message": HumanMessage(content="Latest"), + }, + } + + # Serialize + serialized = serializer.dumps(state) + assert isinstance(serialized, bytes) + + # Deserialize + deserialized = serializer.loads(serialized) + + # Verify structure + assert "messages" in deserialized + assert len(deserialized["messages"]) == 2 + assert isinstance(deserialized["messages"][0], HumanMessage) + assert isinstance(deserialized["messages"][1], AIMessage) + assert isinstance(deserialized["metadata"]["last_message"], HumanMessage) + + +def test_dumps_typed_with_messages(): + """Test that dumps_typed also properly handles messages. + + This tests the full serialization path used by Redis checkpointer. + """ + serializer = JsonPlusRedisSerializer() + + msg = HumanMessage(content="Test message", id="test-123") + + # Use dumps_typed (what the checkpointer actually calls) + type_str, blob = serializer.dumps_typed(msg) + + assert type_str == "json" + assert isinstance(blob, str) + + # Deserialize + deserialized = serializer.loads_typed((type_str, blob)) + + assert isinstance(deserialized, HumanMessage) + assert deserialized.content == "Test message" + assert deserialized.id == "test-123" + + +def test_serializer_backwards_compatible(): + """Test that the fix doesn't break serialization of regular objects. + + Ensures that non-LangChain objects still serialize correctly. + """ + serializer = JsonPlusRedisSerializer() + + test_cases = [ + "simple string", + 42, + 3.14, + True, + None, + [1, 2, 3], + {"key": "value"}, + {"nested": {"data": [1, 2, 3]}}, + ] + + for obj in test_cases: + serialized = serializer.dumps(obj) + deserialized = serializer.loads(serialized) + assert deserialized == obj + + +def test_serializer_with_langchain_serialized_format(): + """Test that manually constructed LangChain serialized dicts are revived. + + This tests the _revive_if_needed functionality works with the new dumps() implementation. + """ + serializer = JsonPlusRedisSerializer() + + # This is the format that LangChain objects serialize to + message_dict = { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "HumanMessage"], + "kwargs": { + "content": "Manually constructed message", + "type": "human", + "id": "manual-123", + }, + } + + # Serialize and deserialize + serialized = serializer.dumps(message_dict) + deserialized = serializer.loads(serialized) + + # Should be revived as a HumanMessage + assert isinstance(deserialized, HumanMessage) + assert deserialized.content == "Manually constructed message" + assert deserialized.id == "manual-123" + + +if __name__ == "__main__": + # Run tests + pytest.main([__file__, "-v"])