|
77 | 77 | # create a sharding plan based on the given world_size. |
78 | 78 | dp_size = _world_size // tp_size |
79 | 79 |
|
| 80 | +device_type = torch.accelerator.current_accelerator().type |
80 | 81 | # Create a device mesh with 2 dimensions. |
81 | 82 | # First dim is the data parallel dimension |
82 | 83 | # Second dim is the tensor parallel dimension. |
83 | | -device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
| 84 | +device_mesh = init_device_mesh(device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
84 | 85 |
|
85 | 86 | rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") |
86 | 87 | tp_mesh = device_mesh["tp"] |
|
92 | 93 | # to mimic the behavior of the dataloader. |
93 | 94 | dp_rank = dp_mesh.get_local_rank() |
94 | 95 |
|
95 | | -# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. |
| 96 | +# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids. |
96 | 97 | simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000) |
97 | 98 |
|
98 | | -model = Transformer.from_model_args(simple_llama2_config).to("cuda") |
| 99 | +model = Transformer.from_model_args(simple_llama2_config).to(device_type) |
99 | 100 |
|
100 | 101 | # init model weights |
101 | 102 | model.init_weights() |
|
170 | 171 | for i in range(num_iterations): |
171 | 172 | # seeding with dp_rank to ensure identical inputs for TP groups |
172 | 173 | torch.manual_seed(i + dp_rank) |
173 | | - inp = torch.randint(32000, (8, 256), device="cuda") |
| 174 | + inp = torch.randint(32000, (8, 256), device=device_type) |
174 | 175 |
|
175 | 176 | output = sharded_model(inp) |
176 | 177 | output.sum().backward() |
|
0 commit comments