Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 22 additions & 43 deletions plugin/xprof/convert/raw_to_tool_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ def process_raw_trace(raw_trace):
return ''.join(trace_events_json.TraceEventsJsonStream(trace))


def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
filenames, tool, params):
def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
params):
"""Helper function for getting an XSpace tool from a bytes string.

Args:
xspace_byte_list: A list of byte strings read from a XSpace proto file.
all_hosts: A list of all hosts in the session.
filenames: Names of the read files.
tool: A string of tool name.
params: user input parameters.
Expand All @@ -58,7 +57,7 @@ def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
# pylint:disable=dangerous-default-value
def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
xspace_arg, all_hosts, filenames, tool_arg, params)
xspace_arg, filenames, tool_arg, params)
# pylint:enable=dangerous-default-value

return xspace_to_tool_data(xspace_byte_list, tool, params,
Expand All @@ -74,26 +73,22 @@ def xspace_to_tool_names(xspace_paths):
Returns:
Returns a list of tool names.
"""
# xspace_to_tools_data expects all_hosts as the second argument, passing an
# empty list.
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
xspace_paths, [], 'tool_names', {})
xspace_paths, 'tool_names')
if success:
return [tool for tool in raw_data.decode().split(',')]
return []


def xspace_to_tool_data(
xspace_paths,
all_hosts,
tool,
params,
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
"""Converts XSpace to tool data string.

Args:
xspace_paths: A list of XSpace paths.
all_hosts: A list of all hosts in the session.
tool: A string of tool name.
params: user input parameters.
xspace_wrapper_func: A callable that takes a list of strings and a tool and
Expand All @@ -117,31 +112,26 @@ def xspace_to_tool_data(
if tool == 'trace_viewer':
# Trace viewer handles one host at a time.
assert len(xspace_paths) == 1
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = process_raw_trace(raw_data)
elif tool == 'trace_viewer@':
options = params.get('trace_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
options['hosts'] = all_hosts
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
options['hosts'] = params.get('hosts', [])
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'overview_page':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'input_pipeline_analyzer':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'framework_op_stats':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
if tqx == 'out:csv':
data = csv_writer.json_to_csv(json_data)
Expand All @@ -152,16 +142,15 @@ def xspace_to_tool_data(
# TODO(b/419013992): Remove this tool completely as it has been deprecated
legacy_tool = 'tensorflow_stats'
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, legacy_tool, options
xspace_paths, legacy_tool, options
)
if success:
if tqx == 'out:csv':
data = csv_writer.json_to_csv(json_data)
else:
data = json_data
elif tool == 'kernel_stats':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
if tqx == 'out:csv':
data = csv_writer.json_to_csv(json_data)
Expand All @@ -170,44 +159,37 @@ def xspace_to_tool_data(
elif tool == 'memory_profile':
# Memory profile handles one host at a time.
assert len(xspace_paths) == 1
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'pod_viewer':
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'op_profile':
options['group_by'] = params.get('group_by', 'program')
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'hlo_op_profile':
options['group_by'] = params.get('group_by', 'program')
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'hlo_stats':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'roofline_model':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'graph_viewer':
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
graph_html_type = 'graph'
options = params.get('graph_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
content_type = 'text/plain'
Expand All @@ -231,21 +213,18 @@ def xspace_to_tool_data(
'view_memory_allocation_timeline': view_memory_allocation_timeline,
'memory_space': params.get('memory_space', ''),
}
raw_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
if view_memory_allocation_timeline:
content_type = 'text/html'
elif tool == 'megascale_stats':
options = {'host_name': params.get('host')}
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'inference_profile':
json_data, success = xspace_wrapper_func(
xspace_paths, all_hosts, tool, options)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
else:
Expand Down
12 changes: 2 additions & 10 deletions plugin/xprof/convert/raw_to_tool_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def test_using_old_tool_format_maps_to_new_format(self):
xspace_paths=["/path/to/xspace"],
tool="trace_viewer@^",
params={},
all_hosts=[],
xspace_wrapper_func=lambda paths, hosts, tool, options: (
tool.encode(),
True,
),
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
)

self.assertEqual(data, b"trace_viewer@")
Expand All @@ -42,11 +38,7 @@ def test_using_new_tool_format_does_not_map_to_old_format(self):
xspace_paths=["/path/to/xspace"],
tool="trace_viewer@",
params={},
all_hosts=[],
xspace_wrapper_func=lambda paths, hosts, tool, options: (
tool.encode(),
True,
),
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
)

self.assertEqual(data, b"trace_viewer@")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_tools_are_in_list(self):

def test_overview_page(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, [],
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames,
'overview_page', {})
result = json.loads(result)
run_environment = result[2]
Expand All @@ -123,9 +123,7 @@ def test_overview_page(self):

def test_overview_page_creates_cache(self):
xspace_filenames = self._get_session_snapshot()
raw_to_tool_data.xspace_to_tool_data(
xspace_filenames, [], 'overview_page', {}
)
raw_to_tool_data.xspace_to_tool_data(xspace_filenames, 'overview_page', {})
profile_plugin_root = os.path.join(log_dir, 'plugins/profile')
# The session exists under a director whose name is time-dependent.
cache_glob = os.path.join(profile_plugin_root, '*', '*.op_stats.pb')
Expand All @@ -134,7 +132,7 @@ def test_overview_page_creates_cache(self):
def test_op_profile(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(
xspace_filenames, [], 'op_profile', {'group_by': 'category'}
xspace_filenames, 'op_profile', {'group_by': 'category'}
)
result = json.loads(result)
logging.info(result)
Expand All @@ -153,7 +151,7 @@ def test_op_profile(self):
def test_device_trace_contains_threads(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(
xspace_filenames, [], 'trace_viewer', {}
xspace_filenames, 'trace_viewer', {}
)
result = json.loads(result)
thread_names = []
Expand Down
12 changes: 5 additions & 7 deletions plugin/xprof/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def hlo_module_list_route(

def _get_valid_hosts(
self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
) -> tuple[List[str], List[epath.Path], List[str]]:
) -> tuple[List[str], List[epath.Path]]:
"""Retrieves and validates the hosts and asset paths for a run and tool.

Args:
Expand All @@ -720,7 +720,7 @@ def _get_valid_hosts(
host: The single host parameter.

Returns:
A tuple containing (selected_hosts, asset_paths, all_hosts).
A tuple containing (selected_hosts, asset_paths).

Raises:
FileNotFoundError: If a required xplane file for the specified host(s)
Expand Down Expand Up @@ -786,9 +786,7 @@ def _get_valid_hosts(
'Host must be specified for tool %s in run %s' % (tool, run)
)

all_hosts = list(all_xplane_files.keys())

return selected_hosts, asset_paths, all_hosts
return selected_hosts, asset_paths

def data_impl(
self, request: wrappers.Request
Expand Down Expand Up @@ -877,7 +875,7 @@ def data_impl(

_, content_encoding = None, None
if use_xplane(tool):
selected_hosts, asset_paths, all_hosts = self._get_valid_hosts(
selected_hosts, asset_paths = self._get_valid_hosts(
run_dir, run, tool, hosts_param, host
)
if not asset_paths:
Expand All @@ -886,7 +884,7 @@ def data_impl(
params['hosts'] = selected_hosts
try:
data, content_type = convert.xspace_to_tool_data(
asset_paths, all_hosts, tool, params)
asset_paths, tool, params)
except AttributeError as e:
logger.warning('Error generating analysis results due to %s', e)
raise AttributeError(
Expand Down
2 changes: 1 addition & 1 deletion plugin/xprof/profile_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
)

mock_xspace_to_tool_data.assert_called_once_with(
[mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params
[mock.ANY], 'trace_viewer@', expected_params
)
args, _ = mock_xspace_to_tool_data.call_args
actual_path_list = args[0]
Expand Down
1 change: 0 additions & 1 deletion xprof/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ cc_library(
":repository",
":tool_options",
":xplane_to_trace_container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
28 changes: 2 additions & 26 deletions xprof/convert/repository.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ static auto* kHostDataSuffixes =

absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
std::vector<std::string> xspace_paths,
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
std::optional<std::vector<std::string>> all_hosts) {
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces) {
if (xspace_paths.empty()) {
return absl::InvalidArgumentError("Can not find XSpace path.");
}
Expand All @@ -86,26 +85,7 @@ absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
}
}

return SessionSnapshot(std::move(xspace_paths), std::move(xspaces),
std::move(all_hosts));
}

SessionSnapshot::SessionSnapshot(
std::vector<std::string> xspace_paths,
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
std::optional<std::vector<std::string>> all_hosts)
: xspace_paths_(std::move(xspace_paths)),
all_hosts_(std::move(all_hosts)),
// If the snapshot was initialized by xspaces, the file path and run dir
// is a path tensorflow can't read from or write to so any file IO
// encapsulated in this class will be disabled in this mode.
has_accessible_run_dir_(!xspaces.has_value()),
xspaces_(std::move(xspaces)) {
session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
for (size_t i = 0; i < xspace_paths_.size(); ++i) {
std::string host_name = GetHostname(i);
hostname_map_[host_name] = i;
}
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces));
}

absl::StatusOr<XSpace*> SessionSnapshot::GetXSpace(size_t index,
Expand Down Expand Up @@ -146,10 +126,6 @@ std::string SessionSnapshot::GetHostname(size_t index) const {
return GetHostnameByPath(xspace_paths_.at(index));
}

std::optional<std::vector<std::string>> SessionSnapshot::GetAllHosts() const {
return all_hosts_;
}

std::optional<std::string> SessionSnapshot::GetFilePath(
absl::string_view toolname, absl::string_view hostname) const {
if (!has_accessible_run_dir_) return std::nullopt;
Expand Down
Loading