diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c886..5c58110384aa6 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -6,8 +6,17 @@ #include #include +#include #include #include +#include + +// verbosity flag set via the params.verbosity CLI flag. This is used for two +// things: +// 1. If > 0, tensors are printed with 8 digits of precision instead of 5 +// 2. If > 1, all tensor values are printed instead of the pretty-printed +// partial output +static int verbosity = 0; /** * This the arbitrary data which will be passed to each callback. @@ -61,6 +70,10 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * } static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + std::stringstream ss; + const int float_digits = verbosity > 0 ? 8 : 4; + ss << "%12." << float_digits << "f"; + const auto float_fmt = ss.str(); GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { @@ -93,7 +106,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne i0 = ne[0] - n; } const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - LOG("%12.4f", v); + LOG(float_fmt.c_str(), v); if (i0 < ne[0] - 1) LOG(", "); } LOG("],\n"); @@ -153,8 +166,9 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } if (!ggml_is_quantized(t->type)) { + const int print_width = verbosity > 1 ? std::numeric_limits::max() : 3; uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + ggml_print_tensor(data, t->type, t->ne, t->nb, print_width); } return true; @@ -192,6 +206,9 @@ int main(int argc, char ** argv) { common_init(); + // set verbosity for printing + verbosity = params.verbosity; + llama_backend_init(); llama_numa_init(params.numa); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c1ed1a21c81c4..4acec8216f28b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, + GGML_OP_CUMSUM, GGML_OP_MEAN, GGML_OP_ARGMAX, GGML_OP_COUNT_EQUAL, @@ -530,6 +531,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_TRI, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -582,6 +584,7 @@ extern "C" { GGML_UNARY_OP_CEIL, GGML_UNARY_OP_ROUND, GGML_UNARY_OP_TRUNC, + GGML_UNARY_OP_SOFTPLUS, GGML_UNARY_OP_COUNT, }; @@ -620,6 +623,13 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + enum ggml_tri_type { + GGML_TRI_TYPE_UPPER_DIAG = 0, // upper including diag + GGML_TRI_TYPE_UPPER = 1, // upper excluding diag + GGML_TRI_TYPE_LOWER_DIAG = 2, // lower including diag + GGML_TRI_TYPE_LOWER = 3 // lower excluding diag + }; + struct ggml_init_params { // memory pool size_t mem_size; // bytes @@ -983,6 +993,17 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // Cumulative sum along the specified dimension + GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a, + int dim); + + // Convenience function: cumulative sum along dimension 0 + GGML_API struct ggml_tensor * ggml_cumsum_0( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx, @@ -1194,6 +1215,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // softplus(x) = log(1 + exp(beta * x)) / beta, for x * beta <= threshold + // = x, otherwise + GGML_API struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a); // xIELU activation function @@ -2187,6 +2213,27 @@ extern "C" { int shift2, int shift3); + // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value + // dim_x and dim_y specify which two dimensions to compare for triangular masking. They must have equal size. + // Default is dim_x=0, dim_y=1 (compares indices in dim 0 vs indices in dim 1) + GGML_API struct ggml_tensor * ggml_tri_dims( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype, + int dim_x, + int dim_y); + + GGML_API struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype); + + GGML_API struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype); // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b5466dd703d1d..8255acb26e283 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_sum_rows(params, tensor); } break; + case GGML_OP_CUMSUM: + { + ggml_compute_forward_cumsum(params, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor); @@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_TRI: + { + ggml_compute_forward_tri(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor); @@ -2153,6 +2161,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: { n_tasks = 1; } break; @@ -2192,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_SOFTPLUS: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8235f69594391..76bc67b0b3a4c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,6 +9,7 @@ #include #include +#include // ggml_compute_forward_dup @@ -1394,6 +1395,186 @@ void ggml_compute_forward_sum( } } +// ggml_compute_forward_cumsum + +// General implementation for arbitrary dimensions +template +static void ggml_compute_forward_cumsum_general( + const ggml_compute_params * params, + ggml_tensor * dst, + int dim) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + + GGML_TENSOR_UNARY_OP_LOCALS + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + for (int64_t i0 = 0; i0 < ne00; i0++) { + const T * src_ptr = (const T *)((const char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + T * dst_ptr = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + // Determine position in the cumsum dimension + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t i_dim = i_vals[dim]; + + if (i_dim == 0) { + // First element: just copy + dst_ptr[0] = src_ptr[0]; + } else { + // Accumulate: add current value to previous cumsum value + const T * prev_dst_ptr = (const T *)((const char *) dst_ptr - dst->nb[dim]); + dst_ptr[0] = type_conversion_table::from_f32( + type_conversion_table::to_f32(prev_dst_ptr[0]) + + type_conversion_table::to_f32(src_ptr[0])); + } + } + } + } + } +} + +static void ggml_compute_forward_cumsum_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + const float * src_row = (const float *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f32(ne00, dst_row, src_row); + } + } + } +} + +static void ggml_compute_forward_cumsum_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + const ggml_fp16_t * src_row = (const ggml_fp16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f16(ne00, dst_row, src_row); + } + } + } +} + +static void ggml_compute_forward_cumsum_bf16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_bf16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + const ggml_bf16_t * src_row = (const ggml_bf16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_bf16(ne00, dst_row, src_row); + } + } + } +} + +void ggml_compute_forward_cumsum( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + const int dim = ggml_get_op_params_i32(dst, 0); + const bool use_general = dim != 0 || !ggml_is_contiguous_rows(src0); + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_f32(params, dst); + } + } break; + case GGML_TYPE_F16: + { + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_f16(params, dst); + } + } break; + case GGML_TYPE_BF16: + { + if (use_general) { + ggml_compute_forward_cumsum_general(params, dst, dim); + } else { + ggml_compute_forward_cumsum_bf16(params, dst); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum_rows static void ggml_compute_forward_sum_rows_f32( @@ -2140,6 +2321,153 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_tri + +// General implementation for arbitrary dimensions +template +static void ggml_compute_forward_tri_general( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + T c_val, + bool keep_org_val, + int dim_x, + int dim_y) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(dim_x >= 0 && dim_x < GGML_MAX_DIMS); + GGML_ASSERT(dim_y >= 0 && dim_y < GGML_MAX_DIMS); + GGML_ASSERT(dim_x != dim_y); + GGML_ASSERT(src0->ne[dim_x] == src0->ne[dim_y]); + + GGML_TENSOR_UNARY_OP_LOCALS + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + for (int64_t i00 = 0; i00 < ne0; ++i00) { + const T * src = (const T *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + T * dst_ptr = ( T *)(( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 + i00*nb00); + int64_t i_vals[4] = {i00, i01, i02, i03}; + int64_t iX = i_vals[dim_x]; + int64_t iY = i_vals[dim_y]; + dst_ptr[0] = _ggml_vec_tri_cmp(iX, iY, ttype) ? + (keep_org_val ? src[0] : c_val) : + type_conversion_table::from_f32(0.f); + } + } +} + +static void ggml_compute_forward_tri_f32( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + float c_val, + bool keep_org_val) { + const ggml_tensor * src0 = dst->src[0]; + GGML_TENSOR_UNARY_OP_LOCALS + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + float * dst_ptr = (float *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const float * src = (const float *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); + } +} + +static void ggml_compute_forward_tri_f16( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + ggml_fp16_t c_val, + bool keep_org_val) { + const ggml_tensor * src0 = dst->src[0]; + GGML_TENSOR_UNARY_OP_LOCALS + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const ggml_fp16_t * src = (const ggml_fp16_t *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f16(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); + } +} + +static void ggml_compute_forward_tri_bf16( + const ggml_compute_params * params, + ggml_tensor * dst, + ggml_tri_type ttype, + ggml_bf16_t c_val, + bool keep_org_val) { + const ggml_tensor * src0 = dst->src[0]; + GGML_TENSOR_UNARY_OP_LOCALS + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const ggml_bf16_t * src = (const ggml_bf16_t *)((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_bf16(ne0, i01, dst_ptr, src, keep_org_val, c_val, ttype); + } +} + +void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const int dim_x = ggml_get_op_params_i32(dst, 2); + const int dim_y = ggml_get_op_params_i32(dst, 3); + const bool use_general = dim_x != 0 || dim_y != 1 || !ggml_is_contiguous(src0); + const bool keep_org_val = isnan(c); + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_f32(params, dst, ttype, c, keep_org_val); + } + } break; + case GGML_TYPE_F16: + { + ggml_fp16_t c_val = GGML_FP32_TO_FP16(c); + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c_val, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_f16(params, dst, ttype, c_val, keep_org_val); + } + } break; + case GGML_TYPE_BF16: + { + ggml_bf16_t c_val = GGML_FP32_TO_BF16(c); + if (use_general) { + ggml_compute_forward_tri_general(params, dst, ttype, c_val, keep_org_val, dim_x, dim_y); + } else { + ggml_compute_forward_tri_bf16(params, dst, ttype, c_val, keep_org_val); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_erf static void ggml_compute_forward_gelu_erf_f32( @@ -4230,6 +4558,60 @@ static void ggml_compute_forward_scale_f32( } } +static void ggml_compute_forward_scale_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(ggml_fp16_t)); + } + ggml_vec_scale_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*nb1), s); + } + } else { + //TODO: support bias! + GGML_ABORT("fatal error"); + // for (int i1 = ir0; i1 < ir1; i1++) { + // ggml_vec_mad1_f16(nc, + // (ggml_fp16_t *) ((char *) dst->data + i1*nb1), + // (ggml_fp16_t *) ((char *) src0->data + i1*nb1), + // s, b); + // } + } +} + void ggml_compute_forward_scale( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4241,6 +4623,10 @@ void ggml_compute_forward_scale( { ggml_compute_forward_scale_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_scale_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -8639,7 +9025,8 @@ void ggml_compute_forward_flash_attn_back( // ggml_compute_forward_ssm_conv -static void ggml_compute_forward_ssm_conv_f32( +template +static void ggml_compute_forward_ssm_conv_impl( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // conv_x @@ -8655,9 +9042,10 @@ static void ggml_compute_forward_ssm_conv_f32( const int n_s = dst->ne[2]; // number of sequences in the batch GGML_ASSERT( dst->ne[0] == nr); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(src_t)); + GGML_ASSERT(src1->nb[0] == sizeof(conv_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(src_t)); + GGML_ASSERT(dst->type == src0->type); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -8671,9 +9059,9 @@ static void ggml_compute_forward_ssm_conv_f32( for (int i2 = 0; i2 < n_t; ++i2) { // {d_conv - 1 + n_t, d_inner, n_seqs} // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} - float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + const src_t * s = (const src_t *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const conv_t * c = (const conv_t *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + src_t * x = ( src_t *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} // TODO: transpose the output for smaller strides for big batches? // d_inner @@ -8684,13 +9072,80 @@ static void ggml_compute_forward_ssm_conv_f32( // d_conv for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + sumf += type_conversion_table::to_f32(s[i0 + i1*ncs]) * type_conversion_table::to_f32(c[i0 + i1*nc]); } - x[i1] = sumf; - } - } - } -} + x[i1] = type_conversion_table::from_f32(sumf); + } + } + } +} + +// static void ggml_compute_forward_ssm_conv_q_f32( +// const ggml_compute_params * params, +// ggml_tensor * dst) { +// const ggml_tensor * src0 = dst->src[0]; // conv_x +// const ggml_tensor * src1 = dst->src[1]; // conv1d.weight + +// const int ith = params->ith; +// const int nth = params->nth; + +// const int nc = src1->ne[0]; // d_conv +// const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// const int nr = src0->ne[1]; // d_inner +// const int n_t = dst->ne[1]; // tokens per sequence +// const int n_s = dst->ne[2]; // number of sequences in the batch + +// const ggml_type type0 = src0->type; +// const size_t type0_size = ggml_type_size(type0); +// ggml_to_float_t const dequantize_row0_q = ggml_get_type_traits(type0)->to_float; +// ggml_from_float_t const quantize_row0_q = ggml_get_type_traits_cpu(type0)->from_float; + +// const ggml_type type1 = src1->type; +// const size_t type1_size = ggml_type_size(type1); +// ggml_to_float_t const dequantize_row1_q = ggml_get_type_traits(type1)->to_float; +// ggml_from_float_t const quantize_row1_q = ggml_get_type_traits_cpu(type1)->from_float; + +// GGML_ASSERT( dst->ne[0] == nr); +// GGML_ASSERT(src0->nb[0] == type0_size); +// GGML_ASSERT(src1->nb[0] == type1_size); +// GGML_ASSERT(src0->nb[1] == src0->ne[0]*type0_size); +// GGML_ASSERT(dst->type == src0->type); + +// // rows per thread +// const int dr = (nr + nth - 1)/nth; + +// // row range for this thread +// const int ir0 = dr*ith; +// const int ir1 = MIN(ir0 + dr, nr); +// const int ir = ir1 - ir0; + +// // temporary storage for dequantized lines +// float * wdata = (float *) params->wdata + (src0->ne[0] + CACHE_LINE_SIZE_F32) * ith; + +// for (int i3 = 0; i3 < n_s; ++i3) { +// for (int i2 = 0; i2 < n_t; ++i2) { +// // {d_conv - 1 + n_t, d_inner, n_seqs} +// // sliding window +// const void * s = (const void *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} +// const void * c = (const void *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} +// void * x = ( void *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + +// // TODO: transpose the output for smaller strides for big batches? +// // d_inner +// for (int i1 = 0; i1 < ir; ++i1) { +// // rowwise dot product +// // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision +// float sumf = 0.0f; + +// // d_conv +// for (int i0 = 0; i0 < nc; ++i0) { +// sumf += type_conversion_table::to_f32(s[i0 + i1*ncs]) * type_conversion_table::to_f32(c[i0 + i1*nc]); +// } +// x[i1] = type_conversion_table::from_f32(sumf); +// } +// } +// } +// } void ggml_compute_forward_ssm_conv( const ggml_compute_params * params, @@ -8698,8 +9153,68 @@ void ggml_compute_forward_ssm_conv( switch (dst->src[0]->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_conv_f32(params, dst); + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + } break; + case GGML_TYPE_F16: + { + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + } break; + case GGML_TYPE_BF16: + { + switch (dst->src[1]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_ssm_conv_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } } break; + // TODO: Support quantized types default: { GGML_ABORT("fatal error"); @@ -8770,7 +9285,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_op_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8867,7 +9382,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = ggml_softplus(dt[h]); + const float dt_soft_plus = ggml_op_softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim @@ -9150,6 +9665,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_xielu(params, dst); } break; + case GGML_UNARY_OP_SOFTPLUS: + { + ggml_compute_forward_softplus(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b45833..d6f8dedcd9c55 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index a047537b34f78..6723592c8215c 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -319,3 +319,7 @@ void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor unary_op_functor(params, dst, xielu_op_params); } +void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index fa45d9f0e636f..a4b1022db68cc 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -27,6 +27,7 @@ void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 65c7dfb6b9a49..94031a1b01008 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1404,6 +1404,8 @@ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const } } +// sum + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -1440,6 +1442,80 @@ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16 *s = sum; } +// tri + +// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'. +// Parameters: +// n - number of elements +// r - current row index +// dst - output array +// src - input array +// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c' +// c - constant value to use when not keeping original value +// type - type of triangular mask (lower, upper, etc.) +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const enum ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; break; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; break; + case GGML_TRI_TYPE_UPPER: return i > r; break; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; break; + default: GGML_ABORT("Invalid tri type"); + } +} + +inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0.0f; + } +} + +inline static void ggml_vec_tri_f16(const int n, const int r, ggml_fp16_t * dst, const ggml_fp16_t * src, bool keep_org_val, ggml_fp16_t c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0; + } +} + +inline static void ggml_vec_tri_bf16(const int n, const int r, ggml_bf16_t * dst, const ggml_bf16_t * src, bool keep_org_val, ggml_bf16_t c, enum ggml_tri_type type) { + const ggml_bf16_t zero = ggml_fp32_to_bf16(0); + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : zero; + } +} + +// cumsum + +inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = y[i - 1] + x[i]; + } + } +} + +inline static void ggml_vec_cumsum_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i - 1]) + GGML_CPU_FP16_TO_FP32(x[i])); + } + } +} + +inline static void ggml_vec_cumsum_bf16(const int n, ggml_bf16_t * y, const ggml_bf16_t * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(y[i - 1]) + GGML_BF16_TO_FP32(x[i])); + } + } +} + +// max + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -1452,6 +1528,8 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #endif } +// norm inv + inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1.f/(*s); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca876459d404d..07b426336218a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -414,6 +414,56 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + template static __device__ __forceinline__ int warp_reduce_all(int x) { if (width == ggml_cuda_get_physical_warp_size()) { diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 0000000000000..f6c5e1c3d1aac --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,173 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along an arbitrary dimension +// Each block processes one position in the non-cumsum dimensions +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int dim) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + // Map block indices to actual tensor dimensions + // blockIdx.x, blockIdx.y, blockIdx.z represent the 3 non-cumsum dimensions + // threadIdx.x represents position in the cumsum dimension + int64_t grid_indices[3] = {blockIdx.x, blockIdx.y, blockIdx.z}; + int64_t i_vals[4]; + + int grid_idx = 0; + for (int d = 0; d < 4; ++d) { + if (d == dim) { + i_vals[d] = 0; // Will be set in the loop below + } else { + i_vals[d] = grid_indices[grid_idx++]; + } + } + + const int64_t i0 = i_vals[0]; + const int64_t i1 = i_vals[1]; + const int64_t i2 = i_vals[2]; + const int64_t i3 = i_vals[3]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01 || i0 >= ne00) { + return; + } + + const int64_t ne_dim = (dim == 0) ? ne00 : (dim == 1) ? ne01 : (dim == 2) ? ne02 : ne03; + const int64_t nb_dim_src = (dim == 0) ? nb00 : (dim == 1) ? nb01 : (dim == 2) ? nb02 : nb03; + const int64_t nb_dim_dst = (dim == 0) ? nb0 : (dim == 1) ? nb1 : (dim == 2) ? nb2 : nb3; + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) { + const int64_t offset_src = i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03 + i_dim*nb_dim_src; + const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst; + + const T * src_ptr = (const T *) ((const char *) src + offset_src); + T * dst_ptr = (T *) (( char *) dst + offset_dst); + + // Load value and compute prefix sum within warp + float val = static_cast(src_ptr[0]); + val = warp_prefix_inclusive_sum(val); + dst_ptr[0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i_dim == ne_dim - 1) { + const int shmem_idx = i_dim / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) { + const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst; + T * dst_ptr = (T *) ((char *) dst + offset_dst); + + const int shmem_idx = i_dim / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_ptr[0] = static_cast(static_cast(dst_ptr[0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int dim, + cudaStream_t stream) { + + // Dimension being accumulated + const int64_t ne_dims[4] = {ne00, ne01, ne02, ne03}; + const int64_t ne_dim = ne_dims[dim]; + + // Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions + int64_t grid_dims_arr[GGML_MAX_DIMS - 1]; + int grid_idx = 0; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d != dim) { + grid_dims_arr[grid_idx++] = ne_dims[d]; + } + } + + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); + dim3 grid_dims(grid_dims_arr[0], grid_dims_arr[1], grid_dims_arr[2]); + + // Shared memory size: one float per warp + const int num_warps = (ne_dim + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + dim + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const int dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + dim, + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 0000000000000..782d1d92e9bb1 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 415a7e962d779..3bbbb6da5927a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" @@ -47,6 +48,7 @@ #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" +#include "ggml-cuda/tri.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -2513,6 +2515,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_softplus(ctx, dst); + break; case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); break; @@ -2694,6 +2699,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -3802,6 +3813,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SOFTPLUS: return ggml_is_contiguous(op->src[0]); default: return false; @@ -4105,6 +4117,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 0000000000000..1940e368ac451 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,118 @@ +#include "tri.cuh" +#include "ggml.h" +#include + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype, const int dim_x, const int dim_y) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const bool keep_org_val = isnan(c); + const T c_val = static_cast(c); + const T zero_val = static_cast(0.f); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + // Create index array matching CPU implementation + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t iX = i_vals[dim_x]; + int64_t iY = i_vals[dim_y]; + + const T * src_ptr = (const T *) ((const char *) src + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_ptr = (T *) (( char *) dst + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + dst_ptr[0] = tri_compare(iX, iY, ttype) + ? (keep_org_val ? src_ptr[0] : c_val) + : zero_val; + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype, const int dim_x, const int dim_y, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + c, ttype, dim_x, dim_y + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + const float c = ggml_get_op_params_f32(dst, 1); + const int dim_x = ggml_get_op_params_i32(dst, 2); + const int dim_y = ggml_get_op_params_i32(dst, 3); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, dim_x, dim_y, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, dim_x, dim_y, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, dim_x, dim_y, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 0000000000000..a4cc66750d3b5 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c1dc6ddbf8f81..5bc5d3c9e1361 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -101,6 +101,10 @@ static __device__ __forceinline__ float op_trunc(float x) { return trunc(x); } +static __device__ __forceinline__ float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -233,6 +237,11 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 2800c75ba3f7a..f73ab69df47a3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -71,6 +71,8 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index ec37a25337b64..882fc19d8ecba 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } -static inline float ggml_softplus(float input) { +static inline float ggml_op_softplus(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } // diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 5607deaf414a2..20106d89fc69f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -211,13 +211,14 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; + case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); }; const char * suffix = ""; - if (n % 4 == 0) { + if (n % 4 == 0 && op->type == GGML_TYPE_F32) { suffix = "_4"; } @@ -318,6 +319,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const char * op_str = "cumsum"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + // reuse existing precompiled pipeline, but allow memory size setting + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (!res) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + // one shared memory element for each simd group in the threadgroup + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + const int nsg = (ne00 + 31)/32; + ggml_metal_pipeline_set_smem(res, nsg*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + return ggml_metal_library_compile_pipeline(lib, base, name, nullptr); +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); @@ -349,7 +394,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -359,7 +403,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar const char * suffix = ""; - if (op->src[1]->ne[0] % 4 == 0) { + if (op->src[1]->ne[0] % 4 == 0 && op->src[1]->type == GGML_TYPE_F32) { suffix = "_4"; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4d58297481813..5b29904f11e3b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -111,6 +111,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 0cadd19a30fe9..0ba568c07f57e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -619,6 +619,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -666,6 +667,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: + return has_simdgroup_reduction; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7a878a657bc12..6bc0725d9db34 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -582,6 +582,49 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_sum_rows; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_cumsum; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float c; + uint32_t ttype; + int32_t dim_x; + int32_t dim_y; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7a85edbdcdb84..b1582c59358c6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -310,6 +310,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_sum_rows(ctx, idx); } break; + case GGML_OP_CUMSUM: + { + n_fuse = ggml_metal_op_cumsum(ctx, idx); + } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -956,6 +964,136 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t dim = (int32_t) op->op_params[0]; + + ggml_metal_kargs_cumsum args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cumsum(lib, op); + + // Dimension being accumulated + const int64_t ne_dim = op->src[0]->ne[dim]; + + // Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions + int64_t grid_dims[GGML_MAX_DIMS - 1]; + int grid_idx = 0; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d != dim) { + grid_dims[grid_idx++] = op->src[0]->ne[d]; + } + } + + int nth = 32; // SIMD width + + while (nth < ne_dim && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int)ne_dim); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, grid_dims[0], grid_dims[1], grid_dims[2], nth, 1, 1); + + return 1; +} + +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; + const float c = *((float *) &(op->op_params[1])); + const int32_t dim_x = (int32_t) op->op_params[2]; + const int32_t dim_y = (int32_t) op->op_params[3]; + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.c =*/ c, + /*.ttype =*/ static_cast(ttype), + /*.dim_x =*/ dim_x, + /*.dim_y =*/ dim_y + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 0d9cb8af7c1d0..86dc84b5d40f8 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -52,6 +52,8 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 424c400f24b9b..482a3a400b847 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1235,6 +1235,14 @@ kernel void kernel_scale_f32( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_scale_f16( + constant ggml_metal_kargs_scale & args, + device const half * src0, + device half * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * args.scale + args.bias; +} + kernel void kernel_scale_f32_4( constant ggml_metal_kargs_scale & args, device const float4 * src0, @@ -1398,6 +1406,22 @@ kernel void kernel_silu_f32_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 20.0f) ? x : log(1.0f + exp(x)); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + kernel void kernel_elu_f32( device const float * src0, device float * dst, @@ -1826,6 +1850,149 @@ typedef decltype(kernel_sum_rows) kernel_sum_rows_t; template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template +kernel void kernel_cumsum( + constant ggml_metal_kargs_cumsum & args, + device const char * src0, + device const char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + // Figure out the dize and stride of the cumsum dim + const int64_t ne_dim = (args.dim == 0) ? args.ne00 : (args.dim == 1) ? args.ne01 : (args.dim == 2) ? args.ne02 : args.ne03; + const int64_t nb_dim_src = (args.dim == 0) ? args.nb00 : (args.dim == 1) ? args.nb01 : (args.dim == 2) ? args.nb02 : args.nb03; + const int64_t nb_dim_dst = (args.dim == 0) ? args.nb0 : (args.dim == 1) ? args.nb1 : (args.dim == 2) ? args.nb2 : args.nb3; + + // Map threadgroup indices to actual tensor dimensions + // tgpig.x, tgpig.y, tgpig.z represent the 3 non-cumsum dimensions + // tpitg.x represents position in the cumsum dimension + int64_t grid_indices[3] = {int64_t(tgpig.x), int64_t(tgpig.y), int64_t(tgpig.z)}; + int64_t i_vals[4]; + + int grid_idx = 0; + for (int d = 0; d < 4; ++d) { + if (d == args.dim) { + i_vals[d] = 0; // Will be set in the loop below + } else { + i_vals[d] = grid_indices[grid_idx++]; + } + } + + // Base index offsets. The cumsum dim will be further offset by the position + // in the threadgroup + const int64_t i0 = i_vals[0]; + const int64_t i1 = i_vals[1]; + const int64_t i2 = i_vals[2]; + const int64_t i3 = i_vals[3]; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 || i0 >= args.ne00) { + return; + } + + // Each thread processes elements at stride ntg.x along the cumsum dimension + for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) { + const int64_t offset_src = i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 + i_dim*nb_dim_src; + const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst; + + device const T * src_ptr = (device const T *) ((device const char *) src0 + offset_src); + device T * dst_ptr = (device T *) ((device char *) dst + offset_dst); + + // Each thread does simd_prefix_inclusive_sum + float sumf = static_cast(src_ptr[0]); + sumf = simd_prefix_inclusive_sum(sumf); + dst_ptr[0] = static_cast(sumf); + + // If this is the last element of the simd group, store its value in shared memory + if (tiisg == N_SIMDWIDTH - 1 || i_dim == ne_dim - 1) { + const ushort shmem_idx = i_dim / N_SIMDWIDTH; + shmem_f32[shmem_idx] = sumf; + } + } + + // Ensure all simd groups sync here before proceeding + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each element then adds the final value of all preceding simd groups + for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) { + const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst; + device T * dst_ptr = (device T *) ((device char *) dst + offset_dst); + + const ushort shmem_idx = i_dim / N_SIMDWIDTH; + for (ushort j = 0; j < shmem_idx; ++j) { + dst_ptr[0] += static_cast(shmem_f32[j]); + } + } +} + +typedef decltype(kernel_cumsum) kernel_cumsum_t; + +template [[host_name("kernel_cumsum_f32")]] kernel kernel_cumsum_t kernel_cumsum; +template [[host_name("kernel_cumsum_f16")]] kernel kernel_cumsum_t kernel_cumsum; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cumsum_bf16")]] kernel kernel_cumsum_t kernel_cumsum; +#endif + +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) { + switch (type) { + // ggml.h:620 + case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break; + case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break; + case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break; + case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break; + } +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + const bool keep_org_val = isnan(args.c); + const T c_val = static_cast(args.c); + const T zero_val = static_cast(0.f); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + int64_t i_vals[4] = {i0, i1, i2, i3}; + int64_t iX = i_vals[args.dim_x]; + int64_t iY = i_vals[args.dim_y]; + + device const T * src_ptr = (device const T *) ((device const char *) src0 + i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_ptr = (device T *) ((device char *) dst + i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + dst_ptr[0] = _ggml_vec_tri_cmp(iX, iY, args.ttype) + ? (keep_org_val ? src_ptr[0] : c_val) + : zero_val; + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, @@ -2048,8 +2215,9 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -kernel void kernel_ssm_conv_f32_f32( +// ref: ggml.c:ggml_compute_forward_ssm_conv_impl +template +kernel void kernel_ssm_conv_impl( constant ggml_metal_kargs_ssm_conv & args, device const void * src0, device const void * src1, @@ -2067,14 +2235,14 @@ kernel void kernel_ssm_conv_f32_f32( //const int64_t n_t = args.ne1; //const int64_t n_s = args.ne2; - device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); - device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); - device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + device const src_t * s = (device const src_t *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const conv_t * c = (device const conv_t *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - sumf += s[i0] * c[i0]; + sumf += static_cast(s[i0]) * static_cast(c[i0]); } x[0] = sumf; @@ -2111,6 +2279,13 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +typedef decltype(kernel_ssm_conv_impl) kernel_ssm_conv_t; +template [[host_name("kernel_ssm_conv_f32_f32")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +template [[host_name("kernel_ssm_conv_f32_f16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_ssm_conv_f32_bf16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl; +#endif + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9be35c1be8456..4ce50c867b844 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "COS", "SUM", "SUM_ROWS", + "CUMSUM", "MEAN", "ARGMAX", "COUNT_EQUAL", @@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "TRI", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -1019,7 +1021,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cos(x)", "Σx", "Σx_k", + "cumsum(x)", "Σx/n", "argmax(x)", "count_equal(x)", @@ -1094,6 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "tri(x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -1123,7 +1127,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1148,9 +1152,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "CEIL", "ROUND", "TRUNC", + "SOFTPLUS", }; -static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20"); +static_assert(GGML_UNARY_OP_COUNT == 21, "GGML_UNARY_OP_COUNT != 21"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2341,6 +2346,31 @@ struct ggml_tensor * ggml_sum_rows( return result; } +// ggml_cumsum + +struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a, + int dim) { + + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, dim); + + result->op = GGML_OP_CUMSUM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cumsum_0( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cumsum(ctx, a, 0); +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -2668,8 +2698,8 @@ struct ggml_tensor * ggml_xielu( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); - ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); - ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 1, beta + ggml_op_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_op_softplus(alpha_p)); ggml_set_op_params_f32(result, 3, beta); ggml_set_op_params_f32(result, 4, eps); @@ -2724,6 +2754,14 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_softplus + +struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + // ggml_glu static struct ggml_tensor * ggml_glu_impl( @@ -5028,6 +5066,50 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } +// ggml_tri + +struct ggml_tensor * ggml_tri_dims( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype, + int dim_x, + int dim_y) { + + GGML_ASSERT(dim_x >= 0 && dim_x < GGML_MAX_DIMS); + GGML_ASSERT(dim_y >= 0 && dim_y < GGML_MAX_DIMS); + GGML_ASSERT(dim_x != dim_y); + GGML_ASSERT(a->ne[dim_x] == a->ne[dim_y]); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, tritype); + ggml_set_op_params_f32(result, 1, constant); + ggml_set_op_params_i32(result, 2, dim_x); + ggml_set_op_params_i32(result, 3, dim_y); + + result->op = GGML_OP_TRI; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype) { + return ggml_tri_dims(ctx, a, constant, tritype, 0, 1); +} + +struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype) { + + return ggml_tri(ctx, a, nan(""), tritype); +} + // ggml_argsort struct ggml_tensor * ggml_argsort( diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2b39366271ff9..777b97ec6bbb1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1380,7 +1380,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + return std::max(1024u, 16u*model.n_tensors()); //DEBUG!!!!!!!!! } llm_graph_result * llama_context::get_gf_res_reserve() const { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f9751b3183694..a3d8444223c73 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -458,8 +476,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; } // @@ -1843,6 +1899,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1911,10 +1970,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67927..caba9779b5d48 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph @@ -232,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -364,22 +370,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01bdf..a1b45e4a3cce3 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1987135ca6a2e..7f6f016a8eb5b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6839,8 +6839,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_type_r */ params.type_k, + /* recurrent_type_s */ params.type_v, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae1c5..428f3b8db9751 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -421,6 +421,18 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } ++qs.i_ffn_up; } + else if (name.find("ssm_conv1d") != std::string::npos) { + // go as low as F16 for now + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: + case LLAMA_FTYPE_MOSTLY_BF16: + break; + default: + { + new_type = GGML_TYPE_F16; + } + } + } // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; //} @@ -859,9 +871,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize shortconv 2D weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index b9a363b32b6b3..13b70f88e3e39 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-impl.h" + llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, @@ -241,9 +243,129 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + if (n_seq_tokens == 1) { + // if (true) { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // If single-token, use ssm_scan op + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } else { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // otherwise, use the SSD formulation + + // extract the state(s) for the sequences identified by ids + if (ssm->ne[3] != ids->ne[0]) { + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + } + // ssm -> {d_state, head_dim, n_head, n_seqs} + + // step 1: compute dt softplus + // NOTE: In other implementations, the bias is added after + // the softplus. This shouldn't be a problem, but it's a + // difference. + ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + cb(dt_softplus, "dt_softplus", il); + + // step 2: compute dtA and dtX + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + cb(dtA, "dtA", il); + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + cb(dtX, "dtX", il); + + + // step 3: compute CB + uint32_t repeats = n_head / n_group; + ggml_tensor * C_perm = ggml_permute(ctx, C, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * B_perm = ggml_permute(ctx, B, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs} + cb(CB, "CB", il); + + // step 4: compute decay + ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA, + dtA->ne[0], dtA->ne[1], dtA->ne[2], dtA->ne[3] * n_seq_tokens); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + cb(segsum, "segsum", il); + ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1} + decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs} + cb(decay, "decay", il); + + // step 5: compute surrogate_attention_matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay)); + ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); + + // step 6: compute y + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX, 1, 2, 0, 3)); + ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); + cb(y, "y", il); + + // step 7: compute dtxdecay + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + cb(decay_last, "decay_last", il); + B_perm = ggml_cont(ctx, B_perm); + B_perm = ggml_repeat_4d(ctx, B_perm, + B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX, decay_last); + dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + cb(dtxdecay, "dtxdecay", il); + + // step 8: compute next_state + ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); + } + cb(next_state, "next_state", il); + + // TODO: Skip y and state updates if no previous state + + // step 9: update from previous state + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA, 1)); // {n_head, chunk_size_i, n_seqs} + cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs} + cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm))); + cb(next_state, "next_state_updated", il); + + // step 10: update from previous y + ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C, 0, 2, 1, 3), ssm); + cb(y_prev, "y_prev", il); + y_prev = ggml_mul(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0))); + cb(y_prev, "y_prev_mul", il); + y = ggml_add(ctx, y, y_prev); + cb(y, "y_updated", il); + + // Concat the output y and state + if (ssm->type != y->type) { + ssm = ggml_cast(ctx, ssm, y->type); + } + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + 0); + return out; + } }; ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 967a53c63d86d..2d08de1740758 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -175,6 +176,33 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); } +static std::vector ggml_get_float_value(uint8_t * buf, ggml_type type, size_t i, size_t bs, + bool quantized, std::vector & vq) { + const auto * tt = ggml_get_type_traits(type); + std::vector tv; + if (type == GGML_TYPE_F16) { + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (type == GGML_TYPE_BF16) { + tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); + } else if (type == GGML_TYPE_F32) { + tv.push_back(*(float *) &buf[i]); + } else if (type == GGML_TYPE_I64) { + tv.push_back((float)*(int64_t *) &buf[i]); + } else if (type == GGML_TYPE_I32) { + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); + } else if (quantized) { + tt->to_float(&buf[i], vq.data(), bs); + tv.insert(tv.end(), vq.begin(), vq.end()); + } else { + GGML_ABORT("fatal error"); + } + return tv; +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -182,7 +210,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { std::vector buf(ggml_nbytes(t)); ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); - const auto * tt = ggml_get_type_traits(t->type); size_t bs = ggml_blck_size(t->type); std::vector vq(ggml_blck_size(t->type)); bool quantized = ggml_is_quantized(t->type); @@ -193,26 +220,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; - if (t->type == GGML_TYPE_F16) { - tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_BF16) { - tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_F32) { - tv.push_back(*(float *) &buf[i]); - } else if (t->type == GGML_TYPE_I64) { - tv.push_back((float)*(int64_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I32) { - tv.push_back((float)*(int32_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I16) { - tv.push_back((float)*(int16_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I8) { - tv.push_back((float)*(int8_t *) &buf[i]); - } else if (quantized) { - tt->to_float(&buf[i], vq.data(), bs); - tv.insert(tv.end(), vq.begin(), vq.end()); - } else { - GGML_ABORT("fatal error"); - } + const auto fvs = ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq); + tv.insert(tv.end(), fvs.begin(), fvs.end()); } } } @@ -221,6 +230,108 @@ static std::vector tensor_to_float(const ggml_tensor * t) { return tv; } +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void ggml_print_tensor(ggml_tensor * t, int64_t verbose = 0) { + int n = verbose >= 2 ? std::numeric_limits::max() : 3; + GGML_ASSERT(t != nullptr); + GGML_ASSERT(n > 0); + + std::stringstream src_ss; + src_ss << "("; + size_t last_src = 0; + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + last_src = i; + } + } + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + src_ss << t->src[i]->name << "{" << ggml_ne_string(t->src[i]) <<"}"; + } + if (i <= last_src) { + src_ss << ", "; + } + } + src_ss << ")"; + + printf("%s: %24s = (%s) %10s%s = {%s}\n", __func__, + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src_ss.str().c_str(), + ggml_ne_string(t).c_str()); + + std::vector tv; + tv.reserve(ggml_nelements(t)); + + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + + size_t bs = ggml_blck_size(t->type); + std::vector vq(ggml_blck_size(t->type)); + bool quantized = ggml_is_quantized(t->type); + + float sum = 0; + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + for (const auto & val : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + sum += val; + } + } + } + } + } + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + printf(" [\n"); + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + if (i2 == n && t->ne[2] > 2*n) { + printf(" ..., \n"); + i2 = t->ne[2] - n; + } + printf(" [\n"); + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + if (i1 == n && t->ne[1] > 2*n) { + printf(" ..., \n"); + i1 = t->ne[1] - n; + } + printf(" ["); + for (int64_t i0 = 0; i0 < t->ne[0]; i0++) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + if (i0 == n && t->ne[0] > 2*n) { + printf("..., "); + i0 = t->ne[0] - n; + } + for (const auto & v : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + printf("%12.4f", v); + } + if (i0 < t->ne[0] - 1) printf(", "); + } + printf("],\n"); + } + printf(" ],\n"); + } + printf(" ]\n"); + printf(" sum = %f\n", sum); + } + + // TODO: make this abort configurable/optional? + if (std::isnan(sum)) { + printf("encountered NaN - aborting\n"); + exit(0); + } +} + // normalized mean squared error = mse(a, b) / mse(a, 0) static double nmse(const float * a, const float * b, size_t n) { double mse_a_b = 0.0; @@ -993,6 +1104,8 @@ static std::unique_ptr create_printer(output_formats format) { GGML_ABORT("invalid output format"); } +// test case definition + struct test_case { virtual ~test_case() {} @@ -1039,6 +1152,9 @@ struct test_case { virtual void initialize_tensors(ggml_context * ctx) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -1069,6 +1185,9 @@ struct test_case { std::vector sentinels; + // set to 1 to print tensors, 2 to fully print tensors + int verbose = 0; + std::string current_op_name; void add_sentinel(ggml_context * ctx) { @@ -1220,6 +1339,7 @@ struct test_case { // compare struct callback_userdata { bool ok; + int verbose; double max_err; ggml_backend_t backend1; ggml_backend_t backend2; @@ -1227,6 +1347,7 @@ struct test_case { callback_userdata ud { true, + verbose, max_nmse_err(), backend1, backend2 @@ -1251,6 +1372,11 @@ struct test_case { } } + if (ud->verbose) { + ggml_print_tensor(t1, ud->verbose); + ggml_print_tensor(t2, ud->verbose); + } + std::vector f1 = tensor_to_float(t1); std::vector f2 = tensor_to_float(t2); @@ -1280,11 +1406,12 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); - //exit(1); + if (ud->verbose) { + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); + } ud->ok = false; } return true; @@ -1327,7 +1454,7 @@ struct test_case { ggml_tensor * out = build_graph(ctx.get()); current_op_name = op_desc(out); if (!matches_filter(out, op_names_filter)) { - //printf(" %s: skipping\n", op_desc(out).c_str()); + // printf(" %s: skipping\n", op_desc(out).c_str()); return true; } @@ -1830,6 +1957,9 @@ struct test_unary : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -1895,6 +2025,9 @@ struct test_glu : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -1953,6 +2086,9 @@ struct test_glu_split : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2013,6 +2149,9 @@ struct test_swiglu_oai : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check for NaNs in GELU init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2070,6 +2209,9 @@ struct test_get_rows : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2123,6 +2265,9 @@ struct test_get_rows_back : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2208,6 +2353,9 @@ struct test_set_rows : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2334,6 +2482,9 @@ struct test_argmax : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2395,6 +2546,9 @@ struct test_count_equal : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2649,6 +2803,9 @@ struct test_cpy : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // test extended range of values to check if casting between f32 and i32 is consistent init_tensor_uniform(t, -150.f, 150.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -2747,6 +2904,9 @@ struct test_bin_bcast : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -2821,6 +2981,9 @@ struct test_add_id : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3081,6 +3244,9 @@ struct test_rms_norm : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3124,6 +3290,9 @@ struct test_rms_norm_back : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3180,6 +3349,9 @@ struct test_rms_norm_mul_add : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -10.f, 10.f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3270,6 +3442,9 @@ struct test_ssm_scan : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -3563,6 +3738,9 @@ struct test_mul_mat_id : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3677,6 +3855,9 @@ struct test_sqrt : public test_case { // fill with positive values for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 50.0f, 100.0f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3717,6 +3898,9 @@ struct test_log : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass: init_tensor_uniform(t, 0.9f, 1.1f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3752,6 +3936,9 @@ struct test_sin : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -3795,6 +3982,9 @@ struct test_cos : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -4221,6 +4411,9 @@ struct test_rope : public test_case { init_tensor_uniform(t); } } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -4748,6 +4941,9 @@ struct test_argsort : public test_case { } else { GGML_ABORT("fatal error"); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -5074,6 +5270,89 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_CUMSUM +struct test_cumsum : public test_case { + const ggml_type type; + const std::array ne; + const int64_t dim; + const std::array permute; + + std::string vars() override { + return VARS_TO_STR4(type, ne, dim, permute); + } + + test_cumsum(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}, + int64_t dim = 0, + std::array permute = {-1, -1, -1, -1}) + : type(type), ne(ne), dim(dim), permute(permute) {} + + + double max_nmse_err() override { + // Lower precision types have expected precision errors in lower bits + if (type == GGML_TYPE_BF16 || type == GGML_TYPE_F16) { + return 1e-5; + } + return 1e-7; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + if (permute[0] != -1) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + } + + ggml_tensor * out = ggml_cumsum(ctx, a, dim); + ggml_set_name(out, "out"); + + return out; + } +}; + +// GGML_OP_TRI +struct test_tri : public test_case { + const ggml_type type; + const std::array ne; + const ggml_tri_type tri_type; + const float c; + const int64_t dim_x; + const int64_t dim_y; + const std::array permute; + + std::string vars() override { + return VARS_TO_STR7(type, ne, tri_type, c, dim_x, dim_y, permute); + } + + test_tri(ggml_tri_type tri_type, + ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 1, 1}, + float c = nan(""), + int64_t dim_x = 0, + int64_t dim_y = 1, + // don't permute by default + std::array permute = {-1, -1, -1, -1}) + : type(type), ne(ne), tri_type(tri_type), c(c), + dim_x(dim_x), dim_y(dim_y), permute(permute) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + if (permute[0] != -1) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + } + + ggml_tensor * out = ggml_tri_dims(ctx, a, c, tri_type, dim_x, dim_y); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_MEAN struct test_mean : public test_case { const ggml_type type; @@ -5584,6 +5863,9 @@ struct test_flash_attn_ext : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5628,6 +5910,9 @@ struct test_cross_entropy_loss : public test_case { // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients. for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, -100.0f, 100.0f); + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5713,6 +5998,9 @@ struct test_opt_step_adamw : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5752,6 +6040,9 @@ struct test_opt_step_sgd : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t, 0.0f, 1.0f); // sgd_params need non-negative values. + if (verbose) { + ggml_print_tensor(t, verbose); + } } } @@ -5894,6 +6185,9 @@ struct test_llm : public test_case { } else { init_tensor_uniform(t); } + if (verbose) { + ggml_print_tensor(t, verbose); + } } } }; @@ -6193,7 +6487,7 @@ static const ggml_type other_types[] = { }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low -static std::vector> make_test_cases_eval() { +static std::vector> make_test_cases_eval(int verbose = 0) { std::vector> test_cases; std::default_random_engine rng(0); @@ -7222,6 +7516,31 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + // non-contiguous + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3})); + // alternate dim + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1)); + + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + // non-contiguous + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 0, 1, {2, 0, 1, 3})); + // alternate dims + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 1, 2)); + for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); @@ -7329,6 +7648,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_falcon(2)); #endif + // set verbose on all test cases + for (auto & tc : test_cases) { + tc->verbose = verbose; + } + return test_cases; } @@ -7407,6 +7731,31 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + // non-contiguous + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3})); + // alternate dim + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1)); + + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + // non-contiguous + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 0, 1, {2, 0, 1, 3})); + // alternate dims + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {1, 8, 8, 1}, nan(""), 1, 2)); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { @@ -7489,11 +7838,23 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } + for (int64_t d_conv : {3, 4}) { + for (int64_t d_inner: {1024, 1536, 2048}) { + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); + } + } + + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 + return test_cases; } static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer) { + printer * output_printer, int verbose) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -7512,7 +7873,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op }; if (mode == MODE_TEST) { - auto test_cases = make_test_cases_eval(); + auto test_cases = make_test_cases_eval(verbose); filter_test_cases(test_cases, params_filter); ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); if (backend_cpu == NULL) { @@ -7701,6 +8062,7 @@ static void usage(char ** argv) { printf(" --output specifies output format (default: console, options: console, sql, csv)\n"); printf(" --list-ops lists all available GGML operations\n"); printf(" --show-coverage shows test coverage\n"); + printf(" --verbose | -v print tensors during ops (can specify multiple times)\n"); } int main(int argc, char ** argv) { @@ -7709,6 +8071,7 @@ int main(int argc, char ** argv) { const char * op_names_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; + int verbose = 0; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { @@ -7756,6 +8119,8 @@ int main(int argc, char ** argv) { } else if (strcmp(argv[i], "--show-coverage") == 0) { show_test_coverage(); return 0; + } else if (strcmp(argv[i], "--verbose") == 0 || strcmp(argv[i], "-v") == 0) { + ++verbose; } else { usage(argv); return 1; @@ -7808,7 +8173,7 @@ int main(int argc, char ** argv) { false, "", ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024, true)); - bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get()); + bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), verbose); if (ok) { n_ok++;