Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4"
jax = ">= 0.6"
tensorflow = ">= 2.17"
numpy = ">= 2"
xprof = ">= 2.20"
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Expand Down Expand Up @@ -89,6 +91,7 @@ Functors = "0.5"
GPUArraysCore = "0.2"
GPUCompiler = "1.3"
HTTP = "1.10.15"
JSON3 = "1.14.3"
KernelAbstractions = "0.9.30"
LLVM = "9.1"
LLVMOpenMP_jll = "18.1.7"
Expand All @@ -101,6 +104,7 @@ OneHotArrays = "0.2.10"
OrderedCollections = "1"
PrecompileTools = "1.2"
Preferences = "1.4.3"
PrettyTables = "3.1.0"
PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
Expand Down
17 changes: 17 additions & 0 deletions ext/ReactantPythonCallExt/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ const npptr = Ref{Py}()

const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)

const xprofptr = Ref{Py}()
const xprofconvertptr = Ref{Py}()

const XPROF_PROFILER_SUPPORTED = Ref{Bool}(false)

const NUMPY_SIMPLE_TYPES = Dict(
Bool => :bool,
Int8 => :int8,
Expand Down Expand Up @@ -54,11 +59,23 @@ function __init__()
tensorflow SavedModel will not be \
supported." exception = (err, catch_backtrace())
end

try
xprofptr[] = pyimport("xprof")
xprofconvertptr[] = pyimport("xprof.convert.raw_to_tool_data")
XPROF_PROFILER_SUPPORTED[] = true
catch err
@warn "Failed to import xprof. Xprof support will not be available." exception = (
err, catch_backtrace()
)
end

return nothing
end

include("overlays.jl")
include("pycall.jl")
include("saved_model.jl")
include("xprof.jl")

end
14 changes: 14 additions & 0 deletions ext/ReactantPythonCallExt/xprof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Currently prototyping with xprof via python. we should instead add this into
# the C++ API.
function xspace_to_tools_data(filename::String, tool_name::String)
if !XPROF_PROFILER_SUPPORTED[]
error("xprof is not supported...")
end

return String(
pyconvert(
Vector{UInt8},
xprofconvertptr[].xspace_to_tool_data(pylist([filename]), tool_name, pydict())[0],
),
)
end
158 changes: 158 additions & 0 deletions src/Profiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module Profiler

import ..Reactant
using Sockets: Sockets
using PrettyTables: PrettyTables
using JSON3: JSON3

"""
with_profiler(f, trace_output_dir::String; trace_device=true, trace_host=true, create_perfetto_link=false)
Expand Down Expand Up @@ -195,4 +197,160 @@ mutable struct ProfileServer
end
end

function wrap_string(s; width=20)
s_str = string(s)
return join([s_str[i:min(i + width - 1, end)] for i in 1:width:length(s_str)], "\n")
end

struct KernelStatsProfileResults
data
end

function KernelStatsProfileResults(data::JSON3.Object)
cols = data["cols"]
rows = data["rows"]
keys = Tuple(Symbol.(get.(cols, "id")))
table = Vector{NamedTuple}(undef, length(rows))

for (i, row) in enumerate(rows)
vals = get.(row["c"], "v")
table[i] = NamedTuple{keys}(vals)
end

return KernelStatsProfileResults(table)
end

function Base.show(io::IO, r::KernelStatsProfileResults)
tbl = r.data

println(io, "╔════════════════╗")
println(io, "║ Kernel Stats ║")
println(io, "╚════════════════╝")

isempty(tbl) && return nothing

fields = fieldnames(typeof(tbl[1]))
wrapped = split.(wrap_string.(fields; width=10), "\n")
nrows = maximum(length.(wrapped))
column_labels = [[get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows]

PrettyTables.pretty_table(
io,
tbl;
line_breaks=true,
maximum_data_column_widths=10,
auto_wrap=true,
column_labels,
)
return nothing
end

struct FrameworkStatsProfileResults
data
end

function FrameworkStatsProfileResults(data::JSON3.Array{JSON3.Object})
results = Vector{Vector{NamedTuple}}()

for table in data
local_result = Vector{NamedTuple}()

# Extract column information
cols = table["cols"]
col_ids = [Symbol(col["id"]) for col in cols]

# Extract rows
rows = table["rows"]

# Parse each row into a NamedTuple
for row in rows
values = [cell["v"] for cell in row["c"]]
nt = NamedTuple{Tuple(col_ids)}(Tuple(values))
push!(local_result, nt)
end

push!(results, local_result)
end

return FrameworkStatsProfileResults(results)
end

function Base.show(io::IO, r::FrameworkStatsProfileResults)
println(io, "╔══════════════════════════════╗")
println(io, "║ FrameworkOpStatsResults ║")
println(io, "╚══════════════════════════════╝")

isempty(r.data) && return nothing

for tbl in r.data
fields = fieldnames(typeof(tbl[1]))
wrapped = split.(wrap_string.(fields; width=10), "\n")
nrows = maximum(length.(wrapped))
column_labels = [
[get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows
]

PrettyTables.pretty_table(
io,
tbl;
line_breaks=true,
auto_wrap=true,
maximum_data_column_widths=10,
column_labels,
)
end

return nothing
end

struct ReactantProfileResults
kernel_stats::KernelStatsProfileResults
framework_stats::FrameworkStatsProfileResults
end

function Base.show(io::IO, r::ReactantProfileResults)
println(io, "╔═══════════════════════════════════════════════════════╗")
println(io, "║ Reactant Profile Results ║")
println(io, "╚═══════════════════════════════════════════════════════╝")
show(io, r.kernel_stats)
println(io)
show(io, r.framework_stats)
return nothing
end

function parse_xprof_profile_data(data)
extmod = Base.get_extension(Reactant, :ReactantPythonCallExt)
if extmod === nothing
error("Currently we require `PythonCall` to be loaded to parse xprof data.")
end
kernel_stats = KernelStatsProfileResults(
JSON3.read(extmod.xspace_to_tools_data(data, "kernel_stats"))
)
framework_stats = FrameworkStatsProfileResults(
JSON3.read(extmod.xspace_to_tools_data(data, "framework_op_stats"))
)
return ReactantProfileResults(kernel_stats, framework_stats)
return nothing
end

macro profile(ex)
profile_dir = joinpath(tempdir(), "reactant_profile")
mkpath(profile_dir)

quote
# TODO: optionally compile the code first and profile

Reactant.Profiler.with_profiler($(esc(profile_dir))) do
$(esc(ex))
end

trace_output_dir = joinpath($(esc(profile_dir)), "plugins", "profile")
date = maximum(readdir(trace_output_dir))
traces_path = joinpath(trace_output_dir, date)

filename = first(f for f in readdir(traces_path) if endswith(f, ".xplane.pb"))
data = $(parse_xprof_profile_data)(joinpath(traces_path, filename))
end
end

end # module Profiler
Loading