@@ -1992,7 +1992,12 @@ ggml_tensor * llm_graph_context::build_rwkv_channel_mix(
19921992 return cur;
19931993}
19941994
1995- void llm_graph_context::build_pooling (ggml_cgraph * gf) const {
1995+ void llm_graph_context::build_pooling (
1996+ ggml_cgraph * gf,
1997+ ggml_tensor * cls,
1998+ ggml_tensor * cls_b,
1999+ ggml_tensor * cls_out,
2000+ ggml_tensor * cls_out_b) const {
19962001 if (!cparams.embeddings ) {
19972002 return ;
19982003 }
@@ -2036,18 +2041,18 @@ void llm_graph_context::build_pooling(ggml_cgraph * gf) const {
20362041
20372042 // classification head
20382043 // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2039- GGML_ASSERT (model. cls != nullptr );
2040- GGML_ASSERT (model. cls_b != nullptr );
2044+ GGML_ASSERT (cls != nullptr );
2045+ GGML_ASSERT (cls_b != nullptr );
20412046
2042- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model. cls , inp), model. cls_b );
2047+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls, inp), cls_b);
20432048 cur = ggml_tanh (ctx0, cur);
20442049
20452050 // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
20462051 // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2047- if (model. cls_out ) {
2048- GGML_ASSERT (model. cls_out_b != nullptr );
2052+ if (cls_out) {
2053+ GGML_ASSERT (cls_out_b != nullptr );
20492054
2050- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model. cls_out , cur), model. cls_out_b );
2055+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls_out, cur), cls_out_b);
20512056 }
20522057 } break ;
20532058 default :
0 commit comments