@@ -1656,6 +1656,172 @@ static void ggml_compute_forward_mul_mat_id(
16561656 }
16571657}
16581658
1659+ // ggml_compute_forward_delta_net
1660+
1661+ static void ggml_compute_forward_delta_net (
1662+ const struct ggml_compute_params * params ,
1663+ struct ggml_tensor * dst ) {
1664+
1665+ const struct ggml_tensor * src0 = dst -> src [0 ]; // query
1666+ const struct ggml_tensor * src1 = dst -> src [1 ]; // key
1667+ const struct ggml_tensor * src2 = dst -> src [2 ]; // value
1668+ const struct ggml_tensor * src3 = dst -> src [3 ]; // gate
1669+ const struct ggml_tensor * src4 = dst -> src [4 ]; // beta
1670+ const struct ggml_tensor * src5 = dst -> src [5 ]; // state
1671+
1672+ GGML_ASSERT (src0 -> type == GGML_TYPE_F32 );
1673+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
1674+ GGML_ASSERT (src2 -> type == GGML_TYPE_F32 );
1675+ GGML_ASSERT (src3 -> type == GGML_TYPE_F32 );
1676+ GGML_ASSERT (src4 -> type == GGML_TYPE_F32 );
1677+ GGML_ASSERT (src5 -> type == GGML_TYPE_F32 );
1678+ GGML_ASSERT (dst -> type == GGML_TYPE_F32 );
1679+
1680+ GGML_TENSOR_TERNARY_OP_LOCALS ;
1681+ GGML_TENSOR_LOCALS (int64_t , ne3 , src3 , ne );
1682+ GGML_TENSOR_LOCALS (size_t , nb3 , src3 , nb );
1683+ GGML_TENSOR_LOCALS (int64_t , ne4 , src4 , ne );
1684+ GGML_TENSOR_LOCALS (size_t , nb4 , src4 , nb );
1685+ GGML_TENSOR_LOCALS (int64_t , ne5 , src5 , ne );
1686+ GGML_TENSOR_LOCALS (size_t , nb5 , src5 , nb );
1687+
1688+ const int ith = params -> ith ;
1689+ const int nth = params -> nth ;
1690+
1691+ const int64_t S = src0 -> ne [0 ]; // head dimension
1692+ const int64_t H = src0 -> ne [1 ]; // number of heads
1693+ const int64_t n_tokens = src0 -> ne [2 ];
1694+ const int64_t n_seqs = src0 -> ne [3 ];
1695+
1696+ GGML_ASSERT (ne00 == S && ne01 == H && ne02 == n_tokens && ne03 == n_seqs );
1697+ GGML_ASSERT (ne10 == S && ne11 == H && ne12 == n_tokens && ne13 == n_seqs );
1698+ GGML_ASSERT (ne20 == S && ne21 == H && ne22 == n_tokens && ne23 == n_seqs );
1699+ GGML_ASSERT (ne30 == S && ne31 == H && ne32 == n_tokens && ne33 == n_seqs );
1700+ GGML_ASSERT (ne40 == H && ne41 == n_tokens && ne42 == n_seqs && ne43 == 1 );
1701+ GGML_ASSERT (ne50 == S && ne51 == S && ne52 == H && ne53 == n_seqs );
1702+
1703+ // Get operation parameters
1704+ bool use_qk_l2norm = ggml_get_op_params_i32 (dst , 1 ) != 0 ;
1705+ float scale ;
1706+ memcpy (& scale , ((int32_t * )dst -> op_params ) + 4 , sizeof (float ));
1707+
1708+ GGML_ASSERT (ne0 == S * H );
1709+ GGML_ASSERT (ne1 == n_tokens + S * n_seqs );
1710+
1711+ // Parallelize over sequences and heads
1712+ const int64_t n_total = n_seqs * H ;
1713+ const int64_t n_per_thread = (n_total + nth - 1 ) / nth ;
1714+ const int64_t n_start = ith * n_per_thread ;
1715+ const int64_t n_end = MIN (n_start + n_per_thread , n_total );
1716+
1717+ for (int64_t n = n_start ; n < n_end ; ++ n ) {
1718+ const int64_t seq_idx = n / H ;
1719+ const int64_t head_idx = n % H ;
1720+
1721+ // Get pointers to current sequence and head
1722+ float * q_ptr = (float * )((char * )src0 -> data + seq_idx * nb03 + head_idx * nb01 );
1723+ float * k_ptr = (float * )((char * )src1 -> data + seq_idx * nb13 + head_idx * nb11 );
1724+ float * v_ptr = (float * )((char * )src2 -> data + seq_idx * nb23 + head_idx * nb21 );
1725+ float * g_ptr = (float * )((char * )src3 -> data + seq_idx * nb33 + head_idx * nb31 );
1726+ float * beta_ptr = (float * )((char * )src4 -> data + seq_idx * nb43 );
1727+ float * state_ptr = (float * )((char * )src5 -> data + seq_idx * nb53 + head_idx * nb51 );
1728+
1729+ float * out_ptr = (float * )((char * )dst -> data + n * ne0 * sizeof (float ));
1730+ float * new_state_ptr = out_ptr + n_tokens * S ;
1731+
1732+ // Apply L2 normalization if requested
1733+ if (use_qk_l2norm ) {
1734+ // Normalize query and key
1735+ for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1736+ float q_sum = 0.0f , k_sum = 0.0f ;
1737+ for (int64_t s = 0 ; s < S ; ++ s ) {
1738+ float q_val = q_ptr [t * nb02 / sizeof (float ) + s ];
1739+ float k_val = k_ptr [t * nb12 / sizeof (float ) + s ];
1740+ q_sum += q_val * q_val ;
1741+ k_sum += k_val * k_val ;
1742+ }
1743+ float q_norm = sqrtf (q_sum + 1e-6f );
1744+ float k_norm = sqrtf (k_sum + 1e-6f );
1745+
1746+ for (int64_t s = 0 ; s < S ; ++ s ) {
1747+ q_ptr [t * nb02 / sizeof (float ) + s ] /= q_norm ;
1748+ k_ptr [t * nb12 / sizeof (float ) + s ] /= k_norm ;
1749+ }
1750+ }
1751+ }
1752+
1753+ // Apply scaling to query
1754+ for (int64_t i = 0 ; i < n_tokens * S ; ++ i ) {
1755+ q_ptr [i ] *= scale ;
1756+ }
1757+
1758+ // Apply sigmoid to beta
1759+ float * beta_sigmoid = (float * )alloca (n_tokens * sizeof (float ));
1760+ for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1761+ beta_sigmoid [t ] = 1.0f / (1.0f + expf (- beta_ptr [t * nb42 / sizeof (float )]));
1762+ }
1763+
1764+ // Complete implementation of gated delta rule
1765+ // Based on torch_recurrent_gated_delta_rule from the reference implementation
1766+
1767+ // Process each token sequentially for recurrent computation
1768+ for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1769+ // Get pointers to current token data
1770+ float * q_t = q_ptr + t * (nb02 / sizeof (float ));
1771+ float * k_t = k_ptr + t * (nb12 / sizeof (float ));
1772+ float * v_t = v_ptr + t * (nb22 / sizeof (float ));
1773+ float * g_t = g_ptr + t * (nb32 / sizeof (float ));
1774+
1775+ // Apply exponential to gate and multiply by beta
1776+ float g_exp = expf (g_t [0 ]); // g is per-head, not per-dimension
1777+ float beta_t = beta_sigmoid [t ];
1778+
1779+ // Update recurrent state: state = state * g_exp
1780+ for (int64_t i = 0 ; i < S * S ; ++ i ) {
1781+ state_ptr [i ] *= g_exp ;
1782+ }
1783+
1784+ // Compute kv_mem = (state * k_t^T).sum(dim=-1)
1785+ // This is a matrix-vector multiplication: state[S×S] @ k_t[S]
1786+ float kv_mem [S ];
1787+ for (int64_t i = 0 ; i < S ; ++ i ) {
1788+ kv_mem [i ] = 0.0f ;
1789+ for (int64_t j = 0 ; j < S ; ++ j ) {
1790+ kv_mem [i ] += state_ptr [i * S + j ] * k_t [j ];
1791+ }
1792+ }
1793+
1794+ // Compute delta = (v_t - kv_mem) * beta_t
1795+ float delta [S ];
1796+ for (int64_t i = 0 ; i < S ; ++ i ) {
1797+ delta [i ] = (v_t [i ] - kv_mem [i ]) * beta_t ;
1798+ }
1799+
1800+ // Update state: state = state + k_t * delta^T
1801+ // This is an outer product: k_t[S] ⊗ delta[S]
1802+ for (int64_t i = 0 ; i < S ; ++ i ) {
1803+ for (int64_t j = 0 ; j < S ; ++ j ) {
1804+ state_ptr [i * S + j ] += k_t [i ] * delta [j ];
1805+ }
1806+ }
1807+
1808+ // Compute output: out = (state * q_t^T).sum(dim=-1)
1809+ // This is a matrix-vector multiplication: state[S×S] @ q_t[S]
1810+ float * out_t = out_ptr + t * S ;
1811+ for (int64_t i = 0 ; i < S ; ++ i ) {
1812+ out_t [i ] = 0.0f ;
1813+ for (int64_t j = 0 ; j < S ; ++ j ) {
1814+ out_t [i ] += state_ptr [i * S + j ] * q_t [j ];
1815+ }
1816+ }
1817+ }
1818+
1819+ // Copy final state to new_state
1820+ memcpy (new_state_ptr , state_ptr , S * S * sizeof (float ));
1821+ }
1822+ }
1823+
1824+
16591825/////////////////////////////////
16601826
16611827static void ggml_compute_forward (struct ggml_compute_params * params , struct ggml_tensor * tensor ) {
@@ -1998,6 +2164,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19982164 {
19992165 ggml_compute_forward_rwkv_wkv7 (params , tensor );
20002166 } break ;
2167+ case GGML_OP_DELTA_NET :
2168+ {
2169+ ggml_compute_forward_delta_net (params , tensor );
2170+ } break ;
20012171 case GGML_OP_MAP_CUSTOM1 :
20022172 {
20032173 ggml_compute_forward_map_custom1 (params , tensor );
@@ -2291,6 +2461,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22912461 case GGML_OP_RWKV_WKV6 :
22922462 case GGML_OP_GATED_LINEAR_ATTN :
22932463 case GGML_OP_RWKV_WKV7 :
2464+ case GGML_OP_DELTA_NET :
22942465 {
22952466 n_tasks = n_threads ;
22962467 } break ;
0 commit comments