Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 90 additions & 7 deletions src/extensions_ref/src/quantized_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include <cstdint>
#include <iostream>
#include <sstream>

#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/dtype.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h"
#include "tiny_llm_ext.h"

#ifdef _METAL_
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#endif

namespace tiny_llm_ext_ref {
Expand All @@ -23,8 +23,8 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale
const bool transpose_b, // Whether to transpose b
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) {
throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16");
if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16 && scales.dtype() != mx::float32) {
throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16 or float32");
}
if (scales.dtype() != biases.dtype()) {
throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype");
Expand Down Expand Up @@ -143,15 +143,98 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con
});
}

template<typename T>
void quantized_matmul_impl_typed(
const mx::array &scales, const mx::array &biases,
const mx::array &a, const mx::array &b,
mx::array &out, int group_size, int bits, mx::Stream stream) {

out.set_data(mx::allocator::malloc(out.nbytes()));
auto &encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);

encoder.dispatch([out_ptr = out.data<T>(), out_shape = out.shape(), out_strides = out.strides(),
group_size = group_size, bits = bits,
a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b),
scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() {

// each `group_size` continuous weighted elements are packed into a group and each weight is quantized into `bits` bits
// thus each `group_size` continuous weighted elements takes `group_size * bits / 32` uint32_t elements in b
// when decoding the group of weights, the scales and biases are repeated for `group_size` times (shared by all elements in the group)
int m = a.shape()[0], n = a.shape()[1], k = b.shape()[0];

// row => group => item => pack
const int group_per_row = n / group_size; // b[k, :] = [ group_0, group_1, ..., group_(group_per_row-1) ]
const int packs_per_item = 32 / bits; // each uint32_t element can store `packs_per_item` packed elements
const int items_per_group = group_size / packs_per_item; // each group contains `items_per_group` uint32_t elements

const T *a_ptr = a.data<T>(),
*scales_ptr = scales.data<T>(), *biases_ptr = biases.data<T>();
const uint32_t *b_ptr = b.data<uint32_t>();

uint32_t pack_mask = (1 << bits) - 1;

for (int i = 0; i < m; i++) {
for (int j = 0; j < k; j++) {
float sum = 0;
for (int group_idx = 0; group_idx < group_per_row; group_idx++) {
int64_t scales_idx = mx::elem_to_loc(j * group_per_row + group_idx, scales.shape(), scales.strides());
int64_t biases_idx = mx::elem_to_loc(j * group_per_row + group_idx, biases.shape(), biases.strides());
T scale = scales_ptr[scales_idx], bias = biases_ptr[biases_idx];

int64_t a_idx = mx::elem_to_loc(i * n + group_idx * group_size, a.shape(), a.strides());
int64_t b_idx = mx::elem_to_loc((j * n + group_idx * group_size) / packs_per_item, b.shape(), b.strides());

for (int item_idx = 0; item_idx < items_per_group; item_idx++) {
uint32_t b_val = b_ptr[b_idx]; // fetch one uint32_t element in current group (item), so we use type uint32_t to store it
uint8_t *b_bytes = reinterpret_cast<uint8_t *>(&b_val); // reinterpret the uint32_t element as a byte array (32 = one byte * 4)

for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) {
// extract the pack(4 bits) from the byte array
// pack_idx / 2 is the index of the byte array, and (pack_idx % 2) * bits is the shift amount
// when pack_idx is even, extract the low 4 bits, otherwise extract the high 4 bits
// (pack_7, pack_6, pack_5, pack_4, pack_3, pack_2, pack_1, pack_0) => (b_bytes[3], b_bytes[2], b_bytes[1], b_bytes[0])
uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & pack_mask;
float a_val = static_cast<float>(a_ptr[a_idx]);
float b_val_real = static_cast<float>(item_val) * static_cast<float>(scale) + static_cast<float>(bias);
sum += a_val * b_val_real;
a_idx += 1;
}
b_idx += 1;
}
}
int64_t out_idx = mx::elem_to_loc(i * k + j, out_shape, out_strides);
out_ptr[out_idx] = static_cast<T>(sum);
}
}
});
}

void QuantizedMatmul::eval_cpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
auto &scales = inputs[0];
auto &biases = inputs[1];
auto &a = inputs[2];
auto &b = inputs[3];
auto &out = outputs[0];

// TODO: dispatch to f32, f16, bf16
quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream());
// quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream());
switch (a.dtype()) {
case mx::float16:
quantized_matmul_impl_typed<float16_t>(scales, biases, a, b, out, group_size_, bits_, stream());
break;
case mx::float32:
quantized_matmul_impl_typed<float>(scales, biases, a, b, out, group_size_, bits_, stream());
break;
case mx::bfloat16:
quantized_matmul_impl_typed<mx::bfloat16_t>(scales, biases, a, b, out, group_size_, bits_, stream());
break;
default:
throw std::runtime_error("Unsupported dtype for quantized_matmul");
}
}

void QuantizedMatmul::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
Expand Down
2 changes: 1 addition & 1 deletion src/extensions_ref/src/tiny_llm_ext.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "mlx/ops.h"
#include "mlx/utils.h"
#include "mlx/primitives.h"

namespace mx = mlx::core;
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def assert_allclose(
b = np.array(b)
if precision == mx.float32:
rtol = rtol or 1.0e-5
atol = atol or 1.0e-8
atol = atol or 1.0e-6
elif precision == mx.float16:
rtol = rtol or 3.0e-2
atol = atol or 1.0e-5
Expand Down
8 changes: 8 additions & 0 deletions tests_refsol/test_week_2_day_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ def test_task_2_quantized_matmul_simple_f16_gpu():

def test_task_2_quantized_matmul_complex_f16_gpu():
quantized_matmul_helper(mx.gpu, False, mx.float16)


def test_task_1_quantized_matmul_simple_f32_cpu():
quantized_matmul_helper(mx.cpu, True, mx.float32)


def test_task_1_quantized_matmul_complex_f32_cpu():
quantized_matmul_helper(mx.cpu, False, mx.float32)