Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions xprof/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ cc_library(
"@org_xprof//xprof/convert/trace_viewer:trace_events_to_json",
"@org_xprof//xprof/convert/trace_viewer:trace_options",
"@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility",
"@tsl//tsl/platform:path",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/tsl/platform:env",
"@xla//xla/tsl/platform:errors",
Expand All @@ -215,6 +216,33 @@ cc_library(
alwayslink = 1,
)

cc_test(
name = "streaming_trace_viewer_processor_test",
srcs = ["streaming_trace_viewer_processor_test.cc"],
deps = [
":repository",
":streaming_trace_viewer_processor",
":tool_options",
"//file/base:path",
"//file/util:temp_path",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@xla//xla/tsl/lib/core:status_test_util",
"@xla//xla/tsl/platform:env",
"@xla//xla/tsl/platform:errors",
"@xla//xla/tsl/platform:status",
"@xla//xla/tsl/platform:statusor",
"@xla//xla/tsl/platform:test",
"@xla//xla/tsl/profiler/utils:xplane_builder",
"@xla//xla/tsl/profiler/utils:xplane_schema",
"@xla//xla/tsl/profiler/utils:xplane_utils",
],
)

cc_library(
name = "inference_stats_processor",
srcs = ["inference_stats_processor.cc"],
Expand Down
174 changes: 174 additions & 0 deletions xprof/convert/streaming_trace_viewer_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
#include <cmath>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "google/protobuf/arena.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/file_system.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/profiler/utils/timespan.h"
#include "tsl/platform/path.h"
#include "tsl/profiler/protobuf/xplane.pb.h"
#include "xprof/convert/preprocess_single_host_xplane.h"
#include "xprof/convert/process_megascale_dcn.h"
Expand Down Expand Up @@ -196,6 +200,176 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
return absl::OkStatus();
}

absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
const std::string& xspace_path) {
std::vector<std::string> xspace_paths = {xspace_path};
TF_ASSIGN_OR_RETURN(
SessionSnapshot session_snapshot,
SessionSnapshot::Create(xspace_paths, /*xspaces=*/std::nullopt));
// get xspace from session snapshot
std::string hostname = session_snapshot.GetHostname(0);
google::protobuf::Arena arena;
TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena));

return Map(session_snapshot, hostname, *xspace);
}

absl::StatusOr<std::string> StreamingTraceViewerProcessor::Map(
const SessionSnapshot& session_snapshot, const std::string& hostname,
const XSpace& xspace) {
XSpace temp_xspace = xspace;
tensorflow::profiler::PreprocessSingleHostXSpace(&temp_xspace,
/*step_grouping=*/true,
/*derived_timeline=*/true);
tensorflow::profiler::ProcessMegascaleDcn(&temp_xspace);

auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_LEVELDB, hostname);
auto trace_events_metadata_sstable_path =
session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
hostname);
auto trace_events_prefix_trie_sstable_path =
session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::
TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
hostname);

if (!trace_events_sstable_path.has_value() ||
!trace_events_metadata_sstable_path.has_value() ||
!trace_events_prefix_trie_sstable_path.has_value()) {
return tsl::errors::Unimplemented(
"streaming trace viewer hasn't been supported in Cloud AI");
}

if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
TraceEventsContainer trace_container;
tensorflow::profiler::ConvertXSpaceToTraceEventsContainer(
hostname, temp_xspace, &trace_container);
std::unique_ptr<tsl::WritableFile> trace_events_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_sstable_path, &trace_events_file));
std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_metadata_sstable_path, &trace_events_metadata_file));
std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_prefix_trie_sstable_path,
&trace_events_prefix_trie_file));
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables(
std::move(trace_events_file), std::move(trace_events_metadata_file),
std::move(trace_events_prefix_trie_file)));
}
return *trace_events_sstable_path;
}

namespace {

absl::StatusOr<TraceEventsContainer> LoadTraceContainerForHost(
const SessionSnapshot& session_snapshot,
const std::string& trace_events_sstable_path,
const TraceViewOption& trace_option,
const tensorflow::profiler::TraceOptions& profiler_trace_options) {
absl::string_view filename = tsl::io::Basename(trace_events_sstable_path);
absl::ConsumeSuffix(&filename, ".SSTABLE");
std::string hostname = std::string(filename);

TraceEventsLevelDbFilePaths file_paths;
file_paths.trace_events_file_path = trace_events_sstable_path;
// These should exist as they were created in the Map phase.
auto metadata_path = session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
hostname);
auto trie_path = session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
hostname);
if (!metadata_path || !trie_path) {
return tsl::errors::Internal(
"Could not find metadata or trie file paths for host: ", hostname);
}
file_paths.trace_events_metadata_file_path = *metadata_path;
file_paths.trace_events_prefix_trie_file_path = *trie_path;

TraceEventsContainer trace_container;
if (!trace_option.event_name.empty()) {
TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable(
file_paths.trace_events_metadata_file_path,
file_paths.trace_events_file_path, trace_option.event_name,
static_cast<uint64_t>(std::round(trace_option.start_time_ms * 1E9)),
static_cast<uint64_t>(std::round(trace_option.duration_ms * 1E9)),
trace_option.unique_id));
} else if (!trace_option.search_prefix.empty()) { // Search Events Request
if (tsl::Env::Default()
->FileExists(file_paths.trace_events_prefix_trie_file_path)
.ok()) {
auto trace_events_filter =
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable(
file_paths, trace_option.search_prefix,
std::move(trace_events_filter)));
}
} else {
auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
tsl::profiler::MilliSpan(trace_option.start_time_ms,
trace_option.end_time_ms),
trace_option.resolution, profiler_trace_options);
// Trace smaller than threshold will be disabled from streaming.
constexpr int64_t kDisableStreamingThreshold = 500000;
auto trace_events_filter =
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
file_paths, std::move(trace_events_filter),
std::move(visibility_filter), kDisableStreamingThreshold));
}
return trace_container;
}

} // namespace

absl::Status StreamingTraceViewerProcessor::Reduce(
const SessionSnapshot& session_snapshot,
const std::vector<std::string>& map_output_files) {
if (map_output_files.empty()) {
return absl::InvalidArgumentError("map_output_files cannot be empty");
}

TF_ASSIGN_OR_RETURN(TraceViewOption trace_option,
GetTraceViewOption(options_));
tensorflow::profiler::TraceOptions profiler_trace_options =
TraceOptionsFromToolOptions(options_);

TraceEventsContainer merged_trace_container;

for (int i = 0; i < map_output_files.size(); ++i) {
const std::string& trace_events_sstable_path = map_output_files[i];
int host_id = i + 1;

TF_ASSIGN_OR_RETURN(
TraceEventsContainer trace_container,
LoadTraceContainerForHost(session_snapshot, trace_events_sstable_path,
trace_option, profiler_trace_options));

merged_trace_container.Merge(std::move(trace_container), host_id);
}

std::string trace_viewer_json;
JsonTraceOptions json_trace_options;

tensorflow::profiler::TraceDeviceType device_type =
tensorflow::profiler::TraceDeviceType::kUnknownDevice;
if (IsTpuTrace(merged_trace_container.trace())) {
device_type = TraceDeviceType::kTpu;
}
json_trace_options.details =
TraceOptionsToDetails(device_type, profiler_trace_options);
IOBufferAdapter adapter(&trace_viewer_json);
TraceEventsToJson<IOBufferAdapter, TraceEventsContainer, RawData>(
json_trace_options, merged_trace_container, &adapter);

SetOutput(trace_viewer_json, "application/json");
return absl::OkStatus();
}

// NOTE: We use "trace_viewer@" to distinguish from the non-streaming
// trace_viewer. The "@" suffix is used to indicate that this tool
// supports streaming.
Expand Down
22 changes: 14 additions & 8 deletions xprof/convert/streaming_trace_viewer_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,32 @@ namespace xprof {
class StreamingTraceViewerProcessor : public ProfileProcessor {
public:
explicit StreamingTraceViewerProcessor(
const tensorflow::profiler::ToolOptions&) {}
const tensorflow::profiler::ToolOptions& options)
: options_(options) {}

absl::Status ProcessSession(
const tensorflow::profiler::SessionSnapshot& session_snapshot,
const tensorflow::profiler::ToolOptions& options) final;

absl::StatusOr<std::string> Map(const std::string& xspace_path) override;

absl::StatusOr<std::string> Map(
const tensorflow::profiler::SessionSnapshot& session_snapshot,
const std::string& hostname,
const tensorflow::profiler::XSpace& xspace) override {
return absl::UnimplementedError(
"Map not implemented for StreamingTraceViewerProcessor");
}
const tensorflow::profiler::XSpace& xspace) override;

absl::Status Reduce(
const tensorflow::profiler::SessionSnapshot& session_snapshot,
const std::vector<std::string>& map_output_files) override {
return absl::UnimplementedError(
"Reduce not implemented for StreamingTraceViewerProcessor");
const std::vector<std::string>& map_output_files) override;

bool ShouldUseWorkerService(
const tensorflow::profiler::SessionSnapshot& session_snapshot,
const tensorflow::profiler::ToolOptions& options) const override {
return session_snapshot.XSpaceSize() > 1;
}

private:
tensorflow::profiler::ToolOptions options_;
};

} // namespace xprof
Expand Down
Loading