Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
91de4b3
Add RewardModel and PrefGRPORewardModel classes for reward computation
LouisRouss Sep 10, 2025
db14e8a
Refactor RewardModel and PrefGRPORewardModel to enhance image handlin…
LouisRouss Sep 11, 2025
02e1ca8
Add return_latents option to Diffuser's denoise method for latent rep…
LouisRouss Sep 13, 2025
2ff6348
Add attribute delegation and enhanced dir() support to Diffuser class
LouisRouss Sep 13, 2025
b54d0c7
Fix dtype argument in model initialization
LouisRouss Sep 13, 2025
7e60ff7
Add one_step_denoise_grpo method for GRPO training in Flow class
LouisRouss Sep 13, 2025
b0240f0
Refactor training classes to use a common trainer and reorganize impo…
LouisRouss Sep 16, 2025
f60fd65
Add GRPO support to Diffuser and Flow classes with new methods and ut…
LouisRouss Sep 16, 2025
2a3ffc8
Enhance RewardModel and PrefGRPORewardModel with n_image_per_prompt s…
LouisRouss Sep 16, 2025
43f6cd3
Add GRPO support with new BatchData structures and update training cl…
LouisRouss Sep 19, 2025
6b014f7
fix typing
LouisRouss Sep 19, 2025
fd4252a
fix loss calculation grpo flow
LouisRouss Sep 20, 2025
ba8f074
Refactor loss computation in Flow class to use a list for step-wise l…
LouisRouss Sep 20, 2025
10a823d
Refactor trainer imports and implement validation step in GRPOTrainer
LouisRouss Sep 20, 2025
7519c8c
Finish GRPO training loop and fix epoch level scheduler logic
LouisRouss Sep 20, 2025
2f1940e
Add clip in reward model
LouisRouss Sep 21, 2025
d442082
Implement StepResult and Sampler classes for diffusion process; add E…
LouisRouss Sep 23, 2025
693e403
adapt to abstraction sampler and clean GRPO logic
LouisRouss Sep 23, 2025
342982a
Refactor ContextEmbedder to implement properties for n_output and out…
LouisRouss Sep 23, 2025
f7c4200
Refactor PrefGRPORewardModel to standardize clip model ID usage and i…
LouisRouss Sep 23, 2025
dc033ab
Refactor sampler classes to standardize set_steps method for improved…
LouisRouss Sep 25, 2025
8f17901
Add DDIM and DDPM sampler implementations with step and parameter set…
LouisRouss Sep 25, 2025
2a9847a
Refactor Flow and EulerMaruyama classes for improved parameter handli…
LouisRouss Sep 25, 2025
c64db67
improve tensor handling and device compatibility in flow and euler me…
LouisRouss Sep 25, 2025
2d476d4
- Refactor model input handling in Diffuser, Flow, and GRPOTrainer cl…
LouisRouss Sep 27, 2025
d2425fb
Add a generic abstract sampler class over modelization specific sampl…
LouisRouss Sep 27, 2025
0b9e159
Refactor diffusion model classes to standardize sampler initializatio…
LouisRouss Sep 27, 2025
3c81ce9
update docstring
LouisRouss Sep 27, 2025
854d82a
Refactor denoise method signatures in Diffuser, Flow, and GaussianDif…
LouisRouss Sep 27, 2025
89d8c90
Allow MMDiT to use a context embedder without pooled embedding
LouisRouss Sep 28, 2025
d079027
- Add loguru dependency
LouisRouss Sep 28, 2025
4ae13d2
Refactor preprocess method in DinoV2
LouisRouss Sep 28, 2025
8838fba
Enhance input validation in encode method of DCAE class to support ad…
LouisRouss Sep 28, 2025
ee3d906
Implement DDT architecture and refactor modulation classes for enhanc…
LouisRouss Sep 29, 2025
562f1f3
add dinoV3 and precompute functions
LouisRouss Oct 1, 2025
2b8a642
Refactor SD3TextEmbedder to improve type casting and add attention ma…
LouisRouss Oct 5, 2025
5c24b79
Update step method docstring in Euler and EulerMaruyama classes to re…
LouisRouss Oct 22, 2025
20945ec
add dependencies
LouisRouss Oct 22, 2025
9960d67
improve attn unet
LouisRouss Oct 22, 2025
94c9c44
Refactor ContextEmbedder to use ContextEmbedderOutput for forward method
LouisRouss Oct 22, 2025
d333dc5
Enhance MMDiTAttention and MMDiTBlock to support attention masks and …
LouisRouss Oct 26, 2025
da43da7
- rename mask to attn mask in context embedder output
LouisRouss Oct 26, 2025
ee8f7d0
Update DDT to utilize ContextEmbedderOutput for improved context hand…
LouisRouss Oct 26, 2025
5aed411
use transformers instead of open clip
LouisRouss Oct 26, 2025
ba5028a
Add attention mask to U-Net and use torch scaled do product attn
LouisRouss Oct 26, 2025
59b1d1d
finish forward method of PrefGRPORewardModel
LouisRouss Oct 26, 2025
1cd7e40
Refactor input tensor handling in DCAE to streamline min/max detectio…
LouisRouss Oct 28, 2025
048a5ab
Add RAE decoder
LouisRouss Oct 28, 2025
f39f838
Remove 'local/' from Pyright include paths in pyproject.toml
LouisRouss Oct 28, 2025
4b60ddd
add RAE vision tower
LouisRouss Oct 28, 2025
08f489b
fix torch stack xt_std
LouisRouss Oct 29, 2025
d7ab331
fix unet
LouisRouss Oct 29, 2025
a74f3b0
Add GRPO loss computation and update method signatures in Diffuser, F…
LouisRouss Oct 29, 2025
ce670ba
Remove 'local/' from Pyright include paths in pyproject.toml
LouisRouss Oct 29, 2025
42c1a64
Fix feature appending condition in DDT class to check for None
LouisRouss Oct 30, 2025
5f2e856
Fix default value for attn_mask in PreComputedEmbedder to ensure prop…
LouisRouss Oct 30, 2025
c38be13
fix reward model
LouisRouss Oct 30, 2025
f084799
Refactor encoding input range checks in DCAE class for clarity and ac…
LouisRouss Oct 30, 2025
f2840de
clean code - update docstring
LouisRouss Nov 2, 2025
f58555b
Merge branch 'main' into feature/PrefGRPO
LouisRouss Nov 2, 2025
1cba0ef
Merge remote-tracking branch 'origin/feature/PrefGRPO' into feature/RAE
LouisRouss Nov 2, 2025
1c62a25
add docstring to RAEDecode and RAE architectures
LouisRouss Nov 2, 2025
d64ba01
add losses for rae decoder training
LouisRouss Nov 2, 2025
81dc873
refactor GRPOTrainer: handle vision tower data shape based on patch size
LouisRouss Nov 3, 2025
ada351c
add RAEDiscriminator
LouisRouss Nov 3, 2025
5cf98a7
add support for canceling affine transformation in DinoV3 initialization
LouisRouss Nov 3, 2025
f2edb7a
fix: inherit from nn.Module in RAEDiscriminator class
LouisRouss Nov 3, 2025
8074048
add base imagenet dataset class
LouisRouss Nov 3, 2025
62e769a
Add rae training + fix code
LouisRouss Nov 11, 2025
43fa81b
add size to dinoV3
LouisRouss Nov 11, 2025
e5ca566
fix forward discriminator for gradient propagation
LouisRouss Nov 12, 2025
a393780
fix image logging
LouisRouss Nov 12, 2025
535a06a
downgrade base model for dinov3 repa
LouisRouss Nov 12, 2025
ccd60c0
downgrade base model rae
LouisRouss Nov 12, 2025
466a718
add configs for rae decoder training
LouisRouss Nov 12, 2025
23c2765
remove best val loss => doesnt make much sense here , TODO : add cmmd
LouisRouss Nov 12, 2025
ef61972
add shit in t sampling
LouisRouss Nov 17, 2025
a86a5ed
refactor RAETrainer: improve tensor handling and add scheduler checkp…
LouisRouss Nov 17, 2025
a619c31
- remove precomputed embedder for context => need to think more thoro…
LouisRouss Nov 18, 2025
6430601
set qwen model to eval + remove grad
LouisRouss Nov 18, 2025
052b3f4
bump transformers for qwen 3 vl
LouisRouss Nov 18, 2025
1381596
refactor RAETrainer: update tracker keys for discriminator loss logging
LouisRouss Nov 18, 2025
74162c0
feat(RAETrainer): add hinge GAN option for loss calculation
LouisRouss Nov 20, 2025
8801006
add axialRope
LouisRouss Nov 21, 2025
653647b
use AxialRope in networks
LouisRouss Nov 21, 2025
f0ef335
use AxialRope in networks
LouisRouss Nov 21, 2025
7d40aa1
Merge branch 'feature/RAE' of github.com:LouisRouss/DiffuLab into fea…
LouisRouss Nov 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ share/python-wheels/
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

Expand Down Expand Up @@ -83,36 +81,12 @@ notebooks/
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
# PEP 582
__pypackages__/

# Celery stuff
Expand Down Expand Up @@ -155,13 +129,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# running logs
examples/wandb
outputs/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ Here is a To-Do list, feel welcome to help to any point along this list. The alr
- [ ] add some more context embedders
- [ ] add reflow algorithm
- [ ] add EDM
- [ ] think about how to add a sampler abstraction and use it in the different Diffusion classes (generalist class with euler, heuns etc)
- [x] think about how to add a sampler abstraction and use it in the different Diffusion classes (generalist class with euler, heuns etc)
- [ ] Train our models on toy datasets for different tasks (conditional generation, Image to Image ...)
- [ ] Add possibility to train LORA/DORA
- [ ] add different sampler
- [x] add different sampler
- [ ] Try out Differential Transformers
- [ ] Check to add https://arxiv.org/pdf/2406.02507
- [ ] inject lessons learned from nvidia https://developer.nvidia.com/blog/rethinking-how-to-train-diffusion-models/
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# CIFAR10 dataset configuration
train:
_target_: diffulab.datasets.ImageNetLatentREPA
_target_: diffulab.datasets.ImageNetLatent
data_path: "data/imagenet"
local: true
batch_size: 128
split: "train"

val:
_target_: diffulab.datasets.ImageNetLatentREPA
_target_: diffulab.datasets.ImageNetLatent
data_path: "data/imagenet"
local: true
batch_size: 128
Expand Down
16 changes: 16 additions & 0 deletions configs/dataset/imagenet_noised_latents.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# CIFAR10 dataset configuration
train:
_target_: diffulab.datasets.ImageNetNoisyLatent
data_path: "data/imagenet"
local: true
batch_size: 128
split: "train"
noise_tau: 0.8

val:
_target_: diffulab.datasets.ImageNetNoisyLatent
data_path: "data/imagenet"
local: true
batch_size: 128
split: "val"
noise_tau : 0
3 changes: 3 additions & 0 deletions configs/discriminator/rae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: diffulab.networks.disc.RAEDiscriminator
model_name: facebook/dino-vits8
features_depth: [2, 5, 8, 11]
5 changes: 5 additions & 0 deletions configs/optimizer/adam.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 0
betas: [0.9, 0.999]
eps: 1e-8
4 changes: 2 additions & 2 deletions configs/train_imagenet_flow_matching_repa.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# configs/train_cifar10_flow_matching.yaml
# configs/train_imagenet_flow_matching_repa.yaml
# @package _global_
defaults:
- model: dit
- diffuser: rectified_flow
- trainer: default
- dataset: imagenet_repa
- dataset: imagenet_latents
- dataloader: default
- optimizer: adamw
- vision_tower: dcae
Expand Down
41 changes: 41 additions & 0 deletions configs/train_imagenet_rae_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
defaults:
- vision_tower: rae
- discriminator: rae
- dataloader: default
- dataset: imagenet_noised_latents
- /optimizer@optimizer.rae_decoder: adam
- /optimizer@optimizer.discriminator: adam
- trainer: default

vision_tower:
load_encoder: false

optimizer:
rae_decoder:
betas: [0.5, 0.9]
lr : 2e-4
discriminator:
betas: [0.5, 0.9]
lr : 2e-4

dataloader:
batch_size: 32

dataset:
train:
batch_size: 32
val:
batch_size: 32

trainer:
project_name: imagenet_rae_decoder
n_epoch: 16
precision_type: "bf16"
per_batch_scheduler: true
disc_epoch_start: 6
gan_epoch_start: 8
lpips_epoch_start: 1
lambda_lpips: 1
lambda_gan: 0.75
use_adaptive_weight_loss: true
gradient_accumulation_step: 16
19 changes: 19 additions & 0 deletions configs/vision_tower/rae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_target_: diffulab.networks.vision_towers.RAE
decoder:
_target_: diffulab.networks.vision_towers.RAEDecoder
out_size: [256, 256]
out_channels: 3
encoder_dim: 768
input_dim: 1024
hidden_dim: 1024
num_heads: 16
mlp_ratio: 4
patch_size: 16
depth: 24
partial_rotary_factor: 1
use_checkpoint: False
dropout_attn: 0
dropout_mlp: 0
dinov3_id: facebook/dinov3-vitb16-pretrain-lvd1689m
load_encoder: true
encoder_patch_size: 16
9 changes: 6 additions & 3 deletions examples/train_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer


@hydra.main(version_base=None, config_path="../configs", config_name="train_mnist_flow_matching")
Expand Down Expand Up @@ -52,8 +52,7 @@ def count_parameters(model: torch.nn.Module) -> int:
params=denoiser.parameters(),
)

# TODO: add a run name for wandb
trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand All @@ -62,6 +61,10 @@ def count_parameters(model: torch.nn.Module) -> int:
ema_update_after_step=cfg.trainer.get("ema_update_after_step", 0),
ema_update_every=cfg.trainer.get("ema_update_every", 10),
run_config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[reportArgumentType]
compile=cfg.trainer.get("compile", False),
init_kwargs={
"wandb": cfg.trainer.get("wandb", {}),
},
)

trainer.train(
Expand Down
125 changes: 125 additions & 0 deletions examples/train_rae_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import math

import hydra
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from diffulab.datasets.imagenet import ImageNetLatent
from diffulab.networks.disc import RAEDiscriminator
from diffulab.networks.vision_towers.rae import RAE
from diffulab.training.trainers.extra.rae_trainer import RAETrainer


def cosine_with_warmup_and_min_lr_lambda(
current_step: int, num_warmup_steps: int, num_training_steps: int, min_lr_factor: float = 0.1
) -> float:
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = (current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
# interpolate between 1 and min_lr_factor
return min_lr_factor + (1 - min_lr_factor) * cosine


def get_cosine_schedule_with_warmup_and_min_lr(
optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_factor: float = 0.1
) -> LambdaLR:
lr_lambda = lambda step: cosine_with_warmup_and_min_lr_lambda( # type: ignore
step, # type: ignore
num_warmup_steps,
num_training_steps,
min_lr_factor, # type: ignore
)
return LambdaLR(optimizer, lr_lambda) # type: ignore


@hydra.main(version_base=None, config_path="../configs", config_name="train_imagenet_rae_decoder")
def train(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))

train_dataset: ImageNetLatent = instantiate(cfg.dataset.train)
val_dataset: ImageNetLatent = instantiate(cfg.dataset.val)

train_dataset.set_latent_scale(1)
val_dataset.set_latent_scale(1)

dl_cfg = cfg.get("dataloader", {})
train_loader = DataLoader(
dataset=train_dataset,
batch_size=dl_cfg.get("batch_size", 32),
shuffle=dl_cfg.get("shuffle", True),
num_workers=dl_cfg.get("num_workers", 0),
pin_memory=dl_cfg.get("pin_memory", False),
)

val_loader = DataLoader(
dataset=val_dataset,
batch_size=dl_cfg.get("batch_size", 32),
shuffle=dl_cfg.get("shuffle", False),
num_workers=dl_cfg.get("num_workers", 0),
pin_memory=dl_cfg.get("pin_memory", False),
)

rae: RAE = instantiate(cfg.vision_tower)
discriminator: RAEDiscriminator = instantiate(cfg.discriminator)

def count_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters in the rae decoder: {count_parameters(rae.decoder):,}")
print(f"Number of trainable parameters in the discriminator: {count_parameters(discriminator):,}")

rae_optimizer = instantiate(cfg.optimizer.rae_decoder, params=rae.decoder.parameters())
disc_optimizer = instantiate(cfg.optimizer.discriminator, params=discriminator.parameters())

rae_scheduler = get_cosine_schedule_with_warmup_and_min_lr(
optimizer=rae_optimizer,
num_warmup_steps=len(train_loader),
num_training_steps=cfg.trainer.n_epoch * len(train_loader),
)
disc_scheduler = get_cosine_schedule_with_warmup_and_min_lr(
optimizer=disc_optimizer,
num_warmup_steps=len(train_loader),
num_training_steps=(cfg.trainer.n_epoch - cfg.trainer.disc_epoch_start) * len(train_loader),
)

rae_trainer = RAETrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
project_name=cfg.trainer.project_name,
use_ema=cfg.trainer.use_ema,
ema_update_after_step=cfg.trainer.get("ema_update_after_step", 0),
ema_update_every=cfg.trainer.get("ema_update_every", 10),
run_config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[reportArgumentType]
compile=cfg.trainer.get("compile", False),
init_kwargs={
"wandb": cfg.trainer.get("wandb", {}),
},
)

rae_trainer.train(
rae=rae,
disc=discriminator,
rae_optimizer=rae_optimizer,
disc_optimizer=disc_optimizer,
train_dataloader=train_loader,
val_dataloader=val_loader,
rae_scheduler=rae_scheduler,
disc_scheduler=disc_scheduler,
per_batch_scheduler=cfg.trainer.get("per_batch_scheduler", True),
log_validation_images=cfg.trainer.get("log_validation_images", True),
disc_epoch_start=cfg.trainer.get("disc_epoch_start", 6),
gan_epoch_start=cfg.trainer.get("gan_epoch_start", 8),
lpips_epoch_start=cfg.trainer.get("lpips_epoch_start", 1),
lambda_lpips=cfg.trainer.get("lambda_lpips", 1.0),
lambda_gan=cfg.trainer.get("lambda_gan", 0.75),
use_adaptive_weight_loss=cfg.trainer.get("use_adaptive_weight_loss", True),
)


if __name__ == "__main__":
train()
4 changes: 2 additions & 2 deletions examples/train_repa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer
from diffulab.training.losses.repa import RepaLoss


Expand Down Expand Up @@ -77,7 +77,7 @@ def count_parameters(model: torch.nn.Module) -> int:
+ list(repa_loss.resampler.parameters() if repa_loss.resampler else []),
)

trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand Down
Loading
Loading