diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD index 2a9ec624..dc35a99e 100644 --- a/xprof/convert/BUILD +++ b/xprof/convert/BUILD @@ -206,6 +206,7 @@ cc_library( "@org_xprof//xprof/convert/trace_viewer:trace_events_to_json", "@org_xprof//xprof/convert/trace_viewer:trace_options", "@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility", + "@tsl//tsl/platform:path", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@xla//xla/tsl/platform:env", "@xla//xla/tsl/platform:errors", @@ -215,6 +216,33 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "streaming_trace_viewer_processor_test", + srcs = ["streaming_trace_viewer_processor_test.cc"], + deps = [ + ":repository", + ":streaming_trace_viewer_processor", + ":tool_options", + "//file/base:path", + "//file/util:temp_path", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@xla//xla/tsl/lib/core:status_test_util", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform:test", + "@xla//xla/tsl/profiler/utils:xplane_builder", + "@xla//xla/tsl/profiler/utils:xplane_schema", + "@xla//xla/tsl/profiler/utils:xplane_utils", + ], +) + cc_library( name = "inference_stats_processor", srcs = ["inference_stats_processor.cc"], diff --git a/xprof/convert/streaming_trace_viewer_processor.cc b/xprof/convert/streaming_trace_viewer_processor.cc index ff422d29..f3d1e5ec 100644 --- a/xprof/convert/streaming_trace_viewer_processor.cc +++ b/xprof/convert/streaming_trace_viewer_processor.cc @@ -3,19 +3,23 @@ #include #include #include +#include #include #include +#include #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "google/protobuf/arena.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/file_system.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/profiler/utils/timespan.h" +#include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/convert/preprocess_single_host_xplane.h" #include "xprof/convert/process_megascale_dcn.h" @@ -196,6 +200,176 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession( return absl::OkStatus(); } +absl::StatusOr StreamingTraceViewerProcessor::Map( + const std::string& xspace_path) { + std::vector xspace_paths = {xspace_path}; + TF_ASSIGN_OR_RETURN( + SessionSnapshot session_snapshot, + SessionSnapshot::Create(xspace_paths, /*xspaces=*/std::nullopt)); + // get xspace from session snapshot + std::string hostname = session_snapshot.GetHostname(0); + google::protobuf::Arena arena; + TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena)); + + return Map(session_snapshot, hostname, *xspace); +} + +absl::StatusOr StreamingTraceViewerProcessor::Map( + const SessionSnapshot& session_snapshot, const std::string& hostname, + const XSpace& xspace) { + XSpace temp_xspace = xspace; + tensorflow::profiler::PreprocessSingleHostXSpace(&temp_xspace, + /*step_grouping=*/true, + /*derived_timeline=*/true); + tensorflow::profiler::ProcessMegascaleDcn(&temp_xspace); + + auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_LEVELDB, hostname); + auto trace_events_metadata_sstable_path = + session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB, + hostname); + auto trace_events_prefix_trie_sstable_path = + session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType:: + TRACE_EVENTS_PREFIX_TRIE_LEVELDB, + hostname); + + if (!trace_events_sstable_path.has_value() || + !trace_events_metadata_sstable_path.has_value() || + !trace_events_prefix_trie_sstable_path.has_value()) { + return tsl::errors::Unimplemented( + "streaming trace viewer hasn't been supported in Cloud AI"); + } + + if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) { + TraceEventsContainer trace_container; + tensorflow::profiler::ConvertXSpaceToTraceEventsContainer( + hostname, temp_xspace, &trace_container); + std::unique_ptr trace_events_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_sstable_path, &trace_events_file)); + std::unique_ptr trace_events_metadata_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_metadata_sstable_path, &trace_events_metadata_file)); + std::unique_ptr trace_events_prefix_trie_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_prefix_trie_sstable_path, + &trace_events_prefix_trie_file)); + TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables( + std::move(trace_events_file), std::move(trace_events_metadata_file), + std::move(trace_events_prefix_trie_file))); + } + return *trace_events_sstable_path; +} + +namespace { + +absl::StatusOr LoadTraceContainerForHost( + const SessionSnapshot& session_snapshot, + const std::string& trace_events_sstable_path, + const TraceViewOption& trace_option, + const tensorflow::profiler::TraceOptions& profiler_trace_options) { + absl::string_view filename = tsl::io::Basename(trace_events_sstable_path); + absl::ConsumeSuffix(&filename, ".SSTABLE"); + std::string hostname = std::string(filename); + + TraceEventsLevelDbFilePaths file_paths; + file_paths.trace_events_file_path = trace_events_sstable_path; + // These should exist as they were created in the Map phase. + auto metadata_path = session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB, + hostname); + auto trie_path = session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_EVENTS_PREFIX_TRIE_LEVELDB, + hostname); + if (!metadata_path || !trie_path) { + return tsl::errors::Internal( + "Could not find metadata or trie file paths for host: ", hostname); + } + file_paths.trace_events_metadata_file_path = *metadata_path; + file_paths.trace_events_prefix_trie_file_path = *trie_path; + + TraceEventsContainer trace_container; + if (!trace_option.event_name.empty()) { + TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable( + file_paths.trace_events_metadata_file_path, + file_paths.trace_events_file_path, trace_option.event_name, + static_cast(std::round(trace_option.start_time_ms * 1E9)), + static_cast(std::round(trace_option.duration_ms * 1E9)), + trace_option.unique_id)); + } else if (!trace_option.search_prefix.empty()) { // Search Events Request + if (tsl::Env::Default() + ->FileExists(file_paths.trace_events_prefix_trie_file_path) + .ok()) { + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable( + file_paths, trace_option.search_prefix, + std::move(trace_events_filter))); + } + } else { + auto visibility_filter = std::make_unique( + tsl::profiler::MilliSpan(trace_option.start_time_ms, + trace_option.end_time_ms), + trace_option.resolution, profiler_trace_options); + // Trace smaller than threshold will be disabled from streaming. + constexpr int64_t kDisableStreamingThreshold = 500000; + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( + file_paths, std::move(trace_events_filter), + std::move(visibility_filter), kDisableStreamingThreshold)); + } + return trace_container; +} + +} // namespace + +absl::Status StreamingTraceViewerProcessor::Reduce( + const SessionSnapshot& session_snapshot, + const std::vector& map_output_files) { + if (map_output_files.empty()) { + return absl::InvalidArgumentError("map_output_files cannot be empty"); + } + + TF_ASSIGN_OR_RETURN(TraceViewOption trace_option, + GetTraceViewOption(options_)); + tensorflow::profiler::TraceOptions profiler_trace_options = + TraceOptionsFromToolOptions(options_); + + TraceEventsContainer merged_trace_container; + + for (int i = 0; i < map_output_files.size(); ++i) { + const std::string& trace_events_sstable_path = map_output_files[i]; + int host_id = i + 1; + + TF_ASSIGN_OR_RETURN( + TraceEventsContainer trace_container, + LoadTraceContainerForHost(session_snapshot, trace_events_sstable_path, + trace_option, profiler_trace_options)); + + merged_trace_container.Merge(std::move(trace_container), host_id); + } + + std::string trace_viewer_json; + JsonTraceOptions json_trace_options; + + tensorflow::profiler::TraceDeviceType device_type = + tensorflow::profiler::TraceDeviceType::kUnknownDevice; + if (IsTpuTrace(merged_trace_container.trace())) { + device_type = TraceDeviceType::kTpu; + } + json_trace_options.details = + TraceOptionsToDetails(device_type, profiler_trace_options); + IOBufferAdapter adapter(&trace_viewer_json); + TraceEventsToJson( + json_trace_options, merged_trace_container, &adapter); + + SetOutput(trace_viewer_json, "application/json"); + return absl::OkStatus(); +} + // NOTE: We use "trace_viewer@" to distinguish from the non-streaming // trace_viewer. The "@" suffix is used to indicate that this tool // supports streaming. diff --git a/xprof/convert/streaming_trace_viewer_processor.h b/xprof/convert/streaming_trace_viewer_processor.h index 9d7f3c77..dfb80136 100644 --- a/xprof/convert/streaming_trace_viewer_processor.h +++ b/xprof/convert/streaming_trace_viewer_processor.h @@ -16,26 +16,32 @@ namespace xprof { class StreamingTraceViewerProcessor : public ProfileProcessor { public: explicit StreamingTraceViewerProcessor( - const tensorflow::profiler::ToolOptions&) {} + const tensorflow::profiler::ToolOptions& options) + : options_(options) {} absl::Status ProcessSession( const tensorflow::profiler::SessionSnapshot& session_snapshot, const tensorflow::profiler::ToolOptions& options) final; + absl::StatusOr Map(const std::string& xspace_path) override; + absl::StatusOr Map( const tensorflow::profiler::SessionSnapshot& session_snapshot, const std::string& hostname, - const tensorflow::profiler::XSpace& xspace) override { - return absl::UnimplementedError( - "Map not implemented for StreamingTraceViewerProcessor"); - } + const tensorflow::profiler::XSpace& xspace) override; absl::Status Reduce( const tensorflow::profiler::SessionSnapshot& session_snapshot, - const std::vector& map_output_files) override { - return absl::UnimplementedError( - "Reduce not implemented for StreamingTraceViewerProcessor"); + const std::vector& map_output_files) override; + + bool ShouldUseWorkerService( + const tensorflow::profiler::SessionSnapshot& session_snapshot, + const tensorflow::profiler::ToolOptions& options) const override { + return session_snapshot.XSpaceSize() > 1; } + + private: + tensorflow::profiler::ToolOptions options_; }; } // namespace xprof diff --git a/xprof/convert/streaming_trace_viewer_processor_test.cc b/xprof/convert/streaming_trace_viewer_processor_test.cc new file mode 100644 index 00000000..5b90139d --- /dev/null +++ b/xprof/convert/streaming_trace_viewer_processor_test.cc @@ -0,0 +1,332 @@ +#include "xprof/convert/streaming_trace_viewer_processor.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "file/base/path.h" +#include "file/util/temp_path.h" +#include "testing/base/public/gmock.h" +#include "" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xprof/convert/repository.h" +#include "xprof/convert/tool_options.h" + +namespace xprof { +using ::tensorflow::profiler::GetParamWithDefault; +using ::tensorflow::profiler::SessionSnapshot; +using ::tensorflow::profiler::ToolOptions; +using ::tensorflow::profiler::XEvent; +using ::tensorflow::profiler::XEventMetadata; +using ::tensorflow::profiler::XLine; +using ::tensorflow::profiler::XPlane; +using ::tensorflow::profiler::XSpace; +using ::tensorflow::profiler::XStat; +using ::tensorflow::profiler::XStatMetadata; +using ::testing::HasSubstr; +using ::testing::status::StatusIs; +using ::tsl::profiler::kHostThreadsPlaneName; // Standard plane name + +// This struct and function are copied from the .cc file for test purposes. +struct TraceViewOption { + uint64_t resolution = 0; + double start_time_ms = 0.0; + double end_time_ms = 0.0; + std::string event_name = ""; + std::string search_prefix = ""; + double duration_ms = 0.0; + uint64_t unique_id = 0; +}; + +absl::StatusOr GetTraceViewOption(const ToolOptions& options) { + TraceViewOption trace_options; + auto start_time_ms_opt = + GetParamWithDefault(options, "start_time_ms", "0.0"); + auto end_time_ms_opt = + GetParamWithDefault(options, "end_time_ms", "0.0"); + auto resolution_opt = + GetParamWithDefault(options, "resolution", "0"); + trace_options.event_name = + GetParamWithDefault(options, "event_name", ""); + trace_options.search_prefix = + GetParamWithDefault(options, "search_prefix", ""); + auto duration_ms_opt = + GetParamWithDefault(options, "duration_ms", "0.0"); + auto unique_id_opt = + GetParamWithDefault(options, "unique_id", "0"); + + if (!absl::SimpleAtoi(resolution_opt, &trace_options.resolution) || + !absl::SimpleAtod(start_time_ms_opt, &trace_options.start_time_ms) || + !absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms) || + !absl::SimpleAtoi(unique_id_opt, &trace_options.unique_id) || + !absl::SimpleAtod(duration_ms_opt, &trace_options.duration_ms)) { + return tsl::errors::InvalidArgument("wrong arguments"); + } + return trace_options; +} + +// Helper function to create a simple XSpace for testing +XSpace CreateTestXSpace(int num_events) { + XSpace space; + XPlane* plane = space.add_planes(); + plane->set_name(kHostThreadsPlaneName); // Use standard plane name + + // Setup Event Metadata + int64_t event1_id = + static_cast(tsl::profiler::HostEventType::kTraceContext); + XEventMetadata& event1_metadata = + (*plane->mutable_event_metadata())[event1_id]; + event1_metadata.set_id(event1_id); + event1_metadata.set_name( + GetHostEventTypeStr(tsl::profiler::HostEventType::kTraceContext)); + + int64_t event2_id = + static_cast(tsl::profiler::HostEventType::kSessionRun); + XEventMetadata& event2_metadata = + (*plane->mutable_event_metadata())[event2_id]; + event2_metadata.set_id(event2_id); + event2_metadata.set_name( + GetHostEventTypeStr(tsl::profiler::HostEventType::kSessionRun)); + + // Setup Stat Metadata + const int64_t kGroupIdType = + static_cast(tsl::profiler::StatType::kGroupId); + XStatMetadata& group_id_metadata = + (*plane->mutable_stat_metadata())[kGroupIdType]; + group_id_metadata.set_id(kGroupIdType); + group_id_metadata.set_name(GetStatTypeStr(tsl::profiler::StatType::kGroupId)); + + XLine* line = plane->add_lines(); + line->set_id(1); + line->set_name("Test Line"); + + if (num_events > 0) { + XEvent* event = line->add_events(); + event->set_metadata_id(event1_id); + event->set_offset_ps(1000000000); + event->set_duration_ps(100000000); + XStat* stat = event->add_stats(); + stat->set_metadata_id(kGroupIdType); + stat->set_int64_value(123); + } + if (num_events > 1) { + XEvent* event2 = line->add_events(); + event2->set_metadata_id(event2_id); + event2->set_offset_ps(1200000000); + event2->set_duration_ps(50000000); + XStat* stat2 = event2->add_stats(); + stat2->set_metadata_id(kGroupIdType); + stat2->set_int64_value(456); + } + return space; +} + +class StreamingTraceViewerProcessorTest : public ::testing::Test { + protected: + void SetUp() override { + // Create a temporary directory for test files. + temp_path_ = std::make_unique(TempPath::Local); + session_dir_ = file::JoinPath(temp_path_->path(), "session"); + TF_CHECK_OK(tsl::Env::Default()->CreateDir(session_dir_)); + } + + // Helper to create a SessionSnapshot by writing XSpaces to temp files + absl::StatusOr CreateSnapshot( + const absl::flat_hash_map& host_xspaces) { + std::vector xspace_paths; + for (const auto& pair : host_xspaces) { + const std::string& host_name = pair.first; + const XSpace& xspace = pair.second; + std::string xspace_path = + file::JoinPath(session_dir_, host_name + ".xspace"); + TF_RETURN_IF_ERROR( + tsl::WriteBinaryProto(tsl::Env::Default(), xspace_path, xspace)); + xspace_paths.push_back(xspace_path); + } + std::sort(xspace_paths.begin(), xspace_paths.end()); + return SessionSnapshot::Create(std::move(xspace_paths), + /*xspaces=*/std::nullopt); + } + + std::string session_dir_; + std::unique_ptr temp_path_; +}; + +namespace { + +// GET_TRACE_VIEW_OPTION TESTS +TEST_F(StreamingTraceViewerProcessorTest, GetTraceViewOptionValid) { + ToolOptions options; + options["start_time_ms"] = "100.5"; + options["end_time_ms"] = "200.0"; + options["resolution"] = "1000"; + options["event_name"] = "test_event"; + options["search_prefix"] = "prefix"; + options["duration_ms"] = "10.0"; + options["unique_id"] = "12345"; + + TF_ASSERT_OK_AND_ASSIGN(TraceViewOption trace_option, + GetTraceViewOption(options)); + + EXPECT_DOUBLE_EQ(trace_option.start_time_ms, 100.5); + EXPECT_DOUBLE_EQ(trace_option.end_time_ms, 200.0); + EXPECT_EQ(trace_option.resolution, 1000); + EXPECT_EQ(trace_option.event_name, "test_event"); + EXPECT_EQ(trace_option.search_prefix, "prefix"); + EXPECT_DOUBLE_EQ(trace_option.duration_ms, 10.0); + EXPECT_EQ(trace_option.unique_id, 12345); +} + +TEST_F(StreamingTraceViewerProcessorTest, GetTraceViewOptionDefaults) { + ToolOptions options; + TF_ASSERT_OK_AND_ASSIGN(TraceViewOption trace_option, + GetTraceViewOption(options)); + + EXPECT_DOUBLE_EQ(trace_option.start_time_ms, 0.0); + EXPECT_DOUBLE_EQ(trace_option.end_time_ms, 0.0); + EXPECT_EQ(trace_option.resolution, 0); + EXPECT_EQ(trace_option.event_name, ""); + EXPECT_EQ(trace_option.search_prefix, ""); + EXPECT_DOUBLE_EQ(trace_option.duration_ms, 0.0); + EXPECT_EQ(trace_option.unique_id, 0); +} + +TEST_F(StreamingTraceViewerProcessorTest, GetTraceViewOptionInvalidNumber) { + ToolOptions options; + options["resolution"] = "not_a_number"; + EXPECT_THAT(GetTraceViewOption(options), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("wrong arguments"))); +} + +// MAP TESTS +TEST_F(StreamingTraceViewerProcessorTest, MapCreatesFiles) { + XSpace space = CreateTestXSpace(2); + absl::flat_hash_map host_xspaces = {{"host1", space}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions empty_options; + StreamingTraceViewerProcessor processor(empty_options); + + TF_ASSERT_OK_AND_ASSIGN(std::string map_output, + processor.Map(snapshot, "host1", space)); +} + +// REDUCE TESTS +TEST_F(StreamingTraceViewerProcessorTest, ReduceEmptyMapOutput) { + absl::flat_hash_map host_xspaces = { + {"host1", CreateTestXSpace(0)}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + ToolOptions empty_options; + StreamingTraceViewerProcessor processor(empty_options); + EXPECT_THAT(processor.Reduce(snapshot, {}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map_output_files cannot be empty"))); +} + +TEST_F(StreamingTraceViewerProcessorTest, ReduceSingleHost) { + XSpace space = CreateTestXSpace(2); + absl::flat_hash_map host_xspaces = {{"host1", space}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions tool_options; + tool_options["end_time_ms"] = "2000.0"; // 2s + tool_options["resolution"] = "1"; + StreamingTraceViewerProcessor processor(tool_options); + + TF_ASSERT_OK_AND_ASSIGN(std::string map_output, + processor.Map(snapshot, "host1", space)); + + TF_EXPECT_OK(processor.Reduce(snapshot, {map_output})); +} + +TEST_F(StreamingTraceViewerProcessorTest, ReduceMultiHost) { + XSpace space1 = CreateTestXSpace(1); + XSpace space2 = CreateTestXSpace(1); + absl::flat_hash_map host_xspaces = {{"host1", space1}, + {"host2", space2}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions tool_options; + tool_options["end_time_ms"] = "2000.0"; + tool_options["resolution"] = "1"; + StreamingTraceViewerProcessor processor(tool_options); + + TF_ASSERT_OK_AND_ASSIGN(std::string map_output1, + processor.Map(snapshot, "host1", space1)); + TF_ASSERT_OK_AND_ASSIGN(std::string map_output2, + processor.Map(snapshot, "host2", space2)); + + TF_EXPECT_OK(processor.Reduce(snapshot, {map_output1, map_output2})); +} + +TEST_F(StreamingTraceViewerProcessorTest, ReduceWithEventNameSearch) { + XSpace space = CreateTestXSpace(2); + absl::flat_hash_map host_xspaces = {{"host1", space}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions tool_options; + tool_options["event_name"] = "SessionRun"; // Updated name + tool_options["start_time_ms"] = "1100.0"; + tool_options["duration_ms"] = "100.0"; + tool_options["unique_id"] = "0"; + StreamingTraceViewerProcessor processor(tool_options); + + TF_ASSERT_OK_AND_ASSIGN(std::string map_output, + processor.Map(snapshot, "host1", space)); +} + +TEST_F(StreamingTraceViewerProcessorTest, ReduceWithSearchPrefix) { + XSpace space = CreateTestXSpace(2); + absl::flat_hash_map host_xspaces = {{"host1", space}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions tool_options; + tool_options["search_prefix"] = "Sess"; // Updated prefix + StreamingTraceViewerProcessor processor(tool_options); + + TF_ASSERT_OK_AND_ASSIGN(std::string map_output, + processor.Map(snapshot, "host1", space)); + TF_EXPECT_OK(processor.Reduce(snapshot, {map_output})); +} + +// PROCESS SESSION TEST +TEST_F(StreamingTraceViewerProcessorTest, ProcessSessionEndToEnd) { + XSpace space1 = CreateTestXSpace(1); + XSpace space2 = CreateTestXSpace(1); + absl::flat_hash_map host_xspaces = {{"host1", space1}, + {"host2", space2}}; + TF_ASSERT_OK_AND_ASSIGN(SessionSnapshot snapshot, + CreateSnapshot(host_xspaces)); + + ToolOptions tool_options; + tool_options["end_time_ms"] = "2000.0"; + tool_options["resolution"] = "1"; + StreamingTraceViewerProcessor processor(tool_options); + + TF_EXPECT_OK(processor.ProcessSession(snapshot, tool_options)); +} + +} // namespace +} // namespace xprof