Skip to content

Commit c446bfa

Browse files
committed
test changes
1 parent 4f4918c commit c446bfa

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

tests/integration_tests/flux.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
2727
"--parallelism.data_parallel_replicate_degree 2",
2828
"--parallelism.context_parallel_degree 2",
2929
"--validation.enable",
30+
"--validaiton.step 5" "--checkpoint.enable",
3031
],
32+
[],
3133
],
32-
"HSDP+CP+Validation",
33-
"hsdp+cp+validation",
34+
"HSDP+CP+Validation+Inference",
35+
"hsdp+cp+validation+inference",
3436
ngpu=8,
3537
),
3638
]
@@ -55,7 +57,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
5557
t5_encoder_version_arg = (
5658
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
5759
)
58-
tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
60+
hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer"
5961

6062
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
6163

@@ -76,7 +78,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
7678
cmd += " " + random_init_encoder_arg
7779
cmd += " " + clip_encoder_version_arg
7880
cmd += " " + t5_encoder_version_arg
79-
cmd += " " + tokenzier_path_arg
81+
cmd += " " + hf_assets_path_arg
8082
if override_arg:
8183
cmd += " " + " ".join(override_arg)
8284

torchtitan/models/flux/train_configs/debug_model.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ enable_wandb = false
2121
[model]
2222
name = "flux"
2323
flavor = "flux-debug"
24+
hf_assets_path = "tests/assets/tokenizer"
2425

2526
[optimizer]
2627
name = "AdamW"
@@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f
4849
[parallelism]
4950
data_parallel_replicate_degree = 1
5051
data_parallel_shard_degree = -1
52+
context_parallel_degree = 1
5153

5254
[activation_checkpoint]
5355
mode = "full"

0 commit comments

Comments
 (0)