diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 7254ceb3bc..758aa1342a 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -898,13 +898,28 @@ platform( ], ) +cc_library( + name = "hlo_featurizer", + srcs = [ + "hlo_featurizer/encoder.cc", + ], + hdrs = [ + "hlo_featurizer/encoder.h", + "hlo_featurizer/featurizer.h", + "hlo_featurizer/hlo_opcode.h", + ], + linkstatic = True, + deps = [ + "@com_google_absl//absl/strings", + "@xla//xla/hlo/ir:hlo", + ], + alwayslink = True, +) + cc_library( name = "ReactantExtraLib", - srcs = glob( - [ - "*.cpp", - ], - ) + [ + srcs = [ + "API.cpp", "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", "@enzyme_ad//src/enzyme_ad/jax:cpu.cc", # "@xla//xla/service/gpu:backend_configs.pb.cc", @@ -912,9 +927,7 @@ cc_library( # "@xla//xla:autotune_results.pb.cc", # "@xla//xla/service:buffer_assignment.pb.cc", ], - hdrs = glob([ - "*.h", - ]) + [ + hdrs = [ "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.h", ], copts = [ diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 90496d5ba6..544d81f95e 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -226,11 +226,6 @@ load( cuda_json_init_repository() -load( - "@cuda_redist_json//:distributions.bzl", - "CUDA_REDISTRIBUTIONS", - "CUDNN_REDISTRIBUTIONS", -) load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", diff --git a/deps/ReactantExtra/hlo_featurizer/README.md b/deps/ReactantExtra/hlo_featurizer/README.md new file mode 100644 index 0000000000..b81d5abd23 --- /dev/null +++ b/deps/ReactantExtra/hlo_featurizer/README.md @@ -0,0 +1,5 @@ +# HLO Featurizer + +This is based on the code from + +but adapted to the latest version of XLA. diff --git a/deps/ReactantExtra/hlo_featurizer/encoder.cc b/deps/ReactantExtra/hlo_featurizer/encoder.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deps/ReactantExtra/hlo_featurizer/encoder.h b/deps/ReactantExtra/hlo_featurizer/encoder.h new file mode 100644 index 0000000000..7b9637ef9c --- /dev/null +++ b/deps/ReactantExtra/hlo_featurizer/encoder.h @@ -0,0 +1 @@ +#pragma once \ No newline at end of file diff --git a/deps/ReactantExtra/hlo_featurizer/featurizer.h b/deps/ReactantExtra/hlo_featurizer/featurizer.h new file mode 100644 index 0000000000..3f59c932d3 --- /dev/null +++ b/deps/ReactantExtra/hlo_featurizer/featurizer.h @@ -0,0 +1,2 @@ +#pragma once + diff --git a/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h b/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h new file mode 100644 index 0000000000..4f0d3bdf5e --- /dev/null +++ b/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h @@ -0,0 +1,76 @@ +#pragma once + +#include + +#include "absl/strings/string_view.h" + +#include "xla/hlo/ir/hlo_opcode.h" + +namespace reactant { +namespace ml_lib { + +using namespace xla; + +#define FEATURE_WINDOW 1 << 0 +#define FEATURE_OP_NON_ZERO 1 << 1 +#define FEATURE_OP_ZERO 1 << 2 +#define FEATURE_MODULE_NON_ZERO 1 << 3 + +// Node features. +const constexpr uint16_t kMinimalNodeFeatureCount = 113; +const constexpr uint16_t kOpLevelNonZeroNodeFeatureCount = 27; +const constexpr uint16_t kModuleLevelNonZeroNodeFeatureCount = 29; + +// Module features. +const constexpr uint16_t kWindowConfigFeatureCount = 24; + +// Config features. +const constexpr uint16_t kFusionConfigFeatureCount = 1; +const constexpr uint16_t kLayoutConfigFeatureCount = 18; +const constexpr uint16_t kDotConfigFeatureCount = 3; +// make sure to compute feature ranges after module features are finalized + +inline uint8_t GetIncludeFeatureBits(absl::string_view task) { + if (task == "op_window_cost") { + return FEATURE_OP_NON_ZERO | FEATURE_WINDOW; + } + if (task == "module_fusion_cost" || task == "module_layout_cost" || + task == "module_dot_cost") { + return FEATURE_OP_NON_ZERO; + } + return 0; +} + +inline uint16_t GetNodeFeatureCount(absl::string_view task) { + if (task == "op_window_cost") { + return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount; + } + if (task == "module_fusion_cost" || task == "module_layout_cost" || + task == "module_dot_cost") { + return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount; + } + return kMinimalNodeFeatureCount; +} + +inline uint16_t GetModuleFeatureCount(absl::string_view task) { + if (task == "op_window_cost") { + return kWindowConfigFeatureCount; + } + return 0; +} + +inline uint16_t GetConfigFeatureCount(absl::string_view task) { + if (task == "module_fusion_cost") { + return kFusionConfigFeatureCount; + } + if (task == "module_layout_cost") { + return kLayoutConfigFeatureCount; + } + if (task == "module_dot_cost") { + return kDotConfigFeatureCount; + } + return 0; +} + +} // namespace ml_lib +} // namespace reactant