diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 9bc978f2..935d6e16 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: true use_variance_scaling: true rel_pos: true diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index acbe25df..9d63028f 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -71,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 58a4d3a6..40f4c532 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: false use_variance_scaling: true hidden_size: 384 diff --git a/configs/variance.yaml b/configs/variance.yaml index 6bd86cfa..a819c1c4 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: false use_variance_scaling: true rel_pos: true diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 849dae5d..e7437365 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -1,6 +1,7 @@ import json from pathlib import Path from typing import List, Union, Tuple, Dict +import warnings import onnx import onnxsim @@ -78,6 +79,7 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) + self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerAcousticONNX: model = DiffSingerAcousticONNX( @@ -88,8 +90,21 @@ def build_model(self) -> DiffSingerAcousticONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) + if self.rope_interleaved is None: + warnings.warn( + "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " + "(https://github.com/openvpi/DiffSinger/pull/276)" + "In order to export ONNX with behavior compatible with past checkpoints, " + "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " + "Please understand what you are doing.", + UserWarning, + stacklevel=2 + ) + strict=False + else: + strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, - prefix_in_ckpt='model', strict=True, device=self.device) + prefix_in_ckpt='model', strict=strict, device=self.device) return model def export(self, path: Path): diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 82808ec0..69af991c 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -1,6 +1,7 @@ import json from pathlib import Path from typing import Union, List, Tuple, Dict +import warnings import onnx import onnxsim @@ -81,6 +82,7 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) + self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerVarianceONNX: model = DiffSingerVarianceONNX( @@ -90,6 +92,19 @@ def build_model(self) -> DiffSingerVarianceONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) + if self.rope_interleaved is None: + warnings.warn( + "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " + "(https://github.com/openvpi/DiffSinger/pull/276)" + "In order to export ONNX with behavior compatible with past checkpoints, " + "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " + "Please understand what you are doing.", + UserWarning, + stacklevel=2 + ) + strict=False + else: + strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, prefix_in_ckpt='model', strict=True, device=self.device) model.build_smooth_op(self.device) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index e0ab05f2..9af4a277 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -1,323 +1,77 @@ from __future__ import annotations -from math import pi, log - import torch -from torch.amp import autocast -from torch.nn import Module, ModuleList -from torch import nn, einsum, broadcast_tensors, Tensor - +from torch import nn, einsum, Tensor +from torch.nn import Module from einops import rearrange, repeat -from typing import Literal - -# helper functions - -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d - -# broadcat, as tortoise-tts was using it - -def broadcat(tensors, dim = -1): - broadcasted_tensors = broadcast_tensors(*tensors) - return torch.cat(broadcasted_tensors, dim = dim) - -def slice_at_dim(t, dim_slice: slice, *, dim): - dim += (t.ndim if dim < 0 else 0) - colons = [slice(None)] * t.ndim - colons[dim] = dim_slice - return t[tuple(colons)] - -# rotary embedding helper functions - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') - -@autocast('cuda', enabled = False) -def apply_rotary_emb( - freqs, - t, - start_index = 0, - scale = 1., - seq_dim = -2, - freqs_seq_dim = None -): - dtype = t.dtype - if not exists(freqs_seq_dim): - if freqs.ndim == 2 or t.ndim == 3: - freqs_seq_dim = 0 +def rotate_half(x: Tensor, interleaved=True) -> Tensor: + if not interleaved: + # x_half1, x_half2 = x.chunk(2, dim=-1) + # Using torch.split instead of chunk for ONNX export compatibility. + x1, x2 = torch.split(x, x.size(-1) // 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x = rearrange(x, '... (d r) -> ... d r', r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, '... d r -> ... (d r)') - if t.ndim == 3 or exists(freqs_seq_dim): - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) +def apply_rotary_emb(freqs: Tensor, t: Tensor, interleaved=True) -> Tensor: rot_dim = freqs.shape[-1] - end_index = start_index + rot_dim + t_to_rotate = t[..., :rot_dim] + t_pass_through = t[..., rot_dim:] + + t_rotated = (t_to_rotate * freqs.cos()) + (rotate_half(t_to_rotate, interleaved) * freqs.sin()) + + return torch.cat((t_rotated, t_pass_through), dim=-1) - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' - - # Split t into three parts: left, middle (to be transformed), and right - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - # Apply rotary embeddings without modifying t in place - t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) - - out = torch.cat((t_left, t_transformed, t_right), dim=-1) - - return out.type(dtype) - -# learned rotation helpers - -def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): - if exists(freq_ranges): - rotations = einsum('..., f -> ... f', rotations, freq_ranges) - rotations = rearrange(rotations, '... r f -> ... (r f)') - - rotations = repeat(rotations, '... n -> ... (n r)', r = 2) - return apply_rotary_emb(rotations, t, start_index = start_index) - -# classes class RotaryEmbedding(Module): def __init__( self, dim, - custom_freqs: Tensor | None = None, - freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - learned_freq = False, - use_xpos = False, - xpos_scale_base = 512, - interpolate_factor = 1., - theta_rescale_factor = 1., - seq_before_head_dim = False, - cache_if_possible = True, - cache_max_seq_len = 8192 + theta=10000, + precompute_len=8192, + cache_max_seq_len=8192, + interleaved: bool = True ): super().__init__() - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning - # has some connection to NTK literature - # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for + self.interleaved = interleaved - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() + inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len + self._cache_max_seq_len = max(precompute_len, cache_max_seq_len) + self._precomputed_len = precompute_len - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.register_buffer('cached_freqs', None, persistent=True) self.cached_freqs_seq_len = 0 + + if self._precomputed_len > 0: + self._precompute_cache(self._precomputed_len) - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1. - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - - if not use_xpos: - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - self.scale_base = xpos_scale_base - - self.register_buffer('scale', scale, persistent = False) - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_scales_seq_len = 0 - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def get_seq_pos(self, seq_len, device, dtype, offset = 0): - return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor - - def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' - - device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] - - seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) - - freqs = self.forward(seq, seq_len = seq_len, offset = offset) - - if seq_dim == -3: - freqs = rearrange(freqs, 'n d -> n 1 d') - - return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) - - def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): - dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) - - q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] - assert q_len <= k_len - - q_scale = k_scale = 1. - - if self.use_xpos: - seq = self.get_seq_pos(k_len, dtype = dtype, device = device) - - q_scale = self.get_scale(seq[-q_len:]).type(dtype) - k_scale = self.get_scale(seq).type(dtype) - - rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) - rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def rotate_queries_and_keys(self, q, k, seq_dim = None): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert self.use_xpos - device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] - - seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) - - freqs = self.forward(seq, seq_len = seq_len) - scale = self.get_scale(seq, seq_len = seq_len).to(dtype) - - if seq_dim == -3: - freqs = rearrange(freqs, 'n d -> n 1 d') - scale = rearrange(scale, 'n d -> n 1 d') - - rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) - rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def get_scale( - self, - t: Tensor, - seq_len: int | None = None, - offset = 0 - ): - assert self.use_xpos - - should_cache = ( - self.cache_if_possible and - exists(seq_len) and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_scales) and \ - (seq_len + offset) <= self.cached_scales_seq_len - ): - return self.cached_scales[offset:(offset + seq_len)] - - scale = 1. - if self.use_xpos: - power = (t - len(t) // 2) / self.scale_base - scale = self.scale ** rearrange(power, 'n -> n 1') - scale = repeat(scale, 'n d -> n (d r)', r = 2) - - if should_cache and offset == 0: - self.cached_scales[:seq_len] = scale.detach() - self.cached_scales_seq_len = seq_len - - return scale - - def get_axial_freqs(self, *dims): - Colon = slice(None) - all_freqs = [] - - for ind, dim in enumerate(dims): - if self.freqs_for == 'pixel': - pos = torch.linspace(-1, 1, steps = dim, device = self.device) - else: - pos = torch.arange(dim, device = self.device) - - freqs = self.forward(pos, seq_len = dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - all_freqs = broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim = -1) - - @autocast('cuda', enabled = False) - def forward( - self, - t: Tensor, - seq_len: int | None = None, - offset = 0 - ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - freqs = self.cached_freqs[offset:(offset + seq_len)].detach() - # Fix issue about 'find_unused_parameters' when DDP training.(#244) - freqs = freqs + 0. * self.freqs.sum() - return freqs - - freqs = self.freqs + def _precompute_cache(self, seq_len: int): + seq = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = einsum('i, j -> i j', seq, self.inv_freq) + + if self.interleaved: + freqs = repeat(freqs, '... n -> ... (n r)', r=2) + else: + freqs = torch.cat((freqs, freqs), dim=-1) - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + self.cached_freqs = freqs + self.cached_freqs_seq_len = seq_len - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + def forward(self, t: Tensor, seq_len: int) -> Tensor: + if self.cached_freqs is None or seq_len > self.cached_freqs_seq_len: + self._precompute_cache(seq_len) + + return self.cached_freqs[0: seq_len].detach() - return freqs + def rotate_queries_or_keys(self, t: Tensor) -> Tensor: + device, dtype, seq_len = t.device, t.dtype, t.shape[-2] + freqs = self.forward(t, seq_len=seq_len) + + return apply_rotary_emb(freqs.to(device=device, dtype=dtype), t, self.interleaved) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 86aa535a..868d383f 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -38,7 +38,7 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) self.pitch_embed = Linear(1, hparams['hidden_size']) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 882ebc11..cc840aed 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -369,14 +369,14 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size self.dropout = dropout self.use_pos_embed = use_pos_embed if use_pos_embed and use_rope: - rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads) + rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads, interleaved = rope_interleaved) else: rotary_embed = None self.layers = nn.ModuleList([ diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 70edcebc..ba6994c1 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -33,7 +33,7 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) dur_hparams = hparams['dur_prediction_args'] @@ -127,7 +127,7 @@ def get_hparam(key): ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'), dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'), - use_rope=get_hparam('use_rope') + use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True) ) self.out_proj = Linear(hidden_size, hparams['hidden_size'])