Skip to content

Commit 29f631d

Browse files
muditgokhale2copybara-github
authored andcommitted
Leverage trace viewer to use the aggregator-worker setup
PiperOrigin-RevId: 837167925
1 parent d7178ab commit 29f631d

File tree

3 files changed

+125
-69
lines changed

3 files changed

+125
-69
lines changed

xprof/convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ cc_library(
206206
"@org_xprof//xprof/convert/trace_viewer:trace_events_to_json",
207207
"@org_xprof//xprof/convert/trace_viewer:trace_options",
208208
"@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility",
209+
"@tsl//tsl/platform:path",
209210
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
210211
"@xla//xla/tsl/platform:env",
211212
"@xla//xla/tsl/platform:errors",

xprof/convert/streaming_trace_viewer_processor.cc

Lines changed: 110 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
#include <cmath>
44
#include <cstdint>
55
#include <memory>
6+
#include <optional>
67
#include <string>
78
#include <utility>
9+
#include <vector>
810

911
#include "absl/log/log.h"
1012
#include "absl/status/status.h"
1113
#include "absl/strings/numbers.h"
1214
#include "absl/strings/string_view.h"
15+
#include "absl/strings/strip.h"
1316
#include "google/protobuf/arena.h"
1417
#include "xla/tsl/platform/env.h"
1518
#include "xla/tsl/platform/errors.h"
1619
#include "xla/tsl/platform/file_system.h"
1720
#include "xla/tsl/platform/statusor.h"
1821
#include "xla/tsl/profiler/utils/timespan.h"
22+
#include "tsl/platform/path.h"
1923
#include "tsl/profiler/protobuf/xplane.pb.h"
2024
#include "xprof/convert/preprocess_single_host_xplane.h"
2125
#include "xprof/convert/process_megascale_dcn.h"
@@ -83,84 +87,129 @@ absl::StatusOr<TraceViewOption> GetTraceViewOption(const ToolOptions& options) {
8387

8488
absl::Status StreamingTraceViewerProcessor::ProcessSession(
8589
const SessionSnapshot& session_snapshot, const ToolOptions& options) {
86-
TraceEventsContainer merged_trace_container;
87-
std::string tool_name = "trace_viewer@";
90+
return absl::UnimplementedError(
91+
"ProcessSession is not used when ShouldUseWorkerService is true for "
92+
"trace_viewer@.");
93+
}
94+
95+
absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
96+
const std::string& xspace_path) {
97+
std::vector<std::string> xspace_paths = {xspace_path};
98+
TF_ASSIGN_OR_RETURN(
99+
SessionSnapshot session_snapshot,
100+
SessionSnapshot::Create(xspace_paths, /*xspaces=*/std::nullopt));
101+
// get xspace from session snapshot
102+
std::string hostname = session_snapshot.GetHostname(0);
103+
google::protobuf::Arena arena;
104+
TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena));
105+
106+
return Map(session_snapshot, hostname, *xspace);
107+
}
108+
109+
absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
110+
const SessionSnapshot& session_snapshot, const std::string& hostname,
111+
const XSpace& xspace) {
112+
XSpace temp_xspace = xspace;
113+
tensorflow::profiler::PreprocessSingleHostXSpace(&temp_xspace,
114+
/*step_grouping=*/true,
115+
/*derived_timeline=*/true);
116+
tensorflow::profiler::ProcessMegascaleDcn(&temp_xspace);
117+
118+
auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath(
119+
tensorflow::profiler::StoredDataType::TRACE_LEVELDB, hostname);
120+
auto trace_events_metadata_sstable_path =
121+
session_snapshot.MakeHostDataFilePath(
122+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
123+
hostname);
124+
auto trace_events_prefix_trie_sstable_path =
125+
session_snapshot.MakeHostDataFilePath(
126+
tensorflow::profiler::StoredDataType::
127+
TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
128+
hostname);
129+
130+
if (!trace_events_sstable_path || !trace_events_metadata_sstable_path ||
131+
!trace_events_prefix_trie_sstable_path) {
132+
return tsl::errors::Unimplemented(
133+
"streaming trace viewer hasn't been supported in Cloud AI");
134+
}
135+
136+
if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
137+
TraceEventsContainer trace_container;
138+
tensorflow::profiler::ConvertXSpaceToTraceEventsContainer(
139+
hostname, temp_xspace, &trace_container);
140+
std::unique_ptr<tsl::WritableFile> trace_events_file;
141+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
142+
*trace_events_sstable_path, &trace_events_file));
143+
std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
144+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
145+
*trace_events_metadata_sstable_path, &trace_events_metadata_file));
146+
std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
147+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
148+
*trace_events_prefix_trie_sstable_path,
149+
&trace_events_prefix_trie_file));
150+
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables(
151+
std::move(trace_events_file), std::move(trace_events_metadata_file),
152+
std::move(trace_events_prefix_trie_file)));
153+
}
154+
return *trace_events_sstable_path; // Return the path to the main LevelDB
155+
// file
156+
}
157+
158+
absl::Status StreamingTraceViewerProcessor::Reduce(
159+
const SessionSnapshot& session_snapshot,
160+
const std::vector<std::string>& map_output_files) {
161+
if (map_output_files.empty()) {
162+
return absl::InvalidArgumentError("map_output_files cannot be empty");
163+
}
88164

89165
TF_ASSIGN_OR_RETURN(TraceViewOption trace_option,
90-
GetTraceViewOption(options));
166+
GetTraceViewOption(options_));
91167
tensorflow::profiler::TraceOptions profiler_trace_options =
92-
TraceOptionsFromToolOptions(options);
93-
94-
// TODO: b/452217676 - Optimize this to process hosts in parallel.
95-
for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) {
96-
int host_id = i+1;
97-
google::protobuf::Arena arena;
98-
TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(i, &arena));
99-
PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true,
100-
/*derived_timeline=*/true);
101-
102-
std::string host_name = session_snapshot.GetHostname(i);
103-
auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath(
104-
tensorflow::profiler::StoredDataType::TRACE_LEVELDB, host_name);
105-
auto trace_events_metadata_sstable_path =
106-
session_snapshot.MakeHostDataFilePath(
107-
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
108-
host_name);
109-
auto trace_events_prefix_trie_sstable_path =
110-
session_snapshot.MakeHostDataFilePath(
111-
tensorflow::profiler::StoredDataType::
112-
TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
113-
host_name);
114-
if (!trace_events_sstable_path || !trace_events_metadata_sstable_path ||
115-
!trace_events_prefix_trie_sstable_path) {
116-
return tsl::errors::Unimplemented(
117-
"streaming trace viewer hasn't been supported in Cloud AI");
118-
}
119-
if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
120-
ProcessMegascaleDcn(xspace);
121-
TraceEventsContainer trace_container;
122-
ConvertXSpaceToTraceEventsContainer(host_name, *xspace,
123-
&trace_container);
124-
std::unique_ptr<tsl::WritableFile> trace_events_file;
125-
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
126-
*trace_events_sstable_path, &trace_events_file));
127-
std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
128-
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
129-
*trace_events_metadata_sstable_path, &trace_events_metadata_file));
130-
std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
131-
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
132-
*trace_events_prefix_trie_sstable_path,
133-
&trace_events_prefix_trie_file));
134-
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables(
135-
std::move(trace_events_file),
136-
std::move(trace_events_metadata_file),
137-
std::move(trace_events_prefix_trie_file)
138-
));
139-
}
168+
TraceOptionsFromToolOptions(options_);
169+
170+
TraceEventsContainer merged_trace_container;
171+
172+
for (int i = 0; i < map_output_files.size(); ++i) {
173+
const std::string& trace_events_sstable_path = map_output_files[i];
174+
int host_id = i + 1;
175+
176+
absl::string_view filename = tsl::io::Basename(trace_events_sstable_path);
177+
absl::ConsumeSuffix(&filename, ".SSTABLE");
178+
std::string hostname = std::string(filename);
140179

141180
TraceEventsLevelDbFilePaths file_paths;
142-
file_paths.trace_events_file_path = *trace_events_sstable_path;
143-
file_paths.trace_events_metadata_file_path =
144-
*trace_events_metadata_sstable_path;
145-
file_paths.trace_events_prefix_trie_file_path =
146-
*trace_events_prefix_trie_sstable_path;
181+
file_paths.trace_events_file_path = trace_events_sstable_path;
182+
// These should exist as they were created in the Map phase.
183+
auto metadata_path = session_snapshot.MakeHostDataFilePath(
184+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
185+
hostname);
186+
auto trie_path = session_snapshot.MakeHostDataFilePath(
187+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
188+
hostname);
189+
if (!metadata_path || !trie_path) {
190+
return tsl::errors::Internal(
191+
"Could not find metadata or trie file paths for host: ", hostname);
192+
}
193+
file_paths.trace_events_metadata_file_path = *metadata_path;
194+
file_paths.trace_events_prefix_trie_file_path = *trie_path;
147195

148196
TraceEventsContainer trace_container;
149197
if (!trace_option.event_name.empty()) {
150198
TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable(
151-
*trace_events_metadata_sstable_path, *trace_events_sstable_path,
152-
trace_option.event_name,
199+
file_paths.trace_events_metadata_file_path,
200+
file_paths.trace_events_file_path, trace_option.event_name,
153201
static_cast<uint64_t>(std::round(trace_option.start_time_ms * 1E9)),
154202
static_cast<uint64_t>(std::round(trace_option.duration_ms * 1E9)),
155203
trace_option.unique_id));
156204
} else if (!trace_option.search_prefix.empty()) { // Search Events Request
157205
if (tsl::Env::Default()
158-
->FileExists(*trace_events_prefix_trie_sstable_path).ok()) {
206+
->FileExists(file_paths.trace_events_prefix_trie_file_path)
207+
.ok()) {
159208
auto trace_events_filter =
160209
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
161210
TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable(
162-
file_paths,
163-
trace_option.search_prefix, std::move(trace_events_filter)));
211+
file_paths, trace_option.search_prefix,
212+
std::move(trace_events_filter)));
164213
}
165214
} else {
166215
auto visibility_filter = std::make_unique<TraceVisibilityFilter>(

xprof/convert/streaming_trace_viewer_processor.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,32 @@ namespace xprof {
1616
class StreamingTraceViewerProcessor : public ProfileProcessor {
1717
public:
1818
explicit StreamingTraceViewerProcessor(
19-
const tensorflow::profiler::ToolOptions&) {}
19+
const tensorflow::profiler::ToolOptions& options)
20+
: options_(options) {} // Store options
2021

2122
absl::Status ProcessSession(
2223
const tensorflow::profiler::SessionSnapshot& session_snapshot,
2324
const tensorflow::profiler::ToolOptions& options) final;
2425

26+
absl::StatusOr<std::string> Map(const std::string& xspace_path) override;
27+
2528
absl::StatusOr<std::string> Map(
2629
const tensorflow::profiler::SessionSnapshot& session_snapshot,
2730
const std::string& hostname,
28-
const tensorflow::profiler::XSpace& xspace) override {
29-
return absl::UnimplementedError(
30-
"Map not implemented for StreamingTraceViewerProcessor");
31-
}
31+
const tensorflow::profiler::XSpace& xspace) override;
3232

3333
absl::Status Reduce(
3434
const tensorflow::profiler::SessionSnapshot& session_snapshot,
35-
const std::vector<std::string>& map_output_files) override {
36-
return absl::UnimplementedError(
37-
"Reduce not implemented for StreamingTraceViewerProcessor");
35+
const std::vector<std::string>& map_output_files) override;
36+
37+
bool ShouldUseWorkerService(
38+
const tensorflow::profiler::SessionSnapshot& session_snapshot,
39+
const tensorflow::profiler::ToolOptions& options) const override {
40+
return true;
3841
}
42+
43+
private:
44+
tensorflow::profiler::ToolOptions options_;
3945
};
4046

4147
} // namespace xprof

0 commit comments

Comments
 (0)