@@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
45404540// llama_context_enc
45414541//
45424542
4543+ // TODO: avoid copy-paste of the entire encode() function
45434544int llama_context_enc::encode (llama_batch & inp_batch) {
45444545 if (inp_batch.n_tokens == 0 ) {
45454546 LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
@@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46714672 // overlap with device computation.
46724673 ggml_backend_sched_reset (sched.get ());
46734674
4674- cross->n_outputs = n_tokens;
4675- cross->embd_enc = embd;
4675+ cross->t_embd = t_embd;
46764676
46774677 // remember the sequence ids used during the encoding - needed for cross attention later
46784678 cross->seq_ids_enc .resize (n_tokens);
@@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46924692
46934693void llama_context_dec::reserve () {
46944694 // simulate full KV cache
4695- cross->n_outputs = cparams.n_ubatch ;
4696-
4697- LLAMA_LOG_DEBUG (" %s: n_outputs = %u\n " , __func__, cross->n_outputs );
4695+ cross->t_embd = nullptr ;
46984696
46994697 llama_context_kv_self::reserve ();
47004698}
@@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47034701 // call base functionality
47044702 llama_context_kv_self::input_set (ubatch);
47054703
4706- if (inp.cross_embd ) {
4707- assert (inp.cross_embd ->type == GGML_TYPE_F32);
4708- assert (ggml_nelements (inp.cross_embd ) == cross->n_outputs *model.hparams .n_embd );
4704+ // if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE ) {
4705+ // assert(inp.cross_embd->type == GGML_TYPE_F32);
4706+ // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
47094707
4710- ggml_backend_tensor_set (inp.cross_embd , cross->embd_enc , 0 , ggml_nbytes (inp.cross_embd ));
4711- }
4708+ // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
4709+ // }
47124710
47134711 if (inp.cross_kq_mask ) {
4714- const int64_t n_output_enc = cross-> n_outputs ;
4712+ const int64_t n_enc = inp. cross_kq_mask -> ne [ 0 ] ;
47154713 const int64_t n_tokens = ubatch.n_tokens ;
47164714
47174715 GGML_ASSERT (ggml_backend_buffer_is_host (inp.cross_kq_mask ->buffer ));
@@ -4721,21 +4719,21 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47214719
47224720 for (int h = 0 ; h < 1 ; ++h) {
47234721 for (int j = 0 ; j < n_tokens; ++j) {
4724- for (int i = 0 ; i < n_output_enc ; ++i) {
4722+ for (int i = 0 ; i < n_enc ; ++i) {
47254723 float f = -INFINITY;
47264724 for (int s = 0 ; s < ubatch.n_seq_id [j]; ++s) {
47274725 const llama_seq_id seq_id = ubatch.seq_id [j][s];
47284726 if (cross->seq_ids_enc [i].find (seq_id) != cross->seq_ids_enc [i].end ()) {
47294727 f = 0 .0f ;
47304728 }
47314729 }
4732- data[h*(n_output_enc *n_tokens) + j*n_output_enc + i] = f;
4730+ data[h*(n_enc *n_tokens) + j*n_enc + i] = f;
47334731 }
47344732 }
47354733
47364734 for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
4737- for (int j = 0 ; j < n_output_enc ; ++j) {
4738- data[h*(n_output_enc *n_tokens) + i*n_output_enc + j] = -INFINITY;
4735+ for (int j = 0 ; j < n_enc ; ++j) {
4736+ data[h*(n_enc *n_tokens) + i*n_enc + j] = -INFINITY;
47394737 }
47404738 }
47414739 }
@@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
47504748
47514749ggml_tensor * llama_context_dec::build_inp_cross_embd (
47524750 ggml_context * ctx0) {
4751+ // if we have the output embeddings from the encoder, use them directly
4752+ if (cross->t_embd ) {
4753+ inp.cross_embd = ggml_view_tensor (ctx0, cross->t_embd );
4754+
4755+ return inp.cross_embd ;
4756+ }
4757+
47534758 const auto & hparams = model.hparams ;
4754- const int64_t n_embd = hparams.n_embd ;
47554759
4756- const int32_t n_outputs_enc = cross->n_outputs ;
4760+ const auto n_embd = hparams.n_embd ;
4761+ const auto n_enc = hparams.n_ctx_train ;
47574762
4758- inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc );
4763+ inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_enc );
47594764 ggml_set_input (inp.cross_embd );
47604765
47614766 return inp.cross_embd ;
@@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
47684773 bool swa) {
47694774 llama_context_kv_self::build_attn_inp (ctx0, n_tokens, causal, swa);
47704775
4771- const int32_t n_outputs_enc = cross->n_outputs ;
4776+ const int32_t n_enc = cross->t_embd ? cross-> t_embd -> ne [ 1 ] : model. hparams . n_ctx_train ;
47724777
4773- inp.cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_outputs_enc , GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
4778+ inp.cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc , GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
47744779 ggml_set_input (inp.cross_kq_mask );
47754780
47764781 inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp.cross_kq_mask , GGML_TYPE_F16) : inp.cross_kq_mask ;
0 commit comments