55LICENSE file in the root directory of this source tree.
66"""
77
8+ import inspect
89import json
910import logging
1011from abc import ABC
1112from io import TextIOWrapper , StringIO
1213from json import JSONDecodeError
1314from typing import Generic , Any , TypeVar
14- from collections .abc import Callable , Mapping
15+ from collections .abc import Awaitable , Callable , Mapping
1516import warnings
1617
1718from pydantic import JsonValue , ValidationError
@@ -225,6 +226,15 @@ def __init_subclass__(cls, **kwargs):
225226 "The DEPRECATED 'process' method must not be implemented "
226227 "alongside 'process_input' or 'process_response'."
227228 )
229+ if is_process_overridden and inspect .iscoroutinefunction (
230+ inspect .unwrap (cls .process )
231+ ):
232+ # we don't want to add async capabilities to the deprecated function
233+ raise TypeError (
234+ f"Cannot create concrete class { cls .__name__ } . "
235+ "The DEPRECATED 'process' method does not support async. "
236+ "Implement 'process_input' and/or 'process_response' instead."
237+ )
228238
229239 return
230240
@@ -875,15 +885,18 @@ async def _parse_and_process(self, request: Request) -> Response:
875885 prompt_hash , response_hash = (None , None )
876886 if input_direction :
877887 prompt_hash = prompt .hash ()
878- result : Result | Reject = self .process_input (
888+ result = await self ._handle_process_function (
889+ self .process_input ,
879890 metadata = metadata ,
880891 parameters = parameters ,
881892 prompt = prompt ,
882893 request = request ,
883894 )
895+
884896 else :
885897 response_hash = response .hash ()
886- result : Result | Reject = self .process_response (
898+ result = await self ._handle_process_function (
899+ self .process_response ,
887900 metadata = metadata ,
888901 parameters = parameters ,
889902 prompt = prompt ,
@@ -1014,13 +1027,22 @@ def _is_method_overridden(self, method_name: str) -> bool:
10141027 # the method object directly from the Processor class, then it has been overridden.
10151028 return instance_class_method_obj is not base_class_method_obj
10161029
1030+ def _process_fallback (self , ** kwargs ) -> Result | Reject :
1031+ warnings .warn (
1032+ f"{ type (self ).__name__ } uses the deprecated 'process' method. "
1033+ "Implement 'process_input' and/or 'process_response' instead." ,
1034+ DeprecationWarning ,
1035+ stacklevel = 2 ,
1036+ )
1037+ return self .process (** kwargs )
1038+
10171039 def process_input (
10181040 self ,
10191041 prompt : PROMPT ,
10201042 metadata : Metadata ,
10211043 parameters : PARAMS ,
10221044 request : Request ,
1023- ) -> Result | Reject :
1045+ ) -> Result | Reject | Awaitable [ Result | Reject ] :
10241046 """
10251047 This abstract method is for implementors of the processor to define
10261048 with their own custom logic. Errors should be raised as a subclass
@@ -1043,23 +1065,17 @@ def process_input(self, prompt, response, metadata, parameters, request):
10431065
10441066 return Result(processor_result=result)
10451067 """
1046- if self ._is_method_overridden ("process" ):
1047- warnings .warn (
1048- f"{ type (self ).__name__ } uses the deprecated 'process' method for input. "
1049- "Implement 'process_input' instead." ,
1050- DeprecationWarning ,
1051- stacklevel = 2 , # Points the warning to the caller of process_input
1068+ if not self ._is_method_overridden ("process" ):
1069+ raise NotImplementedError (
1070+ f"{ type (self ).__name__ } must implement 'process_input' or the "
1071+ "deprecated 'process' method to handle input."
10521072 )
1053- return self .process (
1054- prompt = prompt ,
1055- response = None ,
1056- metadata = metadata ,
1057- parameters = parameters ,
1058- request = request ,
1059- )
1060- raise NotImplementedError (
1061- f"{ type (self ).__name__ } must implement 'process_input' or the "
1062- "deprecated 'process' method to handle input."
1073+ return self ._process_fallback (
1074+ prompt = prompt ,
1075+ response = None ,
1076+ metadata = metadata ,
1077+ parameters = parameters ,
1078+ request = request ,
10631079 )
10641080
10651081 def process_response (
@@ -1069,7 +1085,7 @@ def process_response(
10691085 metadata : Metadata ,
10701086 parameters : PARAMS ,
10711087 request : Request ,
1072- ) -> Result | Reject :
1088+ ) -> Result | Reject | Awaitable [ Result | Reject ] :
10731089 """
10741090 This abstract method is for implementors of the processor to define
10751091 with their own custom logic. Errors should be raised as a subclass
@@ -1096,23 +1112,17 @@ def process_response(self, prompt, response, metadata, parameters, request):
10961112 return Result(processor_result=result)
10971113 """
10981114
1099- if self ._is_method_overridden ("process" ):
1100- warnings .warn (
1101- f"{ type (self ).__name__ } uses the deprecated 'process' method for response. "
1102- "Implement 'process_response' instead." ,
1103- DeprecationWarning ,
1104- stacklevel = 2 , # Points the warning to the caller of process_input
1115+ if not self ._is_method_overridden ("process" ):
1116+ raise NotImplementedError (
1117+ f"{ type (self ).__name__ } must implement 'process_response' or the "
1118+ "deprecated 'process' method to handle input."
11051119 )
1106- return self .process (
1107- prompt = prompt ,
1108- response = response ,
1109- metadata = metadata ,
1110- parameters = parameters ,
1111- request = request ,
1112- )
1113- raise NotImplementedError (
1114- f"{ type (self ).__name__ } must implement 'process_response' or the "
1115- "deprecated 'process' method to handle input."
1120+ return self ._process_fallback (
1121+ prompt = prompt ,
1122+ response = response ,
1123+ metadata = metadata ,
1124+ parameters = parameters ,
1125+ request = request ,
11161126 )
11171127
11181128 def process (
@@ -1159,6 +1169,13 @@ def process(self, prompt, response, metadata, parameters, request):
11591169 "'process_input'/'process_response'."
11601170 )
11611171
1172+ async def _handle_process_function (self , func , ** kwargs ) -> Result | Reject :
1173+ if inspect .iscoroutinefunction (func ):
1174+ result = await func (** kwargs )
1175+ else :
1176+ result = func (** kwargs )
1177+ return result
1178+
11621179
11631180def _validation_error_as_messages (err : ValidationError ) -> list [str ]:
11641181 return [_error_details_to_str (e ) for e in err .errors ()]
0 commit comments