Skip to content

Commit 2bc40c4

Browse files
pavanbalajimeta-codesync[bot]
authored andcommitted
Add support for additional datatypes
Summary: We were missing some datatypes supported by NCCLX in TorchComms. Reviewed By: fduwjj, mlunar-meta Differential Revision: D86467916 fbshipit-source-id: 0d0684fecb732b9ce95e2a756a29822f935b217b
1 parent 72ef0df commit 2bc40c4

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,34 @@ namespace {
1515

1616
ncclDataType_t getNcclDataTypeInternal(const at::Tensor& tensor) {
1717
switch (tensor.scalar_type()) {
18+
case at::ScalarType::Byte:
19+
return ncclUint8;
20+
case at::ScalarType::Char:
21+
return ncclInt8;
22+
case at::ScalarType::Int:
23+
return ncclInt32;
24+
case at::ScalarType::Long:
25+
return ncclInt64;
26+
case at::ScalarType::Half:
27+
return ncclFloat16;
1828
case at::ScalarType::Float:
1929
return ncclFloat32;
2030
case at::ScalarType::Double:
2131
return ncclFloat64;
22-
case at::ScalarType::Half:
23-
return ncclFloat16;
32+
case at::ScalarType::Bool:
33+
return ncclUint8;
2434
case at::ScalarType::BFloat16:
2535
return ncclBfloat16;
26-
case at::ScalarType::Int:
27-
return ncclInt32;
28-
case at::ScalarType::Long:
29-
return ncclInt64;
30-
case at::ScalarType::Char:
31-
return ncclInt8;
32-
case at::ScalarType::Byte:
33-
return ncclUint8;
36+
case at::ScalarType::Float8_e5m2:
37+
return ncclFloat8e5m2;
38+
case at::ScalarType::Float8_e4m3fn:
39+
return ncclFloat8e4m3;
40+
case at::ScalarType::UInt32:
41+
return ncclUint32;
42+
case at::ScalarType::UInt64:
43+
return ncclUint64;
3444
default:
35-
throw std::runtime_error("Unsupported tensor data type for NCCL");
45+
throw std::runtime_error("Unsupported tensor data type for NCCLX");
3646
}
3747
}
3848

0 commit comments

Comments
 (0)