diff --git a/src/extensions_ref/src/quantized_matmul.cpp b/src/extensions_ref/src/quantized_matmul.cpp index 0420414..86886b9 100644 --- a/src/extensions_ref/src/quantized_matmul.cpp +++ b/src/extensions_ref/src/quantized_matmul.cpp @@ -1,7 +1,8 @@ #include -#include -#include +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/dtype.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/utils.h" @@ -9,7 +10,6 @@ #ifdef _METAL_ #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/utils.h" #endif namespace tiny_llm_ext_ref { @@ -23,8 +23,8 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale const bool transpose_b, // Whether to transpose b mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { - if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) { - throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16"); + if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16 && scales.dtype() != mx::float32) { + throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16 or float32"); } if (scales.dtype() != biases.dtype()) { throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype"); @@ -143,6 +143,77 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con }); } +template +void quantized_matmul_impl_typed( + const mx::array &scales, const mx::array &biases, + const mx::array &a, const mx::array &b, + mx::array &out, int group_size, int bits, mx::Stream stream) { + + out.set_data(mx::allocator::malloc(out.nbytes())); + auto &encoder = mx::cpu::get_command_encoder(stream); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + encoder.dispatch([out_ptr = out.data(), out_shape = out.shape(), out_strides = out.strides(), + group_size = group_size, bits = bits, + a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b), + scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() { + + // each `group_size` continuous weighted elements are packed into a group and each weight is quantized into `bits` bits + // thus each `group_size` continuous weighted elements takes `group_size * bits / 32` uint32_t elements in b + // when decoding the group of weights, the scales and biases are repeated for `group_size` times (shared by all elements in the group) + int m = a.shape()[0], n = a.shape()[1], k = b.shape()[0]; + + // row => group => item => pack + const int group_per_row = n / group_size; // b[k, :] = [ group_0, group_1, ..., group_(group_per_row-1) ] + const int packs_per_item = 32 / bits; // each uint32_t element can store `packs_per_item` packed elements + const int items_per_group = group_size / packs_per_item; // each group contains `items_per_group` uint32_t elements + + const T *a_ptr = a.data(), + *scales_ptr = scales.data(), *biases_ptr = biases.data(); + const uint32_t *b_ptr = b.data(); + + uint32_t pack_mask = (1 << bits) - 1; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + float sum = 0; + for (int group_idx = 0; group_idx < group_per_row; group_idx++) { + int64_t scales_idx = mx::elem_to_loc(j * group_per_row + group_idx, scales.shape(), scales.strides()); + int64_t biases_idx = mx::elem_to_loc(j * group_per_row + group_idx, biases.shape(), biases.strides()); + T scale = scales_ptr[scales_idx], bias = biases_ptr[biases_idx]; + + int64_t a_idx = mx::elem_to_loc(i * n + group_idx * group_size, a.shape(), a.strides()); + int64_t b_idx = mx::elem_to_loc((j * n + group_idx * group_size) / packs_per_item, b.shape(), b.strides()); + + for (int item_idx = 0; item_idx < items_per_group; item_idx++) { + uint32_t b_val = b_ptr[b_idx]; // fetch one uint32_t element in current group (item), so we use type uint32_t to store it + uint8_t *b_bytes = reinterpret_cast(&b_val); // reinterpret the uint32_t element as a byte array (32 = one byte * 4) + + for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) { + // extract the pack(4 bits) from the byte array + // pack_idx / 2 is the index of the byte array, and (pack_idx % 2) * bits is the shift amount + // when pack_idx is even, extract the low 4 bits, otherwise extract the high 4 bits + // (pack_7, pack_6, pack_5, pack_4, pack_3, pack_2, pack_1, pack_0) => (b_bytes[3], b_bytes[2], b_bytes[1], b_bytes[0]) + uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & pack_mask; + float a_val = static_cast(a_ptr[a_idx]); + float b_val_real = static_cast(item_val) * static_cast(scale) + static_cast(bias); + sum += a_val * b_val_real; + a_idx += 1; + } + b_idx += 1; + } + } + int64_t out_idx = mx::elem_to_loc(i * k + j, out_shape, out_strides); + out_ptr[out_idx] = static_cast(sum); + } + } + }); +} + void QuantizedMatmul::eval_cpu(const std::vector &inputs, std::vector &outputs) { auto &scales = inputs[0]; auto &biases = inputs[1]; @@ -150,8 +221,20 @@ void QuantizedMatmul::eval_cpu(const std::vector &inputs, std::vector auto &b = inputs[3]; auto &out = outputs[0]; - // TODO: dispatch to f32, f16, bf16 - quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream()); + // quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream()); + switch (a.dtype()) { + case mx::float16: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + case mx::float32: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + case mx::bfloat16: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + default: + throw std::runtime_error("Unsupported dtype for quantized_matmul"); + } } void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector &outputs) { diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index 0fe663e..599ba9f 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -1,6 +1,6 @@ #pragma once -#include "mlx/ops.h" +#include "mlx/utils.h" #include "mlx/primitives.h" namespace mx = mlx::core; diff --git a/tests/utils.py b/tests/utils.py index c34584f..aef7e25 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ def assert_allclose( b = np.array(b) if precision == mx.float32: rtol = rtol or 1.0e-5 - atol = atol or 1.0e-8 + atol = atol or 1.0e-6 elif precision == mx.float16: rtol = rtol or 3.0e-2 atol = atol or 1.0e-5 diff --git a/tests_refsol/test_week_2_day_2.py b/tests_refsol/test_week_2_day_2.py index fcb7f1c..56d918b 100644 --- a/tests_refsol/test_week_2_day_2.py +++ b/tests_refsol/test_week_2_day_2.py @@ -49,3 +49,11 @@ def test_task_2_quantized_matmul_simple_f16_gpu(): def test_task_2_quantized_matmul_complex_f16_gpu(): quantized_matmul_helper(mx.gpu, False, mx.float16) + + +def test_task_1_quantized_matmul_simple_f32_cpu(): + quantized_matmul_helper(mx.cpu, True, mx.float32) + + +def test_task_1_quantized_matmul_complex_f32_cpu(): + quantized_matmul_helper(mx.cpu, False, mx.float32) \ No newline at end of file