Skip to content

Commit c886d9e

Browse files
Matt-Hurdcopybara-github
authored andcommitted
Add all_hosts information to the session_snapshot and move the device collision logic for trace_viewer to CreateTraceEventsContainer.
PiperOrigin-RevId: 831997964
1 parent 58941a6 commit c886d9e

19 files changed

+99
-215
lines changed

plugin/xprof/convert/raw_to_tool_data.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@ def process_raw_trace(raw_trace):
4141
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
4242

4343

44-
def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
45-
filenames, tool, params):
44+
def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
45+
params):
4646
"""Helper function for getting an XSpace tool from a bytes string.
4747
4848
Args:
4949
xspace_byte_list: A list of byte strings read from a XSpace proto file.
50-
all_hosts: A list of all hosts in the session.
5150
filenames: Names of the read files.
5251
tool: A string of tool name.
5352
params: user input parameters.
@@ -58,7 +57,7 @@ def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
5857
# pylint:disable=dangerous-default-value
5958
def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
6059
return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
61-
xspace_arg, all_hosts, filenames, tool_arg, params)
60+
xspace_arg, filenames, tool_arg, params)
6261
# pylint:enable=dangerous-default-value
6362

6463
return xspace_to_tool_data(xspace_byte_list, tool, params,
@@ -74,26 +73,22 @@ def xspace_to_tool_names(xspace_paths):
7473
Returns:
7574
Returns a list of tool names.
7675
"""
77-
# xspace_to_tools_data expects all_hosts as the second argument, passing an
78-
# empty list.
7976
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
80-
xspace_paths, [], 'tool_names', {})
77+
xspace_paths, 'tool_names')
8178
if success:
8279
return [tool for tool in raw_data.decode().split(',')]
8380
return []
8481

8582

8683
def xspace_to_tool_data(
8784
xspace_paths,
88-
all_hosts,
8985
tool,
9086
params,
9187
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
9288
"""Converts XSpace to tool data string.
9389
9490
Args:
9591
xspace_paths: A list of XSpace paths.
96-
all_hosts: A list of all hosts in the session.
9792
tool: A string of tool name.
9893
params: user input parameters.
9994
xspace_wrapper_func: A callable that takes a list of strings and a tool and
@@ -117,31 +112,26 @@ def xspace_to_tool_data(
117112
if tool == 'trace_viewer':
118113
# Trace viewer handles one host at a time.
119114
assert len(xspace_paths) == 1
120-
raw_data, success = xspace_wrapper_func(
121-
xspace_paths, all_hosts, tool, options)
115+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
122116
if success:
123117
data = process_raw_trace(raw_data)
124118
elif tool == 'trace_viewer@':
125119
options = params.get('trace_viewer_options', {})
126120
options['use_saved_result'] = params.get('use_saved_result', True)
127-
options['hosts'] = all_hosts
128-
raw_data, success = xspace_wrapper_func(
129-
xspace_paths, all_hosts, tool, options)
121+
options['hosts'] = params.get('hosts', [])
122+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
130123
if success:
131124
data = raw_data
132125
elif tool == 'overview_page':
133-
json_data, success = xspace_wrapper_func(
134-
xspace_paths, all_hosts, tool, options)
126+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
135127
if success:
136128
data = json_data
137129
elif tool == 'input_pipeline_analyzer':
138-
json_data, success = xspace_wrapper_func(
139-
xspace_paths, all_hosts, tool, options)
130+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
140131
if success:
141132
data = json_data
142133
elif tool == 'framework_op_stats':
143-
json_data, success = xspace_wrapper_func(
144-
xspace_paths, all_hosts, tool, options)
134+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
145135
if success:
146136
if tqx == 'out:csv':
147137
data = csv_writer.json_to_csv(json_data)
@@ -152,16 +142,15 @@ def xspace_to_tool_data(
152142
# TODO(b/419013992): Remove this tool completely as it has been deprecated
153143
legacy_tool = 'tensorflow_stats'
154144
json_data, success = xspace_wrapper_func(
155-
xspace_paths, all_hosts, legacy_tool, options
145+
xspace_paths, legacy_tool, options
156146
)
157147
if success:
158148
if tqx == 'out:csv':
159149
data = csv_writer.json_to_csv(json_data)
160150
else:
161151
data = json_data
162152
elif tool == 'kernel_stats':
163-
json_data, success = xspace_wrapper_func(
164-
xspace_paths, all_hosts, tool, options)
153+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
165154
if success:
166155
if tqx == 'out:csv':
167156
data = csv_writer.json_to_csv(json_data)
@@ -170,44 +159,37 @@ def xspace_to_tool_data(
170159
elif tool == 'memory_profile':
171160
# Memory profile handles one host at a time.
172161
assert len(xspace_paths) == 1
173-
raw_data, success = xspace_wrapper_func(
174-
xspace_paths, all_hosts, tool, options)
162+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
175163
if success:
176164
data = raw_data
177165
elif tool == 'pod_viewer':
178-
raw_data, success = xspace_wrapper_func(
179-
xspace_paths, all_hosts, tool, options)
166+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
180167
if success:
181168
data = raw_data
182169
elif tool == 'op_profile':
183170
options['group_by'] = params.get('group_by', 'program')
184-
raw_data, success = xspace_wrapper_func(
185-
xspace_paths, all_hosts, tool, options)
171+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
186172
if success:
187173
data = raw_data
188174
elif tool == 'hlo_op_profile':
189175
options['group_by'] = params.get('group_by', 'program')
190-
raw_data, success = xspace_wrapper_func(
191-
xspace_paths, all_hosts, tool, options)
176+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
192177
if success:
193178
data = raw_data
194179
elif tool == 'hlo_stats':
195-
json_data, success = xspace_wrapper_func(
196-
xspace_paths, all_hosts, tool, options)
180+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
197181
if success:
198182
data = json_data
199183
elif tool == 'roofline_model':
200-
json_data, success = xspace_wrapper_func(
201-
xspace_paths, all_hosts, tool, options)
184+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
202185
if success:
203186
data = json_data
204187
elif tool == 'graph_viewer':
205188
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
206189
graph_html_type = 'graph'
207190
options = params.get('graph_viewer_options', {})
208191
options['use_saved_result'] = params.get('use_saved_result', True)
209-
raw_data, success = xspace_wrapper_func(
210-
xspace_paths, all_hosts, tool, options)
192+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
211193
if success:
212194
data = raw_data
213195
content_type = 'text/plain'
@@ -231,21 +213,18 @@ def xspace_to_tool_data(
231213
'view_memory_allocation_timeline': view_memory_allocation_timeline,
232214
'memory_space': params.get('memory_space', ''),
233215
}
234-
raw_data, success = xspace_wrapper_func(
235-
xspace_paths, all_hosts, tool, options)
216+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
236217
if success:
237218
data = raw_data
238219
if view_memory_allocation_timeline:
239220
content_type = 'text/html'
240221
elif tool == 'megascale_stats':
241222
options = {'host_name': params.get('host')}
242-
json_data, success = xspace_wrapper_func(
243-
xspace_paths, all_hosts, tool, options)
223+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
244224
if success:
245225
data = json_data
246226
elif tool == 'inference_profile':
247-
json_data, success = xspace_wrapper_func(
248-
xspace_paths, all_hosts, tool, options)
227+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
249228
if success:
250229
data = json_data
251230
else:

plugin/xprof/convert/raw_to_tool_data_test.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ def test_using_old_tool_format_maps_to_new_format(self):
2727
xspace_paths=["/path/to/xspace"],
2828
tool="trace_viewer@^",
2929
params={},
30-
all_hosts=[],
31-
xspace_wrapper_func=lambda paths, hosts, tool, options: (
32-
tool.encode(),
33-
True,
34-
),
30+
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
3531
)
3632

3733
self.assertEqual(data, b"trace_viewer@")
@@ -42,11 +38,7 @@ def test_using_new_tool_format_does_not_map_to_old_format(self):
4238
xspace_paths=["/path/to/xspace"],
4339
tool="trace_viewer@",
4440
params={},
45-
all_hosts=[],
46-
xspace_wrapper_func=lambda paths, hosts, tool, options: (
47-
tool.encode(),
48-
True,
49-
),
41+
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
5042
)
5143

5244
self.assertEqual(data, b"trace_viewer@")

plugin/xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_tools_are_in_list(self):
114114

115115
def test_overview_page(self):
116116
xspace_filenames = self._get_session_snapshot()
117-
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, [],
117+
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames,
118118
'overview_page', {})
119119
result = json.loads(result)
120120
run_environment = result[2]
@@ -123,9 +123,7 @@ def test_overview_page(self):
123123

124124
def test_overview_page_creates_cache(self):
125125
xspace_filenames = self._get_session_snapshot()
126-
raw_to_tool_data.xspace_to_tool_data(
127-
xspace_filenames, [], 'overview_page', {}
128-
)
126+
raw_to_tool_data.xspace_to_tool_data(xspace_filenames, 'overview_page', {})
129127
profile_plugin_root = os.path.join(log_dir, 'plugins/profile')
130128
# The session exists under a director whose name is time-dependent.
131129
cache_glob = os.path.join(profile_plugin_root, '*', '*.op_stats.pb')
@@ -134,7 +132,7 @@ def test_overview_page_creates_cache(self):
134132
def test_op_profile(self):
135133
xspace_filenames = self._get_session_snapshot()
136134
result, _ = raw_to_tool_data.xspace_to_tool_data(
137-
xspace_filenames, [], 'op_profile', {'group_by': 'category'}
135+
xspace_filenames, 'op_profile', {'group_by': 'category'}
138136
)
139137
result = json.loads(result)
140138
logging.info(result)
@@ -153,7 +151,7 @@ def test_op_profile(self):
153151
def test_device_trace_contains_threads(self):
154152
xspace_filenames = self._get_session_snapshot()
155153
result, _ = raw_to_tool_data.xspace_to_tool_data(
156-
xspace_filenames, [], 'trace_viewer', {}
154+
xspace_filenames, 'trace_viewer', {}
157155
)
158156
result = json.loads(result)
159157
thread_names = []

plugin/xprof/profile_plugin.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def hlo_module_list_route(
709709

710710
def _get_valid_hosts(
711711
self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
712-
) -> tuple[List[str], List[epath.Path], List[str]]:
712+
) -> tuple[List[str], List[epath.Path]]:
713713
"""Retrieves and validates the hosts and asset paths for a run and tool.
714714
715715
Args:
@@ -720,7 +720,7 @@ def _get_valid_hosts(
720720
host: The single host parameter.
721721
722722
Returns:
723-
A tuple containing (selected_hosts, asset_paths, all_hosts).
723+
A tuple containing (selected_hosts, asset_paths).
724724
725725
Raises:
726726
FileNotFoundError: If a required xplane file for the specified host(s)
@@ -786,9 +786,7 @@ def _get_valid_hosts(
786786
'Host must be specified for tool %s in run %s' % (tool, run)
787787
)
788788

789-
all_hosts = list(all_xplane_files.keys())
790-
791-
return selected_hosts, asset_paths, all_hosts
789+
return selected_hosts, asset_paths
792790

793791
def data_impl(
794792
self, request: wrappers.Request
@@ -877,7 +875,7 @@ def data_impl(
877875

878876
_, content_encoding = None, None
879877
if use_xplane(tool):
880-
selected_hosts, asset_paths, all_hosts = self._get_valid_hosts(
878+
selected_hosts, asset_paths = self._get_valid_hosts(
881879
run_dir, run, tool, hosts_param, host
882880
)
883881
if not asset_paths:
@@ -886,7 +884,7 @@ def data_impl(
886884
params['hosts'] = selected_hosts
887885
try:
888886
data, content_type = convert.xspace_to_tool_data(
889-
asset_paths, all_hosts, tool, params)
887+
asset_paths, tool, params)
890888
except AttributeError as e:
891889
logger.warning('Error generating analysis results due to %s', e)
892890
raise AttributeError(

plugin/xprof/profile_plugin_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
465465
)
466466

467467
mock_xspace_to_tool_data.assert_called_once_with(
468-
[mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params
468+
[mock.ANY], 'trace_viewer@', expected_params
469469
)
470470
args, _ = mock_xspace_to_tool_data.call_args
471471
actual_path_list = args[0]

xprof/convert/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ cc_library(
195195
":repository",
196196
":tool_options",
197197
":xplane_to_trace_container",
198-
"@com_google_absl//absl/container:flat_hash_map",
199198
"@com_google_absl//absl/log",
200199
"@com_google_absl//absl/status",
201200
"@com_google_absl//absl/status:statusor",

xprof/convert/repository.cc

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ static auto* kHostDataSuffixes =
5858

5959
absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
6060
std::vector<std::string> xspace_paths,
61-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
62-
std::optional<std::vector<std::string>> all_hosts) {
61+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces) {
6362
if (xspace_paths.empty()) {
6463
return absl::InvalidArgumentError("Can not find XSpace path.");
6564
}
@@ -86,26 +85,7 @@ absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
8685
}
8786
}
8887

89-
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces),
90-
std::move(all_hosts));
91-
}
92-
93-
SessionSnapshot::SessionSnapshot(
94-
std::vector<std::string> xspace_paths,
95-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
96-
std::optional<std::vector<std::string>> all_hosts)
97-
: xspace_paths_(std::move(xspace_paths)),
98-
all_hosts_(std::move(all_hosts)),
99-
// If the snapshot was initialized by xspaces, the file path and run dir
100-
// is a path tensorflow can't read from or write to so any file IO
101-
// encapsulated in this class will be disabled in this mode.
102-
has_accessible_run_dir_(!xspaces.has_value()),
103-
xspaces_(std::move(xspaces)) {
104-
session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
105-
for (size_t i = 0; i < xspace_paths_.size(); ++i) {
106-
std::string host_name = GetHostname(i);
107-
hostname_map_[host_name] = i;
108-
}
88+
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces));
10989
}
11090

11191
absl::StatusOr<XSpace*> SessionSnapshot::GetXSpace(size_t index,
@@ -146,10 +126,6 @@ std::string SessionSnapshot::GetHostname(size_t index) const {
146126
return GetHostnameByPath(xspace_paths_.at(index));
147127
}
148128

149-
std::optional<std::vector<std::string>> SessionSnapshot::GetAllHosts() const {
150-
return all_hosts_;
151-
}
152-
153129
std::optional<std::string> SessionSnapshot::GetFilePath(
154130
absl::string_view toolname, absl::string_view hostname) const {
155131
if (!has_accessible_run_dir_) return std::nullopt;

0 commit comments

Comments
 (0)