Skip to content

Commit 2730d52

Browse files
Run tests on TPU (#21425)
* added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro * updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend configs * Changed the tpu CI config files path to .github from .kokoro * Added new job in .github/workflows/actions.yml to run TPU tests * fixed runs-on option in acvtions.yml for tpu_build job to run on self hosted TPU based runner * Added another runner in the actions TPU job * Update continuous.cfg updated build file path * Update presubmit.cfg updated build file path * Update actions.yml Updated tpu_build job of actions.yml with specific runner label * Developed Dockerfile for TPU build job in actions.yml * Update actions.yml Added container section * Included few more runners in tpu_build job * Using linux-x86-ct6e-44-1tpu * Modified requirement-commmon.txt and updated requirements-tensorflow-tpu.txt * Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt * Progress bar now handles `steps_per_execution`. (#21422) Progress bar would always report the starting batch + 1 at the end of the batch. Now it takes into account `steps_per_execution` for the last batch reported. Fixes #20861 * Fix symbolic call of `logsumexp` with int axis. (#21428) Using `keras.ops.math.logsumexp` with an int for `axis` in a functional model would throw an error. * Only allow deserialization of `KerasSaveable`s by module and name. (#21429) Arbitrary functions and classes are not allowed. - Made `Operation` extend `KerasSaveable`, this required moving imports to avoid circular imports - `Layer` no longer need to extend `KerasSaveable` directly - Made feature space `Cross` and `Feature` extend `KerasSaveable` - Also dissallow public function `enable_unsafe_deserialization` * commented tensorflow deps * Added log of dtypes_test_tpu.py and the test script for the same * modified dtypes_test_tpu.py as per pre-commit standards * Added TPU initiaization and teardown functionalities in conftest.py, developed dtypes_new_test.py to use requires_tpu marker * Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py * Added Dcokerfile and tests list command * Updated Dockerfile * Restored Dockerfile to previous changes * updated actions.yml file to install and configure docker engine on self hosted runner, build the image and check TPU support on jax backend * updated actions.yml file to include container option * updated actions.yml file to include container option without volume binding * updated actions.yml file to change TPU * Updated container path in build-and-test-on-tpu job * seperated TPU workflow from actions.yml * updated trigger condition for TPU tests workflow * updated container usage configuration for TPU tests workflow * updated env vars for TPU tests workflow * updated env vars parsing syntax in TPU tests workflow * updated env vars syntax in TPU tests workflow * updated env vars syntax in TPU tests workflow * updated env vars syntax in TPU tests workflow * updated env vars syntax in TPU tests workflow * updated image name in TPU tests workflow * updated image name with generic ubuntu image * updated tpu-tests to use ghcr * updated tpu-tests to store built image as local tar * updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests workflow * updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests workflow and added a step to install docker client * added volume mount from host in load-and-test-job * Reverted tpu-tests.yml to version using ghcr.io for image storage * Removed custom dtypes_test files for TPU testing and restored original actions.yml * Updated tpu-tests.yml to pull image from GCP artifact registry * Resolved conflicts in actions.yml * Added a workflow to check service accounts associated with self hosted runners * Made find_sa.yml specific to linux-x86-ct6e-44-1tpu * Added container tag to find_sa.yml * Checking SA for linux-x86-ct5lp-112-4tpu * Checking SA for linux-x86-ct6e-44-1tpu-nxgm7-runner-vb87c * Using SA for auth in tpu-tests * Updated SA with container tag for auth in tpu-tests * Added docker socket mount test * Updated tpu-tests to just pull and test the image from artifact registry after attaching service accounts as IAM policies * Added pytest command to the workflow * added grain installation command * Pruned unwanted files * included grain in requirements.txt * Updated tpu-tests.yml to use python image and explicitly install specific backend deps * Renamed tpu-tests to tpu-tests-jax and logging TPU device kind * Added a step to check gcloud installation * Running pytest on generic tpu workflow * Made changes as per suggestions in PR * Fixed error in action file * Added a job in tpu workflow to persist failed tests list * using requirements-jax-tpu.txt * Reverted the tpu-tests-jax.yml * Removed a command line option in tpu workflow file * Removed uninstall step from tpu workflow job * reverted tensorflow version in requirements file * Updated the tpu workflow to skip failing tests * Clean up TPU tests workflow by removing comment Removed commented out pip uninstall command from workflow. * Changed the failed tests file path and updated the same in conftest.py * Updated the failed test file path * Updated the workflow and job names for TPU * Added TPU job in actions.yml * Added TPU config print line * Added condition to TPU test job so that it gets triggered only after the PR review is approved * Added TPU specific tests in seperate workflow with PR approval condition * Removed pull_request condition for workflow execution and renamed the workflow --------- Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
1 parent bfde12b commit 2730d52

File tree

7 files changed

+341
-2
lines changed

7 files changed

+341
-2
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,4 @@ jobs:
148148
pip uninstall -y keras keras-nightly
149149
pip install -e "." --progress-bar off --upgrade
150150
- name: Run pre-commit
151-
run: pre-commit run --all-files --hook-stage manual
151+
run: pre-commit run --all-files --hook-stage manual

.github/workflows/tpu_tests.yml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Keras Tests
2+
3+
# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
4+
# Currently only basic flow tests run with NNX enabled
5+
6+
on:
7+
push:
8+
branches: [ master ]
9+
pull_request_review:
10+
types: [submitted]
11+
release:
12+
types: [created]
13+
14+
permissions:
15+
contents: read
16+
17+
jobs:
18+
19+
test-in-container:
20+
name: Run tests on TPU
21+
runs-on: linux-x86-ct6e-44-1tpu
22+
# Only run on approved PRs, pushes to master, or releases
23+
if: |
24+
github.event_name == 'push' ||
25+
github.event_name == 'release' ||
26+
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved')
27+
28+
strategy:
29+
fail-fast: false
30+
matrix:
31+
backend: [jax]
32+
33+
container:
34+
image: python:3.10-slim
35+
options: --privileged --network host
36+
37+
steps:
38+
- name: Checkout Repository
39+
uses: actions/checkout@v4
40+
41+
- name: Install Dependencies
42+
run: |
43+
pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt \
44+
45+
- name: Set Keras Backend
46+
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV
47+
48+
- name: Run Verification and Tests
49+
run: |
50+
echo "Successfully running inside the public python container!"
51+
echo "Verifying JAX installation..."
52+
python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices : {jax.devices()}')"
53+
54+
pytest keras --ignore keras/src/applications \
55+
--cov=keras \
56+
--cov-config=pyproject.toml

conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ def pytest_collection_modifyitems(config, items):
3131
line.strip() for line in openvino_skipped_tests if line.strip()
3232
]
3333

34+
tpu_skipped_tests = []
35+
if backend() == "jax":
36+
try:
37+
with open(
38+
"keras/src/backend/jax/excluded_tpu_tests.txt", "r"
39+
) as file:
40+
tpu_skipped_tests = file.readlines()
41+
# it is necessary to check if stripped line is not empty
42+
# and exclude such lines
43+
tpu_skipped_tests = [
44+
line.strip() for line in tpu_skipped_tests if line.strip()
45+
]
46+
except FileNotFoundError:
47+
pass # File doesn't exist, no tests to skip
48+
3449
requires_trainable_backend = pytest.mark.skipif(
3550
backend() in ["numpy", "openvino"],
3651
reason="Trainer not implemented for NumPy and OpenVINO backend.",
@@ -49,6 +64,14 @@ def pytest_collection_modifyitems(config, items):
4964
"Not supported operation by openvino backend",
5065
)
5166
)
67+
# also, skip concrete tests for TPU when using JAX backend
68+
for skipped_test in tpu_skipped_tests:
69+
if skipped_test in item.nodeid:
70+
item.add_marker(
71+
pytest.mark.skip(
72+
reason="Known TPU test failure",
73+
)
74+
)
5275

5376

5477
def skip_if_backend(given_backend, reason):
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
AdditiveAttentionTest::test_attention_correctness
2+
AttentionTest::test_attention_calculate_scores_with_scale
3+
AttentionTest::test_attention_correctness
4+
CircleTest::test_correctness
5+
CircleTest::test_correctness_weighted
6+
CircleTest::test_mean_with_sample_weight_reduction
7+
CircleTest::test_no_reduction
8+
CircleTest::test_sum_reduction
9+
ConvBasicTest::test_enable_lora_with_alpha
10+
ConvCorrectnessTest::test_conv1d0
11+
ConvCorrectnessTest::test_conv1d1
12+
ConvCorrectnessTest::test_conv1d2
13+
ConvCorrectnessTest::test_conv1d3
14+
ConvCorrectnessTest::test_conv1d4
15+
ConvCorrectnessTest::test_conv2d0
16+
ConvCorrectnessTest::test_conv2d1
17+
ConvCorrectnessTest::test_conv2d2
18+
ConvCorrectnessTest::test_conv2d3
19+
ConvCorrectnessTest::test_conv2d4
20+
ConvCorrectnessTest::test_conv2d5
21+
ConvCorrectnessTest::test_conv3d0
22+
ConvCorrectnessTest::test_conv3d1
23+
ConvCorrectnessTest::test_conv3d2
24+
ConvCorrectnessTest::test_conv3d3
25+
ConvCorrectnessTest::test_conv3d4
26+
ConvLSTM1DTest::test_correctness
27+
ConvLSTM1DTest::test_correctness
28+
ConvLSTM2DTest::test_correctness
29+
ConvLSTMCellTest::test_correctness
30+
ConvLSTMTest::test_correctness
31+
ConvTransposeCorrectnessTest::test_conv1d_transpose0
32+
ConvTransposeCorrectnessTest::test_conv1d_transpose1
33+
ConvTransposeCorrectnessTest::test_conv1d_transpose2
34+
ConvTransposeCorrectnessTest::test_conv2d_transpose0
35+
ConvTransposeCorrectnessTest::test_conv2d_transpose1
36+
ConvTransposeCorrectnessTest::test_conv2d_transpose2
37+
ConvTransposeCorrectnessTest::test_conv2d_transpose3
38+
ConvTransposeCorrectnessTest::test_conv3d_transpose0
39+
ConvTransposeCorrectnessTest::test_conv3d_transpose1
40+
ConvTransposeCorrectnessTest::test_conv3d_transpose2
41+
CTCTest::test_correctness
42+
DenseTest::test_dense_sparse
43+
DepthwiseConvCorrectnessTest::test_depthwise_conv1d0
44+
DepthwiseConvCorrectnessTest::test_depthwise_conv1d1
45+
DepthwiseConvCorrectnessTest::test_depthwise_conv1d2
46+
DepthwiseConvCorrectnessTest::test_depthwise_conv2d0
47+
DepthwiseConvCorrectnessTest::test_depthwise_conv2d1
48+
DepthwiseConvCorrectnessTest::test_depthwise_conv2d2
49+
EinsumDenseTest::test_enable_lora_with_alpha
50+
EmbeddingTest::test_enable_lora_with_alpha
51+
ExportArchiveTest::test_jax_endpoint_registration_tf_function
52+
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
53+
ExportArchiveTest::test_layer_export
54+
ExportArchiveTest::test_low_level_model_export_functional
55+
ExportArchiveTest::test_low_level_model_export_sequential
56+
ExportArchiveTest::test_low_level_model_export_subclass
57+
ExportArchiveTest::test_low_level_model_export_with_alias
58+
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional
59+
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential
60+
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass
61+
ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs
62+
ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes
63+
ExportArchiveTest::test_model_combined_with_tf_preprocessing
64+
ExportArchiveTest::test_model_export_method_functional
65+
ExportArchiveTest::test_model_export_method_sequential
66+
ExportArchiveTest::test_model_export_method_subclass
67+
ExportArchiveTest::test_multi_input_output_functional_model
68+
ExportArchiveTest::test_non_standard_layer_signature
69+
ExportArchiveTest::test_non_standard_layer_signature_with_kwargs
70+
ExportArchiveTest::test_track_multiple_layers
71+
ExportONNXTest::test_export_with_input_names
72+
ExportONNXTest::test_export_with_opset_version_18
73+
ExportONNXTest::test_export_with_opset_version_none
74+
ExportONNXTest::test_standard_model_export_functional
75+
ExportONNXTest::test_standard_model_export_lstm
76+
ExportONNXTest::test_standard_model_export_sequential
77+
ExportONNXTest::test_standard_model_export_subclass
78+
ExportOpenVINOTest::test_standard_model_export_functional
79+
ExportOpenVINOTest::test_standard_model_export_sequential
80+
ExportOpenVINOTest::test_standard_model_export_subclass
81+
ExportSavedModelTest::test_input_signature_functional_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
82+
ExportSavedModelTest::test_input_signature_functional_backend_tensor
83+
ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2)
84+
ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
85+
ExportSavedModelTest::test_input_signature_sequential_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
86+
ExportSavedModelTest::test_input_signature_sequential_backend_tensor
87+
ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2)
88+
ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
89+
ExportSavedModelTest::test_input_signature_subclass_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
90+
ExportSavedModelTest::test_input_signature_subclass_backend_tensor
91+
ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2)
92+
ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
93+
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true}
94+
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none
95+
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true}
96+
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none
97+
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true}
98+
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none
99+
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true}
100+
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none
101+
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true}
102+
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none
103+
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true}
104+
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none
105+
ExportSavedModelTest::test_model_with_input_structure_array
106+
ExportSavedModelTest::test_model_with_input_structure_dict
107+
ExportSavedModelTest::test_model_with_input_structure_tuple
108+
ExportSavedModelTest::test_model_with_multiple_inputs
109+
ExportSavedModelTest::test_model_with_non_trainable_state_export_functional
110+
ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential
111+
ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass
112+
ExportSavedModelTest::test_model_with_rng_export_functional
113+
ExportSavedModelTest::test_model_with_rng_export_sequential
114+
ExportSavedModelTest::test_model_with_rng_export_subclass
115+
ExportSavedModelTest::test_model_with_tf_data_layer_functional
116+
ExportSavedModelTest::test_model_with_tf_data_layer_sequential
117+
ExportSavedModelTest::test_model_with_tf_data_layer_subclass
118+
ExportSavedModelTest::test_standard_model_export_functional
119+
ExportSavedModelTest::test_standard_model_export_sequential
120+
ExportSavedModelTest::test_standard_model_export_subclass
121+
GRUTest::test_correctness0
122+
GRUTest::test_correctness1
123+
GRUTest::test_legacy_implementation_argument
124+
GRUTest::test_masking
125+
GRUTest::test_pass_initial_state
126+
GRUTest::test_pass_return_state
127+
GRUTest::test_statefulness
128+
ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant
129+
ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror
130+
ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest
131+
ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect
132+
ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap
133+
LinalgOpsCorrectnessTest::test_cholesky_inverse_lower
134+
LinalgOpsCorrectnessTest::test_cholesky_inverse_upper
135+
LinalgOpsCorrectnessTest::test_eig
136+
LinalgOpsCorrectnessTest::test_svd
137+
LSTMTest::test_correctness0
138+
LSTMTest::test_correctness1
139+
LSTMTest::test_masking
140+
LSTMTest::test_pass_initial_state
141+
LSTMTest::test_statefulness
142+
MathOpsCorrectnessTest::test_extract_sequences
143+
MergingLayersTest::test_correctness_dynamic_dot_3d
144+
MergingLayersTest::test_correctness_static_dot_3d
145+
MuonTest::test_Newton_Schulz
146+
NNOpsCorrectnessTest::test_conv_2d0
147+
NNOpsCorrectnessTest::test_conv_2d1
148+
NNOpsCorrectnessTest::test_conv_2d2
149+
NNOpsCorrectnessTest::test_conv_2d3
150+
NNOpsCorrectnessTest::test_conv_2d4
151+
NNOpsCorrectnessTest::test_conv_2d5
152+
NNOpsCorrectnessTest::test_conv_3d0
153+
NNOpsCorrectnessTest::test_conv_3d1
154+
NNOpsCorrectnessTest::test_conv_3d10
155+
NNOpsCorrectnessTest::test_conv_3d11
156+
NNOpsCorrectnessTest::test_conv_3d2
157+
NNOpsCorrectnessTest::test_conv_3d3
158+
NNOpsCorrectnessTest::test_conv_3d4
159+
NNOpsCorrectnessTest::test_conv_3d5
160+
NNOpsCorrectnessTest::test_conv_3d6
161+
NNOpsCorrectnessTest::test_conv_3d7
162+
NNOpsCorrectnessTest::test_conv_3d8
163+
NNOpsCorrectnessTest::test_conv_3d9
164+
NNOpsCorrectnessTest::test_ctc_loss
165+
NNOpsCorrectnessTest::test_depthwise_conv_2d0
166+
NNOpsCorrectnessTest::test_depthwise_conv_2d1
167+
NNOpsCorrectnessTest::test_depthwise_conv_2d10
168+
NNOpsCorrectnessTest::test_depthwise_conv_2d11
169+
NNOpsCorrectnessTest::test_depthwise_conv_2d2
170+
NNOpsCorrectnessTest::test_depthwise_conv_2d3
171+
NNOpsCorrectnessTest::test_depthwise_conv_2d4
172+
NNOpsCorrectnessTest::test_depthwise_conv_2d5
173+
NNOpsCorrectnessTest::test_depthwise_conv_2d6
174+
NNOpsCorrectnessTest::test_depthwise_conv_2d7
175+
NNOpsCorrectnessTest::test_depthwise_conv_2d8
176+
NNOpsCorrectnessTest::test_depthwise_conv_2d9
177+
NNOpsCorrectnessTest::test_separable_conv_2d0
178+
NNOpsCorrectnessTest::test_separable_conv_2d1
179+
NNOpsCorrectnessTest::test_separable_conv_2d2
180+
NNOpsCorrectnessTest::test_separable_conv_2d3
181+
NNOpsCorrectnessTest::test_separable_conv_2d4
182+
NNOpsCorrectnessTest::test_separable_conv_2d5
183+
NNOpsCorrectnessTest::test_separable_conv_2d6
184+
NNOpsCorrectnessTest::test_separable_conv_2d7
185+
NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero
186+
NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero
187+
NumpyTwoInputOpsCorrectnessTest::test_logspace
188+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false
189+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false
190+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false
191+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false
192+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false
193+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false
194+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false
195+
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false
196+
RandomGaussianBlurTest::test_random_erasing_basic
197+
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale
198+
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale
199+
RandomZoomTest::test_random_zoom_out_correctness
200+
RegularizersTest::test_orthogonal_regularizer
201+
RNNTest::test_go_backwards
202+
SeparableConvCorrectnessTest::test_separable_conv1d0
203+
SeparableConvCorrectnessTest::test_separable_conv1d1
204+
SeparableConvCorrectnessTest::test_separable_conv1d2
205+
SeparableConvCorrectnessTest::test_separable_conv2d0
206+
SeparableConvCorrectnessTest::test_separable_conv2d1
207+
SeparableConvCorrectnessTest::test_separable_conv2d2
208+
SimpleRNNTest::test_correctness
209+
SimpleRNNTest::test_correctness
210+
SimpleRNNTest::test_masking
211+
SimpleRNNTest::test_masking
212+
SimpleRNNTest::test_pass_initial_state
213+
SimpleRNNTest::test_pass_initial_state
214+
SimpleRNNTest::test_return_state
215+
SimpleRNNTest::test_statefulness
216+
SimpleRNNTest::test_statefulness
217+
StackedRNNTest::test_correctness_single_state_stack
218+
StackedRNNTest::test_correctness_two_states_stack
219+
StackedRNNTest::test_statefullness_single_state_stack
220+
StackedRNNTest::test_statefullness_two_states_stack
221+
TestFitLRSchedulesFlow::test_fit_lr_correctness
222+
TestJaxLayer::test_flax_layer_training_independent_bound_method
223+
TestJaxLayer::test_flax_layer_training_rng_state_no_method
224+
TestJaxLayer::test_flax_layer_training_rng_unbound_method
225+
TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy
226+
TestJaxLayer::test_jax_layer_training_independent
227+
TestJaxLayer::test_jax_layer_training_state
228+
TestJaxLayer::test_jax_layer_training_state_dtype_policy
229+
TestSpectrogram::test_spectrogram_error
230+
TestTrainer::test_loss_weights
231+
TestTrainer::test_nested_inputs
232+
TestTrainer::test_on_batch_methods_eager
233+
TestTrainer::test_on_batch_methods_graph_fn
234+
TestTrainer::test_on_batch_methods_jit

requirements-common.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,4 @@ onnxruntime
2727
# onnxscript==0.3.2 breaks LSTM model export.
2828
onnxscript!=0.3.2
2929
openvino
30-
# for grain_dataset_adapter_test.py
3130
grain

requirements-jax-tpu.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Tensorflow cpu-only version (needed for testing).
2+
tensorflow-cpu~=2.18.1
3+
tf2onnx
4+
5+
# Torch cpu-only version (needed for testing).
6+
--extra-index-url https://download.pytorch.org/whl/cpu
7+
torch==2.6.0
8+
9+
# Jax with cuda support.
10+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
11+
jax[tpu]
12+
flax
13+
14+
-r requirements-common.txt

requirements-tensorflow-tpu.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html
2+
tensorflow-tpu==2.19.1
3+
4+
tf2onnx
5+
6+
# Torch cpu-only version (needed for testing).
7+
--extra-index-url https://download.pytorch.org/whl/cpu
8+
torch==2.6.0
9+
10+
# Jax cpu-only version (needed for testing).
11+
jax
12+
13+
-r requirements-common.txt

0 commit comments

Comments
 (0)