File tree Expand file tree Collapse file tree 2 files changed +6
-0
lines changed
end_to_end/tpu/gpt_oss/20b Expand file tree Collapse file tree 2 files changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -49,6 +49,8 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
4949# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
5050export DATASET_PATH=gs://maxtext-dataset
5151
52+ export LIBTPU_INIT_ARGS=' --xla_tpu_scoped_vmem_limit_kib=81920'
53+
5254# Test whether the forward pass logits match the golden logits
5355# default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
5456python3 -m tests.forward_pass_logit_checker " ${MAXTEXT_PKG_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ MaxText} /" configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4
Original file line number Diff line number Diff line change @@ -61,6 +61,10 @@ class MatmulPrecision(str, Enum):
6161 DEFAULT = "default"
6262 HIGH = "high"
6363 HIGHEST = "highest"
64+ # same as default
65+ BFLOAT16 = "bfloat16"
66+ # same as highest
67+ FLOAT32 = "float32"
6468
6569
6670class QuantizationType (str , Enum ):
You can’t perform that action at this time.
0 commit comments