Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ size_t create_seed_from_tensor_name(const std::string& tensor_name) {

std::vector<DType> all_fp_types = {DType::kFloat32,
DType::kFloat16,
DType::kFloat64,
DType::kBFloat16,
DType::kFloat8E5M2,
DType::kFloat8E4M3};
Expand All @@ -57,6 +58,7 @@ const std::string &typeName(DType type) {
{DType::kByte, "byte"},
{DType::kInt32, "int32"},
{DType::kInt64, "int64"},
{DType::kFloat64, "float64"},
{DType::kFloat32, "float32"},
{DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"},
Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp64 = double;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
Expand Down Expand Up @@ -84,9 +85,9 @@ struct BitsNumber {
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp64, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp64, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
#endif

template <typename U, DType current>
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kFloat64:
return CUDA_R_64F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
Expand Down
17 changes: 12 additions & 5 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp64 = double;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
Expand Down Expand Up @@ -349,6 +350,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(double)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
Expand Down Expand Up @@ -421,14 +423,15 @@ struct BitsNumber {
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1
using types =
std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1, fp64
#if CUDA_VERSION >= 12080
,
fp8e8m0
,
fp8e8m0
#endif
>;
>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp64
#if CUDA_VERSION >= 12080
,
fp8e8m0
Expand Down Expand Up @@ -497,6 +500,10 @@ struct TypeInfo {
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat64: { \
using type = double; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
Expand Down
10 changes: 9 additions & 1 deletion transformer_engine/common/fused_router/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
}
}

// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
// Current TE only support float32/bf16/fp16/fp64
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat64: { \
using type = double; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
Expand Down Expand Up @@ -254,6 +258,10 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat64: { \
using type = double; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum NVTEDType {
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTEFloat64 = 11, /*!< 64-bit float */
kNVTENumTypes /*!< Number of supported types */
};

Expand Down Expand Up @@ -418,6 +419,7 @@ enum class DType {
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kFloat4E2M1 = 10,
kFloat64 = 11,
kNumTypes
};

Expand All @@ -443,7 +445,8 @@ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
* \param[in] DType TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
return t == DType::kFloat64 || t == DType::kFloat32 || t == DType::kBFloat16 ||
t == DType::kFloat16;
}

/*! \struct TensorWrapper
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ std::string to_string(const DType type) {
return "Float16";
case DType::kFloat32:
return "Float32";
case DType::kFloat64:
return "Float64";
case DType::kFloat8E4M3:
return "Float8E4M3";
case DType::kFloat8E5M2:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
.value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \
.value("kFloat64", transformer_engine::DType::kFloat64) \
.value("kFloat32", transformer_engine::DType::kFloat32) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline size_t typeToNumBits(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt64:
case transformer_engine::DType::kFloat64:
return 64;
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
Expand Down Expand Up @@ -376,6 +377,8 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
return torch::kInt64;
case transformer_engine::DType::kFloat32:
return at::kFloat;
case transformer_engine::DType::kFloat64:
return at::kDouble;
case transformer_engine::DType::kFloat16:
return at::kHalf;
case transformer_engine::DType::kBFloat16:
Expand All @@ -401,6 +404,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kFloat16;
case at::kFloat:
return transformer_engine::DType::kFloat32;
case at::kDouble:
return transformer_engine::DType::kFloat64;
case at::kBFloat16:
return transformer_engine::DType::kBFloat16;
case at::kBool:
Expand Down
Loading