Skip to content
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 140 additions & 63 deletions src/libfuncs/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ fn build_add<'ctx, 'this>(
}

/// Generate MLIR operations for the `bounded_int_sub` libfunc.
///
/// # Cairo Signature
/// ```cairo
/// extern fn bounded_int_sub<Lhs, Rhs, impl H: SubHelper<Lhs, Rhs>>(
/// lhs: Lhs, rhs: Rhs,
/// ) -> H::Result nopanic;
/// ```
///
/// A number X as a `BoundedInt` is internally represented as an offset Xd from the lower bound Xo.
/// So X = Xo + Xd.
///
/// Since we want to get C = A - B, we can translate this to
/// Co + Cd = (Ao + Ad) - (Bo + Bd). Where Ao, Bo and Co represent the lower bound
/// of the ranges in the `BoundedInt` and Ad, Bd and Cd represent the offsets.
#[allow(clippy::too_many_arguments)]
fn build_sub<'ctx, 'this>(
context: &'ctx Context,
Expand All @@ -213,7 +227,7 @@ fn build_sub<'ctx, 'this>(
let lhs_value = entry.arg(0)?;
let rhs_value = entry.arg(1)?;

// Extract the ranges for the operands and the result type.
// Extract the ranges for the operands.
let lhs_ty = registry.get_type(&info.signature.param_signatures[0].ty)?;
let rhs_ty = registry.get_type(&info.signature.param_signatures[1].ty)?;

Expand All @@ -223,6 +237,7 @@ fn build_sub<'ctx, 'this>(
.get_type(&info.signature.branch_signatures[0].vars[0].ty)?
.integer_range(registry)?;

// Extract the bit width.
let lhs_width = if lhs_ty.is_bounded_int(registry)? {
lhs_range.offset_bit_width()
} else {
Expand All @@ -233,31 +248,17 @@ fn build_sub<'ctx, 'this>(
} else {
rhs_range.zero_based_bit_width()
};
let dst_width = dst_range.offset_bit_width();

// Calculate the computation range.
let compute_range = Range {
lower: (&lhs_range.lower)
.min(&rhs_range.lower)
.min(&dst_range.lower)
.clone(),
upper: (&lhs_range.upper)
.max(&rhs_range.upper)
.max(&dst_range.upper)
.clone(),
};
let compute_ty = IntegerType::new(context, compute_range.offset_bit_width()).into();
// Get the compute type so we can do the subtraction without problems
let compile_time_val = lhs_range.lower.clone() - rhs_range.lower.clone() - dst_range.lower;
let compile_time_val_width = u32::try_from(compile_time_val.bits())?;

// Zero-extend operands into the computation range.
native_assert!(
compute_range.offset_bit_width() >= lhs_width,
"the lhs_range bit_width must be less or equal than the compute_range"
);
native_assert!(
compute_range.offset_bit_width() >= rhs_width,
"the rhs_range bit_width must be less or equal than the compute_range"
);
let compute_width = lhs_width.max(rhs_width).max(compile_time_val_width) + 1;
let compute_ty = IntegerType::new(context, compute_width).into();

let lhs_value = if compute_range.offset_bit_width() > lhs_width {
// Get the operands on the same number of bits so we can operate with them
let lhs_value = if compute_width > lhs_width {
if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? {
entry.extui(lhs_value, compute_ty, location)?
} else {
Expand All @@ -266,7 +267,7 @@ fn build_sub<'ctx, 'this>(
} else {
lhs_value
};
let rhs_value = if compute_range.offset_bit_width() > rhs_width {
let rhs_value = if compute_width > rhs_width {
if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? {
entry.extui(rhs_value, compute_ty, location)?
} else {
Expand All @@ -276,47 +277,23 @@ fn build_sub<'ctx, 'this>(
rhs_value
};

// Offset the operands so that they are compatible.
let lhs_offset = if lhs_ty.is_bounded_int(registry)? {
&lhs_range.lower - &compute_range.lower
} else {
lhs_range.lower
};
let lhs_value = if lhs_offset != BigInt::ZERO {
let lhs_offset = entry.const_int_from_type(context, location, lhs_offset, compute_ty)?;
entry.addi(lhs_value, lhs_offset, location)?
} else {
lhs_value
};

let rhs_offset = if rhs_ty.is_bounded_int(registry)? {
&rhs_range.lower - &compute_range.lower
} else {
rhs_range.lower
};
let rhs_value = if rhs_offset != BigInt::ZERO {
let rhs_offset = entry.const_int_from_type(context, location, rhs_offset, compute_ty)?;
entry.addi(rhs_value, rhs_offset, location)?
} else {
rhs_value
};

// Compute the operation.
let res_value = entry.append_op_result(arith::subi(lhs_value, rhs_value, location))?;

// Offset and truncate the result to the output type.
let res_offset = dst_range.lower.clone();
let res_value = if res_offset != BigInt::ZERO {
let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?;
entry.append_op_result(arith::subi(res_value, res_offset, location))?
} else {
res_value
};

let res_value = if dst_range.offset_bit_width() < compute_range.offset_bit_width() {
let compile_time_val =
entry.const_int_from_type(context, location, compile_time_val, compute_ty)?;
// First we do -> Ad - Bd = intermediate_res
let res_value = entry.subi(lhs_value, rhs_value, location)?;
// Then we do -> intermediate_res + (Ao - Bo - Co)
let res_value = entry.addi(res_value, compile_time_val, location)?;
// Get the result value on the desired range
let res_value = if compute_width > dst_width {
entry.trunci(
res_value,
IntegerType::new(context, dst_range.offset_bit_width()).into(),
IntegerType::new(context, dst_width).into(),
location,
)?
} else if compute_width < dst_width {
entry.extui(
res_value,
IntegerType::new(context, dst_width).into(),
location,
)?
} else {
Expand Down Expand Up @@ -1013,6 +990,106 @@ mod test {
assert_eq!(value, Felt252::from(0));
}

lazy_static! {
static ref TEST_SUB_PROGRAM: (String, Program) = load_cairo! {
#[feature("bounded-int-utils")]
use core::internal::bounded_int::{BoundedInt, sub, SubHelper};

impl SubHelperBI_1x1_BI_1x5 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 5>> {
type Result = BoundedInt<-4, 0>;
}

fn bi_1x1_minus_bi_1x5(
a: felt252,
b: felt252,
) -> BoundedInt<-4, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 5> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_1x1_BI_1x1 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 1>> {
type Result = BoundedInt<0, 0>;
}

fn bi_1x1_minus_bi_1x1(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 1> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m3xm3_BI_m3xm3 of SubHelper<BoundedInt<-3, -3>, BoundedInt<-3, -3>> {
type Result = BoundedInt<0, 0>;
}

fn bi_m3xm3_minus_bi_m3xm3(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<-3, -3> = a.try_into().unwrap();
let b: BoundedInt<-3, -3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m6xm3_BI_1x3 of SubHelper<BoundedInt<-6, -3>, BoundedInt<1, 3>> {
type Result = BoundedInt<-9, -4>;
}

fn bi_m6xm3_minus_bi_1x3(
a: felt252,
b: felt252,
) -> BoundedInt<-9, -4> {
let a: BoundedInt<-6, -3> = a.try_into().unwrap();
let b: BoundedInt<1, 3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelperBI_m6xm2_BI_m20xm10 of SubHelper<BoundedInt<-6, -2>, BoundedInt<-20, -10>> {
type Result = BoundedInt<4, 18>;
}

fn bi_m6xm2_minus_bi_m20xm10(
a: felt252,
b: felt252,
) -> BoundedInt<4, 18> {
let a: BoundedInt<-6, -2> = a.try_into().unwrap();
let b: BoundedInt<-20, -10> = b.try_into().unwrap();
return sub(a, b);
}
};
}

#[test_case("bi_1x1_minus_bi_1x5", 1, 5, -4)]
#[test_case("bi_1x1_minus_bi_1x1", 1, 1, 0)]
#[test_case("bi_m3xm3_minus_bi_m3xm3", -3, -3, 0)]
#[test_case("bi_m6xm3_minus_bi_1x3", -6, 3, -9)]
#[test_case("bi_m6xm2_minus_bi_m20xm10", -2, -20, 18)]
fn test_sub(entry_point: &str, lhs: i32, rhs: i32, expected_result: i32) {
let result = run_program(
&TEST_SUB_PROGRAM,
entry_point,
&[
Value::Felt252(Felt252::from(lhs)),
Value::Felt252(Felt252::from(rhs)),
],
)
.return_value;
if let Value::Enum { value, .. } = result {
if let Value::Struct { fields, .. } = *value {
assert!(
matches!(fields[0], Value::BoundedInt { value, .. } if value == Felt252::from(expected_result))
)
} else {
panic!("Test returned an unexpected value");
}
} else {
panic!("Test returned value was not an Enum as expected");
}
}

fn assert_bool_output(result: Value, expected_tag: usize) {
if let Value::Enum { tag, value, .. } = result {
assert_eq!(tag, 0);
Expand Down
Loading