@@ -1131,17 +1131,18 @@ llm_graph_result_ptr llama_context_base::graph_build(
11311131 return model.build_graph (
11321132 {
11331133 /* .ctx =*/ ctx,
1134- /* .model =*/ model,
1134+ /* .arch =*/ model.arch ,
1135+ /* .hparams =*/ model.hparams ,
11351136 /* .cparams =*/ cparams,
11361137 /* .ubatch =*/ ubatch,
11371138 /* .sched =*/ sched.get (),
11381139 /* .backend_cpu =*/ backend_cpu,
1139- /* .backends =*/ backends,
11401140 /* .cvec =*/ &cvec,
11411141 /* .loras =*/ &loras,
11421142 /* .memory =*/ nullptr ,
11431143 /* .cross =*/ nullptr ,
11441144 /* .n_outputs =*/ n_outputs,
1145+ /* .cb =*/ graph_get_cb (),
11451146 }, gf, gtype);
11461147}
11471148
@@ -1172,6 +1173,39 @@ enum ggml_status llama_context_base::graph_compute(
11721173 return status;
11731174}
11741175
1176+ llm_graph_cb llama_context_base::graph_get_cb () const {
1177+ return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1178+ if (il >= 0 ) {
1179+ ggml_format_name (cur, " %s-%d" , name, il);
1180+ } else {
1181+ ggml_set_name (cur, name);
1182+ }
1183+
1184+ if (!cparams.offload_kqv ) {
1185+ if (strcmp (name, " kqv_merged_cont" ) == 0 ) {
1186+ // all nodes between the KV store and the attention output are run on the CPU
1187+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend_cpu);
1188+ }
1189+ }
1190+
1191+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1192+ // FIXME: fix in ggml_backend_sched
1193+ const bool full_offload = model.params .n_gpu_layers > (int ) model.hparams .n_layer ;
1194+ if (ubatch.n_tokens < 32 || full_offload) {
1195+ if (il != -1 && strcmp (name, " norm" ) == 0 ) {
1196+ const auto & dev_layer = model.dev_layer (il);
1197+ for (const auto & backend : backends) {
1198+ if (ggml_backend_get_device (backend.get ()) == dev_layer) {
1199+ if (ggml_backend_supports_op (backend.get (), cur)) {
1200+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend.get ());
1201+ }
1202+ }
1203+ }
1204+ }
1205+ }
1206+ };
1207+ }
1208+
11751209//
11761210// perf
11771211//
@@ -2567,17 +2601,18 @@ llm_graph_result_ptr llama_context_kv_self::graph_build(
25672601 return model.build_graph (
25682602 {
25692603 /* .ctx =*/ ctx,
2570- /* .model =*/ model,
2604+ /* .arch =*/ model.arch ,
2605+ /* .hparams =*/ model.hparams ,
25712606 /* .cparams =*/ cparams,
25722607 /* .ubatch =*/ ubatch,
25732608 /* .sched =*/ sched.get (),
25742609 /* .backend_cpu =*/ backend_cpu,
2575- /* .backends =*/ backends,
25762610 /* .cvec =*/ &cvec,
25772611 /* .loras =*/ &loras,
25782612 /* .memory =*/ kv_self.get (),
25792613 /* .cross =*/ nullptr ,
25802614 /* .n_outputs =*/ n_outputs,
2615+ /* .cb =*/ graph_get_cb (),
25812616 }, gf, gtype);
25822617}
25832618
@@ -3010,17 +3045,18 @@ llm_graph_result_ptr llama_context_recurrent::graph_build(
30103045 return model.build_graph (
30113046 {
30123047 /* .ctx =*/ ctx,
3013- /* .model =*/ model,
3048+ /* .arch =*/ model.arch ,
3049+ /* .hparams =*/ model.hparams ,
30143050 /* .cparams =*/ cparams,
30153051 /* .ubatch =*/ ubatch,
30163052 /* .sched =*/ sched.get (),
30173053 /* .backend_cpu =*/ backend_cpu,
3018- /* .backends =*/ backends,
30193054 /* .cvec =*/ &cvec,
30203055 /* .loras =*/ &loras,
30213056 /* .memory =*/ kv_self.get (),
30223057 /* .cross =*/ nullptr ,
30233058 /* .n_outputs =*/ n_outputs,
3059+ /* .cb =*/ graph_get_cb (),
30243060 }, gf, gtype);
30253061}
30263062
@@ -3227,17 +3263,18 @@ llm_graph_result_ptr llama_context_dec::graph_build(
32273263 return model.build_graph (
32283264 {
32293265 /* .ctx =*/ ctx,
3230- /* .model =*/ model,
3266+ /* .arch =*/ model.arch ,
3267+ /* .hparams =*/ model.hparams ,
32313268 /* .cparams =*/ cparams,
32323269 /* .ubatch =*/ ubatch,
32333270 /* .sched =*/ sched.get (),
32343271 /* .backend_cpu =*/ backend_cpu,
3235- /* .backends =*/ backends,
32363272 /* .cvec =*/ &cvec,
32373273 /* .loras =*/ &loras,
32383274 /* .memory =*/ kv_self.get (),
32393275 /* .cross =*/ cross,
32403276 /* .n_outputs =*/ n_outputs,
3277+ /* .cb =*/ graph_get_cb (),
32413278 }, gf, gtype);
32423279}
32433280
0 commit comments