Skip to content
Open
203 changes: 140 additions & 63 deletions src/libfuncs/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ fn build_add<'ctx, 'this>(
}

/// Generate MLIR operations for the `bounded_int_sub` libfunc.
///
/// # Cairo Signature
/// ```cairo
/// extern fn bounded_int_sub<Lhs, Rhs, impl H: SubHelper<Lhs, Rhs>>(
/// lhs: Lhs, rhs: Rhs,
/// ) -> H::Result nopanic;
/// ```
///
/// A number X as a `BoundedInt` is internally represented as an offset Xd from the lower bound Xo.
/// So X = Xo + Xd.
///
/// Since we want to get C = A - B, we can translate this to
/// Co + Cd = (Ao + Ad) - (Bo + Bd). Where Ao, Bo and Co represent the lower bound
/// of the ranges in the `BoundedInt` and Ad, Bd and Cd represent the offsets.
#[allow(clippy::too_many_arguments)]
fn build_sub<'ctx, 'this>(
context: &'ctx Context,
Expand All @@ -213,7 +227,7 @@ fn build_sub<'ctx, 'this>(
let lhs_value = entry.arg(0)?;
let rhs_value = entry.arg(1)?;

// Extract the ranges for the operands and the result type.
// Extract the ranges for the operands.
let lhs_ty = registry.get_type(&info.signature.param_signatures[0].ty)?;
let rhs_ty = registry.get_type(&info.signature.param_signatures[1].ty)?;

Expand All @@ -223,6 +237,7 @@ fn build_sub<'ctx, 'this>(
.get_type(&info.signature.branch_signatures[0].vars[0].ty)?
.integer_range(registry)?;

// Extract the bit width.
let lhs_width = if lhs_ty.is_bounded_int(registry)? {
lhs_range.offset_bit_width()
} else {
Expand All @@ -233,31 +248,17 @@ fn build_sub<'ctx, 'this>(
} else {
rhs_range.zero_based_bit_width()
};
let dst_width = dst_range.offset_bit_width();

// Calculate the computation range.
let compute_range = Range {
lower: (&lhs_range.lower)
.min(&rhs_range.lower)
.min(&dst_range.lower)
.clone(),
upper: (&lhs_range.upper)
.max(&rhs_range.upper)
.max(&dst_range.upper)
.clone(),
};
let compute_ty = IntegerType::new(context, compute_range.offset_bit_width()).into();
// Get the compute type so we can do the subtraction without problems
let compile_time_val = lhs_range.lower.clone() - rhs_range.lower.clone() - dst_range.lower;
let compile_time_val_width = u32::try_from(compile_time_val.bits())?;

// Zero-extend operands into the computation range.
native_assert!(
compute_range.offset_bit_width() >= lhs_width,
"the lhs_range bit_width must be less or equal than the compute_range"
);
native_assert!(
compute_range.offset_bit_width() >= rhs_width,
"the rhs_range bit_width must be less or equal than the compute_range"
);
let compute_width = lhs_width.max(rhs_width).max(compile_time_val_width) + 1;
let compute_ty = IntegerType::new(context, compute_width).into();

let lhs_value = if compute_range.offset_bit_width() > lhs_width {
// Get the operands on the same number of bits so we can operate with them
let lhs_value = if compute_width > lhs_width {
if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? {
entry.extui(lhs_value, compute_ty, location)?
} else {
Expand All @@ -266,7 +267,7 @@ fn build_sub<'ctx, 'this>(
} else {
lhs_value
};
let rhs_value = if compute_range.offset_bit_width() > rhs_width {
let rhs_value = if compute_width > rhs_width {
if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? {
entry.extui(rhs_value, compute_ty, location)?
} else {
Expand All @@ -276,47 +277,23 @@ fn build_sub<'ctx, 'this>(
rhs_value
};

// Offset the operands so that they are compatible.
let lhs_offset = if lhs_ty.is_bounded_int(registry)? {
&lhs_range.lower - &compute_range.lower
} else {
lhs_range.lower
};
let lhs_value = if lhs_offset != BigInt::ZERO {
let lhs_offset = entry.const_int_from_type(context, location, lhs_offset, compute_ty)?;
entry.addi(lhs_value, lhs_offset, location)?
} else {
lhs_value
};

let rhs_offset = if rhs_ty.is_bounded_int(registry)? {
&rhs_range.lower - &compute_range.lower
} else {
rhs_range.lower
};
let rhs_value = if rhs_offset != BigInt::ZERO {
let rhs_offset = entry.const_int_from_type(context, location, rhs_offset, compute_ty)?;
entry.addi(rhs_value, rhs_offset, location)?
} else {
rhs_value
};

// Compute the operation.
let res_value = entry.append_op_result(arith::subi(lhs_value, rhs_value, location))?;

// Offset and truncate the result to the output type.
let res_offset = dst_range.lower.clone();
let res_value = if res_offset != BigInt::ZERO {
let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?;
entry.append_op_result(arith::subi(res_value, res_offset, location))?
} else {
res_value
};

let res_value = if dst_range.offset_bit_width() < compute_range.offset_bit_width() {
let compile_time_val =
entry.const_int_from_type(context, location, compile_time_val, compute_ty)?;
// First we do -> intermediate_res = Ad - Bd
let res_value = entry.subi(lhs_value, rhs_value, location)?;
// Then we do -> intermediate_res += (Ao - Bo - Co)
let res_value = entry.addi(res_value, compile_time_val, location)?;
// Get the result value on the desired range
let res_value = if compute_width > dst_width {
entry.trunci(
res_value,
IntegerType::new(context, dst_range.offset_bit_width()).into(),
IntegerType::new(context, dst_width).into(),
location,
)?
} else if compute_width < dst_width {
entry.extui(
res_value,
IntegerType::new(context, dst_width).into(),
location,
)?
} else {
Expand Down Expand Up @@ -1146,6 +1123,106 @@ mod test {
run_program_assert_output(&TEST_TRIM_PROGRAM, entry_point, arguments, expected_result);
}

lazy_static! {
static ref TEST_SUB_PROGRAM: (String, Program) = load_cairo! {
#[feature("bounded-int-utils")]
use core::internal::bounded_int::{BoundedInt, sub, SubHelper};

impl SubHelperBI_1x1_BI_1x5 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 5>> {
type Result = BoundedInt<-4, 0>;
}

fn bi_1x1_minus_bi_1x5(
a: felt252,
b: felt252,
) -> BoundedInt<-4, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 5> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_1x1_BI_1x1 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 1>> {
type Result = BoundedInt<0, 0>;
}

fn bi_1x1_minus_bi_1x1(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 1> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m3xm3_BI_m3xm3 of SubHelper<BoundedInt<-3, -3>, BoundedInt<-3, -3>> {
type Result = BoundedInt<0, 0>;
}

fn bi_m3xm3_minus_bi_m3xm3(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<-3, -3> = a.try_into().unwrap();
let b: BoundedInt<-3, -3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m6xm3_BI_1x3 of SubHelper<BoundedInt<-6, -3>, BoundedInt<1, 3>> {
type Result = BoundedInt<-9, -4>;
}

fn bi_m6xm3_minus_bi_1x3(
a: felt252,
b: felt252,
) -> BoundedInt<-9, -4> {
let a: BoundedInt<-6, -3> = a.try_into().unwrap();
let b: BoundedInt<1, 3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m6xm2_BI_m20xm10 of SubHelper<BoundedInt<-6, -2>, BoundedInt<-20, -10>> {
type Result = BoundedInt<4, 18>;
}

fn bi_m6xm2_minus_bi_m20xm10(
a: felt252,
b: felt252,
) -> BoundedInt<4, 18> {
let a: BoundedInt<-6, -2> = a.try_into().unwrap();
let b: BoundedInt<-20, -10> = b.try_into().unwrap();
return sub(a, b);
}
};
}

#[test_case("bi_1x1_minus_bi_1x5", 1, 5, -4)]
#[test_case("bi_1x1_minus_bi_1x1", 1, 1, 0)]
#[test_case("bi_m3xm3_minus_bi_m3xm3", -3, -3, 0)]
#[test_case("bi_m6xm3_minus_bi_1x3", -6, 3, -9)]
#[test_case("bi_m6xm2_minus_bi_m20xm10", -2, -20, 18)]
fn test_sub(entry_point: &str, lhs: i32, rhs: i32, expected_result: i32) {
let result = run_program(
&TEST_SUB_PROGRAM,
entry_point,
&[
Value::Felt252(Felt252::from(lhs)),
Value::Felt252(Felt252::from(rhs)),
],
)
.return_value;
if let Value::Enum { value, .. } = result {
if let Value::Struct { fields, .. } = *value {
assert!(
matches!(fields[0], Value::BoundedInt { value, .. } if value == Felt252::from(expected_result))
)
} else {
panic!("Test returned an unexpected value");
}
} else {
panic!("Test didn't return an enum as expected");
}
}

fn assert_bool_output(result: Value, expected_tag: usize) {
if let Value::Enum { tag, value, .. } = result {
assert_eq!(tag, 0);
Expand Down
Loading