|
1 | 1 | # Taken and modified pytorch lightening |
2 | 2 | # https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning |
| 3 | +# Taken and modified pytorch lightening |
| 4 | +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning |
3 | 5 | import logging |
4 | 6 | import os |
5 | 7 | import time |
6 | 8 |
|
7 | 9 | import torch |
8 | | -import torch_tensorrt |
| 10 | +import torch.distributed as dist |
9 | 11 | from llama3_model import ModelArgs, ParallelTransformer |
| 12 | +from tensor_parallel_initialize_dist import ( |
| 13 | + cleanup_distributed_env, |
| 14 | + initialize_distributed_env, |
| 15 | +) |
10 | 16 | from torch.distributed._composable.fsdp import MixedPrecisionPolicy |
11 | 17 | from torch.distributed._composable.fsdp.fully_shard import fully_shard |
12 | 18 | from torch.distributed._tensor import Replicate, Shard |
13 | 19 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
14 | 20 | checkpoint_wrapper, |
15 | 21 | ) |
| 22 | + |
| 23 | +if not dist.is_initialized(): |
| 24 | + initialize_distributed_env() |
| 25 | + |
| 26 | +import torch_tensorrt |
16 | 27 | from torch_tensorrt.dynamo.distributed.utils import ( |
17 | | - cleanup_distributed_env, |
18 | 28 | get_tensor_parallel_device_mesh, |
19 | | - initialize_distributed_env, |
20 | 29 | initialize_logger, |
21 | 30 | ) |
22 | 31 |
|
23 | | -if not dist.is_initialized(): |
24 | | - initialize_distributed_env() |
25 | | - |
26 | 32 | device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() |
27 | | -logger = initialize_logger(_rank, "tensor_parallel_simple_example") |
| 33 | +logger = initialize_logger(_rank, "tensor_parallel_llama3") |
28 | 34 |
|
29 | 35 | logger.info(f"Starting PyTorch TP example on rank {_rank}.") |
30 | 36 | assert ( |
|
0 commit comments