Skip to content

Commit 5d366c3

Browse files
committed
feat: profile macro to provide profiling data in the terminal
1 parent 9528846 commit 5d366c3

File tree

6 files changed

+194
-0
lines changed

6 files changed

+194
-0
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4"
55
jax = ">= 0.6"
66
tensorflow = ">= 2.17"
77
numpy = ">= 2"
8+
xprof = ">= 2.20"

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
16+
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1617
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1718
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
1819
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -89,6 +90,7 @@ Functors = "0.5"
8990
GPUArraysCore = "0.2"
9091
GPUCompiler = "1.3"
9192
HTTP = "1.10.15"
93+
JSON3 = "1.14.3"
9294
KernelAbstractions = "0.9.30"
9395
LLVM = "9.1"
9496
LLVMOpenMP_jll = "18.1.7"

deps/ReactantExtra/API.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3504,3 +3504,5 @@ REACTANT_ABI void EstimateRunTimeForInstruction(void *gpu_performance_model,
35043504
}
35053505

35063506
#endif
3507+
3508+

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ const npptr = Ref{Py}()
1515

1616
const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)
1717

18+
const xprofptr = Ref{Py}()
19+
const xprofconvertptr = Ref{Py}()
20+
21+
const XPROF_PROFILER_SUPPORTED = Ref{Bool}(false)
22+
1823
const NUMPY_SIMPLE_TYPES = Dict(
1924
Bool => :bool,
2025
Int8 => :int8,
@@ -54,11 +59,23 @@ function __init__()
5459
tensorflow SavedModel will not be \
5560
supported." exception = (err, catch_backtrace())
5661
end
62+
63+
try
64+
xprofptr[] = pyimport("xprof")
65+
xprofconvertptr[] = pyimport("xprof.convert.raw_to_tool_data")
66+
XPROF_PROFILER_SUPPORTED[] = true
67+
catch err
68+
@warn "Failed to import xprof. Xprof support will not be available." exception = (
69+
err, catch_backtrace()
70+
)
71+
end
72+
5773
return nothing
5874
end
5975

6076
include("overlays.jl")
6177
include("pycall.jl")
6278
include("saved_model.jl")
79+
include("xprof.jl")
6380

6481
end

ext/ReactantPythonCallExt/xprof.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Currently prototyping with xprof via python. we should instead add this into
2+
# the C++ API.
3+
function xspace_to_tools_data(filename::String, tool_name::String)
4+
if !XPROF_PROFILER_SUPPORTED[]
5+
error("xprof is not supported...")
6+
end
7+
8+
return String(
9+
pyconvert(
10+
Vector{UInt8},
11+
xprofconvertptr[].xspace_to_tool_data(pylist([filename]), tool_name, pydict())[0],
12+
),
13+
)
14+
end

src/Profiler.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module Profiler
22

33
import ..Reactant
44
using Sockets: Sockets
5+
using PrettyTables: PrettyTables
6+
using JSON3: JSON3
57

68
"""
79
with_profiler(f, trace_output_dir::String; trace_device=true, trace_host=true, create_perfetto_link=false)
@@ -195,4 +197,160 @@ mutable struct ProfileServer
195197
end
196198
end
197199

200+
function wrap_string(s; width=20)
201+
s_str = string(s)
202+
return join([s_str[i:min(i + width - 1, end)] for i in 1:width:length(s_str)], "\n")
203+
end
204+
205+
struct KernelStatsProfileResults
206+
data
207+
end
208+
209+
function KernelStatsProfileResults(data::JSON3.Object)
210+
cols = data["cols"]
211+
rows = data["rows"]
212+
keys = Tuple(Symbol.(get.(cols, "id")))
213+
table = Vector{NamedTuple}(undef, length(rows))
214+
215+
for (i, row) in enumerate(rows)
216+
vals = get.(row["c"], "v")
217+
table[i] = NamedTuple{keys}(vals)
218+
end
219+
220+
return KernelStatsProfileResults(table)
221+
end
222+
223+
function Base.show(io::IO, r::KernelStatsProfileResults)
224+
tbl = r.data
225+
226+
println(io, "╔════════════════╗")
227+
println(io, "║ Kernel Stats ║")
228+
println(io, "╚════════════════╝")
229+
230+
isempty(tbl) && return nothing
231+
232+
fields = fieldnames(typeof(tbl[1]))
233+
wrapped = split.(wrap_string.(fields; width=10), "\n")
234+
nrows = maximum(length.(wrapped))
235+
column_labels = [[get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows]
236+
237+
PrettyTables.pretty_table(
238+
io,
239+
tbl;
240+
line_breaks=true,
241+
maximum_data_column_widths=10,
242+
auto_wrap=true,
243+
column_labels,
244+
)
245+
return nothing
246+
end
247+
248+
struct FrameworkStatsProfileResults
249+
data
250+
end
251+
252+
function FrameworkStatsProfileResults(data::JSON3.Array{JSON3.Object})
253+
results = Vector{Vector{NamedTuple}}()
254+
255+
for table in data
256+
local_result = Vector{NamedTuple}()
257+
258+
# Extract column information
259+
cols = table["cols"]
260+
col_ids = [Symbol(col["id"]) for col in cols]
261+
262+
# Extract rows
263+
rows = table["rows"]
264+
265+
# Parse each row into a NamedTuple
266+
for row in rows
267+
values = [cell["v"] for cell in row["c"]]
268+
nt = NamedTuple{Tuple(col_ids)}(Tuple(values))
269+
push!(local_result, nt)
270+
end
271+
272+
push!(results, local_result)
273+
end
274+
275+
return FrameworkStatsProfileResults(results)
276+
end
277+
278+
function Base.show(io::IO, r::FrameworkStatsProfileResults)
279+
println(io, "╔══════════════════════════════╗")
280+
println(io, "║ FrameworkOpStatsResults ║")
281+
println(io, "╚══════════════════════════════╝")
282+
283+
isempty(r.data) && return nothing
284+
285+
for tbl in r.data
286+
fields = fieldnames(typeof(tbl[1]))
287+
wrapped = split.(wrap_string.(fields; width=10), "\n")
288+
nrows = maximum(length.(wrapped))
289+
column_labels = [
290+
[get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows
291+
]
292+
293+
PrettyTables.pretty_table(
294+
io,
295+
tbl;
296+
line_breaks=true,
297+
auto_wrap=true,
298+
maximum_data_column_widths=10,
299+
column_labels,
300+
)
301+
end
302+
303+
return nothing
304+
end
305+
306+
struct ReactantProfileResults
307+
kernel_stats::KernelStatsProfileResults
308+
framework_stats::FrameworkStatsProfileResults
309+
end
310+
311+
function Base.show(io::IO, r::ReactantProfileResults)
312+
println(io, "╔═══════════════════════════════════════════════════════╗")
313+
println(io, "║ Reactant Profile Results ║")
314+
println(io, "╚═══════════════════════════════════════════════════════╝")
315+
show(io, r.kernel_stats)
316+
println(io)
317+
show(io, r.framework_stats)
318+
return nothing
319+
end
320+
321+
function parse_xprof_profile_data(data)
322+
extmod = Base.get_extension(Reactant, :ReactantPythonCallExt)
323+
if extmod === nothing
324+
error("Currently we require `PythonCall` to be loaded to parse xprof data.")
325+
end
326+
kernel_stats = KernelStatsProfileResults(
327+
JSON3.read(extmod.xspace_to_tools_data(data, "kernel_stats"))
328+
)
329+
framework_stats = FrameworkStatsProfileResults(
330+
JSON3.read(extmod.xspace_to_tools_data(data, "framework_op_stats"))
331+
)
332+
return ReactantProfileResults(kernel_stats, framework_stats)
333+
return nothing
334+
end
335+
336+
macro profile(ex)
337+
profile_dir = joinpath(tempdir(), "reactant_profile")
338+
mkpath(profile_dir)
339+
340+
quote
341+
# TODO: optionally compile the code first and profile
342+
343+
Reactant.Profiler.with_profiler($(esc(profile_dir))) do
344+
$(esc(ex))
345+
end
346+
347+
trace_output_dir = joinpath($(esc(profile_dir)), "plugins", "profile")
348+
date = maximum(readdir(trace_output_dir))
349+
traces_path = joinpath(trace_output_dir, date)
350+
351+
filename = first(f for f in readdir(traces_path) if endswith(f, ".xplane.pb"))
352+
data = $(parse_xprof_profile_data)(joinpath(traces_path, filename))
353+
end
354+
end
355+
198356
end # module Profiler

0 commit comments

Comments
 (0)