diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index c07d7e5ac..06528dd50 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -121,6 +121,19 @@ jobs: container_resource_option: "--privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + tpu_pathways_integration_tests: + needs: tpu_image + uses: ./.github/workflows/run_pathways_tests_internal.yml + with: + device_type: tpu + device_name: v4-8 + cloud_runner: linux-x86-ct4p-240-4tpu + pytest_marker: 'not cpu_only and not gpu_only and integration_test' + xla_python_client_mem_fraction: 0.75 + tf_force_gpu_allow_growth: false + container_resource_option: "--privileged" + is_scheduled_run: ${{ github.event_name == 'schedule' }} + gpu_unit_tests: needs: gpu_image uses: ./.github/workflows/run_tests_internal.yml @@ -151,7 +164,7 @@ jobs: clean_up: if: ${{ always() }} - needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests] + needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests] name: "Clean up" runs-on: ["self-hosted"] permissions: @@ -170,7 +183,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests] + needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: @@ -202,7 +215,7 @@ jobs: name: Close issue after 3 successful builds # This job runs only if all the preceding test jobs succeeded if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }} - needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests] + needs: [cpu_unit_tests, gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests, tpu_pathways_unit_tests, tpu_pathways_integration_tests] runs-on: ubuntu-latest permissions: issues: write diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py index 2237c9162..46cd7ffa1 100644 --- a/tests/integration_tests/__init__.py +++ b/tests/integration_tests/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Test initialization +""" + +import pathwaysutils + +pathwaysutils.initialize() diff --git a/tests/integration_tests/checkpointing_test.py b/tests/integration_tests/checkpointing_test.py index 4350c7324..aa35fc659 100644 --- a/tests/integration_tests/checkpointing_test.py +++ b/tests/integration_tests/checkpointing_test.py @@ -42,22 +42,32 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention "base_num_decoder_layers=8", "head_dim=128", ] - return [ - None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), - f"hardware={hardware}", - f"run_name=runner_{run_date}", - f"steps={steps}", - "max_target_length=128", - "per_device_batch_size=1", - f"metrics_file={metrics_file}", - "checkpoint_period=3", - "base_output_directory=gs://runner-maxtext-logs", - f"dataset_path={dataset_path}", - f"dataset_type={dataset_type}", - "async_checkpointing=False", - f"attention={attention_type}", - ] + model_params + pathways_command = [] + if os.getenv("JAX_PLATFORMS") == "proxy": + pathways_command = [ + "enable_single_controller=True", + "checkpoint_storage_use_zarr3=False", + ] + return ( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + f"hardware={hardware}", + f"run_name=runner_{run_date}", + f"steps={steps}", + "max_target_length=128", + "per_device_batch_size=1", + f"metrics_file={metrics_file}", + "checkpoint_period=3", + "base_output_directory=gs://runner-maxtext-logs", + f"dataset_path={dataset_path}", + f"dataset_type={dataset_type}", + "async_checkpointing=False", + f"attention={attention_type}", + ] + + model_params + + pathways_command + ) def check_loss(metrics_file, target): diff --git a/tests/integration_tests/generate_param_only_checkpoint_test.py b/tests/integration_tests/generate_param_only_checkpoint_test.py index 8e55e2315..08a8c5a03 100644 --- a/tests/integration_tests/generate_param_only_checkpoint_test.py +++ b/tests/integration_tests/generate_param_only_checkpoint_test.py @@ -53,6 +53,10 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta "per_device_batch_size=1", ] + model_config + pathways_command = [] + if os.getenv("JAX_PLATFORMS") == "proxy": + pathways_command = ["enable_single_controller=True"] + if state_path is None: # Run training to get a checkpoint train_main( @@ -69,17 +73,25 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta state_path = f"gs://runner-maxtext-logs/runner_{run_date}/checkpoints/0/items" # Generate parameter-only checkpoint - generate_param_only_ckpt_config = test_config + [ - f"run_name=generate_param_{run_date}", - f"load_full_state_path={state_path}", - ] + generate_param_only_ckpt_config = ( + test_config + + [ + f"run_name=generate_param_{run_date}", + f"load_full_state_path={state_path}", + ] + + pathways_command + ) generate_param_only_ckpt_main(generate_param_only_ckpt_config) # Run inference on parameter-only checkpoint - decode_config = test_config + [ - f"run_name=decode_{run_date}", - f"load_parameters_path=gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items", - ] + decode_config = ( + test_config + + [ + f"run_name=decode_{run_date}", + f"load_parameters_path=gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items", + ] + + pathways_command + ) decode_main(decode_config)