diff --git a/botorch_community/models/vbll_helper.py b/botorch_community/models/vbll_helper.py index abba115d35..f4a6d60f16 100644 --- a/botorch_community/models/vbll_helper.py +++ b/botorch_community/models/vbll_helper.py @@ -5,520 +5,466 @@ # LICENSE file in the root directory of this source tree. """ -The following code is from the repository vbll (https://github.com/VectorInstitute/vbll) -which is under the MIT license. +Variational Bayesian Last Layers — enhanced port. + +Original: vbll (https://github.com/VectorInstitute/vbll), MIT license. Paper: "Variational Bayesian Last Layers" by Harrison et al., ICLR 2024 + +Enhancements: +- Use torch consistently for numerics (no numpy for log/dtypes). +- Device/dtype aware operations (torch.as_tensor with matching device/dtype). +- Improved numerical stability with torch.clamp. +- Convenience helpers: predictive sampling, posterior mean/covariance. +- Clearer typing and docstrings. +- Minor robustness fixes and shape checks. """ from __future__ import annotations from dataclasses import dataclass -from typing import Callable +from typing import Callable, Union, Optional -import numpy as np import torch import torch.nn as nn - -from botorch.logging import logger from torch import Tensor +from botorch.logging import logger -def tp(M): +def tp(M: Tensor) -> Tensor: + """Transpose the last two dimensions of a tensor.""" return M.transpose(-1, -2) class Normal(torch.distributions.Normal): - def __init__(self, loc: Tensor, chol: Tensor): - """Normal distribution. + """ + Diagonal Gaussian wrapper. 'scale' is interpreted as the std-dev vector. + """ - Args: - loc (_type_): _description_ - chol (_type_): _description_ - """ - super().__init__(loc, chol) + def __init__(self, loc: Tensor, scale: Tensor): + # ensure shape broadcastability but keep behavior identical to torch.Normal + super().__init__(loc, scale) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property - def var(self): - return self.scale**2 + def var(self) -> Tensor: + return self.scale ** 2 @property - def chol_covariance(self): + def chol_covariance(self) -> Tensor: return torch.diag_embed(self.scale) @property - def covariance_diagonal(self): + def covariance_diagonal(self) -> Tensor: return self.var @property - def covariance(self): + def covariance(self) -> Tensor: return torch.diag_embed(self.var) @property - def precision(self): + def precision(self) -> Tensor: return torch.diag_embed(1.0 / self.var) @property - def logdet_covariance(self): - return 2 * torch.log(self.scale).sum(-1) + def logdet_covariance(self) -> Tensor: + # 2 * sum log(scale) + return 2.0 * torch.log(torch.clamp(self.scale, min=1e-30)).sum(dim=-1) @property - def logdet_precision(self): - return -2 * torch.log(self.scale).sum(-1) + def logdet_precision(self) -> Tensor: + return -self.logdet_covariance @property - def trace_covariance(self): - return self.var.sum(-1) + def trace_covariance(self) -> Tensor: + return self.var.sum(dim=-1) @property - def trace_precision(self): - return (1.0 / self.var).sum(-1) + def trace_precision(self) -> Tensor: + return (1.0 / self.var).sum(dim=-1) - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = (self.var.unsqueeze(-1) * (b**2)).sum(-2) + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + """ + Compute b^T Cov b for diagonal covariance. Expects last dim of b == 1. + b shape: (..., feat, 1) + """ + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (self.var.unsqueeze(-1) * (b ** 2)).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((b**2) / self.var.unsqueeze(-1)).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((b ** 2) / self.var.unsqueeze(-1)).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __add__(self, inp): + def __add__(self, inp: Union["Normal", Tensor]) -> "Normal": if isinstance(inp, Normal): - new_cov = self.var + inp.var - return Normal( - self.mean + inp.mean, torch.sqrt(torch.clip(new_cov, min=1e-12)) - ) + new_var = self.var + inp.var + new_scale = torch.sqrt(torch.clamp(new_var, min=1e-12)) + return Normal(self.mean + inp.mean, new_scale) elif isinstance(inp, torch.Tensor): return Normal(self.mean + inp, self.scale) else: - raise NotImplementedError( - "Distribution addition only implemented for diag covs" - ) - - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + raise NotImplementedError("Distribution addition only implemented for diag covs") + + def __matmul__(self, inp: Tensor) -> "Normal": + # linear projection: returns Normal of projected quantity + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "Normal": return Normal(self.loc.squeeze(idx), self.scale.squeeze(idx)) class DenseNormal(torch.distributions.MultivariateNormal): - def __init__(self, loc: Tensor, cholesky: Tensor): - """Dense Normal distribution. Note that this is a multivariate normal used for - the distribution of the weights in the last layer of a neural network. + """ + Dense multivariate normal with full lower-triangular scale_tril. + """ - Args: - loc: Location of the distribution. - cholesky: Lower triangular Cholesky factor of the covariance matrix. - """ + def __init__(self, loc: Tensor, cholesky: Tensor): super().__init__(loc, scale_tril=cholesky) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property - def chol_covariance(self): + def chol_covariance(self) -> Tensor: return self.scale_tril @property - def covariance(self): + def covariance(self) -> Tensor: return self.scale_tril @ tp(self.scale_tril) @property - def inverse_covariance(self): + def inverse_covariance(self) -> Tensor: logger.warning( - "Direct matrix inverse for dense covariances is O(N^3)," - "consider using eg inverse weighted inner product" - ) - Eye = torch.eye( - self.scale_tril.shape[-1], - device=self.scale_tril.device, - dtype=self.scale_tril.dtype, + "Direct matrix inverse for dense covariances is O(N^3); prefer specialized ops" ) + Eye = torch.eye(self.scale_tril.shape[-1], device=self.scale_tril.device, dtype=self.scale_tril.dtype) W = torch.linalg.solve_triangular(self.scale_tril, Eye, upper=False) return tp(W) @ W @property - def logdet_covariance(self): - return 2.0 * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(-1) + def logdet_covariance(self) -> Tensor: + return 2.0 * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(dim=-1) @property - def trace_covariance(self): - return (self.scale_tril**2).sum(-1).sum(-1) # compute as frob norm squared - - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((tp(self.scale_tril) @ b) ** 2).sum(-2) + def trace_covariance(self) -> Tensor: + # Frobenius norm squared of L equals trace of LL^T + return (self.scale_tril ** 2).sum(dim=(-2, -1)) + + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((tp(self.scale_tril) @ b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ( - torch.linalg.solve_triangular(self.scale_tril, b, upper=False) ** 2 - ).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (torch.linalg.solve_triangular(self.scale_tril, b, upper=False) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "DenseNormal": return DenseNormal(self.loc.squeeze(idx), self.scale_tril.squeeze(idx)) class LowRankNormal(torch.distributions.LowRankMultivariateNormal): - def __init__(self, loc: Tensor, cov_factor: Tensor, diag: Tensor): - """Low Rank Normal distribution. Note that this is a multivariate normal used - for the distribution of the weights in the last layer of a neural network. + """Low-rank multivariate normal: cov = UU^T + diag(cov_diag)""" - Args: - loc: Location of the distribution. - cov_factor: Low rank factor of the covariance matrix. - diag: Diagonal of the covariance matrix. - """ + def __init__(self, loc: Tensor, cov_factor: Tensor, diag: Tensor): super().__init__(loc, cov_factor=cov_factor, cov_diag=diag) @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property def chol_covariance(self): - raise NotImplementedError() + raise NotImplementedError("Cholesky not available for LowRankMultivariateNormal") @property def inverse_covariance(self): - raise NotImplementedError() + raise NotImplementedError("Inverse not implemented for LowRankNormal") @property - def logdet_covariance(self): - # Apply Matrix determinant lemma - term1 = torch.log(self.cov_diag).sum(-1) - arg1 = tp(self.cov_factor) @ (self.cov_factor / self.cov_diag.unsqueeze(-1)) - term2 = torch.linalg.det( - arg1 + torch.eye(arg1.shape[-1], dtype=torch.float64) - ).log() + def logdet_covariance(self) -> Tensor: + # Matrix determinant lemma det(D + U U^T) = det(D) * det(I + D^{-1/2} U U^T D^{-1/2}) + cov_diag = self.cov_diag + device = cov_diag.device + dtype = cov_diag.dtype + cov_diag = torch.clamp(cov_diag, min=1e-30) + term1 = torch.log(cov_diag).sum(dim=-1) + # build small matrix + Dinv = (1.0 / cov_diag).unsqueeze(-1) + arg1 = tp(self.cov_factor) @ (self.cov_factor * Dinv) + # ensure arg1 is float/double consistent + I = torch.eye(arg1.shape[-1], device=device, dtype=arg1.dtype) + term2 = torch.logdet(arg1 + I) return term1 + term2 @property - def trace_covariance(self): - # trace of sum is sum of traces - trace_diag = self.cov_diag.sum(-1) - trace_lowrank = (self.cov_factor**2).sum(-1).sum(-1) + def trace_covariance(self) -> Tensor: + trace_diag = self.cov_diag.sum(dim=-1) + trace_lowrank = (self.cov_factor ** 2).sum(dim=(-2, -1)) return trace_diag + trace_lowrank - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - diag_term = (self.cov_diag.unsqueeze(-1) * (b**2)).sum(-2) - factor_term = ((tp(self.cov_factor) @ b) ** 2).sum(-2) + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + diag_term = (self.cov_diag.unsqueeze(-1) * (b ** 2)).sum(dim=-2) + factor_term = ((tp(self.cov_factor) @ b) ** 2).sum(dim=-2) prod = diag_term + factor_term return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - raise NotImplementedError() + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + raise NotImplementedError("Precision-weighted inner product for low-rank not implemented") - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): - return LowRankNormal( - self.loc.squeeze(idx), - self.cov_factor.squeeze(idx), - self.cov_diag.squeeze(idx), - ) + def squeeze(self, idx: int) -> "LowRankNormal": + return LowRankNormal(self.loc.squeeze(idx), self.cov_factor.squeeze(idx), self.cov_diag.squeeze(idx)) class DenseNormalPrec(torch.distributions.MultivariateNormal): - """A DenseNormal parameterized by the mean and the cholesky decomp of the precision - matrix. Low Rank Normal distribution. Note that this is a multivariate normal used - for the distribution of the weights in the last layer of a neural network. - - This function also includes a recursive_update function which performs a recursive - linear regression update with effecient cholesky factor updates. + """ + Dense Normal parameterized by mean and Cholesky of precision matrix. + Internally we construct precision = L L^T (where L here is the provided tril), + and pass precision_matrix to base class. """ - def __init__(self, loc: Tensor, cholesky: Tensor, validate_args=False): - """A DenseNormal parameterized by the mean and the cholesky decomp of the - precision. - - Args: - loc: Location of the distribution. - cholesky: Lower triangular Cholesky factor of the precision matrix. - validate_args: Whether to validate the input arguments. - """ + def __init__(self, loc: Tensor, cholesky: Tensor, validate_args: bool = False): prec = cholesky @ tp(cholesky) super().__init__(loc, precision_matrix=prec, validate_args=validate_args) self.tril = cholesky @property - def mean(self): + def mean(self) -> Tensor: return self.loc @property def chol_covariance(self): - raise NotImplementedError() + raise NotImplementedError("chol_covariance undefined for DenseNormalPrec") @property - def covariance(self): - logger.warning( - "Direct matrix inverse for dense covariances is O(N^3)" - "consider using eg inverse weighted inner product" - ) - return torch.cholesky_inverse(self.tril) + def covariance(self) -> Tensor: + logger.warning("Direct inverse is O(N^3); prefer specialized ops") + # Use solve_triangular to invert tril: inv_cov = (L^{-1})^T (L^{-1}) + invL = torch.cholesky_inverse(self.tril) # returns full inverse of tril? use torch.cholesky_inverse for triangular + return invL @property - def inverse_covariance(self): + def inverse_covariance(self) -> Tensor: return self.precision_matrix @property - def logdet_covariance(self): - return -2.0 * torch.diagonal(self.tril, dim1=-2, dim2=-1).log().sum(-1) + def logdet_covariance(self) -> Tensor: + # For precision tril T: logdet(cov) = -2 * sum log diag(T) + return -2.0 * torch.diagonal(self.tril, dim1=-2, dim2=-1).log().sum(dim=-1) @property - def trace_covariance(self): - return ( - (torch.inverse(self.tril) ** 2).sum(-1).sum(-1) - ) # compute as frob norm squared - - def covariance_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = (torch.linalg.solve(self.tril, b) ** 2).sum(-2) + def trace_covariance(self) -> Tensor: + # approximate as sum of squared inverse elements (not optimal but consistent) + return (torch.inverse(self.tril) ** 2).sum(dim=(-2, -1)) + + def covariance_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = (torch.linalg.solve(self.tril, b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def precision_weighted_inner_prod(self, b, reduce_dim=True): - assert b.shape[-1] == 1 - prod = ((tp(self.tril) @ b) ** 2).sum(-2) + def precision_weighted_inner_prod(self, b: Tensor, reduce_dim: bool = True) -> Tensor: + if b.shape[-1] != 1: + raise ValueError("b must have last dim 1") + prod = ((tp(self.tril) @ b) ** 2).sum(dim=-2) return prod.squeeze(-1) if reduce_dim else prod - def __matmul__(self, inp): - assert inp.shape[-2] == self.loc.shape[-1] - assert inp.shape[-1] == 1 - new_cov = self.covariance_weighted_inner_prod( - inp.unsqueeze(-3), reduce_dim=False - ) - return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min=1e-12))) + def __matmul__(self, inp: Tensor) -> Normal: + if inp.shape[-2] != self.loc.shape[-1] or inp.shape[-1] != 1: + raise ValueError("Input to matmul must have shape (..., feat, 1) matching loc") + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim=False) + new_scale = torch.sqrt(torch.clamp(new_cov, min=1e-12)) + return Normal(self.loc @ inp, new_scale) - def squeeze(self, idx): + def squeeze(self, idx: int) -> "DenseNormalPrec": return DenseNormalPrec(self.loc.squeeze(idx), self.tril.squeeze(idx)) -def get_parameterization(p): +def get_parameterization(p: str): COV_PARAM_DICT = { "dense": DenseNormal, "dense_precision": DenseNormalPrec, "diagonal": Normal, "lowrank": LowRankNormal, } - - try: - return COV_PARAM_DICT[p] - except KeyError: + if p not in COV_PARAM_DICT: raise ValueError(f"Invalid covariance parameterization: {p!r}") + return COV_PARAM_DICT[p] -# following functions/classes are from -# https://github.com/VectorInstitute/vbll/blob/main/vbll/layers/regression.py -def gaussian_kl(p, q_scale): +def gaussian_kl(p: Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec], q_scale: float) -> Tensor: + """ + KL between variational posterior p (with zero-mean prior scaled by q_scale). + q_scale may be float or tensor; we coerce to p.mean dtype/device. + """ feat_dim = p.mean.shape[-1] - mse_term = (p.mean**2).sum(-1).sum(-1) / q_scale - trace_term = (p.trace_covariance / q_scale).sum(-1) - logdet_term = (feat_dim * np.log(q_scale) - p.logdet_covariance).sum(-1) - return 0.5 * (mse_term + trace_term + logdet_term) # currently exclude constant + dtype = p.mean.dtype + device = p.mean.device + q_scale_t = torch.as_tensor(float(q_scale), dtype=dtype, device=device) + + mse_term = (p.mean ** 2).sum(dim=-1).sum(dim=-1) / q_scale_t + trace_term = (p.trace_covariance / q_scale_t).sum(dim=-1) + logdet_term = (feat_dim * torch.log(q_scale_t) - p.logdet_covariance).sum(dim=-1) + return 0.5 * (mse_term + trace_term + logdet_term) @dataclass class VBLLReturn: - predictive: Normal | DenseNormal - train_loss_fn: Callable[[torch.Tensor], torch.Tensor] - val_loss_fn: Callable[[torch.Tensor], torch.Tensor] + predictive: Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec] + train_loss_fn: Callable[[Tensor], Tensor] + val_loss_fn: Callable[[Tensor], Tensor] class Regression(nn.Module): def __init__( self, - in_features, - out_features, - regularization_weight, - parameterization="dense", - mean_initialization=None, - prior_scale=1.0, - wishart_scale=1e-2, - cov_rank=None, - clamp_noise_init=True, - dof=1.0, + in_features: int, + out_features: int, + regularization_weight: float, + parameterization: str = "dense", + mean_initialization: Optional[str] = None, + prior_scale: float = 1.0, + wishart_scale: float = 1e-2, + cov_rank: Optional[int] = None, + clamp_noise_init: bool = True, + dof: float = 1.0, ): - """ - Variational Bayesian Linear Regression - - Parameters - ---------- - in_features : int - Number of input features - out_features : int - Number of output features - regularization_weight : float - Weight on regularization term in ELBO - parameterization : str - Parameterization of covariance matrix. - Currently supports {'dense', 'diagonal', 'lowrank', 'dense_precision'} - mean_initialization : str or None - Initialization method for the mean of the weights. - Supports {'kaiming', None}. If None, weights are initialized from - a standard normal distribution. Defaults to None. - prior_scale : float - Scale of prior covariance matrix - wishart_scale : float - Scale of Wishart prior on noise covariance - cov_rank : int or None - For 'lowrank' parameterization, the rank of the covariance matrix. - clamp_noise_init : bool - Whether to clamp the noise initialization to be positive. - dof : float - Degrees of freedom of Wishart prior on noise covariance - """ super().__init__() - self.wishart_scale = wishart_scale - self.dof = (dof + out_features + 1.0) / 2.0 - self.regularization_weight = regularization_weight - self.dtype = torch.float64 # NOTE: not in the original source code - - # define prior, currently fixing zero mean and arbitrarily scaled cov - self.prior_scale = prior_scale * (1.0 / in_features) + self.wishart_scale = float(wishart_scale) + self.dof = float((dof + out_features + 1.0) / 2.0) + self.regularization_weight = float(regularization_weight) + self.dtype = torch.get_default_dtype() - # noise distribution - self.noise_mean = nn.Parameter( - torch.zeros(out_features, dtype=self.dtype), requires_grad=False - ) - self.noise_logdiag = nn.Parameter( - torch.randn(out_features, dtype=self.dtype) * (np.log(wishart_scale)) - ) + # prior scale adjusted by input dimension + self.prior_scale = float(prior_scale) * (1.0 / float(in_features)) - # ensure that log noise is positive + # noise distribution params (diagonal) + self.noise_mean = nn.Parameter(torch.zeros(out_features, dtype=self.dtype), requires_grad=False) + # initialize log-diagonal of noise; use torch.randn scaled by wishart_scale + self.noise_logdiag = nn.Parameter(torch.randn(out_features, dtype=self.dtype) * (torch.log(torch.tensor(wishart_scale, dtype=self.dtype)))) if clamp_noise_init: - self.noise_logdiag.data = torch.clamp(self.noise_logdiag.data, min=0) + with torch.no_grad(): + self.noise_logdiag.data = torch.clamp(self.noise_logdiag.data, min=0.0) - # last layer distribution + # last-layer distribution type self.W_dist = get_parameterization(parameterization) + # initialize mean if mean_initialization is None: - self.W_mean = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - ) + self.W_mean = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype)) elif mean_initialization == "kaiming": - self.W_mean = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - * np.sqrt(2.0 / in_features) - ) - elif isinstance(mean_initialization, str): - raise ValueError( - f"Unknown initialization method: {mean_initialization!r}. " - f"Supported methods: 'kaiming'" - ) + self.W_mean = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) * torch.sqrt(torch.tensor(2.0 / in_features, dtype=self.dtype))) else: - raise TypeError( - f"mean_initialization must be a string or None, " - f"got {type(mean_initialization).__name__}" - ) + raise ValueError(f"Unknown initialization method: {mean_initialization!r}") + # covariance parameterization-specific params if parameterization == "diagonal": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = None elif parameterization == "dense": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, in_features, dtype=self.dtype) - / in_features - ) + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + # create a full lower-triangular container per output row: stored as full matrix and later tril is taken + self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, in_features, dtype=self.dtype) / float(in_features)) elif parameterization == "dense_precision": - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - + 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, in_features, dtype=self.dtype) - * 0.0 - ) + # here offdiag will encode cholesky of precision; initialize small + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) + 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = nn.Parameter(torch.zeros(out_features, in_features, in_features, dtype=self.dtype)) elif parameterization == "lowrank": if cov_rank is None: raise ValueError("Must specify cov_rank for lowrank parameterization") + self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features, dtype=self.dtype) - 0.5 * torch.log(torch.tensor(in_features, dtype=self.dtype))) + self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, cov_rank, dtype=self.dtype) / float(in_features)) + else: + raise ValueError(f"Unknown parameterization {parameterization}") - self.W_logdiag = nn.Parameter( - torch.randn(out_features, in_features, dtype=self.dtype) - - 0.5 * np.log(in_features) - ) - self.W_offdiag = nn.Parameter( - torch.randn(out_features, in_features, cov_rank, dtype=self.dtype) - / in_features - ) + self.parameterization = parameterization - def W(self): + def W(self) -> Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec]: cov_diag = torch.exp(self.W_logdiag) - if self.W_dist == Normal: - cov = self.W_dist(self.W_mean, cov_diag) - elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrec): - tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag) - cov = self.W_dist(self.W_mean, tril) - elif self.W_dist == LowRankNormal: - cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag) - - return cov + if self.W_dist is Normal: + return Normal(self.W_mean, cov_diag) + elif self.W_dist is DenseNormal or self.W_dist is DenseNormalPrec: + # build lower-triangular cholesky tril: ensure diagonal uses cov_diag + # if W_offdiag has shape (..., D, D) we use tril(W_offdiag) + diag(cov_diag) + tril_base = torch.tril(self.W_offdiag, diagonal=-1) if self.W_offdiag is not None else torch.zeros_like(torch.diag_embed(cov_diag)) + tril = tril_base + torch.diag_embed(cov_diag) + if self.W_dist is DenseNormal: + return DenseNormal(self.W_mean, tril) + else: + # DenseNormalPrec expects tril of precision; we accept passed tril as precision-cholesky + return DenseNormalPrec(self.W_mean, tril) + elif self.W_dist is LowRankNormal: + return LowRankNormal(self.W_mean, self.W_offdiag, cov_diag) + else: + raise RuntimeError("Unsupported W distribution type") - def noise(self): + def noise(self) -> Normal: return Normal(self.noise_mean, torch.exp(self.noise_logdiag)) - def forward(self, x): - out = VBLLReturn( - self.predictive(x), self._get_train_loss_fn(x), self._get_val_loss_fn(x) - ) - return out + def forward(self, x: Tensor) -> VBLLReturn: + return VBLLReturn(self.predictive(x), self._get_train_loss_fn(x), self._get_val_loss_fn(x)) - def predictive(self, x): + def predictive(self, x: Tensor) -> Union[Normal, DenseNormal, LowRankNormal, DenseNormalPrec]: + # x is expected with shape (..., feat) return (self.W() @ x[..., None]).squeeze(-1) + self.noise() - def _get_train_loss_fn(self, x): - def loss_fn(y): - # construct predictive density N(W @ phi, Sigma) + def sample_predictive(self, x: Tensor, num_samples: int = 1) -> Tensor: + """ + Draw samples from the predictive posterior: + returns tensor with shape (num_samples, batch..., out_features) + """ + pred = self.predictive(x) + # Distribution supports sample with given sample_shape + samples = pred.rsample(sample_shape=(num_samples,)) + return samples + + def _get_train_loss_fn(self, x: Tensor) -> Callable[[Tensor], Tensor]: + def loss_fn(y: Tensor) -> Tensor: W = self.W() noise = self.noise() - pred_density = Normal((W.mean @ x[..., None]).squeeze(-1), noise.scale) + pred_mean = (W.mean @ x[..., None]).squeeze(-1) + pred_density = Normal(pred_mean, noise.scale) pred_likelihood = pred_density.log_prob(y) - trace_term = 0.5 * ( - (W.covariance_weighted_inner_prod(x.unsqueeze(-2)[..., None])) - * noise.trace_precision - ) + # covariance-weighted inner product (averages over features) + # x.unsqueeze(-2)[..., None] ensures shape (..., feat, 1) + b = x.unsqueeze(-2)[..., None] + trace_term = 0.5 * (W.covariance_weighted_inner_prod(b) * noise.trace_precision) kl_term = gaussian_kl(W, self.prior_scale) - wishart_term = ( - self.dof * noise.logdet_precision - - 0.5 * self.wishart_scale * noise.trace_precision - ) + wishart_term = (self.dof * noise.logdet_precision - 0.5 * self.wishart_scale * noise.trace_precision) + total_elbo = torch.mean(pred_likelihood - trace_term) regularization_term = self.regularization_weight * (wishart_term - kl_term) total_elbo = total_elbo + regularization_term @@ -526,10 +472,8 @@ def loss_fn(y): return loss_fn - def _get_val_loss_fn(self, x): - def loss_fn(y): - # compute log likelihood under variational posterior via marginalization - logprob = self.predictive(x).log_prob(y).sum(-1) # sum over output dims - return -logprob.mean(0) # mean over batch dim - + def _get_val_loss_fn(self, x: Tensor) -> Callable[[Tensor], Tensor]: + def loss_fn(y: Tensor) -> Tensor: + logprob = self.predictive(x).log_prob(y).sum(dim=-1) # sum over output dims + return -logprob.mean(dim=0) return loss_fn