|
11 | 11 | from peft import LoraConfig as _LoraConfig |
12 | 12 | from torch.distributed import ProcessGroup |
13 | 13 | from text_generation_server.utils.log import log_master |
14 | | - |
15 | | -from text_generation_server.adapters.config import AdapterConfig, ModuleMap |
16 | 14 | from text_generation_server.utils.import_utils import SYSTEM |
| 15 | +from text_generation_server.adapters.config import AdapterConfig, ModuleMap |
17 | 16 | from text_generation_server.utils.kernels import load_kernel |
18 | 17 | from text_generation_server.adapters.weights import ( |
19 | 18 | AdapterBatchMetadata, |
@@ -128,17 +127,27 @@ def __init__( |
128 | 127 | self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 |
129 | 128 | self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 |
130 | 129 |
|
131 | | - self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) |
132 | 130 | self._is_transposed = False |
| 131 | + if SYSTEM == "ipex": |
| 132 | + self._use_cutlass_shrink = False |
| 133 | + # [num_layers, r, hidden_size] |
| 134 | + weights_a = [w.transpose(0, 1).contiguous() for w in weights_a] |
| 135 | + self._weights_a = torch.stack(weights_a) |
| 136 | + |
| 137 | + # [num_layers, hidden_size, r] |
| 138 | + weights_b = [w.transpose(0, 1).contiguous() for w in weights_b] |
| 139 | + self._weights_b = torch.stack(weights_b) |
| 140 | + else: |
| 141 | + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) |
| 142 | + # [num_layers, hidden_size, r] |
| 143 | + weights_a = [ |
| 144 | + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() |
| 145 | + for w in weights_a |
| 146 | + ] |
| 147 | + self._weights_a = torch.stack(weights_a) |
133 | 148 |
|
134 | | - # [num_layers, hidden_size, r] |
135 | | - weights_a = [ |
136 | | - punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a |
137 | | - ] |
138 | | - self._weights_a = torch.stack(weights_a) |
139 | | - |
140 | | - # [num_layers, r, hidden_size] |
141 | | - self._weights_b = torch.stack(weights_b) |
| 149 | + # [num_layers, r, hidden_size] |
| 150 | + self._weights_b = torch.stack(weights_b) |
142 | 151 |
|
143 | 152 | self.adapter_config = adapter_config |
144 | 153 |
|
@@ -175,7 +184,10 @@ def _transpose_weights(self): |
175 | 184 |
|
176 | 185 | @classmethod |
177 | 186 | def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: |
178 | | - return [BatchLoraWeights] |
| 187 | + if SYSTEM == "ipex": |
| 188 | + return [IPEXBatchLoraWeights] |
| 189 | + else: |
| 190 | + return [BatchLoraWeights] |
179 | 191 |
|
180 | 192 | # prepare pre-loaded lora weights for use in the model. |
181 | 193 | # |
@@ -245,17 +257,20 @@ def prepare_weights( |
245 | 257 | lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale |
246 | 258 |
|
247 | 259 | # pad lora ranks to be compatible with sgmv |
248 | | - lora_a_list = [ |
249 | | - punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list |
250 | | - ] |
251 | | - lora_b_list = [ |
252 | | - punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list |
253 | | - ] |
254 | | - |
255 | | - if lora_a_list: |
256 | | - # update rank if it was padded |
257 | | - padded_rank = lora_a_list[0].size(1) |
258 | | - config.r = padded_rank |
| 260 | + if SYSTEM != "ipex": |
| 261 | + lora_a_list = [ |
| 262 | + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) |
| 263 | + for w in lora_a_list |
| 264 | + ] |
| 265 | + lora_b_list = [ |
| 266 | + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) |
| 267 | + for w in lora_b_list |
| 268 | + ] |
| 269 | + |
| 270 | + if lora_a_list: |
| 271 | + # update rank if it was padded |
| 272 | + padded_rank = lora_a_list[0].size(1) |
| 273 | + config.r = padded_rank |
259 | 274 |
|
260 | 275 | return LoraWeights( |
261 | 276 | *shard_lora_weights( |
@@ -471,6 +486,115 @@ def load( |
471 | 486 | ) |
472 | 487 |
|
473 | 488 |
|
| 489 | +@dataclass |
| 490 | +class IPEXBatchLoraWeights(BatchLoraWeights): |
| 491 | + @classmethod |
| 492 | + def load( |
| 493 | + self, |
| 494 | + adapter_weights: Dict[int, AdapterWeights], |
| 495 | + meta: AdapterBatchMetadata, |
| 496 | + prefill: bool, |
| 497 | + prefill_head_indices: Optional[torch.Tensor], |
| 498 | + ) -> Optional["BatchLoraWeights"]: |
| 499 | + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} |
| 500 | + adapter_weights = { |
| 501 | + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) |
| 502 | + } |
| 503 | + if not adapter_weights: |
| 504 | + return None |
| 505 | + |
| 506 | + first_weights = next(iter(adapter_weights.values())) |
| 507 | + device = first_weights.weights_a.device |
| 508 | + segment_indices = meta.segment_indices |
| 509 | + |
| 510 | + lora_a = { |
| 511 | + idx: adapter_weights[idx].weights_a |
| 512 | + for idx in segment_indices |
| 513 | + if idx in adapter_weights |
| 514 | + } |
| 515 | + lora_b = { |
| 516 | + idx: adapter_weights[idx].weights_b |
| 517 | + for idx in segment_indices |
| 518 | + if idx in adapter_weights |
| 519 | + } |
| 520 | + adapter_index_configs = { |
| 521 | + idx: adapter_weights[idx].adapter_config |
| 522 | + for idx in segment_indices |
| 523 | + if idx in adapter_weights |
| 524 | + } |
| 525 | + if len(lora_a) != 0: |
| 526 | + lora_a_ptr = torch.stack(list(lora_a.values())) |
| 527 | + if len(lora_b) != 0: |
| 528 | + lora_b_ptr = torch.stack(list(lora_b.values())) |
| 529 | + |
| 530 | + use_sgmv = True if prefill else False |
| 531 | + |
| 532 | + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} |
| 533 | + |
| 534 | + rank_indices = defaultdict(list) |
| 535 | + for segment_idx, adapter_idx in enumerate(segment_indices): |
| 536 | + if adapter_idx not in adapter_weights: |
| 537 | + continue |
| 538 | + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) |
| 539 | + |
| 540 | + if prefill_head_indices is not None: |
| 541 | + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] |
| 542 | + for head_index in prefill_head_indices: |
| 543 | + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters |
| 544 | + if head_index < meta.adapter_segments[j]: |
| 545 | + prefill_head_segment_ends[-1] += 1 |
| 546 | + else: |
| 547 | + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) |
| 548 | + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) |
| 549 | + j += 1 |
| 550 | + |
| 551 | + rank_data = {} |
| 552 | + segment_starts = None |
| 553 | + segment_ends = None |
| 554 | + if use_sgmv: |
| 555 | + segment_starts = meta.adapter_segments[:-1] |
| 556 | + segment_ends = meta.adapter_segments[1:] |
| 557 | + if prefill_head_indices is not None: |
| 558 | + segment_starts = prefill_head_segment_starts[:-1] |
| 559 | + segment_ends = prefill_head_segment_ends[1:] |
| 560 | + batch_indices = [ |
| 561 | + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() |
| 562 | + ] |
| 563 | + for rank, indices in rank_indices.items(): |
| 564 | + adapters_indices = [] |
| 565 | + lora_a_keys = list(lora_a.keys()) |
| 566 | + for segment_idx in batch_indices: |
| 567 | + if segment_idx in indices: |
| 568 | + adapters_indices.append( |
| 569 | + lora_a_keys.index(segment_indices[segment_idx]) |
| 570 | + ) |
| 571 | + else: |
| 572 | + adapters_indices.append(-1) |
| 573 | + adapters_indices = torch.tensor( |
| 574 | + adapters_indices, dtype=torch.int64, device=device |
| 575 | + ) |
| 576 | + if use_sgmv: |
| 577 | + adapters_indices = adapters_indices[segment_starts] |
| 578 | + rank_data[rank] = RankSegments( |
| 579 | + rank=rank, |
| 580 | + tmp_shrink=None, |
| 581 | + tmp_expand=None, |
| 582 | + lora_a_ptr=lora_a_ptr, |
| 583 | + lora_b_ptr=lora_b_ptr, |
| 584 | + segment_starts=segment_starts, |
| 585 | + segment_ends=segment_ends, |
| 586 | + indices=adapters_indices, |
| 587 | + ) |
| 588 | + |
| 589 | + return BatchLoraWeights( |
| 590 | + lora_a=lora_a, |
| 591 | + lora_b=lora_b, |
| 592 | + adapter_index_configs=adapter_index_configs, |
| 593 | + rank_data=rank_data, |
| 594 | + use_sgmv=use_sgmv, |
| 595 | + ) |
| 596 | + |
| 597 | + |
474 | 598 | def get_scaling_factor( |
475 | 599 | lora_alpha: int, |
476 | 600 | r: int, |
|
0 commit comments