File tree Expand file tree Collapse file tree 5 files changed +23
-10
lines changed
server/text_generation_server Expand file tree Collapse file tree 5 files changed +23
-10
lines changed Original file line number Diff line number Diff line change @@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
8787
8888RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
8989
90- RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
90+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
9191
9292# Text Generation Inference base env
9393ENV HF_HOME=/data \
@@ -98,9 +98,7 @@ ENV HF_HOME=/data \
9898
9999
100100WORKDIR /usr/src
101- RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
102-
103- RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
101+ RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
104102
105103# Install server
106104COPY proto proto
Original file line number Diff line number Diff line change 201201if MAMBA_AVAILABLE :
202202 __all__ .append (Mamba )
203203
204- FLASH_TRANSFORMERS_BACKEND = torch .cuda .is_available ()
204+ FLASH_TRANSFORMERS_BACKEND = torch .cuda .is_available () or SYSTEM == "ipex"
205+
205206try :
206207 from text_generation_server .models .transformers_flash_causal_lm import (
207208 TransformersFlashCausalLM ,
Original file line number Diff line number Diff line change 1212from text_generation_server .layers .attention import paged_attention , attention , Seqlen
1313from text_generation_server .layers .attention .kv_cache import KVScales , KVCache
1414from text_generation_server .models .globals import ATTENTION
15-
15+ from text_generation_server . utils . import_utils import SYSTEM
1616
1717tracer = trace .get_tracer (__name__ )
1818
@@ -115,8 +115,11 @@ def __init__(
115115 if torch .cuda .is_available ():
116116 device = torch .device (f"cuda:{ rank } " )
117117 dtype = default_dtype if dtype is None else dtype
118- elif hasattr (torch , "xpu" ) and torch .xpu .is_available ():
119- device = torch .device ("xpu" )
118+ elif SYSTEM == "ipex" :
119+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
120+ device = torch .device (f"xpu:{ rank } " )
121+ else :
122+ device = torch .device ("cpu" )
120123 dtype = default_dtype if dtype is None else dtype
121124 else :
122125 raise ValueError (
Original file line number Diff line number Diff line change 1414from text_generation_server .layers .attention .kv_cache import KVScales , KVCache
1515from text_generation_server .models .globals import ATTENTION
1616import torch .nn .functional as F
17+ from text_generation_server .utils .import_utils import SYSTEM
1718
1819tracer = trace .get_tracer (__name__ )
1920
@@ -174,8 +175,11 @@ def __init__(
174175 if torch .cuda .is_available ():
175176 device = torch .device (f"cuda:{ rank } " )
176177 dtype = default_dtype if dtype is None else dtype
177- elif hasattr (torch , "xpu" ) and torch .xpu .is_available ():
178- device = torch .device ("xpu" )
178+ elif SYSTEM == "ipex" :
179+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
180+ device = torch .device (f"xpu:{ rank } " )
181+ else :
182+ device = torch .device ("cpu" )
179183 dtype = default_dtype if dtype is None else dtype
180184 else :
181185 raise ValueError (
Original file line number Diff line number Diff line change @@ -73,6 +73,13 @@ def initialize_torch_distributed():
7373 if SYSTEM == "ipex" :
7474 import intel_extension_for_pytorch as ipex
7575
76+ if torch .xpu .is_available ():
77+ assert (
78+ WORLD_SIZE <= torch .xpu .device_count ()
79+ ), "Each process is one xpu"
80+ device = RANK % torch .xpu .device_count ()
81+ torch .xpu .set_device (device )
82+
7683 ipex .distributed .init_process_group (
7784 backend = "ccl" ,
7885 world_size = WORLD_SIZE ,
You can’t perform that action at this time.
0 commit comments