From 934304f62f84afc37a312a743d2d623113dba4e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 13 Nov 2025 20:48:46 -0600 Subject: [PATCH 1/2] feat: profile macro to provide profiling data in the terminal --- CondaPkg.toml | 1 + Project.toml | 2 + .../ReactantPythonCallExt.jl | 17 ++ ext/ReactantPythonCallExt/xprof.jl | 14 ++ src/Profiler.jl | 158 ++++++++++++++++++ 5 files changed, 192 insertions(+) create mode 100644 ext/ReactantPythonCallExt/xprof.jl diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..217a82a82c 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" +xprof = ">= 2.20" diff --git a/Project.toml b/Project.toml index f1d17dd81f..31d374866d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -89,6 +90,7 @@ Functors = "0.5" GPUArraysCore = "0.2" GPUCompiler = "1.3" HTTP = "1.10.15" +JSON3 = "1.14.3" KernelAbstractions = "0.9.30" LLVM = "9.1" LLVMOpenMP_jll = "18.1.7" diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..e5fca2534f 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -15,6 +15,11 @@ const npptr = Ref{Py}() const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false) +const xprofptr = Ref{Py}() +const xprofconvertptr = Ref{Py}() + +const XPROF_PROFILER_SUPPORTED = Ref{Bool}(false) + const NUMPY_SIMPLE_TYPES = Dict( Bool => :bool, Int8 => :int8, @@ -54,11 +59,23 @@ function __init__() tensorflow SavedModel will not be \ supported." exception = (err, catch_backtrace()) end + + try + xprofptr[] = pyimport("xprof") + xprofconvertptr[] = pyimport("xprof.convert.raw_to_tool_data") + XPROF_PROFILER_SUPPORTED[] = true + catch err + @warn "Failed to import xprof. Xprof support will not be available." exception = ( + err, catch_backtrace() + ) + end + return nothing end include("overlays.jl") include("pycall.jl") include("saved_model.jl") +include("xprof.jl") end diff --git a/ext/ReactantPythonCallExt/xprof.jl b/ext/ReactantPythonCallExt/xprof.jl new file mode 100644 index 0000000000..5d8341c997 --- /dev/null +++ b/ext/ReactantPythonCallExt/xprof.jl @@ -0,0 +1,14 @@ +# Currently prototyping with xprof via python. we should instead add this into +# the C++ API. +function xspace_to_tools_data(filename::String, tool_name::String) + if !XPROF_PROFILER_SUPPORTED[] + error("xprof is not supported...") + end + + return String( + pyconvert( + Vector{UInt8}, + xprofconvertptr[].xspace_to_tool_data(pylist([filename]), tool_name, pydict())[0], + ), + ) +end diff --git a/src/Profiler.jl b/src/Profiler.jl index e63ca273a7..0bc3fbb849 100644 --- a/src/Profiler.jl +++ b/src/Profiler.jl @@ -2,6 +2,8 @@ module Profiler import ..Reactant using Sockets: Sockets +using PrettyTables: PrettyTables +using JSON3: JSON3 """ with_profiler(f, trace_output_dir::String; trace_device=true, trace_host=true, create_perfetto_link=false) @@ -195,4 +197,160 @@ mutable struct ProfileServer end end +function wrap_string(s; width=20) + s_str = string(s) + return join([s_str[i:min(i + width - 1, end)] for i in 1:width:length(s_str)], "\n") +end + +struct KernelStatsProfileResults + data +end + +function KernelStatsProfileResults(data::JSON3.Object) + cols = data["cols"] + rows = data["rows"] + keys = Tuple(Symbol.(get.(cols, "id"))) + table = Vector{NamedTuple}(undef, length(rows)) + + for (i, row) in enumerate(rows) + vals = get.(row["c"], "v") + table[i] = NamedTuple{keys}(vals) + end + + return KernelStatsProfileResults(table) +end + +function Base.show(io::IO, r::KernelStatsProfileResults) + tbl = r.data + + println(io, "╔════════════════╗") + println(io, "║ Kernel Stats ║") + println(io, "╚════════════════╝") + + isempty(tbl) && return nothing + + fields = fieldnames(typeof(tbl[1])) + wrapped = split.(wrap_string.(fields; width=10), "\n") + nrows = maximum(length.(wrapped)) + column_labels = [[get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows] + + PrettyTables.pretty_table( + io, + tbl; + line_breaks=true, + maximum_data_column_widths=10, + auto_wrap=true, + column_labels, + ) + return nothing +end + +struct FrameworkStatsProfileResults + data +end + +function FrameworkStatsProfileResults(data::JSON3.Array{JSON3.Object}) + results = Vector{Vector{NamedTuple}}() + + for table in data + local_result = Vector{NamedTuple}() + + # Extract column information + cols = table["cols"] + col_ids = [Symbol(col["id"]) for col in cols] + + # Extract rows + rows = table["rows"] + + # Parse each row into a NamedTuple + for row in rows + values = [cell["v"] for cell in row["c"]] + nt = NamedTuple{Tuple(col_ids)}(Tuple(values)) + push!(local_result, nt) + end + + push!(results, local_result) + end + + return FrameworkStatsProfileResults(results) +end + +function Base.show(io::IO, r::FrameworkStatsProfileResults) + println(io, "╔══════════════════════════════╗") + println(io, "║ FrameworkOpStatsResults ║") + println(io, "╚══════════════════════════════╝") + + isempty(r.data) && return nothing + + for tbl in r.data + fields = fieldnames(typeof(tbl[1])) + wrapped = split.(wrap_string.(fields; width=10), "\n") + nrows = maximum(length.(wrapped)) + column_labels = [ + [get(wrapped[j], i, "") for j in 1:length(wrapped)] for i in 1:nrows + ] + + PrettyTables.pretty_table( + io, + tbl; + line_breaks=true, + auto_wrap=true, + maximum_data_column_widths=10, + column_labels, + ) + end + + return nothing +end + +struct ReactantProfileResults + kernel_stats::KernelStatsProfileResults + framework_stats::FrameworkStatsProfileResults +end + +function Base.show(io::IO, r::ReactantProfileResults) + println(io, "╔═══════════════════════════════════════════════════════╗") + println(io, "║ Reactant Profile Results ║") + println(io, "╚═══════════════════════════════════════════════════════╝") + show(io, r.kernel_stats) + println(io) + show(io, r.framework_stats) + return nothing +end + +function parse_xprof_profile_data(data) + extmod = Base.get_extension(Reactant, :ReactantPythonCallExt) + if extmod === nothing + error("Currently we require `PythonCall` to be loaded to parse xprof data.") + end + kernel_stats = KernelStatsProfileResults( + JSON3.read(extmod.xspace_to_tools_data(data, "kernel_stats")) + ) + framework_stats = FrameworkStatsProfileResults( + JSON3.read(extmod.xspace_to_tools_data(data, "framework_op_stats")) + ) + return ReactantProfileResults(kernel_stats, framework_stats) + return nothing +end + +macro profile(ex) + profile_dir = joinpath(tempdir(), "reactant_profile") + mkpath(profile_dir) + + quote + # TODO: optionally compile the code first and profile + + Reactant.Profiler.with_profiler($(esc(profile_dir))) do + $(esc(ex)) + end + + trace_output_dir = joinpath($(esc(profile_dir)), "plugins", "profile") + date = maximum(readdir(trace_output_dir)) + traces_path = joinpath(trace_output_dir, date) + + filename = first(f for f in readdir(traces_path) if endswith(f, ".xplane.pb")) + data = $(parse_xprof_profile_data)(joinpath(traces_path, filename)) + end +end + end # module Profiler From 2fbf025b8d8b6449652ea1b688a67e9f535a6caf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 13 Nov 2025 20:53:15 -0600 Subject: [PATCH 2/2] fix: missing dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 31d374866d..a9c24242dd 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" @@ -103,6 +104,7 @@ OneHotArrays = "0.2.10" OrderedCollections = "1" PrecompileTools = "1.2" Preferences = "1.4.3" +PrettyTables = "3.1.0" PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7"