diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 8ae631ab5..f9f84debe 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -322,6 +322,7 @@ class SchemaSerializer: warnings: bool | Literal['none', 'warn', 'error'] = True, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, + polymorphic_serialization: bool | None = None, context: Any | None = None, ) -> Any: """ @@ -345,6 +346,7 @@ class SchemaSerializer: fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + polymorphic_serialization: Whether to override configured model and dataclass polymorphic serialization for this call. context: The context to use for serialization, this is passed to functional serializers as [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. @@ -371,6 +373,7 @@ class SchemaSerializer: warnings: bool | Literal['none', 'warn', 'error'] = True, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, + polymorphic_serialization: bool | None = None, context: Any | None = None, ) -> bytes: """ @@ -394,7 +397,9 @@ class SchemaSerializer: "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + serialize_as_any: Whether to serialize fields with duck-typing + serialization behavior. + polymorphic_serialization: Whether to override configured model and dataclass polymorphic serialization for this call. context: The context to use for serialization, this is passed to functional serializers as [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. @@ -425,6 +430,7 @@ def to_json( serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, + polymorphic_serialization: bool | None = None, context: Any | None = None, ) -> bytes: """ @@ -453,6 +459,7 @@ def to_json( fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + polymorphic_serialization: Whether to override configured model and dataclass polymorphic serialization for this call. context: The context to use for serialization, this is passed to functional serializers as [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. @@ -510,6 +517,7 @@ def to_jsonable_python( serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, + polymorphic_serialization: bool | None = None, context: Any | None = None, ) -> Any: """ @@ -536,6 +544,7 @@ def to_jsonable_python( fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + polymorphic_serialization: Whether to override configured model and dataclass polymorphic serialization for this call. context: The context to use for serialization, this is passed to functional serializers as [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index c8a3b6da6..9e5d3a0b4 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -79,6 +79,7 @@ class CoreConfig(TypedDict, total=False): validate_by_alias: Whether to use the field's alias when validating against the provided input data. Default is `True`. validate_by_name: Whether to use the field's name when validating against the provided input data. Default is `False`. Replacement for `populate_by_name`. serialize_by_alias: Whether to serialize by alias. Default is `False`, expected to change to `True` in V3. + polymorphic_serialization: Whether to enable polymorphic serialization for models and dataclasses. Default is `False`. url_preserve_empty_path: Whether to preserve empty URL paths when validating values for a URL type. Defaults to `False`. """ @@ -120,6 +121,7 @@ class CoreConfig(TypedDict, total=False): validate_by_alias: bool # default: True validate_by_name: bool # default: False serialize_by_alias: bool # default: False + polymorphic_serialization: bool # default: False url_preserve_empty_path: bool # default: False @@ -181,6 +183,11 @@ def serialize_as_any(self) -> bool: """The `serialize_as_any` argument set during serialization.""" ... + @property + def polymorphic_serialization(self) -> bool | None: + """The `polymorphic_serialization` argument set during serialization, if any.""" + ... + @property def round_trip(self) -> bool: """The `round_trip` argument set during serialization.""" diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 5216accb8..6fcd1874f 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -352,6 +352,7 @@ impl ValidationError { None, false, None, + None, ); let mut state = SerializationState::new(config, WarningsMode::None, None, None, extra)?; let mut serializer = ValidationErrorSerializer { diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 6d76db48f..c328511f1 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -180,6 +180,8 @@ pub(crate) struct Extra<'a, 'py> { pub serialize_unknown: bool, pub fallback: Option<&'a Bound<'py, PyAny>>, pub serialize_as_any: bool, + /// Whether `polymorphic_serialization` is globally enabled / disabled for this serialization process + pub polymorphic_serialization: Option, pub context: Option<&'a Bound<'py, PyAny>>, } @@ -197,6 +199,7 @@ impl<'a, 'py> Extra<'a, 'py> { serialize_unknown: bool, fallback: Option<&'a Bound<'py, PyAny>>, serialize_as_any: bool, + polymorphic_serialization: Option, context: Option<&'a Bound<'py, PyAny>>, ) -> Self { Self { @@ -211,6 +214,7 @@ impl<'a, 'py> Extra<'a, 'py> { serialize_unknown, fallback, serialize_as_any, + polymorphic_serialization, context, } } @@ -256,6 +260,7 @@ pub(crate) struct ExtraOwned { serialize_unknown: bool, pub fallback: Option>, serialize_as_any: bool, + polymorphic_serialization: Option, pub context: Option>, include: Option>, exclude: Option>, @@ -299,6 +304,7 @@ impl ExtraOwned { serialize_unknown: extra.serialize_unknown, fallback: extra.fallback.map(|model| model.clone().into()), serialize_as_any: extra.serialize_as_any, + polymorphic_serialization: extra.polymorphic_serialization, context: extra.context.map(|model| model.clone().into()), include: state.include().map(|m| m.clone().into()), exclude: state.exclude().map(|m| m.clone().into()), @@ -318,6 +324,7 @@ impl ExtraOwned { serialize_unknown: self.serialize_unknown, fallback: self.fallback.as_ref().map(|m| m.bind(py)), serialize_as_any: self.serialize_as_any, + polymorphic_serialization: self.polymorphic_serialization, context: self.context.as_ref().map(|m| m.bind(py)), } } diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index c44380bc4..5f3bfaa4f 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -24,6 +24,7 @@ mod fields; mod filter; mod infer; mod ob_type; +mod polymorphism_trampoline; mod prebuilt; pub mod ser; mod shared; @@ -71,7 +72,8 @@ impl SchemaSerializer { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = None, exclude_unset = false, exclude_defaults = false, exclude_none = false, exclude_computed_fields = false, - round_trip = false, warnings = WarningsArg::Bool(true), fallback = None, serialize_as_any = false, context = None))] + round_trip = false, warnings = WarningsArg::Bool(true), fallback = None, serialize_as_any = false, + polymorphic_serialization = None, context = None))] pub fn to_python( &self, py: Python, @@ -88,6 +90,7 @@ impl SchemaSerializer { warnings: WarningsArg, fallback: Option<&Bound<'_, PyAny>>, serialize_as_any: bool, + polymorphic_serialization: Option, context: Option<&Bound<'_, PyAny>>, ) -> PyResult> { let mode: SerMode = mode.into(); @@ -107,6 +110,7 @@ impl SchemaSerializer { false, fallback, serialize_as_any, + polymorphic_serialization, context, ); let mut state = SerializationState::new(self.config, warnings_mode, include, exclude, extra)?; @@ -118,7 +122,8 @@ impl SchemaSerializer { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (value, *, indent = None, ensure_ascii = false, include = None, exclude = None, by_alias = None, exclude_unset = false, exclude_defaults = false, exclude_none = false, exclude_computed_fields = false, - round_trip = false, warnings = WarningsArg::Bool(true), fallback = None, serialize_as_any = false, context = None))] + round_trip = false, warnings = WarningsArg::Bool(true), fallback = None, serialize_as_any = false, + polymorphic_serialization = None, context = None))] pub fn to_json( &self, py: Python, @@ -136,6 +141,7 @@ impl SchemaSerializer { warnings: WarningsArg, fallback: Option<&Bound<'_, PyAny>>, serialize_as_any: bool, + polymorphic_serialization: Option, context: Option<&Bound<'_, PyAny>>, ) -> PyResult> { let warnings_mode = match warnings { @@ -154,6 +160,7 @@ impl SchemaSerializer { false, fallback, serialize_as_any, + polymorphic_serialization, context, ); let mut state = SerializationState::new(self.config, warnings_mode, include, exclude, extra)?; @@ -201,7 +208,7 @@ impl SchemaSerializer { #[pyo3(signature = (value, *, indent = None, ensure_ascii = false, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, timedelta_mode = "iso8601", temporal_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None, - serialize_as_any = false, context = None))] + serialize_as_any = false, polymorphic_serialization = None, context = None))] pub fn to_json( py: Python, value: &Bound<'_, PyAny>, @@ -219,6 +226,7 @@ pub fn to_json( serialize_unknown: bool, fallback: Option<&Bound<'_, PyAny>>, serialize_as_any: bool, + polymorphic_serialization: Option, context: Option<&Bound<'_, PyAny>>, ) -> PyResult> { let config = SerializationConfig::from_args(timedelta_mode, temporal_mode, bytes_mode, inf_nan_mode)?; @@ -234,6 +242,7 @@ pub fn to_json( serialize_unknown, fallback, serialize_as_any, + polymorphic_serialization, context, ); let mut state = SerializationState::new(config, WarningsMode::None, include, exclude, extra)?; @@ -254,7 +263,7 @@ pub fn to_json( #[pyfunction] #[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, timedelta_mode = "iso8601", temporal_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", - serialize_unknown = false, fallback = None, serialize_as_any = false, context = None))] + serialize_unknown = false, fallback = None, serialize_as_any = false, polymorphic_serialization = None, context = None))] pub fn to_jsonable_python( py: Python, value: &Bound<'_, PyAny>, @@ -270,6 +279,7 @@ pub fn to_jsonable_python( serialize_unknown: bool, fallback: Option<&Bound<'_, PyAny>>, serialize_as_any: bool, + polymorphic_serialization: Option, context: Option<&Bound<'_, PyAny>>, ) -> PyResult> { let config = SerializationConfig::from_args(timedelta_mode, temporal_mode, bytes_mode, inf_nan_mode)?; @@ -285,6 +295,7 @@ pub fn to_jsonable_python( serialize_unknown, fallback, serialize_as_any, + polymorphic_serialization, context, ); let mut state = SerializationState::new(config, WarningsMode::None, include, exclude, extra)?; diff --git a/src/serializers/polymorphism_trampoline.rs b/src/serializers/polymorphism_trampoline.rs new file mode 100644 index 000000000..2112b0de8 --- /dev/null +++ b/src/serializers/polymorphism_trampoline.rs @@ -0,0 +1,96 @@ +use std::{borrow::Cow, sync::Arc}; + +use pyo3::{prelude::*, types::PyType}; + +use crate::serializers::{ + errors::unwrap_ser_error, + extra::SerCheck, + infer::call_pydantic_serializer, + shared::{serialize_to_json, serialize_to_python, DoSerialize, TypeSerializer}, + CombinedSerializer, SerializationState, +}; + +/// The polymorphism trampoline detects subclasses of its target type and dispatches to their +/// `__pydantic_serializer__` serializer for serialization. +/// +/// This exists as a separate structure to allow for cases such as model serializers where the +/// inner serializer may just be a function serializer and so cannot handle polymorphism itself. +#[derive(Debug)] +pub struct PolymorphismTrampoline { + class: Py, + /// Inner serializer used when the type is not a subclass (responsible for any fallback etc) + pub(crate) serializer: Arc, + /// Whether polymorphic serialization is enabled from config + enabled_from_config: bool, +} + +impl_py_gc_traverse!(PolymorphismTrampoline { class, serializer }); + +impl PolymorphismTrampoline { + pub fn new(class: Py, serializer: Arc, enabled_from_config: bool) -> Self { + Self { + class, + serializer, + enabled_from_config, + } + } + + fn is_subclass(&self, value: &Bound<'_, PyAny>) -> PyResult { + Ok(!value.get_type().is(&self.class) && value.is_instance(self.class.bind(value.py()))?) + } + + fn serialize<'py, T, E: From>( + &self, + value: &Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + do_serialize: impl DoSerialize<'py, T, E>, + ) -> Result { + let runtime_polymorphic = state.extra.polymorphic_serialization; + if state.check != SerCheck::Strict // strict disables polymorphism + && runtime_polymorphic.unwrap_or(self.enabled_from_config) + && self.is_subclass(value)? + { + call_pydantic_serializer(value, state, do_serialize) + } else { + do_serialize.serialize_no_infer(&self.serializer, value, state) + } + } +} + +impl TypeSerializer for PolymorphismTrampoline { + fn to_python<'py>( + &self, + value: &Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + self.serialize(value, state, serialize_to_python()) + } + + fn json_key<'a, 'py>( + &self, + key: &'a Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + // json key serialization for models and dataclasses was always polymorphic anyway + // FIXME: make this consistent with the other cases? + self.serializer.json_key(key, state) + } + + fn serde_serialize<'py, S: serde::ser::Serializer>( + &self, + value: &Bound<'py, PyAny>, + serializer: S, + state: &mut SerializationState<'_, 'py>, + ) -> Result { + self.serialize(value, state, serialize_to_json(serializer)) + .map_err(unwrap_ser_error) + } + + fn get_name(&self) -> &str { + self.serializer.get_name() + } + + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() + } +} diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index 15ff8145a..88cc2ae24 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -3,9 +3,9 @@ use std::borrow::Cow; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::common::prebuilt::get_prebuilt; use crate::serializers::SerializationState; use crate::SchemaSerializer; +use crate::{common::prebuilt::get_prebuilt, serializers::polymorphism_trampoline::PolymorphismTrampoline}; use super::shared::{CombinedSerializer, TypeSerializer}; @@ -18,12 +18,24 @@ impl PrebuiltSerializer { pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| { let schema_serializer = py_any.extract::>()?; - if matches!( - schema_serializer.get().serializer.as_ref(), - CombinedSerializer::FunctionWrap(_) - ) { + + let mut serializer = schema_serializer.get().serializer.as_ref(); + + // it is very likely that the prebuilt serializer is a polymorphism trampoline, peek + // through it for the sake of the check below + if let CombinedSerializer::PolymorphismTrampoline(PolymorphismTrampoline { + serializer: inner_serializer, + .. + }) = serializer + { + serializer = inner_serializer.as_ref(); + } + + // don't allow wrap serializers as prebuilt serializers (leads to double wrapping) + if matches!(serializer, CombinedSerializer::FunctionWrap(_)) { return Ok(None); } + Ok(Some(Self { schema_serializer }.into())) }) } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 98e91098c..908ee2180 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -19,6 +19,7 @@ use crate::build_tools::py_schema_error_type; use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; use crate::serializers::errors::WrappedSerError; +use crate::serializers::polymorphism_trampoline::PolymorphismTrampoline; use crate::serializers::ser::PythonSerializer; use crate::serializers::type_serializers::any::AnySerializer; use crate::tools::{py_err, SchemaDict}; @@ -91,6 +92,9 @@ combined_serializer! { Fields: super::fields::GeneralFieldsSerializer; // prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum Prebuilt: super::prebuilt::PrebuiltSerializer; + // polymorphism trampoline is manually constructed to wrap models and dataclasses with + // polymorphic serialization + PolymorphismTrampoline: super::polymorphism_trampoline::PolymorphismTrampoline; } // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer @@ -162,7 +166,8 @@ impl CombinedSerializer { config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder>, ) -> PyResult> { - Self::_build(schema, config, definitions, false) + let serializer = Self::_build(schema, config, definitions, false)?; + Self::maybe_wrap_in_polymorphism_trampoline(serializer, schema) } fn _build( @@ -219,7 +224,7 @@ impl CombinedSerializer { let type_ = type_.to_str()?; if use_prebuilt { - // if we have a SchemaValidator on the type already, use it + // if we have a SchemaSerializer on the type already, use it if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) { @@ -230,6 +235,35 @@ impl CombinedSerializer { Self::find_serializer(type_, schema, config, definitions) } + fn maybe_wrap_in_polymorphism_trampoline( + serializer: Arc, + schema: &Bound<'_, PyDict>, + ) -> PyResult> { + let py = schema.py(); + let type_: Bound<'_, PyString> = schema.get_as_req(intern!(py, "type"))?; + let type_ = type_.to_str()?; + + if type_ == "model" || type_ == "dataclass" { + // Get polymorphic serialization from config + let config = schema.get_as::>(intern!(py, "config"))?; + let polymorphic_serialization: bool = config + .and_then(|cfg| cfg.get_as(intern!(py, "polymorphic_serialization")).transpose()) + .unwrap_or(Ok(false))?; + + // Unconditionally wrap in PolymorphismTrampoline, because runtime flag might still enable it + Ok(Arc::new( + PolymorphismTrampoline::new( + schema.get_as_req(intern!(py, "cls"))?, + serializer, + polymorphic_serialization, + ) + .into(), + )) + } else { + Ok(serializer) + } + } + /// Main recursive way to call serializers, supports possible recursive type inference by /// switching to type inference mode eagerly. pub fn to_python<'py>( @@ -308,7 +342,8 @@ impl BuildSerializer for CombinedSerializer { config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder>, ) -> PyResult> { - Self::_build(schema, config, definitions, true) + let serializer = Self::_build(schema, config, definitions, true)?; + Self::maybe_wrap_in_polymorphism_trampoline(serializer, schema) } } @@ -356,6 +391,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::PolymorphismTrampoline(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index 6bcb2fe6f..7f4e33f40 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -540,47 +540,40 @@ struct SerializationInfo { field_name: Option, #[pyo3(get)] serialize_as_any: bool, + #[pyo3(get)] + polymorphic_serialization: Option, } impl SerializationInfo { fn new(state: &SerializationState<'_, '_>, is_field_serializer: bool) -> PyResult { let extra = &state.extra; - if is_field_serializer { - match state.field_name.as_ref() { - Some(field_name) => Ok(Self { - include: state.include().map(|i| i.clone().unbind()), - exclude: state.exclude().map(|e| e.clone().unbind()), - context: extra.context.map(|c| c.clone().unbind()), - _mode: extra.mode.clone(), - by_alias: extra.by_alias, - exclude_unset: extra.exclude_unset, - exclude_defaults: extra.exclude_defaults, - exclude_none: extra.exclude_none, - exclude_computed_fields: extra.exclude_none, - round_trip: extra.round_trip, - field_name: Some(field_name.to_string()), - serialize_as_any: extra.serialize_as_any, - }), - _ => Err(PyRuntimeError::new_err( + + let field_name = if is_field_serializer { + let Some(field_name) = state.field_name.as_ref() else { + return Err(PyRuntimeError::new_err( "Model field context expected for field serialization info but no model field was found", - )), - } + )); + }; + Some(field_name.to_string()) } else { - Ok(Self { - include: state.include().map(|i| i.clone().unbind()), - exclude: state.exclude().map(|e| e.clone().unbind()), - context: extra.context.map(|c| c.clone().unbind()), - _mode: extra.mode.clone(), - by_alias: extra.by_alias, - exclude_unset: extra.exclude_unset, - exclude_defaults: extra.exclude_defaults, - exclude_none: extra.exclude_none, - exclude_computed_fields: extra.exclude_computed_fields, - round_trip: extra.round_trip, - field_name: None, - serialize_as_any: extra.serialize_as_any, - }) - } + None + }; + + Ok(Self { + include: state.include().map(|i| i.clone().unbind()), + exclude: state.exclude().map(|e| e.clone().unbind()), + context: extra.context.map(|c| c.clone().unbind()), + _mode: extra.mode.clone(), + by_alias: extra.by_alias, + exclude_unset: extra.exclude_unset, + exclude_defaults: extra.exclude_defaults, + exclude_none: extra.exclude_none, + exclude_computed_fields: extra.exclude_computed_fields, + round_trip: extra.round_trip, + field_name, + serialize_as_any: extra.serialize_as_any, + polymorphic_serialization: extra.polymorphic_serialization, + }) } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index 4d8ef9ff6..e5eebf6aa 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -286,3 +286,119 @@ class Foo: ) s = SchemaSerializer(schema) assert s.to_python(Foo(my_field='hello'), by_alias=runtime) == expected + + +@pytest.mark.parametrize('config', [True, False, None]) +@pytest.mark.parametrize('runtime', [True, False, None]) +def test_polymorphic_serialization(config: bool, runtime: bool) -> None: + @dataclasses.dataclass + class ClassA: + a: int + + @dataclasses.dataclass + class ClassB(ClassA): + b: str + + model_config = core_schema.CoreConfig(polymorphic_serialization=config) if config is not None else None + + schema_a = core_schema.dataclass_schema( + ClassA, + core_schema.dataclass_args_schema( + 'ClassA', [core_schema.dataclass_field(name='a', schema=core_schema.int_schema())] + ), + ['a'], + config=model_config, + ) + + schema_b = core_schema.dataclass_schema( + ClassB, + core_schema.dataclass_args_schema( + 'ClassB', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.str_schema()), + ], + ), + ['a', 'b'], + ) + + ClassA.__pydantic_serializer__ = SchemaSerializer(schema_a) + ClassB.__pydantic_serializer__ = SchemaSerializer(schema_b) + + kwargs = {} + if runtime is not None: + kwargs['polymorphic_serialization'] = runtime + + assert ClassA.__pydantic_serializer__.to_python(ClassA(123), **kwargs) == {'a': 123} + assert ClassA.__pydantic_serializer__.to_json(ClassA(123), **kwargs) == b'{"a":123}' + + polymorphism_enabled = runtime if runtime is not None else config + if polymorphism_enabled: + assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test'), **kwargs) == {'a': 123, 'b': 'test'} + assert ClassA.__pydantic_serializer__.to_json(ClassB(123, 'test'), **kwargs) == b'{"a":123,"b":"test"}' + else: + assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test'), **kwargs) == {'a': 123} + assert ClassA.__pydantic_serializer__.to_json(ClassB(123, 'test'), **kwargs) == b'{"a":123}' + + +@pytest.mark.parametrize('config', [True, False, None]) +@pytest.mark.parametrize('runtime', [True, False, None]) +def test_polymorphic_serialization_with_model_serializer(config: bool, runtime: bool) -> None: + @dataclasses.dataclass + class ClassA: + a: int + + def serialize(self, info: core_schema.SerializationInfo) -> str: + assert info.polymorphic_serialization is runtime + return 'ClassA' + + @dataclasses.dataclass + class ClassB(ClassA): + b: str + + def serialize(self, info: core_schema.SerializationInfo) -> str: + assert info.polymorphic_serialization is runtime + return 'ClassB' + + model_config = core_schema.CoreConfig(polymorphic_serialization=config) if config is not None else None + + schema_a = core_schema.dataclass_schema( + ClassA, + core_schema.dataclass_args_schema( + 'ClassA', [core_schema.dataclass_field(name='a', schema=core_schema.int_schema())] + ), + ['a'], + config=model_config, + serialization=core_schema.plain_serializer_function_ser_schema(ClassA.serialize, info_arg=True), + ) + + schema_b = core_schema.dataclass_schema( + ClassB, + core_schema.dataclass_args_schema( + 'ClassB', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.str_schema()), + ], + ), + ['a', 'b'], + serialization=core_schema.plain_serializer_function_ser_schema(ClassB.serialize, info_arg=True), + ) + + ClassA.__pydantic_serializer__ = SchemaSerializer(schema_a) + ClassB.__pydantic_serializer__ = SchemaSerializer(schema_b) + + kwargs = {} + if runtime is not None: + kwargs['polymorphic_serialization'] = runtime + + assert ClassA.__pydantic_serializer__.to_python(ClassA(123), **kwargs) == 'ClassA' + assert ClassA.__pydantic_serializer__.to_json(ClassA(123), **kwargs) == b'"ClassA"' + + polymorphism_enabled = runtime if runtime is not None else config + if polymorphism_enabled: + assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test'), **kwargs) == 'ClassB' + assert ClassA.__pydantic_serializer__.to_json(ClassB(123, 'test'), **kwargs) == b'"ClassB"' + else: + assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test'), **kwargs) == 'ClassA' + assert ClassA.__pydantic_serializer__.to_json(ClassB(123, 'test'), **kwargs) == b'"ClassA"' diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index b162064e3..fa1479a97 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -1284,3 +1284,111 @@ def __init__(self, my_field: int) -> None: ) s = SchemaSerializer(schema) assert s.to_python(Model(1), by_alias=runtime) == expected + + +@pytest.mark.parametrize('config', [True, False, None]) +@pytest.mark.parametrize('runtime', [True, False, None]) +def test_polymorphic_serialization(config: bool, runtime: bool) -> None: + class ModelA: + def __init__(self, a: int) -> None: + self.a = a + + class ModelB(ModelA): + def __init__(self, a: int, b: str) -> None: + super().__init__(a) + self.b = b + + model_config = core_schema.CoreConfig(polymorphic_serialization=config) if config is not None else None + + schema_a = core_schema.model_schema( + ModelA, + core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}), + config=model_config, + ) + + schema_b = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field(core_schema.int_schema()), + 'b': core_schema.model_field(core_schema.str_schema()), + } + ), + ) + + ModelA.__pydantic_serializer__ = SchemaSerializer(schema_a) + ModelB.__pydantic_serializer__ = SchemaSerializer(schema_b) + + kwargs = {} + if runtime is not None: + kwargs['polymorphic_serialization'] = runtime + + assert ModelA.__pydantic_serializer__.to_python(ModelA(123), **kwargs) == {'a': 123} + assert ModelA.__pydantic_serializer__.to_json(ModelA(123), **kwargs) == b'{"a":123}' + + polymorphism_enabled = runtime if runtime is not None else config + if polymorphism_enabled: + assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test'), **kwargs) == {'a': 123, 'b': 'test'} + assert ModelA.__pydantic_serializer__.to_json(ModelB(123, 'test'), **kwargs) == b'{"a":123,"b":"test"}' + else: + assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test'), **kwargs) == {'a': 123} + assert ModelA.__pydantic_serializer__.to_json(ModelB(123, 'test'), **kwargs) == b'{"a":123}' + + +@pytest.mark.parametrize('config', [True, False, None]) +@pytest.mark.parametrize('runtime', [True, False, None]) +def test_polymorphic_serialization_with_model_serializer(config: bool, runtime: bool) -> None: + class ModelA: + def __init__(self, a: int) -> None: + self.a = a + + def serialize(value, info: core_schema.SerializationInfo) -> str: + assert info.polymorphic_serialization is runtime + return 'ModelA' + + class ModelB(ModelA): + def __init__(self, a: int, b: str) -> None: + super().__init__(a) + self.b = b + + def serialize(value, info: core_schema.SerializationInfo) -> str: + assert info.polymorphic_serialization is runtime + return 'ModelB' + + model_config = core_schema.CoreConfig(polymorphic_serialization=config) if config is not None else None + + schema_a = core_schema.model_schema( + ModelA, + core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}), + config=model_config, + serialization=core_schema.plain_serializer_function_ser_schema(ModelA.serialize, info_arg=True), + ) + + schema_b = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field(core_schema.int_schema()), + 'b': core_schema.model_field(core_schema.str_schema()), + } + ), + serialization=core_schema.plain_serializer_function_ser_schema(ModelB.serialize, info_arg=True), + ) + + ModelA.__pydantic_serializer__ = SchemaSerializer(schema_a) + ModelB.__pydantic_serializer__ = SchemaSerializer(schema_b) + + kwargs = {} + if runtime is not None: + kwargs['polymorphic_serialization'] = runtime + + assert ModelA.__pydantic_serializer__.to_python(ModelA(123), **kwargs) == 'ModelA' + assert ModelA.__pydantic_serializer__.to_json(ModelA(123), **kwargs) == b'"ModelA"' + + polymorphism_enabled = runtime if runtime is not None else config + if polymorphism_enabled: + assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test'), **kwargs) == 'ModelB' + assert ModelA.__pydantic_serializer__.to_json(ModelB(123, 'test'), **kwargs) == b'"ModelB"' + else: + assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test'), **kwargs) == 'ModelA' + assert ModelA.__pydantic_serializer__.to_json(ModelB(123, 'test'), **kwargs) == b'"ModelA"' diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 7cf11639a..3dcd28e63 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -100,6 +100,7 @@ class SubclassA(ModelA): @pytest.mark.parametrize('input_value', [ModelA(b'bite', 2.3456), SubclassA(b'bite', 2.3456)]) def test_model_a(model_serializer: SchemaSerializer, input_value): + print(model_serializer, input_value) assert model_serializer.to_python(input_value) == {'a': b'bite', 'b': '2.3'} assert model_serializer.to_python(input_value, mode='json') == {'a': 'bite', 'b': '2.3'} assert model_serializer.to_json(input_value) == b'{"a":"bite","b":"2.3"}' diff --git a/tests/test.rs b/tests/test.rs index 8a3f1d5dc..fa7efeab7 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -156,6 +156,7 @@ dump_json_input_2 = {'a': 'something'} None, false, None, + None, ) .unwrap(); let repr = format!("{}", serialization_result.bind(py).repr().unwrap()); @@ -179,6 +180,7 @@ dump_json_input_2 = {'a': 'something'} None, false, None, + None, ) .unwrap(); let repr = format!("{}", serialization_result.bind(py).repr().unwrap()); diff --git a/tests/test_prebuilt.py b/tests/test_prebuilt.py index c5a795b0f..c71694a7f 100644 --- a/tests/test_prebuilt.py +++ b/tests/test_prebuilt.py @@ -99,6 +99,9 @@ def __init__(self, inner: InnerModel) -> None: outer_serializer = SchemaSerializer(outer_schema) + print(inner_serializer) + print(outer_serializer) + # the custom serialization function does apply for the inner model inner_instance = InnerModel(x='hello') assert inner_serializer.to_python(inner_instance) == {'x': 'hello modified'}