From 409f1cefdcedf24aab47c650421abda031bc90a2 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Mon, 24 Nov 2025 11:34:37 -0500 Subject: [PATCH 1/2] Rename GAM.jl to GeneralizedAdditiveModels.jl Renamed the package and all references from GAM.jl to GeneralizedAdditiveModels.jl across source, documentation, tests, and configuration files for clarity and consistency. Updated supported families and links in documentation to reflect the new package name. --- Project.toml | 2 +- README.md | 8 +++++--- docs/Project.toml | 2 +- docs/make.jl | 8 ++++---- docs/src/api_reference.md | 4 ++-- docs/src/index.md | 6 +++--- src/{GAM.jl => GeneralizedAdditiveModels.jl} | 2 +- src/alpha.jl | 2 +- test/runtests.jl | 6 +++--- 9 files changed, 21 insertions(+), 19 deletions(-) rename src/{GAM.jl => GeneralizedAdditiveModels.jl} (94%) diff --git a/Project.toml b/Project.toml index 05f10a1..78fdffc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "GAM" +name = "GeneralizedAdditiveModels" uuid = "cc454e9f-ce0f-441e-b193-468e31ddef4b" authors = ["Trent Henderson "] version = "0.1.0" diff --git a/README.md b/README.md index 24799f1..45e64c2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# GAM.jl +# GeneralizedAdditiveModels.jl Fit, evaluate, and visualise generalised additive models (GAMs) in native Julia ## Motivation @@ -7,7 +7,7 @@ Fit, evaluate, and visualise generalised additive models (GAMs) in native Julia ## Usage -The basic interface to `GAM.jl` is the `gam` function, which is as easy as: +The basic interface to `GeneralizedAdditiveModels.jl` is the `gam` function, which is as easy as: ```{julia} mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) @@ -28,12 +28,14 @@ Note that currently, the following families are supported: * `Normal` * `Poisson` * `Gamma` +* `Bernoulli` And the following link functions: * `Identity` * `Log` +* `Logit` ## Development notes -`GAM.jl` is very much in active development. Please check back for updates and new features or feel free to contribute yourself! The project to-date has been a collaboration between [Trent Henderson](https://github.com/hendersontrent) and [Mason Yahr](https://github.com/yahrMason). +`GeneralizedAdditiveModels.jl` is very much in active development. Please check back for updates and new features or feel free to contribute yourself! The project to-date has been a collaboration between [Trent Henderson](https://github.com/hendersontrent) and [Mason Yahr](https://github.com/yahrMason). diff --git a/docs/Project.toml b/docs/Project.toml index 7ea483c..9253d65 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,5 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -GAM = "cc454e9f-ce0f-441e-b193-468e31ddef4b" +GeneralizedAdditiveModels = "cc454e9f-ce0f-441e-b193-468e31ddef4b" LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" diff --git a/docs/make.jl b/docs/make.jl index 67be5df..2d8bf1e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,14 +3,14 @@ using Documenter # Make sure the package loads when building from docs/ push!(LOAD_PATH, joinpath(@__DIR__, "..", "src")) -using GAM +using GeneralizedAdditiveModels # Doctesting DocMeta.setdocmeta!(GAM, :DocTestSetup, :(using GAM); recursive=true) makedocs( - sitename = "GAM.jl", - modules = [GAM], + sitename = "GeneralizedAdditiveModels.jl", + modules = [GeneralizedAdditiveModels], pages = [ "Home" => "index.md", "API Reference" => "api_reference.md", @@ -22,4 +22,4 @@ makedocs( ) # for later -# deploydocs(repo = "github.com/hendersontrent/GAM.jl.git", devbranch = "main") +# deploydocs(repo = "github.com/hendersontrent/GeneralizedAdditiveModels.jl.git", devbranch = "main") diff --git a/docs/src/api_reference.md b/docs/src/api_reference.md index f7d1277..5dc7207 100644 --- a/docs/src/api_reference.md +++ b/docs/src/api_reference.md @@ -1,9 +1,9 @@ # API Reference ```@autodocs -Modules = [GAM] +Modules = [GeneralizedAdditiveModels] Recursive = true Public = true Private = false Order = [:module, :type, :function, :macro, :constant] -``` \ No newline at end of file +``` diff --git a/docs/src/index.md b/docs/src/index.md index 00b312d..8a06501 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -33,7 +33,7 @@ To install the package, clone it directly from the repository: ```julia using Pkg -Pkg.add(url="https://github.com/hendersontrent/GAM.jl") +Pkg.add(url="https://github.com/hendersontrent/GeneralizedAdditiveModels.jl") ``` --- @@ -44,7 +44,7 @@ Fitting a GAM in GeneralizedAdditiveModels.jl is quick and easy. The syntax of t ```julia using RDatasets -using GAM +using GeneralizedAdditiveModels df = dataset("datasets", "trees"); @@ -55,7 +55,7 @@ mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) **GeneralizedAdditiveModels.jl** is under active development, and contributions are very welcome! -- If you’ve found a bug or want to propose a feature, please [open an issue](https://github.com/hendersontrent/GAM.jl/issues). +- If you’ve found a bug or want to propose a feature, please [open an issue](https://github.com/hendersontrent/GeneralizedAdditiveModels.jl/issues). - If your idea gets positive feedback, feel free to submit a pull request. - If you’re unsure where to start, you can also browse the open issues and pick one that interests you. diff --git a/src/GAM.jl b/src/GeneralizedAdditiveModels.jl similarity index 94% rename from src/GAM.jl rename to src/GeneralizedAdditiveModels.jl index 47eb1ff..726aa2b 100644 --- a/src/GAM.jl +++ b/src/GeneralizedAdditiveModels.jl @@ -1,4 +1,4 @@ -module GAM +module GeneralizedAdditiveModels using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim diff --git a/src/alpha.jl b/src/alpha.jl index e4b6f3f..c2af46b 100644 --- a/src/alpha.jl +++ b/src/alpha.jl @@ -4,7 +4,7 @@ Calculate alpha. Usage: ```julia-repl -alpha((y, mu, Dist, Link) +alpha(y, mu, Dist, Link) ``` Arguments: - `y` : `Vector` containing the response variable. diff --git a/test/runtests.jl b/test/runtests.jl index d430a90..644f0c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using GAM +using GeneralizedAdditiveModels using Test using RDatasets, Plots using Distributions @@ -9,7 +9,7 @@ df = dataset("datasets", "trees"); #-------------------- Run tests ----------------- -@testset "GAM.jl" begin +@testset "GeneralizedAdditiveModels.jl" begin mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) @@ -168,7 +168,7 @@ end # Check that predictions are reasonable x_test = [-1.0, 0.0, 1.0] for xi in x_test - pred_mat = GAM.BuildPredictionMatrix([xi], mod.Basis[1], mod.ColMeans[1]) + pred_mat = GeneralizedAdditiveModels.BuildPredictionMatrix([xi], mod.Basis[1], mod.ColMeans[1]) pred_eta = mod.Coef[1] .+ pred_mat * mod.Coef[mod.CoefIndex[1]] pred_p = 1 / (1 + exp(-pred_eta[1])) From b32695ef53af4905b0a2110f5f306a41fc160567 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Fri, 28 Nov 2025 11:12:34 -0500 Subject: [PATCH 2/2] Add StatsModels.jl formula macro support for GAMs Introduces support for StatsModels.jl @formula macro and FormulaTerm parsing for generalized additive models (GAMs). Adds SmoothTerm type and s() constructor for smooth spline terms, updates FitGAM and GAMFormula to handle FormulaTerm input, and extends tests for formula macro usage and smooth term parsing. Project.toml updated to include StatsModels dependency. --- Project.toml | 1 + src/FitGAM.jl | 58 ++++++++++++++- src/GAMFormula.jl | 98 +++++++++++++++++++++++++ src/GeneralizedAdditiveModels.jl | 6 +- src/SmoothTerm.jl | 64 +++++++++++++++++ test/runtests.jl | 120 +++++++++++++++++++++++++++++++ 6 files changed, 344 insertions(+), 3 deletions(-) create mode 100644 src/SmoothTerm.jl diff --git a/Project.toml b/Project.toml index 78fdffc..c8a972b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,3 +13,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" diff --git a/src/FitGAM.jl b/src/FitGAM.jl index f2e069f..b5d1285 100644 --- a/src/FitGAM.jl +++ b/src/FitGAM.jl @@ -4,10 +4,15 @@ Fit generalised additive model. Usage: ```julia-repl -gam(ModelFormula, Data; Family, Link, Optimizer, maxIter, tol) +# Using string formula (original syntax) +gam("Y ~ s(MPG, k=5, degree=3) + WHT", Data) + +# Using @formula macro (new StatsModels.jl syntax) +using StatsModels +gam(@formula(Y ~ s(MPG, 5, 3) + WHT), Data) ``` Arguments: -- `ModelFormula` : `String` containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R, where `s()` has 3 parts: name of column, `k`` (integer denoting number of knots), and `degree` (polynomial degree of the spline). An example expression is `"Y ~ s(MPG, k=5, degree=3) + WHT + s(TRL, k=5, degree=2)"` +- `ModelFormula` : Either a `String` or `FormulaTerm` (@formula macro) containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R. For strings, use `s(var, k=N, degree=D)` syntax. For @formula macro, use `s(var, N, D)` with positional arguments. An example string expression is `"Y ~ s(MPG, k=5, degree=3) + WHT"` and an example @formula is `@formula(Y ~ s(MPG, 5, 3) + WHT)` - `Data` : `DataFrame` containing the covariates and response variable to use. - `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" - `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" @@ -38,6 +43,55 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident # Build basis Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + # Fit PIRLS procedure + gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) + return gam +end + +""" + gam(ModelFormula::FormulaTerm, Data; Family, Link, Optimizer, maxIter, tol) +Fit generalised additive model using StatsModels.jl @formula macro. + +Usage: +```julia-repl +using StatsModels +f = @formula(Y ~ s(MPG, 5, 3) + WHT) +gam(f, Data; Family="Normal", Link="Identity") +``` +Arguments: +- `ModelFormula` : `FormulaTerm` from StatsModels @formula macro +- `Data` : `DataFrame` containing the covariates and response variable to use. +- `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" +- `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" +- `Optimizer` : Algorithm to use for optimisation. Defaults to `NelderMead()`. +- `maxIter` : Maximum number of iterations for algorithm. Defaults to 25. +- `tol` : Tolerance for solver. Defaults to 1e-6. +""" +function gam(ModelFormula::FormulaTerm, Data::DataFrame; Family="Normal", Link="Identity", Optimizer = NelderMead(), maxIter = 25, tol = 1e-6) + # Delegate to the String version by parsing the FormulaTerm first + # This allows us to reuse all the existing logic + GAMForm = ParseFormula(ModelFormula) + + family_name = Dist_Map[Family] + family_name = Dists[family_name] + link_name = Link_Map[Link] + link_name = Links[link_name] + + # Extract response and covariates + y = Data[!, GAMForm.y] + + # Validate response for Bernoulli family + if Family == "Bernoulli" + @assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family" + end + + x = Data[!, GAMForm.covariates.variable] + BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)] + x = [x[!, col] for col in names(x)] + + # Build basis + Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + # Fit PIRLS procedure gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) return gam diff --git a/src/GAMFormula.jl b/src/GAMFormula.jl index d095e1c..d3b2d11 100644 --- a/src/GAMFormula.jl +++ b/src/GAMFormula.jl @@ -74,3 +74,101 @@ function ParseFormula(formula::String) outs = GAMFormula(Symbol(lhs), df) return outs end + +""" + ParseFormula(formula::FormulaTerm) +Parse StatsModels FormulaTerm into GAMFormula structure. + +Usage: +```julia-repl +f = @formula(y ~ s(x1, k=10, degree=3) + x2) +ParseFormula(f) +``` +Arguments: +- `formula` : `FormulaTerm` from StatsModels.jl @formula macro +""" +function ParseFormula(formula::FormulaTerm) + # Extract response variable + y = formula.lhs.sym + + # Process right-hand side terms + rhs = formula.rhs + + # Create DataFrame to hold covariate information + df = DataFrame(variable = Symbol[], k = Int[], degree = Int[], smooth = Bool[]) + + # Extract terms from the right-hand side + terms = extract_terms(rhs) + + for term in terms + if isa(term, SmoothTerm) + # Smooth term: extract k, degree, and variable name + push!(df, (term.term.sym, term.k, term.degree, true)) + elseif isa(term, Term) + # Linear term: add with default k=0, degree=0, smooth=false + push!(df, (term.sym, 0, 0, false)) + elseif isa(term, InterceptTerm) || isa(term, ConstantTerm) + # Intercept term - we handle this separately, skip for now + continue + else + @warn "Unsupported term type in formula: $(typeof(term)). Skipping." + end + end + + return GAMFormula(y, df) +end + +""" + extract_terms(rhs) +Recursively extract individual terms from the right-hand side of a formula. + +Handles different StatsModels term types including tuples, individual terms, and smooth terms. +""" +function extract_terms(rhs) + terms = [] + + if isa(rhs, Tuple) + # Multiple terms: recursively extract from each + for term in rhs + append!(terms, extract_terms(term)) + end + elseif isa(rhs, SmoothTerm) || isa(rhs, Term) + # Single term: add directly + push!(terms, rhs) + elseif isa(rhs, InterceptTerm) || isa(rhs, ConstantTerm) + # Intercept/constant: add directly + push!(terms, rhs) + elseif isa(rhs, StatsModels.FunctionTerm) + # Handle FunctionTerm (from @formula macro) + # Check if it's our s() function + if rhs.exorig.head == :call && rhs.exorig.args[1] == :s + # Extract arguments from the function call + # rhs.exorig.args[2] is the variable name + # rhs.exorig.args[3] is k (if present) + # rhs.exorig.args[4] is degree (if present) + var_sym = rhs.exorig.args[2] + k = length(rhs.exorig.args) >= 3 ? rhs.exorig.args[3] : 10 + degree = length(rhs.exorig.args) >= 4 ? rhs.exorig.args[4] : 3 + + # Create a SmoothTerm + push!(terms, SmoothTerm(Term(var_sym), k, degree)) + else + @warn "Unsupported function in formula: $(rhs.exorig.args[1])" + end + else + # Try to handle other StatsModels types + # For composite terms, try to extract nested terms + try + # If it has a .terms field (like CategoricalTerm, InteractionTerm, etc.) + if hasfield(typeof(rhs), :terms) + append!(terms, extract_terms(rhs.terms)) + else + @warn "Unable to extract terms from type $(typeof(rhs))" + end + catch e + @warn "Error extracting terms: $e" + end + end + + return terms +end diff --git a/src/GeneralizedAdditiveModels.jl b/src/GeneralizedAdditiveModels.jl index 726aa2b..3b0ad78 100644 --- a/src/GeneralizedAdditiveModels.jl +++ b/src/GeneralizedAdditiveModels.jl @@ -1,6 +1,8 @@ module GeneralizedAdditiveModels using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim +using StatsModels +using StatsModels: FormulaTerm, @formula, Term, ConstantTerm, InterceptTerm, AbstractTerm include("Links-Dists.jl") include("GAMData.jl") @@ -15,8 +17,9 @@ include("alpha.jl") include("PIRLS.jl") include("Predictions.jl") include("Plots.jl") -include("FitGAM.jl") +include("SmoothTerm.jl") include("GAMFormula.jl") +include("FitGAM.jl") export Links export Dists @@ -26,5 +29,6 @@ export GAMData export PartialDependencePlot export plotGAM export gam +export @formula, s, SmoothTerm, ParseFormula end diff --git a/src/SmoothTerm.jl b/src/SmoothTerm.jl new file mode 100644 index 0000000..72ca520 --- /dev/null +++ b/src/SmoothTerm.jl @@ -0,0 +1,64 @@ +#== + SmoothTerm + +Extension of StatsModels.jl for GAM smooth terms. + +This module defines custom term types for representing smooth functions in GAM formulas, +allowing syntax like: @formula(y ~ s(x1, k=10, degree=3) + x2) +==# + + +""" + SmoothTerm + +Represents a smooth spline term in a GAM formula. + +# Fields +- `term::Term`: The variable to be smoothed +- `k::Int`: Number of knots for the spline basis (default: 10) +- `degree::Int`: Polynomial degree of the spline (default: 3) +""" +struct SmoothTerm <: AbstractTerm + term::Term + k::Int + degree::Int +end + +# Constructor with default values +SmoothTerm(term::Term; k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) + +# Allow creating from a Symbol +SmoothTerm(sym::Symbol; k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) + +# Pretty printing +Base.show(io::IO, st::SmoothTerm) = print(io, "s($(st.term.sym), k=$(st.k), degree=$(st.degree))") + +""" + s(variable, k=10, degree=3) + +Create a smooth spline term for use in GAM formulas. + +# Arguments +- `variable`: The variable to be smoothed (Symbol or Term) +- `k`: Number of knots for the spline basis (default: 10) +- `degree`: Polynomial degree of the spline (default: 3) + +# Examples +```julia +using GeneralizedAdditiveModels, StatsModels + +# Using the @formula macro with smooth terms (positional arguments) +f = @formula(y ~ s(x1, 10, 3) + s(x2, 5, 2) + x3) + +# Or define smooth terms before the formula +s1 = s(:x1, 10, 3) +s2 = s(:x2, 5, 2) +# Note: You'll need to use the string formula syntax for pre-defined terms + +# Fit a GAM with the formula +model = gam(f, data) +``` +""" +# Positional argument versions (for use with @formula macro) +s(term::Term, k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) +s(sym::Symbol, k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) diff --git a/test/runtests.jl b/test/runtests.jl index 644f0c1..e1f98e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -178,4 +178,124 @@ end @test abs(pred_p - true_p) < 0.4 end end +end + +@testset "Formula Macro Tests" begin + @testset "SmoothTerm construction" begin + # Test creating smooth terms with positional arguments + st1 = s(:x1, 10, 3) + @test st1 isa SmoothTerm + @test st1.term.sym == :x1 + @test st1.k == 10 + @test st1.degree == 3 + + # Test default values + st2 = s(:x2) + @test st2.k == 10 # default + @test st2.degree == 3 # default + end + + @testset "Formula parsing from FormulaTerm" begin + # Test parsing a simple formula with smooth terms + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 5, 2)) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.k[1] == 10 + @test gam_formula.covariates.degree[1] == 3 + @test gam_formula.covariates.smooth[1] == true + + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.k[2] == 5 + @test gam_formula.covariates.degree[2] == 2 + @test gam_formula.covariates.smooth[2] == true + end + + @testset "Formula with mixed smooth and linear terms" begin + # Test formula with both smooth and linear terms + f = @formula(Volume ~ s(Girth, 10, 3) + Height) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + + # First term is smooth + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.smooth[1] == true + + # Second term is linear + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.smooth[2] == false + @test gam_formula.covariates.k[2] == 0 + @test gam_formula.covariates.degree[2] == 0 + end + + @testset "GAM fitting with @formula macro" begin + # Test fitting a GAM using the @formula macro + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod = gam(f, df) + @test mod isa GAMData + @test length(mod.Fitted) == nrow(df) + + # Compare with string formula version + mod_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + + # Results should be very similar (allowing for numerical precision) + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "GAM with @formula and different families" begin + # Test with Gamma family + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod_gamma = gam(f, df; Family="Gamma", Link="Log") + @test mod_gamma isa GAMData + @test mod_gamma.Family[:Name] == "Gamma" + @test mod_gamma.Link[:Name] == "Log" + + # Compare with string formula version + mod_gamma_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; + Family="Gamma", Link="Log") + @test isapprox(mod_gamma.Fitted, mod_gamma_string.Fitted, rtol=1e-6) + end + + @testset "Bernoulli GAM with @formula" begin + # Create binary data + n = 200 + x1 = range(-2, 2, length=n) + x2 = randn(n) + + # Create true nonlinear effect + f1 = sin.(x1 * π/2) + f2 = x2.^2 .- 1 + eta = f1 + f2 + p = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p)) + + df_bern = DataFrame(y=y, x1=x1, x2=x2) + + # Fit using @formula + f = @formula(y ~ s(x1, 8, 3) + s(x2, 8, 3)) + mod = gam(f, df_bern; Family="Bernoulli", Link="Logit") + + @test mod isa GAMData + @test mod.Family[:Name] == "Bernoulli" + @test all(0 .<= mod.Fitted .<= 1) + + # Compare with string version + mod_string = gam("y ~ s(x1, k=8, degree=3) + s(x2, k=8, degree=3)", df_bern; + Family="Bernoulli", Link="Logit") + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "Plotting GAM fitted with @formula" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + mod = gam(f, df) + + p = plotGAM(mod) + @test p isa Plots.Plot + end end \ No newline at end of file