|
18 | 18 |
|
19 | 19 | _DEFAULT_NATIVE_RESPONSE_TYPES = [Citations] |
20 | 20 |
|
| 21 | + |
21 | 22 | class Adapter: |
22 | | - def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + callbacks: list[BaseCallback] | None = None, |
| 26 | + use_native_function_calling: bool = False, |
| 27 | + native_response_types: list[type[Type]] | None = None, |
| 28 | + ): |
23 | 29 | self.callbacks = callbacks or [] |
24 | 30 | self.use_native_function_calling = use_native_function_calling |
25 | 31 | self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES |
@@ -68,7 +74,11 @@ def _call_preprocess( |
68 | 74 |
|
69 | 75 | # Handle custom types that use native response |
70 | 76 | for name, field in signature.output_fields.items(): |
71 | | - if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types: |
| 77 | + if ( |
| 78 | + isinstance(field.annotation, type) |
| 79 | + and issubclass(field.annotation, Type) |
| 80 | + and field.annotation in self.native_response_types |
| 81 | + ): |
72 | 82 | signature = signature.delete(name) |
73 | 83 |
|
74 | 84 | return signature |
@@ -117,7 +127,11 @@ def _call_postprocess( |
117 | 127 |
|
118 | 128 | # Parse custom types that does not rely on the adapter parsing |
119 | 129 | for name, field in original_signature.output_fields.items(): |
120 | | - if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types: |
| 130 | + if ( |
| 131 | + isinstance(field.annotation, type) |
| 132 | + and issubclass(field.annotation, Type) |
| 133 | + and field.annotation in self.native_response_types |
| 134 | + ): |
121 | 135 | value[name] = field.annotation.parse_lm_response(output) |
122 | 136 |
|
123 | 137 | if output_logprobs: |
@@ -404,7 +418,6 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: |
404 | 418 | return name |
405 | 419 | return None |
406 | 420 |
|
407 | | - |
408 | 421 | def format_conversation_history( |
409 | 422 | self, |
410 | 423 | signature: type[Signature], |
|
0 commit comments