@@ -73,6 +73,22 @@ pub fn build<'ctx, 'this>(
7373}
7474
7575/// Generate MLIR operations for the `bounded_int_add` libfunc.
76+ ///
77+ /// # Cairo Signature
78+ ///
79+ /// ```cairo
80+ /// extern fn bounded_int_add<Lhs, Rhs, impl H: AddHelper<Lhs, Rhs>>(
81+ /// lhs: Lhs, rhs: Rhs,
82+ /// ) -> H::Result nopanic;
83+ /// ```
84+ ///
85+ /// A number X as a `BoundedInt` is internally represented as an offset Xd from the lower bound Xo.
86+ /// So X = Xo + Xd.
87+ ///
88+ /// Since we want to get C = A + B, we can translate this to
89+ /// Co + Cd = Ao + Ad + Bo + Bd. Where Ao, Bo and Co represent the lower bound
90+ /// of the ranges in the `BoundedInt` and Ad, Bd and Cd represent the offsets. Since
91+ /// we also know that Co = Ao + Bo we can simplify the equation to Cd = Ad + Bd.
7692#[ allow( clippy:: too_many_arguments) ]
7793fn build_add < ' ctx , ' this > (
7894 context : & ' ctx Context ,
@@ -86,7 +102,7 @@ fn build_add<'ctx, 'this>(
86102 let lhs_value = entry. arg ( 0 ) ?;
87103 let rhs_value = entry. arg ( 1 ) ?;
88104
89- // Extract the ranges for the operands and the result type .
105+ // Extract the ranges for the operands.
90106 let lhs_ty = registry. get_type ( & info. signature . param_signatures [ 0 ] . ty ) ?;
91107 let rhs_ty = registry. get_type ( & info. signature . param_signatures [ 1 ] . ty ) ?;
92108
@@ -96,6 +112,7 @@ fn build_add<'ctx, 'this>(
96112 . get_type ( & info. signature . branch_signatures [ 0 ] . vars [ 0 ] . ty ) ?
97113 . integer_range ( registry) ?;
98114
115+ // Extract the bit width.
99116 let lhs_width = if lhs_ty. is_bounded_int ( registry) ? {
100117 lhs_range. offset_bit_width ( )
101118 } else {
@@ -106,31 +123,14 @@ fn build_add<'ctx, 'this>(
106123 } else {
107124 rhs_range. zero_based_bit_width ( )
108125 } ;
126+ let dst_width = dst_range. offset_bit_width ( ) ;
109127
110- // Calculate the computation range.
111- let compute_range = Range {
112- lower : ( & lhs_range. lower )
113- . min ( & rhs_range. lower )
114- . min ( & dst_range. lower )
115- . clone ( ) ,
116- upper : ( & lhs_range. upper )
117- . max ( & rhs_range. upper )
118- . max ( & dst_range. upper )
119- . clone ( ) ,
120- } ;
121- let compute_ty = IntegerType :: new ( context, compute_range. offset_bit_width ( ) ) . into ( ) ;
128+ // Get the compute type so we can do the addition without problems
129+ let compute_width = lhs_width. max ( rhs_width) + 1 ;
130+ let compute_ty = IntegerType :: new ( context, compute_width) . into ( ) ;
122131
123- // Zero-extend operands into the computation range.
124- native_assert ! (
125- compute_range. offset_bit_width( ) >= lhs_width,
126- "the lhs_range bit_width must be less or equal than the compute_range"
127- ) ;
128- native_assert ! (
129- compute_range. offset_bit_width( ) >= rhs_width,
130- "the rhs_range bit_width must be less or equal than the compute_range"
131- ) ;
132-
133- let lhs_value = if compute_range. offset_bit_width ( ) > lhs_width {
132+ // Get the operands on the same number of bits so we can operate with them
133+ let lhs_value = if compute_width > lhs_width {
134134 if lhs_range. lower . sign ( ) != Sign :: Minus || lhs_ty. is_bounded_int ( registry) ? {
135135 entry. extui ( lhs_value, compute_ty, location) ?
136136 } else {
@@ -139,7 +139,7 @@ fn build_add<'ctx, 'this>(
139139 } else {
140140 lhs_value
141141 } ;
142- let rhs_value = if compute_range . offset_bit_width ( ) > rhs_width {
142+ let rhs_value = if compute_width > rhs_width {
143143 if rhs_range. lower . sign ( ) != Sign :: Minus || rhs_ty. is_bounded_int ( registry) ? {
144144 entry. extui ( rhs_value, compute_ty, location) ?
145145 } else {
@@ -149,47 +149,18 @@ fn build_add<'ctx, 'this>(
149149 rhs_value
150150 } ;
151151
152- // Offset the operands so that they are compatible.
153- let lhs_offset = if lhs_ty. is_bounded_int ( registry) ? {
154- & lhs_range. lower - & compute_range. lower
155- } else {
156- lhs_range. lower
157- } ;
158- let lhs_value = if lhs_offset != BigInt :: ZERO {
159- let lhs_offset = entry. const_int_from_type ( context, location, lhs_offset, compute_ty) ?;
160- entry. addi ( lhs_value, lhs_offset, location) ?
161- } else {
162- lhs_value
163- } ;
164-
165- let rhs_offset = if rhs_ty. is_bounded_int ( registry) ? {
166- & rhs_range. lower - & compute_range. lower
167- } else {
168- rhs_range. lower
169- } ;
170- let rhs_value = if rhs_offset != BigInt :: ZERO {
171- let rhs_offset = entry. const_int_from_type ( context, location, rhs_offset, compute_ty) ?;
172- entry. addi ( rhs_value, rhs_offset, location) ?
173- } else {
174- rhs_value
175- } ;
176-
177- // Compute the operation.
152+ // Addition and get the result value on the desired range
178153 let res_value = entry. addi ( lhs_value, rhs_value, location) ?;
179-
180- // Offset and truncate the result to the output type.
181- let res_offset = & dst_range. lower - & compute_range. lower ;
182- let res_value = if res_offset != BigInt :: ZERO {
183- let res_offset = entry. const_int_from_type ( context, location, res_offset, compute_ty) ?;
184- entry. append_op_result ( arith:: subi ( res_value, res_offset, location) ) ?
185- } else {
186- res_value
187- } ;
188-
189- let res_value = if dst_range. offset_bit_width ( ) < compute_range. offset_bit_width ( ) {
154+ let res_value = if compute_width > dst_width {
190155 entry. trunci (
191156 res_value,
192- IntegerType :: new ( context, dst_range. offset_bit_width ( ) ) . into ( ) ,
157+ IntegerType :: new ( context, dst_width) . into ( ) ,
158+ location,
159+ ) ?
160+ } else if compute_width < dst_width {
161+ entry. extui (
162+ res_value,
163+ IntegerType :: new ( context, dst_width) . into ( ) ,
193164 location,
194165 ) ?
195166 } else {
@@ -1141,6 +1112,136 @@ mod test {
11411112 run_program_assert_output ( & TEST_TRIM_PROGRAM , entry_point, arguments, expected_result) ;
11421113 }
11431114
1115+ lazy_static ! {
1116+ static ref TEST_ADD_PROGRAM : ( String , Program ) = load_cairo! {
1117+ #[ feature( "bounded-int-utils" ) ]
1118+ use core:: internal:: bounded_int:: { BoundedInt , add, AddHelper , UnitInt } ;
1119+
1120+ impl AddHelperBI_1x31_BI_1x1 of AddHelper <BoundedInt <1 , 31 >, BoundedInt <1 , 1 >> {
1121+ type Result = BoundedInt <2 , 32 >;
1122+ }
1123+
1124+ fn bi_1x31_plus_bi_1x1(
1125+ a: felt252,
1126+ b: felt252,
1127+ ) -> BoundedInt <2 , 32 > {
1128+ let a: BoundedInt <1 , 31 > = a. try_into( ) . unwrap( ) ;
1129+ let b: BoundedInt <1 , 1 > = b. try_into( ) . unwrap( ) ;
1130+ return add( a, b) ;
1131+ }
1132+
1133+ impl AddHelperBI_1x31_BI_m1xm1 of AddHelper <BoundedInt <1 , 31 >, BoundedInt <-1 , -1 >> {
1134+ type Result = BoundedInt <0 , 30 >;
1135+ }
1136+
1137+ fn bi_1x31_plus_bi_m1xm1(
1138+ a: felt252,
1139+ b: felt252,
1140+ ) -> BoundedInt <0 , 30 > {
1141+ let a: BoundedInt <1 , 31 > = a. try_into( ) . unwrap( ) ;
1142+ let b: BoundedInt <-1 , -1 > = b. try_into( ) . unwrap( ) ;
1143+ return add( a, b) ;
1144+ }
1145+
1146+ impl AddHelperBI_0x30_BI_0x10 of AddHelper <BoundedInt <0 , 30 >, BoundedInt <0 , 10 >> {
1147+ type Result = BoundedInt <0 , 40 >;
1148+ }
1149+
1150+ fn bi_0x30_plus_bi_0x10(
1151+ a: felt252,
1152+ b: felt252,
1153+ ) -> BoundedInt <0 , 40 > {
1154+ let a: BoundedInt <0 , 30 > = a. try_into( ) . unwrap( ) ;
1155+ let b: BoundedInt <0 , 10 > = b. try_into( ) . unwrap( ) ;
1156+ return add( a, b) ;
1157+ }
1158+
1159+ impl AddHelperBI_m20xm15_BI_0x10 of AddHelper <BoundedInt <-20 , -15 >, BoundedInt <0 , 10 >> {
1160+ type Result = BoundedInt <-20 , -5 >;
1161+ }
1162+
1163+ fn bi_m20xm15_plus_bi_0x10(
1164+ a: felt252,
1165+ b: felt252,
1166+ ) -> BoundedInt <-20 , -5 > {
1167+ let a: BoundedInt <-20 , -15 > = a. try_into( ) . unwrap( ) ;
1168+ let b: BoundedInt <0 , 10 > = b. try_into( ) . unwrap( ) ;
1169+ return add( a, b) ;
1170+ }
1171+
1172+ impl AddHelperBI_m5xm5_BI_m5xm5 of AddHelper <BoundedInt <-5 , -5 >, BoundedInt <-5 , -5 >> {
1173+ type Result = BoundedInt <-10 , -10 >;
1174+ }
1175+
1176+ fn bi_m5xm5_plus_bi_m5xm5(
1177+ a: felt252,
1178+ b: felt252,
1179+ ) -> BoundedInt <-10 , -10 > {
1180+ let a: BoundedInt <-5 , -5 > = a. try_into( ) . unwrap( ) ;
1181+ let b: BoundedInt <-5 , -5 > = b. try_into( ) . unwrap( ) ;
1182+ return add( a, b) ;
1183+ }
1184+
1185+ impl AddHelperBI_m5xm5_UI_m1 of AddHelper <BoundedInt <-5 , -5 >, UnitInt <-1 >> {
1186+ type Result = BoundedInt <-6 , -6 >;
1187+ }
1188+
1189+ fn bi_m5xm5_plus_ui_m1(
1190+ a: felt252,
1191+ b: felt252,
1192+ ) -> BoundedInt <-6 , -6 > {
1193+ let a: BoundedInt <-5 , -5 > = a. try_into( ) . unwrap( ) ;
1194+ let b: UnitInt <-1 > = b. try_into( ) . unwrap( ) ;
1195+ return add( a, b) ;
1196+ }
1197+
1198+ impl AddHelperUI_1_BI_m5xm5 of AddHelper <UnitInt <1 >, BoundedInt <-5 , -5 >> {
1199+ type Result = BoundedInt <-4 , -4 >;
1200+ }
1201+
1202+ fn ui_m1_plus_bi_m5xm5(
1203+ a: felt252,
1204+ b: felt252,
1205+ ) -> BoundedInt <-4 , -4 > {
1206+ let a: UnitInt <1 > = a. try_into( ) . unwrap( ) ;
1207+ let b: BoundedInt <-5 , -5 > = b. try_into( ) . unwrap( ) ;
1208+ return add( a, b) ;
1209+ }
1210+ } ;
1211+ }
1212+
1213+ #[ test_case( "bi_1x31_plus_bi_1x1" , 31 , 1 , 32 ) ]
1214+ #[ test_case( "bi_1x31_plus_bi_m1xm1" , 31 , -1 , 30 ) ]
1215+ #[ test_case( "bi_0x30_plus_bi_0x10" , 30 , 10 , 40 ) ]
1216+ #[ test_case( "bi_m20xm15_plus_bi_0x10" , -15 , 10 , -5 ) ]
1217+ #[ test_case( "bi_m20xm15_plus_bi_0x10" , -20 , 10 , -10 ) ]
1218+ #[ test_case( "bi_m5xm5_plus_bi_m5xm5" , -5 , -5 , -10 ) ]
1219+ #[ test_case( "bi_m5xm5_plus_ui_m1" , -5 , -1 , -6 ) ]
1220+ #[ test_case( "ui_m1_plus_bi_m5xm5" , 1 , -5 , -4 ) ]
1221+ fn test_add ( entry_point : & str , lhs : i32 , rhs : i32 , expected_result : i32 ) {
1222+ let result = run_program (
1223+ & TEST_ADD_PROGRAM ,
1224+ entry_point,
1225+ & [
1226+ Value :: Felt252 ( Felt252 :: from ( lhs) ) ,
1227+ Value :: Felt252 ( Felt252 :: from ( rhs) ) ,
1228+ ] ,
1229+ )
1230+ . return_value ;
1231+
1232+ if let Value :: Enum { value, .. } = result {
1233+ if let Value :: Struct { fields, .. } = * value {
1234+ assert ! (
1235+ matches!( fields[ 0 ] , Value :: BoundedInt { value, .. } if value == Felt252 :: from( expected_result) )
1236+ )
1237+ } else {
1238+ panic ! ( "Test returned an unexpected value" ) ;
1239+ }
1240+ } else {
1241+ panic ! ( "Test returned value was not an Enum as expected" ) ;
1242+ }
1243+ }
1244+
11441245 fn assert_bool_output ( result : Value , expected_tag : usize ) {
11451246 if let Value :: Enum { tag, value, .. } = result {
11461247 assert_eq ! ( tag, 0 ) ;
0 commit comments