diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD
index 7254ceb3bc..758aa1342a 100644
--- a/deps/ReactantExtra/BUILD
+++ b/deps/ReactantExtra/BUILD
@@ -898,13 +898,28 @@ platform(
],
)
+cc_library(
+ name = "hlo_featurizer",
+ srcs = [
+ "hlo_featurizer/encoder.cc",
+ ],
+ hdrs = [
+ "hlo_featurizer/encoder.h",
+ "hlo_featurizer/featurizer.h",
+ "hlo_featurizer/hlo_opcode.h",
+ ],
+ linkstatic = True,
+ deps = [
+ "@com_google_absl//absl/strings",
+ "@xla//xla/hlo/ir:hlo",
+ ],
+ alwayslink = True,
+)
+
cc_library(
name = "ReactantExtraLib",
- srcs = glob(
- [
- "*.cpp",
- ],
- ) + [
+ srcs = [
+ "API.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
# "@xla//xla/service/gpu:backend_configs.pb.cc",
@@ -912,9 +927,7 @@ cc_library(
# "@xla//xla:autotune_results.pb.cc",
# "@xla//xla/service:buffer_assignment.pb.cc",
],
- hdrs = glob([
- "*.h",
- ]) + [
+ hdrs = [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.h",
],
copts = [
diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index 90496d5ba6..544d81f95e 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -226,11 +226,6 @@ load(
cuda_json_init_repository()
-load(
- "@cuda_redist_json//:distributions.bzl",
- "CUDA_REDISTRIBUTIONS",
- "CUDNN_REDISTRIBUTIONS",
-)
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
diff --git a/deps/ReactantExtra/hlo_featurizer/README.md b/deps/ReactantExtra/hlo_featurizer/README.md
new file mode 100644
index 0000000000..b81d5abd23
--- /dev/null
+++ b/deps/ReactantExtra/hlo_featurizer/README.md
@@ -0,0 +1,5 @@
+# HLO Featurizer
+
+This is based on the code from
+
+but adapted to the latest version of XLA.
diff --git a/deps/ReactantExtra/hlo_featurizer/encoder.cc b/deps/ReactantExtra/hlo_featurizer/encoder.cc
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/deps/ReactantExtra/hlo_featurizer/encoder.h b/deps/ReactantExtra/hlo_featurizer/encoder.h
new file mode 100644
index 0000000000..7b9637ef9c
--- /dev/null
+++ b/deps/ReactantExtra/hlo_featurizer/encoder.h
@@ -0,0 +1 @@
+#pragma once
\ No newline at end of file
diff --git a/deps/ReactantExtra/hlo_featurizer/featurizer.h b/deps/ReactantExtra/hlo_featurizer/featurizer.h
new file mode 100644
index 0000000000..3f59c932d3
--- /dev/null
+++ b/deps/ReactantExtra/hlo_featurizer/featurizer.h
@@ -0,0 +1,2 @@
+#pragma once
+
diff --git a/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h b/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h
new file mode 100644
index 0000000000..4f0d3bdf5e
--- /dev/null
+++ b/deps/ReactantExtra/hlo_featurizer/hlo_opcode.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include
+
+#include "absl/strings/string_view.h"
+
+#include "xla/hlo/ir/hlo_opcode.h"
+
+namespace reactant {
+namespace ml_lib {
+
+using namespace xla;
+
+#define FEATURE_WINDOW 1 << 0
+#define FEATURE_OP_NON_ZERO 1 << 1
+#define FEATURE_OP_ZERO 1 << 2
+#define FEATURE_MODULE_NON_ZERO 1 << 3
+
+// Node features.
+const constexpr uint16_t kMinimalNodeFeatureCount = 113;
+const constexpr uint16_t kOpLevelNonZeroNodeFeatureCount = 27;
+const constexpr uint16_t kModuleLevelNonZeroNodeFeatureCount = 29;
+
+// Module features.
+const constexpr uint16_t kWindowConfigFeatureCount = 24;
+
+// Config features.
+const constexpr uint16_t kFusionConfigFeatureCount = 1;
+const constexpr uint16_t kLayoutConfigFeatureCount = 18;
+const constexpr uint16_t kDotConfigFeatureCount = 3;
+// make sure to compute feature ranges after module features are finalized
+
+inline uint8_t GetIncludeFeatureBits(absl::string_view task) {
+ if (task == "op_window_cost") {
+ return FEATURE_OP_NON_ZERO | FEATURE_WINDOW;
+ }
+ if (task == "module_fusion_cost" || task == "module_layout_cost" ||
+ task == "module_dot_cost") {
+ return FEATURE_OP_NON_ZERO;
+ }
+ return 0;
+}
+
+inline uint16_t GetNodeFeatureCount(absl::string_view task) {
+ if (task == "op_window_cost") {
+ return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount;
+ }
+ if (task == "module_fusion_cost" || task == "module_layout_cost" ||
+ task == "module_dot_cost") {
+ return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount;
+ }
+ return kMinimalNodeFeatureCount;
+}
+
+inline uint16_t GetModuleFeatureCount(absl::string_view task) {
+ if (task == "op_window_cost") {
+ return kWindowConfigFeatureCount;
+ }
+ return 0;
+}
+
+inline uint16_t GetConfigFeatureCount(absl::string_view task) {
+ if (task == "module_fusion_cost") {
+ return kFusionConfigFeatureCount;
+ }
+ if (task == "module_layout_cost") {
+ return kLayoutConfigFeatureCount;
+ }
+ if (task == "module_dot_cost") {
+ return kDotConfigFeatureCount;
+ }
+ return 0;
+}
+
+} // namespace ml_lib
+} // namespace reactant