@@ -5,7 +5,7 @@ use proc_macro::TokenStream;
55use quote:: format_ident;
66use syn:: {
77 parse:: { Parse , ParseStream } ,
8- parse_macro_input, Expr , ExprPath , Path , Token ,
8+ parse_macro_input, ExprPath , LitStr , Token ,
99} ;
1010
1111/// This macro generates the code to setup a Twisted Edwards elliptic curve for a given modular
@@ -86,7 +86,7 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
8686 let result = TokenStream :: from ( quote:: quote_spanned! { span. into( ) =>
8787 extern "C" {
8888 fn #te_add_extern_func( rd: usize , rs1: usize , rs2: usize ) ;
89- fn #te_setup_extern_func( ) ;
89+ fn #te_setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) ;
9090 }
9191
9292 #[ derive( Eq , PartialEq , Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
@@ -143,7 +143,15 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
143143 fn set_up_once( ) {
144144 static is_setup: :: openvm_ecc_guest:: once_cell:: race:: OnceBool = :: openvm_ecc_guest:: once_cell:: race:: OnceBool :: new( ) ;
145145 is_setup. get_or_init( || {
146- unsafe { #te_setup_extern_func( ) ; }
146+ let modulus_bytes = <<Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
147+ let mut zero = [ 0u8 ; <<Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
148+ let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_A ) ;
149+ let curve_d_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_D ) ;
150+ let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
151+ let p2 = [ curve_d_bytes. as_ref( ) , zero. as_ref( ) ] . concat( ) ;
152+ let mut uninit: core:: mem:: MaybeUninit <[ Self ; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
153+
154+ unsafe { #te_setup_extern_func( uninit. as_mut_ptr( ) as * mut core:: ffi:: c_void, p1. as_ptr( ) , p2. as_ptr( ) ) ; }
147155 <#intmod_type as openvm_algebra_guest:: IntMod >:: set_up_once( ) ;
148156 true
149157 } ) ;
@@ -266,22 +274,16 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
266274}
267275
268276struct TeDefine {
269- items : Vec < Path > ,
277+ items : Vec < String > ,
270278}
271279
272280impl Parse for TeDefine {
273281 fn parse ( input : ParseStream ) -> syn:: Result < Self > {
274- let items = input. parse_terminated ( <Expr as Parse >:: parse, Token ! [ , ] ) ?;
282+ let items = input. parse_terminated ( <LitStr as Parse >:: parse, Token ! [ , ] ) ?;
275283 Ok ( Self {
276284 items : items
277285 . into_iter ( )
278- . map ( |e| {
279- if let Expr :: Path ( p) = e {
280- p. path
281- } else {
282- panic ! ( "expected path" ) ;
283- }
284- } )
286+ . map ( |e| e. value ( ) )
285287 . collect ( ) ,
286288 } )
287289 }
@@ -295,17 +297,11 @@ pub fn te_init(input: TokenStream) -> TokenStream {
295297
296298 let span = proc_macro:: Span :: call_site ( ) ;
297299
298- for ( ec_idx, item) in items. into_iter ( ) . enumerate ( ) {
299- let str_path = item
300- . segments
301- . iter ( )
302- . map ( |x| x. ident . to_string ( ) )
303- . collect :: < Vec < _ > > ( )
304- . join ( "_" ) ;
300+ for ( ec_idx, struct_id) in items. into_iter ( ) . enumerate ( ) {
305301 let add_extern_func =
306- syn:: Ident :: new ( & format ! ( "te_add_extern_func_{}" , str_path ) , span. into ( ) ) ;
302+ syn:: Ident :: new ( & format ! ( "te_add_extern_func_{}" , struct_id ) , span. into ( ) ) ;
307303 let setup_extern_func =
308- syn:: Ident :: new ( & format ! ( "te_setup_extern_func_{}" , str_path ) , span. into ( ) ) ;
304+ syn:: Ident :: new ( & format ! ( "te_setup_extern_func_{}" , struct_id ) , span. into ( ) ) ;
309305 externs. push ( quote:: quote_spanned! { span. into( ) =>
310306 #[ no_mangle]
311307 extern "C" fn #add_extern_func( rd: usize , rs1: usize , rs2: usize ) {
@@ -321,26 +317,19 @@ pub fn te_init(input: TokenStream) -> TokenStream {
321317 }
322318
323319 #[ no_mangle]
324- extern "C" fn #setup_extern_func( ) {
320+ extern "C" fn #setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) {
325321 #[ cfg( target_os = "zkvm" ) ]
326322 {
327- use super :: #item;
328- let modulus_bytes = <<#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
329- let mut zero = [ 0u8 ; <<#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
330- let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_A ) ;
331- let curve_d_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_D ) ;
332- let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
333- let p2 = [ curve_d_bytes. as_ref( ) , zero. as_ref( ) ] . concat( ) ;
334- let mut uninit: core:: mem:: MaybeUninit <[ #item; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
323+
335324 openvm:: platform:: custom_insn_r!(
336325 opcode = :: openvm_ecc_guest:: TE_OPCODE ,
337326 funct3 = :: openvm_ecc_guest:: TE_FUNCT3 as usize ,
338327 funct7 = :: openvm_ecc_guest:: TeBaseFunct7 :: TeSetup as usize
339328 + #ec_idx
340329 * ( :: openvm_ecc_guest:: TeBaseFunct7 :: TWISTED_EDWARDS_MAX_KINDS as usize ) ,
341- rd = In uninit. as_mut_ptr ( ) ,
342- rs1 = In p1. as_ptr ( ) ,
343- rs2 = In p2. as_ptr ( ) ,
330+ rd = In uninit,
331+ rs1 = In p1,
332+ rs2 = In p2,
344333 ) ;
345334 }
346335 }
0 commit comments