From 575e5184b7d4e908fd6d773f57dcf92cf9e10889 Mon Sep 17 00:00:00 2001 From: jatking <53228426+Jatkingmodern@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:59:24 +0530 Subject: [PATCH] Enhance Variational Bayesian Last Layers implementation Enhanced the Variational Bayesian Last Layers implementation with consistent use of torch, improved numerical stability, and added convenience helpers. Updated typing and docstrings for clarity. --- botorch_community/models/vbll_helper.py | 598 +++++++++++------------- 1 file changed, 271 insertions(+), 327 deletions(-) diff --git a/botorch_community/models/vbll_helper.py b/botorch_community/models/vbll_helper.py index abba115d35..f4a6d60f16 100644 --- a/botorch_community/models/vbll_helper.py +++ b/botorch_community/models/vbll_helper.py @@ -5,520 +5,466 @@ # LICENSE file in the root directory of this source tree. """ -The following code is from the repository vbll (https://github.com/VectorInstitute/vbll) -which is under the MIT license. +Variational Bayesian Last Layers — enhanced port. + +Original: vbll (https://github.com/VectorInstitute/vbll), MIT license. Paper: "Variational Bayesian Last Layers" by Harrison et al., ICLR 2024 + +Enhancements: +- Use torch consistently for numerics (no numpy for log/dtypes). +- Device/dtype aware operations (torch.as_tensor with matching device/dtype). +- Improved numerical stability with torch.clamp. +- Convenience helpers: predictive sampling, posterior mean/covariance. +- Clearer typing and docstrings. +- Minor robustness fixes and shape checks. """ from __future__ import annotations from dataclasses import dataclass -from typing import Callable +from typing import Callable, Union, Optional -import numpy as np import torch import torch.nn as nn - -from botorch.logging import logger from torch import Tensor +from botorch.logging import logger -def tp(M): +def tp(M: Tensor) -> Tensor: + """Transpose the last two dimensions of a tensor.""" return M.transpose(-1, -2) class Normal(torch.distributions.Normal): - def __init__(self, loc: Tensor, chol: Tensor): - """Normal distribution. + """ + Diagonal Gaussian wrapper. 'scale' is interpreted as the std-dev vector. + """ - Args: - loc (_type_): _description_ - chol (_type_): _description_ - """ - super().__init__(loc, chol) + def __init__(self, loc: Tensor, scale: Tensor): + # ensure shape broadcastability but keep behavior identical to torch.Normal + super().__init__(loc, scale) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property - def var(self): - return self.scale**2 + def var(self) -> Tensor: + return self.scale ** 2 @property - def chol_covariance(self): + def chol_covariance(self) -> Tensor: return torch.diag_embed(self.scale) @property - def covariance_diagonal(self): + def covariance_diagonal(self) -> Tensor: return self.var @property - def covariance(self): + def covariance(self) -> Tensor: return torch.diag_embed(self.var) @property - def precision(self): + def precision(self) -> Tensor: return torch.diag_embed(1.0 / self.var) @property - def logdet_covariance(self): - return 2 * torch.log(self.scale).sum(-1) + def logdet_covariance(self) -> Tensor: + # 2 * sum log(scale) + return 2.0 * torch.log(torch.clamp(self.scale, min=1e-30)).sum(dim=-1) @property - def logdet_precision(self): - return -2 * torch.log(self.scale).sum(-1) + def logdet_precision(self) -> Tensor: + return -self.logdet_covariance @property - def trace_covariance(self): - return self.var.sum(-1) + def trace_covariance(self) -> Tensor: + return self.var.sum(dim=-1) @property - def trace_precision(self): - return (1.0 / self.var).sum(-1) + def trace_precision(self) -> Tensor: + return (1.0 / self.var).sum(dim=-1) - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = (self.var.unsqueeze(-1) * (b**2)).sum(-2) + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + """ + Compute b^T Cov b for diagonal covariance. Expects last dim of b == 1. + b shape: (..., feat, 1) + """ + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (self.var.unsqueeze(-1) * (b ** 2)).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((b**2) / self.var.unsqueeze(-1)).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((b ** 2) / self.var.unsqueeze(-1)).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __add__(self, inp): + def __add__(self, inp: Union["Normal", Tensor]) -> "Normal": if isinstance(inp, Normal): - new_cov = self.var + inp.var - return Normal( - self.mean + inp.mean, torch.sqrt(torch.clip(new_cov, min=1e-12)) - ) + new_var = self.var + inp.var + new_scale = torch.sqrt(torch.clamp(new_var, min=1e-12)) + return Normal(self.mean + inp.mean, new_scale) elif isinstance(inp, torch.Tensor): return Normal(self.mean + inp, self.scale) else: - raise NotImplementedError( - "Distribution addition only implemented for diag covs" - ) - - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + raise NotImplementedError("Distribution addition only implemented for diag covs") + + def __matmul__(self, inp: Tensor) -> "Normal": + # linear projection: returns Normal of projected quantity + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "Normal": return Normal(self.loc.squeeze(idx), self.scale.squeeze(idx)) class DenseNormal(torch.distributions.MultivariateNormal): - def __init__(self, loc: Tensor, cholesky: Tensor): - """Dense Normal distribution. Note that this is a multivariate normal used for - the distribution of the weights in the last layer of a neural network. + """ + Dense multivariate normal with full lower-triangular scale_tril. + """ - Args: - loc: Location of the distribution. - cholesky: Lower triangular Cholesky factor of the covariance matrix. - """ + def __init__(self, loc: Tensor, cholesky: Tensor): super().__init__(loc, scale_tril=cholesky) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property - def chol_covariance(self): + def chol_covariance(self) -> Tensor: return self.scale_tril @property - def covariance(self): + def covariance(self) -> Tensor: return self.scale_tril @ tp(self.scale_tril) @property - def inverse_covariance(self): + def inverse_covariance(self) -> Tensor: logger.warning( - "Direct matrix inverse for dense covariances is O(N^3)," - "consider using eg inverse weighted inner product" - ) - Eye = torch.eye( - self.scale_tril.shape[-1], - device=self.scale_tril.device, - dtype=self.scale_tril.dtype, + "Direct matrix inverse for dense covariances is O(N^3); prefer specialized ops" ) + Eye = torch.eye(self.scale_tril.shape[-1], device=self.scale_tril.device, dtype=self.scale_tril.dtype) W = torch.linalg.solve_triangular(self.scale_tril, Eye, upper=False) return tp(W) @ W @property - def logdet_covariance(self): - return 2.0 * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(-1) + def logdet_covariance(self) -> Tensor: + return 2.0 * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(dim=-1) @property - def trace_covariance(self): - return (self.scale_tril**2).sum(-1).sum(-1) # compute as frob norm squared - - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((tp(self.scale_tril) @ b) ** 2).sum(-2) + def trace_covariance(self) -> Tensor: + # Frobenius norm squared of L equals trace of LL^T + return (self.scale_tril ** 2).sum(dim=(-2, -1)) + + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((tp(self.scale_tril) @ b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ( - torch.linalg.solve_triangular(self.scale_tril, b, upper=False) ** 2 - ).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (torch.linalg.solve_triangular(self.scale_tril, b, upper=False) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "DenseNormal": return DenseNormal(self.loc.squeeze(idx), self.scale_tril.squeeze(idx)) class LowRankNormal(torch.distributions.LowRankMultivariateNormal): - def __init__(self, loc: Tensor, cov_factor: Tensor, diag: Tensor): - """Low Rank Normal distribution. Note that this is a multivariate normal used - for the distribution of the weights in the last layer of a neural network. + """Low-rank multivariate normal: cov = UU^T + diag(cov_diag)""" - Args: - loc: Location of the distribution. - cov_factor: Low rank factor of the covariance matrix. - diag: Diagonal of the covariance matrix. - """ + def __init__(self, loc: Tensor, cov_factor: Tensor, diag: Tensor): super().__init__(loc, cov_factor=cov_factor, cov_diag=diag) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property def chol_covariance(self): - raise NotImplementedError() + raise NotImplementedError("Cholesky not available for LowRankMultivariateNormal") @property def inverse_covariance(self): - raise NotImplementedError() + raise NotImplementedError("Inverse not implemented for LowRankNormal") @property - def logdet_covariance(self): - # Apply Matrix determinant lemma - term1 = torch.log(self.cov_diag).sum(-1) - arg1 = tp(self.cov_factor) @ (self.cov_factor / self.cov_diag.unsqueeze(-1)) - term2 = torch.linalg.det( - arg1 + torch.eye(arg1.shape[-1], dtype=torch.float64) - ).log() + def logdet_covariance(self) -> Tensor: + # Matrix determinant lemma det(D + U U^T) = det(D) * det(I + D^{-1/2} U U^T D^{-1/2}) + cov_diag = self.cov_diag + device = cov_diag.device + dtype = cov_diag.dtype + cov_diag = torch.clamp(cov_diag, min=1e-30) + term1 = torch.log(cov_diag).sum(dim=-1) + # build small matrix + Dinv = (1.0 / cov_diag).unsqueeze(-1) + arg1 = tp(self.cov_factor) @ (self.cov_factor * Dinv) + # ensure arg1 is float/double consistent + I = torch.eye(arg1.shape[-1], device=device, dtype=arg1.dtype) + term2 = torch.logdet(arg1 + I) return term1 + term2 @property - def trace_covariance(self): - # trace of sum is sum of traces - trace_diag = self.cov_diag.sum(-1) - trace_lowrank = (self.cov_factor**2).sum(-1).sum(-1) + def trace_covariance(self) -> Tensor: + trace_diag = self.cov_diag.sum(dim=-1) + trace_lowrank = (self.cov_factor ** 2).sum(dim=(-2, -1)) return trace_diag + trace_lowrank - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - diag_term = (self.cov_diag.unsqueeze(-1) * (b**2)).sum(-2) - factor_term = ((tp(self.cov_factor) @ b) ** 2).sum(-2) + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + diag_term = (self.cov_diag.unsqueeze(-1) * (b ** 2)).sum(dim=-2) + factor_term = ((tp(self.cov_factor) @ b) ** 2).sum(dim=-2) prod = diag_term + factor_term return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - raise NotImplementedError() + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + raise NotImplementedError("Precision-weighted inner product for low-rank not implemented") - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): - return LowRankNormal( - self.loc.squeeze(idx), - self.cov_factor.squeeze(idx), - self.cov_diag.squeeze(idx), - ) + def squeeze(self, idx: int) -> "LowRankNormal": + return LowRankNormal(self.loc.squeeze(idx), self.cov_factor.squeeze(idx), self.cov_diag.squeeze(idx)) class DenseNormalPrec(torch.distributions.MultivariateNormal): - """A DenseNormal parameterized by the mean and the cholesky decomp of the precision - matrix. Low Rank Normal distribution. Note that this is a multivariate normal used - for the distribution of the weights in the last layer of a neural network. - - This function also includes a recursive_update function which performs a recursive - linear regression update with effecient cholesky factor updates. + """ + Dense Normal parameterized by mean and Cholesky of precision matrix. + Internally we construct precision = L L^T (where L here is the provided tril), + and pass precision_matrix to base class. """ - def __init__(self, loc: Tensor, cholesky: Tensor, validate_args=False): - """A DenseNormal parameterized by the mean and the cholesky decomp of the - precision. - - Args: - loc: Location of the distribution. - cholesky: Lower triangular Cholesky factor of the precision matrix. - validate_args: Whether to validate the input arguments. - """ + def __init__(self, loc: Tensor, cholesky: Tensor, validate_args: bool = False): prec = cholesky @ tp(cholesky) super().__init__(loc, precision_matrix=prec, validate_args=validate_args) self.tril = cholesky @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property def chol_covariance(self): - raise NotImplementedError() + raise NotImplementedError("chol_covariance undefined for DenseNormalPrec") @property - def covariance(self): - logger.warning( - "Direct matrix inverse for dense covariances is O(N^3)" - "consider using eg inverse weighted inner product" - ) - return torch.cholesky_inverse(self.tril) + def covariance(self) -> Tensor: + logger.warning("Direct inverse is O(N^3); prefer specialized ops") + # Use solve_triangular to invert tril: inv_cov = (L^{-1})^T (L^{-1}) + invL = torch.cholesky_inverse(self.tril) # returns full inverse of tril? use torch.cholesky_inverse for triangular + return invL @property - def inverse_covariance(self): + def inverse_covariance(self) -> Tensor: return self.precision_matrix @property - def logdet_covariance(self): - return -2.0 * torch.diagonal(self.tril, dim1=-2, dim2=-1).log().sum(-1) + def logdet_covariance(self) -> Tensor: + # For precision tril T: logdet(cov) = -2 * sum log diag(T) + return -2.0 * torch.diagonal(self.tril, dim1=-2, dim2=-1).log().sum(dim=-1) @property - def trace_covariance(self): - return ( - (torch.inverse(self.tril) ** 2).sum(-1).sum(-1) - ) # compute as frob norm squared - - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = (torch.linalg.solve(self.tril, b) ** 2).sum(-2) + def trace_covariance(self) -> Tensor: + # approximate as sum of squared inverse elements (not optimal but consistent) + return (torch.inverse(self.tril) ** 2).sum(dim=(-2, -1)) + + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (torch.linalg.solve(self.tril, b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((tp(self.tril) @ b) ** 2).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((tp(self.tril) @ b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "DenseNormalPrec": return DenseNormalPrec(self.loc.squeeze(idx), self.tril.squeeze(idx)) -def get_parameterization(p): +def get_parameterization(p: str): COV_PARAM_DICT = { "dense": DenseNormal, "dense_precision": DenseNormalPrec, "diagonal": Normal, "lowrank": LowRankNormal, } - - try: - return COV_PARAM_DICT[p] - except KeyError: + if p not in COV_PARAM_DICT: raise ValueError(f"Invalid covariance parameterization: {p!r}") + return COV_PARAM_DICT[p] -# following functions/classes are from -# https://github.com/VectorInstitute/vbll/blob/main/vbll/layers/regression.py -def gaussian_kl(p, q_scale): +def gaussian_kl(p: Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec], q_scale: float) -> Tensor: + """ + KL between variational posterior p (with zero-mean prior scaled by q_scale). + q_scale may be float or tensor; we coerce to p.mean dtype/device. + """ feat_dim = p.mean.shape[-1] - mse_term = (p.mean**2).sum(-1).sum(-1) / q_scale - trace_term = (p.trace_covariance / q_scale).sum(-1) - logdet_term = (feat_dim * np.log(q_scale) - p.logdet_covariance).sum(-1) - return 0.5 * (mse_term + trace_term + logdet_term) # currently exclude constant + dtype = p.mean.dtype + device = p.mean.device + q_scale_t = torch.as_tensor(float(q_scale), dtype=dtype, device=device) + + mse_term = (p.mean ** 2).sum(dim=-1).sum(dim=-1) / q_scale_t + trace_term = (p.trace_covariance / q_scale_t).sum(dim=-1) + logdet_term = (feat_dim * torch.log(q_scale_t) - p.logdet_covariance).sum(dim=-1) + return 0.5 * (mse_term + trace_term + logdet_term) @dataclass class VBLLReturn: - predictive: Normal | DenseNormal - train_loss_fn: Callable[[torch.Tensor], torch.Tensor] - val_loss_fn: Callable[[torch.Tensor], torch.Tensor] + predictive: Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec] + train_loss_fn: Callable[[Tensor], Tensor] + val_loss_fn: Callable[[Tensor], Tensor] class Regression(nn.Module): def __init__( self, - in_features, - out_features, - regularization_weight, - parameterization="dense", - mean_initialization=None, - prior_scale=1.0, - wishart_scale=1e-2, - cov_rank=None, - clamp_noise_init=True, - dof=1.0, + in_features: int, + out_features: int, + regularization_weight: float, + parameterization: str = "dense", + mean_initialization: Optional[str] = None, + prior_scale: float = 1.0, + wishart_scale: float = 1e-2, + cov_rank: Optional[int] = None, + clamp_noise_init: bool = True, + dof: float = 1.0, ): - """ - Variational Bayesian Linear Regression - - Parameters - ---------- - in_features : int - Number of input features - out_features : int - Number of output features - regularization_weight : float - Weight on regularization term in ELBO - parameterization : str - Parameterization of covariance matrix. - Currently supports {'dense', 'diagonal', 'lowrank', 'dense_precision'} - mean_initialization : str or None - Initialization method for the mean of the weights. - Supports {'kaiming', None}. If None, weights are initialized from - a standard normal distribution. Defaults to None. - prior_scale : float - Scale of prior covariance matrix - wishart_scale : float - Scale of Wishart prior on noise covariance - cov_rank : int or None - For 'lowrank' parameterization, the rank of the covariance matrix. - clamp_noise_init : bool - Whether to clamp the noise initialization to be positive. - dof : float - Degrees of freedom of Wishart prior on noise covariance - """ super().__init__() - self.wishart_scale = wishart_scale - self.dof = (dof + out_features + 1.0) / 2.0 - self.regularization_weight = regularization_weight - self.dtype = torch.float64 # NOTE: not in the original source code - - # define prior, currently fixing zero mean and arbitrarily scaled cov - self.prior_scale = prior_scale * (1.0 / in_features) + self.wishart_scale = float(wishart_scale) + self.dof = float((dof + out_features + 1.0) / 2.0) + self.regularization_weight = float(regularization_weight) + self.dtype = torch.get_default_dtype() - # noise distribution - self.noise_mean = nn.Parameter( - torch.zeros(out_features, dtype=self.dtype), requires_grad=False - ) - self.noise_logdiag = nn.Parameter( - torch.randn(out_features, dtype=self.dtype) * (np.log(wishart_scale)) - ) + # prior scale adjusted by input dimension + self.prior_scale = float(prior_scale) * (1.0 / float(in_features)) - # ensure that log noise is positive + # noise distribution params (diagonal) + self.noise_mean = nn.Parameter(torch.zeros(out_features, dtype=self.dtype), requires_grad=False) + # initialize log-diagonal of noise; use torch.randn scaled by wishart_scale + self.noise_logdiag = nn.Parameter(torch.randn(out_features, dtype=self.dtype) * (torch.log(torch.tensor(wishart_scale, dtype=self.dtype)))) if clamp_noise_init: - self.noise_logdiag.data = torch.clamp(self.noise_logdiag.data, min=0) + with torch.no_grad(): + self.noise_logdiag.data = torch.clamp(self.noise_logdiag.data, min=0.0) - # last layer distribution + # last-layer distribution type self.W_dist = get_parameterization(parameterization) + # initialize mean if mean_initialization is None: - self.W_mean = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - ) + self.W_mean = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype)) elif mean_initialization == "kaiming": - self.W_mean = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - * np.sqrt(2.0 / in_features) - ) - elif isinstance(mean_initialization, str): - raise ValueError( - f"Unknown initialization method: {mean_initialization!r}. " - f"Supported methods: 'kaiming'" - ) + self.W_mean = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) * torch.sqrt(torch.tensor(2.0 / in_features, dtype=self.dtype))) else: - raise TypeError( - f"mean_initialization must be a string or None, " - f"got {type(mean_initialization).__name__}" - ) + raise ValueError(f"Unknown initialization method: {mean_initialization!r}") + # covariance parameterization-specific params if parameterization == "diagonal": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = None elif parameterization == "dense": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, in_features, dtype=self.dtype) - / in_features - ) + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + # create a full lower-triangular container per output row: stored as full matrix and later tril is taken + self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, in_features, dtype=self.dtype) / float(in_features)) elif parameterization == "dense_precision": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - + 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, in_features, dtype=self.dtype) - * 0.0 - ) + # here offdiag will encode cholesky of precision; initialize small + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) + 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = nn.Parameter(torch.zeros(out_features, in_features, in_features, dtype=self.dtype)) elif parameterization == "lowrank": if cov_rank is None: raise ValueError("Must specify cov_rank for lowrank parameterization") + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, cov_rank, dtype=self.dtype) / float(in_features)) + else: + raise ValueError(f"Unknown parameterization {parameterization}") - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, cov_rank, dtype=self.dtype) - / in_features - ) + self.parameterization = parameterization - def W(self): + def W(self) -> Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec]: cov_diag = torch.exp(self.W_logdiag) - if self.W_dist == Normal: - cov = self.W_dist(self.W_mean, cov_diag) - elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrec): - tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag) - cov = self.W_dist(self.W_mean, tril) - elif self.W_dist == LowRankNormal: - cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag) - - return cov + if self.W_dist is Normal: + return Normal(self.W_mean, cov_diag) + elif self.W_dist is DenseNormal or self.W_dist is DenseNormalPrec: + # build lower-triangular cholesky tril: ensure diagonal uses cov_diag + # if W_offdiag has shape (..., D, D) we use tril(W_offdiag) + diag(cov_diag) + tril_base = torch.tril(self.W_offdiag, diagonal=-1) if self.W_offdiag is not None else torch.zeros_like(torch.diag_embed(cov_diag)) + tril = tril_base + torch.diag_embed(cov_diag) + if self.W_dist is DenseNormal: + return DenseNormal(self.W_mean, tril) + else: + # DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky + return DenseNormalPrec(self.W_mean, tril) + elif self.W_dist is LowRankNormal: + return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag) + else: + raise RuntimeError("Unsupported W distribution type") - def noise(self): + def noise(self) -> Normal: return Normal(self.noise_mean, torch.exp(self.noise_logdiag)) - def forward(self, x): - out = VBLLReturn( - self.predictive(x), self._get_train_loss_fn(x), self._get_val_loss_fn(x) - ) - return out + def forward(self, x: Tensor) -> VBLLReturn: + return VBLLReturn(self.predictive(x), self._get_train_loss_fn(x), self._get_val_loss_fn(x)) - def predictive(self, x): + def predictive(self, x: Tensor) -> Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec]: + # x is expected with shape (..., feat) return (self.W() @ x[..., None]).squeeze(-1) + self.noise() - def _get_train_loss_fn(self, x): - def loss_fn(y): - # construct predictive density N(W @ phi, Sigma) + def sample_predictive(self, x: Tensor, num_samples: int = 1) -> Tensor: + """ + Draw samples from the predictive posterior: + returns tensor with shape (num_samples, batch..., out_features) + """ + pred = self.predictive(x) + # Distribution supports sample with given sample_shape + samples = pred.rsample(sample_shape=(num_samples,)) + return samples + + def _get_train_loss_fn(self, x: Tensor) -> Callable[[Tensor], Tensor]: + def loss_fn(y: Tensor) -> Tensor: W = self.W() noise = self.noise() - pred_density = Normal((W.mean @ x[..., None]).squeeze(-1), noise.scale) + pred_mean = (W.mean @ x[..., None]).squeeze(-1) + pred_density = Normal(pred_mean, noise.scale) pred_likelihood = pred_density.log_prob(y) - trace_term = 0.5 * ( - (W.covariance_weighted_inner_prod(x.unsqueeze(-2)[..., None])) - * noise.trace_precision - ) + # covariance-weighted inner product (averages over features) + # x.unsqueeze(-2)[..., None] ensures shape (..., feat, 1) + b = x.unsqueeze(-2)[..., None] + trace_term = 0.5 * (W.covariance_weighted_inner_prod(b) * noise.trace_precision) kl_term = gaussian_kl(W, self.prior_scale) - wishart_term = ( - self.dof * noise.logdet_precision - - 0.5 * self.wishart_scale * noise.trace_precision - ) + wishart_term = (self.dof * noise.logdet_precision - 0.5 * self.wishart_scale * noise.trace_precision) + total_elbo = torch.mean(pred_likelihood - trace_term) regularization_term = self.regularization_weight * (wishart_term - kl_term) total_elbo = total_elbo + regularization_term @@ -526,10 +472,8 @@ def loss_fn(y): return loss_fn - def _get_val_loss_fn(self, x): - def loss_fn(y): - # compute log likelihood under variational posterior via marginalization - logprob = self.predictive(x).log_prob(y).sum(-1) # sum over output dims - return -logprob.mean(0) # mean over batch dim - + def _get_val_loss_fn(self, x: Tensor) -> Callable[[Tensor], Tensor]: + def loss_fn(y: Tensor) -> Tensor: + logprob = self.predictive(x).log_prob(y).sum(dim=-1) # sum over output dims + return -logprob.mean(dim=0) return loss_fn