Skip to content

Commit f6005d6

Browse files
authored
xpu lora support (#3232)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 429dcd9 commit f6005d6

File tree

2 files changed

+257
-65
lines changed

2 files changed

+257
-65
lines changed

server/text_generation_server/adapters/lora.py

Lines changed: 147 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from peft import LoraConfig as _LoraConfig
1212
from torch.distributed import ProcessGroup
1313
from text_generation_server.utils.log import log_master
14-
15-
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
1614
from text_generation_server.utils.import_utils import SYSTEM
15+
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
1716
from text_generation_server.utils.kernels import load_kernel
1817
from text_generation_server.adapters.weights import (
1918
AdapterBatchMetadata,
@@ -128,17 +127,27 @@ def __init__(
128127
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
129128
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
130129

131-
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
132130
self._is_transposed = False
131+
if SYSTEM == "ipex":
132+
self._use_cutlass_shrink = False
133+
# [num_layers, r, hidden_size]
134+
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
135+
self._weights_a = torch.stack(weights_a)
136+
137+
# [num_layers, hidden_size, r]
138+
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
139+
self._weights_b = torch.stack(weights_b)
140+
else:
141+
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
142+
# [num_layers, hidden_size, r]
143+
weights_a = [
144+
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()
145+
for w in weights_a
146+
]
147+
self._weights_a = torch.stack(weights_a)
133148

134-
# [num_layers, hidden_size, r]
135-
weights_a = [
136-
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
137-
]
138-
self._weights_a = torch.stack(weights_a)
139-
140-
# [num_layers, r, hidden_size]
141-
self._weights_b = torch.stack(weights_b)
149+
# [num_layers, r, hidden_size]
150+
self._weights_b = torch.stack(weights_b)
142151

143152
self.adapter_config = adapter_config
144153

@@ -175,7 +184,10 @@ def _transpose_weights(self):
175184

176185
@classmethod
177186
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
178-
return [BatchLoraWeights]
187+
if SYSTEM == "ipex":
188+
return [IPEXBatchLoraWeights]
189+
else:
190+
return [BatchLoraWeights]
179191

180192
# prepare pre-loaded lora weights for use in the model.
181193
#
@@ -245,17 +257,20 @@ def prepare_weights(
245257
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
246258

247259
# pad lora ranks to be compatible with sgmv
248-
lora_a_list = [
249-
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
250-
]
251-
lora_b_list = [
252-
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
253-
]
254-
255-
if lora_a_list:
256-
# update rank if it was padded
257-
padded_rank = lora_a_list[0].size(1)
258-
config.r = padded_rank
260+
if SYSTEM != "ipex":
261+
lora_a_list = [
262+
punica_sgmv.pad_rank(w, dim=1, world_size=world_size)
263+
for w in lora_a_list
264+
]
265+
lora_b_list = [
266+
punica_sgmv.pad_rank(w, dim=0, world_size=world_size)
267+
for w in lora_b_list
268+
]
269+
270+
if lora_a_list:
271+
# update rank if it was padded
272+
padded_rank = lora_a_list[0].size(1)
273+
config.r = padded_rank
259274

260275
return LoraWeights(
261276
*shard_lora_weights(
@@ -471,6 +486,115 @@ def load(
471486
)
472487

473488

489+
@dataclass
490+
class IPEXBatchLoraWeights(BatchLoraWeights):
491+
@classmethod
492+
def load(
493+
self,
494+
adapter_weights: Dict[int, AdapterWeights],
495+
meta: AdapterBatchMetadata,
496+
prefill: bool,
497+
prefill_head_indices: Optional[torch.Tensor],
498+
) -> Optional["BatchLoraWeights"]:
499+
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
500+
adapter_weights = {
501+
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
502+
}
503+
if not adapter_weights:
504+
return None
505+
506+
first_weights = next(iter(adapter_weights.values()))
507+
device = first_weights.weights_a.device
508+
segment_indices = meta.segment_indices
509+
510+
lora_a = {
511+
idx: adapter_weights[idx].weights_a
512+
for idx in segment_indices
513+
if idx in adapter_weights
514+
}
515+
lora_b = {
516+
idx: adapter_weights[idx].weights_b
517+
for idx in segment_indices
518+
if idx in adapter_weights
519+
}
520+
adapter_index_configs = {
521+
idx: adapter_weights[idx].adapter_config
522+
for idx in segment_indices
523+
if idx in adapter_weights
524+
}
525+
if len(lora_a) != 0:
526+
lora_a_ptr = torch.stack(list(lora_a.values()))
527+
if len(lora_b) != 0:
528+
lora_b_ptr = torch.stack(list(lora_b.values()))
529+
530+
use_sgmv = True if prefill else False
531+
532+
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
533+
534+
rank_indices = defaultdict(list)
535+
for segment_idx, adapter_idx in enumerate(segment_indices):
536+
if adapter_idx not in adapter_weights:
537+
continue
538+
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
539+
540+
if prefill_head_indices is not None:
541+
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
542+
for head_index in prefill_head_indices:
543+
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
544+
if head_index < meta.adapter_segments[j]:
545+
prefill_head_segment_ends[-1] += 1
546+
else:
547+
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
548+
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
549+
j += 1
550+
551+
rank_data = {}
552+
segment_starts = None
553+
segment_ends = None
554+
if use_sgmv:
555+
segment_starts = meta.adapter_segments[:-1]
556+
segment_ends = meta.adapter_segments[1:]
557+
if prefill_head_indices is not None:
558+
segment_starts = prefill_head_segment_starts[:-1]
559+
segment_ends = prefill_head_segment_ends[1:]
560+
batch_indices = [
561+
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
562+
]
563+
for rank, indices in rank_indices.items():
564+
adapters_indices = []
565+
lora_a_keys = list(lora_a.keys())
566+
for segment_idx in batch_indices:
567+
if segment_idx in indices:
568+
adapters_indices.append(
569+
lora_a_keys.index(segment_indices[segment_idx])
570+
)
571+
else:
572+
adapters_indices.append(-1)
573+
adapters_indices = torch.tensor(
574+
adapters_indices, dtype=torch.int64, device=device
575+
)
576+
if use_sgmv:
577+
adapters_indices = adapters_indices[segment_starts]
578+
rank_data[rank] = RankSegments(
579+
rank=rank,
580+
tmp_shrink=None,
581+
tmp_expand=None,
582+
lora_a_ptr=lora_a_ptr,
583+
lora_b_ptr=lora_b_ptr,
584+
segment_starts=segment_starts,
585+
segment_ends=segment_ends,
586+
indices=adapters_indices,
587+
)
588+
589+
return BatchLoraWeights(
590+
lora_a=lora_a,
591+
lora_b=lora_b,
592+
adapter_index_configs=adapter_index_configs,
593+
rank_data=rank_data,
594+
use_sgmv=use_sgmv,
595+
)
596+
597+
474598
def get_scaling_factor(
475599
lora_alpha: int,
476600
r: int,

0 commit comments

Comments
 (0)