Skip to content

Commit 51a0b9d

Browse files
authored
IPEX support FP8 kvcache/softcap/slidingwindow (#3144)
* IPEX support FP8 kvcache Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add kvcache dtype Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add softcap and slidingwindow Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * kv scale in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove triton installation, will be installed with torch Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * install xelink lib Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * softcap default -1.0 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * softcap default -1.0 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent f208ba6 commit 51a0b9d

File tree

3 files changed

+87
-20
lines changed

3 files changed

+87
-20
lines changed

Dockerfile_intel

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ ENV HF_HOME=/data \
9898

9999

100100
WORKDIR /usr/src
101-
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
101+
102+
RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
102103

103104
# Install server
104105
COPY proto proto
@@ -116,8 +117,8 @@ ENV TORCH_LLM_ALLREDUCE=1
116117
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
117118
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
118119

119-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
120-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
120+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
121+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
121122
# Install benchmarker
122123
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
123124
# Install router
@@ -180,13 +181,13 @@ RUN case ${TARGETPLATFORM} in \
180181

181182
RUN conda install -c conda-forge gperftools mkl
182183

183-
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
184-
RUN pip install triton==3.1.0 py-libnuma
184+
RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu
185+
RUN pip install triton==3.2.0 py-libnuma
185186

186187
WORKDIR /usr/src
187188

188-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
189-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
189+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
190+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
190191

191192

192193
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so

server/text_generation_server/layers/attention/ipex.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
BLOCK_SIZE,
99
)
1010

11-
SUPPORTS_WINDOWING = False
11+
if ATTENTION == "flashdecoding-ipex":
12+
SUPPORTS_WINDOWING = True
13+
else:
14+
SUPPORTS_WINDOWING = False
1215

1316

1417
def attention(
@@ -25,13 +28,19 @@ def attention(
2528
causal: bool = True,
2629
softcap: Optional[float] = None,
2730
):
28-
if softcap is not None:
29-
raise NotImplementedError("softcap is not available in IPEX")
3031

3132
out = torch.empty_like(query)
33+
kv_cache_dtype = "auto"
34+
if kv_cache.key.dtype == torch.float8_e5m2:
35+
kv_cache_dtype = "fp8_e5m2"
36+
if kv_cache.key.dtype == torch.float8_e4m3fn:
37+
kv_cache_dtype = "fp8_e4m3"
3238

3339
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
3440
if ATTENTION == "flashdecoding-ipex":
41+
window_size_right = -1 if window_size_left == -1 else 0
42+
if softcap is None:
43+
softcap = -1.0
3544
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
3645
out,
3746
query.contiguous() if query.device.type == "xpu" else query,
@@ -45,8 +54,18 @@ def attention(
4554
causal,
4655
block_tables,
4756
None,
57+
window_size_left=window_size_left,
58+
window_size_right=window_size_right,
59+
kv_cache_dtype=kv_cache_dtype,
60+
k_scale=kv_scales.key_scale_cpu,
61+
v_scale=kv_scales.value_scale_cpu,
62+
softcap=softcap,
4863
)
4964
else:
65+
if softcap is not None:
66+
raise NotImplementedError(
67+
"softcap is not available in IPEX paged attention"
68+
)
5069
ipex.llm.functional.varlen_attention(
5170
query.contiguous() if query.device.type == "xpu" else query,
5271
key.contiguous() if key.device.type == "xpu" else key,
@@ -80,12 +99,16 @@ def paged_attention(
8099
softcap: Optional[float] = None,
81100
window_size_left: Optional[int] = -1,
82101
):
83-
if softcap is not None:
84-
raise NotImplementedError("softcap is not available in IPEX")
85-
86102
out = torch.empty_like(query)
87-
103+
kv_cache_dtype = "auto"
104+
if kv_cache.key.dtype == torch.float8_e5m2:
105+
kv_cache_dtype = "fp8_e5m2"
106+
if kv_cache.key.dtype == torch.float8_e4m3fn:
107+
kv_cache_dtype = "fp8_e4m3"
88108
if ATTENTION == "flashdecoding-ipex":
109+
window_size_right = -1 if window_size_left == -1 else 0
110+
if softcap is None:
111+
softcap = -1.0
89112
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
90113
out,
91114
query.contiguous() if query.device.type == "xpu" else query,
@@ -99,9 +122,19 @@ def paged_attention(
99122
True,
100123
block_tables,
101124
None,
125+
window_size_left=window_size_left,
126+
window_size_right=window_size_right,
127+
kv_cache_dtype=kv_cache_dtype,
128+
k_scale=kv_scales.key_scale_cpu,
129+
v_scale=kv_scales.value_scale_cpu,
130+
softcap=softcap,
102131
)
103132
else:
104133
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
134+
if softcap is not None:
135+
raise NotImplementedError(
136+
"softcap is not available in IPEX paged attention"
137+
)
105138
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
106139
out,
107140
query,
@@ -114,6 +147,8 @@ def paged_attention(
114147
BLOCK_SIZE,
115148
max_s,
116149
None,
150+
k_scale=kv_scales.key_scale_cpu,
151+
v_scale=kv_scales.value_scale_cpu,
117152
)
118153
return out
119154

server/text_generation_server/layers/attention/kv_cache.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,20 @@ def __init__(
6868
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
6969
if not (
7070
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
71-
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
71+
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex"))
72+
or (ATTENTION == "flashdecoding-ipex")
7273
):
7374
raise ValueError(
74-
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
75+
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
7576
)
7677
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
7778
raise ValueError(
7879
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
7980
)
81+
if device.type == "cpu" and dtype == torch.float8_e4m3fn:
82+
raise ValueError(
83+
"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU"
84+
)
8085

8186
element_size = torch.tensor([], dtype=dtype).element_size()
8287
if SYSTEM == "ipex" and device.type == "xpu":
@@ -133,15 +138,16 @@ def can_scale(self, kv_scales: KVScales) -> bool:
133138
return False
134139
elif self.dtype == torch.float8_e4m3fn and (
135140
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
136-
or (ATTENTION == "paged" and SYSTEM == "rocm")
141+
or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"])
142+
or (ATTENTION == "flashdecoding-ipex")
137143
):
138144
log_once(logger.info, "Using FP8 KV cache scales")
139145
return True
140146
else:
141147
# We have scales, but not the correct FP8 cache type, so warn once.
142148
log_once(
143149
logger.info,
144-
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
150+
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX",
145151
)
146152
return False
147153

@@ -207,8 +213,20 @@ def store(
207213
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
208214
import intel_extension_for_pytorch as ipex
209215

216+
kv_cache_dtype = "auto"
217+
if key_cache.dtype == torch.float8_e5m2:
218+
kv_cache_dtype = "fp8_e5m2"
219+
if key_cache.dtype == torch.float8_e4m3fn:
220+
kv_cache_dtype = "fp8_e4m3"
210221
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
211-
key, value, key_cache, value_cache, slots
222+
key,
223+
value,
224+
key_cache,
225+
value_cache,
226+
slots,
227+
kv_cache_dtype=kv_cache_dtype,
228+
k_scale=kv_scales.key_scale_cpu,
229+
v_scale=kv_scales.value_scale_cpu,
212230
)
213231
else:
214232
paged_reshape_and_cache(
@@ -267,8 +285,21 @@ def paged_reshape_and_cache(
267285
elif SYSTEM == "ipex":
268286
import intel_extension_for_pytorch as ipex
269287

288+
kv_cache_dtype = "auto"
289+
if key_cache.dtype == torch.float8_e5m2:
290+
kv_cache_dtype = "fp8_e5m2"
291+
if key_cache.dtype == torch.float8_e4m3fn:
292+
kv_cache_dtype = "fp8_e4m3"
293+
270294
ipex.llm.modules.PagedAttention.reshape_and_cache(
271-
key, value, key_cache, value_cache, slots
295+
key,
296+
value,
297+
key_cache,
298+
value_cache,
299+
slots,
300+
kv_cache_dtype=kv_cache_dtype,
301+
k_scale=k_scale,
302+
v_scale=v_scale,
272303
)
273304
else:
274305
raise NotImplementedError(

0 commit comments

Comments
 (0)