|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import os |
3 | 4 | import random |
4 | 5 |
|
5 | 6 | import pytest |
6 | 7 | import torch |
7 | 8 |
|
| 9 | +from tests.utils import multi_gpu_test |
8 | 10 | from vllm import _custom_ops as ops |
| 11 | +from vllm.distributed import ( |
| 12 | + init_distributed_environment, |
| 13 | + initialize_model_parallel, |
| 14 | + tensor_model_parallel_all_gather, |
| 15 | + tensor_model_parallel_all_reduce, |
| 16 | +) |
| 17 | +from vllm.distributed.parallel_state import ( |
| 18 | + get_tensor_model_parallel_world_size, |
| 19 | +) |
9 | 20 | from vllm.lora.ops.triton_ops import fused_moe_lora |
10 | 21 | from vllm.platforms import current_platform |
| 22 | +from vllm.utils.network_utils import get_open_port |
11 | 23 |
|
12 | 24 |
|
13 | 25 | @pytest.fixture(autouse=True) |
@@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel( |
122 | 134 | max_loras, |
123 | 135 | num_experts, |
124 | 136 | block_size, |
| 137 | + fully_sharded=False, |
| 138 | + offset=0, |
125 | 139 | ): |
126 | 140 | max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) |
127 | 141 | max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) |
@@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel( |
195 | 209 | config["NUM_STAGES"], |
196 | 210 | config["SPLIT_K"], |
197 | 211 | mul_routed_weight, |
| 212 | + fully_sharded=fully_sharded, |
| 213 | + offset=offset, |
198 | 214 | ) |
199 | 215 |
|
200 | | - return output |
201 | | - |
202 | 216 |
|
203 | 217 | def use_torch( |
204 | 218 | hidden_states, |
@@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel( |
317 | 331 | ) |
318 | 332 |
|
319 | 333 | torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) |
| 334 | + |
| 335 | + |
| 336 | +@multi_gpu_test(num_gpus=2) |
| 337 | +@pytest.mark.parametrize("num_tokens", [100]) |
| 338 | +@pytest.mark.parametrize("top_k_num", [6]) |
| 339 | +@pytest.mark.parametrize("num_experts", [64]) |
| 340 | +@pytest.mark.parametrize("max_loras", [4]) |
| 341 | +@pytest.mark.parametrize("N", [1408]) |
| 342 | +@pytest.mark.parametrize("K", [2048]) |
| 343 | +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) |
| 344 | +@pytest.mark.parametrize("block_size", [16]) |
| 345 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 346 | +@pytest.mark.parametrize("seed", SEED) |
| 347 | +@pytest.mark.parametrize("column_parallel", [True, False]) |
| 348 | +def test_fused_moe_lora_kernel_fully_sharded( |
| 349 | + num_tokens, |
| 350 | + top_k_num, |
| 351 | + num_experts, |
| 352 | + max_loras, |
| 353 | + N, |
| 354 | + K, |
| 355 | + max_lora_rank, |
| 356 | + block_size, |
| 357 | + dtype, |
| 358 | + seed, |
| 359 | + column_parallel, |
| 360 | +): |
| 361 | + current_platform.seed_everything(seed) |
| 362 | + # the number of randomly generated sentences. |
| 363 | + num_sequences = 10 |
| 364 | + # generate data |
| 365 | + topk_ids, topk_weights, token_lora_mapping = sample_data( |
| 366 | + num_tokens, num_sequences, max_loras, num_experts, top_k_num |
| 367 | + ) |
| 368 | + |
| 369 | + def run_torch_spawn(fn, nprocs): |
| 370 | + torch.multiprocessing.spawn( |
| 371 | + fn, |
| 372 | + args=( |
| 373 | + nprocs, |
| 374 | + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", |
| 375 | + dtype, |
| 376 | + seed, |
| 377 | + N, |
| 378 | + K, |
| 379 | + num_tokens, |
| 380 | + topk_ids, |
| 381 | + topk_weights, |
| 382 | + token_lora_mapping, |
| 383 | + max_lora_rank, |
| 384 | + top_k_num, |
| 385 | + max_loras, |
| 386 | + num_experts, |
| 387 | + block_size, |
| 388 | + column_parallel, |
| 389 | + ), |
| 390 | + nprocs=nprocs, |
| 391 | + ) |
| 392 | + |
| 393 | + run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2) |
| 394 | + |
| 395 | + |
| 396 | +def use_fused_moe_lora_kernel_tensor_parallel( |
| 397 | + local_rank, |
| 398 | + world_size, |
| 399 | + init_method, |
| 400 | + dtype, |
| 401 | + seed, |
| 402 | + N, |
| 403 | + K, |
| 404 | + num_tokens, |
| 405 | + topk_ids, |
| 406 | + topk_weights, |
| 407 | + token_lora_mapping, |
| 408 | + max_lora_rank, |
| 409 | + top_k_num, |
| 410 | + max_loras, |
| 411 | + num_experts, |
| 412 | + block_size, |
| 413 | + column_parallel, |
| 414 | +): |
| 415 | + def _get_shard_slice(shard_size): |
| 416 | + return slice(local_rank * shard_size, (local_rank + 1) * shard_size) |
| 417 | + |
| 418 | + current_platform.seed_everything(seed) |
| 419 | + |
| 420 | + device = torch.device(f"cuda:{local_rank}") |
| 421 | + torch.cuda.set_device(device) |
| 422 | + torch.set_default_device(device) |
| 423 | + torch.set_default_dtype(dtype) |
| 424 | + |
| 425 | + init_distributed_environment( |
| 426 | + world_size=world_size, |
| 427 | + rank=local_rank, |
| 428 | + local_rank=local_rank, |
| 429 | + distributed_init_method=init_method, |
| 430 | + ) |
| 431 | + initialize_model_parallel(world_size, 1) |
| 432 | + tp_size = get_tensor_model_parallel_world_size() |
| 433 | + |
| 434 | + input_dim = K if column_parallel else N |
| 435 | + output_dim = N if column_parallel else K |
| 436 | + |
| 437 | + # init lora weights |
| 438 | + lora_a = torch.rand( |
| 439 | + ( |
| 440 | + max_loras, |
| 441 | + num_experts, |
| 442 | + max_lora_rank, |
| 443 | + input_dim, |
| 444 | + ), |
| 445 | + dtype=dtype, |
| 446 | + ) |
| 447 | + lora_b = torch.rand( |
| 448 | + ( |
| 449 | + max_loras, |
| 450 | + num_experts, |
| 451 | + output_dim, |
| 452 | + max_lora_rank, |
| 453 | + ), |
| 454 | + dtype=dtype, |
| 455 | + ) |
| 456 | + |
| 457 | + hidden_states = torch.rand( |
| 458 | + ( |
| 459 | + num_tokens, |
| 460 | + input_dim, |
| 461 | + ), |
| 462 | + dtype=dtype, |
| 463 | + ) |
| 464 | + |
| 465 | + output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype) |
| 466 | + topk_ids = topk_ids.to(device) |
| 467 | + topk_weights = topk_weights.to(device) |
| 468 | + token_lora_mapping = token_lora_mapping.to(device) |
| 469 | + |
| 470 | + ref_output = use_torch( |
| 471 | + hidden_states, |
| 472 | + token_lora_mapping, |
| 473 | + topk_ids, |
| 474 | + [lora_a], |
| 475 | + [lora_b], |
| 476 | + top_k_num, |
| 477 | + ) |
| 478 | + |
| 479 | + if column_parallel: |
| 480 | + # Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim, |
| 481 | + # and Lora B is sliced along the output dim |
| 482 | + lora_a_shard_size = max_lora_rank // tp_size |
| 483 | + lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :] |
| 484 | + max_lora_rank = lora_a_shard_size |
| 485 | + offset = 0 |
| 486 | + |
| 487 | + lora_b_shard_size = output_dim // tp_size |
| 488 | + lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :] |
| 489 | + output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous() |
| 490 | + else: |
| 491 | + # Row parallel (e.g. down proj): LoRA A is sliced along the input dim, |
| 492 | + # and LoRA B is sliced along the output dim |
| 493 | + lora_a_shard_size = input_dim // tp_size |
| 494 | + lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)] |
| 495 | + hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)] |
| 496 | + |
| 497 | + lora_b_shard_size = output_dim // tp_size |
| 498 | + lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :] |
| 499 | + offset = lora_b_shard_size * local_rank |
| 500 | + |
| 501 | + use_fused_moe_lora_kernel( |
| 502 | + topk_ids, |
| 503 | + topk_weights, |
| 504 | + token_lora_mapping, |
| 505 | + max_lora_rank, |
| 506 | + top_k_num, |
| 507 | + [lora_a], |
| 508 | + [lora_b], |
| 509 | + hidden_states, |
| 510 | + output, |
| 511 | + max_loras, |
| 512 | + num_experts, |
| 513 | + block_size, |
| 514 | + fully_sharded=True, |
| 515 | + offset=offset, |
| 516 | + ) |
| 517 | + |
| 518 | + if column_parallel: |
| 519 | + output = tensor_model_parallel_all_gather(output) |
| 520 | + else: |
| 521 | + output = tensor_model_parallel_all_reduce(output) |
| 522 | + |
| 523 | + torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1) |
0 commit comments