diff --git a/end_to_end/gpu/te/README.md b/end_to_end/gpu/te/README.md deleted file mode 100644 index ce9f4d66a..000000000 --- a/end_to_end/gpu/te/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# MaxText + Transformer Engine E2E Benchmarking - -This directory contains scripts for testing MaxText with Transformer Engine (TE) integration across different parallelization strategies and quantization recipes. - -Requirements: -- NVIDIA MaxText image with installed Transformer Engine (TE). Suggested to use the latest version of `ghcr.io/nvidia/jax:maxtext`. -- `test-maxtext.sh` script which is available in the suggested image. Otherwise, you can get it (here)[https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/test-maxtext.sh]. -- NVIDIA GPU(s) with compute capability 9.0 or higher for FP8 quantization, 10.0 or higher for MXFP8 quantization. - -## Quick Start - -### 1. Run Individual Tests - -#### MaxText Baseline with FP8 -```bash -MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=fp8 --model llama3.1-8b --steps 100 -``` - -#### TE with DelayedScaling FP8 -```bash -MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100 -``` - -#### TE with CurrentScaling FP8 -```bash -MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_currentscaling --model llama3.1-8b --steps 100 -``` - -#### TE with MXFP8 Block Scaling -```bash -MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_mxfp8 --model llama3.1-8b --steps 100 -``` - -#### Enable Profiling/Tracing -Add profiling arguments to collect XPlane traces (only the last step is traced): -```bash -MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100 --additional-args="profiler=xplane skip_first_n_steps_for_profiler=99 profiler_steps=1" -``` - -### 2. Run Comprehensive Benchmarking - -The `run_single_node_model_parallel.sh` script automatically tests all quantization recipes across multiple parallelization strategies: - -#### Basic Usage -```bash -bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 -``` - -#### With Tracing Enabled -```bash -bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true -``` - -#### Collecting traces with custom number of decoder layers -```bash -bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true --num-decoder-layers 4 -``` - -#### Skip Single GPU Tests -```bash -bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --single-gpu-run false -``` \ No newline at end of file diff --git a/end_to_end/gpu/te/normalize.py b/end_to_end/gpu/te/normalize.py deleted file mode 100644 index 437f8a088..000000000 --- a/end_to_end/gpu/te/normalize.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Normalize the raw results to get the percentage difference from the baseline""" - -# Usage: python normalize.py input_raw_results.csv output_summary.{csv|txt} format -# format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated - -import csv -import sys - -if len(sys.argv) < 4: - print("Usage: normalize.py input_raw_results.csv output_summary.{csv|txt} format") - print(" format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated") - sys.exit(1) - -input_csv = sys.argv[1] -output_file = sys.argv[2] -format_type = sys.argv[3].lower() - -data = {} -key_order = [] # preserve order of keys - -# Read input TSV -with open(input_csv, encoding="utf-8") as f: - reader = csv.DictReader(f, delimiter="\t") - for row in reader: - key = tuple(row.get(k) if row.get(k) not in [None, ""] else "NA" for k in ["dp", "tpsp", "fsdp"]) - if not row.get("test"): - continue - if key not in data: - data[key] = {} - key_order.append(key) # remember when first seen - data[key][row["test"]] = row - -header = ["test", "dp", "tpsp", "fsdp", "mean", "stddev", "normalized"] -rows = [] - -# iterate keys in first-seen order -for key in key_order: - rowset = data[key] - baseline = rowset.get("fp8", {}) - base_mean = baseline.get("mean", "NA") - try: - base_mean_val = float(base_mean) - has_baseline = True - except ValueError: - base_mean_val = 1.0 # dummy value for pylint - has_baseline = False - - # iterate tests in first-seen order - for testname in rowset: - row = rowset[testname] - mean = row["mean"] - stddev = row["stddev"] - if mean == "NA": - normalized = "-" - elif testname == "fp8": - testname = "maxtext_fp8" - normalized = "0.00%" if has_baseline else "-" - elif has_baseline and mean != "NA": - try: - normalized = f"{(float(mean) / base_mean_val - 1) * 100:.2f}%" - except ValueError: - normalized = "-" - else: - normalized = "-" - rows.append( - [ - testname, - row["dp"], - row["tpsp"], - row["fsdp"], - mean, - stddev, - normalized, - ] - ) - -if format_type in ("csv",): - with open(output_file, "a", newline="", encoding="utf-8") as out: - writer = csv.writer(out) - writer.writerow(header) - writer.writerows(rows) -elif format_type in ("txt", "tsv"): - with open(output_file, "a", encoding="utf-8") as out: - out.write("\t".join(header) + "\n") - for r in rows: - out.write("\t".join(r) + "\n") -else: - print("Invalid format type! Use 'csv' or 'txt'/'tsv'.") - sys.exit(2) - -print(f"Done. Wrote summary to {output_file} as {format_type}.") diff --git a/end_to_end/gpu/te/plot_loss_curves.py b/end_to_end/gpu/te/plot_loss_curves.py deleted file mode 100644 index 4d91c9b5e..000000000 --- a/end_to_end/gpu/te/plot_loss_curves.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Plot loss curves from training logs """ - -# Usage: python plot_loss_curves.py logdir - -import re -import os -import sys -import argparse - -import matplotlib.pyplot as plt - - -def parse_loss_data(file_path): - """ - Parses a text file for lines matching the pattern: - completed step: , seconds: , TFLOP/s/device: , - Tokens/s/device: , total_weights: , loss: - Returns a list of tuples with the extracted values. - """ - pattern = re.compile( - r"completed step: (\d+), seconds: ([\d.]+), TFLOP/s/device: ([\d.]+), Tokens/s/device: ([\d.]+), total_weights: (\d+), loss: ([\d.]+)" # pylint: disable=line-too-long - ) - results = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - match = pattern.search(line) - if match: - step = int(match.group(1)) - seconds = float(match.group(2)) - tflops = float(match.group(3)) - tokens_per_sec = float(match.group(4)) - total_weights = int(match.group(5)) - loss = float(match.group(6)) - results.append((step, seconds, tflops, tokens_per_sec, total_weights, loss)) - return results - - -def main(args): - parser = argparse.ArgumentParser(description="Plot training loss curve from log files.") - parser.add_argument("logdir", type=str, help="Directory containing training log files.") - parsed_args = parser.parse_args(args) - - logdir = parsed_args.logdir - log_files = [ - os.path.join(logdir, f) - for f in os.listdir(logdir) - if os.path.isfile(os.path.join(logdir, f)) and f.endswith(".log") - ] - - # Extract parallelism configs from filenames - config_pattern = re.compile(r"dp(\d+)_tpsp(\d+)_fsdp(\d+)") - configs = {} - for log_file in log_files: - fname = os.path.basename(log_file) - match = config_pattern.search(fname) - if match: - dp, tpsp, fsdp = match.groups() - key = (int(dp), int(tpsp), int(fsdp)) - configs.setdefault(key, []).append(log_file) - - # Plot for each config - for (dp, tpsp, fsdp), files in configs.items(): - plt.figure(figsize=(8, 5)) - for log_file in files: - data = parse_loss_data(log_file) - if not data: - continue - steps = [item[0] for item in data] - losses = [item[5] for item in data] - plt.plot( - steps, - losses, - marker="", - linestyle="-", - label=os.path.basename(log_file), - ) - plt.legend() - plt.xlabel("Step") - plt.ylabel("Loss") - plt.title(f"Loss Curves (dp={dp}, tpsp={tpsp}, fsdp={fsdp})") - plt.grid(True) - plt.tight_layout() - out_image_path = f"loss_curves_dp{dp}_tpsp{tpsp}_fsdp{fsdp}.png" - plt.savefig(out_image_path) - print(f"Saved plot to {out_image_path}") - plt.close() - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/end_to_end/gpu/te/run_single_node_model_parallel.sh b/end_to_end/gpu/te/run_single_node_model_parallel.sh deleted file mode 100755 index 607219fef..000000000 --- a/end_to_end/gpu/te/run_single_node_model_parallel.sh +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Run MaxText with Transformer Engine (TE) across different parallelization strategies and quantization recipes -# Usage: bash run_single_node_model_parallel.sh --model MODEL --output-dir-tag OUTPUT_DIR_TAG --trace true|false --steps STEPS --single-gpu-run true|false --num-decoder-layers N_LAYERS - -#!/bin/bash -set -euo pipefail - -# Default values -MODEL="llama3.1-8b" -OUTPUT_DIR_TAG="" -STEPS=50 -TRACE=false -SINGLE_GPU_RUNS=true -NUM_DECODER_LAYERS="" # unset - -# Parse keyword-style arguments -while [[ $# -gt 0 ]]; do - case "$1" in - --model) - MODEL="$2" - shift 2 - ;; - --output-dir-tag) - OUTPUT_DIR_TAG="$2" - shift 2 - ;; - --trace) - TRACE="$2" - shift 2 - ;; - --steps) - STEPS="$2" - shift 2 - ;; - --single-gpu-run) - SINGLE_GPU_RUNS="$2" - shift 2 - ;; - --num-decoder-layers) - NUM_DECODER_LAYERS="$2" - shift 2 - ;; - -h|--help) - echo "Usage: $0 [--model MODEL] [--output-dir-tag OUTPUT_DIR_TAG] [--trace true|false] [--steps STEPS] [--single-gpu-run true|false] [--num-decoder-layers N_LAYERS]" - exit 0 - ;; - *) - echo "Unknown argument: $1" - echo "Usage: $0 [--model MODEL] [--output-dir-tag OUTPUT_DIR_TAG] [--trace true|false] [--steps STEPS] [--single-gpu-run true|false] [--num-decoder-layers N_LAYERS]" - exit 1 - ;; - esac -done - -if [[ "$TRACE" == "true" ]]; then - OUTPUT_DIR_TAG="trace${OUTPUT_DIR_TAG:+_$OUTPUT_DIR_TAG}" -fi - -# Now your variables are set as needed -echo "MODEL=$MODEL" -echo "OUTPUT_DIR_TAG=$OUTPUT_DIR_TAG" -echo "TRACE=$TRACE" -echo "STEPS=$STEPS" -echo "SINGLE_GPU_RUNS=$SINGLE_GPU_RUNS" - -WARMUP_STEPS=10 -if (( STEPS <= WARMUP_STEPS )); then - echo "ERROR: STEPS ($STEPS) must be greater than WARMUP_STEPS ($WARMUP_STEPS)" - exit 1 -fi - -TIMESTAMP=$(date +%Y%m%d_%H%M%S) -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -MAXTEXT_DIR="$(realpath "$SCRIPT_DIR/../../../")" -OUTPUT_DIR="${SCRIPT_DIR}/output/${MODEL}${NUM_DECODER_LAYERS:+_${NUM_DECODER_LAYERS}_layers}${OUTPUT_DIR_TAG:+_$OUTPUT_DIR_TAG}_${TIMESTAMP}" -mkdir -p "$OUTPUT_DIR" - -n_gpus=$(nvidia-smi -L | wc -l) -half_gpus=$((n_gpus / 2)) -# List of experiments: -experiments=( - "1 1 1" # Single GPU - "$n_gpus 1 1" # Full DP - "1 $n_gpus 1" # Full TPSP - "2 $half_gpus 1" # DP=2, TPSP=half GPUs - "1 1 $n_gpus" # Full FSDP - "1 $half_gpus 2" # FSDP=2, TPSP=half GPUs -) - -CSV="$OUTPUT_DIR/raw_results.csv" -echo -e "test\tdp\ttpsp\tfsdp\tmean\tstddev" > "$CSV" - -run_and_parse() { - local test="$1" - local dp="$2" - local tpsp="$3" - local fsdp="$4" - set +e - local cmd="$5" - set -e - local stdout="$OUTPUT_DIR/run_${test}_dp${dp}_tpsp${tpsp}_fsdp${fsdp}.log" - echo "===== Executing ${test}\t${dp}\t${tpsp}\t${fsdp} =====" - eval "$cmd" 2>&1 | tee "$stdout" - # Exclude the warning steps for warning up and last step for tracing - ths=$(grep 'Tokens/s/device:' "$stdout" | sed '1,'"${WARMUP_STEPS}"'d;$d' | awk -F'Tokens/s/device: ' '{print $2}' | awk -F',' '{print $1}') - - if [ -z "$ths" ]; then - mean="NA" - stddev="NA" - else - mean_stddev=$(echo "$ths" | python3 -c "import sys; import numpy as np -arr = [float(l.strip()) for l in sys.stdin if l.strip()] -if arr: - print(f'{np.mean(arr):.2f}\t{np.std(arr, ddof=1):.2f}') -else: - print('NA\tNA') -" - ) - mean=$(echo "$mean_stddev" | cut -f1) - stddev=$(echo "$mean_stddev" | cut -f2) - fi - echo -e "${test}\t${dp}\t${tpsp}\t${fsdp}\t${mean}\t${stddev}" >> "$CSV" - - if [[ "$TRACE" == "true" ]]; then - TRACE_SRC=$(grep -oE '/tmp/tmp\.[^ ]+' "$stdout" | head -n1) - if [[ -n "$TRACE_SRC" && -e "$TRACE_SRC" ]]; then - TRACE_DEST="${OUTPUT_DIR}/trace_${test}_dp${dp}_tpsp${tpsp}_fsdp${fsdp}" - mv "$TRACE_SRC" "$TRACE_DEST" - echo " === Trace moved: $TRACE_SRC -> $TRACE_DEST" - else - echo "=== No trace file found for $test, dp=$dp, tpsp=$tpsp, fsdp=$fsdp" - fi - fi -} - -PROFILE_SKIP_STEPS=$(($STEPS-1)) -PROFILE_ARG="" -original_num_decoder_layers=1 -if [[ "$TRACE" == "true" ]]; then - PROFILE_ARG="profiler=xplane skip_first_n_steps_for_profiler=${PROFILE_SKIP_STEPS} profiler_steps=1" -fi -# Updating the model config file as we can't pass base_num_decoder_layers=1 in additional-args -if [ -n "$NUM_DECODER_LAYERS" ]; then - MODEL_CONFIG="$MAXTEXT_DIR/MaxText/configs/models/$MODEL.yml" - original_num_decoder_layers=$(grep "base_num_decoder_layers" "$MODEL_CONFIG" | awk -F': ' '{print $2}') - sed -i "s/base_num_decoder_layers: .*/base_num_decoder_layers: $NUM_DECODER_LAYERS/" "$MODEL_CONFIG" - echo "=== Setting base_num_decoder_layers=$NUM_DECODER_LAYERS in $MODEL_CONFIG" -fi - -# Updating the model config file back if modified -restore_model_config_file() { - if [ -n "$NUM_DECODER_LAYERS" ]; then - sed -i "s/base_num_decoder_layers: .*/base_num_decoder_layers: ${original_num_decoder_layers}/" "$MODEL_CONFIG" - echo "=== Restoring base_num_decoder_layers back to ${original_num_decoder_layers} in $MODEL_CONFIG" - fi -} -trap restore_model_config_file EXIT - -BASE_ARGS="--model $MODEL --steps $STEPS" -# Need to be with four escape quotes -OTHER_ARGS="--additional-args=\"\"\"${PROFILE_ARG}\"\"\"" -TRAINING_RECIPES=("fp8" "te_fp8_delayedscaling" "te_fp8_currentscaling" "te_mxfp8" "te_nvfp4") # fp8 is the MaxText baseline - -export NVTE_JAX_CUSTOM_CALLS='NormFwdPrimitive=false,NormBwdPrimitive=false' - -if [[ "$SINGLE_GPU_RUNS" == "false" ]]; then - start_index=1 -else - start_index=0 -fi -for ((i = start_index; i < ${#experiments[@]}; i++)); do - exp="${experiments[$i]}" - echo "Running experiment: $exp" - read dp tpsp fsdp <<< "$exp" - - n_used_gpus=$((dp * tpsp * fsdp)) - if (( n_used_gpus > n_gpus )); then - echo "Error: requested $n_used_gpus GPUs, but only $n_gpus are available." - exit 1 - fi - CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((n_used_gpus - 1))) - export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES - echo "=== Using GPUs: $CUDA_VISIBLE_DEVICES" - - args="--data-parallel=$dp --tensor-sequence-parallel=$tpsp --fsdp=$fsdp" - - for recipe in "${TRAINING_RECIPES[@]}"; do - test="${recipe}" - run_and_parse "$test" "$dp" "$tpsp" "$fsdp" \ - "MAXTEXT_DIR=${MAXTEXT_DIR} bash test-maxtext.sh $args --quantization=${recipe} $BASE_ARGS ${OTHER_ARGS}" - done -done - - -OUTPUT_FORMAT="txt" # txt or csv -echo "=== Experiments finished. Raw CSV at $CSV" -python3 $SCRIPT_DIR/normalize.py "$CSV" "${OUTPUT_DIR}/summary.$OUTPUT_FORMAT" "$OUTPUT_FORMAT" -cat "${OUTPUT_DIR}/summary.$OUTPUT_FORMAT" diff --git a/src/MaxText/layers/linears.py b/src/MaxText/layers/linears.py index 338cafaaf..2808d6af3 100644 --- a/src/MaxText/layers/linears.py +++ b/src/MaxText/layers/linears.py @@ -185,8 +185,7 @@ def __init__( quant_dot_general = nnx_wrappers.ToNNX(dot_general_linen, rngs=rngs) self._quant_dot_general_name = f"{type(dot_general_linen).__name__}_0" setattr(self, self._quant_dot_general_name, quant_dot_general) - block_size = getattr(quant, "get_block_size", lambda: 1)() # needed for TE MXFP8 - dummy_inputs = jnp.zeros((block_size, *self.in_features_shape), dtype=self.dtype) + dummy_inputs = jnp.zeros((1, *self.in_features_shape), dtype=self.dtype) self(dummy_inputs, _initializing=True) else: self._quant_dot_general_name = None diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index 7e343bf11..b89ef7bf3 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -520,8 +520,6 @@ def _get_quant_config(config): return _get_aqt_fp8_quant_config(config) if config.quantization == "aqt_fp8_full": return _get_aqt_fp8_default_config(config) - if config.quantization.startswith("te_"): - return config.quantization raise ValueError(f"Invalid value configured for quantization {config.quantization}.") @@ -557,8 +555,6 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"): return Fp8Quantization() elif quant_cfg == "nanoo_fp8": return NANOOFp8Quantization() - elif isinstance(quant_cfg, str) and quant_cfg.startswith("te_"): - return TransformerEngineQuantization(config) quant_mode = get_quant_mode(quant_mode_str) replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale) @@ -719,124 +715,3 @@ def maybe_quantize_model(model, config): if quantization_provider: model = qwix.quantize_model(model, quantization_provider) return model - - -class TransformerEngineQuantization(Quantization): - """Class for TransformerEngine quantization recipes.""" - - def __init__(self, config): - """Initialize TransformerEngine quantization.""" - - self.quant_mode = "train" - - if not config.quantization.startswith("te_"): - raise ValueError(f"Invalid TransformerEngine quantization config: {config.quantization}") - - self._recipe = TransformerEngineQuantization._get_recipe(config.quantization) - - def __hash__(self): - return hash((self.quant_mode, self._recipe)) - - def __eq__(self, other): - if not isinstance(other, TransformerEngineQuantization): - return False - return (self.quant_mode, self._recipe) == (other.quant_mode, other._recipe) - - @staticmethod - def _get_recipe(recipe_name: str): - """Get the TransformerEngine recipe based on the name.""" - from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - RECIPES = { - "te_fp8_delayedscaling": recipe.DelayedScaling, - "te_fp8_currentscaling": recipe.Float8CurrentScaling, - "te_mxfp8": recipe.MXFP8BlockScaling, - "te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr - } - if recipe_name not in RECIPES: - raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}") - return RECIPES[recipe_name]() - - def get_block_size(self): - """Get the block size for quantization for recipes that require blocks. - - If there is no block requirement for the current recipe, returns 1. - """ - from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - if isinstance(self._recipe, recipe.MXFP8BlockScaling): - return 32 - if isinstance(self._recipe, recipe.NVFP4BlockScaling): # pytype: disable=module-attr - return 128 # TODO(set this to 16 when unfused RHT is supported) - return 1 - - def _wrap(self, f, name=None): - """Wraps the given function `f` to support TransformerEngine quantization. - - This method does a couple things: - - - 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to - TransformerEngine. This ensures our collective operations in TransformerEngine are using - the correct axes. - - 2. Wraps the given function in a Flax linen module. This module does not store any Flax - parameters but can store Flax variables for quantizers if required by the recipe. - - 3. When the wrapper is called, it provides an additional argument to the given function `f`, - 'generate_quantizer_set' as the first argument. 'generate_quantizer_set' is a function that - can be called to generate a TransformerEngine/JAX quantizer set object used in - TransformerEngine/JAX APIs. 'generate_quantizer_set' will generate quantizers based on the - recipe of this TransformerEngineQuantizer object. - - Args: - f: The function to wrap. The first argument must be 'generate_quantizer_set'. - name: The name of this wrapped operation. If unspecified, will use `f.__name__`. - - Returns: - A Flax linen module that wraps the given function. - """ - - import transformer_engine.jax as te # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - fp8_recipe = self._recipe - - class TEWrapper(te.flax.module.TransformerEngineBase): - """Wrapper module for TransformerEngine quantization.""" - - def generate_quantizer_set(self, postfix: str = ""): - OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" - return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args - postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=fp8_recipe - ) - - @nn.compact - def __call__(self, *args, **kwargs): - return f(self.generate_quantizer_set, *args, **kwargs) - - TEWrapper.__name__ = f"TEWrapper_{name if name else f.__name__}" - - return TEWrapper - - def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): - """Placeholder for dot_general implementation in subclasses.""" - import transformer_engine.jax as te # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): - contracting_dims, batch_dims = dims - assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot." - - quantizer_set = generate_quantizer_set() - return te.dense.dense( - x, - kernel, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set, - ) - - return self._wrap(te_dot_general, "dot_general") - - def einsum(self, dtype: DType = jnp.float32): - """Placeholder for einsum implementation in subclasses.""" - # quant.einsum is only required for MoE or for inference with KVCache. - raise ValueError("Einsum is not yet supported for TransformerEngine quantization.") diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 62d896d3d..5b4c0df85 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -344,18 +344,7 @@ def validate_constant_bound(keys): def validate_quantization_methods(keys): """Validate quantization methods""" - valid_quant_methods = ( - "", - "int8", - "fp8", - "fp8_full", - "fp8_gpu", - "fp8_nanoo", - "te_fp8_delayedscaling", - "te_fp8_currentscaling", - "te_mxfp8", - "te_nvfp4", - ) + valid_quant_methods = ("", "int8", "fp8", "fp8_full", "fp8_gpu", "fp8_nanoo") if keys["use_qwix_quantization"]: if keys["quantization"] not in valid_quant_methods: raise ValueError(f"Invalid quantization method {keys['quantization']}. Valid options are {valid_quant_methods}") diff --git a/src/MaxText/train.py b/src/MaxText/train.py index bb48a2c72..48550a1d9 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -19,7 +19,6 @@ # See github.com/google/maxtext/issues/20 for more from typing import Any, Sequence -from contextlib import contextmanager import datetime import functools import os @@ -526,31 +525,9 @@ def run(config, recorder, diagnostic_config): train_loop(config, recorder) -@contextmanager -def transformer_engine_context(): - """If TransformerEngine is available, this context manager will provide - the library with MaxText-specific details needed for correcct operation.""" - try: - from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pylint: disable=import-outside-toplevel - # Inform TransformerEngine of MaxText's physical mesh resources. - mesh_resource = MeshResource( # pytype: disable=wrong-arg-types - dp_resource="data", - tp_resource="tensor", - # tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM - fsdp_resource="fsdp", - pp_resource=None, - cp_resource="context", - ) - with global_shard_guard(mesh_resource): - yield - except ImportError: - yield - - def main(argv: Sequence[str]) -> None: - with transformer_engine_context(): - config, recorder, diagnostic_config = initialize(argv) - run(config, recorder, diagnostic_config) + config, recorder, diagnostic_config = initialize(argv) + run(config, recorder, diagnostic_config) if __name__ == "__main__": diff --git a/tests/integration_tests/train_tests.py b/tests/integration_tests/train_tests.py index d29338e50..942112765 100644 --- a/tests/integration_tests/train_tests.py +++ b/tests/integration_tests/train_tests.py @@ -109,42 +109,6 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", ], - "te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling - None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), - "base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - "dataset_path=gs://maxtext-dataset", - "quantization=te_fp8_delayedscaling", - "steps=2", - "enable_checkpointing=False", - "enable_goodput_recording=False", - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - ], - "te_fp8_currentscaling": [ # tests base config with te_fp8_currentscaling - None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), - "base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - "dataset_path=gs://maxtext-dataset", - "quantization=te_fp8_currentscaling", - "steps=2", - "enable_checkpointing=False", - "enable_goodput_recording=False", - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - ], - "te_mxfp8": [ # tests base config with te_mxfp8 - None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), - "base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - "dataset_path=gs://maxtext-dataset", - "quantization=te_mxfp8", - "steps=2", - "enable_checkpointing=False", - "enable_goodput_recording=False", - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - ], "dropout": [ # tests base config with dropout None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), @@ -234,30 +198,6 @@ def test_gpu_fp8(self): def test_gpu_nanoo_fp8(self): train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"]) - @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") - @pytest.mark.integration_test - @pytest.mark.gpu_only - def test_gpu_te_fp8_delayedscaling(self): - train_main(TrainTests.CONFIGS["te_fp8_delayedscaling"] + ["attention=dot_product"]) - - @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") - @pytest.mark.integration_test - @pytest.mark.gpu_only - def test_gpu_te_fp8_currentscaling(self): - train_main(TrainTests.CONFIGS["te_fp8_currentscaling"] + ["attention=dot_product"]) - - @pytest.mark.skip(reason="No runner with GPU arch >= 100 is available") - @pytest.mark.integration_test - @pytest.mark.gpu_only - def test_gpu_te_mxfp8(self): - train_main(TrainTests.CONFIGS["te_mxfp8"] + ["attention=dot_product"]) - - @pytest.mark.skip(reason="No runner with GPU arch >= 100 is available") - @pytest.mark.integration_test - @pytest.mark.gpu_only - def test_gpu_te_nvfp4(self): - train_main(TrainTests.CONFIGS["te_nvfp4"] + ["attention=dot_product"]) - @pytest.mark.integration_test @pytest.mark.tpu_only def test_tpu_dropout(self): diff --git a/tests/quantizations_test.py b/tests/quantizations_test.py index 002649a62..993734100 100644 --- a/tests/quantizations_test.py +++ b/tests/quantizations_test.py @@ -383,26 +383,6 @@ def test_fp8_gpu_quantization(self): def test_fp8_nanoo_quantization(self): self.quantization_config("fp8_nanoo", grad_tolerance=1.0) - @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") - @pytest.mark.gpu_only - def test_fp8_te_fp8_delayedscaling_quantization(self): - self.quantization_config("te_fp8_delayedscaling", grad_tolerance=1.0) - - @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") - @pytest.mark.gpu_only - def test_fp8_te_fp8_currentscaling_quantization(self): - self.quantization_config("te_fp8_currentscaling", grad_tolerance=1.0) - - @pytest.mark.skip(reason="No runner with GPU arch >= 100 is available") - @pytest.mark.gpu_only - def test_fp8_te_mxfp8_quantization(self): - self.quantization_config("te_mxfp8", grad_tolerance=1.0) - - @pytest.mark.skip(reason="No runner with GPU arch >= 100 is available") - @pytest.mark.gpu_only - def test_fp8_te_nvfp4_quantization(self): - self.quantization_config("te_nvfp4", grad_tolerance=1.0) - @pytest.mark.parametrize( "group_sizes,k,n,tiling,dtype",