Skip to content

Commit 459fbde

Browse files
authored
transformers flash llm/vlm enabling in ipex (#3152)
* transformers flash llm/vlm enabling in xpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * ipex cpu could also support in function Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 449cee4 commit 459fbde

File tree

5 files changed

+23
-10
lines changed

5 files changed

+23
-10
lines changed

Dockerfile_intel

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
8787

8888
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
8989

90-
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
90+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
9191

9292
# Text Generation Inference base env
9393
ENV HF_HOME=/data \
@@ -98,9 +98,7 @@ ENV HF_HOME=/data \
9898

9999

100100
WORKDIR /usr/src
101-
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
102-
103-
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
101+
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
104102

105103
# Install server
106104
COPY proto proto

server/text_generation_server/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@
201201
if MAMBA_AVAILABLE:
202202
__all__.append(Mamba)
203203

204-
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available()
204+
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex"
205+
205206
try:
206207
from text_generation_server.models.transformers_flash_causal_lm import (
207208
TransformersFlashCausalLM,

server/text_generation_server/models/transformers_flash_causal_lm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
1313
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
1414
from text_generation_server.models.globals import ATTENTION
15-
15+
from text_generation_server.utils.import_utils import SYSTEM
1616

1717
tracer = trace.get_tracer(__name__)
1818

@@ -115,8 +115,11 @@ def __init__(
115115
if torch.cuda.is_available():
116116
device = torch.device(f"cuda:{rank}")
117117
dtype = default_dtype if dtype is None else dtype
118-
elif hasattr(torch, "xpu") and torch.xpu.is_available():
119-
device = torch.device("xpu")
118+
elif SYSTEM == "ipex":
119+
if hasattr(torch, "xpu") and torch.xpu.is_available():
120+
device = torch.device(f"xpu:{rank}")
121+
else:
122+
device = torch.device("cpu")
120123
dtype = default_dtype if dtype is None else dtype
121124
else:
122125
raise ValueError(

server/text_generation_server/models/transformers_flash_vlm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
1515
from text_generation_server.models.globals import ATTENTION
1616
import torch.nn.functional as F
17+
from text_generation_server.utils.import_utils import SYSTEM
1718

1819
tracer = trace.get_tracer(__name__)
1920

@@ -174,8 +175,11 @@ def __init__(
174175
if torch.cuda.is_available():
175176
device = torch.device(f"cuda:{rank}")
176177
dtype = default_dtype if dtype is None else dtype
177-
elif hasattr(torch, "xpu") and torch.xpu.is_available():
178-
device = torch.device("xpu")
178+
elif SYSTEM == "ipex":
179+
if hasattr(torch, "xpu") and torch.xpu.is_available():
180+
device = torch.device(f"xpu:{rank}")
181+
else:
182+
device = torch.device("cpu")
179183
dtype = default_dtype if dtype is None else dtype
180184
else:
181185
raise ValueError(

server/text_generation_server/utils/dist.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def initialize_torch_distributed():
7373
if SYSTEM == "ipex":
7474
import intel_extension_for_pytorch as ipex
7575

76+
if torch.xpu.is_available():
77+
assert (
78+
WORLD_SIZE <= torch.xpu.device_count()
79+
), "Each process is one xpu"
80+
device = RANK % torch.xpu.device_count()
81+
torch.xpu.set_device(device)
82+
7683
ipex.distributed.init_process_group(
7784
backend="ccl",
7885
world_size=WORLD_SIZE,

0 commit comments

Comments
 (0)