From 06f4411c46b5ec526cddee9a47399b106a9ce085 Mon Sep 17 00:00:00 2001 From: azeredo-e Date: Wed, 24 Sep 2025 20:35:12 -0300 Subject: [PATCH 1/2] Refactor the plotting framework, now using RecipesBase --- Project.toml | 13 ++++++++ src/GAM.jl | 7 +++-- src/Plots.jl | 78 +++++++++++++++++++++++++++++++++++++++++------ t.jl | 17 +++++++++++ test/Project.toml | 4 --- test/runtests.jl | 22 ++++++++++--- 6 files changed, 119 insertions(+), 22 deletions(-) create mode 100644 t.jl delete mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 05f10a1..6050821 100644 --- a/Project.toml +++ b/Project.toml @@ -1,5 +1,6 @@ name = "GAM" uuid = "cc454e9f-ce0f-441e-b193-468e31ddef4b" +license = "MIT" authors = ["Trent Henderson "] version = "0.1.0" @@ -12,4 +13,16 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + +[compat] +RDatasets = "0.7.7" +RecipesBase = "1.3.4" + +[extras] Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["RDatasets", "Plots", "Test"] diff --git a/src/GAM.jl b/src/GAM.jl index 47eb1ff..e640543 100644 --- a/src/GAM.jl +++ b/src/GAM.jl @@ -1,6 +1,7 @@ module GAM -using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim +using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Optim +using RecipesBase include("Links-Dists.jl") include("GAMData.jl") @@ -23,8 +24,8 @@ export Dists export Dist_Map export Link_Map export GAMData -export PartialDependencePlot -export plotGAM +export partialdependenceplot +export plotgam export gam end diff --git a/src/Plots.jl b/src/Plots.jl index c60f6ad..6d801aa 100644 --- a/src/Plots.jl +++ b/src/Plots.jl @@ -1,20 +1,62 @@ +@recipe function plot(mod::GAMData; var=0) + if typeof(var) != Int + error("var must be an integer") + end + + if var != 0 + x = mod.x[var] + pred = PredictPartial(mod, var) + ord = sortperm(x) + + return @series begin + x[ord], pred[ord] + end + else + n = length(mod.x) + layout := (1, n) + for partial in 1:n + x = mod.x[partial] + pred = PredictPartial(mod, partial) + ord = sortperm(x) + + @series begin + subplot := partial + link := :y + x[ord], pred[ord] + end + end + end +end + """ - PartialDependencePlot(mod, ix) + partialdependenceplot(mod, var) Draw partial dependence plot. Usage: ```julia-repl -PartialDependencePlot(mod, ix) +partialdependenceplot(mod, var) ``` + Arguments: - `mod` : `GAMData` containing the model. -- `ix` : `Int` denoting the variable to plot. +- `var` : `Int` denoting the variable to plot. """ -function PartialDependencePlot(mod, ix) - x = mod.x[ix] - pred = PredictPartial(mod, ix) +@userplot PartialDependencePlot +@recipe function f(p::PartialDependencePlot) + mod, var = p.args + if typeof(mod) != GAMData + error("First argument must be a GAMData object") + end + if typeof(var) != Int + error("Second argument must be an integer") + end + x = mod.x[var] + pred = PredictPartial(mod, var) ord = sortperm(x) - return plot(x[ord], pred[ord]) + + return @series begin + x[ord], pred[ord] + end end """ @@ -28,7 +70,23 @@ plotGAM(mod) Arguments: - `mod` : `GAMData` containing the model. """ -function plotGAM(mod) - partialPlot = map(x -> PartialDependencePlot(mod, x), eachindex(mod.x)) - plot(partialPlot..., layout=(1, length(partialPlot)), link = :y) +@userplot plotGAM +@recipe function f(p::plotGAM) + mod = p.args[1] + if typeof(mod) != GAMData + error("Argument must be a GAMData object") + end + n = length(mod.x) + layout := (1, n) + for partial in 1:n + x = mod.x[partial] + pred = PredictPartial(mod, partial) + ord = sortperm(x) + + @series begin + subplot := partial + link := :y + x[ord], pred[ord] + end + end end diff --git a/t.jl b/t.jl new file mode 100644 index 0000000..634dd66 --- /dev/null +++ b/t.jl @@ -0,0 +1,17 @@ +using GAM, RDatasets, Plots + +df = dataset("datasets", "trees") + +mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + +p = plotgam(mod) +display(p) +readline() + +# p = plot(mod, var=1) +# display(p) +# readline() + +# p = plot(mod, linecolor=:red) +# display(p) +# readline() \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 47da2f3..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" diff --git a/test/runtests.jl b/test/runtests.jl index f6e3e69..8cbdaa7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ -using GAM +include("../src/GAM.jl") +using .GAM + using Test using RDatasets, Plots @@ -8,17 +10,27 @@ df = dataset("datasets", "trees"); #-------------------- Run tests ----------------- -@testset "GAM.jl" begin - +@testset "Plotting" begin mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) - p = plotGAM(mod) + p = plot(mod, var=1) + @test p isa Plots.Plot + + p2 = plot(mod) + @test p2 isa Plots.Plot + + p = plotgam(mod) + @test p isa Plots.Plot + + p = partialdependenceplot(mod, 1) @test p isa Plots.Plot +end +@testset "Gamma" begin # Gamma version mod2 = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; Family = "Gamma", Link = "Log") - p1 = plotGAM(mod2) + p1 = plotgam(mod2) @test p1 isa Plots.Plot end \ No newline at end of file From ff12acc039967c764fcb015ca3039e2cafb695ab Mon Sep 17 00:00:00 2001 From: azeredo-e Date: Wed, 24 Sep 2025 21:46:05 -0300 Subject: [PATCH 2/2] rm test file --- t.jl | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 t.jl diff --git a/t.jl b/t.jl deleted file mode 100644 index 634dd66..0000000 --- a/t.jl +++ /dev/null @@ -1,17 +0,0 @@ -using GAM, RDatasets, Plots - -df = dataset("datasets", "trees") - -mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) - -p = plotgam(mod) -display(p) -readline() - -# p = plot(mod, var=1) -# display(p) -# readline() - -# p = plot(mod, linecolor=:red) -# display(p) -# readline() \ No newline at end of file