Skip to content

Commit bdd0901

Browse files
muditgokhale2copybara-github
authored andcommitted
Add support in trace_viewer to leverage the aggregator-worker setup
PiperOrigin-RevId: 837167925
1 parent 01633ff commit bdd0901

File tree

4 files changed

+548
-8
lines changed

4 files changed

+548
-8
lines changed

xprof/convert/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ cc_library(
206206
"@org_xprof//xprof/convert/trace_viewer:trace_events_to_json",
207207
"@org_xprof//xprof/convert/trace_viewer:trace_options",
208208
"@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility",
209+
"@tsl//tsl/platform:path",
209210
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
210211
"@xla//xla/tsl/platform:env",
211212
"@xla//xla/tsl/platform:errors",
@@ -215,6 +216,33 @@ cc_library(
215216
alwayslink = 1,
216217
)
217218

219+
cc_test(
220+
name = "streaming_trace_viewer_processor_test",
221+
srcs = ["streaming_trace_viewer_processor_test.cc"],
222+
deps = [
223+
":repository",
224+
":streaming_trace_viewer_processor",
225+
":tool_options",
226+
"//file/base:path",
227+
"//file/util:temp_path",
228+
"@com_google_absl//absl/container:flat_hash_map",
229+
"@com_google_absl//absl/status",
230+
"@com_google_absl//absl/status:statusor",
231+
"@com_google_absl//absl/strings",
232+
"@com_google_absl//absl/strings:string_view",
233+
"@com_google_googletest//:gtest_main",
234+
"@xla//xla/tsl/lib/core:status_test_util",
235+
"@xla//xla/tsl/platform:env",
236+
"@xla//xla/tsl/platform:errors",
237+
"@xla//xla/tsl/platform:status",
238+
"@xla//xla/tsl/platform:statusor",
239+
"@xla//xla/tsl/platform:test",
240+
"@xla//xla/tsl/profiler/utils:xplane_builder",
241+
"@xla//xla/tsl/profiler/utils:xplane_schema",
242+
"@xla//xla/tsl/profiler/utils:xplane_utils",
243+
],
244+
)
245+
218246
cc_library(
219247
name = "inference_stats_processor",
220248
srcs = ["inference_stats_processor.cc"],

xprof/convert/streaming_trace_viewer_processor.cc

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
#include <cmath>
44
#include <cstdint>
55
#include <memory>
6+
#include <optional>
67
#include <string>
78
#include <utility>
9+
#include <vector>
810

911
#include "absl/log/log.h"
1012
#include "absl/status/status.h"
1113
#include "absl/strings/numbers.h"
1214
#include "absl/strings/string_view.h"
15+
#include "absl/strings/strip.h"
1316
#include "google/protobuf/arena.h"
1417
#include "xla/tsl/platform/env.h"
1518
#include "xla/tsl/platform/errors.h"
1619
#include "xla/tsl/platform/file_system.h"
1720
#include "xla/tsl/platform/statusor.h"
1821
#include "xla/tsl/profiler/utils/timespan.h"
22+
#include "tsl/platform/path.h"
1923
#include "tsl/profiler/protobuf/xplane.pb.h"
2024
#include "xprof/convert/preprocess_single_host_xplane.h"
2125
#include "xprof/convert/process_megascale_dcn.h"
@@ -196,6 +200,176 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
196200
return absl::OkStatus();
197201
}
198202

203+
absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
204+
const std::string& xspace_path) {
205+
std::vector<std::string> xspace_paths = {xspace_path};
206+
TF_ASSIGN_OR_RETURN(
207+
SessionSnapshot session_snapshot,
208+
SessionSnapshot::Create(xspace_paths, /*xspaces=*/std::nullopt));
209+
// get xspace from session snapshot
210+
std::string hostname = session_snapshot.GetHostname(0);
211+
google::protobuf::Arena arena;
212+
TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena));
213+
214+
return Map(session_snapshot, hostname, *xspace);
215+
}
216+
217+
absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
218+
const SessionSnapshot& session_snapshot, const std::string& hostname,
219+
const XSpace& xspace) {
220+
XSpace temp_xspace = xspace;
221+
tensorflow::profiler::PreprocessSingleHostXSpace(&temp_xspace,
222+
/*step_grouping=*/true,
223+
/*derived_timeline=*/true);
224+
tensorflow::profiler::ProcessMegascaleDcn(&temp_xspace);
225+
226+
auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath(
227+
tensorflow::profiler::StoredDataType::TRACE_LEVELDB, hostname);
228+
auto trace_events_metadata_sstable_path =
229+
session_snapshot.MakeHostDataFilePath(
230+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
231+
hostname);
232+
auto trace_events_prefix_trie_sstable_path =
233+
session_snapshot.MakeHostDataFilePath(
234+
tensorflow::profiler::StoredDataType::
235+
TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
236+
hostname);
237+
238+
if (!trace_events_sstable_path.has_value() ||
239+
!trace_events_metadata_sstable_path.has_value() ||
240+
!trace_events_prefix_trie_sstable_path.has_value()) {
241+
return tsl::errors::Unimplemented(
242+
"streaming trace viewer hasn't been supported in Cloud AI");
243+
}
244+
245+
if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
246+
TraceEventsContainer trace_container;
247+
tensorflow::profiler::ConvertXSpaceToTraceEventsContainer(
248+
hostname, temp_xspace, &trace_container);
249+
std::unique_ptr<tsl::WritableFile> trace_events_file;
250+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
251+
*trace_events_sstable_path, &trace_events_file));
252+
std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
253+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
254+
*trace_events_metadata_sstable_path, &trace_events_metadata_file));
255+
std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
256+
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
257+
*trace_events_prefix_trie_sstable_path,
258+
&trace_events_prefix_trie_file));
259+
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables(
260+
std::move(trace_events_file), std::move(trace_events_metadata_file),
261+
std::move(trace_events_prefix_trie_file)));
262+
}
263+
return *trace_events_sstable_path;
264+
}
265+
266+
namespace {
267+
268+
absl::StatusOr<TraceEventsContainer> LoadTraceContainerForHost(
269+
const SessionSnapshot& session_snapshot,
270+
const std::string& trace_events_sstable_path,
271+
const TraceViewOption& trace_option,
272+
const tensorflow::profiler::TraceOptions& profiler_trace_options) {
273+
absl::string_view filename = tsl::io::Basename(trace_events_sstable_path);
274+
absl::ConsumeSuffix(&filename, ".SSTABLE");
275+
std::string hostname = std::string(filename);
276+
277+
TraceEventsLevelDbFilePaths file_paths;
278+
file_paths.trace_events_file_path = trace_events_sstable_path;
279+
// These should exist as they were created in the Map phase.
280+
auto metadata_path = session_snapshot.MakeHostDataFilePath(
281+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
282+
hostname);
283+
auto trie_path = session_snapshot.MakeHostDataFilePath(
284+
tensorflow::profiler::StoredDataType::TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
285+
hostname);
286+
if (!metadata_path || !trie_path) {
287+
return tsl::errors::Internal(
288+
"Could not find metadata or trie file paths for host: ", hostname);
289+
}
290+
file_paths.trace_events_metadata_file_path = *metadata_path;
291+
file_paths.trace_events_prefix_trie_file_path = *trie_path;
292+
293+
TraceEventsContainer trace_container;
294+
if (!trace_option.event_name.empty()) {
295+
TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable(
296+
file_paths.trace_events_metadata_file_path,
297+
file_paths.trace_events_file_path, trace_option.event_name,
298+
static_cast<uint64_t>(std::round(trace_option.start_time_ms * 1E9)),
299+
static_cast<uint64_t>(std::round(trace_option.duration_ms * 1E9)),
300+
trace_option.unique_id));
301+
} else if (!trace_option.search_prefix.empty()) { // Search Events Request
302+
if (tsl::Env::Default()
303+
->FileExists(file_paths.trace_events_prefix_trie_file_path)
304+
.ok()) {
305+
auto trace_events_filter =
306+
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
307+
TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable(
308+
file_paths, trace_option.search_prefix,
309+
std::move(trace_events_filter)));
310+
}
311+
} else {
312+
auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
313+
tsl::profiler::MilliSpan(trace_option.start_time_ms,
314+
trace_option.end_time_ms),
315+
trace_option.resolution, profiler_trace_options);
316+
// Trace smaller than threshold will be disabled from streaming.
317+
constexpr int64_t kDisableStreamingThreshold = 500000;
318+
auto trace_events_filter =
319+
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
320+
TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
321+
file_paths, std::move(trace_events_filter),
322+
std::move(visibility_filter), kDisableStreamingThreshold));
323+
}
324+
return trace_container;
325+
}
326+
327+
} // namespace
328+
329+
absl::Status StreamingTraceViewerProcessor::Reduce(
330+
const SessionSnapshot& session_snapshot,
331+
const std::vector<std::string>& map_output_files) {
332+
if (map_output_files.empty()) {
333+
return absl::InvalidArgumentError("map_output_files cannot be empty");
334+
}
335+
336+
TF_ASSIGN_OR_RETURN(TraceViewOption trace_option,
337+
GetTraceViewOption(options_));
338+
tensorflow::profiler::TraceOptions profiler_trace_options =
339+
TraceOptionsFromToolOptions(options_);
340+
341+
TraceEventsContainer merged_trace_container;
342+
343+
for (int i = 0; i < map_output_files.size(); ++i) {
344+
const std::string& trace_events_sstable_path = map_output_files[i];
345+
int host_id = i + 1;
346+
347+
TF_ASSIGN_OR_RETURN(
348+
TraceEventsContainer trace_container,
349+
LoadTraceContainerForHost(session_snapshot, trace_events_sstable_path,
350+
trace_option, profiler_trace_options));
351+
352+
merged_trace_container.Merge(std::move(trace_container), host_id);
353+
}
354+
355+
std::string trace_viewer_json;
356+
JsonTraceOptions json_trace_options;
357+
358+
tensorflow::profiler::TraceDeviceType device_type =
359+
tensorflow::profiler::TraceDeviceType::kUnknownDevice;
360+
if (IsTpuTrace(merged_trace_container.trace())) {
361+
device_type = TraceDeviceType::kTpu;
362+
}
363+
json_trace_options.details =
364+
TraceOptionsToDetails(device_type, profiler_trace_options);
365+
IOBufferAdapter adapter(&trace_viewer_json);
366+
TraceEventsToJson<IOBufferAdapter, TraceEventsContainer, RawData>(
367+
json_trace_options, merged_trace_container, &adapter);
368+
369+
SetOutput(trace_viewer_json, "application/json");
370+
return absl::OkStatus();
371+
}
372+
199373
// NOTE: We use "trace_viewer@" to distinguish from the non-streaming
200374
// trace_viewer. The "@" suffix is used to indicate that this tool
201375
// supports streaming.

xprof/convert/streaming_trace_viewer_processor.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,32 @@ namespace xprof {
1616
class StreamingTraceViewerProcessor : public ProfileProcessor {
1717
public:
1818
explicit StreamingTraceViewerProcessor(
19-
const tensorflow::profiler::ToolOptions&) {}
19+
const tensorflow::profiler::ToolOptions& options)
20+
: options_(options) {}
2021

2122
absl::Status ProcessSession(
2223
const tensorflow::profiler::SessionSnapshot& session_snapshot,
2324
const tensorflow::profiler::ToolOptions& options) final;
2425

26+
absl::StatusOr<std::string> Map(const std::string& xspace_path) override;
27+
2528
absl::StatusOr<std::string> Map(
2629
const tensorflow::profiler::SessionSnapshot& session_snapshot,
2730
const std::string& hostname,
28-
const tensorflow::profiler::XSpace& xspace) override {
29-
return absl::UnimplementedError(
30-
"Map not implemented for StreamingTraceViewerProcessor");
31-
}
31+
const tensorflow::profiler::XSpace& xspace) override;
3232

3333
absl::Status Reduce(
3434
const tensorflow::profiler::SessionSnapshot& session_snapshot,
35-
const std::vector<std::string>& map_output_files) override {
36-
return absl::UnimplementedError(
37-
"Reduce not implemented for StreamingTraceViewerProcessor");
35+
const std::vector<std::string>& map_output_files) override;
36+
37+
bool ShouldUseWorkerService(
38+
const tensorflow::profiler::SessionSnapshot& session_snapshot,
39+
const tensorflow::profiler::ToolOptions& options) const override {
40+
return session_snapshot.XSpaceSize() > 1;
3841
}
42+
43+
private:
44+
tensorflow::profiler::ToolOptions options_;
3945
};
4046

4147
} // namespace xprof

0 commit comments

Comments
 (0)