Skip to content

Commit f003a2a

Browse files
FrancoGiachettagabrielbosioJulianGCalderon
authored
Fix downcast for bounded ints (#1474)
* fix trivial case * fix clippy * fix tests * add more tests with bounded_ints * add more tests with felts * add documentation to the libfunc * fmt * simplify implementation * remove unwanted changes * correct if condition * bring back value width comparison * roll back change * reviews * change tests * change downcast_felt test * better comment Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> * increment range check in trivial case * fmt * format2 * Unify comment --------- Co-authored-by: Gabriel Bosio <38794644+gabrielbosio@users.noreply.github.com> Co-authored-by: Julián González Calderón <gonzalezcalderonjulian@gmail.com>
1 parent 2241ffb commit f003a2a

File tree

1 file changed

+142
-25
lines changed

1 file changed

+142
-25
lines changed

src/libfuncs/cast.rs

Lines changed: 142 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use super::LibfuncHelper;
44
use crate::{
55
error::Result,
6+
libfuncs::increment_builtin_counter,
67
metadata::MetadataStorage,
78
native_assert, native_panic,
89
types::TypeBuilder,
@@ -46,7 +47,16 @@ pub fn build<'ctx, 'this>(
4647
}
4748
}
4849

49-
/// Generate MLIR operations for the `downcast` libfunc.
50+
/// Generate MLIR operations for the `downcast` libfunc which converts from a
51+
/// source type `T` to a target type `U`, where `U` might not fully include `T`.
52+
/// This means that the operation can fail.
53+
///
54+
/// ## Signature
55+
/// ```cairo
56+
/// pub extern const fn downcast<FromType, ToType>(
57+
/// x: FromType,
58+
/// ) -> Option<ToType> implicits(RangeCheck) nopanic;
59+
/// ```
5060
pub fn build_downcast<'ctx, 'this>(
5161
context: &'ctx Context,
5262
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
@@ -59,18 +69,6 @@ pub fn build_downcast<'ctx, 'this>(
5969
let range_check = entry.arg(0)?;
6070
let src_value: Value = entry.arg(1)?;
6171

62-
if info.signature.param_signatures[1].ty == info.signature.branch_signatures[0].vars[1].ty {
63-
let k0 = entry.const_int(context, location, 0, 1)?;
64-
return helper.cond_br(
65-
context,
66-
entry,
67-
k0,
68-
[0, 1],
69-
[&[range_check, src_value], &[range_check]],
70-
location,
71-
);
72-
}
73-
7472
let src_ty = registry.get_type(&info.signature.param_signatures[1].ty)?;
7573
let dst_ty = registry.get_type(&info.signature.branch_signatures[0].vars[1].ty)?;
7674

@@ -91,6 +89,28 @@ pub fn build_downcast<'ctx, 'this>(
9189
src_ty.integer_range(registry)?
9290
};
9391

92+
// When the source type is the same as the target type, we just return the
93+
// value as it cannot fail. However, for backwards compatibility, we need to
94+
// increment the range check as if we were checking the upper bound. See:
95+
// - https://github.com/starkware-libs/cairo/tree/v2.12.3/crates/cairo-lang-sierra/src/extensions/modules/casts.rs#L67.
96+
// - https://github.com/starkware-libs/cairo/tree/v2.12.3/crates/cairo-lang-sierra-to-casm/src/invocations/casts.rs#L56.
97+
if info.signature.param_signatures[1].ty == info.signature.branch_signatures[0].vars[1].ty {
98+
let range_check = if src_range.lower == 0.into() {
99+
increment_builtin_counter(context, entry, location, range_check)?
100+
} else {
101+
range_check
102+
};
103+
let k1 = entry.const_int(context, location, 1, 1)?;
104+
return helper.cond_br(
105+
context,
106+
entry,
107+
k1,
108+
[0, 1],
109+
[&[range_check, src_value], &[range_check]],
110+
location,
111+
);
112+
}
113+
94114
let src_width = if src_ty.is_bounded_int(registry)? {
95115
src_range.offset_bit_width()
96116
} else {
@@ -108,6 +128,7 @@ pub fn build_downcast<'ctx, 'this>(
108128

109129
let is_signed = src_range.lower.sign() == Sign::Minus;
110130

131+
// If the target type is wider than the source type, extend the value representation width.
111132
let src_value = if compute_width > src_width {
112133
if is_signed && !src_ty.is_bounded_int(registry)? && !src_ty.is_felt252(registry)? {
113134
entry.extsi(
@@ -126,6 +147,11 @@ pub fn build_downcast<'ctx, 'this>(
126147
src_value
127148
};
128149

150+
// Correct the value representation accordingly.
151+
// 1. if it is a felt, then we need to convert the value from [0,P) to
152+
// [-P/2, P/2].
153+
// 2. if it is a bounded_int, we need to offset the value to get the
154+
// actual value.
129155
let src_value = if is_signed && src_ty.is_felt252(registry)? {
130156
if src_range.upper.is_one() {
131157
let adj_offset =
@@ -159,7 +185,10 @@ pub fn build_downcast<'ctx, 'this>(
159185
src_value
160186
};
161187

162-
if !(dst_range.lower > src_range.lower || dst_range.upper < src_range.upper) {
188+
// Check if the source type is included in the target type. If it is not
189+
// then check if the value is in bounds. If the value is also not in
190+
// bounds then return an error.
191+
if dst_range.lower <= src_range.lower && dst_range.upper >= src_range.upper {
163192
let dst_value = if dst_ty.is_bounded_int(registry)? && dst_range.lower != BigInt::ZERO {
164193
let dst_offset = entry.const_int_from_type(
165194
context,
@@ -193,6 +222,7 @@ pub fn build_downcast<'ctx, 'this>(
193222
location,
194223
)?;
195224
} else {
225+
// Check if the value is in bounds with respect to the lower bound.
196226
let lower_check = if dst_range.lower > src_range.lower {
197227
let dst_lower = entry.const_int_from_type(
198228
context,
@@ -214,6 +244,7 @@ pub fn build_downcast<'ctx, 'this>(
214244
} else {
215245
None
216246
};
247+
// Check if the value is in bounds with respect to the upper bound.
217248
let upper_check = if dst_range.upper < src_range.upper {
218249
let dst_upper = entry.const_int_from_type(
219250
context,
@@ -456,10 +487,12 @@ pub fn build_upcast<'ctx, 'this>(
456487

457488
#[cfg(test)]
458489
mod test {
459-
use crate::{jit_enum, jit_struct, load_cairo, utils::testing::run_program_assert_output};
490+
use crate::{
491+
jit_enum, jit_struct, load_cairo, utils::testing::run_program_assert_output, Value,
492+
};
460493
use cairo_lang_sierra::program::Program;
461-
use cairo_vm::Felt252;
462494
use lazy_static::lazy_static;
495+
use starknet_types_core::felt::Felt;
463496
use test_case::test_case;
464497

465498
lazy_static! {
@@ -484,6 +517,55 @@ mod test {
484517
)
485518
}
486519
};
520+
static ref DOWNCAST_BOUNDED_INT: (String, Program) = load_cairo! {
521+
#[feature("bounded-int-utils")]
522+
use core::internal::bounded_int::BoundedInt;
523+
524+
extern const fn downcast<FromType, ToType>( x: FromType, ) -> Option<ToType> implicits(RangeCheck) nopanic;
525+
526+
fn test_x_y<
527+
X,
528+
Y,
529+
+TryInto<felt252, X>,
530+
+Into<Y, felt252>
531+
>(v: felt252) -> felt252 {
532+
let v: X = v.try_into().unwrap();
533+
let v: Y = downcast(v).unwrap();
534+
v.into()
535+
}
536+
537+
fn b0x30_b0x30(v: felt252) -> felt252 { test_x_y::<BoundedInt<0,30>, BoundedInt<0,30>>(v) }
538+
fn bm31x30_b31x30(v: felt252) -> felt252 { test_x_y::<BoundedInt<-31,30>, BoundedInt<-31,30>>(v) }
539+
fn bm31x30_bm5x30(v: felt252) -> felt252 { test_x_y::<BoundedInt<-31,30>, BoundedInt<-5,30>>(v) }
540+
fn bm31x30_b5x30(v: felt252) -> felt252 { test_x_y::<BoundedInt<-31,30>, BoundedInt<5,30>>(v) }
541+
fn b5x30_b31x31(v: felt252) -> felt252 { test_x_y::<BoundedInt<5,31>, BoundedInt<31,31>>(v) }
542+
fn bm100x100_bm100xm1(v: felt252) -> felt252 { test_x_y::<BoundedInt<-100,100>, BoundedInt<-100,-1>>(v) }
543+
fn bm31xm31_bm31xm31(v: felt252) -> felt252 { test_x_y::<BoundedInt<-31,-31>, BoundedInt<-31,-31>>(v) }
544+
// Check if the target type is wider than the source type
545+
fn b0x30_b5x40(v: felt252) -> felt252 { test_x_y::<BoundedInt<0,30>, BoundedInt<5,40>>(v) }
546+
// Check if the source's lower and upper bound are included in the
547+
// target type.
548+
fn b0x30_bm40x40(v: felt252) -> felt252 { test_x_y::<BoundedInt<0,30>, BoundedInt<-40,40>>(v) }
549+
};
550+
static ref DOWNCAST_FELT: (String, Program) = load_cairo! {
551+
extern const fn downcast<FromType, ToType>( x: FromType, ) -> Option<ToType> implicits(RangeCheck) nopanic;
552+
553+
fn test_x_y<
554+
X,
555+
Y,
556+
+TryInto<felt252, X>,
557+
+Into<Y, felt252>
558+
>(v: felt252) -> felt252 {
559+
let v: X = v.try_into().unwrap();
560+
let v: Y = downcast(v).unwrap();
561+
v.into()
562+
}
563+
564+
fn felt252_i8(v: felt252) -> felt252 { test_x_y::<felt252, i8>(v) }
565+
fn felt252_i16(v: felt252) -> felt252 { test_x_y::<felt252, i16>(v) }
566+
fn felt252_i32(v: felt252) -> felt252 { test_x_y::<felt252, i32>(v) }
567+
fn felt252_i64(v: felt252) -> felt252 { test_x_y::<felt252, i64>(v) }
568+
};
487569
}
488570

489571
#[test]
@@ -504,25 +586,60 @@ mod test {
504586
jit_enum!(1, jit_struct!()),
505587
jit_enum!(1, jit_struct!()),
506588
jit_enum!(1, jit_struct!()),
507-
jit_enum!(1, jit_struct!()),
589+
jit_enum!(0, u8::MAX.into()),
508590
),
509591
jit_struct!(
510592
jit_enum!(1, jit_struct!()),
511593
jit_enum!(1, jit_struct!()),
512594
jit_enum!(1, jit_struct!()),
513-
jit_enum!(1, jit_struct!()),
595+
jit_enum!(0, u16::MAX.into()),
514596
),
515597
jit_struct!(
516598
jit_enum!(1, jit_struct!()),
517599
jit_enum!(1, jit_struct!()),
518-
jit_enum!(1, jit_struct!()),
600+
jit_enum!(0, u32::MAX.into()),
519601
),
520-
jit_struct!(jit_enum!(1, jit_struct!()), jit_enum!(1, jit_struct!())),
521-
jit_struct!(jit_enum!(1, jit_struct!())),
602+
jit_struct!(jit_enum!(1, jit_struct!()), jit_enum!(0, u64::MAX.into())),
603+
jit_struct!(jit_enum!(0, u128::MAX.into())),
522604
),
523605
);
524606
}
525607

608+
#[test_case("b0x30_b0x30", 5.into())]
609+
#[test_case("bm31x30_b31x30", 5.into())]
610+
#[test_case("bm31x30_bm5x30", (-5).into())]
611+
#[test_case("bm31x30_b5x30", 30.into())]
612+
#[test_case("b5x30_b31x31", 31.into())]
613+
#[test_case("bm100x100_bm100xm1", (-90).into())]
614+
#[test_case("bm31xm31_bm31xm31", (-31).into())]
615+
#[test_case("b0x30_b5x40", 10.into())]
616+
#[test_case("b0x30_bm40x40", 10.into())]
617+
fn downcast_bounded_int(entry_point: &str, value: Felt) {
618+
run_program_assert_output(
619+
&DOWNCAST_BOUNDED_INT,
620+
entry_point,
621+
&[Value::Felt252(value)],
622+
jit_enum!(0, jit_struct!(Value::Felt252(value))),
623+
);
624+
}
625+
626+
#[test_case("felt252_i8", i8::MAX.into())]
627+
#[test_case("felt252_i8", i8::MIN.into())]
628+
#[test_case("felt252_i16", i16::MAX.into())]
629+
#[test_case("felt252_i16", i16::MIN.into())]
630+
#[test_case("felt252_i32", i32::MAX.into())]
631+
#[test_case("felt252_i32", i32::MIN.into())]
632+
#[test_case("felt252_i64", i64::MAX.into())]
633+
#[test_case("felt252_i64", i64::MIN.into())]
634+
fn downcast_felt(entry_point: &str, value: Felt) {
635+
run_program_assert_output(
636+
&DOWNCAST_FELT,
637+
entry_point,
638+
&[Value::Felt252(value)],
639+
jit_enum!(0, jit_struct!(Value::Felt252(value))),
640+
);
641+
}
642+
526643
lazy_static! {
527644
static ref TEST_UPCAST_PROGRAM: (String, Program) = load_cairo! {
528645
#[feature("bounded-int-utils")]
@@ -669,13 +786,13 @@ mod test {
669786
#[test_case("b2x5_b1x10", 5.into())]
670787
#[test_case("b0x5_bm10x10", 0.into())]
671788
#[test_case("b0x5_bm10x10", 5.into())]
672-
#[test_case("bm5x5_bm10x10", Felt252::from(-5))]
789+
#[test_case("bm5x5_bm10x10", Felt::from(-5))]
673790
#[test_case("bm5x5_bm10x10", 5.into())]
674-
#[test_case("i8_bm200x200", Felt252::from(-128))]
791+
#[test_case("i8_bm200x200", Felt::from(-128))]
675792
#[test_case("i8_bm200x200", 127.into())]
676-
#[test_case("bm100x100_i8", Felt252::from(-100))]
793+
#[test_case("bm100x100_i8", Felt::from(-100))]
677794
#[test_case("bm100x100_i8", 100.into())]
678-
fn upcast(entry_point: &str, value: Felt252) {
795+
fn upcast(entry_point: &str, value: Felt) {
679796
let arguments = &[value.into()];
680797
let expected_result = jit_enum!(0, jit_struct!(value.into(),));
681798
run_program_assert_output(

0 commit comments

Comments
 (0)