Skip to content

Commit fca6e8f

Browse files
Merge pull request #2617 from AI-Hypercomputer:chzheng/integration_tests
PiperOrigin-RevId: 831554205
2 parents 9642e89 + 2ec16b8 commit fca6e8f

File tree

4 files changed

+70
-27
lines changed

4 files changed

+70
-27
lines changed

.github/workflows/RunTests.yml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,19 @@ jobs:
121121
container_resource_option: "--privileged"
122122
is_scheduled_run: ${{ github.event_name == 'schedule' }}
123123

124+
tpu_pathways_integration_tests:
125+
needs: tpu_image
126+
uses: ./.github/workflows/run_pathways_tests_internal.yml
127+
with:
128+
device_type: tpu
129+
device_name: v4-8
130+
cloud_runner: linux-x86-ct4p-240-4tpu
131+
pytest_marker: 'not cpu_only and not gpu_only and integration_test'
132+
xla_python_client_mem_fraction: 0.75
133+
tf_force_gpu_allow_growth: false
134+
container_resource_option: "--privileged"
135+
is_scheduled_run: ${{ github.event_name == 'schedule' }}
136+
124137
gpu_unit_tests:
125138
needs: gpu_image
126139
uses: ./.github/workflows/run_tests_internal.yml
@@ -151,7 +164,7 @@ jobs:
151164

152165
clean_up:
153166
if: ${{ always() }}
154-
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests]
167+
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests]
155168
name: "Clean up"
156169
runs-on: ["self-hosted"]
157170
permissions:
@@ -170,7 +183,7 @@ jobs:
170183

171184
notify_failure:
172185
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
173-
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests]
186+
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests]
174187
if: ${{ always() }}
175188
runs-on: ubuntu-latest
176189
permissions:
@@ -202,7 +215,7 @@ jobs:
202215
name: Close issue after 3 successful builds
203216
# This job runs only if all the preceding test jobs succeeded
204217
if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
205-
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests]
218+
needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests]
206219
runs-on: ubuntu-latest
207220
permissions:
208221
issues: write

tests/integration_tests/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
"""
16+
Test initialization
17+
"""
18+
19+
import pathwaysutils
20+
21+
pathwaysutils.initialize()

tests/integration_tests/checkpointing_test.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,32 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
4242
"base_num_decoder_layers=8",
4343
"head_dim=128",
4444
]
45-
return [
46-
None,
47-
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
48-
f"hardware={hardware}",
49-
f"run_name=runner_{run_date}",
50-
f"steps={steps}",
51-
"max_target_length=128",
52-
"per_device_batch_size=1",
53-
f"metrics_file={metrics_file}",
54-
"checkpoint_period=3",
55-
"base_output_directory=gs://runner-maxtext-logs",
56-
f"dataset_path={dataset_path}",
57-
f"dataset_type={dataset_type}",
58-
"async_checkpointing=False",
59-
f"attention={attention_type}",
60-
] + model_params
45+
pathways_command = []
46+
if os.getenv("JAX_PLATFORMS") == "proxy":
47+
pathways_command = [
48+
"enable_single_controller=True",
49+
"checkpoint_storage_use_zarr3=False",
50+
]
51+
return (
52+
[
53+
None,
54+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
55+
f"hardware={hardware}",
56+
f"run_name=runner_{run_date}",
57+
f"steps={steps}",
58+
"max_target_length=128",
59+
"per_device_batch_size=1",
60+
f"metrics_file={metrics_file}",
61+
"checkpoint_period=3",
62+
"base_output_directory=gs://runner-maxtext-logs",
63+
f"dataset_path={dataset_path}",
64+
f"dataset_type={dataset_type}",
65+
"async_checkpointing=False",
66+
f"attention={attention_type}",
67+
]
68+
+ model_params
69+
+ pathways_command
70+
)
6171

6272

6373
def check_loss(metrics_file, target):

tests/integration_tests/generate_param_only_checkpoint_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
5353
"per_device_batch_size=1",
5454
] + model_config
5555

56+
pathways_command = []
57+
if os.getenv("JAX_PLATFORMS") == "proxy":
58+
pathways_command = ["enable_single_controller=True"]
59+
5660
if state_path is None:
5761
# Run training to get a checkpoint
5862
train_main(
@@ -69,17 +73,25 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
6973
state_path = f"gs://runner-maxtext-logs/runner_{run_date}/checkpoints/0/items"
7074

7175
# Generate parameter-only checkpoint
72-
generate_param_only_ckpt_config = test_config + [
73-
f"run_name=generate_param_{run_date}",
74-
f"load_full_state_path={state_path}",
75-
]
76+
generate_param_only_ckpt_config = (
77+
test_config
78+
+ [
79+
f"run_name=generate_param_{run_date}",
80+
f"load_full_state_path={state_path}",
81+
]
82+
+ pathways_command
83+
)
7684
generate_param_only_ckpt_main(generate_param_only_ckpt_config)
7785

7886
# Run inference on parameter-only checkpoint
79-
decode_config = test_config + [
80-
f"run_name=decode_{run_date}",
81-
f"load_parameters_path=gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items",
82-
]
87+
decode_config = (
88+
test_config
89+
+ [
90+
f"run_name=decode_{run_date}",
91+
f"load_parameters_path=gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items",
92+
]
93+
+ pathways_command
94+
)
8395
decode_main(decode_config)
8496

8597

0 commit comments

Comments
 (0)