|
3 | 3 | import copy |
4 | 4 | import torch |
5 | 5 | import torchao |
| 6 | +import os |
6 | 7 |
|
7 | 8 | from torch.testing._internal import common_utils |
8 | 9 | from torchao.dtypes import AffineQuantizedTensor |
9 | 10 | from torchao.dtypes import to_affine_quantized_intx |
10 | 11 | from torchao.quantization.quant_primitives import MappingType |
| 12 | +from torchao.quantization import quantize_, int8_weight_only |
| 13 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
11 | 14 |
|
12 | 15 | """ |
13 | 16 | How to use: |
@@ -213,10 +216,122 @@ def test_linear_compile(self, device, dtype): |
213 | 216 | lp_res = torch.compile(l)(hp_act_tensor) |
214 | 217 | self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) |
215 | 218 |
|
| 219 | +import torch.distributed as dist |
| 220 | +from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh |
| 221 | +from torch.testing._internal.distributed._tensor.common_dtensor import ( |
| 222 | + DTensorTestBase, |
| 223 | + with_comms, |
| 224 | + NUM_DEVICES, |
| 225 | +) |
| 226 | + |
| 227 | +class TorchAOTensorParallelTestCase(DTensorTestBase): |
| 228 | + """Basic test case for tensor subclasses |
| 229 | + """ |
| 230 | + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] |
| 231 | + |
| 232 | + TENSOR_SUBCLASS = AffineQuantizedTensor |
| 233 | + QUANT_METHOD_FN = staticmethod(int8_weight_only) |
| 234 | + QUANT_METHOD_KWARGS = {} |
216 | 235 |
|
| 236 | + @staticmethod |
| 237 | + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: |
| 238 | + """ |
| 239 | + Shard linear layer of the model in column-wise fashion |
| 240 | + """ |
| 241 | + # Column-wise is wrt to A^T, so for A it is row-wise. |
| 242 | + # Number of rows per rank |
| 243 | + orig_weight = m.linear.weight |
| 244 | + n_local_rows = orig_weight.size(0) // mesh.size() |
| 245 | + rank = mesh.get_local_rank() |
| 246 | + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] |
| 247 | + # Construct DTensor from local shard |
| 248 | + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) |
| 249 | + # Replace parameter in module |
| 250 | + m.linear.weight = torch.nn.Parameter( |
| 251 | + dtensor, requires_grad=False |
| 252 | + ) |
| 253 | + return m |
| 254 | + |
| 255 | + @staticmethod |
| 256 | + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: |
| 257 | + """ |
| 258 | + Shard linear layer of the model in row-wise fashion |
| 259 | + """ |
| 260 | + # Row-wise is wrt to A^T, so for A it is column-wise. |
| 261 | + # Number of rows per rank |
| 262 | + orig_weight = m.linear.weight |
| 263 | + n_local_cols = orig_weight.size(1) // mesh.size() |
| 264 | + rank = mesh.get_local_rank() |
| 265 | + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] |
| 266 | + # Construct DTensor from local shard |
| 267 | + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) |
| 268 | + # Replace parameter in module |
| 269 | + m.linear.weight = torch.nn.Parameter( |
| 270 | + dtensor, requires_grad=False |
| 271 | + ) |
| 272 | + return m |
| 273 | + |
| 274 | + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: |
| 275 | + """ |
| 276 | + Quantize the model |
| 277 | + """ |
| 278 | + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) |
| 279 | + return m |
| 280 | + |
| 281 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 282 | + @with_comms |
| 283 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 284 | + def test_tp(self, dtype): |
| 285 | + device = "cuda" |
| 286 | + # To make sure different ranks create the same module |
| 287 | + torch.manual_seed(5) |
| 288 | + |
| 289 | + class M(torch.nn.Module): |
| 290 | + def __init__(self, in_features, out_features, **kwargs) -> None: |
| 291 | + super().__init__(**kwargs) |
| 292 | + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") |
| 293 | + |
| 294 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 295 | + return self.linear(x) |
| 296 | + |
| 297 | + # Get rank and device |
| 298 | + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") |
| 299 | + |
| 300 | + # Original model |
| 301 | + proj_up = M(1024, 2048).to(device).to(dtype) |
| 302 | + proj_dn = M(2048, 1024).to(device).to(dtype) |
| 303 | + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) |
| 304 | + y = proj_dn(proj_up(example_input)) |
| 305 | + |
| 306 | + # Quantize the model |
| 307 | + up_quant = self.quantize(proj_up) |
| 308 | + dn_quant = self.quantize(proj_dn) |
| 309 | + y_q = dn_quant(up_quant(example_input)) |
| 310 | + |
| 311 | + mesh = self.build_device_mesh() |
| 312 | + # Shard the models |
| 313 | + up_dist = self.colwise_shard(up_quant, mesh) |
| 314 | + dn_dist = self.rowwise_shard(dn_quant, mesh) |
| 315 | + |
| 316 | + # We need to turn inputs into DTensor form as well -- just a format change |
| 317 | + input_dtensor = DTensor.from_local( |
| 318 | + example_input, mesh, [Replicate()] |
| 319 | + ) |
| 320 | + |
| 321 | + y_d = dn_dist(up_dist(input_dtensor)) |
| 322 | + |
| 323 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 324 | + # Need torch 2.5 to support compiled tensor parallelism |
| 325 | + return |
| 326 | + |
| 327 | + up_compiled = torch.compile(up_dist) |
| 328 | + y_up = up_compiled(input_dtensor) |
| 329 | + dn_compiled = torch.compile(dn_dist) |
| 330 | + y_dn = dn_compiled(y_up) |
217 | 331 |
|
218 | 332 | common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) |
219 | 333 | common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) |
| 334 | +common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase) |
220 | 335 |
|
221 | 336 | if __name__ == "__main__": |
222 | 337 | unittest.main() |
0 commit comments