diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 72eb613753aa..2147a45d7503 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -26,7 +26,6 @@ from contextlib import contextmanager from dataclasses import dataclass, field from functools import partial -from types import MethodType from typing import TYPE_CHECKING, Any, Optional, Union import torch @@ -313,120 +312,6 @@ class ConversionEntry: GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4 -# Factory function to create LoadedParameter subclasses dynamically -def get_loaded_parameter_class(base_cls): - """ - base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor - Returns a new class that combines the base_cls with LoadedParameterMixin - - """ - - class LoadedParam(base_cls): - _inplace_methods = [ - "add_", - "mul_", - "clamp_", - "zero_", - "fill_", - "normal_", - "uniform_", - "copy_", - "erfinv_", - "log_", - "__getitem__", - "neg_", - "exp_", - "sub_", - ] - - def __new__(cls, from_existing, **kwargs): - if isinstance(from_existing, torch.nn.Parameter): - inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__) - else: - inst = super().__new__(cls, from_existing) - # we store the original object to get it back later on - inst._original = from_existing - # Explicitly override all in-place methods per instance - for method_name in inst._inplace_methods: - setattr(inst, method_name, MethodType(inst._skip, inst)) - - return inst - - def _skip(self, *args, **kwargs): - """Helper to skip in-place operations.""" - return self - - def __repr__(self): - return f"LoadedParameter(data={self.data})" - - @property - def data(self): - return super().data - - @data.setter - def data(self, new): - pass - - def __lt__(self, other): - return torch.Tensor.__lt__(self, other) - - def __le__(self, other): - return torch.Tensor.__le__(self, other) - - def __gt__(self, other): - return torch.Tensor.__gt__(self, other) - - def __ge__(self, other): - return torch.Tensor.__ge__(self, other) - - def __eq__(self, other): - return torch.Tensor.__eq__(self, other) - - def __ne__(self, other): - return torch.Tensor.__ne__(self, other) - - def __iadd__(self, *args, **kwargs): - return self - - def __isub__(self, *args, **kwargs): - return self - - def __imul__(self, *args, **kwargs): - return self - - def __imatmul__(self, *args, **kwargs): - return self - - def __itruediv__(self, *args, **kwargs): - return self - - def __ifloordiv__(self, *args, **kwargs): - return self - - def __imod__(self, *args, **kwargs): - return self - - def __ipow__(self, *args, **kwargs): - return self - - def __iand__(self, *args, **kwargs): - return self - - def __ior__(self, *args, **kwargs): - return self - - def __ixor__(self, *args, **kwargs): - return self - - def __ilshift__(self, *args, **kwargs): - return self - - def __irshift__(self, *args, **kwargs): - return self - - return LoadedParam - - def _materialize_copy(tensor, dtype=None): tensor = tensor[...] if dtype is not None: @@ -527,7 +412,6 @@ def set_param_for_module( param_value = param_value.to_local() if param_name not in module_obj._buffers: param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) - param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value) # Remove from missing keys (it's either mismatched, or all good) missing_keys.discard(layer_name) diff --git a/src/transformers/generation/watermarking.py b/src/transformers/generation/watermarking.py index da978c3c107e..348dae74929b 100644 --- a/src/transformers/generation/watermarking.py +++ b/src/transformers/generation/watermarking.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCELoss +from .. import initialization as init from ..modeling_utils import PreTrainedModel from ..utils import ModelOutput, logging from .configuration_utils import PreTrainedConfig, WatermarkingConfig @@ -387,7 +388,7 @@ def __init__(self, config): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Parameter): - module.weight.normal_(mean=0.0, std=0.02) + init.normal_(module.weight, mean=0.0, std=0.02) def _compute_posterior( self, diff --git a/src/transformers/initialization.py b/src/transformers/initialization.py new file mode 100644 index 000000000000..27780e9ac719 --- /dev/null +++ b/src/transformers/initialization.py @@ -0,0 +1,191 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from collections import defaultdict +from contextlib import contextmanager + +import torch + + +# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch +# in context managers +TORCH_INIT_FUNCTIONS = { + "uniform_": torch.nn.init.uniform_, + "normal_": torch.nn.init.normal_, + "constant_": torch.nn.init.constant_, + "ones_": torch.nn.init.ones_, + "zeros_": torch.nn.init.zeros_, + "eye_": torch.nn.init.eye_, + "dirac_": torch.nn.init.dirac_, + "xavier_uniform_": torch.nn.init.xavier_uniform_, + "xavier_normal_": torch.nn.init.xavier_normal_, + "kaiming_uniform_": torch.nn.init.kaiming_uniform_, + "kaiming_normal_": torch.nn.init.kaiming_normal_, + "trunc_normal_": torch.nn.init.trunc_normal_, + "orthogonal_": torch.nn.init.orthogonal_, + "sparse_": torch.nn.init.sparse_, +} + + +def uniform_( + tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator) + return tensor + + +def normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator) + return tensor + + +def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val) + return tensor + + +def ones_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["ones_"](tensor) + return tensor + + +def zeros_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["zeros_"](tensor) + return tensor + + +def eye_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["eye_"](tensor) + return tensor + + +def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups) + return tensor + + +def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator) + return tensor + + +def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator) + return tensor + + +def kaiming_uniform_( + tensor: torch.Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["kaiming_uniform_"]( + tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator + ) + return tensor + + +def kaiming_normal_( + tensor: torch.Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["kaiming_normal_"]( + tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator + ) + return tensor + + +def trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator) + return tensor + + +def orthogonal_( + tensor: torch.Tensor, + gain: float = 1, + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator) + return tensor + + +def sparse_( + tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator) + return tensor + + +def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + with torch.no_grad(): + return tensor.copy_(other) + return tensor + + +@contextmanager +def guard_torch_init_functions(): + """ + Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be + protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded. + + Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure + and for remote code, we also use this context manager. + """ + originals = defaultdict(dict) + try: + # Replace all torch funcs by the ones in this file + for name in TORCH_INIT_FUNCTIONS.keys(): + # Here, we need to check all modules imported, and hot patch all of them, as usually torch does + # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules, + # where MultiHeadAttention lives), so the function name is binded at import time and just doing + # `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough + for module in sys.modules.values(): + if module and hasattr(module, name): + originals[module][name] = getattr(module, name) + setattr(module, name, globals()[name]) + yield + finally: + # Set back the original functions on all modules + for module, functions in originals.items(): + for name, func in functions.items(): + setattr(module, name, func) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1383d69aea06..965084d0c24a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -44,6 +44,7 @@ from torch.distributions import constraints from torch.utils.checkpoint import checkpoint +from . import initialization as init from .configuration_utils import PreTrainedConfig from .conversion_mapping import get_checkpoint_conversion_mapping from .core_model_loading import WeightConverter, convert_and_load_state_dict_in_model, revert_weight_conversion @@ -471,9 +472,7 @@ def _get_tied_weight_keys(module: nn.Module) -> list[str]: tied_weight_keys: list[str] = [] for name, submodule in module.named_modules(): tied = getattr(submodule, "_tied_weights_keys", {}) or {} - tied_weights_dict = list(tied.keys()) - # tied_weights_dict.extend(tied.values()) - tied_weight_keys.extend([f"{name}.{k}" if name else k for k in tied_weights_dict]) + tied_weight_keys.extend([f"{name}.{k}" if name else k for k in tied.keys()]) return tied_weight_keys @@ -901,35 +900,6 @@ def _get_dtype( return config, dtype, dtype_orig -@contextmanager -def guard_nn_init_functions(flag_name: str = "_is_hf_initialized"): - import torch.nn.init as init - - originals = {} - - def make_wrapper(fn): - @wraps(fn) - def wrapped(*args, **kwargs): - # Tensor can come positionally or as a kwarg (e.g. via DeviceContext) - t = args[0] if args else kwargs.get("tensor", kwargs.get("input")) - if t is not None and getattr(t, flag_name, False): - # mimic init.* return convention (returns the tensor) - return t - return fn(*args, **kwargs) # TODO we could set is init here. - - return wrapped - - try: - for name in TORCH_INIT_FUNCTIONS: - if hasattr(init, name): - originals[name] = getattr(init, name) - setattr(init, name, make_wrapper(originals[name])) - yield - finally: - for name, fn in originals.items(): - setattr(init, name, fn) - - class PipelineParallel(Enum): inputs = 0 outputs = 1 @@ -1521,26 +1491,35 @@ def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). - - This is also used when the user is running distributed code. We add hooks to the modules here, according to - the model's tp_plan! """ - self.init_weights() - self._backward_compatibility_gradient_checkpointing() - + # Attach the different parallel plans and tied weight keys to the top-most model, so that everything is + # easily available self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {} + # Current submodel should register its tied weights keys only if the config is asking for it + if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder: + self.all_tied_weights_keys = {} + else: + self.all_tied_weights_keys = self._tied_weights_keys.copy() if self._tied_weights_keys is not None else {} # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {} self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {} for name, module in self.named_children(): + # Parallel plans if plan := getattr(module, "_ep_plan", None): self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) if plan := getattr(module, "_tp_plan", None): self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) if plan := getattr(module, "_pp_plan", None): self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) + # Always attach the keys of the children (if the children's config says to NOT tie, then it's empty) + if tied_keys := getattr(module, "all_tied_weights_keys", None): + self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()}) + + # Maybe initialize the weights and tie the keys + self.init_weights() + self._backward_compatibility_gradient_checkpointing() @property def tp_plan(self) -> dict[str, str]: @@ -2275,22 +2254,25 @@ def _init_weights(self, module): """ if hasattr(self.config, "initializer_range"): std = self.config.initializer_range or 0.02 + elif hasattr(self.config, "init_std"): + std = self.config.init_std + elif hasattr(self.config, "initializer_factor"): + std = self.config.initializer_factor else: # 0.02 is the standard default value across the library std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): if getattr(module, "weight", None) is not None: - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): if getattr(module, "weight", None) is not None: - module.weight.normal_(mean=0.0, std=std) - if getattr( - self.config, "pad_token_id", None - ) is not None and self.config.pad_token_id < module.weight.size(0): - module.weight[self.config.pad_token_id].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.MultiheadAttention): # This uses torch's original init module._reset_parameters() @@ -2303,15 +2285,15 @@ def _init_weights(self, module): ): # Norms can exist without weights (in which case they are None from torch primitives) if hasattr(module, "weight") and module.weight is not None: - module.weight.fill_(1.0) + init.ones_(module.weight) if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if isinstance(getattr(module, "gate_up_proj", None), nn.Parameter): - module.gate_up_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) if isinstance(getattr(module, "down_proj", None), nn.Parameter): - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) if isinstance(getattr(module, "gate", None), nn.Parameter): - module.gate.normal_(mean=0.0, std=std) + init.normal_(module.gate, mean=0.0, std=std) def _initialize_weights(self, module): """ @@ -2324,7 +2306,7 @@ def _initialize_weights(self, module): module._is_hf_initialized = True @torch.no_grad() - @guard_nn_init_functions() + @init.guard_torch_init_functions() def initialize_weights(self): """ This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models. @@ -2355,12 +2337,7 @@ def smart_apply(self, fn): # Let the magic happen with this simple call self.smart_apply(self._initialize_weights) - def tie_weight_source_and_target( - self, - top_level: "PreTrainedModel", - missing_keys: Optional[set[str]] = None, - module_prefix: str = "", - ): + def tie_weights(self, missing_keys: Optional[set[str]] = None): """ If set in the config, tie the weights between the input embeddings and the output embeddings, and the encoder and decoder. This relies on the `_tied_weights_keys` dict. @@ -2409,22 +2386,18 @@ def __init__(self): If you call this function, it will always tie. There is only 1 tricky case, if all weights are missing, you still want to mention that the ones you tied were missing. """ - mapping = getattr(self, "_tied_weights_keys", None) + # TODO Cyril: using this fixed set of keys (set in post_init()) does not allow to switch the config flag and re-tie + mapping = getattr(self, "all_tied_weights_keys", None) if not isinstance(mapping, dict): return - if ( # we only tie for ourselves, so we look at our config - not self.config.tie_word_embeddings - and not self.config.tie_encoder_decoder # if missing keys is None we init? - ): - return # TODO let's pray this is not too slow :) - top_level_params = dict(top_level.named_parameters(remove_duplicate=False)) | dict( - top_level.named_buffers(remove_duplicate=False) + top_level_params = dict(self.named_parameters(remove_duplicate=False)) | dict( + self.named_buffers(remove_duplicate=False) ) for target_name, source_name in mapping.items(): - source_name = f"^{module_prefix}.{source_name}" if module_prefix else "^" + source_name - target_name = f"^{module_prefix}.{target_name}" if module_prefix else "^" + target_name + source_name = "^" + source_name + target_name = "^" + target_name source_is_there = bool(missing_keys) and not re.search( source_name, "\n".join(missing_keys), flags=re.MULTILINE @@ -2440,10 +2413,10 @@ def __init__(self): for target_n, source_n in zip(target_params, cycle(source_params)): if "." in target_n: parent_path, last = target_n.rsplit(".", 1) - parent = top_level.get_submodule(parent_path) + parent = self.get_submodule(parent_path) else: parent_path, last = "", target_n - parent = top_level # top-level + parent = self # top-level setattr(parent, last, top_level_params[source_n]) self._adjust_bias(parent, top_level_params[source_n]) if missing_keys and source_is_there: # test_model_weights_reload_no_missing_tied_weights @@ -2470,19 +2443,6 @@ def _adjust_bias(self, output_embeddings, input_embeddings): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings - def tie_weights(self, missing_keys: Optional[set[str]] = None): - """ - Recursively (for all submodels) tie all the weights of the model. - """ - # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call - if missing_keys is None: - # called from `post_init` - self.tie_weight_source_and_target(self, missing_keys, "") - else: # this is from_pretrained, so its not called on every sub module - for module_prefix, module in self.named_modules(): - if isinstance(module, PreTrainedModel): - module.tie_weight_source_and_target(self, missing_keys, module_prefix) - def _get_no_split_modules(self, device_map: str): """ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to @@ -4301,19 +4261,23 @@ def _load_pretrained_model( for k in all_pointer: k.__exit__(None, None, None) + # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency) + model.mark_tied_weights_as_initialized() + # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when # loading the weights as they are not in the loaded state dict) - # Remove tied weights keys and etc miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys} model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer) - # correctly initialize the missing (and potentially mismatched) keys + # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialzed` flag) model._initialize_missing_keys(is_quantized) - missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys, False) - # We make sure we tie after _init_. We need the missing keys to remove the ones we do tie, and not random remove + # Tie the weights model.tie_weights(missing_keys) + # Adjust missing and unexpected keys + missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys) + # Post-processing for tensor parallelism if device_mesh is not None: # When using TP, the device map is a single device for all parameters @@ -4556,7 +4520,10 @@ def _move_missing_keys_from_meta_to_cpu( return model_state_dict = self.state_dict() - for key in missing_keys: + # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway) + # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they + # will be re-initialized for nothing (which can be quite long) + for key in missing_keys - self.all_tied_weights_keys.keys(): param = model_state_dict[key] # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them if param.device == torch.device("meta"): @@ -4585,23 +4552,13 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None: else: self.initialize_weights() - # Replace the loaded parameters class back to nn.Parameter (they were changed to easily skip initialization - # when performed in-place on the tensors) - for name, p in list(self.named_parameters()) + list(self.named_buffers()): - # We get back the original parameter that we stored in _original. This attribute was created when we initialized LoadedParam when loading the checkpoints. - if hasattr(p, "_original"): - if "." in name: - module, name = name.rsplit(".", 1) - module = self.get_submodule(module) - else: - module = self - setattr(module, name, p._original) - def _adjust_missing_and_unexpected_keys( - self, missing_keys: set[str], unexpected_keys: set[str], loading_task_model_from_base_state_dict: bool + self, missing_keys: set[str], unexpected_keys: set[str] ) -> tuple[set[str], set[str]]: """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid raising unneeded warnings/errors. + Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to + be tied anyway. """ # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to @@ -4625,14 +4582,22 @@ def _adjust_missing_and_unexpected_keys( if ignore_unexpected_regex is not None: unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None} - # Note: only the unexpected keys should remove the added prefix here, to correctly display the original name - # in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model - if loading_task_model_from_base_state_dict: - _prefix = f"{self.base_model_prefix}." - unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys} - return missing_keys, unexpected_keys + def mark_tied_weights_as_initialized(self): + """Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them + later as they will be tied (overwritten) anyway. + This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so + running inits on them is very costly.""" + for tied_param in self.all_tied_weights_keys.keys(): + # It's always a proper weight except for 2 or 3 old models where it's a regex or module set to None + # -> just skip it in those cases (they will just re-init before tying, so they loose the added optimization) + try: + param = self.get_parameter(tied_param) + param._is_hf_initialized = True + except AttributeError: + pass + def get_parameter_or_buffer(self, target: str): """ Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index b56f70d1e90d..4c7e81ed59cc 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -29,6 +29,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask @@ -410,9 +411,9 @@ def _init_weights(self, module): super()._init_weights(module) if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): - module.logit_scale.fill_(math.log(1 / 0.07)) + init.constant_(module.logit_scale, math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index e65f5ab5a612..4262efd5b45b 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -453,9 +454,9 @@ def _init_weights(self, module): super()._init_weights(module) if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): - module.logit_scale.fill_(math.log(1 / 0.07)) + init.constant_(module.logit_scale, math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index ac4337e4f269..bb3679064b0c 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import ( @@ -306,18 +307,19 @@ class AlbertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, AlbertMLMHead): - module.bias.zero_() + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6ec6d72a4771..84ac48e7675c 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -828,20 +829,21 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, AlignModel): - nn.init.xavier_uniform_(module.text_projection.weight) - module.text_projection.bias.zero_() - module.temperature.fill_(self.config.temperature_init_value) + init.xavier_uniform_(module.text_projection.weight) + init.zeros_(module.text_projection.bias) + init.constant_(module.temperature, self.config.temperature_init_value) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 1c45432d5f20..8cb6bbb3ba2f 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -776,43 +777,44 @@ def _init_weights(self, module): factor = self.config.initializer_factor if isinstance(module, AltCLIPVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, AltCLIPAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, AltCLIPMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, AltCLIPModel): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_factor) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_factor) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_factor) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_factor) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) class AltCLIPVisionTransformer(nn.Module): diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index f430972d61b7..76f3da90b1da 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -25,6 +25,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -589,7 +590,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaGroupedExpertsGemm): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -613,7 +614,7 @@ class AriaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaProjector): - nn.init.trunc_normal_(module.query, std=self.config.initializer_range) + init.trunc_normal_(module.query, std=self.config.initializer_range) class AriaTextRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index dce6584ed6af..28b62c390f7d 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig @@ -1200,7 +1201,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaGroupedExpertsGemm): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) class AriaPreTrainedModel(LlamaPreTrainedModel): @@ -1213,7 +1214,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, AriaProjector): - nn.init.trunc_normal_(module.query, std=self.config.initializer_range) + init.trunc_normal_(module.query, std=self.config.initializer_range) class AriaTextModel(LlamaModel): diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 1f270b96aa95..0f76c15e1f08 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput @@ -304,22 +305,16 @@ class ASTPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, ASTEmbeddings): - module.cls_token.zero_() - module.position_embeddings.zero_() - module.distillation_token.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) + init.zeros_(module.distillation_token) @auto_docstring diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 7947fca148be..1e6691d1bc6a 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -264,28 +264,6 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - @torch.no_grad() - def _init_weights(self, module): - # important: this ported version of AudioFlamingo3 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.audio_config.initializer_range - ) - - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 782ef440d0a7..581985fa0c4c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -24,6 +24,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( @@ -318,9 +319,9 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -333,7 +334,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -830,18 +831,19 @@ class AutoformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, AutoformerSinusoidalPositionalEmbedding): - module._init_weight() + init.copy_(module.weight, module.create_weight()) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) # copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index ed07f9345e2b..2428222e0dbe 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -32,6 +32,7 @@ from transformers.activations import ACT2FN +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -1130,9 +1131,9 @@ class BambaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, BambaMixer): - module.dt_bias.fill_(1.0) - module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) - module.D.fill_(1.0) + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) @auto_docstring diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 024e8415fffe..d29273940b8a 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,6 +42,7 @@ segment_sum, ) +from ... import initialization as init from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -804,9 +805,9 @@ class BambaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, BambaMixer): - module.dt_bias.fill_(1.0) - module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) - module.D.fill_(1.0) + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) @auto_docstring diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index e00068e34f0c..791678604f12 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -329,25 +329,6 @@ class BarkPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _supports_flash_attn = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear,)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if getattr(module, "bias", None) is not None: - module.bias.zero_() - module.weight.fill_(1.0) - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - @property def device(self) -> torch.device: """ @@ -1318,6 +1299,8 @@ def __init__(self, config): self.config = config + self.post_init() + @classmethod def can_generate(cls) -> bool: # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d08608268a15..6e25fa4d30fc 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -476,21 +476,6 @@ class BartPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 1d83a26a4b3f..76e20239648b 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -677,29 +678,19 @@ class BeitPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, BeitEmbeddings): - module.cls_token.zero_() + super()._init_weights(module) + if isinstance(module, BeitEmbeddings): + init.zeros_(module.cls_token) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, BeitRelativePositionBias): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) elif isinstance(module, BeitLayer): if module.lambda_1 is not None: - module.lambda_1.fill_(self.config.layer_scale_init_value) - module.lambda_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.lambda_1, self.config.layer_scale_init_value) + init.constant_(module.lambda_2, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index bf7d54108b32..b89382ca00a2 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -565,19 +566,9 @@ class BertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, BertLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, BertLMPredictionHead): + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 359ef6889a45..176ce4d003ae 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -459,19 +460,9 @@ class BertGenerationPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, BertGenerationOnlyLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, BertGenerationOnlyLMHead): + init.zeros_(module.bias) @auto_docstring( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index ccdc0dd8b842..1a90dd028c49 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -1517,19 +1518,9 @@ class BigBirdPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, BigBirdLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, BigBirdLMPredictionHead): + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 220b050496a1..9fb140bd1e2b 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1539,21 +1539,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index fe80fcda4dc8..61879e54b5a5 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -22,6 +22,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BackboneOutput, @@ -631,17 +632,17 @@ class BitPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + init.constant_(module.weight, 1) + init.constant_(module.bias, 0) @auto_docstring diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index bd7790f5a7a4..63021a27ca48 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -435,24 +435,8 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index bd1a36cb4d22..85e342619ff9 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -428,24 +428,8 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index aa812903f311..a1f06ae3a37e 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn.functional import normalize +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer @@ -422,32 +423,13 @@ class BlipPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - factor = self.config.initializer_range - if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)): - module.weight.normal_(mean=0.0, std=factor) - if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() - + super()._init_weights(module) + std = self.config.initializer_range if isinstance(module, BlipVisionEmbeddings): if hasattr(self.config, "vision_config"): - factor = self.config.vision_config.initializer_range - nn.init.trunc_normal_( - module.position_embedding, - mean=0.0, - std=factor, - ) - - nn.init.trunc_normal_( - module.class_embedding, - mean=0.0, - std=factor, - ) - - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + std = self.config.vision_config.initializer_range + init.trunc_normal_(module.position_embedding, mean=0.0, std=std) + init.trunc_normal_(module.class_embedding, mean=0.0, std=std) class BlipEncoder(nn.Module): diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 6e9e3bb7c2c3..eb67aaa45c3c 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -504,17 +504,6 @@ class BlipTextPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() - # Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571 class BlipTextModel(BlipTextPreTrainedModel): diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 45297cb4e57c..542eaca5ca29 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer @@ -411,20 +412,11 @@ class Blip2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - factor = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=factor) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor) - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, Blip2VisionEmbeddings): - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, Blip2VisionEmbeddings): + init.trunc_normal_(module.position_embedding, mean=0.0, std=std) + init.trunc_normal_(module.class_embedding, mean=0.0, std=std) elif isinstance( module, ( @@ -435,7 +427,7 @@ def _init_weights(self, module): Blip2ForImageTextRetrieval, ), ): - module.query_tokens.zero_() + init.zeros_(module.query_tokens) # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2 diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 3b0be0bef90a..8b8d4f715274 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -418,27 +418,8 @@ class BloomPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class BloomModel(BloomPreTrainedModel): diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index a44eb7bfabb1..66de121e78c9 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN, QuickGELUActivation from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...masking_utils import create_bidirectional_mask, create_causal_mask @@ -927,24 +928,24 @@ def _init_weights(self, module: nn.Module): attn_std = self.config.hidden_size**-0.5 fc_std = (2 * self.config.hidden_size) ** -0.5 for block in module.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std) - block.attn.in_proj_bias.zero_() - nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std) - - nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std) - nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std) + init.normal_(block.attn.in_proj_weight, std=attn_std * std) + init.zeros_(block.attn.in_proj_bias) + init.normal_(block.attn.out_proj.weight, std=proj_std * std) + init.normal_(block.mlp.c_fc.weight, std=fc_std * std) + init.normal_(block.mlp.c_proj.weight, std=proj_std * std) + + init.normal_(module.embeddings.class_embedding, std=attn_std * std) + init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)): - module.weight.normal_(mean=0.0, std=0.05 * std) + init.normal_(module.weight, mean=0.0, std=0.05 * std) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, BridgeTowerForContrastiveLearning): - module.logit_scale.fill_(self.config.logit_scale_init_value) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class BridgeTowerVisionModel(BridgeTowerPreTrainedModel): diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 74da9e9c8ae8..1705d6a47b95 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -517,20 +518,10 @@ class BrosPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" + super()._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, BrosRelationExtractor): - nn.init.normal_(module.dummy_node, std=std) + if isinstance(module, BrosRelationExtractor): + init.normal_(module.dummy_node, std=std) @auto_docstring @@ -549,7 +540,7 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = BrosPooler(config) if add_pooling_layer else None - self.init_weights() + self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings @@ -693,7 +684,7 @@ def __init__(self, config): self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -812,7 +803,7 @@ def __init__(self, config): # Subsequent token classification for Entity Extraction (NER) self.subsequent_token_classifier = BrosRelationExtractor(config) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -949,7 +940,7 @@ def __init__(self, config): self.entity_linker = BrosRelationExtractor(config) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 267aafe5959e..797ecaea2bd8 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -27,6 +27,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -413,19 +414,9 @@ class CamembertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, CamembertLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, CamembertLMHead): + init.zeros_(module.bias) class CamembertEmbeddings(nn.Module): diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 2b0a1e897266..1a9fd30d419e 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -719,21 +719,6 @@ class CaninePreTrainedModel(PreTrainedModel): base_model_prefix = "canine" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class CanineModel(CaninePreTrainedModel): diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 815124aa45f8..1bf5e50beed6 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_layers import GradientCheckpointingLayer @@ -568,47 +569,47 @@ def _init_weights(self, module): factor = self.config.initializer_factor if isinstance(module, ChineseCLIPVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, ChineseCLIPTextEmbeddings): - nn.init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range) - nn.init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range) - nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range) + init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range) + init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range) + init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range) for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]: if embedding.padding_idx is not None: - embedding.weight[embedding.padding_idx].zero_() + init.zeros_(embedding.weight[embedding.padding_idx]) elif isinstance(module, ChineseCLIPVisionAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, ChineseCLIPVisionMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, ChineseCLIPModel): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 420d14cb816c..583ac01290b8 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -1314,23 +1315,23 @@ def _init_weights(self, module: nn.Module): factor = self.config.initializer_factor if isinstance(module, ClapTextEmbeddings): - module.position_embeddings.weight.normal_(mean=0.0, std=factor * 0.02) - module.token_type_embeddings.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.position_embeddings.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.token_type_embeddings.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, ClapModel): - module.logit_scale_a.fill_(math.log(self.config.logit_scale_init_value)) - module.logit_scale_t.fill_(math.log(self.config.logit_scale_init_value)) + init.constant_(module.logit_scale_a, math.log(self.config.logit_scale_init_value)) + init.constant_(module.logit_scale_t, math.log(self.config.logit_scale_init_value)) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, (nn.Conv2d, nn.Linear)): in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor - nn.init.normal_(module.weight, std=in_proj_std) + init.normal_(module.weight, std=in_proj_std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, ClapAudioSelfAttention): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 8ce33c4a0dcf..49bab19b971f 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer @@ -413,57 +414,57 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, CLIPTextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, CLIPVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, CLIPAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, CLIPMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, CLIPModel): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) elif isinstance(module, CLIPVisionModelWithProjection): - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, CLIPTextModelWithProjection): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, CLIPForImageClassification): - nn.init.normal_( + init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class CLIPEncoder(nn.Module): diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 9f14686630ba..bba971644a23 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -23,6 +23,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -432,42 +433,42 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, CLIPSegTextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, CLIPSegVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, CLIPSegAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, CLIPSegMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, CLIPSegModel): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 9893b6bd1442..6c8171547aa3 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN, get_activation from ...cache_utils import Cache, DynamicCache from ...generation import GenerationConfig, GenerationMixin @@ -786,37 +787,37 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.weight, mean=0.0, std=factor * 0.02) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, ClvpRMSNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, ClvpEncoderMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, ClvpEncoder): config = self.config.get_text_config() factor = config.initializer_factor - module.projection.weight.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5)) + init.normal_(module.projection.weight, mean=0.0, std=factor * (config.hidden_size**-0.5)) elif isinstance(module, ClvpConditioningEncoder): - module.mel_conv.weight.normal_(mean=0.0, std=factor) - module.mel_conv.bias.zero_() + init.normal_(module.mel_conv.weight, mean=0.0, std=factor) + init.zeros_(module.mel_conv.bias) elif isinstance(module, ClvpForCausalLM): for name, p in module.named_parameters(): if name == "c_proj.weight": - p.normal_( - mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)) + init.normal_( + p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers) ) elif isinstance(module, ClvpModelForConditionalGeneration): - module.logit_scale.fill_(self.config.logit_scale_init_value) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) if isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) class ClvpEncoder(ClvpPreTrainedModel): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index b5e350d79d1a..307a1c9b7a41 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -277,27 +277,8 @@ class CodeGenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CodeGenBlock"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear,)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class CodeGenModel(CodeGenPreTrainedModel): diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 954722e2b144..75e35c4126c0 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -22,6 +22,7 @@ from transformers import AutoModelForImageTextToText +from ... import initialization as init from ...cache_utils import Cache from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple @@ -47,13 +48,14 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) @dataclass diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 27b897f70490..922d9bf3bd86 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -26,6 +26,7 @@ from transformers import AutoModelForImageTextToText +from ... import initialization as init from ...cache_utils import Cache from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available @@ -55,13 +56,14 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) @dataclass diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index c358dd3c2c82..9843dda12a60 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -976,21 +977,22 @@ def _init_weights(self, module): xavier_std = self.config.init_xavier_std if isinstance(module, ConditionalDetrMHAttentionMap): - nn.init.zeros_(module.k_linear.bias) - nn.init.zeros_(module.q_linear.bias) - nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) - nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + init.zeros_(module.k_linear.bias) + init.zeros_(module.q_linear.bias) + init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) elif isinstance(module, ConditionalDetrLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) # Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR @@ -1910,8 +1912,8 @@ def __init__(self, dim, fpn_dims, context_dim): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_uniform_(m.weight, a=1) - nn.init.constant_(m.bias, 0) + init.kaiming_uniform_(m.weight, a=1) + init.constant_(m.bias, 0) def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 392f8ec79a1c..ac33f13ad2b2 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, get_activation from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -111,22 +112,12 @@ class ConvBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, SeparableConv1D): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, SeparableConv1D): + init.zeros_(module.bias) elif isinstance(module, GroupedLinearLayer): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - module.bias.zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.bias) class SeparableConv1D(nn.Module): diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index c0cbc8e55476..851bee0060e1 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BackboneOutput, @@ -243,16 +244,10 @@ class ConvNextPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, ConvNextLayer): + super()._init_weights(module) + if isinstance(module, ConvNextLayer): if module.layer_scale_parameter is not None: - module.layer_scale_parameter.fill_(self.config.layer_scale_init_value) + init.constant_(module.layer_scale_parameter, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index de320116bd16..02e780aa70aa 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BackboneOutput, @@ -263,16 +264,10 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, ConvNextV2GRN): - module.weight.zero_() - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, ConvNextV2GRN): + init.zeros_(module.weight) + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 9f8ce38b2b08..c1e778bef1a6 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -528,21 +529,11 @@ class CpmAntPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, CpmAntLayerNorm): - module.weight.fill_(1.0) + super()._init_weights(module) + if isinstance(module, CpmAntLayerNorm): + init.ones_(module.weight) elif isinstance(module, CpmAntSegmentPositionEmbedding): - module.relative_attention_bias.normal_(mean=0.0, std=self.config.init_std) + init.normal_(module.relative_attention_bias, mean=0.0, std=self.config.init_std) @auto_docstring diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 7c2e8c676864..e92630e9591c 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -28,6 +28,7 @@ from transformers.utils.generic import check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -415,7 +416,7 @@ def _init_weights(self, module): if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index d1cb056f64bd..96ccda50cecb 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -21,6 +21,7 @@ from transformers.utils.generic import check_model_inputs +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask @@ -146,7 +147,7 @@ def _init_weights(self, module): if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index f3a5472410ce..a494b17c1cc9 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -26,7 +26,6 @@ from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D from ...utils import ( auto_docstring, logging, @@ -188,21 +187,6 @@ class CTRLPreTrainedModel(PreTrainedModel): config: CTRLConfig base_model_prefix = "transformer" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class CTRLModel(CTRLPreTrainedModel): diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 55b251a087e7..698e1f8d33a1 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -493,17 +494,15 @@ class CvtPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.copy_(nn.init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, CvtStage): if self.config.cls_token[module.stage]: - module.cls_token.copy_( - nn.init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) - ) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index c1a620dbb75b..d61b96ce566b 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -24,9 +24,9 @@ import torch import torch.nn.functional as F -import torch.nn.init as init from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput @@ -453,22 +453,22 @@ def _init_weights(self, module): for layer in module.class_embed: prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(layer.weight) - nn.init.constant_(layer.bias, bias) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) if module.bbox_embed is not None: for layer in module.bbox_embed: - nn.init.constant_(layer.layers[-1].weight, 0) - nn.init.constant_(layer.layers[-1].bias, 0) + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) if hasattr(module, "reg_scale"): - module.reg_scale.fill_(self.config.reg_scale) + init.constant_(module.reg_scale, self.config.reg_scale) if hasattr(module, "up"): - module.up.fill_(self.config.up) + init.constant_(module.up, self.config.up) if isinstance(module, DFineMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -478,22 +478,21 @@ def _init_weights(self, module): grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) grid_init *= scaling - with torch.no_grad(): - module.sampling_offsets.bias[...] = grid_init.flatten() + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) if isinstance(module, DFineModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(module.enc_score_head.weight) - nn.init.constant_(module.enc_score_head.bias, bias) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) @@ -505,13 +504,13 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].weight, 0) if isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - nn.init.xavier_uniform_(module.weight_embedding.weight) + init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - nn.init.xavier_uniform_(module.denoising_class_embed.weight) + init.xavier_uniform_(module.denoising_class_embed.weight) class DFineIntegral(nn.Module): diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index 4ce91d1b98a7..e0468b32e048 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -17,9 +17,9 @@ import torch import torch.nn.functional as F -import torch.nn.init as init from torch import nn +from ... import initialization as init from ...activations import ACT2CLS from ...configuration_utils import PreTrainedConfig from ...image_transforms import corners_to_center_format @@ -591,28 +591,29 @@ def forward( class DFinePreTrainedModel(RTDetrPreTrainedModel): @torch.no_grad() def _init_weights(self, module): + """Initialize the weights""" # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: for layer in module.class_embed: prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(layer.weight) - nn.init.constant_(layer.bias, bias) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) if module.bbox_embed is not None: for layer in module.bbox_embed: - nn.init.constant_(layer.layers[-1].weight, 0) - nn.init.constant_(layer.layers[-1].bias, 0) + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) if hasattr(module, "reg_scale"): - module.reg_scale.fill_(self.config.reg_scale) + init.constant_(module.reg_scale, self.config.reg_scale) if hasattr(module, "up"): - module.up.fill_(self.config.up) + init.constant_(module.up, self.config.up) if isinstance(module, DFineMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -622,22 +623,21 @@ def _init_weights(self, module): grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) grid_init *= scaling - with torch.no_grad(): - module.sampling_offsets.bias[...] = grid_init.flatten() + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) if isinstance(module, DFineModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(module.enc_score_head.weight) - nn.init.constant_(module.enc_score_head.bias, bias) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) @@ -649,13 +649,13 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].weight, 0) if isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - nn.init.xavier_uniform_(module.weight_embedding.weight) + init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - nn.init.xavier_uniform_(module.denoising_class_embed.weight) + init.xavier_uniform_(module.denoising_class_embed.weight) class DFineIntegral(nn.Module): diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index f4606ccd0499..e1d25bc121d0 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -821,29 +822,30 @@ def _init_weights(self, module): xavier_std = self.config.init_xavier_std if isinstance(module, DabDetrMHAttentionMap): - nn.init.zeros_(module.k_linear.bias) - nn.init.zeros_(module.q_linear.bias) - nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) - nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + init.zeros_(module.k_linear.bias) + init.zeros_(module.q_linear.bias) + init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, DabDetrForObjectDetection): - nn.init.constant_(module.bbox_predictor.layers[-1].weight, 0) - nn.init.constant_(module.bbox_predictor.layers[-1].bias, 0) + init.constant_(module.bbox_predictor.layers[-1].weight, 0) + init.constant_(module.bbox_predictor.layers[-1].bias, 0) # init prior_prob setting for focal loss prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias_value = -math.log((1 - prior_prob) / prior_prob) - module.class_embed.bias.fill_(bias_value) + init.constant_(module.class_embed.bias, bias_value) elif isinstance(module, nn.PReLU): module.reset_parameters() diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 54f1d1a32d49..d74a369d3d33 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ModelOutput, auto_docstring from .configuration_dac import DacConfig @@ -480,14 +481,14 @@ class DacPreTrainedModel(PreTrainedAudioTokenizerBase): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv1d): - nn.init.trunc_normal_(module.weight, std=0.02) - nn.init.constant_(module.bias, 0) + init.trunc_normal_(module.weight, std=0.02) + init.constant_(module.bias, 0) elif isinstance(module, Snake1d): - module.alpha.fill_(1.0) + init.ones_(module.alpha) elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=0.02) + init.normal_(module.weight, mean=0.0, std=0.02) def apply_weight_norm(self): weight_norm = nn.utils.weight_norm diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 5ce39db6f403..5288c36de86e 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -485,26 +486,26 @@ def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Data2VecAudioFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, Data2VecAudioPositionalConvLayer): - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if module.weight is not None: - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None @@ -1043,7 +1044,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ @@ -1199,7 +1200,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index b7a2a7ed2300..190ebea39aa6 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -494,25 +494,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): "cross_attentions": Data2VecTextCrossAttention, } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() - if hasattr(module, "weight") and module.weight is not None: - module.weight.fill_(1.0) - class Data2VecTextEncoder(nn.Module): def __init__(self, config): diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 4a906aac83a3..d21dfd325e4b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -691,29 +692,19 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, Data2VecVisionEmbeddings): - module.cls_token.zero_() + super()._init_weights(module) + if isinstance(module, Data2VecVisionEmbeddings): + init.zeros_(module.cls_token) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, Data2VecVisionRelativePositionBias): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) elif isinstance(module, Data2VecVisionLayer): if module.lambda_1 is not None: - module.lambda_1.fill_(self.config.layer_scale_init_value) - module.lambda_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.lambda_1, self.config.layer_scale_init_value) + init.constant_(module.lambda_2, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 070439d7de4e..78e739911cb5 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import Wav2Vec2BaseModelOutput @@ -149,26 +150,26 @@ def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Data2VecAudioFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, Data2VecAudioPositionalConvLayer): - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if module.weight is not None: - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_adapters(self): raise AttributeError("Not needed for Data2VecAudio") diff --git a/src/transformers/models/data2vec/modular_data2vec_text.py b/src/transformers/models/data2vec/modular_data2vec_text.py index ad0dc81c8e01..7cce78815954 100644 --- a/src/transformers/models/data2vec/modular_data2vec_text.py +++ b/src/transformers/models/data2vec/modular_data2vec_text.py @@ -81,25 +81,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): "cross_attentions": Data2VecTextCrossAttention, } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() - if hasattr(module, "weight") and module.weight is not None: - module.weight.fill_(1.0) - @auto_docstring class Data2VecTextModel(RobertaModel): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index db212fd6378e..ddf5fce4dfce 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -25,6 +25,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -468,23 +469,12 @@ class DbrxPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): + super()._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, DbrxExpertGLU): - module.w1.normal_(mean=0.0, std=std) - module.v1.normal_(mean=0.0, std=std) - module.w2.normal_(mean=0.0, std=std) + if isinstance(module, DbrxExpertGLU): + init.normal_(module.w1, mean=0.0, std=std) + init.normal_(module.v1, mean=0.0, std=std) + init.normal_(module.w2, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index c9633e20fe1e..42a9079cb012 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -21,6 +21,7 @@ import torch.utils.checkpoint from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -338,23 +339,12 @@ class DbrxPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): + super()._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, DbrxExpertGLU): - module.w1.normal_(mean=0.0, std=std) - module.v1.normal_(mean=0.0, std=std) - module.w2.normal_(mean=0.0, std=std) + if isinstance(module, DbrxExpertGLU): + init.normal_(module.w1, mean=0.0, std=std) + init.normal_(module.v1, mean=0.0, std=std) + init.normal_(module.w2, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 3b2ea9b53724..b97938feb4fc 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -20,6 +20,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -617,22 +618,12 @@ class DebertaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, DisentangledSelfAttention): - module.q_bias.zero_() - module.v_bias.zero_() + super()._init_weights(module) + if isinstance(module, DisentangledSelfAttention): + init.zeros_(module.q_bias) + init.zeros_(module.v_bias) elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)): - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 791e433e4d2c..52299f766fbb 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -696,19 +697,9 @@ class DebertaV2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)): + init.zeros_(module.bias) @auto_docstring @@ -1282,7 +1273,7 @@ def __init__(self, config): drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out self.dropout = nn.Dropout(drop_out) - self.init_weights() + self.post_init() def get_input_embeddings(self): return self.deberta.get_input_embeddings() diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 678b21808b1d..efbb6cbc36d7 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer @@ -374,17 +375,7 @@ def __init__(self, *inputs, **kwargs): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + super()._init_weights(module) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -396,7 +387,7 @@ def _init_weights(self, module): for name, p in module.named_parameters(): if "c_proj" in name and "weight" in name: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -614,21 +605,6 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel): main_input_name = "states" supports_gradient_checkpointing = False - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index e1785ef9a851..1d0ccb2e8f2a 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -465,7 +466,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DeepseekV2Moe): - module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 6994ba582e5a..3deed89d041f 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...cache_utils import Cache from ...modeling_rope_utils import RopeParameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -436,7 +437,7 @@ class DeepseekV2PreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DeepseekV2Moe): - module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) class DeepseekV2Model(LlamaModel): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index e619afd25773..a18d0b5083fd 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -552,7 +553,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DeepseekV3TopkRouter): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 5a92d135870d..23d0f1e568d4 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification @@ -308,7 +309,7 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DeepseekV3TopkRouter): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) class DeepseekV3Model(LlamaModel): diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index 849eb5ef34f0..e8dffe07e84f 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -132,15 +132,6 @@ class DeepseekVLPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_param_buffer_assignment = False - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - # Required only for Linear layer in DeepseekVLAligner - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) - if module.bias is not None: - module.bias.zero_() - @auto_docstring class DeepseekVLModel(DeepseekVLPreTrainedModel): diff --git a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py index 038ffc4c8c0a..18f2b65a669e 100644 --- a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py @@ -134,15 +134,6 @@ def forward(self, vision_encodings: torch.Tensor) -> torch.Tensor: class DeepseekVLPreTrainedModel(JanusPreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - # Required only for Linear layer in DeepseekVLAligner - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) - if module.bias is not None: - module.bias.zero_() - @auto_docstring class DeepseekVLModel(JanusModel): diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index 17fed96166ce..99e83417a903 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -24,6 +24,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput @@ -218,18 +219,18 @@ class DeepseekVLHybridPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, DeepseekVLHybridLayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, DeepseekVLHybridModel): - module.high_res_vision_alpha.zero_() + init.zeros_(module.high_res_vision_alpha) DEEPSEEK_VL_COMMON_CUSTOM_ARGS = r""" diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index c8f5be1638d4..4501ed7810d2 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -18,6 +18,7 @@ import torch.nn as nn from torchvision.transforms.v2 import functional as F +from ... import initialization as init from ...cache_utils import Cache from ...image_processing_utils_fast import ( BaseImageProcessorFast, @@ -220,18 +221,18 @@ class DeepseekVLHybridPreTrainedModel(DeepseekVLPreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, DeepseekVLHybridLayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, DeepseekVLHybridModel): - module.high_res_vision_alpha.zero_() + init.zeros_(module.high_res_vision_alpha) class DeepseekVLHybridModel(DeepseekVLModel): diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 553eb8b7a2b5..e92ce998e8b8 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -931,10 +932,10 @@ def _init_weights(self, module): std = self.config.init_std if isinstance(module, DeformableDetrLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) elif isinstance(module, DeformableDetrMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -947,27 +948,28 @@ def _init_weights(self, module): ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): - nn.init.normal_(module.level_embed) + init.normal_(module.level_embed) class DeformableDetrEncoder(DeformableDetrPreTrainedModel): diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b80a02d83a14..62a787e11507 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -370,24 +371,18 @@ class DeiTPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, DeiTEmbeddings): - module.cls_token.zero_() - module.position_embeddings.zero_() - module.distillation_token.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) + init.zeros_(module.distillation_token) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) @auto_docstring diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index d7336d304a76..aa416298f548 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from .... import initialization as init from ....activations import ACT2FN from ....file_utils import ( ModelOutput, @@ -573,7 +574,7 @@ def __init__(self, config: DetaConfig, num_heads: int, n_points: int): self._reset_parameters() def _reset_parameters(self): - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) + init.constant_(self.sampling_offsets.weight.data, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) @@ -586,12 +587,12 @@ def _reset_parameters(self): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) + init.constant_(self.attention_weights.weight.data, 0.0) + init.constant_(self.attention_weights.bias.data, 0.0) + init.xavier_uniform_(self.value_proj.weight.data) + init.constant_(self.value_proj.bias.data, 0.0) + init.xavier_uniform_(self.output_proj.weight.data) + init.constant_(self.output_proj.bias.data, 0.0) def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -993,23 +994,24 @@ def _init_weights(self, module): std = self.config.init_std if isinstance(module, DetaLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) elif isinstance(module, DetaMultiscaleDeformableAttention): module._reset_parameters() elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): - nn.init.normal_(module.level_embed) + init.normal_(module.level_embed) DETA_START_DOCSTRING = r""" @@ -1812,15 +1814,15 @@ def __init__(self, config: DetaConfig): prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) self.class_embed.bias.data.fill_(bias_value) - nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) - nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + init.constant_(self.bbox_embed.layers[-1].bias.data, 0) # if two-stage, the last class_embed and bbox_embed is for region proposal generation num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers if config.with_box_refine: self.class_embed = _get_clones(self.class_embed, num_pred) self.bbox_embed = _get_clones(self.bbox_embed, num_pred) - nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) # hack implementation for iterative bounding box refinement self.model.decoder.bbox_embed = self.bbox_embed self._tied_weights_keys.update( @@ -1829,7 +1831,7 @@ def __init__(self, config: DetaConfig): } ) else: - nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) self.model.decoder.bbox_embed = None @@ -1842,7 +1844,7 @@ def __init__(self, config: DetaConfig): } ) for box_embed in self.bbox_embed: - nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py index f3303da0f6fd..8f1f657c5f42 100644 --- a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py @@ -498,17 +498,6 @@ class EfficientFormerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = False - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - EFFICIENTFORMER_START_DOCSTRING = r""" This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) subclass. Use it as a diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index 7ed73c5a49a8..a22e21bd7a49 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -368,21 +368,6 @@ class ErnieMPreTrainedModel(PreTrainedModel): config: ErnieMConfig base_model_prefix = "ernie_m" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index eaeb2eb035b1..291677c6caa4 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +from .... import initialization as init from ....activations import ACT2FN from ....cache_utils import Cache from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions @@ -533,56 +534,56 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, nn.LayerNorm): - module.weight.fill_(factor * 1.0) - module.bias.zero_() + init.constant_(module.weight, factor * 1.0) + init.zeros_(module.bias) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseModel): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.embed_tokens.weight.normal_(mean=0.0, std=factor * 1.0) - module.position_embeddings.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.embed_tokens.weight, mean=0.0, std=factor * 1.0) + init.normal_(module.position_embeddings.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None: - module.extra_position_embeddings.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.extra_position_embeddings.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.final_logits_bias.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.final_logits_bias, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseDenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, GPTSanJapaneseAttention): # Multi-headed attention d_model = self.config.d_model key_value_proj_dim = self.config.d_model n_heads = self.config.num_heads - module.k_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.v_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.q_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.out_proj.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.k_proj.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.v_proj.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.q_proj.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) elif isinstance(module, GPTSanJapaneseSparseMLP): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_model n_heads = self.config.num_heads - module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) + init.normal_(module.router.classifier.weight, mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wi.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wo.weight, mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py index bc74d7a5e7d5..0069abd2ce6f 100755 --- a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py +++ b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py @@ -22,6 +22,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from .... import initialization as init from ....activations import ACT2FN from ....modeling_outputs import ( BaseModelOutputWithNoAttention, @@ -346,17 +347,17 @@ def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization - nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: - nn.init.xavier_uniform_(self.k_proj.weight) - nn.init.xavier_uniform_(self.v_proj.weight) - nn.init.xavier_uniform_(self.q_proj.weight) + init.xavier_uniform_(self.k_proj.weight) + init.xavier_uniform_(self.v_proj.weight) + init.xavier_uniform_(self.q_proj.weight) - nn.init.xavier_uniform_(self.out_proj.weight) + init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: - nn.init.constant_(self.out_proj.bias, 0.0) + init.constant_(self.out_proj.bias, 0.0) def forward( self, @@ -709,27 +710,23 @@ class GraphormerPreTrainedModel(PreTrainedModel): main_input_name_nodes = "input_nodes" main_input_name_edges = "input_edges" - def normal_(self, data: torch.Tensor): - # with FSDP, module params will be on CUDA, so we cast them back to CPU - # so that the RNG is consistent with and without FSDP - data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) - def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]): """ Initialize the weights specific to the Graphormer Model. """ if isinstance(module, nn.Linear): - self.normal_(module.weight.data) + init.normal_(module.weight.data, mean=0.0, std=0.02) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if isinstance(module, nn.Embedding): - self.normal_(module.weight.data) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + init.normal_(module.weight.data, mean=0.0, std=0.02) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if isinstance(module, GraphormerMultiheadAttention): - self.normal_(module.q_proj.weight.data) - self.normal_(module.k_proj.weight.data) - self.normal_(module.v_proj.weight.data) + init.normal_(module.q_proj.weight.data, mean=0.0, std=0.02) + init.normal_(module.k_proj.weight.data, mean=0.0, std=0.02) + init.normal_(module.v_proj.weight.data, mean=0.0, std=0.02) @torch.no_grad() def _init_weights( @@ -741,31 +738,16 @@ def _init_weights( """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Conv2d)): - # We might be missing part of the Linear init, dependent on the layer num - module.weight.normal_(mean=0.0, std=0.02) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=0.02) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, GraphormerMultiheadAttention): - module.q_proj.weight.normal_(mean=0.0, std=0.02) - module.k_proj.weight.normal_(mean=0.0, std=0.02) - module.v_proj.weight.normal_(mean=0.0, std=0.02) + super()._init_weights(module) + if isinstance(module, GraphormerMultiheadAttention): + init.normal_(module.q_proj.weight, mean=0.0, std=0.02) + init.normal_(module.k_proj.weight, mean=0.0, std=0.02) + init.normal_(module.v_proj.weight, mean=0.0, std=0.02) module.reset_parameters() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) elif isinstance(module, GraphormerGraphEncoder): if module.apply_graphormer_init: module.apply(self.init_graphormer_params) - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - class GraphormerModel(GraphormerPreTrainedModel): """The Graphormer model is a graph-encoder model. diff --git a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py index d71fadd8bf6c..e71d82b745e3 100755 --- a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import LayerNorm as FusedLayerNorm +from .... import initialization as init from ....activations import ACT2FN from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, logging @@ -604,20 +605,20 @@ class JukeboxVQVAE(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Embedding): # embed_tokens - module.weight.normal_(mean=0.0, std=0.02 * self.config.init_scale) + init.normal_(module.weight, mean=0.0, std=0.02 * self.config.init_scale) elif isinstance(module, JukeboxConv1D): if self.config.zero_out: - module.weight.zero_() + init.zeros_(module.weight) else: - module.weight.normal_(mean=0.0, std=0.02 * self.config.init_scale) + init.normal_(module.weight, mean=0.0, std=0.02 * self.config.init_scale) elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: - module.conv1d_2.weight.zero_() - module.conv1d_2.bias.zero_() + init.zeros_(module.conv1d_2.weight) + init.zeros_(module.conv1d_2.bias) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def __init__(self, config: JukeboxVQVAEConfig): super().__init__(config) @@ -1796,28 +1797,28 @@ def _init_weights(self, module): init_scale = self.config.init_scale if isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=0.02 * init_scale) + init.normal_(module.weight, mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxConv1D): if self.config.zero_out: - module.weight.zero_() + init.zeros_(module.weight) else: - module.weight.normal_(mean=0.0, std=0.02 * init_scale) + init.normal_(module.weight, mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxPositionalEmbedding): - module.pos_emb.normal_(mean=0.0, std=0.01 * init_scale) + init.normal_(module.pos_emb, mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxRangeEmbedding): - module.emb.weight.normal_(mean=0.0, std=0.01 * init_scale) + init.normal_(module.emb.weight, mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"): - module.lm_head.weight.normal_(mean=0.0, std=0.02 * init_scale) + init.normal_(module.lm_head.weight, mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"): - module.start_token.normal_(mean=0.0, std=0.01 * init_scale) + init.normal_(module.start_token, mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: - module.conv1d_2.weight.zero_() - module.conv1d_2.bias.zero_() + init.zeros_(module.conv1d_2.weight) + init.zeros_(module.conv1d_2.bias) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None): super().__init__(config) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index db7c475dabd4..8f0530c40756 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -20,6 +20,7 @@ import torch from torch import nn +from .... import initialization as init from ....activations import ACT2FN from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ....integrations.deepspeed import is_deepspeed_zero3_enabled @@ -395,25 +396,10 @@ class MCTCTPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MCTCTLayerNorm): - module.singleton_weight.fill_(1.0) - module.singleton_bias.zero_() - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, MCTCTLayerNorm): + init.ones_(module.singleton_weight) + init.zeros_(module.singleton_bias) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ diff --git a/src/transformers/models/deprecated/mega/modeling_mega.py b/src/transformers/models/deprecated/mega/modeling_mega.py index d66848e1d2b1..eb0e93d636db 100644 --- a/src/transformers/models/deprecated/mega/modeling_mega.py +++ b/src/transformers/models/deprecated/mega/modeling_mega.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from .... import initialization as init from ....activations import ACT2FN from ....cache_utils import Cache from ....modeling_outputs import ( @@ -1338,44 +1339,49 @@ def _init_weights(self, module): if isinstance(module, MegaMultiDimensionDampedEma): with torch.no_grad(): # delta & alpha - nn.init.normal_(module.damping_factor, mean=0.0, std=self.config.ema_delta_alpha_range) - nn.init.normal_(module.decay_factor, mean=0.0, std=self.config.ema_delta_alpha_range) + init.normal_(module.damping_factor, mean=0.0, std=self.config.ema_delta_alpha_range) + init.normal_(module.decay_factor, mean=0.0, std=self.config.ema_delta_alpha_range) # beta [1, -1, 1, -1, ...] seems more stable. val = torch.ones(self.config.ema_projection_size, 1) if self.config.ema_projection_size > 1: idx = torch.tensor(list(range(1, self.config.ema_projection_size, 2))) val.index_fill_(0, idx, -1.0) - module.ema_expansion_matrix.normal_(mean=0.0, std=self.config.ema_beta_range).add_(val) + init.copy_( + module.ema_expansion_matrix, + torch.normal(mean=0.0, std=self.config.ema_beta_range, size=module.ema_expansion_matrix.shape) + + val, + ) # gamma & omega - nn.init.normal_(module.kernel_projection_matrix, mean=0.0, std=self.config.ema_gamma_omega_range) - nn.init.normal_(module.residual_weight, mean=0.0, std=self.config.ema_gamma_omega_range) + init.normal_(module.kernel_projection_matrix, mean=0.0, std=self.config.ema_gamma_omega_range) + init.normal_(module.residual_weight, mean=0.0, std=self.config.ema_gamma_omega_range) elif isinstance(module, MegaSimpleRelativePositionalBias): - nn.init.normal_(module.rel_pos_bias, mean=0.0, std=self.config.initializer_range) + init.normal_(module.rel_pos_bias, mean=0.0, std=self.config.initializer_range) elif isinstance(module, MegaRotaryRelativePositionalBias): - nn.init.normal_(module.alpha, mean=0.0, std=self.config.initializer_range) - nn.init.normal_(module.b_param, mean=0.0, std=self.config.initializer_range) + init.normal_(module.alpha, mean=0.0, std=self.config.initializer_range) + init.normal_(module.b_param, mean=0.0, std=self.config.initializer_range) elif isinstance(module, MegaScaleNorm): if self.config.norm_affine: - nn.init.constant_(module.scalar, 1.0) + init.constant_(module.scalar, 1.0) elif isinstance(module, MegaRMSNorm): if self.config.norm_affine: - nn.init.constant_(module.weight, 1.0) + init.constant_(module.weight, 1.0) elif isinstance(module, MegaMovingAverageGatedAttention): # linear layers covered separately by the generic nn.Linear init below - nn.init.normal_(module.qk_weight, mean=0.0, std=self.config.initializer_range) - nn.init.constant_(module.qk_bias, 0.0) + init.normal_(module.qk_weight, mean=0.0, std=self.config.initializer_range) + init.constant_(module.qk_bias, 0.0) elif isinstance(module, nn.Linear): # initializes all linear layers in the entire network - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) MEGA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/nat/modeling_nat.py b/src/transformers/models/deprecated/nat/modeling_nat.py index a43562406ce6..20e2d62ee57f 100644 --- a/src/transformers/models/deprecated/nat/modeling_nat.py +++ b/src/transformers/models/deprecated/nat/modeling_nat.py @@ -592,17 +592,6 @@ class NatPreTrainedModel(PreTrainedModel): base_model_prefix = "nat" main_input_name = "pixel_values" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - NAT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 8e3cb0cd3f4b..3bb957339e42 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -587,21 +587,6 @@ class NezhaPreTrainedModel(PreTrainedModel): base_model_prefix = "nezha" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @dataclass class NezhaForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 7da07eca1e34..261fa5c97708 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from .... import initialization as init from ....activations import ACT2FN from ....cache_utils import Cache from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -443,16 +444,17 @@ class OpenLlamaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): if self.config.use_stable_embedding: - torch.nn.init.xavier_normal_(module.weight) + init.xavier_normal_(module.weight) else: - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) OPEN_LLAMA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index f395fe51d645..ca7f491508f2 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -597,21 +597,6 @@ class QDQBertPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 0b8062c5c900..d4ea5198965a 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -787,21 +787,6 @@ class RealmPreTrainedModel(PreTrainedModel): config: RealmConfig base_model_prefix = "realm" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - def _flatten_inputs(self, *inputs): """Flatten inputs' shape to (-1, input_shape[-1])""" flattened_inputs = [] diff --git a/src/transformers/models/deprecated/retribert/modeling_retribert.py b/src/transformers/models/deprecated/retribert/modeling_retribert.py index 7a762e46b890..9158e8cb83e3 100644 --- a/src/transformers/models/deprecated/retribert/modeling_retribert.py +++ b/src/transformers/models/deprecated/retribert/modeling_retribert.py @@ -42,21 +42,6 @@ class RetriBertPreTrainedModel(PreTrainedModel): config: RetriBertConfig base_model_prefix = "retribert" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - RETRIBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 821467abccba..f0ca446cc838 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from .... import initialization as init from ....activations import ACT2FN from ....cache_utils import Cache from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask @@ -373,20 +374,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding): + super()._init_weights(module) + if isinstance(module, Speech2Text2SinusoidalPositionalEmbedding): weight = module.get_embedding(*module.weight.shape, module.padding_idx) - weight = nn.Parameter(weight, requires_grad=False) - weight.detach_() - module.weight = weight + init.copy_(module.weight, weight) SPEECH_TO_TEXT_2_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 2bc57636b944..3b8cb4e2edf0 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import functional as F +from .... import initialization as init from ....cache_utils import Cache from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel @@ -87,19 +88,19 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, EinLinear): for i in range(module.n_models): - nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range) + init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i]) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight[i]) bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range - nn.init.uniform_(module.bias[i], -bound, bound) + init.uniform_(module.bias[i], -bound, bound) TRAJECTORY_TRANSFORMER_START_DOCSTRING = r""" @@ -158,11 +159,11 @@ def __init__(self, n_models, in_features, out_features, bias): def reset_parameters(self): for i in range(self.n_models): - nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) + init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i]) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight[i]) bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(self.bias[i], -bound, bound) + init.uniform_(self.bias[i], -bound, bound) def forward(self, input): """ diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py index b28613d71b7f..d890edab89d9 100644 --- a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from .... import initialization as init from ....modeling_utils import PreTrainedModel from ....utils import ( ModelOutput, @@ -331,12 +332,12 @@ class TransfoXLPreTrainedModel(PreTrainedModel): def _init_weight(self, weight): if self.config.init == "uniform": - nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) + init.uniform_(weight, -self.config.init_range, self.config.init_range) elif self.config.init == "normal": - nn.init.normal_(weight, 0.0, self.config.init_std) + init.normal_(weight, 0.0, self.config.init_std) def _init_bias(self, bias): - nn.init.constant_(bias, 0.0) + init.constant_(bias, 0.0) def _init_weights(self, m): """Initialize the weights.""" @@ -350,7 +351,7 @@ def _init_weights(self, m): if hasattr(m, "emb_projs"): for i in range(len(m.emb_projs)): if m.emb_projs[i] is not None: - nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) + init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) elif classname.find("Embedding") != -1: if hasattr(m, "weight"): self._init_weight(m.weight) @@ -362,10 +363,10 @@ def _init_weights(self, m): if hasattr(m, "out_projs"): for i in range(len(m.out_projs)): if m.out_projs[i] is not None: - nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) + init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) elif classname.find("LayerNorm") != -1: if hasattr(m, "weight"): - nn.init.normal_(m.weight, 1.0, self.config.init_std) + init.normal_(m.weight, 1.0, self.config.init_std) if hasattr(m, "bias") and m.bias is not None: self._init_bias(m.bias) else: diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index fbea2e2b77a3..99872fe5ea52 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -548,17 +548,6 @@ class TvltPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - TVLT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 007b74755e5d..3b2b484487a1 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -21,6 +21,7 @@ import torch from torch import nn +from .... import initialization as init from ....activations import ACT2FN from ....modeling_outputs import ( BaseModelOutputWithNoAttention, @@ -363,18 +364,18 @@ class VanPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) + init.trunc_normal_(module.weight, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: - nn.init.constant_(module.bias, 0) + init.constant_(module.bias, 0) elif isinstance(module, nn.LayerNorm): - nn.init.constant_(module.bias, 0) - nn.init.constant_(module.weight, 1.0) + init.constant_(module.bias, 0) + init.constant_(module.weight, 1.0) elif isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups - module.weight.normal_(0, math.sqrt(2.0 / fan_out)) + init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index bbc6554ff5d5..d414277e3e70 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -21,6 +21,7 @@ import torch from torch import nn +from .... import initialization as init from ....activations import ACT2FN from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput @@ -461,34 +462,16 @@ class ViTHybridPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, ViTHybridEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) - module.mask_token.zero_() + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index c592e756b7c9..05fa850a176b 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -520,17 +520,6 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel): base_model_prefix = "prophetnet" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index d6dae7cb72ee..158179e04184 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -216,17 +216,6 @@ class DepthAnythingPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - class DepthAnythingNeck(nn.Module): """ diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py index b754cf9074c1..23e96f838c4d 100644 --- a/src/transformers/models/depth_pro/modeling_depth_pro.py +++ b/src/transformers/models/depth_pro/modeling_depth_pro.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int from ..auto import AutoModel @@ -612,16 +613,16 @@ class DepthProPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 84b4fbf9af49..968859bf3d21 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -733,21 +734,22 @@ def _init_weights(self, module): xavier_std = self.config.init_xavier_std if isinstance(module, DetrMHAttentionMap): - nn.init.zeros_(module.k_linear.bias) - nn.init.zeros_(module.q_linear.bias) - nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) - nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + init.zeros_(module.k_linear.bias) + init.zeros_(module.q_linear.bias) + init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) elif isinstance(module, DetrLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) class DetrEncoder(DetrPreTrainedModel): @@ -1613,8 +1615,8 @@ def __init__(self, dim, fpn_dims, context_dim): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_uniform_(m.weight, a=1) - nn.init.constant_(m.bias, 0) + init.kaiming_uniform_(m.weight, a=1) + init.constant_(m.bias, 0) def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 7e67ac52768c..d19e9299f8f5 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -28,6 +28,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin @@ -600,10 +601,10 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DiffLlamaAttention): - module.lambda_q1.normal_(0, self.config.lambda_std_dev) - module.lambda_k1.normal_(0, self.config.lambda_std_dev) - module.lambda_q2.normal_(0, self.config.lambda_std_dev) - module.lambda_k2.normal_(0, self.config.lambda_std_dev) + init.normal_(module.lambda_q1, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_k1, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_q2, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_k2, 0, self.config.lambda_std_dev) @auto_docstring diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 97b1cc051660..c99d7b2feec2 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...cache_utils import Cache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_utils import PreTrainedModel @@ -403,10 +404,10 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DiffLlamaAttention): - module.lambda_q1.normal_(0, self.config.lambda_std_dev) - module.lambda_k1.normal_(0, self.config.lambda_std_dev) - module.lambda_q2.normal_(0, self.config.lambda_std_dev) - module.lambda_k2.normal_(0, self.config.lambda_std_dev) + init.normal_(module.lambda_q1, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_k1, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_q2, 0, self.config.lambda_std_dev) + init.normal_(module.lambda_k2, 0, self.config.lambda_std_dev) class DiffLlamaModel(LlamaModel): diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 103e12ce5ed9..d2b111226608 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -561,17 +561,6 @@ class DinatPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class DinatModel(DinatPreTrainedModel): diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 49693d507733..1ce9ff9a63cf 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput @@ -418,39 +419,19 @@ class Dinov2PreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, Dinov2Embeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) - + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) if self.config.use_mask_token: - module.mask_token.zero_() + init.zeros_(module.mask_token) elif isinstance(module, Dinov2LayerScale): - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index ddbc6e05b1a5..5feb0e7b5b44 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput @@ -435,39 +436,19 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, Dinov2WithRegistersEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) - - module.mask_token.zero_() - module.register_tokens.zero_() + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + init.zeros_(module.register_tokens) elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py index 1cb6cf79bc0b..2ae67b405f6c 100644 --- a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -27,6 +27,7 @@ Dinov2PatchEmbeddings, Dinov2PreTrainedModel, ) +from ... import initialization as init from ...configuration_utils import PreTrainedConfig from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...processing_utils import Unpack @@ -281,39 +282,19 @@ class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, Dinov2WithRegistersEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) - - module.mask_token.zero_() - module.register_tokens.zero_() + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + init.zeros_(module.register_tokens) elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) class Dinov2WithRegistersModel(Dinov2Model): diff --git a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py index 286cc87c3ca3..3ed4f38217dd 100644 --- a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +++ b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPoolingAndNoAttention from ...modeling_utils import PreTrainedModel @@ -194,16 +195,10 @@ class DINOv3ConvNextPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, (nn.LayerNorm, DINOv3ConvNextLayerNorm)): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, DINOv3ConvNextLayer): + super()._init_weights(module) + if isinstance(module, DINOv3ConvNextLayer): if module.gamma is not None: - module.gamma.fill_(self.config.layer_scale_init_value) + init.constant_(module.gamma, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index ad88e87671a0..09edeed17543 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling @@ -452,39 +453,19 @@ class DINOv3ViTPreTrainedModel(PreTrainedModel): def _init_weights(self, module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_( - module.weight.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, DINOv3ViTEmbeddings): - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) if module.config.num_register_tokens > 0: - module.register_tokens.copy_( - nn.init.trunc_normal_( - module.register_tokens.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.register_tokens.dtype) - ) - module.mask_token.zero_() + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) elif isinstance(module, DINOv3ViTLayerScale): - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index b773c8fb9b3d..7ae7d1632053 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -32,6 +32,7 @@ from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.pixtral.modeling_pixtral import PixtralAttention, rotate_half +from ... import initialization as init from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -344,42 +345,22 @@ class DINOv3ViTPreTrainedModel(Dinov2PreTrainedModel): } @torch.no_grad() - def _init_weights(self, module): + def _init_weights(self, module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_( - module.weight.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, DINOv3ViTEmbeddings): - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) if module.config.num_register_tokens > 0: - module.register_tokens.copy_( - nn.init.trunc_normal_( - module.register_tokens.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.register_tokens.dtype) - ) - module.mask_token.zero_() + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) elif isinstance(module, DINOv3ViTLayerScale): - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 0638a99124b6..4de5389e8908 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import get_activation from ...configuration_utils import PreTrainedConfig from ...integrations.deepspeed import is_deepspeed_zero3_enabled @@ -65,9 +66,9 @@ def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): with deepspeed.zero.GatheredParameters(out, modifier_rank=0): if torch.distributed.get_rank() == 0: - _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) + return _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) else: - _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) + return _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): @@ -76,6 +77,7 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() + return out class Embeddings(nn.Module): @@ -302,20 +304,15 @@ class DistilBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight + super()._init_weights(module) + if isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: + init.copy_( + module.position_embeddings.weight, + create_sinusoidal_embeddings( + self.config.max_position_embeddings, + self.config.dim, + torch.empty_like(module.position_embeddings.weight), + ), ) diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index c3cc3033d5bf..b9ebf9856264 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -29,6 +29,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -530,12 +531,12 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.zero_() + init.zeros_(module.A) elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): - module.input_residual.fill_(1.0) + init.ones_(module.input_residual) if hasattr(module, "post_attention_residual"): - module.post_attention_residual.fill_(1.0) + init.ones_(module.post_attention_residual) @auto_docstring diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 99ad20e90e37..008466fbf4ac 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig @@ -546,12 +547,12 @@ def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.zero_() + init.zeros_(module.A) elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): - module.input_residual.fill_(1.0) + init.ones_(module.input_residual) if hasattr(module, "post_attention_residual"): - module.post_attention_residual.fill_(1.0) + init.ones_(module.post_attention_residual) class DogeModel(MixtralModel): diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index e7d9422e69e2..7e8a6ae5d90f 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -25,6 +25,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -792,20 +793,14 @@ class DonutSwinPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, DonutSwinEmbeddings): + super()._init_weights(module) + if isinstance(module, DonutSwinEmbeddings): if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, DonutSwinSelfAttention): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) @auto_docstring diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index f2df365ffff4..ac5e156f9049 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -470,7 +471,7 @@ class Dots1PreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Dots1TopkRouter): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 6ed58db0184c..e2886062951f 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -105,21 +105,6 @@ class DPRReaderOutput(ModelOutput): class DPRPreTrainedModel(PreTrainedModel): _supports_sdpa = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - class DPREncoder(DPRPreTrainedModel): base_model_prefix = "bert_model" diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 6562e7891772..e5a3d7367c58 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -28,6 +28,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput @@ -735,16 +736,10 @@ class DPTPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) + super()._init_weights(module) if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)): - module.cls_token.zero_() - module.position_embeddings.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) @auto_docstring diff --git a/src/transformers/models/edgetam/configuration_edgetam.py b/src/transformers/models/edgetam/configuration_edgetam.py index a76547809572..d0034a642b2f 100644 --- a/src/transformers/models/edgetam/configuration_edgetam.py +++ b/src/transformers/models/edgetam/configuration_edgetam.py @@ -18,7 +18,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PreTrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig diff --git a/src/transformers/models/edgetam/modeling_edgetam.py b/src/transformers/models/edgetam/modeling_edgetam.py index ecc506857fee..b23370e47f14 100644 --- a/src/transformers/models/edgetam/modeling_edgetam.py +++ b/src/transformers/models/edgetam/modeling_edgetam.py @@ -32,6 +32,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -310,21 +311,10 @@ class EdgeTamPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + super()._init_weights(module) if isinstance(module, EdgeTamModel): if module.no_memory_embedding is not None: - module.no_memory_embedding.zero_() + init.zeros_(module.no_memory_embedding) # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding diff --git a/src/transformers/models/edgetam/modular_edgetam.py b/src/transformers/models/edgetam/modular_edgetam.py index 594cb6084aa0..0037d49af0ae 100644 --- a/src/transformers/models/edgetam/modular_edgetam.py +++ b/src/transformers/models/edgetam/modular_edgetam.py @@ -17,7 +17,6 @@ from typing import Optional, Union import torch -import torch.nn as nn import torch.utils.checkpoint from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig @@ -33,7 +32,9 @@ ) from transformers.utils.generic import TransformersKwargs, check_model_inputs +from ... import initialization as init from ...configuration_utils import PreTrainedConfig +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( auto_docstring, @@ -176,21 +177,10 @@ class EdgeTamFeedForward(Sam2FeedForward): class EdgeTamPreTrainedModel(Sam2PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + PreTrainedModel._init_weights(self, module) if isinstance(module, EdgeTamModel): if module.no_memory_embedding is not None: - module.no_memory_embedding.zero_() + init.zeros_(module.no_memory_embedding) @auto_docstring( diff --git a/src/transformers/models/edgetam_video/modeling_edgetam_video.py b/src/transformers/models/edgetam_video/modeling_edgetam_video.py index 6b648c0eaa6a..f6f0b02874fe 100644 --- a/src/transformers/models/edgetam_video/modeling_edgetam_video.py +++ b/src/transformers/models/edgetam_video/modeling_edgetam_video.py @@ -34,6 +34,7 @@ from transformers.utils.generic import OutputRecorder +from ... import initialization as init from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -780,30 +781,19 @@ class EdgeTamVideoPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, EdgeTamVideoLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, EdgeTamVideoModel): + super()._init_weights(module) + if isinstance(module, EdgeTamVideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.zero_() + init.zeros_(module.no_memory_positional_encoding) if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.zero_() + init.zeros_(module.memory_temporal_positional_encoding) if module.no_object_pointer is not None: - module.no_object_pointer.zero_() + init.zeros_(module.no_object_pointer) if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.zero_() + init.zeros_(module.occlusion_spatial_embedding_parameter) if isinstance(module, EdgeTamVideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.zero_() + init.zeros_(module.scale) class EdgeTamVideoInferenceCache: diff --git a/src/transformers/models/efficientloftr/modeling_efficientloftr.py b/src/transformers/models/efficientloftr/modeling_efficientloftr.py index 5f21d7cad00f..bdf6dd67ae48 100644 --- a/src/transformers/models/efficientloftr/modeling_efficientloftr.py +++ b/src/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -18,6 +18,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput @@ -679,12 +680,12 @@ class EfficientLoFTRPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 4c55a3058b98..b4635b3fcadb 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutputWithNoAttention, @@ -440,9 +441,9 @@ class EfficientNetPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 2fd477541986..a7698b31fad0 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -532,21 +532,6 @@ class ElectraPreTrainedModel(PreTrainedModel): "cross_attentions": ElectraCrossAttention, } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @dataclass @auto_docstring( @@ -1317,7 +1302,7 @@ def __init__(self, config): self.generator_predictions = ElectraGeneratorPredictions(config) self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) - self.init_weights() + self.post_init() def get_output_embeddings(self): return self.generator_lm_head diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 8e5eaf82ac31..bf0bb66c8880 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -29,6 +29,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -941,24 +942,25 @@ class Emu3VQVAE(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1.0) - nn.init.constant_(module.bias, 0.0) + init.constant_(module.weight, 1.0) + init.constant_(module.bias, 0.0) elif isinstance(module, nn.Embedding): - module.weight.normal_() - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 88d6451a6abe..a70fd220b582 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -22,6 +22,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutputWithPast @@ -691,24 +692,25 @@ class Emu3VQVAE(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1.0) - nn.init.constant_(module.bias, 0.0) + init.constant_(module.weight, 1.0) + init.constant_(module.bias, 0.0) elif isinstance(module, nn.Embedding): - module.weight.normal_() - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index a9449caa707f..a02ee98e9a30 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ( ModelOutput, @@ -458,21 +459,21 @@ class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.GroupNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if "weight" in name: - nn.init.xavier_uniform_(param) + init.xavier_uniform_(param) elif "bias" in name: - nn.init.constant_(param, 0.0) + init.constant_(param, 0.0) @auto_docstring( diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index e62cb8f623cc..fe3bec2f57cc 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -164,7 +164,7 @@ def __init__( ) # tie encoder, decoder weights if config set accordingly - self.tie_weights() + self.post_init() @torch.no_grad() def _init_weights(self, module): diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index e52e98364c09..344342b470ba 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -30,6 +30,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer @@ -1000,26 +1001,25 @@ class EomtPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=1) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=1) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, EomtLayerScale): if hasattr(module, "lambda1"): - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) elif isinstance(module, EomtEmbeddings): - module.cls_token.copy_( - nn.init.trunc_normal_(module.cls_token.to(torch.float32), mean=0.0, std=std).to(module.cls_token.dtype) - ) - module.register_tokens.zero_() + init.trunc_normal_(module.cls_token, mean=0.0, std=std) + init.zeros_(module.register_tokens) @auto_docstring( diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index 2c95affa154e..c22571acbb97 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ( ModelOutput, @@ -405,26 +406,25 @@ class EomtPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=1) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=1) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, EomtLayerScale): if hasattr(module, "lambda1"): - module.lambda1.fill_(self.config.layerscale_value) + init.constant_(module.lambda1, self.config.layerscale_value) elif isinstance(module, EomtEmbeddings): - module.cls_token.copy_( - nn.init.trunc_normal_(module.cls_token.to(torch.float32), mean=0.0, std=std).to(module.cls_token.dtype) - ) - module.register_tokens.zero_() + init.trunc_normal_(module.cls_token, mean=0.0, std=std) + init.zeros_(module.register_tokens) @auto_docstring( diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 24890d50ac2e..5f01a8afec6d 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -28,6 +28,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -549,21 +550,9 @@ class ErniePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, ErnieLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, ErnieLMPredictionHead): + init.zeros_(module.bias) @auto_docstring( diff --git a/src/transformers/models/ernie/modular_ernie.py b/src/transformers/models/ernie/modular_ernie.py index 4bf0440d7c16..f633a4c25091 100644 --- a/src/transformers/models/ernie/modular_ernie.py +++ b/src/transformers/models/ernie/modular_ernie.py @@ -21,6 +21,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_outputs import ( @@ -165,21 +166,9 @@ class ErniePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, ErnieLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, ErnieLMPredictionHead): + init.zeros_(module.bias) class ErnieModel(BertModel): diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 8ff07d9f638f..633f97fd2090 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -499,7 +500,7 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Ernie4_5_MoeStatics): - module.e_score_correction_bias.zero_() + init.zeros_(module.e_score_correction_bias) @auto_docstring diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index fe403f81afad..1cd1f216f674 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask @@ -240,7 +241,7 @@ class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Ernie4_5_MoeStatics): - module.e_score_correction_bias.zero_() + init.zeros_(module.e_score_correction_bias) @auto_docstring diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index a3f1fbdf58b5..c4b3e1ea11ec 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -554,19 +555,9 @@ class EsmPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, EsmLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, EsmLMHead): + init.zeros_(module.bias) def get_output_embeddings(self): # NOTE: get_output_embeddings() must return None to prevent accidental weight tying. diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 0c676d631b24..0e195b57e289 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -24,12 +24,12 @@ import torch.nn as nn from torch.nn import LayerNorm +from ... import initialization as init from ...integrations.deepspeed import is_deepspeed_available from ...modeling_outputs import ModelOutput from ...utils import ( ContextManagers, auto_docstring, - is_scipy_available, logging, ) from .modeling_esm import EsmModel, EsmPreTrainedModel @@ -207,33 +207,6 @@ def dict_multimap(fn, dicts): return new_dict -def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): - shape = weights.shape - scale = scale / max(1, shape[1]) - - if not is_scipy_available(): - logger.warning( - "This init requires scipy, but scipy was not found, default to an approximation that might not be" - " equivalent." - ) - std = math.sqrt(scale) - torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) - - else: - from scipy.stats import truncnorm - - std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) - samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) - samples = np.reshape(samples, shape) - weights.copy_(torch.tensor(samples, device=weights.device)) - - -def ipa_point_weights_init_(weights): - with torch.no_grad(): - softplus_inverse_1 = 0.541324854612918 - weights.fill_(softplus_inverse_1) - - class EsmFoldLinear(nn.Linear): """ A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. @@ -923,40 +896,47 @@ def _init_weights(self, module): if module.init_fn is not None: module.init_fn(module.weight, module.bias) elif module.init == "default": - trunc_normal_init_(module.weight, scale=1.0) + shape = module.weight.shape + scale = 1.0 / max(1, shape[1]) + std = math.sqrt(scale) + init.normal_(module.weight, std=std) elif module.init == "relu": - trunc_normal_init_(module.weight, scale=2.0) + shape = module.weight.shape + scale = 2.0 / max(1, shape[1]) + std = math.sqrt(scale) + init.normal_(module.weight, std=std) elif module.init == "glorot": - nn.init.xavier_uniform_(module.weight, gain=1) + init.xavier_uniform_(module.weight, gain=1) elif module.init == "gating": - module.weight.fill_(0.0) + init.zeros_(module.weight) if module.bias: - module.bias.fill_(1.0) + init.ones(module.bias) elif module.init == "normal": - torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") + init.kaiming_normal_(module.weight, nonlinearity="linear") elif module.init == "final": - module.weight.fill_(0.0) + init.zeros_(module.weight) elif isinstance(module, EsmFoldInvariantPointAttention): - ipa_point_weights_init_(module.head_weights) + softplus_inverse_1 = 0.541324854612918 + init.constant_(module.head_weights, softplus_inverse_1) elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): - torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) - torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) - torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) - torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) - torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) - torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) - torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) - torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) - - torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) - torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) - torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) - torch.nn.init.zeros_(module.seq_attention.o_proj.weight) - torch.nn.init.zeros_(module.seq_attention.o_proj.bias) - torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) - torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) - torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) - torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) + init.zeros_(module.tri_mul_in.linear_z.weight) + init.zeros_(module.tri_mul_in.linear_z.bias) + init.zeros_(module.tri_mul_out.linear_z.weight) + init.zeros_(module.tri_mul_out.linear_z.bias) + init.zeros_(module.tri_att_start.mha.linear_o.weight) + init.zeros_(module.tri_att_start.mha.linear_o.bias) + init.zeros_(module.tri_att_end.mha.linear_o.weight) + init.zeros_(module.tri_att_end.mha.linear_o.bias) + + init.zeros_(module.sequence_to_pair.o_proj.weight) + init.zeros_(module.sequence_to_pair.o_proj.bias) + init.zeros_(module.pair_to_sequence.linear.weight) + init.zeros_(module.seq_attention.o_proj.weight) + init.zeros_(module.seq_attention.o_proj.bias) + init.zeros_(module.mlp_seq.mlp[-2].weight) + init.zeros_(module.mlp_seq.mlp[-2].bias) + init.zeros_(module.mlp_pair.mlp[-2].weight) + init.zeros_(module.mlp_pair.mlp[-2].bias) else: super()._init_weights(module) @@ -975,12 +955,12 @@ def __init__(self, embed_dim, num_heads, head_width, gated=False): self.gated = gated if gated: self.g_proj = nn.Linear(embed_dim, embed_dim) - torch.nn.init.zeros_(self.g_proj.weight) - torch.nn.init.ones_(self.g_proj.bias) + init.zeros_(self.g_proj.weight) + init.ones_(self.g_proj.bias) self.rescale_factor = self.head_width**-0.5 - torch.nn.init.zeros_(self.o_proj.bias) + init.zeros_(self.o_proj.bias) def forward(self, x, mask=None, bias=None, indices=None): """ @@ -1053,8 +1033,8 @@ def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) - torch.nn.init.zeros_(self.proj.bias) - torch.nn.init.zeros_(self.o_proj.bias) + init.zeros_(self.proj.bias) + init.zeros_(self.o_proj.bias) def forward(self, sequence_state): """ @@ -2052,6 +2032,8 @@ def __init__(self, config): nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), ) + self.post_init() + @staticmethod def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor: # Remember that t is shifted from residue_constants by 1 (0 is padding). diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 994ce020f811..f4d0ce11255f 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -517,22 +518,6 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): ], } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): def __init__(self, config: SaProtConfig): @@ -1274,11 +1259,11 @@ def _init_weights(self, module): std = self.config.initializer_range super()._init_weights(module) if isinstance(module, EvollaSequenceAlignerCrossAttention): - module.gate_attention.zero_() - module.gate_ffw.zero_() - module.attention_norm.weight.fill_(1.0) + init.zeros_(module.gate_attention) + init.zeros_(module.gate_ffw) + init.ones_(module.attention_norm.weight) elif isinstance(module, EvollaSequenceCompressorResampler): - module.latents.normal_(mean=0.0, std=std) + init.normal_(module.latents, mean=0.0, std=std) class EvollaModel(EvollaPreTrainedModel): diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py index b31f6645c5be..ed6ff25b1cdb 100644 --- a/src/transformers/models/evolla/modular_evolla.py +++ b/src/transformers/models/evolla/modular_evolla.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_bidirectional_mask, create_causal_mask @@ -202,22 +203,6 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): ], } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): def __init__(self, config: SaProtConfig): @@ -738,11 +723,11 @@ def _init_weights(self, module): std = self.config.initializer_range PreTrainedModel._init_weights(self, module) if isinstance(module, EvollaSequenceAlignerCrossAttention): - module.gate_attention.zero_() - module.gate_ffw.zero_() - module.attention_norm.weight.fill_(1.0) + init.zeros_(module.gate_attention) + init.zeros_(module.gate_ffw) + init.ones_(module.attention_norm.weight) elif isinstance(module, EvollaSequenceCompressorResampler): - module.latents.normal_(mean=0.0, std=std) + init.normal_(module.latents, mean=0.0, std=std) class EvollaModel(EvollaPreTrainedModel): diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 4446169eb6c6..7df7c3a0b148 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F +from ... import initialization as init from ...activations import get_activation from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin @@ -672,26 +673,16 @@ class FalconPreTrainedModel(PreTrainedModel): _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn = True _supports_sdpa = True - _can_compile_fullgraph = True - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" - if isinstance(module, (nn.Linear, FalconLinear)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + super()._init_weights(module) + if isinstance(module, FalconLinear): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa @classmethod diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index f15f8ee1c3b1..a6fd7a5aba99 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -33,6 +33,7 @@ from transformers.activations import ACT2FN +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -1196,21 +1197,11 @@ class FalconH1PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Module): - for name, param in module.named_parameters(recurse=True): - if not param.requires_grad: - continue - if "layernorm" in name.lower() and "weight" in name: - # LayerNorm weights usually initialized to 1 - param.fill_(1.0) - elif "bias" in name: - param.zero_() - else: - try: - param.normal_(mean=0.0, std=std) - except Exception as e: - print(f"Skipping init for {name} due to error: {e}") + super()._init_weights(module) + if isinstance(module, FalconH1Mixer): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) def compute_mup_vector(config): diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 5371cab2bf20..386266b89bf3 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -45,6 +45,7 @@ segment_sum, ) +from ... import initialization as init from ...cache_utils import Cache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -922,21 +923,11 @@ class FalconH1PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Module): - for name, param in module.named_parameters(recurse=True): - if not param.requires_grad: - continue - if "layernorm" in name.lower() and "weight" in name: - # LayerNorm weights usually initialized to 1 - param.fill_(1.0) - elif "bias" in name: - param.zero_() - else: - try: - param.normal_(mean=0.0, std=std) - except Exception as e: - print(f"Skipping init for {name} due to error: {e}") + super()._init_weights(module) + if isinstance(module, FalconH1Mixer): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) def compute_mup_vector(config): diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index d7acfd8f1a53..1819b9d0af0b 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin @@ -577,14 +578,14 @@ def _init_weights(self, module): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale if self.config.time_step_init_scheme == "constant": - nn.init.constant_(module.dt_proj.weight, dt_init_std) + init.constant_(module.dt_proj.weight, dt_init_std) elif self.config.time_step_init_scheme == "random": - nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) dt = torch.exp( torch.rand(self.config.intermediate_size) @@ -593,14 +594,12 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_proj.bias.copy_(inv_dt) - module.dt_proj.bias._no_reinit = True + init.copy_(module.dt_proj.bias, inv_dt) - nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) if module.conv1d.bias is not None: - if not getattr(module.conv1d.bias, "_no_reinit", False): - nn.init.zeros_(module.conv1d.bias) - nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) + init.zeros_(module.conv1d.bias) + init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) if self.config.rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: @@ -617,15 +616,13 @@ def _init_weights(self, module): p /= math.sqrt(self.config.num_hidden_layers) if isinstance(module, nn.Linear): - if not getattr(module.weight, "_no_reinit", False): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, FalconMambaRMSNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) @dataclass diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 51f50d298e27..5db835f1fd94 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -995,24 +996,25 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1))) + init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1))) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-key, b=key) + init.uniform_(module.bias, a=-key, b=key) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Embedding): - module.weight.normal_() - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, FastSpeech2ConformerAttention): - nn.init.xavier_uniform_(module.pos_bias_u) - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_v) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, FastSpeech2ConformerEncoder): @@ -1404,14 +1406,6 @@ def __init__(self, config: FastSpeech2ConformerHifiGanConfig): # Initialize weights and apply final processing self.post_init() - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): @@ -1498,6 +1492,8 @@ def __init__(self, config: FastSpeech2ConformerWithHifiGanConfig): self.config = config + self.post_init() + @auto_docstring def forward( self, diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 5a22aff9c047..93ca89acd1e2 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import gelu, get_activation from ...cache_utils import DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -51,6 +52,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() + return out # Copied from transformers.models.xlm.modeling_xlm.get_masks @@ -676,20 +678,26 @@ def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: - nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0, std=self.config.embed_init_std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if isinstance(module, nn.Linear): if self.config is not None and self.config.init_std is not None: - nn.init.normal_(module.weight, mean=0, std=self.config.init_std) + init.normal_(module.weight, mean=0, std=self.config.init_std) if module.bias is not None: - nn.init.constant_(module.bias, 0.0) + init.constant_(module.bias, 0.0) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings: - create_sinusoidal_embeddings( - self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight + init.copy_( + module.position_embeddings.weight, + create_sinusoidal_embeddings( + self.config.max_position_embeddings, + self.config.emb_dim, + out=torch.empty_like(module.position_embeddings.weight), + ), ) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index bcca5d13d528..effb5111cf96 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -23,6 +23,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -668,29 +669,19 @@ class FlavaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, FlavaMaskedPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, FlavaMaskedPredictionHead): + init.zeros_(module.bias) elif isinstance(module, FlavaImageEmbeddings): - module.cls_token.zero_() - module.position_embeddings.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) elif isinstance(module, FlavaMultimodalModel): if module.use_cls_token: - module.cls_token.zero_() + init.zeros_(module.cls_token) elif isinstance(module, FlavaModel): - module.logit_scale.fill_(self.config.logit_scale_init_value) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) @auto_docstring diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index e55c8e02a150..b948b420ad63 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -434,10 +435,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, FlexOlmoExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, FlexOlmoTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 5cc5c870fa9e..a273167710cc 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -374,22 +374,6 @@ class FNetPreTrainedModel(PreTrainedModel): base_model_prefix = "fnet" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - # NOTE: Original code uses same initialization as weights for biases as well. - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @dataclass @auto_docstring( diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index a297378f5492..e55cde0bb98d 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput @@ -584,20 +585,14 @@ class FocalNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, FocalNetEmbeddings): + super()._init_weights(module) + if isinstance(module, FocalNetEmbeddings): if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) elif isinstance(module, FocalNetLayer): if self.config.use_layerscale: - module.gamma_1.fill_(self.config.layerscale_value) - module.gamma_2.fill_(self.config.layerscale_value) + init.constant_(module.gamma_1, self.config.layerscale_value) + init.constant_(module.gamma_2, self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 7cfc86744e74..cc4607023fbd 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -34,6 +34,7 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss, LayerNorm +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -223,18 +224,17 @@ class PretrainedFSMTModel(PreTrainedModel): def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, SinusoidalPositionalEmbedding): weight = module.get_embedding(*module.weight.shape, module.padding_idx) - weight = nn.Parameter(weight, requires_grad=False) - weight.detach_() - module.weight = weight + init.copy_(module.weight, weight) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) @property def dummy_inputs(self): diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index 7290c54e091a..f23e962347f4 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -682,20 +683,20 @@ def _init_weights(self, module): std = np.sqrt(1.0 / float(fan_in + fan_out)) else: std = self.config.initializer_std - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) if getattr(module, "bias", None) is not None: - nn.init.constant_(module.bias, 0.0) + init.constant_(module.bias, 0.0) elif classname == "FunnelRelMultiheadAttention": - nn.init.uniform_(module.r_w_bias, b=self.config.initializer_range) - nn.init.uniform_(module.r_r_bias, b=self.config.initializer_range) - nn.init.uniform_(module.r_kernel, b=self.config.initializer_range) - nn.init.uniform_(module.r_s_bias, b=self.config.initializer_range) - nn.init.uniform_(module.seg_embed, b=self.config.initializer_range) + init.uniform_(module.r_w_bias, b=self.config.initializer_range) + init.uniform_(module.r_r_bias, b=self.config.initializer_range) + init.uniform_(module.r_kernel, b=self.config.initializer_range) + init.uniform_(module.r_s_bias, b=self.config.initializer_range) + init.uniform_(module.seg_embed, b=self.config.initializer_range) elif classname == "FunnelEmbeddings": std = 1.0 if self.config.initializer_std is None else self.config.initializer_std - nn.init.normal_(module.word_embeddings.weight, std=std) + init.normal_(module.word_embeddings.weight, std=std) if module.word_embeddings.padding_idx is not None: - module.word_embeddings.weight[module.word_embeddings.padding_idx].zero_() + init.zeros_(module.word_embeddings.weight[module.word_embeddings.padding_idx]) class FunnelClassificationHead(nn.Module): diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 0adb011378a5..e0a969fea3a8 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -44,18 +44,6 @@ class FuyuPreTrainedModel(PreTrainedModel): _no_split_modules = [] _skip_keys_device_placement = "past_key_values" - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1acb039017dc..8834ba8c3564 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -25,6 +25,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -352,10 +353,9 @@ class GemmaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) @auto_docstring diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index d1b3070a5ad0..1445baef96bf 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_causal_mask @@ -397,10 +398,9 @@ class GemmaPreTrainedModel(LlamaPreTrainedModel): @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) - # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) class GemmaModel(LlamaModel): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6db748900375..69e486032107 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -384,10 +385,9 @@ class Gemma2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) @auto_docstring diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index e17cbddfb4c6..b7aef2390f20 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig @@ -470,10 +471,10 @@ class Gemma3PreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Gemma3MultiModalProjector): - module.mm_input_projection_weight.zero_() + init.zeros_(module.mm_input_projection_weight) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index b3d34234dcd8..47e1b49ac7bb 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask @@ -573,10 +574,10 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Gemma3MultiModalProjector): - module.mm_input_projection_weight.zero_() + init.zeros_(module.mm_input_projection_weight) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 1f8631e156ec..5111a69cebfe 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -28,6 +28,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -1605,11 +1606,11 @@ class Gemma3nPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Gemma3nAudioCumulativeGroupNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, Gemma3nAudioAttention): - module.per_dim_scale.zero_() + init.zeros_(module.per_dim_scale) elif isinstance(module, Gemma3nTextAltUp): - module.correct_output_scale.zero_() + init.zeros_(module.correct_output_scale) class Gemma3nRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 7a58d2fc6313..375bd93f2723 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig, layer_type_validation @@ -1879,11 +1880,11 @@ class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Gemma3nAudioCumulativeGroupNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, Gemma3nAudioAttention): - module.per_dim_scale.zero_() + init.zeros_(module.per_dim_scale) elif isinstance(module, Gemma3nTextAltUp): - module.correct_output_scale.zero_() + init.zeros_(module.correct_output_scale) @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.") diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 24ce421e1d5e..66f11594e27c 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -23,6 +23,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -392,20 +393,21 @@ class GitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, GitVisionEmbeddings): - nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range) - nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range) - nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range) + init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range) + init.normal_(module.patch_embedding.weight, std=self.config.initializer_range) + init.normal_(module.position_embedding.weight, std=self.config.initializer_range) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index de56ee2ad2a7..1c6575f3420a 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -496,7 +497,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Glm4MoeTopkRouter): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 6cff89f8a37b..692ccd2ea9fa 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -28,6 +28,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -557,7 +558,7 @@ class Glm4vMoePreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Glm4vMoeTextTopkRouter): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @dataclass diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 4255ae22f47f..42059a73bc4e 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -389,21 +389,6 @@ class GLPNPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class GLPNModel(GLPNPreTrainedModel): diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 578fff824817..ad07a2d61e6d 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -30,6 +30,7 @@ from transformers.utils.generic import check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -291,11 +292,11 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: - module.rel_pos_h.zero_() - module.rel_pos_w.zero_() + init.zeros_(module.rel_pos_h) + init.zeros_(module.rel_pos_w) elif isinstance(module, GotOcr2VisionEncoder): if module.pos_embed is not None: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) @dataclass diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 9312ed42ff38..36f55a80583d 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig from ...modeling_utils import PreTrainedModel @@ -294,11 +295,11 @@ def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: - module.rel_pos_h.zero_() - module.rel_pos_w.zero_() + init.zeros_(module.rel_pos_h) + init.zeros_(module.rel_pos_w) elif isinstance(module, GotOcr2VisionEncoder): if module.pos_embed is not None: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) class GotOcr2Model(LlavaModel): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 824a781c5b58..d2179fa0c335 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, get_activation from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -484,16 +485,17 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -505,7 +507,7 @@ def _init_weights(self, module): for name, p in module.named_parameters(): if name == "c_proj.weight": # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) @dataclass diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ce2b34e775c3..15080b672521 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -365,6 +366,7 @@ def __init__(self, *inputs, **kwargs): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" + super()._init_weights(module) if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -372,21 +374,9 @@ def _init_weights(self, module): # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - module.c_proj.weight.normal_( - mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + init.normal_( + module.c_proj.weight, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer) ) - module.c_proj._is_hf_initialized = True - elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index c591ef2ec914..84bf9f84b11f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -381,24 +381,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _can_compile_fullgraph = False # TODO: needs a hybrid cache - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear,)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class GPTNeoModel(GPTNeoPreTrainedModel): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index a906004dd41e..f723defcd088 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -47,26 +48,15 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox_japanese" _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, GPTNeoXJapaneseAttention): + super()._init_weights(module) + if isinstance(module, GPTNeoXJapaneseAttention): if module.dense_bias is not None: - module.dense_bias.zero_() + init.zeros_(module.dense_bias) # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoXJapanese diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 11e323544806..8e1ce9df0b97 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations.hub_kernels import use_kernel_forward_from_hub @@ -442,29 +443,18 @@ class GptOssPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): + super()._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Parameter): - module.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, GptOssRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, GptOssExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.gate_up_proj_bias.zero_() - module.down_proj.normal_(mean=0.0, std=std) - module.down_proj_bias.zero_() + if isinstance(module, GptOssExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.zeros_(module.gate_up_proj_bias) + init.normal_(module.down_proj, mean=0.0, std=std) + init.zeros_(module.down_proj_bias) elif isinstance(module, GptOssAttention): - module.sinks.normal_(mean=0.0, std=std) + init.normal_(module.sinks, mean=0.0, std=std) elif isinstance(module, GptOssTopKRouter): - module.weight.normal_(mean=0.0, std=std) - module.bias.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) + init.normal_(module.bias, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 4f33517001b3..57acfea8df64 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -19,6 +19,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask @@ -26,7 +27,7 @@ MoeModelOutputWithPast, ) from ...modeling_rope_utils import dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -358,29 +359,18 @@ class GptOssPreTrainedModel(LlamaPreTrainedModel): @torch.no_grad() def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Parameter): - module.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, GptOssRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, GptOssExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.gate_up_proj_bias.zero_() - module.down_proj.normal_(mean=0.0, std=std) - module.down_proj_bias.zero_() + if isinstance(module, GptOssExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.zeros_(module.gate_up_proj_bias) + init.normal_(module.down_proj, mean=0.0, std=std) + init.zeros_(module.down_proj_bias) elif isinstance(module, GptOssAttention): - module.sinks.normal_(mean=0.0, std=std) + init.normal_(module.sinks, mean=0.0, std=std) elif isinstance(module, GptOssTopKRouter): - module.weight.normal_(mean=0.0, std=std) - module.bias.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) + init.normal_(module.bias, mean=0.0, std=std) class GptOssModel(MixtralModel): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 8d8004577e57..2c04b1af0c67 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -444,24 +444,6 @@ class GPTJPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_param_buffer_assignment = False - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear,)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class GPTJModel(GPTJPreTrainedModel): diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 07e7c2573e99..7aad1523def9 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput @@ -289,21 +290,9 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, GraniteSpeechEncoderProjector): - module.query.normal_() + super()._init_weights(module) + if isinstance(module, GraniteSpeechEncoderProjector): + init.normal_(module.query) @auto_docstring( diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 0b3a893b9883..f722ad416a2f 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -465,7 +466,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/granitemoe/modular_granitemoe.py b/src/transformers/models/granitemoe/modular_granitemoe.py index 53692da91773..94b79494a0f2 100644 --- a/src/transformers/models/granitemoe/modular_granitemoe.py +++ b/src/transformers/models/granitemoe/modular_granitemoe.py @@ -18,6 +18,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask @@ -152,7 +153,7 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, GraniteMoeParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index dc39370b7559..760e7372d220 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -28,6 +28,7 @@ from transformers.activations import ACT2FN +from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -1205,13 +1206,13 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeHybridParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if isinstance(module, GraniteMoeHybridMambaLayer): - module.dt_bias.fill_(1.0) - module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) - module.D.fill_(1.0) + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) elif isinstance(module, GraniteMoeHybridRMSNormGated): - module.weight.fill_(1.0) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index ed0676752fbc..6ee9999a0afb 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...cache_utils import Cache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast @@ -180,11 +181,11 @@ class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeHybridMambaLayer): - module.dt_bias.fill_(1.0) - module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) - module.D.fill_(1.0) + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) + init.ones_(module.D) elif isinstance(module, GraniteMoeHybridRMSNormGated): - module.weight.fill_(1.0) + init.ones_(module.weight) class GraniteMoeHybridModel(GraniteMoeSharedModel): diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index d2f228d0f197..606a59390e6e 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -471,7 +472,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeSharedParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) class GraniteMoeSharedRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 5333f222fb39..d1e8ea3e9fb8 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ModelOutput, is_timm_available, requires_backends from ...integrations import use_kernel_forward_from_hub @@ -1374,10 +1375,10 @@ def _init_weights(self, module): std = self.config.init_std if isinstance(module, GroundingDinoLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) elif isinstance(module, GroundingDinoMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -1390,50 +1391,51 @@ def _init_weights(self, module): ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, GroundingDinoBiMultiHeadAttention): - nn.init.xavier_uniform_(module.vision_proj.weight) - module.vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.text_proj.weight) - module.text_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.values_vision_proj.weight) - module.values_vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.values_text_proj.weight) - module.values_text_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.out_vision_proj.weight) - module.out_vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.out_text_proj.weight) - module.out_text_proj.bias.fill_(0) + init.xavier_uniform_(module.vision_proj.weight) + init.zeros_(module.vision_proj.bias) + init.xavier_uniform_(module.text_proj.weight) + init.zeros_(module.text_proj.bias) + init.xavier_uniform_(module.values_vision_proj.weight) + init.zeros_(module.values_vision_proj.bias) + init.xavier_uniform_(module.values_text_proj.weight) + init.zeros_(module.values_text_proj.bias) + init.xavier_uniform_(module.out_vision_proj.weight) + init.zeros_(module.out_vision_proj.bias) + init.xavier_uniform_(module.out_text_proj.weight) + init.zeros_(module.out_text_proj.bias) elif isinstance(module, GroundingDinoFusionLayer): - module.vision_param.fill_(1e-4) - module.text_param.fill_(1e-4) + init.constant_(module.vision_param, 1e-4) + init.constant_(module.text_param, 1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, GroundingDinoMLPPredictionHead): - nn.init.constant_(module.layers[-1].weight, 0) - nn.init.constant_(module.layers[-1].bias, 0) + init.constant_(module.layers[-1].weight, 0) + init.constant_(module.layers[-1].bias, 0) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): - nn.init.normal_(module.level_embed) + init.normal_(module.level_embed) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, GroundingDinoDecoder): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 0c51c9052afc..7e0877456e60 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -754,31 +755,31 @@ def _init_weights(self, module): init_range = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=init_range) + init.normal_(module.weight, mean=0.0, std=init_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) factor = self.config.initializer_factor if isinstance(module, GroupViTTextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, GroupViTAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, GroupViTMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) class GroupViTVisionEncoder(nn.Module): diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index 85cfa57ca7d8..54a6931a6431 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -782,20 +783,20 @@ def _init_weights(self, module) -> None: std = self.config.initializer_range if isinstance(module, HieraEmbeddings): - nn.init.trunc_normal_(module.position_embeddings, std=std) + init.trunc_normal_(module.position_embeddings, std=std) elif isinstance(module, HieraDecoder): - nn.init.trunc_normal_(module.mask_token, std=std) - nn.init.trunc_normal_(module.decoder_position_embeddings, std=std) + init.trunc_normal_(module.mask_token, std=std) + init.trunc_normal_(module.decoder_position_embeddings, std=std) elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): - nn.init.trunc_normal_(module.weight, std=std) + init.trunc_normal_(module.weight, std=std) if module.bias is not None: - nn.init.constant_(module.bias, std) + init.constant_(module.bias, std) elif isinstance(module, nn.LayerNorm): - nn.init.constant_(module.bias, std) - nn.init.constant_(module.weight, self.config.layer_norm_init) + init.constant_(module.bias, std) + init.constant_(module.weight, self.config.layer_norm_init) class HieraPooler(nn.Module): diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 9967c09b2fbe..614492f23ae0 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -27,6 +27,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -641,33 +642,33 @@ class HubertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, HubertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.uniform_() + init.uniform_(module.masked_spec_embed) elif isinstance(module, HubertForSequenceClassification): if hasattr(module, "layer_weights"): - module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) + init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index a906db997877..bc8aadb25e27 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput @@ -138,33 +139,33 @@ class HubertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, HubertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.uniform_() + init.uniform_(module.masked_spec_embed) elif isinstance(module, HubertForSequenceClassification): if hasattr(module, "layer_weights"): - module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) + init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index b55d9e3ccf5e..556e91b453ea 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -290,18 +290,6 @@ class HunYuanDenseV1PreTrainedModel(PreTrainedModel): "attentions": HunYuanDenseV1Attention, } - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - class HunYuanDenseV1RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index 945d2d1c27b1..d41b5f236759 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -120,17 +120,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): class HunYuanDenseV1PreTrainedModel(LlamaPreTrainedModel): - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + pass class HunYuanDenseV1RotaryEmbedding(LlamaRotaryEmbedding): diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index d62058cb7ab9..1d2fb7d936ca 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import gelu from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -589,18 +590,19 @@ class IBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (QuantLinear, nn.Linear)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (QuantEmbedding, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, (IntLayerNorm, nn.LayerNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, IBertLMHead): - module.bias.zero_() + init.zeros_(module.bias) def resize_token_embeddings(self, new_num_tokens=None): raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.") diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 1e7fdb05360c..cb57a0fd8d20 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -836,34 +837,21 @@ def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only # inference and fine-tuning - so the proper init weights code has been removed - the m4 code # base should be used for training from scratch and it contains the correct code. - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, IdeficsRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, IdeficsVisionEmbeddings): - module.class_embedding.normal_() + super()._init_weights(module) + if isinstance(module, IdeficsVisionEmbeddings): + init.normal_(module.class_embedding) elif isinstance(module, IdeficsGatedCrossAttentionLayer): if self.config.alpha_initializer == "zeros": - module.alpha_cross_attn.zero_() - module.alpha_dense.zero_() + init.zeros_(module.alpha_cross_attn) + init.zeros_(module.alpha_dense) elif self.config.alpha_initializer == "ones": - module.alpha_cross_attn.fill_(1.0) - module.alpha_dense.fill_(1.0) + init.ones_(module.alpha_cross_attn) + init.ones_(module.alpha_dense) elif self.config.alpha_initializer in {"normal", "gaussian", "random"}: - module.alpha_cross_attn.normal_(mean=0.0, std=self.config.alphas_initializer_range) - module.alpha_dense.normal_(mean=0.0, std=self.config.alphas_initializer_range) + init.normal_(module.alpha_cross_attn, mean=0.0, std=self.config.alphas_initializer_range) + init.normal_(module.alpha_dense, mean=0.0, std=self.config.alphas_initializer_range) elif isinstance(module, IdeficsPerceiverResampler): - module.latents.normal_() + init.normal_(module.latents) @auto_docstring diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 2caaf2ab2706..dcfa6d9cd23b 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -419,27 +420,11 @@ class Idefics2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, Idefics2RMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, nn.MultiheadAttention): - module._reset_parameters() # native torch init - elif isinstance(module, Idefics2MultiheadAttentionPoolingHead): - module.probe.normal_() + super()._init_weights(module) + if isinstance(module, Idefics2MultiheadAttentionPoolingHead): + init.normal_(module.probe) elif isinstance(module, Idefics2PerceiverResampler): - module.latents.fill_(1.0) + init.ones_(module.latents) @auto_docstring( diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 6a57af9d49d8..38d6f29c3f04 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -430,27 +430,8 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_attention_backend = True - @torch.no_grad() - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, Idefics3RMSNorm): - module.weight.fill_(1.0) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index a8c5878f35ef..c7f2fd39f824 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput @@ -328,28 +329,16 @@ class IJepaPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, IJepaEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) class IJepaEncoder(nn.Module): diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 095945a3f39d..9ddde3f87e4b 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -5,6 +5,7 @@ from transformers.models.ijepa.configuration_ijepa import IJepaConfig +from ... import initialization as init from ...modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_int @@ -91,28 +92,16 @@ class IJepaPreTrainedModel(ViTPreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, IJepaEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) class IJepaModel(IJepaPreTrainedModel, ViTModel): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index b4c844eb4f49..c5effca80166 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -366,22 +367,10 @@ class ImageGPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ImageGPTBlock"] - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, ImageGPTLayerNorm): - module.weight.fill_(1.0) + super()._init_weights(module) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -393,7 +382,7 @@ def _init_weights(self, module): for name, p in module.named_parameters(): if "c_proj" in name and "weight" in name: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) @auto_docstring diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index a8f618a43b69..dcb05efb752f 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -26,6 +26,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...masking_utils import create_bidirectional_mask, create_causal_mask @@ -203,9 +204,9 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -218,7 +219,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -254,7 +255,7 @@ class InformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, InformerSinusoidalPositionalEmbedding): - module._init_weight() + init.copy_(module.weight, module.create_weight()) def eager_attention_forward( diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 0066f41a3e47..7c5a4e85f392 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -90,7 +91,7 @@ class InformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, InformerSinusoidalPositionalEmbedding): - module._init_weight() + init.copy_(module.weight, module.create_weight()) class InformerAttention(BartAttention): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 25b54f2d2b9f..5432283b4b0c 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -327,22 +328,13 @@ class InstructBlipPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) factor = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=factor) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor) - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, InstructBlipVisionEmbeddings): - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + if isinstance(module, InstructBlipVisionEmbeddings): + init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)): - module.query_tokens.zero_() + init.zeros_(module.query_tokens) # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index f48baf11b925..0375ec5042cf 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -150,22 +151,13 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) factor = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=factor) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor) - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, InstructBlipVideoVisionEmbeddings): - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + if isinstance(module, InstructBlipVideoVisionEmbeddings): + init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)): - module.query_tokens.zero_() + init.zeros_(module.query_tokens) # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32 diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 3d41bb9aba32..0ec57c60a20c 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -28,6 +28,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -416,14 +417,14 @@ def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, InternVLVisionEmbeddings): - module.cls_token.zero_() + init.zeros_(module.cls_token) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, InternVLVisionLayer): - module.lambda_1.fill_(self.config.layer_scale_init_value) - module.lambda_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.lambda_1, self.config.layer_scale_init_value) + init.constant_(module.lambda_2, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 62ee383ce566..f276a49ec5ce 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_layers import GradientCheckpointingLayer @@ -373,14 +374,14 @@ def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, InternVLVisionEmbeddings): - module.cls_token.zero_() + init.zeros_(module.cls_token) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, InternVLVisionLayer): - module.lambda_1.fill_(self.config.layer_scale_init_value) - module.lambda_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.lambda_1, self.config.layer_scale_init_value) + init.constant_(module.lambda_2, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 609fff07ab80..450c06d165a9 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -29,6 +29,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -728,8 +729,8 @@ def _init_weights(self, module): if isinstance(module, JambaMambaMixer): A = torch.arange(1, module.ssm_state_size + 1)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 1c362c3f802a..4cbe69d0e1b7 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -23,6 +23,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -613,8 +614,8 @@ def _init_weights(self, module): if isinstance(module, JambaMambaMixer): A = torch.arange(1, module.ssm_state_size + 1)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) @auto_docstring diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 28a3dc151d70..afa139d66f49 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -585,20 +586,11 @@ class JetMoePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, JetMoeRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, JetMoeParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + super()._init_weights(module) + if isinstance(module, JetMoeParallelExperts): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, JetMoeMoA | JetMoeMoE): - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/jetmoe/modular_jetmoe.py b/src/transformers/models/jetmoe/modular_jetmoe.py index 82c8e582d070..41babe708c63 100644 --- a/src/transformers/models/jetmoe/modular_jetmoe.py +++ b/src/transformers/models/jetmoe/modular_jetmoe.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -30,7 +31,7 @@ GenericForSequenceClassification, ) from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs @@ -438,20 +439,11 @@ class JetMoePreTrainedModel(MixtralPreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, JetMoeRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, JetMoeParallelExperts): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + PreTrainedModel._init_weights(self, module) + if isinstance(module, JetMoeParallelExperts): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, JetMoeMoA | JetMoeMoE): - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 5726eeacaad6..3d2c64865201 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -1134,44 +1135,44 @@ def _init_weights(self, module: nn.Module): std = self.config.text_config.init_std if isinstance(module, Kosmos2VisionEmbeddings): - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, Kosmos2VisionAttention): in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, Kosmos2VisionMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, KosmosTextAttention): - nn.init.normal_(module.q_proj.weight, std=std) - nn.init.normal_(module.k_proj.weight, std=std) - nn.init.normal_(module.v_proj.weight, std=std) - nn.init.normal_(module.out_proj.weight, std=std) + init.normal_(module.q_proj.weight, std=std) + init.normal_(module.k_proj.weight, std=std) + init.normal_(module.v_proj.weight, std=std) + init.normal_(module.out_proj.weight, std=std) elif isinstance(module, Kosmos2TextFFN): - nn.init.normal_(module.fc1.weight, std=std) - nn.init.normal_(module.fc2.weight, std=std) + init.normal_(module.fc1.weight, std=std) + init.normal_(module.fc2.weight, std=std) elif isinstance(module, Kosmos2TextForCausalLM): - nn.init.normal_(module.lm_head.weight, std=std) + init.normal_(module.lm_head.weight, std=std) elif isinstance(module, Kosmos2ImageToTextProjection): - nn.init.normal_(module.dense.weight, std=std) - nn.init.normal_(module.latent_query) + init.normal_(module.dense.weight, std=std) + init.normal_(module.latent_query) elif isinstance(module, Kosmos2TextTransformer): - module.embed_tokens.weight.normal_(mean=0.0, std=std) + init.normal_(module.embed_tokens.weight, mean=0.0, std=std) if module.embed_tokens.padding_idx is not None: - module.embed_tokens.weight[module.embed_tokens.padding_idx].zero_() + init.zeros_(module.embed_tokens.weight[module.embed_tokens.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class Kosmos2VisionModel(Kosmos2PreTrainedModel): diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index c0313f33eca2..a9a456d68d15 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -1238,19 +1239,20 @@ def _init_weights(self, module): elif isinstance(self, (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration)): std = self.config.text_config.init_std if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, (nn.LayerNorm, Kosmos2_5LayerNorm)): - module.weight.fill_(1.0) + init.ones_(module.weight) if getattr(module, "bias", None) is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, Kosmos2_5ImageToTextProjection): - module.latent_query.normal_(mean=0.0, std=1.0) + init.normal_(module.latent_query, mean=0.0, std=1.0) class Kosmos2_5VisionModel(Kosmos2_5PreTrainedModel): diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 10d4b4c7eb57..af6ffaf04674 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -27,6 +27,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationConfig, GenerationMixin @@ -54,25 +55,6 @@ logger = logging.get_logger(__name__) -class KyutaiSpeechToTextRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - # Ignore copy - def forward(self, x): - output = self._norm(x.float()) - output = output * self.weight.float() - return output.type_as(x) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" - - class KyutaiSpeechToTextFlexibleLinear(nn.Module): def __init__(self, input_size, output_size, num_layers): super().__init__() @@ -126,20 +108,9 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, KyutaiSpeechToTextFlexibleLinear): - module.weight.normal_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, KyutaiSpeechToTextRMSNorm): - module.weight.fill_(1.0) + super()._init_weights(module) + if isinstance(module, KyutaiSpeechToTextFlexibleLinear): + init.normal_(module.weight) class KyutaiSpeechToTextConv1dPaddingCache: @@ -251,6 +222,25 @@ def forward(self, input_ids): return inputs_embeds +class KyutaiSpeechToTextRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + # Ignore copy + def forward(self, x): + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + class KyutaiSpeechToTextLinear(nn.Module): def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False): super().__init__() diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 146c395aa9ee..57537beb1f47 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -427,19 +428,9 @@ class LayoutLMPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, LayoutLMLayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, LayoutLMLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, LayoutLMLMPredictionHead): + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index e276407a720b..c98b84ddd39c 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -461,24 +462,14 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, LayoutLMv2SelfAttention): + super()._init_weights(module) + if isinstance(module, LayoutLMv2SelfAttention): if self.config.fast_qkv: - module.q_bias.zero_() - module.v_bias.zero_() + init.zeros_(module.q_bias) + init.zeros_(module.v_bias) elif isinstance(module, LayoutLMv2Model): if hasattr(module, "visual_segment_embedding"): - module.visual_segment_embedding.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.visual_segment_embedding, mean=0.0, std=self.config.initializer_range) def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index a04875e72646..bb020f1614ab 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -206,21 +207,11 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, LayoutLMv3Model): + super()._init_weights(module) + if isinstance(module, LayoutLMv3Model): if self.config.visual_embed: - module.cls_token.zero_() - module.pos_embed.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.pos_embed) class LayoutLMv3SelfAttention(nn.Module): @@ -600,7 +591,7 @@ def __init__(self, config): self.encoder = LayoutLMv3Encoder(config) - self.init_weights() + self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings @@ -890,7 +881,7 @@ def __init__(self, config): else: self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) - self.init_weights() + self.post_init() @auto_docstring def forward( @@ -989,7 +980,7 @@ def __init__(self, config): self.layoutlmv3 = LayoutLMv3Model(config) self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False) - self.init_weights() + self.post_init() @auto_docstring def forward( @@ -1108,7 +1099,7 @@ def __init__(self, config): self.layoutlmv3 = LayoutLMv3Model(config) self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) - self.init_weights() + self.post_init() @auto_docstring def forward( diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 418f60f77a61..40d6b82c9ede 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1067,18 +1067,6 @@ class LEDPreTrainedModel(PreTrainedModel): base_model_prefix = "led" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index ca7cc7589be7..250070b44b8f 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -472,17 +472,6 @@ class LevitPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["LevitResidualLayer"] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class LevitModel(LevitPreTrainedModel): diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index ec924e5000d6..78e0546bd51e 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -500,21 +500,6 @@ class LiltPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class LiltModel(LiltPreTrainedModel): diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index b558135f2ef4..355a4f145779 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -24,6 +24,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -474,30 +475,18 @@ class Llama4PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): + super()._init_weights(module) std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, Llama4TextRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, Llama4TextExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + if isinstance(module, Llama4TextExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, Llama4VisionModel): - module.class_embedding.normal_(std=module.scale) - module.positional_embedding_vlm.normal_(std=module.scale) + init.normal_(module.class_embedding, std=module.scale) + init.normal_(module.positional_embedding_vlm, std=module.scale) @auto_docstring diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index d494e0400f2a..cd155dbce452 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -240,12 +241,12 @@ def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, LlavaNextModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.normal_(mean=0.0, std=embed_std) + init.normal_(module.image_newline, mean=0.0, std=embed_std) @auto_docstring( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index e4bb765e4a2a..34feba5e18c7 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -181,12 +182,12 @@ def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, LlavaNextVideoModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.normal_(mean=0.0, std=embed_std) + init.normal_(module.image_newline, mean=0.0, std=embed_std) def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 193ab3a2ea04..1bd91aee4715 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -122,12 +123,12 @@ def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, LlavaOnevisionModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.normal_(mean=0.0, std=embed_std) + init.normal_(module.image_newline, mean=0.0, std=embed_std) class LlavaOnevisionMultiModalProjector(nn.Module): diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 516bfee99677..8c6b9add8cdb 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -561,10 +562,10 @@ class LongcatFlashPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, LongcatFlashTopkRouter): - module.classifier.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.classifier.weight, mean=0.0, std=self.config.initializer_range) if isinstance(module, LongcatFlashExperts): - module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) - module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/longcat_flash/modular_longcat_flash.py b/src/transformers/models/longcat_flash/modular_longcat_flash.py index 6a9148dab617..9e4b6486631b 100644 --- a/src/transformers/models/longcat_flash/modular_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modular_longcat_flash.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask @@ -345,10 +346,10 @@ class LongcatFlashPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, LongcatFlashTopkRouter): - module.classifier.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.classifier.weight, mean=0.0, std=self.config.initializer_range) if isinstance(module, LongcatFlashExperts): - module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) - module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) class LongcatFlashModel(DeepseekV3Model): diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 1168e9366f1d..f1fbe1658084 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1292,21 +1292,6 @@ class LongformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LongformerSelfAttention"] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class LongformerModel(LongformerPreTrainedModel): diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 0aea13dc01b8..df716432bc26 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -1181,40 +1182,42 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, LongT5LayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, LongT5DenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, LongT5DenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) if isinstance(module, LongT5TransientGlobalAttention): - module.global_relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_( + module.global_relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5) + ) # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 79b63ac33d86..fc6fd7a80810 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -770,19 +771,20 @@ class LukePreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): if module.embedding_dim == 1: # embedding for bias parameters - module.weight.zero_() + init.zeros_(module.weight) else: - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring( diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 69fc0eb1b71a..75d31c0399b3 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, SmoothL1Loss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -677,19 +678,9 @@ class LxmertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, LxmertLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, LxmertLMPredictionHead): + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 60f41cd6ad00..676938c6789f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -512,25 +512,9 @@ class M2M100PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _can_compile_fullgraph = False - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - class M2M100Encoder(M2M100PreTrainedModel): """ diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index f17bd66649af..7437082349ac 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin @@ -513,14 +514,14 @@ def _init_weights(self, module): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale if self.config.time_step_init_scheme == "constant": - nn.init.constant_(module.dt_proj.weight, dt_init_std) + init.constant_(module.dt_proj.weight, dt_init_std) elif self.config.time_step_init_scheme == "random": - nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) dt = torch.exp( torch.rand(self.config.intermediate_size) @@ -529,14 +530,12 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_proj.bias.copy_(inv_dt) - module.dt_proj.bias._no_reinit = True + init.copy_(module.dt_proj.bias, inv_dt) - nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) if module.conv1d.bias is not None: - if not getattr(module.conv1d.bias, "_no_reinit", False): - nn.init.zeros_(module.conv1d.bias) - nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) + init.zeros_(module.conv1d.bias) + init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) if self.config.rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: @@ -553,15 +552,13 @@ def _init_weights(self, module): p /= math.sqrt(self.config.num_hidden_layers) if isinstance(module, nn.Linear): - if not getattr(module.weight, "_no_reinit", False): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, MambaRMSNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) @dataclass diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 716f62e5d1b1..d2e7add0d4f3 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer @@ -725,8 +726,8 @@ def _init_weights(self, module): # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.config.num_heads + 1) - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) dt = torch.exp( torch.rand(self.config.num_heads) @@ -736,14 +737,12 @@ def _init_weights(self, module): # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_bias.copy_(inv_dt) - module.dt_bias._no_reinit = True + init.copy_(module.dt_bias, inv_dt) - nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) if module.conv1d.bias is not None: - if not getattr(module.conv1d.bias, "_no_reinit", False): - nn.init.zeros_(module.conv1d.bias) - nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) + init.zeros_(module.conv1d.bias) + init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) if self.config.rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: @@ -760,15 +759,13 @@ def _init_weights(self, module): p /= math.sqrt(self.config.num_hidden_layers) if isinstance(module, nn.Linear): - if not getattr(module.weight, "_no_reinit", False): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=std) + init.normal_(module.weight, std=std) @dataclass diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 11adf1cdbe20..3a211ddbe715 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -71,9 +72,9 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -86,7 +87,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -446,21 +447,10 @@ class MarianPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True @torch.no_grad() - def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, MarianSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, MarianSinusoidalPositionalEmbedding): + init.copy_(module.weight, module.create_weight()) @property def dummy_inputs(self): diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 60be191c8285..c5daabe2abde 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -513,19 +514,9 @@ class MarkupLMPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MarkupLMLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, MarkupLMLMPredictionHead): + init.zeros_(module.bias) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 24b1d1078b82..5ac2b4a5f2ea 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer @@ -2111,11 +2112,11 @@ def _init_weights(self, module: nn.Module): if module.input_projections is not None: for input_projection in module.input_projections: if not isinstance(input_projection, nn.Sequential): - nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) - nn.init.constant_(input_projection.bias, 0) + init.xavier_uniform_(input_projection.weight, gain=xavier_std) + init.constant_(input_projection.bias, 0) elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( @@ -2125,42 +2126,43 @@ def _init_weights(self, module: nn.Module): ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer): for p in module.parameters(): if p.dim() > 1: - nn.init.xavier_uniform_(p, gain=xavier_std) - module.cross_attn.in_proj_bias.zero_() + init.xavier_uniform_(p, gain=xavier_std) + init.zeros_(module.cross_attn.in_proj_bias) elif isinstance(module, Mask2FormerPixelDecoder): - nn.init.normal_(module.level_embed, std=0) + init.normal_(module.level_embed, std=0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if hasattr(module, "reference_points"): - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) @auto_docstring diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index b2dc868f0138..22a29fe96b70 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -1442,37 +1443,38 @@ def _init_weights(self, module: nn.Module): std = self.config.init_std if isinstance(module, MaskFormerTransformerModule): if module.input_projection is not None: - nn.init.xavier_uniform_(module.input_projection.weight, gain=xavier_std) - nn.init.constant_(module.input_projection.bias, 0) + init.xavier_uniform_(module.input_projection.weight, gain=xavier_std) + init.constant_(module.input_projection.bias, 0) # FPN elif isinstance(module, MaskFormerFPNModel): - nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std) + init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std) elif isinstance(module, MaskFormerFPNLayer): - nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std) + init.xavier_uniform_(module.proj[0].weight, gain=xavier_std) elif isinstance(module, MaskFormerFPNConvLayer): - nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std) + init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std) # The MLP head elif isinstance(module, MaskformerMLPPredictionHead): # I was not able to find the correct initializer in the original implementation # we'll use xavier for submodule in module.modules(): if isinstance(submodule, nn.Linear): - nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) - nn.init.constant_(submodule.bias, 0) + init.xavier_uniform_(submodule.weight, gain=xavier_std) + init.constant_(submodule.bias, 0) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) # copied from DETR if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) @auto_docstring diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index b735b419c10d..05a58d4527df 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -24,6 +24,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ModelOutput from ...modeling_layers import GradientCheckpointingLayer @@ -704,18 +705,12 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MaskFormerSwinEmbeddings): + super()._init_weights(module) + if isinstance(module, MaskFormerSwinEmbeddings): if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, MaskFormerSwinSelfAttention): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): @@ -731,6 +726,8 @@ def __init__(self, config, add_pooling_layer=True): self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + self.post_init() + def get_input_embeddings(self): return self.embeddings.patch_embeddings diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 08cde27d7cce..b4a3ed8ebd02 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -476,24 +476,8 @@ class MBartPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index d7a869cfd89a..8dfea90bbae1 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -524,15 +525,9 @@ class MegatronBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MegatronBertLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, MegatronBertLMPredictionHead): + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/metaclip_2/modeling_metaclip_2.py b/src/transformers/models/metaclip_2/modeling_metaclip_2.py index c66bababfbe5..2407a4147fb1 100644 --- a/src/transformers/models/metaclip_2/modeling_metaclip_2.py +++ b/src/transformers/models/metaclip_2/modeling_metaclip_2.py @@ -11,6 +11,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer @@ -303,57 +304,57 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, MetaClip2TextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, MetaClip2VisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, MetaClip2Attention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, MetaClip2MLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, MetaClip2Model): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2VisionModelWithProjection): - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2TextModelWithProjection): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2ForImageClassification): - nn.init.normal_( + init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class MetaClip2Encoder(nn.Module): diff --git a/src/transformers/models/metaclip_2/modular_metaclip_2.py b/src/transformers/models/metaclip_2/modular_metaclip_2.py index 79cdf35be7e9..49914ab1e6b3 100644 --- a/src/transformers/models/metaclip_2/modular_metaclip_2.py +++ b/src/transformers/models/metaclip_2/modular_metaclip_2.py @@ -3,6 +3,7 @@ import torch from torch import nn +from ... import initialization as init from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...processing_utils import Unpack @@ -222,57 +223,57 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, MetaClip2TextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, MetaClip2VisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, MetaClip2Attention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, MetaClip2MLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, MetaClip2Model): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2VisionModelWithProjection): - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2TextModelWithProjection): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, MetaClip2ForImageClassification): - nn.init.normal_( + init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class MetaClip2TextTransformer(CLIPTextTransformer): diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 819d5d38fcc1..73176c800d1a 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -289,15 +290,15 @@ def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range if isinstance(module, MgpstrEmbeddings): - nn.init.trunc_normal_(module.pos_embed, mean=0.0, std=std) - nn.init.trunc_normal_(module.cls_token, mean=0.0, std=std) + init.trunc_normal_(module.pos_embed, mean=0.0, std=std) + init.trunc_normal_(module.cls_token, mean=0.0, std=std) elif isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.trunc_normal_(module.weight, mean=0.0, std=std) + init.trunc_normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 46d5817bed76..7a367e05261a 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...masking_utils import create_causal_mask @@ -1388,19 +1389,19 @@ class MimiPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, MimiLayerScale): - module.scale.fill_(self.config.layer_scale_initial_scale) + init.constant_(module.scale, self.config.layer_scale_initial_scale) @auto_docstring( diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index b9d3a4ac0a29..004ed68cef23 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -607,10 +608,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, MiniMaxExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, MiniMaxTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index d1205fdf39cc..1faff1f4dcea 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -33,6 +33,7 @@ from transformers.utils.generic import check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -414,10 +415,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, MixtralExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, MixtralTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index c6c4335ac2ef..fabb84db688e 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask @@ -275,10 +276,10 @@ def _init_weights(self, module): PreTrainedModel._init_weights(self, module) std = self.config.initializer_range if isinstance(module, MixtralExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, MixtralTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) class MixtralModel(MistralModel): diff --git a/src/transformers/models/mlcd/configuration_mlcd.py b/src/transformers/models/mlcd/configuration_mlcd.py index d4f79140c804..e1b15cc27192 100644 --- a/src/transformers/models/mlcd/configuration_mlcd.py +++ b/src/transformers/models/mlcd/configuration_mlcd.py @@ -18,7 +18,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PreTrainedConfig diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index a4dd82865202..72e26db9bd1c 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -24,6 +24,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -421,31 +422,31 @@ def _init_weights(self, module): factor = self.config.initializer_factor if isinstance(module, MLCDVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, MLCDAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, MLCDMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, MLCDVisionTransformer): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor - nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) + init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class MLCDVisionTransformer(nn.Module): diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index e3a70b798496..4cfc743948ef 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...configuration_utils import PreTrainedConfig from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -360,31 +361,31 @@ def _init_weights(self, module): factor = self.config.initializer_factor if isinstance(module, MLCDVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, MLCDAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, MLCDMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, MLCDVisionTransformer): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor - nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) + init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class MLCDVisionTransformer(CLIPVisionTransformer): diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 1f52d30be45c..f55131274194 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -818,32 +819,33 @@ def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, MllamaTextRMSNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, MllamaVisionModel): - nn.init.normal_(module.class_embedding, std=std) + init.normal_(module.class_embedding, std=std) elif isinstance(module, MllamaPrecomputedPositionEmbedding): - nn.init.normal_(module.embedding, std=std) - nn.init.zeros_(module.gate) + init.normal_(module.embedding, std=std) + init.zeros_(module.gate) elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated: - nn.init.normal_(module.gate_attn, std=std) - nn.init.normal_(module.gate_ffn, std=std) + init.normal_(module.gate_attn, std=std) + init.normal_(module.gate_ffn, std=std) elif isinstance(module, MllamaCrossAttentionDecoderLayer): - module.cross_attn_attn_gate.zero_() - module.cross_attn_mlp_gate.zero_() + init.zeros_(module.cross_attn_attn_gate) + init.zeros_(module.cross_attn_mlp_gate) elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding): if module.is_gated: - module.gate.zero_() + init.zeros_(module.gate) # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( diff --git a/src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py index ec7f4af7fb4c..fe153b1622e5 100644 --- a/src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py @@ -18,7 +18,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PreTrainedConfig from ...utils import logging from ...utils.backbone_utils import verify_backbone_config_arguments diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 583e40bb8cae..ec65fd185373 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...file_utils import ModelOutput, is_timm_available, requires_backends from ...integrations import use_kernel_forward_from_hub @@ -511,10 +512,10 @@ def _init_weights(self, module): std = self.config.init_std if isinstance(module, MMGroundingDinoLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) elif isinstance(module, MMGroundingDinoMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -527,52 +528,53 @@ def _init_weights(self, module): ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, MMGroundingDinoBiMultiHeadAttention): - nn.init.xavier_uniform_(module.vision_proj.weight) - module.vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.text_proj.weight) - module.text_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.values_vision_proj.weight) - module.values_vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.values_text_proj.weight) - module.values_text_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.out_vision_proj.weight) - module.out_vision_proj.bias.fill_(0) - nn.init.xavier_uniform_(module.out_text_proj.weight) - module.out_text_proj.bias.fill_(0) + init.xavier_uniform_(module.vision_proj.weight) + init.zeros_(module.vision_proj.bias) + init.xavier_uniform_(module.text_proj.weight) + init.zeros_(module.text_proj.bias) + init.xavier_uniform_(module.values_vision_proj.weight) + init.zeros_(module.values_vision_proj.bias) + init.xavier_uniform_(module.values_text_proj.weight) + init.zeros_(module.values_text_proj.bias) + init.xavier_uniform_(module.out_vision_proj.weight) + init.zeros_(module.out_vision_proj.bias) + init.xavier_uniform_(module.out_text_proj.weight) + init.zeros_(module.out_text_proj.bias) elif isinstance(module, MMGroundingDinoFusionLayer): - module.vision_param.fill_(1e-4) - module.text_param.fill_(1e-4) + init.constant_(module.vision_param, 1e-4) + init.constant_(module.text_param, 1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, MMGroundingDinoMLPPredictionHead): - nn.init.constant_(module.layers[-1].weight, 0) - nn.init.constant_(module.layers[-1].bias, 0) + init.constant_(module.layers[-1].weight, 0) + init.constant_(module.layers[-1].bias, 0) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): - nn.init.normal_(module.level_embed) + init.normal_(module.level_embed) if isinstance(module, MMGroundingDinoContrastiveEmbedding): - nn.init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) + init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, MMGroundingDinoDecoder): diff --git a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py index 0168a3f0bec9..91648343e440 100644 --- a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py @@ -17,6 +17,7 @@ import torch from torch import nn +from ... import initialization as init from ...configuration_utils import PreTrainedConfig from ...utils import logging from ...utils.backbone_utils import verify_backbone_config_arguments @@ -323,7 +324,7 @@ class MMGroundingDinoPreTrainedModel(GroundingDinoPreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, MMGroundingDinoContrastiveEmbedding): - nn.init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) + init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) class MMGroundingDinoConvEncoder(GroundingDinoConvEncoder): diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 58964f4ad234..6874dfecd1d3 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -549,19 +550,12 @@ class MobileBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, NoNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + super()._init_weights(module) + if isinstance(module, NoNorm): + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, MobileBertLMPredictionHead): - module.bias.zero_() + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py index a75da78ae3fb..de8fefa142c3 100755 --- a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py @@ -132,17 +132,6 @@ class MobileNetV1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class MobileNetV1Model(MobileNetV1PreTrainedModel): diff --git a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py index ae5979de21b2..5a107d3fcddc 100755 --- a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py @@ -258,17 +258,6 @@ class MobileNetV2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class MobileNetV2Model(MobileNetV2PreTrainedModel): diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index e2646d6c3e46..a41155c1bd71 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -611,12 +612,12 @@ class MobileViTPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index d87aee1d7e63..9eefb9eb77e9 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -578,12 +579,12 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.GroupNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 33d9411941e4..8069f2bec2ff 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -628,7 +629,7 @@ def _init_weights(self, module: nn.Module): cutoff_factor = 3 def init_weight(module: nn.Module, std: float): - nn.init.trunc_normal_( + init.trunc_normal_( module.weight, mean=0.0, std=std, @@ -638,7 +639,7 @@ def init_weight(module: nn.Module, std: float): if isinstance(module, nn.Linear): if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) stds = { "in": self.config.initializer_range, @@ -670,9 +671,9 @@ def init_weight(module: nn.Module, std: float): ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], is_init_check: bool = False diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 3cbdf0d0a6c7..8783517edb8a 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -809,7 +810,7 @@ def _init_weights(self, module: nn.Module): cutoff_factor = 3 def init_weight(module: nn.Module, std: float): - nn.init.trunc_normal_( + init.trunc_normal_( module.weight, mean=0.0, std=std, @@ -819,7 +820,7 @@ def init_weight(module: nn.Module, std: float): if isinstance(module, nn.Linear): if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) stds = { "in": self.config.initializer_range, @@ -851,9 +852,9 @@ def init_weight(module: nn.Module, std: float): ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], is_init_check: bool = False diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 75d46ef20df7..7564e375716b 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -401,7 +402,7 @@ def _init_weights(self, module: nn.Module): cutoff_factor = 3 def init_weight(module: nn.Module, std: float): - nn.init.trunc_normal_( + init.trunc_normal_( module.weight, mean=0.0, std=std, @@ -411,7 +412,7 @@ def init_weight(module: nn.Module, std: float): if isinstance(module, nn.Linear): if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) stds = { "in": self.config.initializer_range, @@ -437,9 +438,9 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, ModernBertDecoderForCausalLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index b5a38f6f716c..24c63f499bb2 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin @@ -427,7 +428,7 @@ def _init_weights(self, module: nn.Module): cutoff_factor = 3 def init_weight(module: nn.Module, std: float): - nn.init.trunc_normal_( + init.trunc_normal_( module.weight, mean=0.0, std=std, @@ -437,7 +438,7 @@ def init_weight(module: nn.Module, std: float): if isinstance(module, nn.Linear): if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) stds = { "in": self.config.initializer_range, @@ -463,9 +464,9 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, ModernBertDecoderForCausalLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) + init.ones_(module.weight) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check): raise AttributeError("No need to inherit!") diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index edfb33cf756b..eb379f0bcf1e 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -23,6 +23,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationConfig, GenerationMixin @@ -827,20 +828,9 @@ class MoshiPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, MoshiFlexibleLinear): - module.weight.normal_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, MoshiRMSNorm): - module.weight.fill_(1.0) + super()._init_weights(module) + if isinstance(module, MoshiFlexibleLinear): + init.normal_(module.weight) class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 975dd0eaff57..a585a63d72ad 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...modeling_outputs import ( BaseModelOutput, @@ -48,19 +49,9 @@ class MPNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MPNetLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, MPNetLMHead): + init.zeros_(module.bias) class MPNetEmbeddings(nn.Module): diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 0d666447910b..a0dae399bf73 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -223,22 +223,6 @@ class MptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MptBlock"] - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): - if module.bias is not None: - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class MptModel(MptPreTrainedModel): diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 9bd95879a05b..08bebf0c8766 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -792,20 +793,9 @@ class MraPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, MraLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, MraLMPredictionHead): + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 9416b191e77f..cf4b1189ee95 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -564,55 +565,55 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, MT5LayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance( module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering), ): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.zero_() + init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.qa_outputs.bias) elif isinstance(module, MT5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.zero_() + init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0) + init.zeros_(module.classifier.bias) elif isinstance(module, MT5ClassificationHead): - module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.zero_() - module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.dense.bias) + init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.zero_() + init.zeros_(module.out_proj.bias) elif isinstance(module, MT5DenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, MT5DenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, MT5Attention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 86988f9da002..35969de777ab 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,6 +26,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import ( @@ -420,16 +421,17 @@ class MusicgenPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) class MusicgenDecoder(MusicgenPreTrainedModel): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 0e48bab3a768..bd8f29a86a5d 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,6 +26,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import ( @@ -391,16 +392,17 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody @@ -1312,9 +1314,9 @@ def _init_weights(self, module): # Projection layers still need to be initialized. std = self.decoder.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def get_text_encoder(self): return self.text_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index c4d3350dc129..28fbf6960545 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -469,18 +469,6 @@ class MvpPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @property def dummy_inputs(self): pad_token = self.config.pad_token_id diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 88643bbc0133..af1d14ee2da0 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from torch import Size, Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin @@ -612,18 +613,10 @@ class NemotronPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, NemotronLayerNorm1P): - module.weight.fill_(1.0) - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, NemotronLayerNorm1P): + init.ones_(module.weight) + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b8bdd3efb14f..7f0de0ce824c 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -665,22 +665,6 @@ class NllbMoePreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights""" - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - class NllbMoeEncoder(NllbMoePreTrainedModel): _can_record_outputs = { diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index cbde955ecde2..92e613d9d2fa 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -413,21 +413,6 @@ class NystromformerPreTrainedModel(PreTrainedModel): base_model_prefix = "nystromformer" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class NystromformerModel(NystromformerPreTrainedModel): diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index fe899ef89e98..35804ec9c9be 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...file_utils import ( ModelOutput, @@ -991,36 +992,36 @@ class OmDetTurboPreTrainedModel(PreTrainedModel): def _init_weights(self, module): def linear_init_(module_to_init): bound = 1 / math.sqrt(module_to_init.weight.shape[0]) - nn.init.uniform_(module_to_init.weight, -bound, bound) + init.uniform_(module_to_init.weight, -bound, bound) if hasattr(module_to_init, "bias") and module_to_init.bias is not None: - nn.init.uniform_(module_to_init.bias, -bound, bound) + init.uniform_(module_to_init.bias, -bound, bound) if isinstance(module, OmDetTurboEncoderLayer): linear_init_(module.fc1) linear_init_(module.fc2) elif isinstance(module, OmDetTurboDecoder): - nn.init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0) - nn.init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0) + init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0) + init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0) for mlp in module.decoder_bbox_head: - nn.init.constant_(mlp.layers[-1].weight, 0.0) - nn.init.constant_(mlp.layers[-1].bias, 0.0) + init.constant_(mlp.layers[-1].weight, 0.0) + init.constant_(mlp.layers[-1].bias, 0.0) linear_init_(module.encoder_vision_features[0]) - nn.init.xavier_uniform_(module.encoder_vision_features[0].weight) + init.xavier_uniform_(module.encoder_vision_features[0].weight) if module.learn_initial_query: - nn.init.xavier_uniform_(module.tgt_embed.weight) - nn.init.xavier_uniform_(module.query_position_head.layers[0].weight) - nn.init.xavier_uniform_(module.query_position_head.layers[1].weight) + init.xavier_uniform_(module.tgt_embed.weight) + init.xavier_uniform_(module.query_position_head.layers[0].weight) + init.xavier_uniform_(module.query_position_head.layers[1].weight) for layer in module.channel_projection_layers: - nn.init.xavier_uniform_(layer[0].weight) + init.xavier_uniform_(layer[0].weight) elif isinstance(module, OmDetTurboLanguageBackbone): - nn.init.normal_(module.text_projection, std=self.config.text_projection_in_dim**-0.5) + init.normal_(module.text_projection, std=self.config.text_projection_in_dim**-0.5) elif isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.init_std) + init.normal_(module.weight, mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, OmDetTurboDecoder): diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 0f4b16d072b1..86d5e8d670ce 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -24,6 +24,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -2774,13 +2775,13 @@ def _init_weights(self, module: nn.Module): if module.input_projections is not None: for input_projection in module.input_projections: if not isinstance(input_projection, nn.Sequential): - nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) - nn.init.constant_(input_projection.bias, 0) + init.xavier_uniform_(input_projection.weight, gain=xavier_std) + init.constant_(input_projection.bias, 0) elif isinstance(module, OneFormerTransformerDecoder): - nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) - nn.init.constant_(module.query_input_projection.bias, 0) + init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) + init.constant_(module.query_input_projection.bias, 0) elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( @@ -2790,56 +2791,57 @@ def _init_weights(self, module: nn.Module): ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, OneFormerPixelDecoder): - nn.init.normal_(module.level_embed, std=0) + init.normal_(module.level_embed, std=0) elif isinstance(module, (OneFormerTransformerDecoderLayer, OneFormerTransformerDecoderQueryTransformer)): for p in module.parameters(): if p.dim() > 1: - nn.init.xavier_uniform_(p, gain=xavier_std) + init.xavier_uniform_(p, gain=xavier_std) elif isinstance(module, OneFormerTextTransformer): proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5) attn_std = module.width**-0.5 fc_std = (2 * module.width) ** -0.5 for layer in module.layers: - nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std) - nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) - nn.init.normal_(layer.mlp.fc1.weight, std=fc_std) - nn.init.normal_(layer.mlp.fc2.weight, std=proj_std) + init.normal_(layer.self_attn.in_proj_weight, std=attn_std) + init.normal_(layer.self_attn.out_proj.weight, std=proj_std) + init.normal_(layer.mlp.fc1.weight, std=fc_std) + init.normal_(layer.mlp.fc2.weight, std=proj_std) elif isinstance(module, OneFormerTextEncoder): - nn.init.normal_(module.token_embedding.weight, std=0.02) - nn.init.normal_(module.positional_embedding, std=0.01) + init.normal_(module.token_embedding.weight, std=0.02) + init.normal_(module.positional_embedding, std=0.01) if hasattr(module, "reference_points"): - nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) - nn.init.constant_(module.reference_points.bias, 0.0) + init.xavier_uniform_(module.reference_points.weight, gain=1.0) + init.constant_(module.reference_points.bias, 0.0) elif isinstance(module, OneFormerMLPPredictionHead): for submodule in module.modules(): if isinstance(submodule, nn.Linear): - nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) - nn.init.constant_(submodule.bias, 0) + init.xavier_uniform_(submodule.weight, gain=xavier_std) + init.constant_(submodule.bias, 0) elif isinstance(module, nn.MultiheadAttention): - module.in_proj_weight.normal_(mean=0.0, std=std) - module.in_proj_bias.zero_() + init.normal_(module.in_proj_weight, mean=0.0, std=std) + init.zeros_(module.in_proj_bias) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, OneFormerLoss): - module.logit_scale.fill_(np.log(1 / self.config.contrastive_temperature)) + init.constant_(module.logit_scale, np.log(1 / self.config.contrastive_temperature)) @auto_docstring diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 18a12bce9dc8..4c15f2719be3 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -259,21 +259,6 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): config: OpenAIGPTConfig base_model_prefix = "transformer" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @dataclass @auto_docstring( diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2d88858a6c0d..025ea7d2cd80 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -301,24 +301,8 @@ class OPTPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - class OPTDecoder(OPTPreTrainedModel): """ diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index f10631a7071a..16cfd68e49bb 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -572,41 +573,41 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, Owlv2TextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, Owlv2VisionEmbeddings): - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, Owlv2Attention): in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, Owlv2MLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, Owlv2Model): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * factor, ) - module.logit_scale.fill_(self.config.logit_scale_init_value) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=factor) + init.normal_(module.weight, mean=0.0, std=factor) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTEncoder with OwlViT->Owlv2 diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 95cd4ccb6034..bd8b23ca38eb 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -559,41 +560,41 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, OwlViTTextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, OwlViTVisionEmbeddings): - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, OwlViTAttention): in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, OwlViTMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, OwlViTModel): - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * factor, ) - module.logit_scale.fill_(self.config.logit_scale_init_value) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=factor) + init.normal_(module.weight, mean=0.0, std=factor) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class OwlViTEncoder(nn.Module): @@ -1198,6 +1199,8 @@ def __init__(self, config: OwlViTConfig): self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) + self.post_init() + @staticmethod def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: # Create grid coordinates using torch diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index dcbe454a9867..509906fa9e41 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -219,24 +219,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = False _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True - @torch.no_grad() - def _init_weights(self, module): - # important: this ported version of PaliGemmaisn't meant for training from scratch - only - # inference and fine-tuning - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 3c8698c7c9b0..ef8841a51b3f 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -27,6 +27,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput @@ -467,8 +468,8 @@ def _init_weights(self, module): if isinstance(module, ParakeetEncoderAttention): # Initialize positional bias parameters - module.bias_u.normal_(mean=0.0, std=std) - module.bias_v.normal_(mean=0.0, std=std) + init.normal_(module.bias_u, mean=0.0, std=std) + init.normal_(module.bias_v, mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index f792b19c9315..e1d552cbf43d 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput @@ -343,8 +344,8 @@ def _init_weights(self, module): if isinstance(module, ParakeetEncoderAttention): # Initialize positional bias parameters - module.bias_u.normal_(mean=0.0, std=std) - module.bias_v.normal_(mean=0.0, std=std) + init.normal_(module.bias_u, mean=0.0, std=std) + init.normal_(module.bias_v, mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 3402386596d2..410ef5abb7f4 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -25,6 +25,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput +from ... import initialization as init from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -691,17 +692,17 @@ def _init_weights(self, module): if isinstance(module, PatchTSMixerPositionalEncoding): # initialize positional encoding if self.config.positional_encoding_type == "random": - nn.init.normal_(module.position_enc, mean=0.0, std=0.1) + init.normal_(module.position_enc, mean=0.0, std=0.1) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, PatchTSMixerBatchNorm): - module.batchnorm.bias.zero_() - module.batchnorm.weight.fill_(1.0) + init.zeros_(module.batchnorm.bias) + init.ones_(module.batchnorm.weight) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.init_std) + init.normal_(module.weight, mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class PatchTSMixerPretrainHead(nn.Module): diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index fe99982803d9..5317476b1a42 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2CLS from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput @@ -567,20 +568,20 @@ def _init_weights(self, module: nn.Module): ) // self.config.patch_stride + 1 # initialize cls_token if self.config.use_cls_token: - nn.init.normal_(module.cls_token, std=0.02) + init.normal_(module.cls_token, std=0.02) num_patches += 1 # initialize positional encoding - module.position_enc = module._init_pe(self.config, num_patches) + init.copy_(module.position_enc, module._init_pe(self.config, num_patches)) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, PatchTSTBatchNorm): - module.batchnorm.bias.zero_() - module.batchnorm.weight.fill_(1.0) + init.zeros_(module.batchnorm.bias) + init.ones_(module.batchnorm.weight) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.init_std) + init.normal_(module.weight, mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (PatchTSTEncoder)): diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index e1009cc96e5a..34205b2e5972 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -73,9 +74,9 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -88,7 +89,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -435,25 +436,13 @@ class PegasusPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True @torch.no_grad() def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, PegasusSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, PegasusSinusoidalPositionalEmbedding): + init.copy_(module.weight, module.create_weight()) class PegasusEncoder(PegasusPreTrainedModel): @@ -512,7 +501,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.d_model, self.padding_idx, ) - self.embed_positions._init_weight() + init.copy_(self.embed_positions.weight, self.embed_positions.create_weight()) self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: @@ -684,7 +673,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.d_model, self.padding_idx, ) - self.embed_positions._init_weight() + init.copy_(self.embed_positions.weight, self.embed_positions.create_weight()) self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 0e9b8bc1e255..53e44ef414d0 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -744,22 +744,8 @@ class PegasusXPreTrainedModel(PreTrainedModel): # Flaky logits _supports_sdpa = False _supports_flex_attn = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - class PegasusXEncoder(PegasusXPreTrainedModel): """ diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index 4ddad1c5b2c6..1b22c7ca1802 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -535,23 +536,24 @@ class PerceiverPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif hasattr(module, "latents"): - module.latents.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.latents, mean=0.0, std=self.config.initializer_range) elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding): - module.position_embeddings.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.ParameterDict): for modality in module: - module[modality].normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module[modality], mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8bb936c41461..4b09a2dd75bf 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -423,27 +423,11 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True _supports_sdpa = True _supports_flash_attn = True _supports_attention_backend = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @auto_docstring class PersimmonModel(PersimmonPreTrainedModel): diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 31ef21fbda1e..eab15068d252 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -19,7 +19,6 @@ # limitations under the License. import math -import warnings from collections.abc import Callable from typing import Optional, Union @@ -29,6 +28,7 @@ from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -209,69 +209,7 @@ def forward( return BaseModelOutput(last_hidden_state=hidden_states) -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): +def variance_scaling_(tensor, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in @@ -280,18 +218,15 @@ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 - variance = scale / denom + variance = 1.0 / denom if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) + init.normal_(tensor, std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) + init.uniform_(tensor, -bound, bound) else: raise ValueError(f"invalid distribution {distribution}") @@ -331,34 +266,34 @@ def _init_weights(self, module): if isinstance(self.config, Phi4MultimodalVisionConfig) else self.config.hidden_size ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Phi4MultimodalVisionAttention): - nn.init.normal_(module.q_proj.weight) - nn.init.normal_(module.k_proj.weight) - nn.init.normal_(module.v_proj.weight) - nn.init.normal_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) + init.normal_(module.q_proj.weight) + init.normal_(module.k_proj.weight) + init.normal_(module.v_proj.weight) + init.normal_(module.out_proj.weight) + init.zeros_(module.q_proj.bias) + init.zeros_(module.k_proj.bias) + init.zeros_(module.v_proj.bias) + init.zeros_(module.out_proj.bias) elif isinstance(module, Phi4MultimodalVisionMLP): - nn.init.normal_(module.fc1.weight) - nn.init.normal_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) + init.normal_(module.fc1.weight) + init.normal_(module.fc2.weight) + init.normal_(module.fc1.bias, std=1e-6) + init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): - nn.init.normal_(module.probe) - nn.init.normal_(module.attention.in_proj_weight) - nn.init.zeros_(module.attention.in_proj_bias) + init.normal_(module.probe) + init.normal_(module.attention.in_proj_weight) + init.zeros_(module.attention.in_proj_bias) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) class Phi4MultimodalVisionEmbeddings(nn.Module): @@ -944,8 +879,8 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): - module.b1.zero_() - module.b2.zero_() + init.zeros_(module.b1) + init.zeros_(module.b2) def unfold_tensor(tensor, max_seq_len): @@ -1503,8 +1438,8 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalImageEmbedding): - module.global_img_feature_extensor.zero_() - module.sub_img_feature_extensor.zero_() + init.zeros_(module.global_img_feature_extensor) + init.zeros_(module.sub_img_feature_extensor) class Phi4MultimodalRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 62c7fb50748f..58cb33d62320 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig @@ -555,34 +556,34 @@ def _init_weights(self, module): if isinstance(self.config, Phi4MultimodalVisionConfig) else self.config.hidden_size ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Phi4MultimodalVisionAttention): - nn.init.normal_(module.q_proj.weight) - nn.init.normal_(module.k_proj.weight) - nn.init.normal_(module.v_proj.weight) - nn.init.normal_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) + init.normal_(module.q_proj.weight) + init.normal_(module.k_proj.weight) + init.normal_(module.v_proj.weight) + init.normal_(module.out_proj.weight) + init.zeros_(module.q_proj.bias) + init.zeros_(module.k_proj.bias) + init.zeros_(module.v_proj.bias) + init.zeros_(module.out_proj.bias) elif isinstance(module, Phi4MultimodalVisionMLP): - nn.init.normal_(module.fc1.weight) - nn.init.normal_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) + init.normal_(module.fc1.weight) + init.normal_(module.fc2.weight) + init.normal_(module.fc1.bias, std=1e-6) + init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): - nn.init.normal_(module.probe) - nn.init.normal_(module.attention.in_proj_weight) - nn.init.zeros_(module.attention.in_proj_bias) + init.normal_(module.probe) + init.normal_(module.attention.in_proj_weight) + init.zeros_(module.attention.in_proj_bias) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings): @@ -1124,8 +1125,8 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): - module.b1.zero_() - module.b2.zero_() + init.zeros_(module.b1) + init.zeros_(module.b2) class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): @@ -1447,8 +1448,8 @@ class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Phi4MultimodalImageEmbedding): - module.global_img_feature_extensor.zero_() - module.sub_img_feature_extensor.zero_() + init.zeros_(module.global_img_feature_extensor) + init.zeros_(module.sub_img_feature_extensor) class Phi4MultimodalModel(Phi3Model): diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 50479af0dac8..12e41214094d 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -26,6 +26,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -628,10 +629,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, PhimoeExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, PhimoeTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index f47e9f005e02..8a58a353617f 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -355,7 +356,7 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, Pix2StructLayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, Pix2StructTextDenseGatedActDense): hidden_size = ( self.config.text_config.hidden_size @@ -364,15 +365,15 @@ def _init_weights(self, module): ) d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff - module.wi_0.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, Pix2StructTextAttention): hidden_size = ( self.config.text_config.hidden_size @@ -388,12 +389,12 @@ def _init_weights(self, module): else self.config.num_heads ) - module.query.weight.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) - module.key.weight.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) - module.value.weight.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) - module.output.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.query.weight, mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) + init.normal_(module.key.weight, mean=0.0, std=factor * (hidden_size**-0.5)) + init.normal_(module.value.weight, mean=0.0, std=factor * (hidden_size**-0.5)) + init.normal_(module.output.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, nn.Embedding): hidden_size = ( self.config.text_config.hidden_size @@ -401,9 +402,10 @@ def _init_weights(self, module): else self.config.hidden_size ) - module.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5)) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, Pix2StructTextModel): hidden_size = ( self.config.text_config.hidden_size @@ -411,24 +413,19 @@ def _init_weights(self, module): else self.config.hidden_size ) - module.lm_head.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, Pix2StructLayerNorm): if module.weight is not None: - module.weight.fill_(1.0) + init.ones_(module.weight) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct def _shift_right(self, input_ids): diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index f9a408193387..3e6d468a8184 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -441,16 +441,6 @@ class PixtralPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _no_split_modules = ["PixtralAttentionLayer"] - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, PixtralRMSNorm): - module.weight.fill_(1.0) - def generate_block_attention_mask(patch_embeds_list, tensor): dtype = tensor.dtype diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 028c22e180f8..e1b9f375c542 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -841,7 +841,7 @@ def __init__(self, config: PLBartConfig): self.encoder = PLBartEncoder(config) self.decoder = PLBartDecoder(config) - self.init_weights() + self.post_init() def get_input_embeddings(self): return self.shared @@ -970,7 +970,7 @@ def __init__(self, config: PLBartConfig): self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) - self.init_weights() + self.post_init() def get_encoder(self): return self.model.get_encoder() diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index e67705ef697b..796f63d54632 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -82,7 +82,7 @@ def __init__(self, config: PLBartConfig): self.encoder = PLBartEncoder(config) self.decoder = PLBartDecoder(config) - self.init_weights() + self.post_init() def get_input_embeddings(self): return self.shared @@ -211,7 +211,7 @@ def __init__(self, config: PLBartConfig): self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) - self.init_weights() + self.post_init() def get_encoder(self): return self.model.get_encoder() diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 0e7dc6fe24f0..21f01878ef87 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention from ...modeling_utils import PreTrainedModel @@ -248,17 +249,11 @@ class PoolFormerPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.GroupNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, PoolFormerLayer): + super()._init_weights(module) + if isinstance(module, PoolFormerLayer): if hasattr(module, "layer_scale_1"): - module.layer_scale_1.fill_(self.config.layer_scale_init_value) - module.layer_scale_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.layer_scale_1, self.config.layer_scale_init_value) + init.constant_(module.layer_scale_2, self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 086fc951dd04..493b9d4776ec 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -24,6 +24,7 @@ from transformers.generation import GenerationConfig +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -549,40 +550,40 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, Pop2PianoLayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, Pop2PianoConcatEmbeddingToMel): - module.embedding.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.embedding.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoForConditionalGeneration): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head"): - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoDenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, Pop2PianoDenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, Pop2PianoAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 5674114cebc4..619769b9d2b7 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -332,17 +332,6 @@ class ProphetNetPreTrainedModel(PreTrainedModel): base_model_prefix = "prophetnet" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.init_std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2a296a5e09e8..caee0f0dfb78 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel @@ -426,30 +427,16 @@ def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - nn.init.trunc_normal_(module.weight, mean=0.0, std=std) + init.trunc_normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, PvtPatchEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings, - mean=0.0, - std=std, - ) - ) + init.trunc_normal_(module.position_embeddings, mean=0.0, std=std) if module.cls_token is not None: - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token, - mean=0.0, - std=std, - ) - ) + init.trunc_normal_(module.cls_token, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index 010e91b9d479..9973c6dfcad6 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput @@ -372,20 +373,18 @@ class PvtV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, nn.Linear): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_(nn.init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups - module.weight.normal_(0, math.sqrt(2.0 / fan_out)) + init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 94114633ef47..7cad197e7e72 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -257,28 +257,6 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - @torch.no_grad() - def _init_weights(self, module): - # important: this ported version of Qwen2Audio isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.audio_config.initializer_range - ) - - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bf642609c9fe..8bda140d3cdb 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -443,10 +444,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, Qwen2MoeExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, Qwen2MoeTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index e709a7d84709..477694d5fb2b 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -376,10 +377,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, Qwen3MoeExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, Qwen3MoeTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) class Qwen3MoeRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 9096064b1cc2..362c8fab007f 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -992,16 +993,16 @@ class Qwen3NextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Qwen3NextGatedDeltaNet): - module.dt_bias.fill_(1.0) - module.A_log.uniform_(0, 16).log_() + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, Qwen3NextRMSNorm): - module.weight.zero_() - if isinstance(module, Qwen3NextExperts): - module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) - module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) - if isinstance(module, Qwen3NextSparseMoeBlock): - module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.zeros_(module.weight) + elif isinstance(module, Qwen3NextExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3NextSparseMoeBlock): + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) class Qwen3NextModel(Qwen3NextPreTrainedModel): diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 8630da2a06b0..7deedb9c868b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...masking_utils import create_causal_mask @@ -740,16 +741,16 @@ class Qwen3NextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Qwen3NextGatedDeltaNet): - module.dt_bias.fill_(1.0) - module.A_log.uniform_(0, 16).log_() + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, Qwen3NextRMSNorm): - module.weight.zero_() - if isinstance(module, Qwen3NextExperts): - module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) - module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) - if isinstance(module, Qwen3NextSparseMoeBlock): - module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.zeros_(module.weight) + elif isinstance(module, Qwen3NextExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3NextSparseMoeBlock): + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) class Qwen3NextModel(Qwen3NextPreTrainedModel): diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index aa9d41366be0..1be0487cea98 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -31,6 +31,7 @@ from torch.nn import Parameter from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -81,9 +82,9 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock): - module.experts.gate_up_proj.normal_(mean=0.0, std=std) - module.experts.down_proj.normal_(mean=0.0, std=std) - module.router.weight.normal_(mean=0.0, std=std) + init.normal_(module.experts.gate_up_proj, mean=0.0, std=std) + init.normal_(module.experts.down_proj, mean=0.0, std=std) + init.normal_(module.router.weight, mean=0.0, std=std) def _get_feat_extract_output_lengths(input_lengths): @@ -1607,10 +1608,10 @@ def _init_weights(self, module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, Qwen3OmniMoeThinkerTextExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) elif isinstance(module, Qwen3OmniMoeThinkerTextTopKRouter): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) @use_kernel_forward_from_hub("RMSNorm") diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index e478dfccb50a..ea6ac6860133 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...audio_utils import AudioInput from ...cache_utils import Cache, DynamicCache @@ -796,9 +797,9 @@ def _init_weights(self, module): PreTrainedModel._init_weights(self, module) std = self.config.initializer_range if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock): - module.experts.gate_up_proj.normal_(mean=0.0, std=std) - module.experts.down_proj.normal_(mean=0.0, std=std) - module.router.weight.normal_(mean=0.0, std=std) + init.normal_(module.experts.gate_up_proj, mean=0.0, std=std) + init.normal_(module.experts.down_proj, mean=0.0, std=std) + init.normal_(module.router.weight, mean=0.0, std=std) class Qwen3OmniMoePreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration): diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 7765344c2a1b..efd0e8d24926 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -27,6 +27,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -413,8 +414,8 @@ def _init_weights(self, module): else: std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, Qwen3VLMoeTextExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) class Qwen3VLMoeVisionMLP(nn.Module): diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index 459d45159fdc..006fa186fe44 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig @@ -367,8 +368,8 @@ def _init_weights(self, module): else: std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, Qwen3VLMoeTextExperts): - module.gate_up_proj.normal_(mean=0.0, std=std) - module.down_proj.normal_(mean=0.0, std=std) + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) class Qwen3VLMoeVisionModel(Qwen3VLVisionModel): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a1d58064207e..dc1a3d4951e2 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -557,49 +558,50 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) if isinstance(module, nn.Conv1d): - torch.nn.init.normal_(module.weight, mean=0.0, std=std) - torch.nn.init.zeros_(module.bias) + init.normal_(module.weight, mean=0.0, std=std) + init.zeros_(module.bias) elif isinstance(module, RecurrentGemmaSdpaAttention): - torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) - torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) - torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) std = math.sqrt(self.config.final_w_init_variance_scale / self.config.hidden_size) - torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=std) + init.normal_(module.o_proj.weight, mean=0.0, std=std) elif isinstance(module, RecurrentGemmaRecurrentBlock): - torch.nn.init.zeros_(module.linear_x.bias) - torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + init.zeros_(module.linear_x.bias) + init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) - torch.nn.init.zeros_(module.linear_y.bias) - torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + init.zeros_(module.linear_y.bias) + init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) std = math.sqrt(self.config.final_w_init_variance_scale / self.config.lru_width) - torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std) - torch.nn.init.zeros_(module.linear_out.bias) + init.normal_(module.linear_out.weight, mean=0.0, std=std) + init.zeros_(module.linear_out.bias) elif isinstance(module, RecurrentGemmaRglru): std = math.sqrt( self.config.w_init_variance_scale / (self.config.lru_width // self.config.num_attention_heads) ) - torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std) - torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std) - torch.nn.init.zeros_(module.input_gate_bias) - torch.nn.init.zeros_(module.recurrent_gate_bias) - - module.recurrent_param.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) - module.recurrent_param.log_().mul_(0.5) - module.recurrent_param.neg_().exp_().sub_(1.0).log_() + init.normal_(module.input_gate_weight, mean=0.0, std=std) + init.normal_(module.recurrent_gate_weight, mean=0.0, std=std) + init.zeros_(module.input_gate_bias) + init.zeros_(module.recurrent_gate_bias) + + recurrent_param = torch.empty_like(module.recurrent_param).uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) + recurrent_param.log_().mul_(0.5).neg_().exp_().sub_(1.0).log_() + init.copy_(module.recurrent_param, recurrent_param) elif isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: - torch.nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, RecurrentGemmaRMSNorm): - module.weight.zero_() + init.zeros_(module.weight) def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 24a598251956..5eccfcd7b870 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -29,6 +29,7 @@ from torch.autograd.function import Function from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput @@ -1846,20 +1847,10 @@ def dummy_inputs(self): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) if isinstance(module, AxialPositionEmbeddings): for weight in module.weights: - nn.init.normal_(weight, std=self.config.axial_norm_std) - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.normal_(weight, std=self.config.axial_norm_std) @dataclass diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index fd6416f46ec6..a5ce03ac783f 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -20,6 +20,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutputWithNoAttention, @@ -266,17 +267,17 @@ class RegNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + init.constant_(module.weight, 1) + init.constant_(module.bias, 0) @auto_docstring diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 13651c32f5da..d0145af76cab 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -488,21 +488,6 @@ class RemBertPreTrainedModel(PreTrainedModel): base_model_prefix = "rembert" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index dba8200edba1..3f4fec571cd7 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -20,6 +20,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BackboneOutput, @@ -253,17 +254,17 @@ class ResNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + init.constant_(module.weight, 1) + init.constant_(module.bias, 0) @auto_docstring diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index f5b315f38f26..091a9649885b 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -27,6 +27,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -497,19 +498,9 @@ class RobertaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, RobertaLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, RobertaLMHead): + init.zeros_(module.bias) class RobertaEncoder(nn.Module): diff --git a/src/transformers/models/roberta/modular_roberta.py b/src/transformers/models/roberta/modular_roberta.py index 54049e1189da..f93e12245285 100644 --- a/src/transformers/models/roberta/modular_roberta.py +++ b/src/transformers/models/roberta/modular_roberta.py @@ -21,6 +21,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import gelu from ...generation import GenerationMixin from ...modeling_outputs import ( @@ -168,19 +169,9 @@ class RobertaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, RobertaLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, RobertaLMHead): + init.zeros_(module.bias) class RobertaModel(BertModel): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index bdc93a1cc73c..787e2556d14b 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -557,19 +558,9 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, RobertaPreLayerNormLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, RobertaPreLayerNormLMHead): + init.zeros_(module.bias) @auto_docstring( diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 6800fa2fbfa5..6772ab8b1bc7 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -617,19 +618,9 @@ class RoCBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, RoCBertLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, RoCBertLMPredictionHead): + init.zeros_(module.bias) @auto_docstring( diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 0aa4cb11bf51..1d21119d4bb6 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, get_activation from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -50,9 +51,9 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -65,7 +66,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -637,22 +638,11 @@ class RoFormerPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, RoFormerSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + super()._init_weights(module) + if isinstance(module, RoFormerSinusoidalPositionalEmbedding): + init.copy_(module.weight, module.create_weight()) elif isinstance(module, RoFormerLMPredictionHead): - module.bias.zero_() + init.zeros_(module.bias) @auto_docstring( diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 00b661be3acc..d91c8fd92d90 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format from ...integrations import use_kernel_forward_from_hub @@ -1017,16 +1018,16 @@ def _init_weights(self, module): for layer in module.model.decoder.class_embed: prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(layer.weight) - nn.init.constant_(layer.bias, bias) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) if module.model.decoder.bbox_embed is not None: for layer in module.model.decoder.bbox_embed: - nn.init.constant_(layer.layers[-1].weight, 0) - nn.init.constant_(layer.layers[-1].bias, 0) + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) elif isinstance(module, RTDetrMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -1040,34 +1041,33 @@ def _init_weights(self, module): for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - if not getattr(module.sampling_offsets.bias, "_is_hf_initialized", False): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, RTDetrModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(module.enc_score_head.weight) - nn.init.constant_(module.enc_score_head.bias, bias) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - nn.init.xavier_uniform_(module.weight_embedding.weight) + init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - nn.init.xavier_uniform_(module.denoising_class_embed.weight) + init.xavier_uniform_(module.denoising_class_embed.weight) class RTDetrEncoder(nn.Module): diff --git a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py index 12f9d90d8eb5..69d99dce3d86 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention from ...modeling_utils import PreTrainedModel @@ -307,17 +308,17 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. elif isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(module.bias, -bound, bound) + init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + init.constant_(module.weight, 1) + init.constant_(module.bias, 0) @auto_docstring( diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index a763e925a6e1..1a5a510ce11c 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format from ...modeling_outputs import BaseModelOutput @@ -464,16 +465,16 @@ def _init_weights(self, module): for layer in module.model.decoder.class_embed: prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(layer.weight) - nn.init.constant_(layer.bias, bias) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) if module.model.decoder.bbox_embed is not None: for layer in module.model.decoder.bbox_embed: - nn.init.constant_(layer.layers[-1].weight, 0) - nn.init.constant_(layer.layers[-1].bias, 0) + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight, 0.0) + init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -487,34 +488,33 @@ def _init_weights(self, module): for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 - if not getattr(module.sampling_offsets.bias, "_is_hf_initialized", False): - module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight, 0.0) - nn.init.constant_(module.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight) - nn.init.constant_(module.value_proj.bias, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight) - nn.init.constant_(module.output_proj.bias, 0.0) + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, RTDetrV2Model): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) - nn.init.xavier_uniform_(module.enc_score_head.weight) - nn.init.constant_(module.enc_score_head.bias, bias) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - nn.init.xavier_uniform_(module.weight_embedding.weight) + init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - nn.init.xavier_uniform_(module.denoising_class_embed.weight) + init.xavier_uniform_(module.denoising_class_embed.weight) @dataclass diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 2f0a434720a2..ca72ec5f5565 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -399,12 +400,12 @@ def _init_weights(self, module: nn.Module): * 0.5 ) - module.time_decay.copy_(decay_speed) - module.time_first.copy_(torch.ones_like(module.time_first * math.log(0.3) + zigzag)) + init.copy_(module.time_decay, decay_speed) + init.copy_(module.time_first, torch.ones_like(module.time_first * math.log(0.3) + zigzag)) - module.time_mix_key.copy_(torch.pow(time_weight, ratio_1_to_almost0)) - module.time_mix_value.copy_(torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) - module.time_mix_receptance.copy_(torch.pow(time_weight, 0.5 * ratio_1_to_almost0)) + init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0)) + init.copy_(module.time_mix_value, torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + init.copy_(module.time_mix_receptance, torch.pow(time_weight, 0.5 * ratio_1_to_almost0)) elif isinstance(module, RwkvFeedForward): layer_id = module.layer_id num_hidden_layers = module.config.num_hidden_layers @@ -419,28 +420,28 @@ def _init_weights(self, module: nn.Module): ) time_weight = time_weight[None, None, :] - module.time_mix_key.copy_(torch.pow(time_weight, ratio_1_to_almost0)) - module.time_mix_receptance.copy_(torch.pow(time_weight, ratio_1_to_almost0)) + init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0)) + init.copy_(module.time_mix_receptance, torch.pow(time_weight, ratio_1_to_almost0)) elif isinstance(module, nn.Linear): shape = module.weight.shape gain = 1.0 scale = 1.0 # extra scale for gain if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection? scale = 0.5 gain *= scale - nn.init.orthogonal_(module.weight, gain=gain) + init.orthogonal_(module.weight, gain=gain) elif isinstance(module, nn.Embedding): shape = module.weight.shape gain = 1e-4 * math.sqrt(max(shape[0], shape[1])) - nn.init.orthogonal_(module.weight, gain=gain) + init.orthogonal_(module.weight, gain=gain) elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() + init.ones_(module.weight) + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 138a59716b2f..00e6bd1ab23c 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -26,6 +26,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -1008,11 +1009,11 @@ def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, SamVisionAttention): if module.use_rel_pos: - module.rel_pos_h.zero_() - module.rel_pos_w.zero_() + init.zeros_(module.rel_pos_h) + init.zeros_(module.rel_pos_w) elif isinstance(module, SamVisionEncoder): if self.config.use_abs_pos: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) class SamVisionEncoder(SamPreTrainedModel): diff --git a/src/transformers/models/sam2/modeling_sam2.py b/src/transformers/models/sam2/modeling_sam2.py index 6c0f6b77cc0e..8e62d6d99e76 100644 --- a/src/transformers/models/sam2/modeling_sam2.py +++ b/src/transformers/models/sam2/modeling_sam2.py @@ -32,6 +32,7 @@ from transformers.utils.generic import OutputRecorder +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -558,26 +559,15 @@ class Sam2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + super()._init_weights(module) if isinstance(module, Sam2HieraDetModel): if module.pos_embed is not None: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) if module.pos_embed_window is not None: - module.pos_embed_window.zero_() + init.zeros_(module.pos_embed_window) if isinstance(module, Sam2Model): if module.no_memory_embedding is not None: - module.no_memory_embedding.zero_() + init.zeros_(module.no_memory_embedding) class Sam2HieraDetModel(Sam2PreTrainedModel): diff --git a/src/transformers/models/sam2/modular_sam2.py b/src/transformers/models/sam2/modular_sam2.py index 64bc3bfc30ca..d2052bf3b7eb 100644 --- a/src/transformers/models/sam2/modular_sam2.py +++ b/src/transformers/models/sam2/modular_sam2.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...image_processing_utils import BatchFeature, get_size_dict from ...image_processing_utils_fast import BaseImageProcessorFast @@ -679,26 +680,15 @@ class Sam2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() + super()._init_weights(module) if isinstance(module, Sam2HieraDetModel): if module.pos_embed is not None: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) if module.pos_embed_window is not None: - module.pos_embed_window.zero_() + init.zeros_(module.pos_embed_window) if isinstance(module, Sam2Model): if module.no_memory_embedding is not None: - module.no_memory_embedding.zero_() + init.zeros_(module.no_memory_embedding) class Sam2HieraDetModel(Sam2PreTrainedModel): diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py index 7437130aaee8..3158b55bc319 100644 --- a/src/transformers/models/sam2_video/modeling_sam2_video.py +++ b/src/transformers/models/sam2_video/modeling_sam2_video.py @@ -32,6 +32,7 @@ from torch import Tensor from tqdm import tqdm +from ... import initialization as init from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -668,30 +669,19 @@ class Sam2VideoPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, Sam2VideoModel): + super()._init_weights(module) + if isinstance(module, Sam2VideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.zero_() + init.zeros_(module.no_memory_positional_encoding) if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.zero_() + init.zeros_(module.memory_temporal_positional_encoding) if module.no_object_pointer is not None: - module.no_object_pointer.zero_() + init.zeros_(module.no_object_pointer) if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.zero_() + init.zeros_(module.occlusion_spatial_embedding_parameter) if isinstance(module, Sam2VideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.zero_() + init.zeros_(module.scale) class Sam2VideoVisionRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py index 97550a96d19b..e701a7111bb4 100644 --- a/src/transformers/models/sam2_video/modular_sam2_video.py +++ b/src/transformers/models/sam2_video/modular_sam2_video.py @@ -27,6 +27,7 @@ from torch import Tensor from tqdm import tqdm +from ... import initialization as init from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -989,30 +990,19 @@ class Sam2VideoPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, Sam2VideoModel): + super()._init_weights(module) + if isinstance(module, Sam2VideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.zero_() + init.zeros_(module.no_memory_positional_encoding) if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.zero_() + init.zeros_(module.memory_temporal_positional_encoding) if module.no_object_pointer is not None: - module.no_object_pointer.zero_() + init.zeros_(module.no_object_pointer) if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.zero_() + init.zeros_(module.occlusion_spatial_embedding_parameter) if isinstance(module, Sam2VideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.zero_() + init.zeros_(module.scale) class Sam2VideoVisionRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 7bca21b31892..db01c871bf11 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -32,6 +32,7 @@ from transformers.modeling_outputs import ModelOutput from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -427,11 +428,11 @@ def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, SamHQVisionAttention): if module.use_rel_pos: - module.rel_pos_h.zero_() - module.rel_pos_w.zero_() + init.zeros_(module.rel_pos_h) + init.zeros_(module.rel_pos_w) elif isinstance(module, SamHQVisionEncoder): if self.config.use_abs_pos: - module.pos_embed.zero_() + init.zeros_(module.pos_embed) class SamHQPatchEmbeddings(nn.Module): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 7efe8936d837..13ebc98ef56d 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -1347,37 +1348,38 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, SeamlessM4TConformerSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, SeamlessM4TConformerPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, SeamlessM4TConformerFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -2400,22 +2402,6 @@ def forward( return hidden_states, lengths - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): @@ -2973,12 +2959,12 @@ def __init__(self, config: SeamlessM4TConfig): self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4TCodeHifiGan(config) + # Initialize weights and apply final processing + self.post_init() + def get_encoder(self): return self.text_encoder @@ -3617,9 +3603,6 @@ def __init__(self, config, current_modality="text"): self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.current_modality = current_modality if current_modality == "speech": self.main_input_name = "input_features" @@ -3628,6 +3611,9 @@ def __init__(self, config, current_modality="text"): self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4TCodeHifiGan(config) + # Initialize weights and apply final processing + self.post_init() + def set_modality(self, modality="text"): if modality == "text": self.main_input_name = "input_ids" diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 16aba775566c..9096c40e9f0c 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -1263,33 +1264,34 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, SeamlessM4Tv2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, SeamlessM4Tv2ConformerFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, SeamlessM4Tv2TextToUnitDecoder): - module.pos_emb_alpha_char.fill_(1) - module.pos_emb_alpha.fill_(1) + init.ones_(module.pos_emb_alpha_char) + init.ones_(module.pos_emb_alpha) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TPreTrainedModel._compute_sub_sample_lengths_from_attention_mask def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): @@ -2602,22 +2604,6 @@ def forward( return hidden_states, lengths - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm def apply_weight_norm(self): weight_norm = nn.utils.weight_norm @@ -3186,12 +3172,12 @@ def __init__(self, config: SeamlessM4Tv2Config): self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + # Initialize weights and apply final processing + self.post_init() + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.get_encoder def get_encoder(self): return self.text_encoder @@ -3903,9 +3889,6 @@ def __init__(self, config, current_modality="text"): self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.current_modality = current_modality if current_modality == "speech": self.main_input_name = "input_features" @@ -3914,6 +3897,9 @@ def __init__(self, config, current_modality="text"): self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + # Initialize weights and apply final processing + self.post_init() + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.set_modality def set_modality(self, modality="text"): if modality == "text": diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index ea0a58568101..fe18dd33cd3c 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -414,21 +414,6 @@ class SegformerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class SegformerModel(SegformerPreTrainedModel): diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 80f98707757d..7ff7922029d0 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import functional as F +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -600,47 +601,22 @@ def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=std).to(module.weight.dtype) - ) + init.trunc_normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, SegGptLayerNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, SegGptAttention): - module.rel_pos_h.copy_( - nn.init.trunc_normal_( - module.rel_pos_h.to(torch.float32), - mean=0.0, - std=std, - ).to(module.rel_pos_h.dtype) - ) - - module.rel_pos_w.copy_( - nn.init.trunc_normal_( - module.rel_pos_w.to(torch.float32), - mean=0.0, - std=std, - ).to(module.rel_pos_w.dtype) - ) - + init.trunc_normal_(module.rel_pos_h, mean=0.0, std=std) + init.trunc_normal_(module.rel_pos_w, mean=0.0, std=std) elif isinstance(module, SegGptEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=std, - ).to(module.position_embeddings.dtype) - ) - - torch.nn.init.normal_(module.mask_token, std=std) - torch.nn.init.normal_(module.segment_token_input, std=std) - torch.nn.init.normal_(module.segment_token_prompt, std=std) - torch.nn.init.normal_(module.type_token_semantic, std=std) - torch.nn.init.normal_(module.type_token_instance, std=std) + init.trunc_normal_(module.position_embeddings, mean=0.0, std=std) + init.normal_(module.mask_token, std=std) + init.normal_(module.segment_token_input, std=std) + init.normal_(module.segment_token_prompt, std=std) + init.normal_(module.type_token_semantic, std=std) + init.normal_(module.type_token_instance, std=std) @auto_docstring diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 12870e4817db..28b5541b80e5 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -28,6 +28,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -521,32 +522,32 @@ class SEWPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 4db3783036e5..20ddfc1e7f8b 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -259,32 +260,32 @@ class SEWPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index a86f800027c5..ca426c3e23be 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_layers import GradientCheckpointingLayer @@ -1179,36 +1180,37 @@ class SEWDPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWDPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 171c9a3d1bfa..2b89f8748477 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -15,7 +15,6 @@ """PyTorch Siglip model.""" import math -import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional, Union @@ -25,6 +24,7 @@ from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -43,69 +43,7 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): +def variance_scaling_(tensor, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in @@ -114,18 +52,15 @@ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 - variance = scale / denom + variance = 1.0 / denom if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) + init.normal_(tensor, std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) + init.uniform_(tensor, -bound, bound) else: raise ValueError(f"invalid distribution {distribution}") @@ -494,43 +429,42 @@ def _init_weights(self, module): if isinstance(self.config, SiglipConfig) else self.config.hidden_size ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) + init.xavier_uniform_(module.q_proj.weight) + init.xavier_uniform_(module.k_proj.weight) + init.xavier_uniform_(module.v_proj.weight) + init.xavier_uniform_(module.out_proj.weight) + init.zeros_(module.q_proj.bias) + init.zeros_(module.k_proj.bias) + init.zeros_(module.v_proj.bias) + init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) + init.xavier_uniform_(module.fc1.weight) + init.xavier_uniform_(module.fc2.weight) + init.normal_(module.fc1.bias, std=1e-6) + init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe) - nn.init.xavier_uniform_(module.attention.in_proj_weight) - nn.init.zeros_(module.attention.in_proj_bias) + init.xavier_uniform_(module.probe) + init.xavier_uniform_(module.attention.in_proj_weight) + init.zeros_(module.attention.in_proj_bias) elif isinstance(module, SiglipModel): - logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.fill_(logit_scale_init) - module.logit_bias.zero_() + init.zeros_(module.logit_scale) + init.zeros_(module.logit_bias) elif isinstance(module, SiglipForImageClassification): - nn.init.normal_( + init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 8db1e0e68b13..0a2a990ef52a 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional, Union @@ -30,6 +29,7 @@ import torch.nn.functional as F from torch.nn.init import _calculate_fan_in_and_fan_out +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -349,69 +349,7 @@ def forward( return hidden_states -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): +def variance_scaling_(tensor, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in @@ -420,18 +358,15 @@ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 - variance = scale / denom + variance = 1.0 / denom if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) + init.normal_(tensor, std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) + init.uniform_(tensor, -bound, bound) else: raise ValueError(f"invalid distribution {distribution}") @@ -476,43 +411,42 @@ def _init_weights(self, module): if isinstance(self.config, Siglip2Config) else self.config.hidden_size ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Siglip2Attention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) + init.xavier_uniform_(module.q_proj.weight) + init.xavier_uniform_(module.k_proj.weight) + init.xavier_uniform_(module.v_proj.weight) + init.xavier_uniform_(module.out_proj.weight) + init.zeros_(module.q_proj.bias) + init.zeros_(module.k_proj.bias) + init.zeros_(module.v_proj.bias) + init.zeros_(module.out_proj.bias) elif isinstance(module, Siglip2MLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) + init.xavier_uniform_(module.fc1.weight) + init.xavier_uniform_(module.fc2.weight) + init.normal_(module.fc1.bias, std=1e-6) + init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe) - nn.init.xavier_uniform_(module.attention.in_proj_weight) - nn.init.zeros_(module.attention.in_proj_bias) + init.xavier_uniform_(module.probe) + init.xavier_uniform_(module.attention.in_proj_weight) + init.zeros_(module.attention.in_proj_bias) elif isinstance(module, Siglip2Model): - logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.fill_(logit_scale_init) - module.logit_bias.zero_() + init.zeros_(module.logit_scale) + init.zeros_(module.logit_bias) elif isinstance(module, Siglip2ForImageClassification): - nn.init.normal_( + init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: - nn.init.zeros_(module.bias) + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) class Siglip2Encoder(nn.Module): diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 625e616b5989..dc72ed10a670 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -49,26 +49,6 @@ logger = logging.get_logger(__name__) -class SmolVLMRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - SmolVLMRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - @auto_docstring class SmolVLMPreTrainedModel(PreTrainedModel): config: SmolVLMConfig @@ -80,27 +60,8 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_attention_backend = True - @torch.no_grad() - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, SmolVLMRMSNorm): - module.weight.fill_(1.0) - class SmolVLMVisionEmbeddings(nn.Module): """ diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 5db0e2ff2605..c53034a66510 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -146,6 +146,8 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) + self.post_init() + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 0176ef4fa636..a0f017842559 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -495,18 +495,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 74744a42e6f5..114c4d444124 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -1175,37 +1176,38 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, SpeechT5PositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, SpeechT5ScaledPositionalEncoding): - module.alpha.fill_(1.0) + init.ones_(module.alpha) elif isinstance(module, SpeechT5FeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if hasattr(module, "masked_spec_embed"): - nn.init.uniform_(module.masked_spec_embed) + init.uniform_(module.masked_spec_embed) class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -3015,14 +3017,6 @@ def __init__(self, config: SpeechT5HifiGanConfig): # Initialize weights and apply final processing self.post_init() - @torch.no_grad() - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index d0fa3699207b..f9b73afdd247 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -331,21 +331,6 @@ class SplinterPreTrainedModel(PreTrainedModel): base_model_prefix = "splinter" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class SplinterModel(SplinterPreTrainedModel): diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index b5418e34a575..94d5f36891d8 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -408,19 +409,9 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, SqueezeBertLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, SqueezeBertLMPredictionHead): + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index f2ab414ff30c..3b091726fab4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -449,24 +449,8 @@ class StableLmPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - @auto_docstring class StableLmModel(StableLmPreTrainedModel): diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index fbba759df1b5..e676920347f8 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -23,6 +23,7 @@ from transformers import PreTrainedModel from transformers.models.superglue.configuration_superglue import SuperGlueConfig +from ... import initialization as init from ...utils import ModelOutput, auto_docstring, logging from ..auto import AutoModelForKeypointDetection @@ -472,16 +473,9 @@ class SuperGluePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.BatchNorm1d): - module.bias.zero_() - module.weight.fill_(1.0) - + super()._init_weights(module) if hasattr(module, "bin_score"): - module.bin_score.fill_(1.0) + init.ones_(module.bin_score) @auto_docstring( diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index 9e2abdeb863f..e9f808e6037c 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -328,17 +328,6 @@ class SuperPointPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = False - @torch.no_grad() - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: """ Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index 0fabef2afe44..2440c2cc0550 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2CLS from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention @@ -393,20 +394,20 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Conv2d, nn.Linear)): - nn.init.trunc_normal_(module.weight, std=0.02) + init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: - nn.init.constant_(module.bias, 0) + init.constant_(module.bias, 0) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - nn.init.constant_(module.bias, 0) - nn.init.constant_(module.weight, 1.0) + init.constant_(module.bias, 0) + init.constant_(module.weight, 1.0) elif isinstance(module, (SwiftFormerConvEncoder, SwiftFormerLocalRepresentation)): - module.layer_scale.fill_(1.0) + init.ones_(module.layer_scale) elif isinstance(module, SwiftFormerEncoderBlock): if self.config.use_layer_scale: - module.layer_scale_1.fill_(self.config.layer_scale_init_value) - module.layer_scale_2.fill_(self.config.layer_scale_init_value) + init.constant_(module.layer_scale_1, self.config.layer_scale_init_value) + init.constant_(module.layer_scale_2, self.config.layer_scale_init_value) elif isinstance(module, SwiftFormerEfficientAdditiveAttention): - nn.init.normal_(module.w_g) + init.normal_(module.w_g) @auto_docstring diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 82bf2bfbc173..d28146e90ba2 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -23,6 +23,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput @@ -814,20 +815,14 @@ class SwinPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, SwinEmbeddings): + super()._init_weights(module) + if isinstance(module, SwinEmbeddings): if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, SwinSelfAttention): - module.relative_position_bias_table.zero_() + init.zeros_(module.relative_position_bias_table) @auto_docstring diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 093d34994b3a..2cc6724be0b1 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput @@ -695,12 +696,12 @@ class Swin2SRPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - torch.nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) + init.trunc_normal_(module.weight, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index ffbeff3456ca..c27f8fc745ed 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput @@ -890,19 +891,19 @@ class Swinv2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, Swinv2Embeddings): if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) if module.position_embeddings is not None: - module.position_embeddings.zero_() + init.zeros_(module.position_embeddings) elif isinstance(module, Swinv2SelfAttention): - module.logit_scale.fill_(math.log(10)) + init.constant_(module.logit_scale, math.log(10)) @auto_docstring diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 07ffd1c280c3..8792f50d4e38 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -27,6 +27,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -583,7 +584,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config: SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True - _can_compile_fullgraph = False _no_split_modules = ["SwitchTransformersBlock"] @@ -592,39 +592,39 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, SwitchTransformersLayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance( module, (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), ): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, SwitchTransformersDenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, SwitchTransformersAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, SwitchTransformersSparseMLP): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) + init.normal_(module.router.classifier.weight, mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wi.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wo.weight, mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/switch_transformers/modular_switch_transformers.py b/src/transformers/models/switch_transformers/modular_switch_transformers.py index d1a9f3788290..33009c21a15a 100644 --- a/src/transformers/models/switch_transformers/modular_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modular_switch_transformers.py @@ -21,6 +21,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -339,7 +340,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config: SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True - _can_compile_fullgraph = False _no_split_modules = ["SwitchTransformersBlock"] @@ -348,39 +348,39 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, SwitchTransformersLayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance( module, (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), ): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, SwitchTransformersDenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, SwitchTransformersAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, SwitchTransformersSparseMLP): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) + init.normal_(module.router.classifier.weight, mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wi.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.experts[f"expert_{idx}"].wo.weight, mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 63c447079897..f3cb42b4bc23 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -567,55 +568,55 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance( module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), ): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.zero_() + init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.qa_outputs.bias) elif isinstance(module, T5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.zero_() + init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0) + init.zeros_(module.classifier.bias) elif isinstance(module, T5ClassificationHead): - module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.zero_() - module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.dense.bias) + init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.zero_() + init.zeros_(module.out_proj.bias) elif isinstance(module, T5DenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, T5DenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, T5Attention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 13b6d3c75f14..5f70baa4812e 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -578,16 +579,16 @@ def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.normal_(mean=0.0, std=std * scale) + init.normal_(module.out_proj.weight, mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.zero_() + init.zeros_(module.out_proj.bias) elif isinstance(module, T5GemmaLMHead): if not self.config.tie_word_embeddings: scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.normal_(mean=0.0, std=std * scale) + init.normal_(module.out_proj.weight, mean=0.0, std=std * scale) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) def _shift_right(self, input_ids): """ diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index a020827bb757..b3ed006f007e 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin @@ -609,16 +610,16 @@ def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.normal_(mean=0.0, std=std * scale) + init.normal_(module.out_proj.weight, mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.zero_() + init.zeros_(module.out_proj.bias) elif isinstance(module, T5GemmaLMHead): if not self.config.tie_word_embeddings: scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.normal_(mean=0.0, std=std * scale) + init.normal_(module.out_proj.weight, mean=0.0, std=std * scale) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) def _shift_right(self, input_ids): """ diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index f0577309ccda..8c4777e74420 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -699,16 +700,17 @@ def _init_weights(self, module): std = self.config.init_std if isinstance(module, TableTransformerLearnedPositionEmbedding): - nn.init.uniform_(module.row_embeddings.weight) - nn.init.uniform_(module.column_embeddings.weight) + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) class TableTransformerEncoder(TableTransformerPreTrainedModel): diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index e0206fc5c0a8..819a90d70aeb 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer @@ -511,19 +512,9 @@ class TapasPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, TapasLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, TapasLMPredictionHead): + init.zeros_(module.bias) @auto_docstring @@ -819,11 +810,11 @@ def __init__(self, config: TapasConfig): self.column_output_weights = nn.Parameter(torch.zeros(config.hidden_size)) else: self.output_weights = nn.Parameter(torch.empty(config.hidden_size)) - nn.init.normal_( + init.normal_( self.output_weights, std=config.initializer_range ) # here, a truncated normal is used in the original implementation self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size)) - nn.init.normal_( + init.normal_( self.column_output_weights, std=config.initializer_range ) # here, a truncated normal is used in the original implementation self.output_bias = nn.Parameter(torch.zeros([])) diff --git a/src/transformers/models/textnet/modeling_textnet.py b/src/transformers/models/textnet/modeling_textnet.py index 616a1a8327c6..3928c571e5aa 100644 --- a/src/transformers/models/textnet/modeling_textnet.py +++ b/src/transformers/models/textnet/modeling_textnet.py @@ -221,17 +221,6 @@ class TextNetPreTrainedModel(PreTrainedModel): base_model_prefix = "textnet" main_input_name = "pixel_values" - @torch.no_grad() - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.weight.fill_(1.0) - if module.bias is not None: - module.bias.zero_() - @auto_docstring class TextNetModel(TextNetPreTrainedModel): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 33dc932e01b4..ab4976f1121d 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...masking_utils import create_bidirectional_mask, create_causal_mask @@ -230,9 +231,9 @@ class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) + super().__init__(num_positions, embedding_dim, _freeze=True) - def _init_weight(self): + def create_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] @@ -245,7 +246,7 @@ def _init_weight(self): sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - self.weight = nn.Parameter(out, requires_grad=False) + return out @torch.no_grad() def forward( @@ -617,17 +618,9 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + super()._init_weights(module) + if isinstance(module, TimeSeriesSinusoidalPositionalEmbedding): + init.copy_(module.weight, module.create_weight()) class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index d8042a82bea9..9799d6b24a14 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -28,6 +28,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput @@ -311,7 +312,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, TimesFmAttention): # Initialize scaling parameter - nn.init.ones_(module.scaling) + init.ones_(module.scaling) @auto_docstring diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index dc5e05e33714..5610b5d11c42 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -267,7 +268,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, TimesFmAttention): # Initialize scaling parameter - nn.init.ones_(module.scaling) + init.ones_(module.scaling) @auto_docstring diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 5d463c73da91..ee1f6fcab8e4 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput @@ -458,16 +459,15 @@ class TimesformerPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) + init.trunc_normal_(module.weight, std=self.config.initializer_range) if module.bias is not None: - nn.init.constant_(module.bias, 0) + init.constant_(module.bias, 0) elif isinstance(module, nn.LayerNorm): - nn.init.constant_(module.bias, 0) - nn.init.constant_(module.weight, 1.0) + init.constant_(module.bias, 0) + init.constant_(module.weight, 1.0) elif isinstance(module, TimesformerEmbeddings): - nn.init.trunc_normal_(module.cls_token, std=self.config.initializer_range) - nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) - module.patch_embeddings.apply(self._init_weights) + init.trunc_normal_(module.cls_token, std=self.config.initializer_range) + init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 40481d26fbac..6e364cf1da05 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -19,6 +19,7 @@ import torch from torch import Tensor, nn +from ... import initialization as init from ...modeling_outputs import ImageClassifierOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_timm_available, requires_backends @@ -131,9 +132,9 @@ def _init_weights(self, module): initialization, while all other weights should be loaded from the checkpoint. """ if isinstance(module, (nn.Linear)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) def _timm_model_supports_gradient_checkpointing(self): """ diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 78cc9206511d..e3a513f8687e 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -406,18 +406,6 @@ class TrOCRPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["TrOCRDecoderLayer"] - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - class TrOCRDecoder(TrOCRPreTrainedModel): """ diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 303ddfbfb9cb..7648c3d00038 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput @@ -526,27 +527,27 @@ class TvpPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - nn.init.constant_(module.bias, 0) + init.constant_(module.bias, 0) elif isinstance(module, TvpModel): - nn.init.normal_(module.text_prompt) + init.normal_(module.text_prompt) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) if hasattr(module, "pad_up"): - nn.init.normal_(module.pad_up) + init.normal_(module.pad_up) if hasattr(module, "pad_down"): - nn.init.normal_(module.pad_down) + init.normal_(module.pad_down) if hasattr(module, "pad_left"): - nn.init.normal_(module.pad_left) + init.normal_(module.pad_left) if hasattr(module, "pad_right"): - nn.init.normal_(module.pad_right) + init.normal_(module.pad_right) class TvpFrameDownPadPrompter(nn.Module): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 30d4d1e689fb..c11f1b4a9102 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,6 +34,7 @@ Seq2SeqModelOutput, ) +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -262,55 +263,52 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, UdopLayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=factor) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=factor) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.Conv2d): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=factor).to(module.weight.dtype) - ) + init.trunc_normal_(module.weight, mean=0.0, std=factor) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, RelativePositionBiasBase): factor = self.config.initializer_factor d_model = self.config.d_model - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, UdopModel): - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, UdopForConditionalGeneration): if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, UdopDenseActDense): - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, UdopDenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, UdopAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop def _shift_right(self, input_ids): @@ -1077,6 +1075,7 @@ def __init__(self, config): # get weights from encoder position bias self.relative_bias = self._get_relative_bias(config) + self.post_init() @staticmethod def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated: diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index d5a0f955049d..69f6c77e06ae 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -507,7 +508,7 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, UMT5LayerNorm): - module.weight.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance( module, ( @@ -519,55 +520,55 @@ def _init_weights(self, module): ): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.shared.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.zero_() + init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.qa_outputs.bias) elif isinstance(module, UMT5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.zero_() + init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0) + init.zeros_(module.classifier.bias) elif isinstance(module, UMT5ClassificationHead): - module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.zero_() - module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.dense.bias) + init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.zero_() + init.zeros_(module.out_proj.bias) elif isinstance(module, UMT5DenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, UMT5DenseGatedActDense): - module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.zero_() - module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.zeros_(module.wi_0.bias) + init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.zero_() - module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi_1.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, UMT5Attention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index b6f2590e2be5..03160d33bd90 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -29,6 +29,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -744,34 +745,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, UniSpeechPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, UniSpeechFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 84b962e4b1d1..25cc7406c419 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...modeling_outputs import ModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -151,34 +152,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, UniSpeechPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, UniSpeechFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index e1906639f2dc..b83c184e2b46 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -30,6 +30,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -750,34 +751,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechSatGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, UniSpeechSatPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, UniSpeechSatFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ @@ -1438,7 +1439,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ @@ -1594,7 +1595,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index abcdb0364810..107d0810f546 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...modeling_outputs import ModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -163,34 +164,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechSatGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, UniSpeechSatPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, UniSpeechSatFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/univnet/modeling_univnet.py b/src/transformers/models/univnet/modeling_univnet.py index 1b208acdc5d9..a977ab8aadd5 100644 --- a/src/transformers/models/univnet/modeling_univnet.py +++ b/src/transformers/models/univnet/modeling_univnet.py @@ -591,14 +591,6 @@ def forward( waveform_lengths=waveform_lengths, ) - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 64bd7e958f7b..c897f1c81f8c 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -272,16 +272,6 @@ class UperNetPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] - @torch.no_grad() - def _init_weights(self, module): - if isinstance(module, nn.Conv2d): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.weight.fill_(1.0) - module.bias.zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 40977bfc2c42..ee36d7519a53 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -373,10 +374,9 @@ class VaultGemmaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.zero_() + init.zeros_(module.weight) @auto_docstring diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 495719cb22c7..4aaea9d762d5 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -129,32 +130,21 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn = True _supports_sdpa = True - _can_compile_fullgraph = True _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module): + super()._init_weights(module) std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) - if hasattr(module, "class_embedding"): - module.class_embedding.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.class_embedding, mean=0.0, std=std) @auto_docstring( diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index b1a7179771d6..268a2b21bffc 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -392,17 +392,6 @@ class VideoMAEPreTrainedModel(PreTrainedModel): "attentions": VideoMAESelfAttention, } - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class VideoMAEModel(VideoMAEPreTrainedModel): diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 67a2a34b58f2..fc4d226eb098 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -516,21 +516,6 @@ class ViltPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"] - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class ViltModel(ViltPreTrainedModel): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 796580aaa0c6..d3c64e1848df 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -145,6 +145,8 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) + self.post_init() + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index 0f7f86bb1458..52f1144bf874 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -102,6 +102,8 @@ def __init__( self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + self.post_init() + @filter_out_non_signature_kwargs() @auto_docstring def get_text_features( diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index a085f8954f03..0ce9ee4b8a12 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -464,14 +465,14 @@ class VisualBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if hasattr(module, "bias") and module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, VisualBertLMPredictionHead): - module.bias.zero_() + init.zeros_(module.bias) @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index bef55534d577..79221145486e 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -369,37 +370,17 @@ class ViTPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, ViTEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - - module.cls_token.copy_( - nn.init.trunc_normal_( - module.cls_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - ) - + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) @auto_docstring diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index a6b24268ac58..281a142e6f26 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -24,6 +24,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -188,14 +189,14 @@ def initialize_weights(self): pos_embed = get_2d_sincos_pos_embed( self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True ) - self.position_embeddings.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + init.copy_(self.position_embeddings, torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) w = self.patch_embeddings.projection.weight - torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) - torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) + init.normal_(self.cls_token, std=self.config.initializer_range) # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: @@ -536,17 +537,17 @@ class ViTMAEPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, ViTMAEEmbeddings): module.initialize_weights() elif isinstance(module, ViTMAEDecoder): - module.mask_token.zero_() - module.decoder_pos_embed.zero_() + init.zeros_(module.mask_token) + init.zeros_(module.decoder_pos_embed) @auto_docstring @@ -685,10 +686,10 @@ def initialize_weights(self, num_patches): decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True ) - self.decoder_pos_embed.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + init.copy_(self.decoder_pos_embed, torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) - torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range) + init.normal_(self.mask_token, std=self.config.initializer_range) def forward(self, hidden_states: torch.Tensor, ids_restore: torch.Tensor, interpolate_pos_encoding: bool = False): # Embed tokens diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index e10dfb6d123f..21e702970dc8 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput @@ -374,17 +375,17 @@ class ViTMSNPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, ViTMSNEmbeddings): - module.cls_token.zero_() - module.position_embeddings.zero_() + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) if module.mask_token is not None: - module.mask_token.zero_() + init.zeros_(module.mask_token) @auto_docstring diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index a235b25a57c5..c8aac7a89910 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput @@ -570,20 +571,6 @@ def forward( ) -def caffe2_msra_fill(module: nn.Module) -> None: - """ - Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0. - - Source: https://detectron2.readthedocs.io/en/latest/_modules/fvcore/nn/weight_init.html. - - Args: - module (torch.nn.Module): module to initialize. - """ - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - @auto_docstring class VitDetPreTrainedModel(PreTrainedModel): config: VitDetConfig @@ -597,53 +584,28 @@ class VitDetPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, VitDetEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) - + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings: - module.rel_pos_h.copy_( - nn.init.trunc_normal_( - module.rel_pos_h.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ) - ) - module.rel_pos_w.copy_( - nn.init.trunc_normal_( - module.rel_pos_w.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ) - ) - + init.trunc_normal_(module.rel_pos_h, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.rel_pos_w, mean=0.0, std=self.config.initializer_range) elif isinstance(module, VitDetResBottleneckBlock): for layer in [module.conv1, module.conv2, module.conv3]: - caffe2_msra_fill(layer) + init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu") + if layer.bias is not None: + init.constant_(layer.bias, 0) for layer in [module.norm1, module.norm2]: - layer.weight.fill_(1.0) - layer.bias.zero_() + init.ones_(layer.weight) + init.zeros_(layer.bias) # zero init last norm layer. - module.norm3.weight.zero_() - module.norm3.bias.zero_() + init.zeros_(module.norm3.weight) + init.zeros_(module.norm3.bias) @auto_docstring diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 8cf9841d1e47..31be5264a81d 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring from ...utils.backbone_utils import load_backbone @@ -61,9 +62,9 @@ class VitMattePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vitpose/modeling_vitpose.py b/src/transformers/models/vitpose/modeling_vitpose.py index f87396b564f7..2a391f119560 100644 --- a/src/transformers/models/vitpose/modeling_vitpose.py +++ b/src/transformers/models/vitpose/modeling_vitpose.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack @@ -70,18 +71,12 @@ class VitPosePreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index c5c5d8ffbe02..a615e04f6b0c 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -26,6 +26,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput @@ -361,26 +362,14 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.copy_( - nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( - module.weight.dtype - ) - ) + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, VitPoseBackboneEmbeddings): - module.position_embeddings.copy_( - nn.init.trunc_normal_( - module.position_embeddings.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - ) + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index dd9117e309a3..22eabb6d9299 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1206,29 +1207,30 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) + init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, VitsAttention): if self.config.window_size: head_dim = self.config.hidden_size // self.config.num_attention_heads - nn.init.normal_(module.emb_rel_k, std=head_dim**-0.5) - nn.init.normal_(module.emb_rel_v, std=head_dim**-0.5) + init.normal_(module.emb_rel_k, std=head_dim**-0.5) + init.normal_(module.emb_rel_v, std=head_dim**-0.5) elif isinstance(module, VitsElementwiseAffine): - module.translate.zero_() - module.log_scale.zero_() + init.zeros_(module.translate) + init.zeros_(module.log_scale) @auto_docstring( diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index ed55faac7aa0..3f2494783b14 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -20,6 +20,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput @@ -378,20 +379,10 @@ class VivitPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, VivitEmbeddings): - module.cls_token.zero_() - module.position_embeddings.zero_() + super()._init_weights(module) + if isinstance(module, VivitEmbeddings): + init.zeros_(module.cls_token) + init.zeros_(module.position_embeddings) @auto_docstring diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index f2ab5b1f2cf8..f87b7c2854a3 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -19,6 +19,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput @@ -946,34 +947,26 @@ def _init_weights(self, module): """Initialize the weights""" init_std = self.config.initializer_range - - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - def trunc_normal_f32_(weight, std): - data_float_32 = weight.to(torch.float32) - data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std) - weight.copy_(data_init.to(weight.dtype)) - if isinstance(module, VJEPA2AttentivePooler): - trunc_normal_f32_(module.query_tokens, std=init_std) + init.trunc_normal_(module.query_tokens, std=init_std) for i, layer in enumerate(module.self_attention_layers, 1): std = init_std / (i**0.5) - trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std) - trunc_normal_f32_(layer.mlp.fc2.weight, std=std) + init.trunc_normal_(layer.self_attn.out_proj.weight, std=std) + init.trunc_normal_(layer.mlp.fc2.weight, std=std) std = init_std / (len(module.self_attention_layers) + 1) ** 0.5 - trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std) + init.trunc_normal_(module.cross_attention_layer.mlp.fc2.weight, std=std) elif isinstance(module, VJEPA2PredictorEmbeddings): if module.zero_init_mask_tokens: - module.mask_tokens.zero_() + init.zeros_(module.mask_tokens) else: - trunc_normal_f32_(module.mask_tokens, std=init_std) + init.trunc_normal_(module.mask_tokens, std=init_std) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): - trunc_normal_f32_(module.weight, std=init_std) + init.trunc_normal_(module.weight, std=init_std) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) @auto_docstring diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 59848b48b741..9f3321c6fa48 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -231,28 +231,6 @@ class VoxtralPreTrainedModel(PreTrainedModel): _supports_attention_backend = True _can_compile_fullgraph = True - @torch.no_grad() - def _init_weights(self, module): - # important: this ported version of Voxtral isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.audio_config.initializer_range - ) - - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index af33fdad10de..361064816991 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -976,38 +977,36 @@ def _init_weights(self, module): if isinstance(module, Wav2Vec2ForPreTraining): module.project_hid.reset_parameters() module.project_q.reset_parameters() - module.project_hid._is_hf_initialized = True - module.project_q._is_hf_initialized = True # gumbel softmax requires special init elif isinstance(module, Wav2Vec2GumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2PositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, Wav2Vec2FeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None @@ -1918,7 +1917,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ @@ -2074,7 +2073,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 65c53653c191..62c9ed0a31b6 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -13,6 +13,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -716,38 +717,38 @@ def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Wav2Vec2BertSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, Wav2Vec2BertFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, Wav2Vec2BertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.uniform_() + init.uniform_(module.masked_spec_embed) elif isinstance( module, (Wav2Vec2BertForSequenceClassification, Wav2Vec2BertForAudioFrameClassification, Wav2Vec2BertForXVector), ): if hasattr(module, "layer_weights"): - module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) + init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1)) elif isinstance(module, AMSoftmaxLoss): # noqa: F821 - module.weight.normal_() + init.normal_(module.weight) # Ignore copy def _get_feat_extract_output_lengths( @@ -1262,7 +1263,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_base_model(self): """ @@ -1405,7 +1406,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_base_model(self): """ diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index 4cd67b9db252..a4961ed59528 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -588,38 +589,38 @@ def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Wav2Vec2BertSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, Wav2Vec2BertFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, Wav2Vec2BertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.uniform_() + init.uniform_(module.masked_spec_embed) elif isinstance( module, (Wav2Vec2BertForSequenceClassification, Wav2Vec2BertForAudioFrameClassification, Wav2Vec2BertForXVector), ): if hasattr(module, "layer_weights"): - module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) + init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1)) elif isinstance(module, AMSoftmaxLoss): # noqa: F821 - module.weight.normal_() + init.normal_(module.weight) # Ignore copy def _get_feat_extract_output_lengths( diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f3ee90ba8576..478507d1e758 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -14,6 +14,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -860,39 +861,39 @@ def _init_weights(self, module): module.project_q.reset_parameters() # gumbel softmax requires special init elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, Wav2Vec2ConformerFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None @@ -1646,7 +1647,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ @@ -1802,7 +1803,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index 7b8080798471..731cf1c18add 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -5,6 +5,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -559,39 +560,39 @@ def _init_weights(self, module): module.project_q.reset_parameters() # gumbel softmax requires special init elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): - nn.init.xavier_uniform_(module.pos_bias_u) + init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): - nn.init.xavier_uniform_(module.pos_bias_v) + init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, Wav2Vec2ConformerFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 958337d3bc3d..25b600086582 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -608,34 +609,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, WavLMFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None @@ -1367,7 +1368,7 @@ def __init__(self, config): self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.num_labels = config.num_labels - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ @@ -1523,7 +1524,7 @@ def __init__(self, config): self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) - self.init_weights() + self.post_init() def freeze_feature_encoder(self): """ diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index c50f2a4ec7e1..b9fa58346a6d 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_layers import GradientCheckpointingLayer @@ -518,34 +519,34 @@ def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): - module.weight_proj.weight.normal_(mean=0.0, std=1) - module.weight_proj.bias.zero_() - nn.init.uniform_(module.codevectors) + init.normal_(module.weight_proj.weight, mean=0.0, std=1) + init.zeros_(module.weight_proj.bias) + init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): - nn.init.normal_( + init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) - nn.init.constant_(module.conv.bias, 0) + init.constant_(module.conv.bias, 0) elif isinstance(module, WavLMFeatureProjection): k = math.sqrt(1 / module.projection.in_features) - nn.init.uniform_(module.projection.weight, a=-k, b=k) - nn.init.uniform_(module.projection.bias, a=-k, b=k) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) def _get_adapters(self): raise AttributeError("Not needed for WavLM") diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 6e91445ca961..cff7e10b4b2f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -540,23 +541,12 @@ class WhisperPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.fill_(1.0) - module.bias.zero_() - elif isinstance(module, WhisperEncoder): - module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape)) + super()._init_weights(module) + if isinstance(module, WhisperEncoder): + init.copy_(module.embed_positions.weight, sinusoids(*module.embed_positions.weight.shape)) elif isinstance(module, WhisperForAudioClassification): if self.config.use_weighted_layer_sum: - module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) + init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 36be6ad43294..8b0695d1d140 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -22,6 +22,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -509,48 +510,48 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, XCLIPTextEmbeddings): - module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) elif isinstance(module, XCLIPVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, XCLIPAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, XCLIPMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, XCLIPModel): factor = self.config.initializer_factor - nn.init.normal_( + init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * factor, ) - nn.init.normal_( + init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * factor, ) - nn.init.normal_(module.prompts_visual_projection, mean=0.0, std=module.vision_embed_dim**-0.5 * factor) + init.normal_(module.prompts_visual_projection, mean=0.0, std=module.vision_embed_dim**-0.5 * factor) elif isinstance(module, XCLIPMultiframeIntegrationTransformer): - nn.init.normal_(module.position_embedding, std=self.config.initializer_factor) + init.normal_(module.position_embedding, std=self.config.initializer_factor) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_factor) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_factor) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->XCLIP diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 7e5b802e72f7..44fd8c669131 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -22,6 +22,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ModelOutput, auto_docstring from ..auto import AutoModel @@ -331,36 +332,34 @@ class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.zero_() + init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) + init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + init.uniform_(module.bias, a=-k, b=k) elif module.__class__.__name__ == "Snake1d": - module.alpha.fill_(1.0) + init.ones_(module.alpha) elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=0.02) + init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, XcodecModel): # The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel, # but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel for submodule in module.acoustic_encoder.modules(): if isinstance(submodule, nn.Conv1d): - nn.init.trunc_normal_(submodule.weight, std=0.02) - nn.init.constant_(submodule.bias, 0) - submodule._is_hf_initialized = True + init.trunc_normal_(submodule.weight, std=0.02) + init.constant_(submodule.bias, 0) for submodule in module.acoustic_decoder.modules(): if isinstance(submodule, nn.Conv1d): - nn.init.trunc_normal_(submodule.weight, std=0.02) - nn.init.constant_(submodule.bias, 0) - submodule._is_hf_initialized = True + init.trunc_normal_(submodule.weight, std=0.02) + init.constant_(submodule.bias, 0) def apply_weight_norm(self): """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied.""" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 6edd50844c25..f8e8f5fae2e1 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -361,18 +361,6 @@ class XGLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["XGLMDecoderLayer"] - @torch.no_grad() - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - @auto_docstring class XGLMModel(XGLMPreTrainedModel): diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 5ed343824902..31bad0ef2af2 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import gelu, get_activation from ...cache_utils import DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -52,6 +53,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() + return out def get_masks(slen, lengths, causal, padding_mask=None): @@ -619,20 +621,26 @@ def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: - nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() + init.normal_(module.weight, mean=0, std=self.config.embed_init_std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) if isinstance(module, nn.Linear): if self.config is not None and self.config.init_std is not None: - nn.init.normal_(module.weight, mean=0, std=self.config.init_std) + init.normal_(module.weight, mean=0, std=self.config.init_std) if module.bias is not None: - nn.init.constant_(module.bias, 0.0) + init.constant_(module.bias, 0.0) if isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) + init.zeros_(module.bias) + init.ones_(module.weight) if isinstance(module, XLMModel) and self.config.sinusoidal_embeddings: - create_sinusoidal_embeddings( - self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight + init.copy_( + module.position_embeddings.weight, + create_sinusoidal_embeddings( + self.config.max_position_embeddings, + self.config.emb_dim, + out=torch.empty_like(module.position_embeddings.weight), + ), ) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 05fa46b23f54..7a8743fd9686 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -27,6 +27,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -413,19 +414,9 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, XLMRobertaLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, XLMRobertaLMHead): + init.zeros_(module.bias) class XLMRobertaEmbeddings(nn.Module): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index a6200dc1ddde..8631d33c3a8c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -32,6 +32,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -538,19 +539,9 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, XLMRobertaXLLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, XLMRobertaXLLMHead): + init.zeros_(module.bias) class XLMRobertaXLPooler(nn.Module): @@ -784,7 +775,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.lm_head = XLMRobertaXLLMHead(config) - self.init_weights() + self.post_init() def get_output_embeddings(self): return self.lm_head.decoder @@ -887,7 +878,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.lm_head = XLMRobertaXLLMHead(config) - self.init_weights() + self.post_init() def get_output_embeddings(self): return self.lm_head.decoder @@ -958,7 +949,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.classifier = XLMRobertaXLClassificationHead(config) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -1030,7 +1021,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -1121,7 +1112,7 @@ def __init__(self, config): self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -1185,7 +1176,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py index ec2dcf9a0a39..31c0068ef712 100644 --- a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py @@ -280,7 +280,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.lm_head = XLMRobertaXLLMHead(config) - self.init_weights() + self.post_init() def get_output_embeddings(self): return self.lm_head.decoder @@ -383,7 +383,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.lm_head = XLMRobertaXLLMHead(config) - self.init_weights() + self.post_init() def get_output_embeddings(self): return self.lm_head.decoder @@ -454,7 +454,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.classifier = XLMRobertaXLClassificationHead(config) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -526,7 +526,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -617,7 +617,7 @@ def __init__(self, config): self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring @@ -681,7 +681,7 @@ def __init__(self, config): self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + self.post_init() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index a52ae140e77d..08a6f1df0b55 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel @@ -638,18 +639,8 @@ class XLNetPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, XLNetRelativeAttention): + super()._init_weights(module) + if isinstance(module, XLNetRelativeAttention): for param in [ module.q, module.k, @@ -661,9 +652,9 @@ def _init_weights(self, module): module.r_w_bias, module.seg_embed, ]: - param.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(param, mean=0.0, std=self.config.initializer_range) elif isinstance(module, XLNetModel): - module.mask_emb.normal_(mean=0.0, std=self.config.initializer_range) + init.normal_(module.mask_emb, mean=0.0, std=self.config.initializer_range) @dataclass diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 685df5dc42f8..c1a38c0029e9 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ... import initialization as init from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -1211,7 +1212,7 @@ def small_init_method(dim): std = (2 / (5 * dim)) ** (1 / 2) def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init.normal_(tensor, mean=0.0, std=std) return init_ @@ -1223,7 +1224,7 @@ def wang_init_method(n_layers, dim): std = 2 / n_layers / dim ** (1 / 2) def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init.normal_(tensor, mean=0.0, std=std) return init_ @@ -1251,37 +1252,46 @@ def _init_weights(self, module): small_init_method(self.config.hidden_size)(self.embeddings.weight) elif isinstance(module, nn.Linear): if module.bias is not None: - torch.nn.init.zeros_(module.bias) + init.zeros_(module.bias) if self.config.weight_mode == "single" and "gate" in self._module_name_map(module): - torch.nn.init.zeros_(module.weight) - with torch.no_grad(): - if "igate" in self._module_name_map(module): - module.bias.copy_(-10.0 * torch.ones_like(module.bias)) - elif "fgate" in self._module_name_map(module): - module.bias.copy_( - torch.linspace( - 3.0, - 6.0, - module.bias.shape[-1], - ).to( - device=module.bias.device, - dtype=module.bias.dtype, - ) - ) + init.zeros_(module.weight) + + if "igate" in self._module_name_map(module): + init.copy_(module.bias, -10.0 * torch.ones_like(module.bias)) + elif "fgate" in self._module_name_map(module): + init.copy_( + module.bias, + torch.linspace( + 3.0, + 6.0, + module.bias.shape[-1], + ).to( + device=module.bias.device, + dtype=module.bias.dtype, + ), + ) elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module): - torch.nn.init.zeros_(module.weight) - with torch.no_grad(): - module.bias[: self.config.num_heads] += -module.bias[ - : self.config.num_heads - ] - 10.0 * torch.ones_like(module.bias) - module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace( + init.zeros_(module.weight) + + init.copy_( + module.bias[: self.config.num_heads], + module.bias[: self.config.num_heads] + - module.bias[: self.config.num_heads] + - 10.0 * torch.ones_like(module.bias), + ) + init.copy_( + module.bias[: self.config.num_heads], + module.bias[: self.config.num_heads] + - module.bias[self.config.num_heads :] + + torch.linspace( 3.0, 6.0, module.bias.shape[-1], ).to( device=module.bias.device, dtype=module.bias.dtype, - ) + ), + ) elif "proj_down" in self._module_name_map(module): wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight) elif "out_proj" in self._module_name_map(module): @@ -1289,9 +1299,9 @@ def _init_weights(self, module): elif module.weight is not None: small_init_method(self.config.hidden_size)(module.weight) elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"): - torch.nn.init.ones_(module.weight) + init.ones_(module.weight) if hasattr(module, "bias") and module.bias is not None: - torch.nn.init.zeros_(module.bias) + init.zeros_(module.bias) class xLSTMCache: diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index b50d4fb64600..551797eb5846 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN, gelu from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -630,19 +631,9 @@ class XmodPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, XmodLMHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, XmodLMHead): + init.zeros_(module.bias) def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index edd6cfd5b10e..77df8447a256 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -17,7 +17,7 @@ import collections.abc from collections.abc import Callable from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional import torch from torch import nn @@ -445,17 +445,6 @@ class YolosPreTrainedModel(PreTrainedModel): "attentions": YolosSelfAttention, } - @torch.no_grad() - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring class YolosModel(YolosPreTrainedModel): diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index ce945d24bdb9..6e319b5b5f92 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -607,20 +608,9 @@ class YosoPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - elif isinstance(module, YosoLMPredictionHead): - module.bias.zero_() + super()._init_weights(module) + if isinstance(module, YosoLMPredictionHead): + init.zeros_(module.bias) @auto_docstring diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index a755cede13ba..fb51f5add858 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -796,20 +797,11 @@ class ZambaPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.Embedding): - module.weight.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx].zero_() - elif isinstance(module, ZambaRMSNorm): - module.weight.fill_(1.0) - elif isinstance(module, ZambaMambaMixer): - module.x_proj_weight.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, ZambaMambaMixer): + init.normal_(module.x_proj_weight, mean=0.0, std=std) dt_init_std = self.config.mamba_dt_rank**-0.5 - nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std) + init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std) mamba_head_dim = self.config.mamba_expand * self.config.hidden_size // self.config.n_mamba_heads dt = torch.exp( @@ -819,12 +811,12 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_proj_bias.copy_(inv_dt) + init.copy_(module.dt_proj_bias, inv_dt) A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) + init.ones_(module.D) @auto_docstring diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 40197d8667ca..5b5f532cebf7 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -28,6 +28,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin @@ -1226,11 +1227,11 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_bias.copy_(inv_dt) + init.copy_(module.dt_bias, inv_dt) A = torch.arange(1, module.num_heads + 1) - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) @auto_docstring diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 9e3875ffa927..b4761a50dd42 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast @@ -914,11 +915,11 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_bias.copy_(inv_dt) + init.copy_(module.dt_bias, inv_dt) A = torch.arange(1, module.num_heads + 1) - module.A_log.copy_(torch.log(A)) - module.D.fill_(1.0) + init.copy_(module.A_log, torch.log(A)) + init.ones_(module.D) class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index f077fd387dd3..8a23f8380375 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -1211,17 +1211,6 @@ class ZoeDepthPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True - @torch.no_grad() - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.zero_() - module.weight.fill_(1.0) - @auto_docstring( custom_intro=""" diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index ca9e7b03ac4c..68b375fee077 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -56,6 +56,7 @@ if is_torch_available(): import torch + import torch.nn as nn from tests.trainer.test_trainer import ( # noqa RegressionModelConfig, @@ -292,8 +293,8 @@ def forward(self, *args, **kwargs): def _init_weights(self, module): super()._init_weights(module) if module is self.new_head: - self.new_head.weight.data.fill_(-100.0) - self.new_head.bias.data.fill_(+100.0) + nn.init.constant_(self.new_head.weight.data, -100.0) + nn.init.constant_(self.new_head.bias.data, 100.0) ds_config = { "train_batch_size": 1, diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 3804c3914f23..1701f8ca9e1a 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -555,6 +555,7 @@ def test_custom_model_patched_generation_inheritance(self): # patching was added in v4.45) self.assertTrue("GenerationMixin" in str(model.__class__.__bases__)) + @unittest.skip("@Cyril: add the post_init() on the hub repo") def test_model_with_dotted_name_and_relative_imports(self): """ Test for issue #40496: AutoModel.from_pretrained() doesn't work for models with '.' in their name diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index c30cc27b34c9..8b4d6c1eabfe 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -171,7 +171,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict): embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - embed_positions._init_weight() + embed_positions.weight.copy_(embed_positions.create_weight()) embed_positions = embed_positions.to(torch_device) self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index f8eeb27e3286..bf4fe458ebbc 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -517,7 +517,7 @@ class RoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase): def test_basic(self): input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device) emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6) - emb1._init_weight() + emb1.weight.copy_(emb1.create_weight()) emb1 = emb1.to(torch_device) emb = emb1(input_ids.shape) desired_weights = torch.tensor( @@ -537,7 +537,7 @@ def test_positional_emb_weights_against_roformer(self): ] ).to(torch_device) emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512).to(torch_device) - emb1._init_weight() + emb1.weight.copy_(emb1.create_weight()) weights = emb1.weight.data[:3, :5].to(torch_device) self.assertTrue( @@ -559,7 +559,7 @@ def test_apply_rotary_position_embeddings(self): -torch.arange(2 * 12 * 16 * 64, dtype=torch.float, device=torch_device).reshape(2, 12, 16, 64) / 100 ).to(torch_device) embed_positions = RoFormerSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=64) - embed_positions._init_weight() + embed_positions.weight.copy_(embed_positions.create_weight()) embed_positions = embed_positions.to(torch_device) sinusoidal_pos = embed_positions([2, 16, 768])[None, None, :, :] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ff2c64daba87..c06e2beba2b8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -447,6 +447,7 @@ def __init__(self, config): self.a = nn.Parameter(torch.tensor(config.a).float()) self.b = nn.Parameter(torch.tensor(config.b).float()) self.double_output = config.double_output + self.post_init() def forward(self, input_x, labels=None, **kwargs): y = input_x * self.a + self.b @@ -466,6 +467,7 @@ def __init__(self, config): self.head = nn.Linear(config.hidden_size, 1) self.gradient_checkpointing = False self.double_output = config.double_output + self.post_init() def forward(self, input_x, labels=None, **kwargs): y = input_x.unsqueeze(0) @@ -496,6 +498,7 @@ def __init__(self, config): self.a = nn.Parameter(torch.tensor(config.a).float()) self.b = nn.Parameter(torch.tensor(config.b).float()) self.random_torch = config.random_torch + self.post_init() def forward(self, input_x, labels=None, **kwargs): y = input_x * self.a + self.b diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2ee3ccbbd9c2..11ad42ca45a9 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -134,6 +134,7 @@ def __init__(self, config): super().__init__(config) self.linear = nn.Linear(5, 5) self.linear_2 = nn.Linear(5, 5) + self.post_init() def forward(self, x): return self.linear_2(self.linear(x)) @@ -147,6 +148,7 @@ def __init__(self, config): super().__init__(config) self.linear = nn.Linear(50, 50) self.linear_2 = nn.Linear(50, 50) + self.post_init() def forward(self, x): return self.linear_2(self.linear(x)) @@ -160,17 +162,20 @@ def __init__(self, config): super().__init__(config) self.linear = nn.Linear(50, 50) self.linear_2 = nn.Linear(50, 50) + self.post_init() def forward(self, x): return self.linear_2(self.linear(x)) class BaseModelWithTiedWeights(PreTrainedModel): config_class = PreTrainedConfig + _tied_weights_keys = {"linear_2.weight": "linear.weight"} def __init__(self, config): super().__init__(config) self.linear = nn.Linear(5, 5) self.linear_2 = nn.Linear(5, 5) + self.post_init() def forward(self, x): return self.linear_2(self.linear(x)) @@ -193,6 +198,7 @@ def __init__(self, config): # linear is a common name between Base and Head on purpose. self.linear = nn.Linear(5, 5) self.linear2 = nn.Linear(5, 5) + self.post_init() def forward(self, x): return self.linear2(self.linear(self.base(x))) @@ -209,6 +215,7 @@ def __init__(self, config): # direct params and submodules is helpful for testing offloading logic self.weight = nn.Parameter(torch.rand((5, 5))) self.base = BaseModel(config) + self.post_init() def forward(self, x): return self.base(x @ self.weight.T) @@ -225,6 +232,7 @@ def __init__(self, config): self.submodule = ModelWithDirectParam(config) # needed so model can have at least one module on accelerator self.linear = nn.Linear(5, 5) + self.post_init() def forward(self, x): return self.linear(self.submodule(x)) @@ -232,6 +240,7 @@ def forward(self, x): class ModelWithHeadAndTiedWeights(PreTrainedModel): base_model_prefix = "base" config_class = PreTrainedConfig + _tied_weights_keys = {"decoder.weight": "base.linear.weight"} def _init_weights(self, module): pass @@ -240,6 +249,7 @@ def __init__(self, config): super().__init__(config) self.base = BaseModel(config) self.decoder = nn.Linear(5, 5) + self.post_init() def forward(self, x): return self.decoder(self.base(x)) @@ -1412,6 +1422,7 @@ def test_tied_weights_reload(self): self.assertIs(new_model.linear.weight, new_model.linear_2.weight) # With head + model = BaseModel(PreTrainedConfig()) model.save_pretrained(tmp_dir) new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True) self.assertIs(new_model.base.linear.weight, new_model.decoder.weight) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index d4458f9e1c0e..e569a0fc7b5c 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -401,6 +401,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s # common and important attributes, even if they do not always appear in the modeling files attributes_to_allow = [ "initializer_range", + "init_std", + "initializer_factor", "bos_index", "eos_index", "pad_index", diff --git a/utils/check_init_weights_data.py b/utils/check_init_weights_data.py index 93aebd9f5b2d..cf549102c5ac 100644 --- a/utils/check_init_weights_data.py +++ b/utils/check_init_weights_data.py @@ -33,9 +33,21 @@ def iter_modeling_files(): yield from MODELING_ROOT.rglob(pattern) -def function_has_forbidden_data_usage(fn: ast.FunctionDef) -> int | None: +def full_name(node): """ - Returns the first offending line number if `.data` is used, otherwise `None`. + Return full dotted name from an Attribute or Name node. + """ + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return full_name(node.value) + "." + node.attr + else: + raise ValueError("Not a Name or Attribute node") + + +def function_has_forbidden_usage(fn: ast.FunctionDef) -> int | None: + """ + Returns the first offending line number if we detect an in-place operation on a module's weight, otherwise `None`. """ args = fn.args.args @@ -43,8 +55,14 @@ def function_has_forbidden_data_usage(fn: ast.FunctionDef) -> int | None: return None for node in ast.walk(fn): - if isinstance(node, ast.Attribute) and node.attr == "data": - return node.lineno + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + is_inplace_ops = node.func.attr.endswith("_") + # We allow in-place ops on tensors that are not part of the module itself (see e.g. modeling_qwen3_next.py L997) + is_on_module_weight = isinstance(node.func.value, (ast.Name, ast.Attribute)) and "module." in full_name( + node.func.value + ) + if is_inplace_ops and is_on_module_weight: + return node.lineno return None @@ -62,16 +80,17 @@ def main() -> int: for node in ast.walk(tree): if isinstance(node, ast.FunctionDef) and node.name == "_init_weights": - offending_line = function_has_forbidden_data_usage(node) + offending_line = function_has_forbidden_usage(node) if offending_line is not None: violations.append( - f"{file_path}:{offending_line}: `_init_weights(self, module)` uses `.data`. " - "Use tensor ops directly to remain compatible with HFParameter." + f"{file_path}:{offending_line}: `_init_weights(self, module)` uses an in-place operation on a " + "module's weight. Please use the `init` functions primitives instead, usually imported as " + "`from ... import initialization as init`." ) break if violations: - print("Found forbidden `.data` usage inside `_init_weights(self, module)`:\n", file=sys.stderr) + print("Found forbidden usage inside `_init_weights(self, module)`:\n", file=sys.stderr) print("\n".join(violations), file=sys.stderr) return 1 diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index 0e0df9b66ef3..c67499fb4704 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -55,8 +55,11 @@ def topological_sort(dependencies: dict) -> list[list[str]]: ) -def is_model_import(module: str) -> bool: +def is_model_import(module: str | None) -> bool: """Check whether `module` is a model import or not.""" + # Happens for fully relative import, i.e. `from ... import initialization as init` + if module is None: + return False patterns = "|".join(ALL_FILE_TYPES) regex = rf"(\w+)\.(?:{patterns})_(\w+)" match_object = re.search(regex, module) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 3216495133b0..2e8d91467ac8 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -558,7 +558,8 @@ def visit_ImportFrom(self, node): """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs to be added (because it will be part of the imports)""" - import_module = self.python_module.code_for_node(node.module) + # `node.module` is None for fully relative imports, e.g. `from ... import initialization as init` + import_module = self.python_module.code_for_node(node.module) if node.module is not None else "" import_statement = "." * len(node.relative) + import_module if re.search(rf"^\.({self.match_patterns}).*", import_statement): for imported_object in node.names: @@ -1227,7 +1228,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ - import_module = self.python_module.code_for_node(node.module) + # `node.module` is None for fully relative imports, e.g. `from ... import initialization as init` + import_module = self.python_module.code_for_node(node.module) if node.module is not None else "" import_statement = "." * len(node.relative) + import_module if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): return @@ -1283,7 +1285,10 @@ def visit_SimpleStatementLine(self, node): if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): self.imports.append(node) elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): - import_module = self.python_module.code_for_node(node.body[0].module) + # `node.body[0].module` is None for fully relative imports, e.g. `from ... import initialization as init` + import_module = ( + self.python_module.code_for_node(node.body[0].module) if node.body[0].module is not None else "" + ) import_statement = "." * len(node.body[0].relative) + import_module if any( external_file["name"] in import_statement for external_file in self.excluded_external_files diff --git a/utils/test_module/custom_modeling.py b/utils/test_module/custom_modeling.py index fafa7bff253d..94089890d7d7 100644 --- a/utils/test_module/custom_modeling.py +++ b/utils/test_module/custom_modeling.py @@ -11,6 +11,7 @@ class CustomModel(PreTrainedModel): def __init__(self, config): super().__init__(config) self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.post_init() def forward(self, x): return self.linear(x)