@@ -3824,7 +3824,7 @@ struct llm_build_llama : public llm_graph_context {
38243824 // self-attention
38253825 {
38263826 // rope freq factors for llama3; may return nullptr for llama2 and other models
3827- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
3827+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
38283828
38293829 // compute Q and K and RoPE them
38303830 struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
@@ -3998,7 +3998,7 @@ struct llm_build_deci : public llm_graph_context {
39983998 } else if (n_head > 0 ) {
39993999 // self-attention
40004000 // rope freq factors for llama3; may return nullptr for llama2 and other models
4001- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
4001+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
40024002
40034003 // compute Q and K and RoPE them
40044004 struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
@@ -6156,7 +6156,7 @@ struct llm_build_phi3 : public llm_graph_context {
61566156 // self-attention
61576157 {
61586158 // rope freq factors for 128k context
6159- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
6159+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
61606160
61616161 struct ggml_tensor * attn_norm_output = build_norm (inpL,
61626162 model.layers [il].attn_norm ,
@@ -6879,7 +6879,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
68796879 for (int il = 0 ; il < n_layer; ++il) {
68806880 struct ggml_tensor * inpSA = inpL;
68816881
6882- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
6882+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
68836883
68846884 // norm
68856885 cur = build_norm (inpL,
@@ -7801,7 +7801,7 @@ struct llm_build_cohere2 : public llm_graph_context {
78017801 // self-attention
78027802 {
78037803 // rope freq factors for 128k context
7804- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
7804+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
78057805
78067806 // compute Q and K and RoPE them
78077807 struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
@@ -8715,7 +8715,7 @@ struct llm_build_deepseek : public llm_graph_context {
87158715 // self-attention
87168716 {
87178717 // rope freq factors for llama3; may return nullptr for llama2 and other models
8718- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
8718+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
87198719
87208720 // compute Q and K and RoPE them
87218721 struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
@@ -9872,7 +9872,7 @@ struct llm_build_exaone : public llm_graph_context {
98729872 // self-attention
98739873 {
98749874 // rope freq factors for llama3; may return nullptr for llama2 and other models
9875- struct ggml_tensor * rope_factors = model. build_rope_factors (n_ctx_per_seq, il);
9875+ ggml_tensor * rope_factors = static_cast < const llama_kv_cache_unified *>(memory)-> cbs . get_rope_factors (n_ctx_per_seq, il);
98769876
98779877 // compute Q and K and RoPE them
98789878 struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
@@ -10682,17 +10682,38 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
1068210682 }
1068310683};
1068410684
10685- ggml_tensor * llama_model::build_rope_factors (uint32_t n_ctx_per_seq, int il) const {
10686- // choose long/short freq factors based on the context size
10687- if (layers[il].rope_freqs != nullptr ) {
10688- return layers[il].rope_freqs ;
10689- }
10685+ llama_memory_i * llama_model::create_memory () const {
10686+ llama_memory_i * res;
1069010687
10691- if (n_ctx_per_seq > hparams.n_ctx_orig_yarn ) {
10692- return layers[il].rope_long ;
10688+ switch (arch) {
10689+ case LLM_ARCH_RWKV6:
10690+ case LLM_ARCH_RWKV6QWEN2:
10691+ case LLM_ARCH_MAMBA:
10692+ {
10693+ res = new llama_kv_cache_recurrent (hparams, {
10694+ /* .get_rope_factors =*/ nullptr
10695+ });
10696+ } break ;
10697+ default :
10698+ {
10699+ res = new llama_kv_cache_unified (hparams, {
10700+ /* .get_rope_factors =*/ [this ](uint32_t n_ctx_per_seq, int il) {
10701+ // choose long/short freq factors based on the context size
10702+ if (layers[il].rope_freqs != nullptr ) {
10703+ return layers[il].rope_freqs ;
10704+ }
10705+
10706+ if (n_ctx_per_seq > hparams.n_ctx_orig_yarn ) {
10707+ return layers[il].rope_long ;
10708+ }
10709+
10710+ return layers[il].rope_short ;
10711+ }
10712+ });
10713+ }
1069310714 }
1069410715
10695- return layers[il]. rope_short ;
10716+ return res ;
1069610717}
1069710718
1069810719llm_graph_result_ptr llama_model::build_graph (
0 commit comments