Skip to content

Commit 32380ea

Browse files
Merge pull request #2712 from AI-Hypercomputer:shuningjin-fix
PiperOrigin-RevId: 834025121
2 parents 745d09d + 5248e89 commit 32380ea

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
4949
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
5050
export DATASET_PATH=gs://maxtext-dataset
5151

52+
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920'
53+
5254
# Test whether the forward pass logits match the golden logits
5355
# default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
5456
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class MatmulPrecision(str, Enum):
6161
DEFAULT = "default"
6262
HIGH = "high"
6363
HIGHEST = "highest"
64+
# same as default
65+
BFLOAT16 = "bfloat16"
66+
# same as highest
67+
FLOAT32 = "float32"
6468

6569

6670
class QuantizationType(str, Enum):

0 commit comments

Comments
 (0)