@@ -521,36 +521,18 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
521521
522522def run (config , recorder , diagnostic_config ):
523523 """Run the job given hyperparameters and utilities"""
524- with diagnostic .diagnose (diagnostic_config ):
525- with maybe_record_goodput (recorder , GoodputEvent .JOB ):
526- train_loop (config , recorder )
524+ with (
525+ diagnostic .diagnose (diagnostic_config ),
526+ maybe_record_goodput (recorder , GoodputEvent .JOB ),
527+ max_utils .maybe_get_transformer_engine_context (config )
528+ ):
529+ train_loop (config , recorder )
527530
528531
529- @contextmanager
530- def transformer_engine_context ():
531- """If TransformerEngine is available, this context manager will provide
532- the library with MaxText-specific details needed for correcct operation."""
533- try :
534- from transformer_engine .jax .sharding import global_shard_guard , MeshResource # pylint: disable=import-outside-toplevel
535- # Inform TransformerEngine of MaxText's physical mesh resources.
536- mesh_resource = MeshResource ( # pytype: disable=wrong-arg-types
537- dp_resource = "data" ,
538- tp_resource = "tensor" ,
539- # tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM
540- fsdp_resource = "fsdp" ,
541- pp_resource = None ,
542- cp_resource = "context" ,
543- )
544- with global_shard_guard (mesh_resource ):
545- yield
546- except (ImportError , AttributeError ):
547- yield
548-
549532
550533def main (argv : Sequence [str ]) -> None :
551- with transformer_engine_context ():
552- config , recorder , diagnostic_config = initialize (argv )
553- run (config , recorder , diagnostic_config )
534+ config , recorder , diagnostic_config = initialize (argv )
535+ run (config , recorder , diagnostic_config )
554536
555537
556538if __name__ == "__main__" :
0 commit comments