Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "GAM"
name = "GeneralizedAdditiveModels"
uuid = "cc454e9f-ce0f-441e-b193-468e31ddef4b"
authors = ["Trent Henderson <trent.henderson1@outlook.com>"]
version = "0.1.0"
Expand All @@ -13,3 +13,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# GAM.jl
# GeneralizedAdditiveModels.jl
Fit, evaluate, and visualise generalised additive models (GAMs) in native Julia

## Motivation
Expand All @@ -7,7 +7,7 @@ Fit, evaluate, and visualise generalised additive models (GAMs) in native Julia

## Usage

The basic interface to `GAM.jl` is the `gam` function, which is as easy as:
The basic interface to `GeneralizedAdditiveModels.jl` is the `gam` function, which is as easy as:

```{julia}
mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df)
Expand All @@ -28,12 +28,14 @@ Note that currently, the following families are supported:
* `Normal`
* `Poisson`
* `Gamma`
* `Bernoulli`

And the following link functions:

* `Identity`
* `Log`
* `Logit`

## Development notes

`GAM.jl` is very much in active development. Please check back for updates and new features or feel free to contribute yourself! The project to-date has been a collaboration between [Trent Henderson](https://github.com/hendersontrent) and [Mason Yahr](https://github.com/yahrMason).
`GeneralizedAdditiveModels.jl` is very much in active development. Please check back for updates and new features or feel free to contribute yourself! The project to-date has been a collaboration between [Trent Henderson](https://github.com/hendersontrent) and [Mason Yahr](https://github.com/yahrMason).
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GAM = "cc454e9f-ce0f-441e-b193-468e31ddef4b"
GeneralizedAdditiveModels = "cc454e9f-ce0f-441e-b193-468e31ddef4b"
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
8 changes: 4 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ using Documenter

# Make sure the package loads when building from docs/
push!(LOAD_PATH, joinpath(@__DIR__, "..", "src"))
using GAM
using GeneralizedAdditiveModels

# Doctesting
DocMeta.setdocmeta!(GAM, :DocTestSetup, :(using GAM); recursive=true)

makedocs(
sitename = "GAM.jl",
modules = [GAM],
sitename = "GeneralizedAdditiveModels.jl",
modules = [GeneralizedAdditiveModels],
pages = [
"Home" => "index.md",
"API Reference" => "api_reference.md",
Expand All @@ -22,4 +22,4 @@ makedocs(
)

# for later
# deploydocs(repo = "github.com/hendersontrent/GAM.jl.git", devbranch = "main")
# deploydocs(repo = "github.com/hendersontrent/GeneralizedAdditiveModels.jl.git", devbranch = "main")
4 changes: 2 additions & 2 deletions docs/src/api_reference.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# API Reference

```@autodocs
Modules = [GAM]
Modules = [GeneralizedAdditiveModels]
Recursive = true
Public = true
Private = false
Order = [:module, :type, :function, :macro, :constant]
```
```
6 changes: 3 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ To install the package, clone it directly from the repository:

```julia
using Pkg
Pkg.add(url="https://github.com/hendersontrent/GAM.jl")
Pkg.add(url="https://github.com/hendersontrent/GeneralizedAdditiveModels.jl")
```

---
Expand All @@ -44,7 +44,7 @@ Fitting a GAM in GeneralizedAdditiveModels.jl is quick and easy. The syntax of t

```julia
using RDatasets
using GAM
using GeneralizedAdditiveModels

df = dataset("datasets", "trees");

Expand All @@ -55,7 +55,7 @@ mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df)

**GeneralizedAdditiveModels.jl** is under active development, and contributions are very welcome!

- If you’ve found a bug or want to propose a feature, please [open an issue](https://github.com/hendersontrent/GAM.jl/issues).
- If you’ve found a bug or want to propose a feature, please [open an issue](https://github.com/hendersontrent/GeneralizedAdditiveModels.jl/issues).
- If your idea gets positive feedback, feel free to submit a pull request.
- If you’re unsure where to start, you can also browse the open issues and pick one that interests you.

Expand Down
58 changes: 56 additions & 2 deletions src/FitGAM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ Fit generalised additive model.

Usage:
```julia-repl
gam(ModelFormula, Data; Family, Link, Optimizer, maxIter, tol)
# Using string formula (original syntax)
gam("Y ~ s(MPG, k=5, degree=3) + WHT", Data)

# Using @formula macro (new StatsModels.jl syntax)
using StatsModels
gam(@formula(Y ~ s(MPG, 5, 3) + WHT), Data)
```
Arguments:
- `ModelFormula` : `String` containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R, where `s()` has 3 parts: name of column, `k`` (integer denoting number of knots), and `degree` (polynomial degree of the spline). An example expression is `"Y ~ s(MPG, k=5, degree=3) + WHT + s(TRL, k=5, degree=2)"`
- `ModelFormula` : Either a `String` or `FormulaTerm` (@formula macro) containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R. For strings, use `s(var, k=N, degree=D)` syntax. For @formula macro, use `s(var, N, D)` with positional arguments. An example string expression is `"Y ~ s(MPG, k=5, degree=3) + WHT"` and an example @formula is `@formula(Y ~ s(MPG, 5, 3) + WHT)`
- `Data` : `DataFrame` containing the covariates and response variable to use.
- `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal"
- `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity"
Expand Down Expand Up @@ -38,6 +43,55 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident
# Build basis
Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs)

# Fit PIRLS procedure
gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol)
return gam
end

"""
gam(ModelFormula::FormulaTerm, Data; Family, Link, Optimizer, maxIter, tol)
Fit generalised additive model using StatsModels.jl @formula macro.

Usage:
```julia-repl
using StatsModels
f = @formula(Y ~ s(MPG, 5, 3) + WHT)
gam(f, Data; Family="Normal", Link="Identity")
```
Arguments:
- `ModelFormula` : `FormulaTerm` from StatsModels @formula macro
- `Data` : `DataFrame` containing the covariates and response variable to use.
- `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal"
- `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity"
- `Optimizer` : Algorithm to use for optimisation. Defaults to `NelderMead()`.
- `maxIter` : Maximum number of iterations for algorithm. Defaults to 25.
- `tol` : Tolerance for solver. Defaults to 1e-6.
"""
function gam(ModelFormula::FormulaTerm, Data::DataFrame; Family="Normal", Link="Identity", Optimizer = NelderMead(), maxIter = 25, tol = 1e-6)
# Delegate to the String version by parsing the FormulaTerm first
# This allows us to reuse all the existing logic
GAMForm = ParseFormula(ModelFormula)

family_name = Dist_Map[Family]
family_name = Dists[family_name]
link_name = Link_Map[Link]
link_name = Links[link_name]

# Extract response and covariates
y = Data[!, GAMForm.y]

# Validate response for Bernoulli family
if Family == "Bernoulli"
@assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family"
end

x = Data[!, GAMForm.covariates.variable]
BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)]
x = [x[!, col] for col in names(x)]

# Build basis
Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs)

# Fit PIRLS procedure
gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol)
return gam
Expand Down
98 changes: 98 additions & 0 deletions src/GAMFormula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,101 @@ function ParseFormula(formula::String)
outs = GAMFormula(Symbol(lhs), df)
return outs
end

"""
ParseFormula(formula::FormulaTerm)
Parse StatsModels FormulaTerm into GAMFormula structure.

Usage:
```julia-repl
f = @formula(y ~ s(x1, k=10, degree=3) + x2)
ParseFormula(f)
```
Arguments:
- `formula` : `FormulaTerm` from StatsModels.jl @formula macro
"""
function ParseFormula(formula::FormulaTerm)
# Extract response variable
y = formula.lhs.sym

# Process right-hand side terms
rhs = formula.rhs

# Create DataFrame to hold covariate information
df = DataFrame(variable = Symbol[], k = Int[], degree = Int[], smooth = Bool[])

# Extract terms from the right-hand side
terms = extract_terms(rhs)

for term in terms
if isa(term, SmoothTerm)
# Smooth term: extract k, degree, and variable name
push!(df, (term.term.sym, term.k, term.degree, true))
elseif isa(term, Term)
# Linear term: add with default k=0, degree=0, smooth=false
push!(df, (term.sym, 0, 0, false))
elseif isa(term, InterceptTerm) || isa(term, ConstantTerm)
# Intercept term - we handle this separately, skip for now
continue
else
@warn "Unsupported term type in formula: $(typeof(term)). Skipping."
end
end

return GAMFormula(y, df)
end

"""
extract_terms(rhs)
Recursively extract individual terms from the right-hand side of a formula.

Handles different StatsModels term types including tuples, individual terms, and smooth terms.
"""
function extract_terms(rhs)
terms = []

if isa(rhs, Tuple)
# Multiple terms: recursively extract from each
for term in rhs
append!(terms, extract_terms(term))
end
elseif isa(rhs, SmoothTerm) || isa(rhs, Term)
# Single term: add directly
push!(terms, rhs)
elseif isa(rhs, InterceptTerm) || isa(rhs, ConstantTerm)
# Intercept/constant: add directly
push!(terms, rhs)
elseif isa(rhs, StatsModels.FunctionTerm)
# Handle FunctionTerm (from @formula macro)
# Check if it's our s() function
if rhs.exorig.head == :call && rhs.exorig.args[1] == :s
# Extract arguments from the function call
# rhs.exorig.args[2] is the variable name
# rhs.exorig.args[3] is k (if present)
# rhs.exorig.args[4] is degree (if present)
var_sym = rhs.exorig.args[2]
k = length(rhs.exorig.args) >= 3 ? rhs.exorig.args[3] : 10
degree = length(rhs.exorig.args) >= 4 ? rhs.exorig.args[4] : 3

# Create a SmoothTerm
push!(terms, SmoothTerm(Term(var_sym), k, degree))
else
@warn "Unsupported function in formula: $(rhs.exorig.args[1])"
end
else
# Try to handle other StatsModels types
# For composite terms, try to extract nested terms
try
# If it has a .terms field (like CategoricalTerm, InteractionTerm, etc.)
if hasfield(typeof(rhs), :terms)
append!(terms, extract_terms(rhs.terms))
else
@warn "Unable to extract terms from type $(typeof(rhs))"
end
catch e
@warn "Error extracting terms: $e"
end
end

return terms
end
8 changes: 6 additions & 2 deletions src/GAM.jl → src/GeneralizedAdditiveModels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module GAM
module GeneralizedAdditiveModels

using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim
using StatsModels
using StatsModels: FormulaTerm, @formula, Term, ConstantTerm, InterceptTerm, AbstractTerm

include("Links-Dists.jl")
include("GAMData.jl")
Expand All @@ -15,8 +17,9 @@ include("alpha.jl")
include("PIRLS.jl")
include("Predictions.jl")
include("Plots.jl")
include("FitGAM.jl")
include("SmoothTerm.jl")
include("GAMFormula.jl")
include("FitGAM.jl")

export Links
export Dists
Expand All @@ -26,5 +29,6 @@ export GAMData
export PartialDependencePlot
export plotGAM
export gam
export @formula, s, SmoothTerm, ParseFormula

end
64 changes: 64 additions & 0 deletions src/SmoothTerm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#==
SmoothTerm

Extension of StatsModels.jl for GAM smooth terms.

This module defines custom term types for representing smooth functions in GAM formulas,
allowing syntax like: @formula(y ~ s(x1, k=10, degree=3) + x2)
==#


"""
SmoothTerm

Represents a smooth spline term in a GAM formula.

# Fields
- `term::Term`: The variable to be smoothed
- `k::Int`: Number of knots for the spline basis (default: 10)
- `degree::Int`: Polynomial degree of the spline (default: 3)
"""
struct SmoothTerm <: AbstractTerm
term::Term
k::Int
degree::Int
end

# Constructor with default values
SmoothTerm(term::Term; k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree)

# Allow creating from a Symbol
SmoothTerm(sym::Symbol; k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree)

# Pretty printing
Base.show(io::IO, st::SmoothTerm) = print(io, "s($(st.term.sym), k=$(st.k), degree=$(st.degree))")

"""
s(variable, k=10, degree=3)

Create a smooth spline term for use in GAM formulas.

# Arguments
- `variable`: The variable to be smoothed (Symbol or Term)
- `k`: Number of knots for the spline basis (default: 10)
- `degree`: Polynomial degree of the spline (default: 3)

# Examples
```julia
using GeneralizedAdditiveModels, StatsModels

# Using the @formula macro with smooth terms (positional arguments)
f = @formula(y ~ s(x1, 10, 3) + s(x2, 5, 2) + x3)

# Or define smooth terms before the formula
s1 = s(:x1, 10, 3)
s2 = s(:x2, 5, 2)
# Note: You'll need to use the string formula syntax for pre-defined terms

# Fit a GAM with the formula
model = gam(f, data)
```
"""
# Positional argument versions (for use with @formula macro)
s(term::Term, k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree)
s(sym::Symbol, k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree)
2 changes: 1 addition & 1 deletion src/alpha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Calculate alpha.

Usage:
```julia-repl
alpha((y, mu, Dist, Link)
alpha(y, mu, Dist, Link)
```
Arguments:
- `y` : `Vector` containing the response variable.
Expand Down
Loading