Skip to content

Commit 24f6104

Browse files
committed
remove state dict compress and disk decompress
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 891da51 commit 24f6104

File tree

4 files changed

+13
-469
lines changed

4 files changed

+13
-469
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 1 addition & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,12 @@
1313
# limitations under the License.
1414

1515
import json
16-
import logging
17-
import operator
1816
import os
19-
import re
20-
from contextlib import contextmanager
2117
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
18+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union
2319

2420
import compressed_tensors
2521
import torch
26-
import transformers
2722
from compressed_tensors.base import (
2823
COMPRESSION_VERSION_NAME,
2924
QUANTIZATION_CONFIG_NAME,
@@ -39,26 +34,20 @@
3934
QuantizationConfig,
4035
QuantizationScheme,
4136
QuantizationStatus,
42-
apply_quantization_config,
43-
load_pretrained_quantization_parameters,
4437
)
4538
from compressed_tensors.transform import TransformConfig
4639
from compressed_tensors.utils import (
4740
align_module_device,
4841
delete_offload_parameter,
4942
get_execution_device,
5043
get_offloaded_device,
51-
get_safetensors_folder,
52-
has_offloaded_params,
5344
register_offload_parameter,
54-
update_parameter_data,
5545
)
5646
from compressed_tensors.utils.helpers import (
5747
fix_fsdp_module_name,
5848
is_compressed_tensors_config,
5949
)
6050
from compressed_tensors.utils.match import match_named_modules
61-
from torch import Tensor
6251
from torch.nn import Module
6352
from tqdm import tqdm
6453
from transformers import AutoConfig
@@ -71,8 +60,6 @@
7160

7261
__all__ = ["ModelCompressor", "map_module_to_scheme"]
7362

74-
_LOGGER: logging.Logger = logging.getLogger(__name__)
75-
7663

7764
if TYPE_CHECKING:
7865
# dummy type if not available from transformers
@@ -488,153 +475,6 @@ def decompress_model(self, model: Module):
488475

489476
module.quantization_status = QuantizationStatus.FROZEN
490477

491-
# ----- state dict compression pathways ----- #
492-
493-
def compress(
494-
self,
495-
model: Module,
496-
state_dict: Optional[Dict[str, Tensor]] = None,
497-
show_progress: bool = False,
498-
) -> Dict[str, Tensor]:
499-
"""
500-
Compresses a dense state dict or model with sparsity and/or quantization
501-
502-
:param model: uncompressed model to compress
503-
:param state_dict: optional uncompressed state_dict to insert into model
504-
:return: compressed state dict
505-
"""
506-
507-
if state_dict is None:
508-
state_dict = model.state_dict()
509-
510-
if self.quantization_compressor is not None:
511-
module_to_scheme = map_module_to_scheme(model)
512-
# Note - compress only supports one compression format atm
513-
quant_compressor = next(iter(self.quantization_compressor.values()))
514-
state_dict = quant_compressor.compress(
515-
state_dict,
516-
names_to_scheme=module_to_scheme,
517-
show_progress=show_progress,
518-
)
519-
520-
# TODO: consider sparse compression to also be compression
521-
if self.quantization_config.format != CompressionFormat.dense.value:
522-
self.quantization_config.quantization_status = (
523-
QuantizationStatus.COMPRESSED
524-
)
525-
526-
if self.sparsity_compressor is not None:
527-
sparse_compression_targets: Set[str] = {
528-
module_name
529-
for module_name, _module in match_named_modules(
530-
model=model,
531-
targets=self.sparsity_config.targets,
532-
ignore=self.sparsity_config.ignore,
533-
)
534-
}
535-
state_dict = self.sparsity_compressor.compress(
536-
state_dict,
537-
compression_targets=sparse_compression_targets,
538-
show_progress=show_progress,
539-
)
540-
541-
# HACK: Override the dtype_byte_size function in transformers to
542-
# support float8 types. Fix is posted upstream
543-
# https://github.com/huggingface/transformers/pull/30488
544-
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
545-
546-
return state_dict
547-
548-
# ----- disk decompression pathways ----- #
549-
550-
def decompress(self, model_path: str, model: Module):
551-
"""
552-
Overwrites the weights in model with weights decompressed from model_path
553-
554-
:param model_path: path to compressed weights
555-
:param model: pytorch model to load decompressed weights into
556-
557-
Note: decompress makes use of both _replace_sparsity_weights and
558-
_replace_weights. The variations in these methods are a result of the subtle
559-
variations between the sparsity and quantization compressors. Specifically,
560-
quantization compressors return not just the decompressed weight, but the
561-
quantization parameters (e.g scales, zero_point) whereas sparsity compressors
562-
only return the decompressed weight.
563-
564-
"""
565-
model_path = get_safetensors_folder(model_path)
566-
sparse_decompressed = False
567-
quant_compressor = (
568-
next(iter(self.quantization_compressor.values()))
569-
if self.quantization_compressor is not None
570-
else None
571-
)
572-
573-
if (
574-
self.sparsity_compressor is not None
575-
and self.sparsity_config.format != CompressionFormat.dense.value
576-
):
577-
# note - decompress only supports one compressor atm
578-
params_to_ignore = None
579-
if quant_compressor is not None:
580-
params_to_ignore = quant_compressor.compression_param_names
581-
# Sparse decompression is applied on the model_path
582-
# The compressor will try and load any quantization parameters as well
583-
# params_to_skip_load will skip over quantization params from being loaded
584-
dense_gen = self.sparsity_compressor.decompress(
585-
model_path, params_to_skip_load=params_to_ignore
586-
)
587-
self._replace_sparsity_weights(dense_gen, model)
588-
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
589-
sparse_decompressed = True
590-
591-
if quant_compressor is not None:
592-
# Temporarily set quantization status to FROZEN to prevent
593-
# quantization during apply_quantization_config. This ensures
594-
# that the dtypes of the weights are not unintentionally updated.
595-
# The status is restored after quantization params are loaded.
596-
597-
with override_quantization_status(
598-
self.quantization_config, QuantizationStatus.FROZEN
599-
):
600-
apply_quantization_config(model, self.quantization_config)
601-
names_to_scheme: Set[QuantizationScheme] = {
602-
name: getattr(module, "quantization_scheme")
603-
for name, module in model.named_modules()
604-
if getattr(module, "quantization_scheme", None) is not None
605-
}
606-
# Load activation scales/zp or any other quantization parameters
607-
# Conditionally load the weight quantization parameters if we have a
608-
# dense compressor or if a sparsity compressor has already been applied
609-
load_weight_qparams = sparse_decompressed or isinstance(
610-
quant_compressor, DenseCompressor
611-
)
612-
load_pretrained_quantization_parameters(
613-
model,
614-
model_path,
615-
# TODO: all weight quantization params will be moved to the
616-
# compressor in a follow-up including initialization
617-
load_weight_qparams=load_weight_qparams,
618-
)
619-
model_path_or_state_dict = (
620-
model.state_dict() if sparse_decompressed else model_path
621-
)
622-
623-
dense_gen = quant_compressor.decompress(
624-
model_path_or_state_dict, names_to_scheme=names_to_scheme
625-
)
626-
# TODO: all weight quantization params will be moved to the compressor
627-
# to prevent duplicate parameter updates in update_parameter_data
628-
self._replace_weights(
629-
dense_gen, model, load_weight_qparams=not load_weight_qparams
630-
)
631-
632-
def freeze_quantization_status(module):
633-
module.quantization_status = QuantizationStatus.FROZEN
634-
635-
model.apply(freeze_quantization_status)
636-
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
637-
638478
def update_config(self, save_directory: str):
639479
"""
640480
Update the model config located at save_directory with compression configs
@@ -688,79 +528,6 @@ def update_config(self, save_directory: str):
688528
with open(config_file_path, "w") as config_file:
689529
json.dump(config_data, config_file, indent=2, sort_keys=True)
690530

691-
def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
692-
"""
693-
Replace the weights of the model with the
694-
provided dense weights.
695-
696-
This method iterates over the dense_weight_generator and
697-
updates the corresponding weights in the model. If a parameter
698-
name does not exist in the model, it will be skipped.
699-
700-
:param dense_weight_generator (generator): A generator that yields
701-
tuples of (name, data), where 'name' is the parameter name and
702-
'data' is the updated param data
703-
:param model: The model whose weights are to be updated.
704-
"""
705-
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
706-
split_name = name.split(".")
707-
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
708-
module = operator.attrgetter(prefix)(model)
709-
710-
params_device = next(module.parameters()).device
711-
device = "cpu" if has_offloaded_params(module) else params_device
712-
delattr(module, param_name)
713-
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
714-
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
715-
register_offload_parameter(module, param_name, param)
716-
717-
def _replace_weights(
718-
self, dense_weight_generator, model: Module, load_weight_qparams: bool = True
719-
):
720-
"""
721-
Replace the weights of the model with the
722-
provided dense weights.
723-
724-
This method iterates over the dense_weight_generator and
725-
updates the corresponding weights in the model. If a parameter
726-
name does not exist in the model, it will be skipped.
727-
728-
:param dense_weight_generator (generator): A generator that yields
729-
tuples of (name, data), where 'name' is the parameter name and
730-
'data' is the updated param data
731-
:param model: The model whose weights are to be updated.
732-
"""
733-
734-
for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
735-
module = operator.attrgetter(mod_path)(model)
736-
737-
params_device = next(module.parameters()).device
738-
device = "cpu" if has_offloaded_params(module) else params_device
739-
740-
for param_name, param_data in data.items():
741-
if hasattr(module, param_name):
742-
# If compressed, will have an incorrect dtype for transformers >4.49
743-
# TODO: we can also just skip initialization of scales/zp if in
744-
# decompression in init to be consistent with loading which happens
745-
# later as well however, update_data does a good shape check -
746-
# should be moved to the compressor
747-
748-
if param_name == "weight":
749-
delattr(module, param_name)
750-
requires_grad = param_data.dtype in (
751-
torch.float16,
752-
torch.float32,
753-
torch.bfloat16,
754-
)
755-
param = torch.nn.Parameter(
756-
param_data.to(device), requires_grad=requires_grad
757-
)
758-
register_offload_parameter(module, param_name, param)
759-
elif load_weight_qparams:
760-
# Should already be registered to the correct device for
761-
# for scales/zero-points
762-
update_parameter_data(module, param_data, param_name)
763-
764531

765532
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
766533
"""
@@ -775,35 +542,3 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
775542
and module.quantization_scheme.weights is not None
776543
)
777544
}
778-
779-
780-
# HACK: Override the dtype_byte_size function in transformers to support float8 types
781-
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
782-
def new_dtype_byte_size(dtype):
783-
if dtype == torch.bool:
784-
return 1 / 8
785-
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
786-
if bit_search is None:
787-
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
788-
bit_size = int(bit_search.groups()[0])
789-
return bit_size // 8
790-
791-
792-
@contextmanager
793-
def override_quantization_status(
794-
config: QuantizationConfig, status: QuantizationStatus
795-
):
796-
"""
797-
Within this context, the quantization status will be set to the
798-
supplied status. After the context exits, the original status
799-
will be restored.
800-
801-
:param config: the quantization config to override
802-
:param status: the status to temporarily set
803-
"""
804-
original_status = config.quantization_status
805-
config.quantization_status = status
806-
try:
807-
yield
808-
finally:
809-
config.quantization_status = original_status

0 commit comments

Comments
 (0)