Skip to content

Commit b834b4c

Browse files
authored
[USAGE] Improve error handling for weight initialization in Unquantized… (vllm-project#20321)
Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com> Signed-off-by: Rafael Koike <koike.rafael@gmail.com>
1 parent 740f064 commit b834b4c

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

vllm/attention/layer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2626
from vllm.model_executor.models.vision import get_vit_attn_backend
2727
from vllm.platforms import _Backend, current_platform
28-
from vllm.utils import direct_register_custom_op
28+
from vllm.utils import GiB_bytes, direct_register_custom_op
2929

3030
logger = init_logger(__name__)
3131
USE_XFORMERS_OPS = None
@@ -225,9 +225,26 @@ def __init__(
225225
).parallel_config.pipeline_parallel_size)
226226
]
227227

228-
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
229-
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
230-
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
228+
try:
229+
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
230+
dtype=torch.float32)
231+
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
232+
dtype=torch.float32)
233+
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
234+
dtype=torch.float32)
235+
except torch.cuda.OutOfMemoryError as e:
236+
logger.error(
237+
"Failed to initialize attention q/k/v range constants: %s", e)
238+
if torch.cuda.is_available():
239+
logger.debug("CUDA device: %s", torch.cuda.current_device())
240+
logger.debug("Allocated: %.2f GiB",
241+
torch.cuda.memory_allocated() / GiB_bytes)
242+
logger.debug("Reserved: %.2f GiB",
243+
torch.cuda.memory_reserved() / GiB_bytes)
244+
raise RuntimeError(
245+
"Failed to initialize q/k/v range constants. "
246+
"This may be caused by insufficient memory to allocate "
247+
"kv cache.") from e
231248

232249
def forward(
233250
self,

vllm/model_executor/layers/linear.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# yapf: enable
3030
from vllm.model_executor.utils import set_weight_attrs
3131
from vllm.platforms import current_platform
32+
from vllm.utils import GiB_bytes
3233

3334
logger = init_logger(__name__)
3435

@@ -190,10 +191,27 @@ def create_weights(self, layer: torch.nn.Module,
190191
output_partition_sizes: list[int], input_size: int,
191192
output_size: int, params_dtype: torch.dtype,
192193
**extra_weight_attrs):
193-
weight = Parameter(torch.empty(sum(output_partition_sizes),
194-
input_size_per_partition,
195-
dtype=params_dtype),
196-
requires_grad=False)
194+
# This method creates unquantized linear weights.
195+
# The weights are not quantized, and they are not sharded.
196+
# The amount of memory allocated for the weights is
197+
# sum(output_partition_sizes) * input_size_per_partition.
198+
try:
199+
weight = Parameter(torch.empty(sum(output_partition_sizes),
200+
input_size_per_partition,
201+
dtype=params_dtype),
202+
requires_grad=False)
203+
except torch.cuda.OutOfMemoryError as e:
204+
logger.error("Failed to create unquantized linear weights: %s", e)
205+
if torch.cuda.is_available():
206+
logger.debug("CUDA device: %s", torch.cuda.current_device())
207+
logger.debug("Allocated: %.2f GiB",
208+
torch.cuda.memory_allocated() / GiB_bytes)
209+
logger.debug("Reserved: %.2f GiB",
210+
torch.cuda.memory_reserved() / GiB_bytes)
211+
raise RuntimeError(
212+
"Failed to create unquantized linear weights. "
213+
"This may be caused by insufficient memory to allocate "
214+
"the weight.") from e
197215
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
198216
layer.register_parameter("weight", weight)
199217
set_weight_attrs(weight, extra_weight_attrs)

0 commit comments

Comments
 (0)