Skip to content

Commit 4794e1d

Browse files
Fix BoundedInt Add libfunc (#1463)
* Add fix and tests for add * Refactor fix * Change add implementation * Remove unnecessary part of the equation * Handle addition of bounded ints with unit ints * Remove todo * Add asserts * Add comments * Refactor of tests * Refactor tests * Improve docs * Remove asserts * Refactor of tests names * Minor changes --------- Co-authored-by: Julian Gonzalez Calderon <gonzalezcalderonjulian@gmail.com>
1 parent 15388a2 commit 4794e1d

File tree

1 file changed

+164
-63
lines changed

1 file changed

+164
-63
lines changed

src/libfuncs/bounded_int.rs

Lines changed: 164 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ pub fn build<'ctx, 'this>(
7373
}
7474

7575
/// Generate MLIR operations for the `bounded_int_add` libfunc.
76+
///
77+
/// # Cairo Signature
78+
///
79+
/// ```cairo
80+
/// extern fn bounded_int_add<Lhs, Rhs, impl H: AddHelper<Lhs, Rhs>>(
81+
/// lhs: Lhs, rhs: Rhs,
82+
/// ) -> H::Result nopanic;
83+
/// ```
84+
///
85+
/// A number X as a `BoundedInt` is internally represented as an offset Xd from the lower bound Xo.
86+
/// So X = Xo + Xd.
87+
///
88+
/// Since we want to get C = A + B, we can translate this to
89+
/// Co + Cd = Ao + Ad + Bo + Bd. Where Ao, Bo and Co represent the lower bound
90+
/// of the ranges in the `BoundedInt` and Ad, Bd and Cd represent the offsets. Since
91+
/// we also know that Co = Ao + Bo we can simplify the equation to Cd = Ad + Bd.
7692
#[allow(clippy::too_many_arguments)]
7793
fn build_add<'ctx, 'this>(
7894
context: &'ctx Context,
@@ -86,7 +102,7 @@ fn build_add<'ctx, 'this>(
86102
let lhs_value = entry.arg(0)?;
87103
let rhs_value = entry.arg(1)?;
88104

89-
// Extract the ranges for the operands and the result type.
105+
// Extract the ranges for the operands.
90106
let lhs_ty = registry.get_type(&info.signature.param_signatures[0].ty)?;
91107
let rhs_ty = registry.get_type(&info.signature.param_signatures[1].ty)?;
92108

@@ -96,6 +112,7 @@ fn build_add<'ctx, 'this>(
96112
.get_type(&info.signature.branch_signatures[0].vars[0].ty)?
97113
.integer_range(registry)?;
98114

115+
// Extract the bit width.
99116
let lhs_width = if lhs_ty.is_bounded_int(registry)? {
100117
lhs_range.offset_bit_width()
101118
} else {
@@ -106,31 +123,14 @@ fn build_add<'ctx, 'this>(
106123
} else {
107124
rhs_range.zero_based_bit_width()
108125
};
126+
let dst_width = dst_range.offset_bit_width();
109127

110-
// Calculate the computation range.
111-
let compute_range = Range {
112-
lower: (&lhs_range.lower)
113-
.min(&rhs_range.lower)
114-
.min(&dst_range.lower)
115-
.clone(),
116-
upper: (&lhs_range.upper)
117-
.max(&rhs_range.upper)
118-
.max(&dst_range.upper)
119-
.clone(),
120-
};
121-
let compute_ty = IntegerType::new(context, compute_range.offset_bit_width()).into();
128+
// Get the compute type so we can do the addition without problems
129+
let compute_width = lhs_width.max(rhs_width) + 1;
130+
let compute_ty = IntegerType::new(context, compute_width).into();
122131

123-
// Zero-extend operands into the computation range.
124-
native_assert!(
125-
compute_range.offset_bit_width() >= lhs_width,
126-
"the lhs_range bit_width must be less or equal than the compute_range"
127-
);
128-
native_assert!(
129-
compute_range.offset_bit_width() >= rhs_width,
130-
"the rhs_range bit_width must be less or equal than the compute_range"
131-
);
132-
133-
let lhs_value = if compute_range.offset_bit_width() > lhs_width {
132+
// Get the operands on the same number of bits so we can operate with them
133+
let lhs_value = if compute_width > lhs_width {
134134
if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? {
135135
entry.extui(lhs_value, compute_ty, location)?
136136
} else {
@@ -139,7 +139,7 @@ fn build_add<'ctx, 'this>(
139139
} else {
140140
lhs_value
141141
};
142-
let rhs_value = if compute_range.offset_bit_width() > rhs_width {
142+
let rhs_value = if compute_width > rhs_width {
143143
if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? {
144144
entry.extui(rhs_value, compute_ty, location)?
145145
} else {
@@ -149,47 +149,18 @@ fn build_add<'ctx, 'this>(
149149
rhs_value
150150
};
151151

152-
// Offset the operands so that they are compatible.
153-
let lhs_offset = if lhs_ty.is_bounded_int(registry)? {
154-
&lhs_range.lower - &compute_range.lower
155-
} else {
156-
lhs_range.lower
157-
};
158-
let lhs_value = if lhs_offset != BigInt::ZERO {
159-
let lhs_offset = entry.const_int_from_type(context, location, lhs_offset, compute_ty)?;
160-
entry.addi(lhs_value, lhs_offset, location)?
161-
} else {
162-
lhs_value
163-
};
164-
165-
let rhs_offset = if rhs_ty.is_bounded_int(registry)? {
166-
&rhs_range.lower - &compute_range.lower
167-
} else {
168-
rhs_range.lower
169-
};
170-
let rhs_value = if rhs_offset != BigInt::ZERO {
171-
let rhs_offset = entry.const_int_from_type(context, location, rhs_offset, compute_ty)?;
172-
entry.addi(rhs_value, rhs_offset, location)?
173-
} else {
174-
rhs_value
175-
};
176-
177-
// Compute the operation.
152+
// Addition and get the result value on the desired range
178153
let res_value = entry.addi(lhs_value, rhs_value, location)?;
179-
180-
// Offset and truncate the result to the output type.
181-
let res_offset = &dst_range.lower - &compute_range.lower;
182-
let res_value = if res_offset != BigInt::ZERO {
183-
let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?;
184-
entry.append_op_result(arith::subi(res_value, res_offset, location))?
185-
} else {
186-
res_value
187-
};
188-
189-
let res_value = if dst_range.offset_bit_width() < compute_range.offset_bit_width() {
154+
let res_value = if compute_width > dst_width {
190155
entry.trunci(
191156
res_value,
192-
IntegerType::new(context, dst_range.offset_bit_width()).into(),
157+
IntegerType::new(context, dst_width).into(),
158+
location,
159+
)?
160+
} else if compute_width < dst_width {
161+
entry.extui(
162+
res_value,
163+
IntegerType::new(context, dst_width).into(),
193164
location,
194165
)?
195166
} else {
@@ -1141,6 +1112,136 @@ mod test {
11411112
run_program_assert_output(&TEST_TRIM_PROGRAM, entry_point, arguments, expected_result);
11421113
}
11431114

1115+
lazy_static! {
1116+
static ref TEST_ADD_PROGRAM: (String, Program) = load_cairo! {
1117+
#[feature("bounded-int-utils")]
1118+
use core::internal::bounded_int::{BoundedInt, add, AddHelper, UnitInt};
1119+
1120+
impl AddHelperBI_1x31_BI_1x1 of AddHelper<BoundedInt<1, 31>, BoundedInt<1, 1>> {
1121+
type Result = BoundedInt<2, 32>;
1122+
}
1123+
1124+
fn bi_1x31_plus_bi_1x1(
1125+
a: felt252,
1126+
b: felt252,
1127+
) -> BoundedInt<2, 32> {
1128+
let a: BoundedInt<1, 31> = a.try_into().unwrap();
1129+
let b: BoundedInt<1, 1> = b.try_into().unwrap();
1130+
return add(a, b);
1131+
}
1132+
1133+
impl AddHelperBI_1x31_BI_m1xm1 of AddHelper<BoundedInt<1, 31>, BoundedInt<-1, -1>> {
1134+
type Result = BoundedInt<0, 30>;
1135+
}
1136+
1137+
fn bi_1x31_plus_bi_m1xm1(
1138+
a: felt252,
1139+
b: felt252,
1140+
) -> BoundedInt<0, 30> {
1141+
let a: BoundedInt<1, 31> = a.try_into().unwrap();
1142+
let b: BoundedInt<-1, -1> = b.try_into().unwrap();
1143+
return add(a, b);
1144+
}
1145+
1146+
impl AddHelperBI_0x30_BI_0x10 of AddHelper<BoundedInt<0, 30>, BoundedInt<0, 10>> {
1147+
type Result = BoundedInt<0, 40>;
1148+
}
1149+
1150+
fn bi_0x30_plus_bi_0x10(
1151+
a: felt252,
1152+
b: felt252,
1153+
) -> BoundedInt<0, 40> {
1154+
let a: BoundedInt<0, 30> = a.try_into().unwrap();
1155+
let b: BoundedInt<0, 10> = b.try_into().unwrap();
1156+
return add(a, b);
1157+
}
1158+
1159+
impl AddHelperBI_m20xm15_BI_0x10 of AddHelper<BoundedInt<-20, -15>, BoundedInt<0, 10>> {
1160+
type Result = BoundedInt<-20, -5>;
1161+
}
1162+
1163+
fn bi_m20xm15_plus_bi_0x10(
1164+
a: felt252,
1165+
b: felt252,
1166+
) -> BoundedInt<-20, -5> {
1167+
let a: BoundedInt<-20, -15> = a.try_into().unwrap();
1168+
let b: BoundedInt<0, 10> = b.try_into().unwrap();
1169+
return add(a, b);
1170+
}
1171+
1172+
impl AddHelperBI_m5xm5_BI_m5xm5 of AddHelper<BoundedInt<-5, -5>, BoundedInt<-5, -5>> {
1173+
type Result = BoundedInt<-10, -10>;
1174+
}
1175+
1176+
fn bi_m5xm5_plus_bi_m5xm5(
1177+
a: felt252,
1178+
b: felt252,
1179+
) -> BoundedInt<-10, -10> {
1180+
let a: BoundedInt<-5, -5> = a.try_into().unwrap();
1181+
let b: BoundedInt<-5, -5> = b.try_into().unwrap();
1182+
return add(a, b);
1183+
}
1184+
1185+
impl AddHelperBI_m5xm5_UI_m1 of AddHelper<BoundedInt<-5, -5>, UnitInt<-1>> {
1186+
type Result = BoundedInt<-6, -6>;
1187+
}
1188+
1189+
fn bi_m5xm5_plus_ui_m1(
1190+
a: felt252,
1191+
b: felt252,
1192+
) -> BoundedInt<-6, -6> {
1193+
let a: BoundedInt<-5, -5> = a.try_into().unwrap();
1194+
let b: UnitInt<-1> = b.try_into().unwrap();
1195+
return add(a, b);
1196+
}
1197+
1198+
impl AddHelperUI_1_BI_m5xm5 of AddHelper<UnitInt<1>, BoundedInt<-5, -5>> {
1199+
type Result = BoundedInt<-4, -4>;
1200+
}
1201+
1202+
fn ui_m1_plus_bi_m5xm5(
1203+
a: felt252,
1204+
b: felt252,
1205+
) -> BoundedInt<-4, -4> {
1206+
let a: UnitInt<1> = a.try_into().unwrap();
1207+
let b: BoundedInt<-5, -5> = b.try_into().unwrap();
1208+
return add(a, b);
1209+
}
1210+
};
1211+
}
1212+
1213+
#[test_case("bi_1x31_plus_bi_1x1", 31, 1, 32)]
1214+
#[test_case("bi_1x31_plus_bi_m1xm1", 31, -1, 30)]
1215+
#[test_case("bi_0x30_plus_bi_0x10", 30, 10, 40)]
1216+
#[test_case("bi_m20xm15_plus_bi_0x10", -15, 10, -5)]
1217+
#[test_case("bi_m20xm15_plus_bi_0x10", -20, 10, -10)]
1218+
#[test_case("bi_m5xm5_plus_bi_m5xm5", -5, -5, -10)]
1219+
#[test_case("bi_m5xm5_plus_ui_m1", -5, -1, -6)]
1220+
#[test_case("ui_m1_plus_bi_m5xm5", 1, -5, -4)]
1221+
fn test_add(entry_point: &str, lhs: i32, rhs: i32, expected_result: i32) {
1222+
let result = run_program(
1223+
&TEST_ADD_PROGRAM,
1224+
entry_point,
1225+
&[
1226+
Value::Felt252(Felt252::from(lhs)),
1227+
Value::Felt252(Felt252::from(rhs)),
1228+
],
1229+
)
1230+
.return_value;
1231+
1232+
if let Value::Enum { value, .. } = result {
1233+
if let Value::Struct { fields, .. } = *value {
1234+
assert!(
1235+
matches!(fields[0], Value::BoundedInt { value, .. } if value == Felt252::from(expected_result))
1236+
)
1237+
} else {
1238+
panic!("Test returned an unexpected value");
1239+
}
1240+
} else {
1241+
panic!("Test returned value was not an Enum as expected");
1242+
}
1243+
}
1244+
11441245
fn assert_bool_output(result: Value, expected_tag: usize) {
11451246
if let Value::Enum { tag, value, .. } = result {
11461247
assert_eq!(tag, 0);

0 commit comments

Comments
 (0)