@@ -15,12 +15,14 @@ limitations under the License.
1515
1616#include " xprof/convert/graph_viewer_processor.h"
1717
18+ #include < memory>
1819#include < optional>
1920#include < string>
2021
2122#include " absl/log/log.h"
2223#include " absl/status/status.h"
2324#include " absl/strings/string_view.h"
25+ #include " xla/hlo/ir/hlo_instruction.h"
2426#include " xla/tsl/platform/statusor.h"
2527#include " tsl/profiler/protobuf/xplane.pb.h"
2628#include " xprof/convert/hlo_proto_to_graph_view.h"
@@ -39,6 +41,9 @@ limitations under the License.
3941#include " plugin/xprof/protobuf/roofline_model.pb.h"
4042#include " plugin/xprof/protobuf/tf_data_stats.pb.h"
4143#include " plugin/xprof/protobuf/tf_stats.pb.h"
44+ #include " xprof/utils/custom_call_utils.h"
45+ #include " xprof/utils/hlo_module_utils.h"
46+ #include " xprof/utils/hlo_proto_to_module.h"
4247
4348namespace xprof {
4449
@@ -49,6 +54,7 @@ using ::tensorflow::profiler::GetHloProtoByModuleName;
4954using ::tensorflow::profiler::GetParam;
5055using ::tensorflow::profiler::GraphViewerParams;
5156using ::tensorflow::profiler::kAdjacentNodes ;
57+ using ::tensorflow::profiler::kCustomCallGraphTypeName ;
5258using ::tensorflow::profiler::kGraphTypeName ;
5359using ::tensorflow::profiler::ParseGraphViewerParams;
5460using ::tensorflow::profiler::SessionSnapshot;
@@ -64,6 +70,16 @@ absl::StatusOr<std::string> ConvertHloProtoToGraphViewer(
6470 return ConvertHloProtoToGraph (hlo_proto, params.node_name ,
6571 params.graph_width , params.render_options ,
6672 params.format );
73+ } else if (params.type == kCustomCallGraphTypeName ) {
74+ TF_ASSIGN_OR_RETURN (
75+ std::unique_ptr<xla::HloModule> hlo_module,
76+ tensorflow::profiler::ConvertHloProtoToModule (hlo_proto));
77+ const xla::HloInstruction* hlo_instruction =
78+ tensorflow::profiler::FindInstruction (*hlo_module, params.node_name );
79+ if (hlo_instruction == nullptr ) {
80+ return absl::InvalidArgumentError (" Hlo Instruction not found." );
81+ }
82+ return GetCustomCallText (*hlo_instruction);
6783 } else if (params.type == kAdjacentNodes ) {
6884 return GetAdjacentNodes (hlo_proto, params.node_name );
6985 } else {
0 commit comments