Skip to content

Commit 344331c

Browse files
committed
First draft
1 parent cd08fc3 commit 344331c

File tree

11 files changed

+856
-10
lines changed

11 files changed

+856
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3741,6 +3741,28 @@ def set_vocab(self):
37413741

37423742
super().set_vocab()
37433743

3744+
@ModelBase.register("Qwen3NextForCausalLM")
3745+
class Qwen3NextModel(Qwen3MoeModel):
3746+
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
3747+
3748+
def set_gguf_parameters(self):
3749+
super().set_gguf_parameters()
3750+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["linear_conv_kernel_dim"]))
3751+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"]))
3752+
self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"]))
3753+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"]))
3754+
self.gguf_writer.add_ssm_inner_size(self.find_hparam(["hidden_size"]) * (self.find_hparam(["linear_num_value_heads"]) // self.find_hparam(["linear_num_key_heads"])))
3755+
3756+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3757+
if name.endswith(".A_log"):
3758+
data_torch = -torch.exp(data_torch)
3759+
elif name.endswith(".dt_bias"):
3760+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
3761+
elif "conv1d" in name:
3762+
data_torch = data_torch.squeeze()
3763+
3764+
return Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
3765+
37443766

37453767
@ModelBase.register("GPT2LMHeadModel")
37463768
class GPT2Model(TextModel):

ggml/include/ggml.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,8 @@ extern "C" {
539539
GGML_OP_RWKV_WKV6,
540540
GGML_OP_GATED_LINEAR_ATTN,
541541
GGML_OP_RWKV_WKV7,
542-
542+
GGML_OP_DELTA_NET,
543+
543544
GGML_OP_UNARY,
544545

545546
GGML_OP_MAP_CUSTOM1,
@@ -2278,6 +2279,31 @@ extern "C" {
22782279
struct ggml_tensor * state,
22792280
float scale);
22802281

2282+
// Delta-Net linear layer activation
2283+
// Implements the complete Delta-Net gated linear attention mechanism
2284+
// This includes causal convolution preprocessing and gated delta rule computation
2285+
// k, v, q, g: [S, H, n_tokens, n_seqs] - key, value, query, gate tensors
2286+
// conv_weight: [conv_dim, 1, conv_kernel_size] - convolution kernel weights
2287+
// conv_bias: [conv_dim] - convolution bias (optional, can be NULL)
2288+
// beta: [H, n_tokens, n_seqs] - beta parameter for delta rule
2289+
// state: [S, S, H, n_seqs] - recurrent state tensor
2290+
// chunk_size: chunk size for chunked computation (0 for recurrent mode)
2291+
// use_qk_l2norm: whether to apply L2 normalization to query and key
2292+
// scale: attention scaling factor
2293+
GGML_API struct ggml_tensor * ggml_delta_net(
2294+
struct ggml_context * ctx,
2295+
struct ggml_tensor * k,
2296+
struct ggml_tensor * v,
2297+
struct ggml_tensor * q,
2298+
struct ggml_tensor * g,
2299+
struct ggml_tensor * conv_weight,
2300+
struct ggml_tensor * conv_bias,
2301+
struct ggml_tensor * beta,
2302+
struct ggml_tensor * state,
2303+
int chunk_size,
2304+
bool use_qk_l2norm,
2305+
float scale);
2306+
22812307
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
22822308
struct ggml_context * ctx,
22832309
struct ggml_tensor * r,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,172 @@ static void ggml_compute_forward_mul_mat_id(
16561656
}
16571657
}
16581658

1659+
// ggml_compute_forward_delta_net
1660+
1661+
static void ggml_compute_forward_delta_net(
1662+
const struct ggml_compute_params * params,
1663+
struct ggml_tensor * dst) {
1664+
1665+
const struct ggml_tensor * src0 = dst->src[0]; // query
1666+
const struct ggml_tensor * src1 = dst->src[1]; // key
1667+
const struct ggml_tensor * src2 = dst->src[2]; // value
1668+
const struct ggml_tensor * src3 = dst->src[3]; // gate
1669+
const struct ggml_tensor * src4 = dst->src[4]; // beta
1670+
const struct ggml_tensor * src5 = dst->src[5]; // state
1671+
1672+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1673+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
1674+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
1675+
GGML_ASSERT(src3->type == GGML_TYPE_F32);
1676+
GGML_ASSERT(src4->type == GGML_TYPE_F32);
1677+
GGML_ASSERT(src5->type == GGML_TYPE_F32);
1678+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1679+
1680+
GGML_TENSOR_TERNARY_OP_LOCALS;
1681+
GGML_TENSOR_LOCALS(int64_t, ne3, src3, ne);
1682+
GGML_TENSOR_LOCALS(size_t, nb3, src3, nb);
1683+
GGML_TENSOR_LOCALS(int64_t, ne4, src4, ne);
1684+
GGML_TENSOR_LOCALS(size_t, nb4, src4, nb);
1685+
GGML_TENSOR_LOCALS(int64_t, ne5, src5, ne);
1686+
GGML_TENSOR_LOCALS(size_t, nb5, src5, nb);
1687+
1688+
const int ith = params->ith;
1689+
const int nth = params->nth;
1690+
1691+
const int64_t S = src0->ne[0]; // head dimension
1692+
const int64_t H = src0->ne[1]; // number of heads
1693+
const int64_t n_tokens = src0->ne[2];
1694+
const int64_t n_seqs = src0->ne[3];
1695+
1696+
GGML_ASSERT(ne00 == S && ne01 == H && ne02 == n_tokens && ne03 == n_seqs);
1697+
GGML_ASSERT(ne10 == S && ne11 == H && ne12 == n_tokens && ne13 == n_seqs);
1698+
GGML_ASSERT(ne20 == S && ne21 == H && ne22 == n_tokens && ne23 == n_seqs);
1699+
GGML_ASSERT(ne30 == S && ne31 == H && ne32 == n_tokens && ne33 == n_seqs);
1700+
GGML_ASSERT(ne40 == H && ne41 == n_tokens && ne42 == n_seqs && ne43 == 1);
1701+
GGML_ASSERT(ne50 == S && ne51 == S && ne52 == H && ne53 == n_seqs);
1702+
1703+
// Get operation parameters
1704+
bool use_qk_l2norm = ggml_get_op_params_i32(dst, 1) != 0;
1705+
float scale;
1706+
memcpy(&scale, ((int32_t*)dst->op_params) + 4, sizeof(float));
1707+
1708+
GGML_ASSERT(ne0 == S * H);
1709+
GGML_ASSERT(ne1 == n_tokens + S * n_seqs);
1710+
1711+
// Parallelize over sequences and heads
1712+
const int64_t n_total = n_seqs * H;
1713+
const int64_t n_per_thread = (n_total + nth - 1) / nth;
1714+
const int64_t n_start = ith * n_per_thread;
1715+
const int64_t n_end = MIN(n_start + n_per_thread, n_total);
1716+
1717+
for (int64_t n = n_start; n < n_end; ++n) {
1718+
const int64_t seq_idx = n / H;
1719+
const int64_t head_idx = n % H;
1720+
1721+
// Get pointers to current sequence and head
1722+
float * q_ptr = (float *)((char *)src0->data + seq_idx * nb03 + head_idx * nb01);
1723+
float * k_ptr = (float *)((char *)src1->data + seq_idx * nb13 + head_idx * nb11);
1724+
float * v_ptr = (float *)((char *)src2->data + seq_idx * nb23 + head_idx * nb21);
1725+
float * g_ptr = (float *)((char *)src3->data + seq_idx * nb33 + head_idx * nb31);
1726+
float * beta_ptr = (float *)((char *)src4->data + seq_idx * nb43);
1727+
float * state_ptr = (float *)((char *)src5->data + seq_idx * nb53 + head_idx * nb51);
1728+
1729+
float * out_ptr = (float *)((char *)dst->data + n * ne0 * sizeof(float));
1730+
float * new_state_ptr = out_ptr + n_tokens * S;
1731+
1732+
// Apply L2 normalization if requested
1733+
if (use_qk_l2norm) {
1734+
// Normalize query and key
1735+
for (int64_t t = 0; t < n_tokens; ++t) {
1736+
float q_sum = 0.0f, k_sum = 0.0f;
1737+
for (int64_t s = 0; s < S; ++s) {
1738+
float q_val = q_ptr[t * nb02 / sizeof(float) + s];
1739+
float k_val = k_ptr[t * nb12 / sizeof(float) + s];
1740+
q_sum += q_val * q_val;
1741+
k_sum += k_val * k_val;
1742+
}
1743+
float q_norm = sqrtf(q_sum + 1e-6f);
1744+
float k_norm = sqrtf(k_sum + 1e-6f);
1745+
1746+
for (int64_t s = 0; s < S; ++s) {
1747+
q_ptr[t * nb02 / sizeof(float) + s] /= q_norm;
1748+
k_ptr[t * nb12 / sizeof(float) + s] /= k_norm;
1749+
}
1750+
}
1751+
}
1752+
1753+
// Apply scaling to query
1754+
for (int64_t i = 0; i < n_tokens * S; ++i) {
1755+
q_ptr[i] *= scale;
1756+
}
1757+
1758+
// Apply sigmoid to beta
1759+
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float));
1760+
for (int64_t t = 0; t < n_tokens; ++t) {
1761+
beta_sigmoid[t] = 1.0f / (1.0f + expf(-beta_ptr[t * nb42 / sizeof(float)]));
1762+
}
1763+
1764+
// Complete implementation of gated delta rule
1765+
// Based on torch_recurrent_gated_delta_rule from the reference implementation
1766+
1767+
// Process each token sequentially for recurrent computation
1768+
for (int64_t t = 0; t < n_tokens; ++t) {
1769+
// Get pointers to current token data
1770+
float * q_t = q_ptr + t * (nb02 / sizeof(float));
1771+
float * k_t = k_ptr + t * (nb12 / sizeof(float));
1772+
float * v_t = v_ptr + t * (nb22 / sizeof(float));
1773+
float * g_t = g_ptr + t * (nb32 / sizeof(float));
1774+
1775+
// Apply exponential to gate and multiply by beta
1776+
float g_exp = expf(g_t[0]); // g is per-head, not per-dimension
1777+
float beta_t = beta_sigmoid[t];
1778+
1779+
// Update recurrent state: state = state * g_exp
1780+
for (int64_t i = 0; i < S * S; ++i) {
1781+
state_ptr[i] *= g_exp;
1782+
}
1783+
1784+
// Compute kv_mem = (state * k_t^T).sum(dim=-1)
1785+
// This is a matrix-vector multiplication: state[S×S] @ k_t[S]
1786+
float kv_mem[S];
1787+
for (int64_t i = 0; i < S; ++i) {
1788+
kv_mem[i] = 0.0f;
1789+
for (int64_t j = 0; j < S; ++j) {
1790+
kv_mem[i] += state_ptr[i * S + j] * k_t[j];
1791+
}
1792+
}
1793+
1794+
// Compute delta = (v_t - kv_mem) * beta_t
1795+
float delta[S];
1796+
for (int64_t i = 0; i < S; ++i) {
1797+
delta[i] = (v_t[i] - kv_mem[i]) * beta_t;
1798+
}
1799+
1800+
// Update state: state = state + k_t * delta^T
1801+
// This is an outer product: k_t[S] ⊗ delta[S]
1802+
for (int64_t i = 0; i < S; ++i) {
1803+
for (int64_t j = 0; j < S; ++j) {
1804+
state_ptr[i * S + j] += k_t[i] * delta[j];
1805+
}
1806+
}
1807+
1808+
// Compute output: out = (state * q_t^T).sum(dim=-1)
1809+
// This is a matrix-vector multiplication: state[S×S] @ q_t[S]
1810+
float * out_t = out_ptr + t * S;
1811+
for (int64_t i = 0; i < S; ++i) {
1812+
out_t[i] = 0.0f;
1813+
for (int64_t j = 0; j < S; ++j) {
1814+
out_t[i] += state_ptr[i * S + j] * q_t[j];
1815+
}
1816+
}
1817+
}
1818+
1819+
// Copy final state to new_state
1820+
memcpy(new_state_ptr, state_ptr, S * S * sizeof(float));
1821+
}
1822+
}
1823+
1824+
16591825
/////////////////////////////////
16601826

16611827
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -1998,6 +2164,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19982164
{
19992165
ggml_compute_forward_rwkv_wkv7(params, tensor);
20002166
} break;
2167+
case GGML_OP_DELTA_NET:
2168+
{
2169+
ggml_compute_forward_delta_net(params, tensor);
2170+
} break;
20012171
case GGML_OP_MAP_CUSTOM1:
20022172
{
20032173
ggml_compute_forward_map_custom1(params, tensor);
@@ -2291,6 +2461,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22912461
case GGML_OP_RWKV_WKV6:
22922462
case GGML_OP_GATED_LINEAR_ATTN:
22932463
case GGML_OP_RWKV_WKV7:
2464+
case GGML_OP_DELTA_NET:
22942465
{
22952466
n_tasks = n_threads;
22962467
} break;

0 commit comments

Comments
 (0)