Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit ec3d38a

Browse files
authored
[LLM Runtime]Magicoder graph (#1053)
1 parent db209b4 commit ec3d38a

27 files changed

+160
-84
lines changed

intel_extension_for_transformers/llm/runtime/graph/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ LLM Runtime supports the following models:
193193
<td>✅</td>
194194
<td>✅</td>
195195
<td>Latest</td>
196+
</tr>
197+
<tr>
198+
<td><a href="https://huggingface.co/ise-uiuc/Magicoder-S-DS-6.7B" target="_blank" rel="noopener noreferrer">Magicoder-6.7B</td>
199+
<td>✅</td>
200+
<td>✅</td>
201+
<td>✅</td>
202+
<td>✅</td>
203+
<td>Latest</td>
196204
</tr>
197205
<tr>
198206
<td><a href="https://huggingface.co/bigcode/starcoderbase-1b" target="_blank" rel="noopener noreferrer">StarCoder-1B</a>,

intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2980,7 +2980,7 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
29802980

29812981
struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
29822982
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
2983-
bool padding_left, float freq_base) {
2983+
bool padding_left, float freq_base, float freq_scale) {
29842984
NE_ASSERT(n_past >= 0 || n_keep >= 0);
29852985
NE_ASSERT(padding_left);
29862986
bool is_node = false;
@@ -3020,7 +3020,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
30203020

30213021
ne_scratch_load(ctx);
30223022

3023-
ne_set_op_params(result, &freq_base, sizeof(freq_base));
3023+
float params[] = {freq_base, freq_scale};
3024+
ne_set_op_params(result, &params, sizeof(params));
3025+
30243026
result->op = NE_OP_ROPE;
30253027
result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
30263028
result->src0 = a;
@@ -3031,18 +3033,20 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
30313033
}
30323034

30333035
struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
3034-
int prompt_size, float freq_base) {
3035-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base);
3036+
int prompt_size, float freq_base, float freq_scale) {
3037+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
30363038
}
30373039

30383040
struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
3039-
int prompt_size, float freq_base) {
3040-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base);
3041+
int prompt_size, float freq_base, float freq_scale) {
3042+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
30413043
}
30423044

30433045
struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
3044-
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base) {
3045-
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base);
3046+
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
3047+
float freq_scale) {
3048+
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
3049+
freq_scale);
30463050
}
30473051

30483052
// ne_rope_back
@@ -3078,13 +3082,16 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
30783082
}
30793083

30803084
struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
3081-
int prompt_size, int* n_padding, float freq_base) {
3082-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base);
3085+
int prompt_size, int* n_padding, float freq_base, float freq_scale) {
3086+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
3087+
freq_scale);
30833088
}
30843089

30853090
struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
3086-
int mode, int prompt_size, int* n_padding, float freq_base) {
3087-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base);
3091+
int mode, int prompt_size, int* n_padding, float freq_base,
3092+
float freq_scale) {
3093+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
3094+
freq_scale);
30883095
}
30893096

30903097
// ne_alibi
@@ -7867,9 +7874,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
78677874
NE_ASSERT(src1->type == NE_TYPE_I32);
78687875
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params
78697876

7870-
float freq_base = 10000.0f;
7871-
memcpy(&freq_base, dst->op_params, sizeof(float));
7872-
static const float freq_scale = 1.0f;
7877+
const float freq_base = ((float*)(dst->op_params))[0];
7878+
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
78737879

78747880
const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
78757881
const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
@@ -8043,7 +8049,10 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80438049
// row index used to determine which thread to use
80448050
int ir = 0;
80458051

8046-
const float theta_scale = powf(10000.0, -2.0f / n_dims);
8052+
const float freq_base = ((float*)(dst->op_params))[0];
8053+
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
8054+
8055+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
80478056

80488057
const bool skip = mode & 1;
80498058
const bool is_neox = mode & 2;
@@ -8053,7 +8062,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80538062
NE_ASSERT(("shift RoPE is only implemented for the vanilla mode", !is_shift || !(is_glm || is_neox || skip)));
80548063

80558064
if (is_shift) {
8056-
float theta = n_past;
8065+
float theta = n_past * freq_scale;
80578066
ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL;
80588067
if (cossin == NULL) {
80598068
cossin = malloc(ne0 * sizeof(ne_fp16_t));
@@ -8098,7 +8107,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
80988107
if (ir++ < ir0) continue;
80998108
if (ir > ir1) break;
81008109

8101-
float theta = (float)p;
8110+
float theta = freq_scale * (float)p;
81028111

81038112
if (!is_neox) {
81048113
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -8172,11 +8181,14 @@ static void ne_compute_forward_rope_jblas(const struct ne_compute_params* params
81728181
const int seq_len = dst->ne[1];
81738182
const int head_size = dst->ne[0];
81748183

8184+
const float freq_base = ((float*)(dst->op_params))[0];
8185+
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
8186+
81758187
if (is_shift) {
81768188
ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL;
81778189
if (cossin == NULL) {
8178-
float theta = n_past;
8179-
const float theta_scale = powf(10000.0, -2.0f / n_dims);
8190+
float theta = n_past * freq_scale;
8191+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
81808192
cossin = malloc(head_size * sizeof(ne_fp16_t));
81818193
for (int i0 = 0; i0 < head_size; i0 += 2) {
81828194
cossin[i0 + 0] = NE_FP32_TO_FP16(cosf(theta));
@@ -10016,7 +10028,7 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor
1001610028
const int n_dims = ((int32_t*)src1->data)[1];
1001710029
const int mode = ((int32_t*)src1->data)[2];
1001810030
src0->grad =
10019-
ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0, 10000.0), inplace);
10031+
ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0, 10000.0, 1.0), inplace);
1002010032
}
1002110033
if (src1->grad) {
1002210034
// noop

intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,29 +403,30 @@ NE_API struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_t
403403
// if mode & 4 == 1, especially for glm
404404
// TODO: avoid creating a new tensor every time
405405
NE_API struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
406-
int prompt_size, float freq_base);
406+
int prompt_size, float freq_base, float freq_scale);
407407

408408
// in-place, returns view(a)
409409
NE_API struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
410-
int prompt_size, float freq_base);
410+
int prompt_size, float freq_base, float freq_scale);
411411

412412
// shift all tokens by a give p (n_shift)
413413
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
414414
NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
415415
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
416-
float freq_base);
416+
float freq_base, float freq_scale);
417417

418418
// rotary position embedding backward, i.e compute dx from dy
419419
// a - dy
420420
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);
421421

422422
NE_API struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
423-
int mode, int prompt_size, int* n_padding, float freq_base);
423+
int mode, int prompt_size, int* n_padding, float freq_base,
424+
float freq_scale);
424425

425426
// in-place, returns view(a)
426427
NE_API struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past,
427428
int n_dims, int mode, int prompt_size, int* n_padding,
428-
float freq_base);
429+
float freq_base, float freq_scale);
429430

430431
// alibi position embedding
431432
// in-place, returns view(a)

intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
137137

138138
ne_set_name(query_layer, "query_layer");
139139
query_layer = ne_rope_with_padding_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size,
140-
n_padding.data(), hparams.freq_base);
140+
n_padding.data(), hparams.freq_base, hparams.freq_scale);
141141
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size]
142142

143143
ne_tensor* key_layer =
144144
ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur),
145145
cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size]
146146
key_layer = ne_rope_with_padding_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size,
147-
n_padding.data(), hparams.freq_base);
147+
n_padding.data(), hparams.freq_base, hparams.freq_scale);
148148

149149
ne_tensor* value_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size,
150150
3 * head_size * ne_element_size(cur), cur->nb[1], cur->nb[1] * qlen,

intel_extension_for_transformers/llm/runtime/graph/models/chatglm/chatglm2.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,15 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
146146
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
147147
0); // [N, heads, head_size]
148148
ne_set_name(query_layer, "query_layer");
149-
query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
149+
query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base,
150+
hparams.freq_scale);
150151

151152
struct ne_tensor* key_layer =
152153
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
153154
hidden_size * ne_element_size(cur)); // [N, kv_heads, head_size]
154155
ne_set_name(key_layer, "key_layer");
155156
key_layer = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
156-
ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
157+
ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
157158

158159
struct ne_tensor* value_layer =
159160
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
@@ -198,7 +199,8 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
198199
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
199200
// in a single eval execution
200201
if (N == 1) cossin_cache = kv_self.cossin;
201-
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
202+
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
203+
hparams.freq_scale);
202204
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // perm back
203205
}
204206

@@ -253,7 +255,8 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
253255
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
254256
// in a single eval execution
255257
if (N == 1) cossin_cache = kv_self.cossin;
256-
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
258+
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
259+
hparams.freq_scale);
257260
}
258261
value_layer =
259262
ne_view_3d(ctx0, model.layers[il].v_cache, // tensor

intel_extension_for_transformers/llm/runtime/graph/models/falcon/falcon.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
162162
fused_qkv_row_nb, (n_embd + n_head_kv * head_dim) * ne_element_size(cur));
163163

164164
// using mode = 2 for neox mode
165-
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0, hparams.freq_base);
166-
Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0, hparams.freq_base);
165+
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0, hparams.freq_base, hparams.freq_scale);
166+
Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0, hparams.freq_base, hparams.freq_scale);
167167

168168
// self-attention
169169
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));

intel_extension_for_transformers/llm/runtime/graph/models/gptj/gptj.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,10 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
186186
Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head, N, batch_size);
187187
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
188188
}
189-
Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
189+
Qcur =
190+
ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
190191
Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
191-
ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
192+
ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
192193
ne_set_name(Qcur, "Qcur");
193194
ne_set_name(Kcur, "Kcur");
194195
ne_set_name(Vcur, "Vcur");
@@ -293,7 +294,8 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
293294
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
294295
// in a single eval execution
295296
if (N == 1) cossin_cache = kv_self.cossin;
296-
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
297+
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
298+
hparams.freq_scale);
297299
}
298300
const auto v_size = kv_cache_info.v_bytes;
299301
V = ne_view_4d(ctx0, kv_self.v, // tensor
@@ -321,7 +323,8 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
321323
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
322324
// a single eval execution
323325
if (N == 1) cossin_cache = kv_self.cossin;
324-
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
326+
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
327+
hparams.freq_scale);
325328
K = ne_permute(ctx0, K, 0, 2, 1, 3);
326329
}
327330
} else {
@@ -332,7 +335,8 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
332335
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
333336
// a single eval execution
334337
if (N == 1) cossin_cache = kv_self.cossin;
335-
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
338+
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
339+
hparams.freq_scale);
336340
K = ne_permute(ctx0, K, 0, 2, 1, 3);
337341
}
338342

intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
188188

189189
// using mode = 2 for GPT-NeoX mode
190190
Qcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0,
191-
hparams.freq_base);
191+
hparams.freq_base, hparams.freq_scale);
192192
Kcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Kcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0,
193-
hparams.freq_base);
193+
hparams.freq_base, hparams.freq_scale);
194194
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));
195195
// store key and value to memory
196196
if (!run_mha_reordered) {

0 commit comments

Comments
 (0)