@@ -2980,7 +2980,7 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
29802980
29812981struct ne_tensor * ne_rope_impl (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
29822982 int prompt_size , bool inplace , int n_keep , struct ne_tensor * cossin , int * n_padding ,
2983- bool padding_left , float freq_base ) {
2983+ bool padding_left , float freq_base , float freq_scale ) {
29842984 NE_ASSERT (n_past >= 0 || n_keep >= 0 );
29852985 NE_ASSERT (padding_left );
29862986 bool is_node = false;
@@ -3020,7 +3020,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
30203020
30213021 ne_scratch_load (ctx );
30223022
3023- ne_set_op_params (result , & freq_base , sizeof (freq_base ));
3023+ float params [] = {freq_base , freq_scale };
3024+ ne_set_op_params (result , & params , sizeof (params ));
3025+
30243026 result -> op = NE_OP_ROPE ;
30253027 result -> grad = is_node ? ne_dup_tensor (ctx , result ) : NULL ;
30263028 result -> src0 = a ;
@@ -3031,18 +3033,20 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
30313033}
30323034
30333035struct ne_tensor * ne_rope (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3034- int prompt_size , float freq_base ) {
3035- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , NULL , true, freq_base );
3036+ int prompt_size , float freq_base , float freq_scale ) {
3037+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , NULL , true, freq_base , freq_scale );
30363038}
30373039
30383040struct ne_tensor * ne_rope_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3039- int prompt_size , float freq_base ) {
3040- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , NULL , true, freq_base );
3041+ int prompt_size , float freq_base , float freq_scale ) {
3042+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , NULL , true, freq_base , freq_scale );
30413043}
30423044
30433045struct ne_tensor * ne_rope_shift_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_shift , int n_dims , int mode ,
3044- int prompt_size , int n_keep , struct ne_tensor * cossin , float freq_base ) {
3045- return ne_rope_impl (ctx , a , n_shift , n_dims , mode , prompt_size , true, n_keep , cossin , NULL , true, freq_base );
3046+ int prompt_size , int n_keep , struct ne_tensor * cossin , float freq_base ,
3047+ float freq_scale ) {
3048+ return ne_rope_impl (ctx , a , n_shift , n_dims , mode , prompt_size , true, n_keep , cossin , NULL , true, freq_base ,
3049+ freq_scale );
30463050}
30473051
30483052// ne_rope_back
@@ -3078,13 +3082,16 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
30783082}
30793083
30803084struct ne_tensor * ne_rope_with_padding (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3081- int prompt_size , int * n_padding , float freq_base ) {
3082- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , n_padding , true, freq_base );
3085+ int prompt_size , int * n_padding , float freq_base , float freq_scale ) {
3086+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , n_padding , true, freq_base ,
3087+ freq_scale );
30833088}
30843089
30853090struct ne_tensor * ne_rope_with_padding_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims ,
3086- int mode , int prompt_size , int * n_padding , float freq_base ) {
3087- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , n_padding , true, freq_base );
3091+ int mode , int prompt_size , int * n_padding , float freq_base ,
3092+ float freq_scale ) {
3093+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , n_padding , true, freq_base ,
3094+ freq_scale );
30883095}
30893096
30903097// ne_alibi
@@ -7867,9 +7874,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
78677874 NE_ASSERT (src1 -> type == NE_TYPE_I32 );
78687875 NE_ASSERT (ne_nelements (src1 ) == 5 + bs ); // 5 + bs params
78697876
7870- float freq_base = 10000.0f ;
7871- memcpy (& freq_base , dst -> op_params , sizeof (float ));
7872- static const float freq_scale = 1.0f ;
7877+ const float freq_base = ((float * )(dst -> op_params ))[0 ];
7878+ const float freq_scale = 1 / ((float * )(dst -> op_params ))[1 ];
78737879
78747880 const int64_t n_past = ((int32_t * )src1 -> data )[ROPE_NPAST_IDX ];
78757881 const int64_t n_dims = ((int32_t * )src1 -> data )[ROPE_NDIMS_IDX ];
@@ -8043,7 +8049,10 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80438049 // row index used to determine which thread to use
80448050 int ir = 0 ;
80458051
8046- const float theta_scale = powf (10000.0 , -2.0f / n_dims );
8052+ const float freq_base = ((float * )(dst -> op_params ))[0 ];
8053+ const float freq_scale = 1 / ((float * )(dst -> op_params ))[1 ];
8054+
8055+ const float theta_scale = powf (freq_base , -2.0f / n_dims );
80478056
80488057 const bool skip = mode & 1 ;
80498058 const bool is_neox = mode & 2 ;
@@ -8053,7 +8062,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80538062 NE_ASSERT (("shift RoPE is only implemented for the vanilla mode" , !is_shift || !(is_glm || is_neox || skip )));
80548063
80558064 if (is_shift ) {
8056- float theta = n_past ;
8065+ float theta = n_past * freq_scale ;
80578066 ne_fp16_t * cossin = (dst -> opt [0 ] != NULL ) ? dst -> opt [0 ]-> data : NULL ;
80588067 if (cossin == NULL ) {
80598068 cossin = malloc (ne0 * sizeof (ne_fp16_t ));
@@ -8098,7 +8107,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80988107 if (ir ++ < ir0 ) continue ;
80998108 if (ir > ir1 ) break ;
81008109
8101- float theta = (float )p ;
8110+ float theta = freq_scale * (float )p ;
81028111
81038112 if (!is_neox ) {
81048113 for (int64_t i0 = 0 ; i0 < ne0 ; i0 += 2 ) {
@@ -8172,11 +8181,14 @@ static void ne_compute_forward_rope_jblas(const struct ne_compute_params* params
81728181 const int seq_len = dst -> ne [1 ];
81738182 const int head_size = dst -> ne [0 ];
81748183
8184+ const float freq_base = ((float * )(dst -> op_params ))[0 ];
8185+ const float freq_scale = 1 / ((float * )(dst -> op_params ))[1 ];
8186+
81758187 if (is_shift ) {
81768188 ne_fp16_t * cossin = (dst -> opt [0 ] != NULL ) ? dst -> opt [0 ]-> data : NULL ;
81778189 if (cossin == NULL ) {
8178- float theta = n_past ;
8179- const float theta_scale = powf (10000.0 , -2.0f / n_dims );
8190+ float theta = n_past * freq_scale ;
8191+ const float theta_scale = powf (freq_base , -2.0f / n_dims );
81808192 cossin = malloc (head_size * sizeof (ne_fp16_t ));
81818193 for (int i0 = 0 ; i0 < head_size ; i0 += 2 ) {
81828194 cossin [i0 + 0 ] = NE_FP32_TO_FP16 (cosf (theta ));
@@ -10016,7 +10028,7 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor
1001610028 const int n_dims = ((int32_t * )src1 -> data )[1 ];
1001710029 const int mode = ((int32_t * )src1 -> data )[2 ];
1001810030 src0 -> grad =
10019- ne_add_impl (ctx , src0 -> grad , ne_rope (ctx , tensor -> grad , n_past , n_dims , mode , 0 , 10000.0 ), inplace );
10031+ ne_add_impl (ctx , src0 -> grad , ne_rope (ctx , tensor -> grad , n_past , n_dims , mode , 0 , 10000.0 , 1.0 ), inplace );
1002010032 }
1002110033 if (src1 -> grad ) {
1002210034 // noop
0 commit comments