diff --git a/Project.toml b/Project.toml index 05f10a1..6050821 100644 --- a/Project.toml +++ b/Project.toml @@ -1,5 +1,6 @@ name = "GAM" uuid = "cc454e9f-ce0f-441e-b193-468e31ddef4b" +license = "MIT" authors = ["Trent Henderson "] version = "0.1.0" @@ -12,4 +13,16 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + +[compat] +RDatasets = "0.7.7" +RecipesBase = "1.3.4" + +[extras] Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["RDatasets", "Plots", "Test"] diff --git a/src/GAM.jl b/src/GAM.jl index 47eb1ff..e640543 100644 --- a/src/GAM.jl +++ b/src/GAM.jl @@ -1,6 +1,7 @@ module GAM -using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim +using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Optim +using RecipesBase include("Links-Dists.jl") include("GAMData.jl") @@ -23,8 +24,8 @@ export Dists export Dist_Map export Link_Map export GAMData -export PartialDependencePlot -export plotGAM +export partialdependenceplot +export plotgam export gam end diff --git a/src/Plots.jl b/src/Plots.jl index c60f6ad..6d801aa 100644 --- a/src/Plots.jl +++ b/src/Plots.jl @@ -1,20 +1,62 @@ +@recipe function plot(mod::GAMData; var=0) + if typeof(var) != Int + error("var must be an integer") + end + + if var != 0 + x = mod.x[var] + pred = PredictPartial(mod, var) + ord = sortperm(x) + + return @series begin + x[ord], pred[ord] + end + else + n = length(mod.x) + layout := (1, n) + for partial in 1:n + x = mod.x[partial] + pred = PredictPartial(mod, partial) + ord = sortperm(x) + + @series begin + subplot := partial + link := :y + x[ord], pred[ord] + end + end + end +end + """ - PartialDependencePlot(mod, ix) + partialdependenceplot(mod, var) Draw partial dependence plot. Usage: ```julia-repl -PartialDependencePlot(mod, ix) +partialdependenceplot(mod, var) ``` + Arguments: - `mod` : `GAMData` containing the model. -- `ix` : `Int` denoting the variable to plot. +- `var` : `Int` denoting the variable to plot. """ -function PartialDependencePlot(mod, ix) - x = mod.x[ix] - pred = PredictPartial(mod, ix) +@userplot PartialDependencePlot +@recipe function f(p::PartialDependencePlot) + mod, var = p.args + if typeof(mod) != GAMData + error("First argument must be a GAMData object") + end + if typeof(var) != Int + error("Second argument must be an integer") + end + x = mod.x[var] + pred = PredictPartial(mod, var) ord = sortperm(x) - return plot(x[ord], pred[ord]) + + return @series begin + x[ord], pred[ord] + end end """ @@ -28,7 +70,23 @@ plotGAM(mod) Arguments: - `mod` : `GAMData` containing the model. """ -function plotGAM(mod) - partialPlot = map(x -> PartialDependencePlot(mod, x), eachindex(mod.x)) - plot(partialPlot..., layout=(1, length(partialPlot)), link = :y) +@userplot plotGAM +@recipe function f(p::plotGAM) + mod = p.args[1] + if typeof(mod) != GAMData + error("Argument must be a GAMData object") + end + n = length(mod.x) + layout := (1, n) + for partial in 1:n + x = mod.x[partial] + pred = PredictPartial(mod, partial) + ord = sortperm(x) + + @series begin + subplot := partial + link := :y + x[ord], pred[ord] + end + end end diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 47da2f3..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" diff --git a/test/runtests.jl b/test/runtests.jl index f6e3e69..8cbdaa7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ -using GAM +include("../src/GAM.jl") +using .GAM + using Test using RDatasets, Plots @@ -8,17 +10,27 @@ df = dataset("datasets", "trees"); #-------------------- Run tests ----------------- -@testset "GAM.jl" begin - +@testset "Plotting" begin mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) - p = plotGAM(mod) + p = plot(mod, var=1) + @test p isa Plots.Plot + + p2 = plot(mod) + @test p2 isa Plots.Plot + + p = plotgam(mod) + @test p isa Plots.Plot + + p = partialdependenceplot(mod, 1) @test p isa Plots.Plot +end +@testset "Gamma" begin # Gamma version mod2 = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; Family = "Gamma", Link = "Log") - p1 = plotGAM(mod2) + p1 = plotgam(mod2) @test p1 isa Plots.Plot end \ No newline at end of file