Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -898,23 +898,36 @@ platform(
],
)

cc_library(
name = "hlo_featurizer",
srcs = [
"hlo_featurizer/encoder.cc",
],
hdrs = [
"hlo_featurizer/encoder.h",
"hlo_featurizer/featurizer.h",
"hlo_featurizer/hlo_opcode.h",
],
linkstatic = True,
deps = [
"@com_google_absl//absl/strings",
"@xla//xla/hlo/ir:hlo",
],
alwayslink = True,
)

cc_library(
name = "ReactantExtraLib",
srcs = glob(
[
"*.cpp",
],
) + [
srcs = [
"API.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
# "@xla//xla/service/gpu:backend_configs.pb.cc",
# "@xla//xla:autotuning.pb.cc",
# "@xla//xla:autotune_results.pb.cc",
# "@xla//xla/service:buffer_assignment.pb.cc",
],
hdrs = glob([
"*.h",
]) + [
hdrs = [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.h",
],
copts = [
Expand Down
5 changes: 0 additions & 5 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,6 @@ load(

cuda_json_init_repository()

load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
Expand Down
5 changes: 5 additions & 0 deletions deps/ReactantExtra/hlo_featurizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# HLO Featurizer

This is based on the code from
<https://github.com/google-research-datasets/tpu_graphs/blob/main/tpu_graphs/process_data/xla/>
but adapted to the latest version of XLA.
Empty file.
1 change: 1 addition & 0 deletions deps/ReactantExtra/hlo_featurizer/encoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#pragma once
2 changes: 2 additions & 0 deletions deps/ReactantExtra/hlo_featurizer/featurizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#pragma once

76 changes: 76 additions & 0 deletions deps/ReactantExtra/hlo_featurizer/hlo_opcode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#pragma once

#include <string>

#include "absl/strings/string_view.h"

#include "xla/hlo/ir/hlo_opcode.h"

namespace reactant {
namespace ml_lib {

using namespace xla;

#define FEATURE_WINDOW 1 << 0
#define FEATURE_OP_NON_ZERO 1 << 1
#define FEATURE_OP_ZERO 1 << 2
#define FEATURE_MODULE_NON_ZERO 1 << 3

// Node features.
const constexpr uint16_t kMinimalNodeFeatureCount = 113;
const constexpr uint16_t kOpLevelNonZeroNodeFeatureCount = 27;
const constexpr uint16_t kModuleLevelNonZeroNodeFeatureCount = 29;

// Module features.
const constexpr uint16_t kWindowConfigFeatureCount = 24;

// Config features.
const constexpr uint16_t kFusionConfigFeatureCount = 1;
const constexpr uint16_t kLayoutConfigFeatureCount = 18;
const constexpr uint16_t kDotConfigFeatureCount = 3;
// make sure to compute feature ranges after module features are finalized

inline uint8_t GetIncludeFeatureBits(absl::string_view task) {
if (task == "op_window_cost") {
return FEATURE_OP_NON_ZERO | FEATURE_WINDOW;
}
if (task == "module_fusion_cost" || task == "module_layout_cost" ||
task == "module_dot_cost") {
return FEATURE_OP_NON_ZERO;
}
return 0;
}

inline uint16_t GetNodeFeatureCount(absl::string_view task) {
if (task == "op_window_cost") {
return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount;
}
if (task == "module_fusion_cost" || task == "module_layout_cost" ||
task == "module_dot_cost") {
return kMinimalNodeFeatureCount + kOpLevelNonZeroNodeFeatureCount;
}
return kMinimalNodeFeatureCount;
}

inline uint16_t GetModuleFeatureCount(absl::string_view task) {
if (task == "op_window_cost") {
return kWindowConfigFeatureCount;
}
return 0;
}

inline uint16_t GetConfigFeatureCount(absl::string_view task) {
if (task == "module_fusion_cost") {
return kFusionConfigFeatureCount;
}
if (task == "module_layout_cost") {
return kLayoutConfigFeatureCount;
}
if (task == "module_dot_cost") {
return kDotConfigFeatureCount;
}
return 0;
}

} // namespace ml_lib
} // namespace reactant
Loading