diff --git a/src/libfuncs/circuit.rs b/src/libfuncs/circuit.rs index 901ff11ae..f84bd17b2 100644 --- a/src/libfuncs/circuit.rs +++ b/src/libfuncs/circuit.rs @@ -601,14 +601,15 @@ fn build_gate_evaluation<'ctx, 'this>( let circuit_modulus_u768 = block.extui(circuit_modulus, u768_type, location)?; // Apply egcd to find gcd and inverse - let euclidean_result = runtime_bindings_meta.extended_euclidean_algorithm( - context, - helper.module, - block, - location, - rhs_value, - circuit_modulus_u768, - )?; + let euclidean_result = runtime_bindings_meta + .u384_extended_euclidean_algorithm( + context, + helper.module, + block, + location, + rhs_value, + circuit_modulus_u768, + )?; // Extract the values from the result struct let gcd = block.extract_value(context, location, euclidean_result, u768_type, 0)?; @@ -636,26 +637,6 @@ fn build_gate_evaluation<'ctx, 'this>( )); block = has_inverse_block; - // if the inverse is negative, then add modulus - let zero = block.const_int_from_type(context, location, 0, u768_type)?; - let is_negative = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Slt, - inverse, - zero, - location, - )) - .result(0)? - .into(); - let wrapped_inverse = block.addi(inverse, circuit_modulus_u768, location)?; - let inverse = block.append_op_result(arith::select( - is_negative, - wrapped_inverse, - inverse, - location, - ))?; - // Truncate back let inverse = block.trunci(inverse, u384_type, location)?; diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index af79646f1..dfc07c613 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -2,8 +2,8 @@ use super::LibfuncHelper; use crate::{ - error::Result, - metadata::MetadataStorage, + error::{panic::ToNativeAssertError, Result}, + metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, utils::{ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ @@ -19,12 +19,9 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::{ - arith::{self, CmpiPredicate}, - cf, - }, - helpers::{ArithBlockExt, BuiltinBlockExt}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value, ValueLike}, + dialect::arith::{self, CmpiPredicate}, + helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt}, + ir::{r#type::IntegerType, Block, Location, Value, ValueLike}, Context, }; use num_bigint::{BigInt, Sign}; @@ -149,130 +146,44 @@ pub fn build_binary_operation<'ctx, 'this>( entry.trunci(result, felt252_ty, location)? } Felt252BinaryOperator::Div => { - // The extended euclidean algorithm calculates the greatest common divisor of two integers, - // as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b) - // We use this in felt division to find the modular inverse of a given number - // If a is the number we're trying to find the inverse of, we can do - // ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME) - // Hence for input a, we return x - // The input MUST be non-zero - // See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm - let start_block = helper.append_block(Block::new(&[(i512, location)])); - let loop_block = helper.append_block(Block::new(&[ - (i512, location), - (i512, location), - (i512, location), - (i512, location), - ])); - let negative_check_block = helper.append_block(Block::new(&[])); - // Block containing final result - let inverse_result_block = helper.append_block(Block::new(&[(i512, location)])); - // Egcd works by calculating a series of remainders, each the remainder of dividing the previous two - // For the initial setup, r0 = PRIME, r1 = a - // This order is chosen because if we reverse them, then the first iteration will just swap them - let prev_remainder = - start_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let remainder = start_block.arg(0)?; - // Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a - let prev_inverse = start_block.const_int_from_type(context, location, 0, i512)?; - let inverse = start_block.const_int_from_type(context, location, 1, i512)?; - start_block.append_operation(cf::br( - loop_block, - &[prev_remainder, remainder, prev_inverse, inverse], - location, - )); - - //---Loop body--- - // Arguments are rem_(i-1), rem, inv_(i-1), inv - let prev_remainder = loop_block.arg(0)?; - let remainder = loop_block.arg(1)?; - let prev_inverse = loop_block.arg(2)?; - let inverse = loop_block.arg(3)?; - - // First calculate q = rem_(i-1)/rem_i, rounded down - let quotient = - loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?; - // Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i - let rem_times_quo = loop_block.muli(remainder, quotient, location)?; - let inv_times_quo = loop_block.muli(inverse, quotient, location)?; - let next_remainder = loop_block.append_op_result(arith::subi( - prev_remainder, - rem_times_quo, - location, - ))?; - let next_inverse = - loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?; - - // If r_(i+1) is 0, then inv_i is the inverse - let zero = loop_block.const_int_from_type(context, location, 0, i512)?; - let next_remainder_eq_zero = - loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?; - loop_block.append_operation(cf::cond_br( - context, - next_remainder_eq_zero, - negative_check_block, - loop_block, - &[], - &[remainder, next_remainder, inverse, next_inverse], - location, - )); - - // egcd sometimes returns a negative number for the inverse, - // in such cases we must simply wrap it around back into [0, PRIME) - // this suffices because |inv_i| <= divfloor(PRIME,2) - let zero = negative_check_block.const_int_from_type(context, location, 0, i512)?; - - let is_negative = negative_check_block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Slt, - inverse, - zero, - location, - )) - .result(0)? - .into(); - // if the inverse is < 0, add PRIME - let prime = - negative_check_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let wrapped_inverse = negative_check_block.addi(inverse, prime, location)?; - let inverse = negative_check_block.append_op_result(arith::select( - is_negative, - wrapped_inverse, - inverse, - location, - ))?; - negative_check_block.append_operation(cf::br( - inverse_result_block, - &[inverse], - location, - )); + let runtime_bindings_meta = metadata + .get_mut::() + .to_native_assert_error( + "Unable to get the RuntimeBindingsMeta from MetadataStorage", + )?; - // Div Logic Start - // Fetch operands + let prime = entry.const_int_from_type(context, location, PRIME.clone(), i512)?; let lhs = entry.extui(lhs, i512, location)?; let rhs = entry.extui(rhs, i512, location)?; - // Calculate inverse of rhs, callling the inverse implementation's starting block - entry.append_operation(cf::br(start_block, &[rhs], location)); - // Fetch the inverse result from the result block - let inverse = inverse_result_block.arg(0)?; - // Peform lhs * (1/ rhs) - let result = inverse_result_block.muli(lhs, inverse, location)?; + + // Find 1 / rhs. + let euclidean_result = runtime_bindings_meta.u252_extended_euclidean_algorithm( + context, + helper.module, + entry, + location, + rhs, + prime, + )?; + + let inverse = entry.extract_value(context, location, euclidean_result, i512, 1)?; + + // Peform lhs * (1 / rhs) + let result = entry.muli(lhs, inverse, location)?; // Apply modulo and convert result to felt252 - let result_mod = - inverse_result_block.append_op_result(arith::remui(result, prime, location))?; + let result_mod = entry.append_op_result(arith::remui(result, prime, location))?; let is_out_of_range = - inverse_result_block.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; + entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; - let result = inverse_result_block.append_op_result(arith::select( + let result = entry.append_op_result(arith::select( is_out_of_range, result_mod, result, location, ))?; - let result = inverse_result_block.trunci(result, felt252_ty, location)?; + let result = entry.trunci(result, felt252_ty, location)?; - return helper.br(inverse_result_block, 0, &[result], location); + return helper.br(entry, 0, &[result], location); } }; diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 8f539c881..7a9ffdbef 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -46,7 +46,8 @@ enum RuntimeBinding { DictDup, GetCostsBuiltin, DebugPrint, - ExtendedEuclideanAlgorithm, + U252ExtendedEuclideanAlgorithm, + U384ExtendedEuclideanAlgorithm, CircuitArithOperation, #[cfg(feature = "with-cheatcode")] VtableCheatcode, @@ -72,8 +73,11 @@ impl RuntimeBinding { RuntimeBinding::DictDrop => "cairo_native__dict_drop", RuntimeBinding::DictDup => "cairo_native__dict_dup", RuntimeBinding::GetCostsBuiltin => "cairo_native__get_costs_builtin", - RuntimeBinding::ExtendedEuclideanAlgorithm => { - "cairo_native__extended_euclidean_algorithm" + RuntimeBinding::U252ExtendedEuclideanAlgorithm => { + "cairo_native__u252_extended_euclidean_algorithm" + } + RuntimeBinding::U384ExtendedEuclideanAlgorithm => { + "cairo_native__u384_extended_euclidean_algorithm" } RuntimeBinding::CircuitArithOperation => "cairo_native__circuit_arith_operation", #[cfg(feature = "with-cheatcode")] @@ -124,7 +128,8 @@ impl RuntimeBinding { RuntimeBinding::GetCostsBuiltin => { crate::runtime::cairo_native__get_costs_builtin as *const () } - RuntimeBinding::ExtendedEuclideanAlgorithm => return None, + RuntimeBinding::U252ExtendedEuclideanAlgorithm + | RuntimeBinding::U384ExtendedEuclideanAlgorithm => return None, RuntimeBinding::CircuitArithOperation => return None, #[cfg(feature = "with-cheatcode")] RuntimeBinding::VtableCheatcode => { @@ -202,7 +207,50 @@ impl RuntimeBindingsMeta { /// After checking, calls the MLIR function with arguments `a` and `b` which are the initial remainders /// used in the algorithm and returns a `Value` containing a struct where the first element is the /// greatest common divisor of `a` and `b` and the second element is the bezout coefficient x. - pub fn extended_euclidean_algorithm<'c, 'a>( + /// + /// This implementation is only for felt252, which uses u252 integers. + pub fn u252_extended_euclidean_algorithm<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + a: Value<'c, '_>, + b: Value<'c, '_>, + ) -> Result> + where + 'c: 'a, + { + let integer_type = IntegerType::new(context, 512).into(); + let func_symbol = RuntimeBinding::U252ExtendedEuclideanAlgorithm.symbol(); + if self + .active_map + .insert(RuntimeBinding::U252ExtendedEuclideanAlgorithm) + { + build_egcd_function(module, context, location, func_symbol, integer_type)?; + } + // The struct returned by the function that contains both of the results + let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false); + Ok(block + .append_operation( + OperationBuilder::new("llvm.call", location) + .add_attributes(&[( + Identifier::new(context, "callee"), + FlatSymbolRefAttribute::new(context, func_symbol).into(), + )]) + .add_operands(&[a, b]) + .add_results(&[return_type]) + .build()?, + ) + .result(0)? + .into()) + } + + /// Similar to [felt252_extended_euclidean_algorithm](Self::felt252_extended_euclidean_algorithm). + /// + /// The difference with the other is that this function is meant to be used + /// with circuits, which use u384 integers. + pub fn u384_extended_euclidean_algorithm<'c, 'a>( &mut self, context: &'c Context, module: &Module, @@ -214,14 +262,14 @@ impl RuntimeBindingsMeta { where 'c: 'a, { - let func_symbol = RuntimeBinding::ExtendedEuclideanAlgorithm.symbol(); + let integer_type = IntegerType::new(context, 768).into(); + let func_symbol = RuntimeBinding::U384ExtendedEuclideanAlgorithm.symbol(); if self .active_map - .insert(RuntimeBinding::ExtendedEuclideanAlgorithm) + .insert(RuntimeBinding::U384ExtendedEuclideanAlgorithm) { - build_egcd_function(module, context, location, func_symbol)?; + build_egcd_function(module, context, location, func_symbol, integer_type)?; } - let integer_type: Type = IntegerType::new(context, 384 * 2).into(); // The struct returned by the function that contains both of the results let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false); Ok(block @@ -820,33 +868,22 @@ pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) { /// /// This function declares a MLIR function that given two numbers a and b, returns a MLIR struct with gcd(a, b) /// and the bezout coefficient x. The declaration is done in the body of the module. +/// +/// The primary use of this function is to find the modular multiplicative inverse of a value. To so, it is expected +/// the a represents the value to be inverted and b the modulus of the field field. fn build_egcd_function<'ctx>( module: &Module, context: &'ctx Context, location: Location<'ctx>, func_symbol: &str, + integer_type: Type, ) -> Result<()> { - let integer_type: Type = IntegerType::new(context, 384 * 2).into(); let region = Region::new(); let entry_block = region.append_block(Block::new(&[ (integer_type, location), (integer_type, location), ])); - - let a = entry_block.arg(0)?; - let b = entry_block.arg(1)?; - // The egcd algorithm works by calculating a series of remainders `rem`, being each `rem_i` the remainder of dividing `rem_{i-1}` with `rem_{i-2}` - // For the initial setup, rem_0 = b, rem_1 = a. - // This order is chosen because if we reverse them, then the first iteration will just swap them - let remainder = a; - let prev_remainder = b; - - // Similarly we'll calculate another series which starts 0,1,... and from which we - // will retrieve the modular inverse of a - let prev_inverse = entry_block.const_int_from_type(context, location, 0, integer_type)?; - let inverse = entry_block.const_int_from_type(context, location, 1, integer_type)?; - let loop_block = region.append_block(Block::new(&[ (integer_type, location), (integer_type, location), @@ -858,6 +895,19 @@ fn build_egcd_function<'ctx>( (integer_type, location), ])); + let rhs = entry_block.arg(0)?; + let prime_modulus = entry_block.arg(1)?; + // The egcd algorithm works by calculating a series of remainders `rem`, being each `rem_i` the remainder of dividing `rem_{i-1}` with `rem_{i-2}` + // For the initial setup, rem_0 = b, rem_1 = a. + // This order is chosen because if we reverse them, then the first iteration will just swap them + let remainder = rhs; + let prev_remainder = prime_modulus; + + // Similarly we'll calculate another series which starts 0,1,... and from which we + // will retrieve the modular inverse of a + let prev_inverse = entry_block.const_int_from_type(context, location, 0, integer_type)?; + let inverse = entry_block.const_int_from_type(context, location, 1, integer_type)?; + entry_block.append_operation(cf::br( &loop_block, &[prev_remainder, remainder, prev_inverse, inverse], @@ -900,17 +950,37 @@ fn build_egcd_function<'ctx>( location, )); + let gcd = end_block.arg(0)?; + let inverse = end_block.arg(1)?; + + // EGDC sometimes returns a negative number for the inverse, + // in such cases we must simply wrap it around back into [0, MODULUS) + // this suffices because |inv_i| <= divfloor(MODULUS,2) + let zero = end_block.const_int_from_type(context, location, 0, integer_type)?; + let is_negative = end_block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Slt, + inverse, + zero, + location, + )) + .result(0)? + .into(); + let wrapped_inverse = end_block.addi(inverse, prime_modulus, location)?; + let inverse = end_block.append_op_result(arith::select( + is_negative, + wrapped_inverse, + inverse, + location, + ))?; + // Create the struct that will contain the results let results = end_block.append_op_result(llvm::undef( llvm::r#type::r#struct(context, &[integer_type, integer_type], false), location, ))?; - let results = end_block.insert_values( - context, - location, - results, - &[end_block.arg(0)?, end_block.arg(1)?], - )?; + let results = end_block.insert_values(context, location, results, &[gcd, inverse])?; end_block.append_operation(llvm::r#return(Some(results), location)); let func_name = StringAttribute::new(context, func_symbol);