1+ import abc
12from dataclasses import dataclass
23from typing import Any
34
1213_WRAPPER_DICT_KEY = "response"
1314
1415
16+ class AgentOutputSchemaBase (abc .ABC ):
17+ """An object that captures the JSON schema of the output, as well as validating/parsing JSON
18+ produced by the LLM into the output type.
19+ """
20+
21+ @abc .abstractmethod
22+ def is_plain_text (self ) -> bool :
23+ """Whether the output type is plain text (versus a JSON object)."""
24+ pass
25+
26+ @abc .abstractmethod
27+ def name (self ) -> str :
28+ """The name of the output type."""
29+ pass
30+
31+ @abc .abstractmethod
32+ def json_schema (self ) -> dict [str , Any ]:
33+ """Returns the JSON schema of the output. Will only be called if the output type is not
34+ plain text.
35+ """
36+ pass
37+
38+ @abc .abstractmethod
39+ def is_strict_json_schema (self ) -> bool :
40+ """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41+ features, but guarantees valis JSON. See here for details:
42+ https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43+ """
44+ pass
45+
46+ @abc .abstractmethod
47+ def validate_json (self , json_str : str ) -> Any :
48+ """Validate a JSON string against the output type. You must return the validated object,
49+ or raise a `ModelBehaviorError` if the JSON is invalid.
50+ """
51+ pass
52+
53+
1554@dataclass (init = False )
16- class AgentOutputSchema :
55+ class AgentOutputSchema ( AgentOutputSchemaBase ) :
1756 """An object that captures the JSON schema of the output, as well as validating/parsing JSON
1857 produced by the LLM into the output type.
1958 """
@@ -32,7 +71,7 @@ class AgentOutputSchema:
3271 _output_schema : dict [str , Any ]
3372 """The JSON schema of the output."""
3473
35- strict_json_schema : bool
74+ _strict_json_schema : bool
3675 """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
3776 as it increases the likelihood of correct JSON input.
3877 """
@@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
4584 setting this to True, as it increases the likelihood of correct JSON input.
4685 """
4786 self .output_type = output_type
48- self .strict_json_schema = strict_json_schema
87+ self ._strict_json_schema = strict_json_schema
4988
5089 if output_type is None or output_type is str :
5190 self ._is_wrapped = False
@@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
70109 self ._type_adapter = TypeAdapter (output_type )
71110 self ._output_schema = self ._type_adapter .json_schema ()
72111
73- if self .strict_json_schema :
74- self ._output_schema = ensure_strict_json_schema (self ._output_schema )
112+ if self ._strict_json_schema :
113+ try :
114+ self ._output_schema = ensure_strict_json_schema (self ._output_schema )
115+ except UserError as e :
116+ raise UserError (
117+ "Strict JSON schema is enabled, but the output type is not valid. "
118+ "Either make the output type strict, or pass output_schema_strict=False to "
119+ "your Agent()"
120+ ) from e
75121
76122 def is_plain_text (self ) -> bool :
77123 """Whether the output type is plain text (versus a JSON object)."""
78124 return self .output_type is None or self .output_type is str
79125
126+ def is_strict_json_schema (self ) -> bool :
127+ """Whether the JSON schema is in strict mode."""
128+ return self ._strict_json_schema
129+
80130 def json_schema (self ) -> dict [str , Any ]:
81131 """The JSON schema of the output type."""
82132 if self .is_plain_text ():
83133 raise UserError ("Output type is plain text, so no JSON schema is available" )
84134 return self ._output_schema
85135
86- def validate_json (self , json_str : str , partial : bool = False ) -> Any :
136+ def validate_json (self , json_str : str ) -> Any :
87137 """Validate a JSON string against the output type. Returns the validated object, or raises
88138 a `ModelBehaviorError` if the JSON is invalid.
89139 """
90- validated = _json .validate_json (json_str , self ._type_adapter , partial )
140+ validated = _json .validate_json (json_str , self ._type_adapter , partial = False )
91141 if self ._is_wrapped :
92142 if not isinstance (validated , dict ):
93143 _error_tracing .attach_error_to_current_span (
@@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any:
113163 return validated [_WRAPPER_DICT_KEY ]
114164 return validated
115165
116- def output_type_name (self ) -> str :
166+ def name (self ) -> str :
117167 """The name of the output type."""
118168 return _type_to_str (self .output_type )
119169
0 commit comments