Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cryptography/hazmat/asn1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from cryptography.hazmat.asn1.asn1 import (
Default,
Explicit,
GeneralizedTime,
Implicit,
PrintableString,
UtcTime,
decode_der,
Expand All @@ -14,7 +16,9 @@

__all__ = [
"Default",
"Explicit",
"GeneralizedTime",
"Implicit",
"PrintableString",
"UtcTime",
"decode_der",
Expand Down
24 changes: 21 additions & 3 deletions src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,30 @@ def _is_union(field_type: type) -> bool:
return get_type_origin(field_type) in union_types


def _extract_annotation(metadata: tuple) -> declarative_asn1.Annotation:
def _extract_annotation(
metadata: tuple, field_name: str
) -> declarative_asn1.Annotation:
default = None
encoding = None
for raw_annotation in metadata:
if isinstance(raw_annotation, Default):
if default is not None:
raise TypeError(
f"multiple DEFAULT annotations found in field "
f"'{field_name}'"
)
default = raw_annotation.value
elif isinstance(raw_annotation, declarative_asn1.Encoding):
if encoding is not None:
raise TypeError(
f"multiple IMPLICIT/EXPLICIT annotations found in field "
f"'{field_name}'"
)
encoding = raw_annotation
else:
raise TypeError(f"unsupported annotation: {raw_annotation}")

return declarative_asn1.Annotation(default=default)
return declarative_asn1.Annotation(default=default, encoding=encoding)


def _normalize_field_type(
Expand All @@ -75,7 +90,7 @@ def _normalize_field_type(
# Strip the `Annotated[...]` off, and populate the annotation
# from it if it exists.
if get_type_origin(field_type) is Annotated:
annotation = _extract_annotation(field_type.__metadata__)
annotation = _extract_annotation(field_type.__metadata__, field_name)
field_type, _ = get_type_args(field_type)
else:
annotation = declarative_asn1.Annotation()
Expand Down Expand Up @@ -194,6 +209,9 @@ class Default(typing.Generic[U]):
value: U


Explicit = declarative_asn1.Encoding.Explicit
Implicit = declarative_asn1.Encoding.Implicit

PrintableString = declarative_asn1.PrintableString
UtcTime = declarative_asn1.UtcTime
GeneralizedTime = declarative_asn1.GeneralizedTime
8 changes: 8 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@ class Type:

class Annotation:
default: typing.Any | None
encoding: Encoding | None
def __new__(
cls,
default: typing.Any | None = None,
encoding: Encoding | None = None,
) -> Annotation: ...
def is_empty(self) -> bool: ...

# Encoding is a Rust enum with tuple variants. For now, we express the type
# annotations like this:
class Encoding:
Implicit: typing.ClassVar[type]
Explicit: typing.ClassVar[type]

class AnnotatedType:
inner: Type
annotation: Annotation
Expand Down
69 changes: 50 additions & 19 deletions src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,41 @@ use pyo3::types::PyAnyMethods;

use crate::asn1::big_byte_slice_to_py_int;
use crate::declarative_asn1::types::{
type_to_tag, AnnotatedType, GeneralizedTime, PrintableString, Type, UtcTime,
type_to_tag, AnnotatedType, Encoding, GeneralizedTime, PrintableString, Type, UtcTime,
};
use crate::error::CryptographyError;

type ParseResult<T> = Result<T, CryptographyError>;

fn read_value<'a, T: asn1::SimpleAsn1Readable<'a>>(
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<T> {
let value = match encoding {
Some(e) => match e.get() {
Encoding::Implicit(n) => parser.read_implicit_element::<T>(*n),
Encoding::Explicit(n) => parser.read_explicit_element::<T>(*n),
},
None => parser.read_element::<T>(),
}?;
Ok(value)
}

fn decode_pybool<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyBool>> {
let value = parser.read_element::<bool>()?;
let value = read_value::<bool>(parser, encoding)?;
Ok(pyo3::types::PyBool::new(py, value).to_owned())
}

fn decode_pyint<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyInt>> {
let value = parser.read_element::<asn1::BigInt<'a>>()?;
let value = read_value::<asn1::BigInt<'a>>(parser, encoding)?;
let pyint =
big_byte_slice_to_py_int(py, value.as_bytes())?.cast_into::<pyo3::types::PyInt>()?;
Ok(pyint)
Expand All @@ -34,33 +50,37 @@ fn decode_pyint<'a>(
fn decode_pybytes<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyBytes>> {
let value = parser.read_element::<&[u8]>()?;
let value = read_value::<&[u8]>(parser, encoding)?;
Ok(pyo3::types::PyBytes::new(py, value))
}

fn decode_pystr<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyString>> {
let value = parser.read_element::<asn1::Utf8String<'a>>()?;
let value = read_value::<asn1::Utf8String<'a>>(parser, encoding)?;
Ok(pyo3::types::PyString::new(py, value.as_str()))
}

fn decode_printable_string<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, PrintableString>> {
let value = parser.read_element::<asn1::PrintableString<'a>>()?.as_str();
let value = read_value::<asn1::PrintableString<'a>>(parser, encoding)?.as_str();
let inner = pyo3::types::PyString::new(py, value).unbind();
Ok(pyo3::Bound::new(py, PrintableString { inner })?)
}

fn decode_utc_time<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, UtcTime>> {
let value = parser.read_element::<asn1::UtcTime>()?;
let value = read_value::<asn1::UtcTime>(parser, encoding)?;
let dt = value.as_datetime();

let inner = crate::x509::datetime_to_py_utc(py, dt)?
Expand All @@ -73,8 +93,9 @@ fn decode_utc_time<'a>(
fn decode_generalized_time<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
encoding: &Option<pyo3::Py<Encoding>>,
) -> ParseResult<pyo3::Bound<'a, GeneralizedTime>> {
let value = parser.read_element::<asn1::GeneralizedTime>()?;
let value = read_value::<asn1::GeneralizedTime>(parser, encoding)?;
let dt = value.as_datetime();

let microseconds = match value.nanoseconds() {
Expand Down Expand Up @@ -102,11 +123,12 @@ pub(crate) fn decode_annotated_type<'a>(
ann_type: &AnnotatedType,
) -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
let inner = ann_type.inner.get();
let encoding = &ann_type.annotation.get().encoding;

// Handle DEFAULT annotation if field is not present (by
// returning the default value)
if let Some(default) = &ann_type.annotation.get().default {
let expected_tag = type_to_tag(inner);
let expected_tag = type_to_tag(inner, encoding);
let next_tag = parser.peek_tag();
if next_tag != Some(expected_tag) {
return Ok(default.clone_ref(py).into_bound(py));
Expand All @@ -115,7 +137,7 @@ pub(crate) fn decode_annotated_type<'a>(

let decoded = match &inner {
Type::Sequence(cls, fields) => {
let seq_parse_result = parser.read_element::<asn1::Sequence<'_>>()?;
let seq_parse_result = read_value::<asn1::Sequence<'_>>(parser, encoding)?;

seq_parse_result.parse(|d| -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
let kwargs = pyo3::types::PyDict::new(py);
Expand All @@ -130,19 +152,28 @@ pub(crate) fn decode_annotated_type<'a>(
})?
}
Type::Option(cls) => {
let inner_tag = type_to_tag(cls.get().inner.get());
let inner_tag = type_to_tag(cls.get().inner.get(), encoding);
match parser.peek_tag() {
Some(t) if t == inner_tag => decode_annotated_type(py, parser, cls.get())?,
Some(t) if t == inner_tag => {
// For optional types, annotations will always be associated to the `Optional` type
// i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type.
// Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional`
let inner_ann_type = AnnotatedType {
inner: cls.get().inner.clone_ref(py),
annotation: ann_type.annotation.clone_ref(py),
};
decode_annotated_type(py, parser, &inner_ann_type)?
}
_ => pyo3::types::PyNone::get(py).to_owned().into_any(),
}
}
Type::PyBool() => decode_pybool(py, parser)?.into_any(),
Type::PyInt() => decode_pyint(py, parser)?.into_any(),
Type::PyBytes() => decode_pybytes(py, parser)?.into_any(),
Type::PyStr() => decode_pystr(py, parser)?.into_any(),
Type::PrintableString() => decode_printable_string(py, parser)?.into_any(),
Type::UtcTime() => decode_utc_time(py, parser)?.into_any(),
Type::GeneralizedTime() => decode_generalized_time(py, parser)?.into_any(),
Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(),
Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(),
Type::PyBytes() => decode_pybytes(py, parser, encoding)?.into_any(),
Type::PyStr() => decode_pystr(py, parser, encoding)?.into_any(),
Type::PrintableString() => decode_printable_string(py, parser, encoding)?.into_any(),
Type::UtcTime() => decode_utc_time(py, parser, encoding)?.into_any(),
Type::GeneralizedTime() => decode_generalized_time(py, parser, encoding)?.into_any(),
};

match &ann_type.annotation.get().default {
Expand Down
38 changes: 26 additions & 12 deletions src/rust/src/declarative_asn1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@ use asn1::{SimpleAsn1Writable, Writer};
use pyo3::types::PyAnyMethods;

use crate::declarative_asn1::types::{
AnnotatedType, AnnotatedTypeObject, GeneralizedTime, PrintableString, Type, UtcTime,
AnnotatedType, AnnotatedTypeObject, Encoding, GeneralizedTime, PrintableString, Type, UtcTime,
};

fn write_value<T: SimpleAsn1Writable>(
writer: &mut Writer<'_>,
value: &T,
encoding: &Option<pyo3::Py<Encoding>>,
) -> Result<(), asn1::WriteError> {
writer.write_element(value)
match encoding {
Some(e) => match e.get() {
Encoding::Implicit(tag) => writer.write_implicit_element(value, *tag),
Encoding::Explicit(tag) => writer.write_explicit_element(value, *tag),
},
None => writer.write_element(value),
}
}

impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
Expand All @@ -37,6 +44,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
}
}

let encoding = &annotated_type.annotation.get().encoding;
let inner = annotated_type.inner.get();
match &inner {
Type::Sequence(_cls, fields) => write_value(
Expand All @@ -60,14 +68,20 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
}
Ok(())
}),
encoding,
),
Type::Option(cls) => {
if !value.is_none() {
let object = AnnotatedTypeObject {
annotated_type: cls.get(),
let inner_object = AnnotatedTypeObject {
annotated_type: &AnnotatedType {
inner: cls.get().inner.clone_ref(py),
// Since for optional types the annotations are enforced to be associated with the Option
// (instead of the inner type), when encoding the inner type we add the annotations of the Option
annotation: annotated_type.annotation.clone_ref(py),
},
value,
};
object.write(writer)
inner_object.write(writer)
} else {
// Missing OPTIONAL values are omitted from DER encoding
Ok(())
Expand All @@ -77,26 +91,26 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
let val: bool = value
.extract()
.map_err(|_| asn1::WriteError::AllocationError)?;
write_value(writer, &val)
write_value(writer, &val, encoding)
}
Type::PyInt() => {
let val: i64 = value
.extract()
.map_err(|_| asn1::WriteError::AllocationError)?;
write_value(writer, &val)
write_value(writer, &val, encoding)
}
Type::PyBytes() => {
let val: &[u8] = value
.extract()
.map_err(|_| asn1::WriteError::AllocationError)?;
write_value(writer, &val)
write_value(writer, &val, encoding)
}
Type::PyStr() => {
let val: pyo3::pybacked::PyBackedStr = value
.extract()
.map_err(|_| asn1::WriteError::AllocationError)?;
let asn1_string: asn1::Utf8String<'_> = asn1::Utf8String::new(&val);
write_value(writer, &asn1_string)
write_value(writer, &asn1_string, encoding)
}
Type::PrintableString() => {
let val: &pyo3::Bound<'_, PrintableString> = value
Expand All @@ -110,7 +124,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
let printable_string: asn1::PrintableString<'_> =
asn1::PrintableString::new(&inner_str)
.ok_or(asn1::WriteError::AllocationError)?;
write_value(writer, &printable_string)
write_value(writer, &printable_string, encoding)
}
Type::UtcTime() => {
let val: &pyo3::Bound<'_, UtcTime> = value
Expand All @@ -121,7 +135,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
.map_err(|_| asn1::WriteError::AllocationError)?;
let utc_time =
asn1::UtcTime::new(datetime).map_err(|_| asn1::WriteError::AllocationError)?;
write_value(writer, &utc_time)
write_value(writer, &utc_time, encoding)
}
Type::GeneralizedTime() => {
let val: &pyo3::Bound<'_, GeneralizedTime> = value
Expand All @@ -134,7 +148,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
let nanoseconds = microseconds.map(|m| m * 1000);
let generalized_time = asn1::GeneralizedTime::new(datetime, nanoseconds)
.map_err(|_| asn1::WriteError::AllocationError)?;
write_value(writer, &generalized_time)
write_value(writer, &generalized_time, encoding)
}
}
}
Expand Down
Loading