Skip to content

Commit ef64c73

Browse files
Merge pull request #2670 from AI-Hypercomputer:mattdavidow-separate-gpu-context
PiperOrigin-RevId: 831832203
2 parents 78fbeca + 1d4c8aa commit ef64c73

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

src/MaxText/max_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from etils import epath
2929
import flax
3030
import jax
31+
from contextlib import contextmanager
3132
from jax.experimental import mesh_utils
3233
from jax.sharding import PartitionSpec as P
3334
import jax.numpy as jnp
@@ -989,3 +990,38 @@ def get_batch_seq_len_for_mode(config, model_mode):
989990
raise ValueError(f"Unknown model_mode: {model_mode}")
990991

991992
return batch_size, seq_len
993+
994+
@contextmanager
995+
def maybe_get_transformer_engine_context(config):
996+
""" Runs a transformer engine context engine manager for GPUs only. """
997+
if config.hardware in ['gpu', 'gpu_multiprocess']:
998+
with transformer_engine_context():
999+
yield
1000+
else:
1001+
with dummy_context_manager():
1002+
yield
1003+
1004+
@contextmanager
1005+
def dummy_context_manager():
1006+
"""A context manager that does nothing."""
1007+
yield
1008+
1009+
@contextmanager
1010+
def transformer_engine_context():
1011+
"""If TransformerEngine is available, this context manager will provide
1012+
the library with MaxText-specific details needed for correcct operation."""
1013+
try:
1014+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pylint: disable=import-outside-toplevel
1015+
# Inform TransformerEngine of MaxText's physical mesh resources.
1016+
mesh_resource = MeshResource( # pytype: disable=wrong-arg-types
1017+
dp_resource="data",
1018+
tp_resource="tensor",
1019+
# tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM
1020+
fsdp_resource="fsdp",
1021+
pp_resource=None,
1022+
cp_resource="context",
1023+
)
1024+
with global_shard_guard(mesh_resource):
1025+
yield
1026+
except (ImportError, AttributeError):
1027+
yield

src/MaxText/train.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -521,36 +521,18 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
521521

522522
def run(config, recorder, diagnostic_config):
523523
"""Run the job given hyperparameters and utilities"""
524-
with diagnostic.diagnose(diagnostic_config):
525-
with maybe_record_goodput(recorder, GoodputEvent.JOB):
526-
train_loop(config, recorder)
524+
with (
525+
diagnostic.diagnose(diagnostic_config),
526+
maybe_record_goodput(recorder, GoodputEvent.JOB),
527+
max_utils.maybe_get_transformer_engine_context(config)
528+
):
529+
train_loop(config, recorder)
527530

528531

529-
@contextmanager
530-
def transformer_engine_context():
531-
"""If TransformerEngine is available, this context manager will provide
532-
the library with MaxText-specific details needed for correcct operation."""
533-
try:
534-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pylint: disable=import-outside-toplevel
535-
# Inform TransformerEngine of MaxText's physical mesh resources.
536-
mesh_resource = MeshResource( # pytype: disable=wrong-arg-types
537-
dp_resource="data",
538-
tp_resource="tensor",
539-
# tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM
540-
fsdp_resource="fsdp",
541-
pp_resource=None,
542-
cp_resource="context",
543-
)
544-
with global_shard_guard(mesh_resource):
545-
yield
546-
except (ImportError, AttributeError):
547-
yield
548-
549532

550533
def main(argv: Sequence[str]) -> None:
551-
with transformer_engine_context():
552-
config, recorder, diagnostic_config = initialize(argv)
553-
run(config, recorder, diagnostic_config)
534+
config, recorder, diagnostic_config = initialize(argv)
535+
run(config, recorder, diagnostic_config)
554536

555537

556538
if __name__ == "__main__":

0 commit comments

Comments
 (0)