diff --git a/requirements.txt b/requirements.txt index 15c7b1e..44dd0bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ omegaconf opencv-python pydantic compel +kornia \ No newline at end of file diff --git a/train.py b/train.py index 592a761..6a3cf27 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPEncoder from utils.dataset import VideoJsonDataset, SingleVideoDataset, \ - ImageDataset, VideoFolderDataset, CachedDataset + ImageDataset, VideoFolderDataset, CachedDataset, ConcatInterleavedDataset from einops import rearrange, repeat from utils.lora_handler import LoraHandler, LORA_VERSIONS @@ -275,7 +275,7 @@ def handle_cache_latents( # Cache latents by storing them in VRAM. # Speeds up training and saves memory by not encoding during the train loop. if not should_cache: return None - vae.to('cuda', dtype=torch.float16) + vae.to('cuda', dtype=torch.float32) vae.enable_slicing() cached_latent_dir = ( @@ -287,15 +287,17 @@ def handle_cache_latents( os.makedirs(cache_save_dir, exist_ok=True) for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): - save_name = f"cached_{i}" full_out_path = f"{cache_save_dir}/{save_name}.pt" - pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) - batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) - for k, v in batch.items(): batch[k] = v[0] + pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32) + batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) + for k, v in batch.items(): + batch[k] = v[0] + torch.save(batch, full_out_path) + del pixel_values del batch @@ -308,8 +310,8 @@ def handle_cache_latents( return torch.utils.data.DataLoader( CachedDataset(cache_dir=cache_save_dir), batch_size=train_batch_size, - shuffle=True, - num_workers=0 + shuffle=False, + num_workers=0, ) def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None): @@ -455,6 +457,7 @@ def main( train_data: Dict, validation_data: Dict, extra_train_data: list = [], + interleave_datasets: bool = False, dataset_types: Tuple[str] = ('json'), validation_steps: int = 100, trainable_modules: Tuple[str] = ("attn1", "attn2"), @@ -601,40 +604,57 @@ def main( num_training_steps=max_train_steps * gradient_accumulation_steps, ) - # Get the training dataset based on types (json, single_video, image) - train_datasets = get_train_dataset(dataset_types, train_data, tokenizer) + train_dataloader = None + + if cached_latent_dir is None: + # Get the training dataset based on types (json, single_video, image) + if extra_train_data is None: + train_datasets = get_train_dataset(dataset_types, train_data, tokenizer) + else: + train_datasets = [] - # If you have extra train data, you can add a list of however many you would like. - # Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}] - try: - if extra_train_data is not None and len(extra_train_data) > 0: - for dataset in extra_train_data: - d_t, t_d = dataset['dataset_types'], dataset['train_data'] - train_datasets += get_train_dataset(d_t, t_d, tokenizer) + # If you have extra train data, you can add a list of however many you would like. + # Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}] + try: + if extra_train_data is not None: + for dataset in extra_train_data: + dataset_enabled = dataset.get('enabled', True) + if not dataset_enabled: + continue - except Exception as e: - print(f"Could not process extra train datasets due to an error : {e}") + d_t, t_d = dataset['dataset_types'], dataset['train_data'] + train_datasets += get_train_dataset(d_t, t_d, tokenizer) - # Extend datasets that are less than the greatest one. This allows for more balanced training. - attrs = ['train_data', 'frames', 'image_dir', 'video_files'] - extend_datasets(train_datasets, attrs, extend=extend_dataset) + # Allows for joint video / text encoder training. + if t_d['n_sample_frames'] > 1 and (train_text_encoder or use_text_lora): + t_d_single = t_d.copy() + t_d_single['frame_step'] = 30 + t_d_single['n_sample_frames'] = 1 - # Process one dataset - if len(train_datasets) == 1: - train_dataset = train_datasets[0] - - # Process many datasets - else: - train_dataset = torch.utils.data.ConcatDataset(train_datasets) + train_datasets += get_train_dataset(d_t, t_d_single, tokenizer) - # DataLoaders creation: - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=train_batch_size, - shuffle=True - ) + except Exception as e: + print(f"Could not process extra train datasets due to an error : {e}") + + # Process one dataset + if len(train_datasets) == 1: + train_dataset = train_datasets[0] + + # Process many datasets + else: + if interleave_datasets and extend_dataset: + train_dataset = ConcatInterleavedDataset(train_datasets) + else: + train_dataset = torch.utils.data.ConcatDataset(train_datasets) + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=(not interleave_datasets and not extend_dataset) + ) - # Latents caching + # Latents caching cached_data_loader = handle_cache_latents( cache_latents, output_dir, diff --git a/utils/dataset.py b/utils/dataset.py index 219b62b..da0fd86 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -12,12 +12,15 @@ from itertools import islice from pathlib import Path from .bucketing import sensible_buckets +from .dataset_processors import ConditionProcessors decord.bridge.set_bridge('torch') from torch.utils.data import Dataset from einops import rearrange, repeat +TRAIN_DATA_VARS = ['train_data', 'frames', 'image_dir', 'video_files'] + def get_prompt_ids(prompt, tokenizer): prompt_ids = tokenizer( prompt, @@ -85,8 +88,96 @@ def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_ba return video, vr + +class DatasetProcessor(object): + + def __init__(self, cond_processor=None): + self.condition_processor = self.get_condition_processor(cond_processor) + + def get_condition_processor(self, cond_processor=None) : + + # The return condition is a function, so create a function + # that doesn't return anything when it's called. + def no_cond(*args, **kwargs): + return torch.empty(1) + + AVAILABLE_PROCESSORS = ['canny', 'threshold'] + cond_processor = [ + p for p in AVAILABLE_PROCESSORS if p == cond_processor + ] + cond_processor = ( + cond_processor[0] if len(cond_processor) > 0 else "" + ) + + return ConditionProcessors.get(cond_processor, no_cond) + + def get_frame_range(self, vr): + return get_video_frames( + vr, + self.sample_start_idx, + self.frame_step, + self.n_sample_frames + ) + + def get_frame_buckets(self, vr): + _, h, w = vr[0].shape + width, height = sensible_buckets(self.width, self.height, h, w) + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def get_frame_batch(self, vr, resize=None): + frame_range = self.get_frame_range(vr) + frames = vr.get_batch(frame_range) + video = rearrange(frames, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + + return video, vr + + # Inspired by VideoMAE + def normalize_input( + self, + item, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ): + if item.dtype == torch.uint8: + item = rearrange(item, 'f c h w -> f h w c') + item = item.float() / 255.0 + mean = torch.tensor(mean) + std = torch.tensor(std) + + out = rearrange((item - mean) / std, 'f h w c -> f c h w') + + return out + else: + return item / (127.5 - 1.0) + + def _example(self, item, prompt_ids, prompt): + example = { + "pixel_values": self.normalize_input(item), + "condition": self.condition_processor(item), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example + # https://github.com/ExponentialML/Video-BLIP2-Preprocessor -class VideoJsonDataset(Dataset): +class VideoJsonDataset(DatasetProcessor, Dataset): def __init__( self, tokenizer = None, @@ -100,8 +191,10 @@ def __init__( vid_data_key: str = "video_path", preprocessed: bool = False, use_bucketing: bool = False, + condition_processor = None, **kwargs ): + DatasetProcessor.__init__(self, condition_processor) self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") self.use_bucketing = use_bucketing self.tokenizer = tokenizer @@ -154,51 +247,6 @@ def load_from_json(self, path, json_data): def validate_json(self, base_path, path): return os.path.exists(f"{base_path}/{path}") - def get_frame_range(self, vr): - return get_video_frames( - vr, - self.sample_start_idx, - self.frame_step, - self.n_sample_frames - ) - - def get_vid_idx(self, vr, vid_data=None): - frames = self.n_sample_frames - - if vid_data is not None: - idx = vid_data['frame_index'] - else: - idx = self.sample_start_idx - - return idx - - def get_frame_buckets(self, vr): - _, h, w = vr[0].shape - width, height = sensible_buckets(self.width, self.height, h, w) - resize = T.transforms.Resize((height, width), antialias=True) - - return resize - - def get_frame_batch(self, vr, resize=None): - frame_range = self.get_frame_range(vr) - frames = vr.get_batch(frame_range) - video = rearrange(frames, "f h w c -> f c h w") - - if resize is not None: video = resize(video) - return video - - def process_video_wrapper(self, vid_path): - video, vr = process_video( - vid_path, - self.use_bucketing, - self.width, - self.height, - self.get_frame_buckets, - self.get_frame_batch - ) - - return video, vr - def train_data_batch(self, index): # If we are training on individual clips. @@ -257,17 +305,10 @@ def __getitem__(self, index): if self.train_data is not None: video, prompt, prompt_ids = self.train_data_batch(index) - example = { - "pixel_values": (video / 127.5 - 1.0), - "prompt_ids": prompt_ids[0], - "text_prompt": prompt, - 'dataset': self.__getname__() - } + return self._example(video, prompt_ids, prompt) - return example - -class SingleVideoDataset(Dataset): +class SingleVideoDataset(DatasetProcessor, Dataset): def __init__( self, tokenizer = None, @@ -279,22 +320,25 @@ def __init__( single_video_prompt: str = "", use_caption: bool = False, use_bucketing: bool = False, + condition_processor = None, **kwargs ): + DatasetProcessor.__init__(self, condition_processor) self.tokenizer = tokenizer self.use_bucketing = use_bucketing self.frames = [] self.index = 1 - self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") self.n_sample_frames = n_sample_frames self.frame_step = frame_step self.single_video_path = single_video_path self.single_video_prompt = single_video_prompt + self.create_video_chunks() self.width = width self.height = height + def create_video_chunks(self): # Create a list of frames separated by sample frames # [(1,2,3), (4,5,6), ...] @@ -304,12 +348,9 @@ def create_video_chunks(self): self.frames = list(self.chunk(vr_range, self.n_sample_frames)) # Delete any list that contains an out of range index. - for i, inner_frame_nums in enumerate(self.frames): - for frame_num in inner_frame_nums: - if frame_num > len(vr): - print(f"Removing out of range index list at position: {i}...") - del self.frames[i] - + self.frames = list( + filter(lambda x: len(x) == self.n_sample_frames, self.frames) + ) return self.frames def chunk(self, it, size): @@ -324,25 +365,6 @@ def get_frame_batch(self, vr, resize=None): if resize is not None: video = resize(video) return video - def get_frame_buckets(self, vr): - _, h, w = vr[0].shape - width, height = sensible_buckets(self.width, self.height, h, w) - resize = T.transforms.Resize((height, width), antialias=True) - - return resize - - def process_video_wrapper(self, vid_path): - video, vr = process_video( - vid_path, - self.use_bucketing, - self.width, - self.height, - self.get_frame_buckets, - self.get_frame_batch - ) - - return video, vr - def single_video_batch(self, index): train_data = self.single_video_path self.index = index @@ -358,26 +380,19 @@ def single_video_batch(self, index): raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") @staticmethod - def __getname__(): return 'single_video' + def __getname__(): + return 'single_video' def __len__(self): - - return len(self.create_video_chunks()) + return len(self.frames) def __getitem__(self, index): video, prompt, prompt_ids = self.single_video_batch(index) - example = { - "pixel_values": (video / 127.5 - 1.0), - "prompt_ids": prompt_ids[0], - "text_prompt": prompt, - 'dataset': self.__getname__() - } - - return example + return self._example(video, prompt_ids, prompt) -class ImageDataset(Dataset): +class ImageDataset(DatasetProcessor, Dataset): def __init__( self, @@ -391,8 +406,10 @@ def __init__( single_img_prompt: str = '', use_bucketing: bool = False, fallback_prompt: str = '', + condition_processor = None, **kwargs ): + DatasetProcessor.__init__(self, condition_processor) self.tokenizer = tokenizer self.img_types = (".png", ".jpg", ".jpeg", '.bmp') self.use_bucketing = use_bucketing @@ -462,16 +479,10 @@ def __len__(self): def __getitem__(self, index): img, prompt, prompt_ids = self.image_batch(index) - example = { - "pixel_values": (img / 127.5 - 1.0), - "prompt_ids": prompt_ids[0], - "text_prompt": prompt, - 'dataset': self.__getname__() - } - return example + return self._example(img, prompt_ids, prompt) -class VideoFolderDataset(Dataset): +class VideoFolderDataset(DatasetProcessor, Dataset): def __init__( self, tokenizer=None, @@ -482,8 +493,10 @@ def __init__( path: str = "./data", fallback_prompt: str = "", use_bucketing: bool = False, + condition_processor = None, **kwargs ): + DatasetProcessor.__init__(self, condition_processor) self.tokenizer = tokenizer self.use_bucketing = use_bucketing @@ -524,26 +537,6 @@ def get_frame_batch(self, vr, resize=None): if resize is not None: video = resize(video) return video, vr - def process_video_wrapper(self, vid_path): - video, vr = process_video( - vid_path, - self.use_bucketing, - self.width, - self.height, - self.get_frame_buckets, - self.get_frame_batch - ) - return video, vr - - def get_prompt_ids(self, prompt): - return self.tokenizer( - prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids - @staticmethod def __getname__(): return 'folder' @@ -560,12 +553,13 @@ def __getitem__(self, index): else: prompt = self.fallback_prompt - prompt_ids = self.get_prompt_ids(prompt) + prompt_ids = get_prompt_ids(prompt, tokenizer) - return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()} + return self._example(video[0], prompt_ids, prompt) -class CachedDataset(Dataset): - def __init__(self,cache_dir: str = ''): +class CachedDataset(DatasetProcessor, Dataset): + def __init__(self, cache_dir: str = ''): + DatasetProcessor.__init__(self) self.cache_dir = cache_dir self.cached_data_list = self.get_files_list() @@ -578,4 +572,115 @@ def __len__(self): def __getitem__(self, index): cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0') + return cached_latent + +class ConcatInterleavedDataset(Dataset): + def __init__(self, datasets): + self.datasets = datasets + self.train_data_vars = TRAIN_DATA_VARS + + self.interleave_datasets() + + def get_parent_dataset(self): + + # There's a chance that the subset images may be bigger than the video if doing text training. + # If it has the attribute "is_subset", we can simply ignore it to ensure it isn't the biggest + # length. + dataset_lengths = [d.__len__() if not hasattr(d, 'is_subset') else 0 for d in self.datasets] + max_dataset_index = dataset_lengths.index(max(dataset_lengths)) + + parent_dataset = self.datasets[max_dataset_index] + + return parent_dataset, max_dataset_index + + def process_dataset(self, dataset): + processed_dataset = [] + train_data_var_name = self.get_dataset_data_var_name(dataset)[0] + train_data_var = getattr(dataset, train_data_var_name) + + for idx, item in enumerate(train_data_var): + if isinstance(item, dict) and 'idx_modulo' in item: + ref_idx = item['idx_modulo'] + already_processed_item = processed_dataset[ref_idx] + + # Dataset items are assumed to be of type Dict + already_processed_item['reference_idx'] = ref_idx + processed_dataset.append(already_processed_item) + else: + processed_dataset.append(dataset[idx]) + + return processed_dataset + + def get_dataset_data_var_name(self, dataset): + return [v for v in self.train_data_vars if v in dataset.__dict__.keys()] + + def create_data_val_dict(self, val, idx, length, idx_modulo): + return dict( + value=val, + idx=idx, + length=length, + idx_modulo=idx_modulo + ) + + def interleave_datasets(self): + parent_dataset, parent_dataset_index = self.get_parent_dataset() + child_datasets = self.datasets.copy() + child_datasets.pop(parent_dataset_index) + + parent_dataset_length = parent_dataset.__len__() + + for dataset in child_datasets: + if dataset.__len__() <= 0: + del dataset + continue + + var_name = self.get_dataset_data_var_name(dataset) + var_name = var_name[0] if len(var_name) == 1 else None + + if var_name is None: + continue + + original_dataset_length = dataset.__len__() + + train_data_var = getattr(dataset, var_name) + train_data_var *= parent_dataset_length + new_train_data_val = train_data_var[:parent_dataset_length] + + # Do this to reference items that were already accessed. + # Since some __getitem__ functions are heavy (numpy computations, video reads, etc.), + # we want to avoid performing the same expensive function multiple times. + # We simply point to the corresponding index so that when we interleave, we can just copy the __getitem__ result. + for i, val in enumerate(new_train_data_val): + if i >= original_dataset_length: + clamped_idx = i % original_dataset_length + new_train_data_val[i] = self.create_data_val_dict( + val, + i, + original_dataset_length, + clamped_idx + ) + + setattr(dataset, var_name, new_train_data_val) + + from itertools import chain + + print("Interleaving Datasets. Please wait...") + train_datasets = [parent_dataset] + child_datasets + + # Zip all of the items in the datasets. We do this to __get_item__ all of our data. + # Example (d == Dataset): [(d1_item1, d2_item1, d3_item1), (d1_item2, d2_item2, d3_item2), (...)] + interleave_datasets = zip(*[self.process_dataset(d) for d in train_datasets]) + + # Now we flatten it as a new Dataset iterable Dataset to be concatenated. + # Example: [d1_item1, d2_item1, d3_item1, d2_item1, d2_item2, d2_item3, ...] + InterLeavedDataset = list(chain(*interleave_datasets)) + self.datasets = InterLeavedDataset + + print("Finished interleaving datasets.") + + def __len__(self): + return len(self.datasets) + + def __getitem__(self, index): + return self.datasets[index] \ No newline at end of file diff --git a/utils/dataset_processors.py b/utils/dataset_processors.py new file mode 100644 index 0000000..b436d65 --- /dev/null +++ b/utils/dataset_processors.py @@ -0,0 +1,66 @@ +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +from kornia.filters import Canny +from einops import rearrange + +canny = Canny().to('cuda') + +def test_canny(): + img_path = "/dir/1/1.png" + image = Image.open(img_path) + + image = np.asarray(image)[:, :, :3] + + image = torch.from_numpy(image).float() / 127. - 1 + + image = rearrange(image, '(b h) w c -> b c h w', b=1) + _, canny_edges = canny(image) + + out = T.ToPILImage()(canny_edges.squeeze(0)) + + out.save('./test.png') + print(canny_edges) + +def item_norm(item): + return item / 127. - 1 + + +# Defaults to the VAE scale of 8. TODO: Change this to be changeable. +def resize_item(item, resize_factor=8): + f, c, h, w = item.shape + resize = T.Resize(size=(h // resize_factor, w // resize_factor)) + + return resize(item) + +def canny_processor(item, img_path=None, from_dataloader=True): + + if img_path is not None and not from_dataloader: + get_img = Image.open(img_path) + + # Remove the alpha channel if the input has one + np_img = np.asarray(get_img)[:, :, :3] + image = item_norm(torch.from_numpy(np_img).float()) + image = rearrange(image, '(b h) w c -> b c h w', b=1) + item = image + else: + item = item_norm(item.float()) + + _, canny_edges = canny(item) + + return resize_item(canny_edges) + +def threshold_processor(item): + item_processed = resize_item(item_norm(item)) > 0.5 + + # The VAE for Stable Diffusion has 4 channels, so we add an arbitrary one (RGB 3 + 1A). + null_channel = torch.zeros_like(item_processed)[:, :1, ...] + item_processed = torch.cat([item_processed, null_channel], dim=1) + + return item_processed + +ConditionProcessors = dict( + canny=canny_processor, + threshold=threshold_processor +)