33use super :: LibfuncHelper ;
44use crate :: {
55 error:: Result ,
6+ libfuncs:: increment_builtin_counter,
67 metadata:: MetadataStorage ,
78 native_assert, native_panic,
89 types:: TypeBuilder ,
@@ -46,7 +47,16 @@ pub fn build<'ctx, 'this>(
4647 }
4748}
4849
49- /// Generate MLIR operations for the `downcast` libfunc.
50+ /// Generate MLIR operations for the `downcast` libfunc which converts from a
51+ /// source type `T` to a target type `U`, where `U` might not fully include `T`.
52+ /// This means that the operation can fail.
53+ ///
54+ /// ## Signature
55+ /// ```cairo
56+ /// pub extern const fn downcast<FromType, ToType>(
57+ /// x: FromType,
58+ /// ) -> Option<ToType> implicits(RangeCheck) nopanic;
59+ /// ```
5060pub fn build_downcast < ' ctx , ' this > (
5161 context : & ' ctx Context ,
5262 registry : & ProgramRegistry < CoreType , CoreLibfunc > ,
@@ -59,18 +69,6 @@ pub fn build_downcast<'ctx, 'this>(
5969 let range_check = entry. arg ( 0 ) ?;
6070 let src_value: Value = entry. arg ( 1 ) ?;
6171
62- if info. signature . param_signatures [ 1 ] . ty == info. signature . branch_signatures [ 0 ] . vars [ 1 ] . ty {
63- let k0 = entry. const_int ( context, location, 0 , 1 ) ?;
64- return helper. cond_br (
65- context,
66- entry,
67- k0,
68- [ 0 , 1 ] ,
69- [ & [ range_check, src_value] , & [ range_check] ] ,
70- location,
71- ) ;
72- }
73-
7472 let src_ty = registry. get_type ( & info. signature . param_signatures [ 1 ] . ty ) ?;
7573 let dst_ty = registry. get_type ( & info. signature . branch_signatures [ 0 ] . vars [ 1 ] . ty ) ?;
7674
@@ -91,6 +89,28 @@ pub fn build_downcast<'ctx, 'this>(
9189 src_ty. integer_range ( registry) ?
9290 } ;
9391
92+ // When the source type is the same as the target type, we just return the
93+ // value as it cannot fail. However, for backwards compatibility, we need to
94+ // increment the range check as if we were checking the upper bound. See:
95+ // - https://github.com/starkware-libs/cairo/tree/v2.12.3/crates/cairo-lang-sierra/src/extensions/modules/casts.rs#L67.
96+ // - https://github.com/starkware-libs/cairo/tree/v2.12.3/crates/cairo-lang-sierra-to-casm/src/invocations/casts.rs#L56.
97+ if info. signature . param_signatures [ 1 ] . ty == info. signature . branch_signatures [ 0 ] . vars [ 1 ] . ty {
98+ let range_check = if src_range. lower == 0 . into ( ) {
99+ increment_builtin_counter ( context, entry, location, range_check) ?
100+ } else {
101+ range_check
102+ } ;
103+ let k1 = entry. const_int ( context, location, 1 , 1 ) ?;
104+ return helper. cond_br (
105+ context,
106+ entry,
107+ k1,
108+ [ 0 , 1 ] ,
109+ [ & [ range_check, src_value] , & [ range_check] ] ,
110+ location,
111+ ) ;
112+ }
113+
94114 let src_width = if src_ty. is_bounded_int ( registry) ? {
95115 src_range. offset_bit_width ( )
96116 } else {
@@ -108,6 +128,7 @@ pub fn build_downcast<'ctx, 'this>(
108128
109129 let is_signed = src_range. lower . sign ( ) == Sign :: Minus ;
110130
131+ // If the target type is wider than the source type, extend the value representation width.
111132 let src_value = if compute_width > src_width {
112133 if is_signed && !src_ty. is_bounded_int ( registry) ? && !src_ty. is_felt252 ( registry) ? {
113134 entry. extsi (
@@ -126,6 +147,11 @@ pub fn build_downcast<'ctx, 'this>(
126147 src_value
127148 } ;
128149
150+ // Correct the value representation accordingly.
151+ // 1. if it is a felt, then we need to convert the value from [0,P) to
152+ // [-P/2, P/2].
153+ // 2. if it is a bounded_int, we need to offset the value to get the
154+ // actual value.
129155 let src_value = if is_signed && src_ty. is_felt252 ( registry) ? {
130156 if src_range. upper . is_one ( ) {
131157 let adj_offset =
@@ -159,7 +185,10 @@ pub fn build_downcast<'ctx, 'this>(
159185 src_value
160186 } ;
161187
162- if !( dst_range. lower > src_range. lower || dst_range. upper < src_range. upper ) {
188+ // Check if the source type is included in the target type. If it is not
189+ // then check if the value is in bounds. If the value is also not in
190+ // bounds then return an error.
191+ if dst_range. lower <= src_range. lower && dst_range. upper >= src_range. upper {
163192 let dst_value = if dst_ty. is_bounded_int ( registry) ? && dst_range. lower != BigInt :: ZERO {
164193 let dst_offset = entry. const_int_from_type (
165194 context,
@@ -193,6 +222,7 @@ pub fn build_downcast<'ctx, 'this>(
193222 location,
194223 ) ?;
195224 } else {
225+ // Check if the value is in bounds with respect to the lower bound.
196226 let lower_check = if dst_range. lower > src_range. lower {
197227 let dst_lower = entry. const_int_from_type (
198228 context,
@@ -214,6 +244,7 @@ pub fn build_downcast<'ctx, 'this>(
214244 } else {
215245 None
216246 } ;
247+ // Check if the value is in bounds with respect to the upper bound.
217248 let upper_check = if dst_range. upper < src_range. upper {
218249 let dst_upper = entry. const_int_from_type (
219250 context,
@@ -456,10 +487,12 @@ pub fn build_upcast<'ctx, 'this>(
456487
457488#[ cfg( test) ]
458489mod test {
459- use crate :: { jit_enum, jit_struct, load_cairo, utils:: testing:: run_program_assert_output} ;
490+ use crate :: {
491+ jit_enum, jit_struct, load_cairo, utils:: testing:: run_program_assert_output, Value ,
492+ } ;
460493 use cairo_lang_sierra:: program:: Program ;
461- use cairo_vm:: Felt252 ;
462494 use lazy_static:: lazy_static;
495+ use starknet_types_core:: felt:: Felt ;
463496 use test_case:: test_case;
464497
465498 lazy_static ! {
@@ -484,6 +517,55 @@ mod test {
484517 )
485518 }
486519 } ;
520+ static ref DOWNCAST_BOUNDED_INT : ( String , Program ) = load_cairo! {
521+ #[ feature( "bounded-int-utils" ) ]
522+ use core:: internal:: bounded_int:: BoundedInt ;
523+
524+ extern const fn downcast<FromType , ToType >( x: FromType , ) -> Option <ToType > implicits( RangeCheck ) nopanic;
525+
526+ fn test_x_y<
527+ X ,
528+ Y ,
529+ +TryInto <felt252, X >,
530+ +Into <Y , felt252>
531+ >( v: felt252) -> felt252 {
532+ let v: X = v. try_into( ) . unwrap( ) ;
533+ let v: Y = downcast( v) . unwrap( ) ;
534+ v. into( )
535+ }
536+
537+ fn b0x30_b0x30( v: felt252) -> felt252 { test_x_y:: <BoundedInt <0 , 30 >, BoundedInt <0 , 30 >>( v) }
538+ fn bm31x30_b31x30( v: felt252) -> felt252 { test_x_y:: <BoundedInt <-31 , 30 >, BoundedInt <-31 , 30 >>( v) }
539+ fn bm31x30_bm5x30( v: felt252) -> felt252 { test_x_y:: <BoundedInt <-31 , 30 >, BoundedInt <-5 , 30 >>( v) }
540+ fn bm31x30_b5x30( v: felt252) -> felt252 { test_x_y:: <BoundedInt <-31 , 30 >, BoundedInt <5 , 30 >>( v) }
541+ fn b5x30_b31x31( v: felt252) -> felt252 { test_x_y:: <BoundedInt <5 , 31 >, BoundedInt <31 , 31 >>( v) }
542+ fn bm100x100_bm100xm1( v: felt252) -> felt252 { test_x_y:: <BoundedInt <-100 , 100 >, BoundedInt <-100 , -1 >>( v) }
543+ fn bm31xm31_bm31xm31( v: felt252) -> felt252 { test_x_y:: <BoundedInt <-31 , -31 >, BoundedInt <-31 , -31 >>( v) }
544+ // Check if the target type is wider than the source type
545+ fn b0x30_b5x40( v: felt252) -> felt252 { test_x_y:: <BoundedInt <0 , 30 >, BoundedInt <5 , 40 >>( v) }
546+ // Check if the source's lower and upper bound are included in the
547+ // target type.
548+ fn b0x30_bm40x40( v: felt252) -> felt252 { test_x_y:: <BoundedInt <0 , 30 >, BoundedInt <-40 , 40 >>( v) }
549+ } ;
550+ static ref DOWNCAST_FELT : ( String , Program ) = load_cairo! {
551+ extern const fn downcast<FromType , ToType >( x: FromType , ) -> Option <ToType > implicits( RangeCheck ) nopanic;
552+
553+ fn test_x_y<
554+ X ,
555+ Y ,
556+ +TryInto <felt252, X >,
557+ +Into <Y , felt252>
558+ >( v: felt252) -> felt252 {
559+ let v: X = v. try_into( ) . unwrap( ) ;
560+ let v: Y = downcast( v) . unwrap( ) ;
561+ v. into( )
562+ }
563+
564+ fn felt252_i8( v: felt252) -> felt252 { test_x_y:: <felt252, i8 >( v) }
565+ fn felt252_i16( v: felt252) -> felt252 { test_x_y:: <felt252, i16 >( v) }
566+ fn felt252_i32( v: felt252) -> felt252 { test_x_y:: <felt252, i32 >( v) }
567+ fn felt252_i64( v: felt252) -> felt252 { test_x_y:: <felt252, i64 >( v) }
568+ } ;
487569 }
488570
489571 #[ test]
@@ -504,25 +586,60 @@ mod test {
504586 jit_enum!( 1 , jit_struct!( ) ) ,
505587 jit_enum!( 1 , jit_struct!( ) ) ,
506588 jit_enum!( 1 , jit_struct!( ) ) ,
507- jit_enum!( 1 , jit_struct! ( ) ) ,
589+ jit_enum!( 0 , u8 :: MAX . into ( ) ) ,
508590 ) ,
509591 jit_struct!(
510592 jit_enum!( 1 , jit_struct!( ) ) ,
511593 jit_enum!( 1 , jit_struct!( ) ) ,
512594 jit_enum!( 1 , jit_struct!( ) ) ,
513- jit_enum!( 1 , jit_struct! ( ) ) ,
595+ jit_enum!( 0 , u16 :: MAX . into ( ) ) ,
514596 ) ,
515597 jit_struct!(
516598 jit_enum!( 1 , jit_struct!( ) ) ,
517599 jit_enum!( 1 , jit_struct!( ) ) ,
518- jit_enum!( 1 , jit_struct! ( ) ) ,
600+ jit_enum!( 0 , u32 :: MAX . into ( ) ) ,
519601 ) ,
520- jit_struct!( jit_enum!( 1 , jit_struct!( ) ) , jit_enum!( 1 , jit_struct! ( ) ) ) ,
521- jit_struct!( jit_enum!( 1 , jit_struct! ( ) ) ) ,
602+ jit_struct!( jit_enum!( 1 , jit_struct!( ) ) , jit_enum!( 0 , u64 :: MAX . into ( ) ) ) ,
603+ jit_struct!( jit_enum!( 0 , u128 :: MAX . into ( ) ) ) ,
522604 ) ,
523605 ) ;
524606 }
525607
608+ #[ test_case( "b0x30_b0x30" , 5 . into( ) ) ]
609+ #[ test_case( "bm31x30_b31x30" , 5 . into( ) ) ]
610+ #[ test_case( "bm31x30_bm5x30" , ( -5 ) . into( ) ) ]
611+ #[ test_case( "bm31x30_b5x30" , 30 . into( ) ) ]
612+ #[ test_case( "b5x30_b31x31" , 31 . into( ) ) ]
613+ #[ test_case( "bm100x100_bm100xm1" , ( -90 ) . into( ) ) ]
614+ #[ test_case( "bm31xm31_bm31xm31" , ( -31 ) . into( ) ) ]
615+ #[ test_case( "b0x30_b5x40" , 10 . into( ) ) ]
616+ #[ test_case( "b0x30_bm40x40" , 10 . into( ) ) ]
617+ fn downcast_bounded_int ( entry_point : & str , value : Felt ) {
618+ run_program_assert_output (
619+ & DOWNCAST_BOUNDED_INT ,
620+ entry_point,
621+ & [ Value :: Felt252 ( value) ] ,
622+ jit_enum ! ( 0 , jit_struct!( Value :: Felt252 ( value) ) ) ,
623+ ) ;
624+ }
625+
626+ #[ test_case( "felt252_i8" , i8 :: MAX . into( ) ) ]
627+ #[ test_case( "felt252_i8" , i8 :: MIN . into( ) ) ]
628+ #[ test_case( "felt252_i16" , i16 :: MAX . into( ) ) ]
629+ #[ test_case( "felt252_i16" , i16 :: MIN . into( ) ) ]
630+ #[ test_case( "felt252_i32" , i32 :: MAX . into( ) ) ]
631+ #[ test_case( "felt252_i32" , i32 :: MIN . into( ) ) ]
632+ #[ test_case( "felt252_i64" , i64 :: MAX . into( ) ) ]
633+ #[ test_case( "felt252_i64" , i64 :: MIN . into( ) ) ]
634+ fn downcast_felt ( entry_point : & str , value : Felt ) {
635+ run_program_assert_output (
636+ & DOWNCAST_FELT ,
637+ entry_point,
638+ & [ Value :: Felt252 ( value) ] ,
639+ jit_enum ! ( 0 , jit_struct!( Value :: Felt252 ( value) ) ) ,
640+ ) ;
641+ }
642+
526643 lazy_static ! {
527644 static ref TEST_UPCAST_PROGRAM : ( String , Program ) = load_cairo! {
528645 #[ feature( "bounded-int-utils" ) ]
@@ -669,13 +786,13 @@ mod test {
669786 #[ test_case( "b2x5_b1x10" , 5 . into( ) ) ]
670787 #[ test_case( "b0x5_bm10x10" , 0 . into( ) ) ]
671788 #[ test_case( "b0x5_bm10x10" , 5 . into( ) ) ]
672- #[ test_case( "bm5x5_bm10x10" , Felt252 :: from( -5 ) ) ]
789+ #[ test_case( "bm5x5_bm10x10" , Felt :: from( -5 ) ) ]
673790 #[ test_case( "bm5x5_bm10x10" , 5 . into( ) ) ]
674- #[ test_case( "i8_bm200x200" , Felt252 :: from( -128 ) ) ]
791+ #[ test_case( "i8_bm200x200" , Felt :: from( -128 ) ) ]
675792 #[ test_case( "i8_bm200x200" , 127 . into( ) ) ]
676- #[ test_case( "bm100x100_i8" , Felt252 :: from( -100 ) ) ]
793+ #[ test_case( "bm100x100_i8" , Felt :: from( -100 ) ) ]
677794 #[ test_case( "bm100x100_i8" , 100 . into( ) ) ]
678- fn upcast ( entry_point : & str , value : Felt252 ) {
795+ fn upcast ( entry_point : & str , value : Felt ) {
679796 let arguments = & [ value. into ( ) ] ;
680797 let expected_result = jit_enum ! ( 0 , jit_struct!( value. into( ) , ) ) ;
681798 run_program_assert_output (
0 commit comments