Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ omegaconf
opencv-python
pydantic
compel
kornia
92 changes: 56 additions & 36 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPEncoder
from utils.dataset import VideoJsonDataset, SingleVideoDataset, \
ImageDataset, VideoFolderDataset, CachedDataset
ImageDataset, VideoFolderDataset, CachedDataset, ConcatInterleavedDataset
from einops import rearrange, repeat
from utils.lora_handler import LoraHandler, LORA_VERSIONS

Expand Down Expand Up @@ -275,7 +275,7 @@ def handle_cache_latents(
# Cache latents by storing them in VRAM.
# Speeds up training and saves memory by not encoding during the train loop.
if not should_cache: return None
vae.to('cuda', dtype=torch.float16)
vae.to('cuda', dtype=torch.float32)
vae.enable_slicing()

cached_latent_dir = (
Expand All @@ -287,15 +287,17 @@ def handle_cache_latents(
os.makedirs(cache_save_dir, exist_ok=True)

for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):

save_name = f"cached_{i}"
full_out_path = f"{cache_save_dir}/{save_name}.pt"

pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
for k, v in batch.items(): batch[k] = v[0]

pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32)
batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
for k, v in batch.items():
batch[k] = v[0]

torch.save(batch, full_out_path)

del pixel_values
del batch

Expand All @@ -308,8 +310,8 @@ def handle_cache_latents(
return torch.utils.data.DataLoader(
CachedDataset(cache_dir=cache_save_dir),
batch_size=train_batch_size,
shuffle=True,
num_workers=0
shuffle=False,
num_workers=0,
)

def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
Expand Down Expand Up @@ -455,6 +457,7 @@ def main(
train_data: Dict,
validation_data: Dict,
extra_train_data: list = [],
interleave_datasets: bool = False,
dataset_types: Tuple[str] = ('json'),
validation_steps: int = 100,
trainable_modules: Tuple[str] = ("attn1", "attn2"),
Expand Down Expand Up @@ -601,40 +604,57 @@ def main(
num_training_steps=max_train_steps * gradient_accumulation_steps,
)

# Get the training dataset based on types (json, single_video, image)
train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)
train_dataloader = None

if cached_latent_dir is None:
# Get the training dataset based on types (json, single_video, image)
if extra_train_data is None:
train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)
else:
train_datasets = []

# If you have extra train data, you can add a list of however many you would like.
# Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}]
try:
if extra_train_data is not None and len(extra_train_data) > 0:
for dataset in extra_train_data:
d_t, t_d = dataset['dataset_types'], dataset['train_data']
train_datasets += get_train_dataset(d_t, t_d, tokenizer)
# If you have extra train data, you can add a list of however many you would like.
# Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}]
try:
if extra_train_data is not None:
for dataset in extra_train_data:
dataset_enabled = dataset.get('enabled', True)
if not dataset_enabled:
continue

except Exception as e:
print(f"Could not process extra train datasets due to an error : {e}")
d_t, t_d = dataset['dataset_types'], dataset['train_data']
train_datasets += get_train_dataset(d_t, t_d, tokenizer)

# Extend datasets that are less than the greatest one. This allows for more balanced training.
attrs = ['train_data', 'frames', 'image_dir', 'video_files']
extend_datasets(train_datasets, attrs, extend=extend_dataset)
# Allows for joint video / text encoder training.
if t_d['n_sample_frames'] > 1 and (train_text_encoder or use_text_lora):
t_d_single = t_d.copy()
t_d_single['frame_step'] = 30
t_d_single['n_sample_frames'] = 1

# Process one dataset
if len(train_datasets) == 1:
train_dataset = train_datasets[0]

# Process many datasets
else:
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_datasets += get_train_dataset(d_t, t_d_single, tokenizer)

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True
)
except Exception as e:
print(f"Could not process extra train datasets due to an error : {e}")

# Process one dataset
if len(train_datasets) == 1:
train_dataset = train_datasets[0]

# Process many datasets
else:
if interleave_datasets and extend_dataset:
train_dataset = ConcatInterleavedDataset(train_datasets)
else:
train_dataset = torch.utils.data.ConcatDataset(train_datasets)

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=(not interleave_datasets and not extend_dataset)
)

# Latents caching
# Latents caching
cached_data_loader = handle_cache_latents(
cache_latents,
output_dir,
Expand Down
Loading