Skip to content

Commit 4f1cb63

Browse files
committed
revert
1 parent 1ceff80 commit 4f1cb63

File tree

5 files changed

+14
-50
lines changed

5 files changed

+14
-50
lines changed

torchtitan/models/flux/inference/infer.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,17 @@ def inference(config: JobConfig):
2626
world_size = int(os.environ["WORLD_SIZE"])
2727
global_rank = int(os.environ["RANK"])
2828

29-
single_prompt_mode = config.inference.prompt is not None
30-
31-
# Use single prompt if specified, otherwise read from file
32-
if single_prompt_mode:
33-
original_prompts = [config.inference.prompt]
34-
logger.info(f"Using single prompt: {config.inference.prompt}")
35-
bs = 1
36-
# If only single prompt, each rank will generate an image with the same prompt
37-
prompts = original_prompts
38-
else:
39-
original_prompts = open(config.inference.prompts_path).readlines()
40-
logger.info(f"Reading prompts from: {config.inference.prompts_path}")
41-
bs = config.inference.local_batch_size
42-
# Distribute prompts across processes using round-robin assignment
43-
prompts = original_prompts[global_rank::world_size]
29+
original_prompts = open(config.inference.prompts_path).readlines()
30+
logger.info(f"Reading prompts from: {config.inference.prompts_path}")
31+
if len(original_prompts) < world_size:
32+
raise ValueError(
33+
f"Number of prompts ({len(prompts)}) must be >= number of ranks ({world_size}). "
34+
f"FSDP all-gather will hang if some ranks have no prompts to process."
35+
)
36+
37+
bs = config.inference.local_batch_size
38+
# Distribute prompts across processes using round-robin assignment
39+
prompts = original_prompts[global_rank::world_size]
4440

4541
total_prompts = len(original_prompts)
4642

@@ -58,13 +54,7 @@ def inference(config: JobConfig):
5854
config.inference.save_img_folder,
5955
)
6056
# Create mapping from local indices to global prompt indices
61-
if single_prompt_mode:
62-
# In single prompt mode, all ranks process the same prompt (index 0)
63-
# But each rank generates a different image (different seed/rank)
64-
global_ids = [0] * len(prompts)
65-
else:
66-
# In multi-prompt mode, use round-robin distribution
67-
global_ids = list(range(global_rank, total_prompts, world_size))
57+
global_ids = list(range(global_rank, total_prompts, world_size))
6858

6959
for i in range(0, len(prompts), bs):
7060
images = generate_image(

torchtitan/models/flux/inference/prompts.txt

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,3 @@ A serene mountain landscape at sunset with a crystal clear lake reflecting the g
22
A futuristic cityscape with flying cars and neon lights illuminating the night sky
33
A cozy cafe interior with steam rising from coffee cups and warm lighting
44
A magical forest with glowing mushrooms and fireflies dancing between ancient trees
5-
A peaceful beach scene with turquoise waves and palm trees swaying in the breeze
6-
A steampunk-inspired mechanical dragon soaring through clouds
7-
A mystical library with floating books and magical artifacts
8-
A Japanese garden in spring with cherry blossoms falling gently
9-
A space station orbiting a colorful nebula
10-
A medieval castle on a hilltop during a dramatic thunderstorm
11-
A underwater scene with bioluminescent creatures and coral reefs
12-
A desert oasis with a majestic palace and palm trees
13-
A cyberpunk street market with holographic signs and diverse crowds
14-
A cozy winter cabin surrounded by snow-covered pine trees
15-
A fantasy tavern filled with unique characters and magical atmosphere
16-
A tropical rainforest with exotic birds and waterfalls
17-
A steampunk airship navigating through storm clouds
18-
A peaceful zen garden with a traditional Japanese tea house
19-
A magical potion shop with bubbling cauldrons and mysterious ingredients
20-
A futuristic space colony on Mars with domed habitats
21-
A mystical temple hidden in the clouds
22-
A vintage train station with steam locomotives and period architecture
23-
A magical bakery with floating pastries and enchanted ingredients
24-
A peaceful countryside scene with rolling hills and a rustic farmhouse
25-
A underwater city with advanced technology and marine life
26-
A fantasy marketplace with magical creatures and exotic goods
27-
A peaceful meditation garden with lotus flowers and koi ponds
28-
A steampunk laboratory with intricate machinery and glowing elements

torchtitan/models/flux/job_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class Inference:
6262
"""Path to save the inference results"""
6363
prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt"
6464
"""Path to file with newline separated prompts to generate images for"""
65-
prompt: str = ""
66-
"""Single prompt to generate image for. If specified, takes precedence over prompts_path"""
6765
local_batch_size: int = 2
6866
"""Batch size for inference"""
6967
img_size: int = 256

torchtitan/models/flux/run_infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ set -ex
1010
# use envs as local overrides for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./torchtitan/models/flux/run_train.sh
13-
NGPU=${NGPU:-"4"}
13+
NGPU=${NGPU:-"8"}
1414
export LOG_RANK=${LOG_RANK:-0}
1515
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"}
1616

torchtitan/models/flux/run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ set -ex
1010
# use envs as local overrides for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh
13-
NGPU=${NGPU:-"4"}
13+
NGPU=${NGPU:-"8"}
1414
export LOG_RANK=${LOG_RANK:-0}
1515
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"}
1616

0 commit comments

Comments
 (0)