@@ -4673,6 +4673,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
46734673 ggml_backend_sched_reset (sched.get ());
46744674
46754675 cross->t_embd = t_embd;
4676+ cross->v_embd = embd;
46764677
46774678 // remember the sequence ids used during the encoding - needed for cross attention later
46784679 cross->seq_ids_enc .resize (n_tokens);
@@ -4701,12 +4702,11 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
47014702 // call base functionality
47024703 llama_context_kv_self::input_set (ubatch);
47034704
4704- // if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
4705- // assert(inp.cross_embd->type == GGML_TYPE_F32);
4706- // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
4705+ if (inp.cross_embd && cross->t_embd ) {
4706+ assert (inp.cross_embd ->type == GGML_TYPE_F32);
47074707
4708- // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc , 0, ggml_nbytes(inp.cross_embd));
4709- // }
4708+ ggml_backend_tensor_set (inp.cross_embd , cross->v_embd , 0 , ggml_nbytes (inp.cross_embd ));
4709+ }
47104710
47114711 if (inp.cross_kq_mask ) {
47124712 const int64_t n_enc = inp.cross_kq_mask ->ne [0 ];
@@ -4749,16 +4749,17 @@ ggml_cgraph * llama_context_dec::graph_init() {
47494749ggml_tensor * llama_context_dec::build_inp_cross_embd (
47504750 ggml_context * ctx0) {
47514751 // if we have the output embeddings from the encoder, use them directly
4752- if (cross->t_embd ) {
4753- inp.cross_embd = ggml_view_tensor (ctx0, cross->t_embd );
4752+ // TODO: needs more work to be correct, for now just use the tensor shape
4753+ // if (cross->t_embd) {
4754+ // inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
47544755
4755- return inp.cross_embd ;
4756- }
4756+ // return inp.cross_embd;
4757+ // }
47574758
47584759 const auto & hparams = model.hparams ;
47594760
4760- const auto n_embd = hparams.n_embd ;
4761- const auto n_enc = hparams.n_ctx_train ;
4761+ const auto n_embd = cross-> t_embd ? cross-> t_embd -> ne [ 0 ] : hparams.n_embd ;
4762+ const auto n_enc = cross-> t_embd ? cross-> t_embd -> ne [ 1 ] : hparams.n_ctx_train ;
47624763
47634764 inp.cross_embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_enc);
47644765 ggml_set_input (inp.cross_embd );
0 commit comments