diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 7064b7486f267..cabd301ad3572 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -3156,26 +3156,17 @@ static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op return (op0 && op0->src[1] == op1->src[1]); } +static inline bool is_compute_op(ggml_tensor *node) +{ + return !(ggml_op_is_empty(node->op) || ggml_is_empty(node)); +} + // scan the graph and figure out last compute op index static inline int last_compute_op(ggml_cgraph * graph) { - int last; + int last = 0; for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * node = graph->nodes[i]; - - switch (node->op) { - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_RMS_NORM: - case GGML_OP_GLU: - case GGML_OP_ADD_ID: - last = i; - break; - - default: - break; + if (is_compute_op(graph->nodes[i])) { + last = i; } } @@ -3194,6 +3185,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * node = graph->nodes[i]; + if (!is_compute_op(node)) { + continue; + } + uint32_t flags = 0; // skip quantizer if src1 is reused @@ -3245,14 +3240,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_rope(node, flags); break; - // non-compute ops - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - break; - default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 92c0109d28712..8ed7f67d9c8c9 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt }; #define htp_binary_preamble \ + const struct htp_tensor * src0 = &octx->src0; \ + const struct htp_tensor * src1 = &octx->src1; \ + const struct htp_tensor * src2 = &octx->src2; \ + struct htp_tensor * dst = &octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f const uint32_t nb0 = dst->nb[0]; \ const uint32_t nb1 = dst->nb[1]; \ const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; - -static void binary_job_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - enum htp_op op) { + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +static void binary_job_f32_per_thread(struct htp_ops_context * octx, + uint8_t * spad_data, + uint32_t nth, + uint32_t ith, + enum htp_op op) { htp_binary_preamble; const size_t src0_row_size = nb01; @@ -107,16 +111,23 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0, uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); - const uint32_t nr0 = ne00 / ne10; - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - const uint8_t * restrict src1_ptr = NULL; + + const uint32_t ne02_ne01 = ne02 * ne01; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size; + const uint32_t i03 = fastdiv(ir, &octx->src0_div21); + const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); + const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + + const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); + const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); + const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); + + const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; if (ir + 1 < src0_end_row) { htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); @@ -125,6 +136,7 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0, } } + const uint32_t nr0 = ne00 / ne10; if (nr0 > 1) { if ((1 == is_aligned) && (nr0 == ne00)) { hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0); @@ -149,22 +161,17 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - const struct htp_tensor * src2, - struct htp_tensor * dst, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - hvx_elemwise_f32_func func_HVX) { +static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, + uint8_t * spad_data, + uint32_t nth, + uint32_t ith, + hvx_elemwise_f32_func func_HVX) { htp_binary_preamble; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; const size_t dst_row_size = nb1; - const uint32_t ne02_ne01 = ne02 * ne01; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src0_start_row = src0_nrows_per_thread * ith; @@ -187,10 +194,11 @@ static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0, const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint32_t ne02_ne01 = ne02 * ne01; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { // src0 indices - const uint32_t i03 = ir / ne02_ne01; - const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01; + const uint32_t i03 = fastdiv(ir, &octx->src0_div21); + const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); // src1 indices @@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: - binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i, - octx->src0_nrows_per_thread, octx->op); + binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); break; case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n, - i, octx->src0_nrows_per_thread, hvx_add_f32); + binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); break; default: @@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); + octx->src0_div3 = init_fastdiv_values(src0->ne[3]); + octx->src0_div2 = init_fastdiv_values(src0->ne[2]); + octx->src0_div1 = init_fastdiv_values(src0->ne[1]); + + octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); + octx->src1_div3 = init_fastdiv_values(src1->ne[3]); + octx->src1_div2 = init_fastdiv_values(src1->ne[2]); + octx->src1_div1 = init_fastdiv_values(src1->ne[1]); + worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); } diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f23d578806867..9278f41f4e119 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -119,10 +119,10 @@ static const char * htp_type_name(uint32_t t) { #define HTP_MAX_DIMS 4 struct htp_tensor { - uint32_t data; // Buffer offset in the messages, and data pointer on the NSP - uint32_t type; // Data type - uint32_t ne[HTP_MAX_DIMS]; // Number of elements - uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) + uint32_t data; // Buffer offset in the messages, and data pointer on the NSP + uint32_t type; // Data type + uint32_t ne[HTP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) }; #define HTP_MAX_OP_PARAMS 64 diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 45723196791af..e87657436f08b 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -4,6 +4,7 @@ #include "htp-ctx.h" #include "htp-msg.h" #include "worker-pool.h" +#include "ops-utils.h" #include #include @@ -38,6 +39,16 @@ struct htp_ops_context { uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; + struct fastdiv_values src0_div1; // fastdiv values for ne1 + struct fastdiv_values src0_div2; // fastdiv values for ne2 + struct fastdiv_values src0_div3; // fastdiv values for ne3 + struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values src1_div1; // fastdiv values for ne1 + struct fastdiv_values src1_div2; // fastdiv values for ne2 + struct fastdiv_values src1_div3; // fastdiv values for ne3 + struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/ops-utils.h b/ggml/src/ggml-hexagon/htp/ops-utils.h index 302f1625216d8..af9c3305f61ff 100644 --- a/ggml/src/ggml-hexagon/htp/ops-utils.h +++ b/ggml/src/ggml-hexagon/htp/ops-utils.h @@ -31,6 +31,39 @@ static inline uint32_t htp_round_up(uint32_t n, uint32_t m) { return m * ((n + m - 1) / m); } +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +struct fastdiv_values { + uint32_t mp; + uint32_t l; +}; + +static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { + struct fastdiv_values result = { 0, 0 }; + // compute L = ceil(log2(d)); + while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { + ++(result.l); + } + + result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); + return result; +} + +static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { + // Compute high 32 bits of n * mp + const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) + // add n, apply bit shift + return (hi + n) >> vals->l; +} + +static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { + return n - fastdiv(n, vals) * d; +} + static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));