Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ fn list_error_json(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -184,7 +183,6 @@ fn list_error_python_input(py: Python<'_>) -> (SchemaValidator, PyObject) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -357,7 +355,6 @@ fn dict_value_error(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "dict[str,constrained-int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -484,7 +481,6 @@ fn typed_dict_deep_error(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "typed-dict");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 1);
Expand Down
60 changes: 60 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Hashable, Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from fractions import Fraction
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Literal, Union

Expand Down Expand Up @@ -809,6 +810,61 @@ def decimal_schema(
serialization=serialization,
)

class FractionSchema(TypedDict, total=False):
type: Required[Literal['decimal']]
le: Decimal
ge: Decimal
lt: Decimal
gt: Decimal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think some decimal -> fraction replacement needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, too much copy + paste ;)

strict: bool
ref: str
metadata: dict[str, Any]
serialization: SerSchema

def fraction_schema(
*,
le: Fraction | None = None,
ge: Fraction | None = None,
lt: Fraction | None = None,
gt: Fraction | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: SerSchema | None = None,
) -> FractionSchema:
"""
Returns a schema that matches a fraction value, e.g.:

```py
from fractions import Fraction
from pydantic_core import SchemaValidator, core_schema

schema = core_schema.fraction_schema(le=0.8, ge=0.2)
v = SchemaValidator(schema)
assert v.validate_python(1, 2) == Fraction(1, 2)
```

Args:
le: The value must be less than or equal to this number
ge: The value must be greater than or equal to this number
lt: The value must be strictly less than this number
gt: The value must be strictly greater than this number
strict: Whether the value should be a float or a value that can be converted to a float
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='fraction',
gt=gt,
ge=ge,
lt=lt,
le=le,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)

class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
Expand Down Expand Up @@ -4109,6 +4165,7 @@ def definition_reference_schema(
IntSchema,
FloatSchema,
DecimalSchema,
FractionSchema,
StringSchema,
BytesSchema,
DateSchema,
Expand Down Expand Up @@ -4168,6 +4225,7 @@ def definition_reference_schema(
'int',
'float',
'decimal',
'fraction',
'str',
'bytes',
'date',
Expand Down Expand Up @@ -4318,6 +4376,8 @@ def definition_reference_schema(
'uuid_version',
'decimal_type',
'decimal_parsing',
'fraction_type',
'fraction_parsing',
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
Expand Down
5 changes: 5 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Fraction errors
FractionType {},
FractionParsing {},
// Complex errors
ComplexType {},
ComplexStrParsing {},
Expand Down Expand Up @@ -579,6 +582,8 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::FractionParsing {..} => "Fraction input should be a tuple of two integers, a string or a Fraction object",
Self::FractionType {..} => "Fraction input should be a tuple of two integers, or a string or Fraction object",
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
}
Expand Down
2 changes: 2 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ pub trait Input<'py>: fmt::Debug {

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

type Dict<'a>: ValidatedDict<'py>
where
Self: 'a;
Expand Down
14 changes: 14 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::fraction::create_fraction;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
Expand Down Expand Up @@ -199,6 +200,15 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
create_fraction(&self.into_pyobject(py)?, self).map(ValidationMatch::strict)
}
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
}

type Dict<'a>
= &'a JsonObject<'data>
where
Expand Down Expand Up @@ -454,6 +464,10 @@ impl<'py> Input<'py> for str {
create_decimal(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
}

fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
create_fraction(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
}

type Dict<'a> = Never;

#[cfg_attr(has_coverage_attribute, coverage(off))]
Expand Down
47 changes: 30 additions & 17 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::str::from_utf8;
use pyo3::intern;
use pyo3::prelude::*;

use pyo3::sync::PyOnceLock;
use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
Expand All @@ -18,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::fraction::{create_fraction, get_fraction_type};
use crate::validators::Exactness;
use crate::validators::TemporalUnitMode;
use crate::validators::ValBytesMode;
Expand Down Expand Up @@ -48,20 +48,6 @@ use super::{
Input,
};

static FRACTION_TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();

pub fn get_fraction_type(py: Python<'_>) -> &Bound<'_, PyType> {
FRACTION_TYPE
.get_or_init(py, || {
py.import("fractions")
.and_then(|fractions_module| fractions_module.getattr("Fraction"))
.unwrap()
.extract()
.unwrap()
})
.bind(py)
}

pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
input.as_python().and_then(|any| any.downcast::<T>().ok())
}
Expand Down Expand Up @@ -290,8 +276,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
decimal_as_int(self, &decimal.into_inner())
} else if self.is_instance(get_fraction_type(self.py()))? {
fraction_as_int(self)
} else if let Ok(fraction) = self.validate_fraction(true, self.py()) {
fraction_as_int(self, &fraction.into_inner())
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Expand Down Expand Up @@ -349,6 +335,33 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
}

fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let fraction_type = get_fraction_type(py);

// Fast path for existing decimal objects
if self.is_exact_instance(fraction_type) {
return Ok(ValidationMatch::exact(self.to_owned().clone()));
}

if !strict && self.is_instance_of::<PyString>() {
return create_fraction(self, self).map(ValidationMatch::lax);
}

let error_type = if strict {
ErrorType::IsInstanceOf {
class: fraction_type
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "Fraction".to_owned()),
context: None,
}
} else {
ErrorTypeDefaults::FractionType
};

Err(ValError::new(error_type, self))
}

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);

Expand Down
8 changes: 8 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::fraction::create_fraction;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
Expand Down Expand Up @@ -154,6 +155,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn validate_fraction(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
Self::String(s) => create_fraction(s, self).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
}

type Dict<'a>
= StringMappingDict<'py>
where
Expand Down
30 changes: 13 additions & 17 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,18 @@ pub fn decimal_as_int<'py>(
Ok(EitherInt::Py(numerator))
}

pub fn fraction_as_int<'py>(input: &Bound<'py, PyAny>) -> ValResult<EitherInt<'py>> {
#[cfg(Py_3_12)]
let is_integer = input.call_method0("is_integer")?.extract::<bool>()?;
#[cfg(not(Py_3_12))]
let is_integer = input.getattr("denominator")?.extract::<i64>().map_or(false, |d| d == 1);

if is_integer {
#[cfg(Py_3_11)]
let as_int = input.call_method0("__int__");
#[cfg(not(Py_3_11))]
let as_int = input.call_method0("__trunc__");
match as_int {
Ok(i) => Ok(EitherInt::Py(i.as_any().to_owned())),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntType, input)),
}
} else {
Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input))
pub fn fraction_as_int<'py>(
input: &(impl Input<'py> + ?Sized),
fraction: &Bound<'py, PyAny>,
) -> ValResult<EitherInt<'py>> {
let py = fraction.py();

// as_integer_ratio was added in Python 3.8, so this should be fine
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// as_integer_ratio was added in Python 3.8, so this should be fine

I don't think it's necessary to mention addition of such methods, as we don't support 3.8 anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

let (numerator, denominator) = fraction
.call_method0(intern!(py, "as_integer_ratio"))?
.extract::<(Bound<'_, PyAny>, Bound<'_, PyAny>)>()?;
if denominator.extract::<i64>().map_or(true, |d| d != 1) {
return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input));
}
Ok(EitherInt::Py(numerator))
}
3 changes: 3 additions & 0 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ pub(crate) fn infer_to_python_known(
v.into_py_any(py)?
}
ObType::Decimal => value.to_string().into_py_any(py)?,
ObType::Fraction => value.to_string().into_py_any(py)?,
ObType::StrSubclass => PyString::new(py, value.downcast::<PyString>()?.to_str()?).into(),
ObType::Bytes => extra
.config
Expand Down Expand Up @@ -431,6 +432,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
}
ObType::Decimal => value.to_string().serialize(serializer),
ObType::Fraction => value.to_string().serialize(serializer),
ObType::Str | ObType::StrSubclass => {
let py_str = value.downcast::<PyString>().map_err(py_err_se_err)?;
super::type_serializers::string::serialize_py_str(py_str, serializer)
Expand Down Expand Up @@ -613,6 +615,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
}
ObType::Decimal => Ok(Cow::Owned(key.to_string())),
ObType::Fraction => Ok(Cow::Owned(key.to_string())),
ObType::Bool => super::type_serializers::simple::bool_json_key(key),
ObType::Str | ObType::StrSubclass => key.downcast::<PyString>()?.to_cow(),
ObType::Bytes => extra
Expand Down
Loading