Skip to content

Commit f9a1fac

Browse files
committed
fix format
1 parent 4f1cb63 commit f9a1fac

File tree

1 file changed

+4
-6
lines changed
  • torchtitan/models/flux/inference

1 file changed

+4
-6
lines changed

torchtitan/models/flux/inference/infer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,19 @@ def inference(config: JobConfig):
2525
# Distributed processing setup: Each GPU/process handles a subset of prompts
2626
world_size = int(os.environ["WORLD_SIZE"])
2727
global_rank = int(os.environ["RANK"])
28-
2928
original_prompts = open(config.inference.prompts_path).readlines()
30-
logger.info(f"Reading prompts from: {config.inference.prompts_path}")
31-
if len(original_prompts) < world_size:
29+
total_prompts = len(original_prompts)
30+
31+
if total_prompts < world_size:
3232
raise ValueError(
33-
f"Number of prompts ({len(prompts)}) must be >= number of ranks ({world_size}). "
33+
f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). "
3434
f"FSDP all-gather will hang if some ranks have no prompts to process."
3535
)
3636

3737
bs = config.inference.local_batch_size
3838
# Distribute prompts across processes using round-robin assignment
3939
prompts = original_prompts[global_rank::world_size]
4040

41-
total_prompts = len(original_prompts)
42-
4341
trainer.checkpointer.load(step=config.checkpoint.load_step)
4442
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config)
4543

0 commit comments

Comments
 (0)