Skip to content

Commit 7fba952

Browse files
Task/raw model answer (#1947)
* Add full_response to llm provider output * Semver * Small leftover cleanup * Add pyi to suppress Pyright errors. full_content is optional * Format * Add missing stubs
1 parent fb4fe72 commit 7fba952

File tree

5 files changed

+78
-3
lines changed

5 files changed

+78
-3
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Add full llm response to LLM PRovider output"
4+
}

graphrag/language_model/providers/fnllm/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ async def achat(
8383
else:
8484
response = await self.model(prompt, history=history, **kwargs)
8585
return BaseModelResponse(
86-
output=BaseModelOutput(content=response.output.content),
86+
output=BaseModelOutput(
87+
content=response.output.content,
88+
full_response=response.output.raw_model.to_dict(),
89+
),
8790
parsed_response=response.parsed_json,
8891
history=response.history,
8992
cache_hit=response.cache_hit,
@@ -282,7 +285,10 @@ async def achat(
282285
else:
283286
response = await self.model(prompt, history=history, **kwargs)
284287
return BaseModelResponse(
285-
output=BaseModelOutput(content=response.output.content),
288+
output=BaseModelOutput(
289+
content=response.output.content,
290+
full_response=response.output.raw_model.to_dict(),
291+
),
286292
parsed_response=response.parsed_json,
287293
history=response.history,
288294
cache_hit=response.cache_hit,

graphrag/language_model/response/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ def content(self) -> str:
1818
"""Return the textual content of the output."""
1919
...
2020

21+
@property
22+
def full_response(self) -> dict[str, Any] | None:
23+
"""Return the complete JSON response returned by the model."""
24+
...
25+
2126

2227
class ModelResponse(Protocol, Generic[T]):
2328
"""Protocol for LLM response."""
@@ -43,6 +48,10 @@ class BaseModelOutput(BaseModel):
4348

4449
content: str = Field(..., description="The textual content of the output.")
4550
"""The textual content of the output."""
51+
full_response: dict[str, Any] | None = Field(
52+
None, description="The complete JSON response returned by the LLM provider."
53+
)
54+
"""The complete JSON response returned by the LLM provider."""
4655

4756

4857
class BaseModelResponse(BaseModel, Generic[T]):
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2025 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from typing import Any, Generic, Protocol, TypeVar
5+
6+
from pydantic import BaseModel
7+
8+
_T = TypeVar("_T", bound=BaseModel, covariant=True)
9+
10+
class ModelOutput(Protocol):
11+
@property
12+
def content(self) -> str: ...
13+
@property
14+
def full_response(self) -> dict[str, Any] | None: ...
15+
16+
class ModelResponse(Protocol, Generic[_T]):
17+
@property
18+
def output(self) -> ModelOutput: ...
19+
@property
20+
def parsed_response(self) -> _T | None: ...
21+
@property
22+
def history(self) -> list[Any]: ...
23+
24+
class BaseModelOutput(BaseModel):
25+
content: str
26+
full_response: dict[str, Any] | None
27+
28+
def __init__(
29+
self,
30+
content: str,
31+
full_response: dict[str, Any] | None = None,
32+
) -> None: ...
33+
34+
class BaseModelResponse(BaseModel, Generic[_T]):
35+
output: BaseModelOutput
36+
parsed_response: _T | None
37+
history: list[Any]
38+
tool_calls: list[Any]
39+
metrics: Any | None
40+
cache_hit: bool | None
41+
42+
def __init__(
43+
self,
44+
output: BaseModelOutput,
45+
parsed_response: _T | None = None,
46+
history: list[Any] = ..., # default provided by Pydantic
47+
tool_calls: list[Any] = ..., # default provided by Pydantic
48+
metrics: Any | None = None,
49+
cache_hit: bool | None = None,
50+
) -> None: ...

tests/integration/language_model/test_factory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ async def achat(
3333
def chat(
3434
self, prompt: str, history: list | None = None, **kwargs: Any
3535
) -> ModelResponse:
36-
return BaseModelResponse(output=BaseModelOutput(content="content"))
36+
return BaseModelResponse(
37+
output=BaseModelOutput(
38+
content="content", full_response={"key": "value"}
39+
)
40+
)
3741

3842
async def achat_stream(
3943
self, prompt: str, history: list | None = None, **kwargs: Any
@@ -49,9 +53,11 @@ def chat_stream(
4953
assert isinstance(model, CustomChatModel)
5054
response = await model.achat("prompt")
5155
assert response.output.content == "content"
56+
assert response.output.full_response is None
5257

5358
response = model.chat("prompt")
5459
assert response.output.content == "content"
60+
assert response.output.full_response == {"key": "value"}
5561

5662

5763
async def test_create_custom_embedding_llm():

0 commit comments

Comments
 (0)