Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name = "GAM"
uuid = "cc454e9f-ce0f-441e-b193-468e31ddef4b"
license = "MIT"
authors = ["Trent Henderson <trent.henderson1@outlook.com>"]
version = "0.1.0"

Expand All @@ -12,4 +13,16 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[compat]
RDatasets = "0.7.7"
RecipesBase = "1.3.4"

[extras]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["RDatasets", "Plots", "Test"]
7 changes: 4 additions & 3 deletions src/GAM.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module GAM

using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim
using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Optim
using RecipesBase

include("Links-Dists.jl")
include("GAMData.jl")
Expand All @@ -23,8 +24,8 @@ export Dists
export Dist_Map
export Link_Map
export GAMData
export PartialDependencePlot
export plotGAM
export partialdependenceplot
export plotgam
export gam

end
78 changes: 68 additions & 10 deletions src/Plots.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,62 @@
@recipe function plot(mod::GAMData; var=0)
if typeof(var) != Int
error("var must be an integer")
end

if var != 0
x = mod.x[var]
pred = PredictPartial(mod, var)
ord = sortperm(x)

return @series begin
x[ord], pred[ord]
end
else
n = length(mod.x)
layout := (1, n)
for partial in 1:n
x = mod.x[partial]
pred = PredictPartial(mod, partial)
ord = sortperm(x)

@series begin
subplot := partial
link := :y
x[ord], pred[ord]
end
end
end
end

"""
PartialDependencePlot(mod, ix)
partialdependenceplot(mod, var)
Draw partial dependence plot.

Usage:
```julia-repl
PartialDependencePlot(mod, ix)
partialdependenceplot(mod, var)
```

Arguments:
- `mod` : `GAMData` containing the model.
- `ix` : `Int` denoting the variable to plot.
- `var` : `Int` denoting the variable to plot.
"""
function PartialDependencePlot(mod, ix)
x = mod.x[ix]
pred = PredictPartial(mod, ix)
@userplot PartialDependencePlot
@recipe function f(p::PartialDependencePlot)
mod, var = p.args
if typeof(mod) != GAMData
error("First argument must be a GAMData object")
end
if typeof(var) != Int
error("Second argument must be an integer")
end
x = mod.x[var]
pred = PredictPartial(mod, var)
ord = sortperm(x)
return plot(x[ord], pred[ord])

return @series begin
x[ord], pred[ord]
end
end

"""
Expand All @@ -28,7 +70,23 @@ plotGAM(mod)
Arguments:
- `mod` : `GAMData` containing the model.
"""
function plotGAM(mod)
partialPlot = map(x -> PartialDependencePlot(mod, x), eachindex(mod.x))
plot(partialPlot..., layout=(1, length(partialPlot)), link = :y)
@userplot plotGAM
@recipe function f(p::plotGAM)
mod = p.args[1]
if typeof(mod) != GAMData
error("Argument must be a GAMData object")
end
n = length(mod.x)
layout := (1, n)
for partial in 1:n
x = mod.x[partial]
pred = PredictPartial(mod, partial)
ord = sortperm(x)

@series begin
subplot := partial
link := :y
x[ord], pred[ord]
end
end
end
4 changes: 0 additions & 4 deletions test/Project.toml

This file was deleted.

22 changes: 17 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using GAM
include("../src/GAM.jl")
using .GAM

using Test
using RDatasets, Plots

Expand All @@ -8,17 +10,27 @@ df = dataset("datasets", "trees");

#-------------------- Run tests -----------------

@testset "GAM.jl" begin

@testset "Plotting" begin
mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df)

p = plotGAM(mod)
p = plot(mod, var=1)
@test p isa Plots.Plot

p2 = plot(mod)
@test p2 isa Plots.Plot

p = plotgam(mod)
@test p isa Plots.Plot

p = partialdependenceplot(mod, 1)
@test p isa Plots.Plot
end

@testset "Gamma" begin
# Gamma version

mod2 = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; Family = "Gamma", Link = "Log")

p1 = plotGAM(mod2)
p1 = plotgam(mod2)
@test p1 isa Plots.Plot
end