Skip to content

Commit 58f3771

Browse files
turiansayakpaul
andauthored
Add optional precision-preserving preprocessing for examples/unconditional_image_generation/train_unconditional.py (#12596)
* Add optional precision-preserving preprocessing * Document decoder caveat for precision flag --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 6198f8a commit 58f3771

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

examples/unconditional_image_generation/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
104104
- you can either provide your own folder as `--train_data_dir`
105105
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
106106

107+
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
108+
107109
Below, we explain both in more detail.
108110

109111
#### Provide the dataset as a folder

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
5252
return res.expand(broadcast_shape)
5353

5454

55+
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
56+
"""
57+
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
58+
"""
59+
if tensor.ndim == 2:
60+
tensor = tensor.unsqueeze(0)
61+
channels = tensor.shape[0]
62+
if channels == 3:
63+
return tensor
64+
if channels == 1:
65+
return tensor.repeat(3, 1, 1)
66+
if channels == 2:
67+
return torch.cat([tensor, tensor[:1]], dim=0)
68+
if channels > 3:
69+
return tensor[:3]
70+
raise ValueError(f"Unsupported number of channels: {channels}")
71+
72+
5573
def parse_args():
5674
parser = argparse.ArgumentParser(description="Simple example of a training script.")
5775
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
260278
parser.add_argument(
261279
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
262280
)
281+
parser.add_argument(
282+
"--preserve_input_precision",
283+
action="store_true",
284+
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
285+
)
263286

264287
args = parser.parse_args()
265288
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def load_model_hook(models, input_dir):
453476
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
454477

455478
# Preprocessing the datasets and DataLoaders creation.
479+
spatial_augmentations = [
480+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
481+
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
482+
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
483+
]
484+
456485
augmentations = transforms.Compose(
457-
[
458-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
459-
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
460-
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
486+
spatial_augmentations
487+
+ [
461488
transforms.ToTensor(),
462489
transforms.Normalize([0.5], [0.5]),
463490
]
464491
)
465492

493+
precision_augmentations = transforms.Compose(
494+
[
495+
transforms.PILToTensor(),
496+
transforms.Lambda(_ensure_three_channels),
497+
transforms.ConvertImageDtype(torch.float32),
498+
]
499+
+ spatial_augmentations
500+
+ [transforms.Normalize([0.5], [0.5])]
501+
)
502+
466503
def transform_images(examples):
467-
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
468-
return {"input": images}
504+
processed = []
505+
for image in examples["image"]:
506+
if not args.preserve_input_precision:
507+
processed.append(augmentations(image.convert("RGB")))
508+
else:
509+
precise_image = image
510+
if precise_image.mode == "P":
511+
precise_image = precise_image.convert("RGB")
512+
processed.append(precision_augmentations(precise_image))
513+
return {"input": processed}
469514

470515
logger.info(f"Dataset size: {len(dataset)}")
471516

0 commit comments

Comments
 (0)