Skip to content

Commit f72e32a

Browse files
cliveverghesecopybara-github
authored andcommitted
Add support for Viewing custom call text in Graph Viewer.
PiperOrigin-RevId: 831447585
1 parent bd2bf8a commit f72e32a

File tree

6 files changed

+45
-16
lines changed

6 files changed

+45
-16
lines changed

xprof/convert/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ cc_library(
149149
"@org_xprof//plugin/xprof/protobuf:roofline_model_proto_cc",
150150
"@org_xprof//plugin/xprof/protobuf:tf_data_stats_proto_cc",
151151
"@org_xprof//plugin/xprof/protobuf:tf_stats_proto_cc",
152+
"@org_xprof//xprof/utils:custom_call_utils",
153+
"@org_xprof//xprof/utils:hlo_module_utils",
154+
"@org_xprof//xprof/utils:hlo_proto_to_module",
152155
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
156+
"@xla//xla/hlo/ir:hlo",
153157
"@xla//xla/tsl/platform:statusor",
154158
],
155159
alwayslink = 1,

xprof/convert/graph_viewer_processor.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ limitations under the License.
1515

1616
#include "xprof/convert/graph_viewer_processor.h"
1717

18+
#include <memory>
1819
#include <optional>
1920
#include <string>
2021

2122
#include "absl/log/log.h"
2223
#include "absl/status/status.h"
2324
#include "absl/strings/string_view.h"
25+
#include "xla/hlo/ir/hlo_instruction.h"
2426
#include "xla/tsl/platform/statusor.h"
2527
#include "tsl/profiler/protobuf/xplane.pb.h"
2628
#include "xprof/convert/hlo_proto_to_graph_view.h"
@@ -39,6 +41,9 @@ limitations under the License.
3941
#include "plugin/xprof/protobuf/roofline_model.pb.h"
4042
#include "plugin/xprof/protobuf/tf_data_stats.pb.h"
4143
#include "plugin/xprof/protobuf/tf_stats.pb.h"
44+
#include "xprof/utils/custom_call_utils.h"
45+
#include "xprof/utils/hlo_module_utils.h"
46+
#include "xprof/utils/hlo_proto_to_module.h"
4247

4348
namespace xprof {
4449

@@ -49,6 +54,7 @@ using ::tensorflow::profiler::GetHloProtoByModuleName;
4954
using ::tensorflow::profiler::GetParam;
5055
using ::tensorflow::profiler::GraphViewerParams;
5156
using ::tensorflow::profiler::kAdjacentNodes;
57+
using ::tensorflow::profiler::kCustomCallGraphTypeName;
5258
using ::tensorflow::profiler::kGraphTypeName;
5359
using ::tensorflow::profiler::ParseGraphViewerParams;
5460
using ::tensorflow::profiler::SessionSnapshot;
@@ -64,6 +70,16 @@ absl::StatusOr<std::string> ConvertHloProtoToGraphViewer(
6470
return ConvertHloProtoToGraph(hlo_proto, params.node_name,
6571
params.graph_width, params.render_options,
6672
params.format);
73+
} else if (params.type == kCustomCallGraphTypeName) {
74+
TF_ASSIGN_OR_RETURN(
75+
std::unique_ptr<xla::HloModule> hlo_module,
76+
tensorflow::profiler::ConvertHloProtoToModule(hlo_proto));
77+
const xla::HloInstruction* hlo_instruction =
78+
tensorflow::profiler::FindInstruction(*hlo_module, params.node_name);
79+
if (hlo_instruction == nullptr) {
80+
return absl::InvalidArgumentError("Hlo Instruction not found.");
81+
}
82+
return GetCustomCallText(*hlo_instruction);
6783
} else if (params.type == kAdjacentNodes) {
6884
return GetAdjacentNodes(hlo_proto, params.node_name);
6985
} else {

xprof/convert/hlo_proto_to_graph_view.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,10 @@ absl::StatusOr<GraphViewerParams> ParseGraphViewerParams(
300300
return InvalidArgument("Graph viewer must provide a type option.");
301301
}
302302
auto valid_types = {
303-
kGraphTypeName, kJsonTypeName, kProtoTypeName, kProtoTextTypeName,
304-
kShortTxtTypeName, kLongTxtTypeName, kAdjacentNodes,
303+
kGraphTypeName, kJsonTypeName,
304+
kProtoTypeName, kProtoTextTypeName,
305+
kShortTxtTypeName, kLongTxtTypeName,
306+
kAdjacentNodes, kCustomCallGraphTypeName,
305307
};
306308
if (std::find(valid_types.begin(), valid_types.end(), type.value()) ==
307309
valid_types.end()) {
@@ -334,6 +336,13 @@ absl::StatusOr<GraphViewerParams> ParseGraphViewerParams(
334336
params.node_name = node_name.value();
335337
}
336338
return params;
339+
} else if (type == kCustomCallGraphTypeName) {
340+
params.type = type.value();
341+
if (std::optional<std::string> node_name =
342+
GetParam<std::string>(options, "node_name")) {
343+
params.node_name = node_name.value();
344+
}
345+
return params;
337346
}
338347

339348
// For txt type.

xprof/convert/tool_options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace profiler {
2828

2929
// Default parameter constants for graph viewer.
3030
static constexpr char kGraphTypeName[] = "graph";
31+
static constexpr char kCustomCallGraphTypeName[] = "custom_call";
3132
static constexpr char kAdjacentNodes[] = "adj_nodes";
3233
static constexpr char kShortTxtTypeName[] = "short_txt";
3334
static constexpr char kLongTxtTypeName[] = "long_txt";

xprof/utils/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,15 +650,15 @@ cc_library(
650650
hdrs = ["custom_call_utils.h"],
651651
deps = [
652652
":backend_configs_proto_cc",
653-
"//third_party/llvm/llvm-project/mlir:IR",
654-
"//third_party/protobuf/json",
655-
"//util/task:status",
656653
"@com_google_absl//absl/status",
657654
"@com_google_absl//absl/status:statusor",
658-
"@com_google_protobuf//:json",
655+
"@llvm-project//mlir:IR",
656+
"@tsl//tsl/platform:protobuf",
659657
"@xla//xla/hlo/ir:hlo",
660658
"@xla//xla/pjrt:mlir_to_hlo",
661659
"@xla//xla/service/llvm_ir:llvm_util",
660+
"@xla//xla/tsl/platform:errors",
661+
"@xla//xla/tsl/platform:statusor",
662662
],
663663
)
664664

xprof/utils/custom_call_utils.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
#include <string>
44

5-
#include "xprof/utils/backend_configs.pb.h"
65
#include "absl/status/status.h"
76
#include "absl/status/statusor.h"
8-
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinOps.h"
9-
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
10-
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OwningOpRef.h"
11-
#include "google/protobuf/json/json.h"
12-
#include "google/protobuf/json/json.h"
7+
#include "mlir/include/mlir/IR/BuiltinOps.h"
8+
#include "mlir/include/mlir/IR/MLIRContext.h"
9+
#include "mlir/include/mlir/IR/OwningOpRef.h"
1310
#include "xla/hlo/ir/hlo_instruction.h"
1411
#include "xla/pjrt/mlir_to_hlo.h"
1512
#include "xla/service/llvm_ir/llvm_util.h"
16-
#include "util/task/status_macros.h"
13+
#include "xla/tsl/platform/errors.h"
14+
#include "xla/tsl/platform/statusor.h"
15+
#include "xprof/utils/backend_configs.pb.h"
1716

1817
namespace xprof {
1918

@@ -22,18 +21,18 @@ absl::StatusOr<std::string> GetCustomCallText(
2221
if (!hlo_instruction.has_backend_config()) {
2322
return absl::NotFoundError("Backend config not found");
2423
}
25-
google::protobuf::json::ParseOptions options;
24+
tsl::protobuf::util::JsonParseOptions options;
2625
options.ignore_unknown_fields = true;
2726
BackendConfig config;
28-
RETURN_IF_ERROR(google::protobuf::util::JsonStringToMessage(
27+
TF_RETURN_IF_ERROR(tsl::protobuf::util::JsonStringToMessage(
2928
hlo_instruction.raw_backend_config_string(), &config, options));
3029
if (!config.has_custom_call_config()) {
3130
return absl::NotFoundError("Custom call config not found");
3231
}
3332
CustomCallConfig custom_call_config = config.custom_call_config();
3433
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
3534
context.allowUnregisteredDialects(true);
36-
ASSIGN_OR_RETURN(
35+
TF_ASSIGN_OR_RETURN(
3736
mlir::OwningOpRef<mlir::ModuleOp> mlir_op,
3837
xla::ParseMlirModuleString(
3938
static_cast<std::string>(custom_call_config.body()), context));

0 commit comments

Comments
 (0)