Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions botorch/sampling/pathwise/features/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,42 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
shape = self.raw_output_shape
ndim = len(shape)
for feature_map in self:
# Collect/scale individual feature blocks
block = feature_map(x, **kwargs).to_dense()
block_ndim = len(feature_map.output_shape)

# Handle broadcasting for lower-dimensional feature maps
if block_ndim < ndim:
# Determine how the tiling/broadcasting works for lower-dimensional feature maps
tile_shape = shape[-ndim:-block_ndim]
num_copies = prod(tile_shape)

# Scale down by sqrt of number of copies to maintain proper variance
if num_copies > 1:
block = block * (num_copies**-0.5)

# Create multi-index for broadcasting: add None dimensions for tiling
# This expands the block to match the target dimensionality
multi_index = (
...,
*repeat(None, ndim - block_ndim),
*repeat(slice(None), block_ndim),
*repeat(None, ndim - block_ndim), # Add new axes for tiling
*repeat(slice(None), block_ndim), # Keep existing dimensions
)
# Apply the multi-index and expand to tile across the new dimensions
block = block[multi_index].expand(
*block.shape[:-block_ndim], *tile_shape, *block.shape[-block_ndim:]
)
blocks.append(block)

# Concatenate all blocks along the last dimension
return torch.concat(blocks, dim=-1)

@property
def raw_output_shape(self) -> Size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I'm reading this correctly, raw_output_shape returns the largest broadcastable shape across all sub-kernels? How does it differ from torch.broadcast_shapes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raw_output_shape does two things in one go:

  1. For every sub-map it first figures out the shape you would get after tiling / broadcasting inside forward (e.g. when a 1-D feature map has to be expanded to align with a 2-D one).
  2. It then concatenates those shapes along the feature dimension, so the final size of the last axis is the sum of the sub-maps’ feature counts.

torch.broadcast_shapes by itself only answers “what common shape can all these tensors be viewed as without copying?”. It never alters the size of any dimension- especially not the last one.

In our case we need to grow the last dimension because we’re gluing feature vectors together; that’s why raw_output_shape can’t be expressed as a single torch.broadcast_shapes call.

# Handle empty DirectSumFeatureMap case - can occur when:
# 1. Purposely start with an empty container and plan to append feature maps later, or
# 2. Deleted the last entry and the list is now length-zero.
# Returning Size([]) keeps the object in a queryable state until real feature maps are added.
if not self:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does this occur?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not self only happens when the DirectSumFeatureMap is empty. That can be because you:

  • Purposely start with an empty container and plan to append feature-maps later, or
  • Deleted the last entry and the list is now length-zero.

We hit this in unit tests that build DirectSumFeatureMap([]) to make sure the class doesn’t crash in that edge case. Returning Size([]) just keeps the object in a sane, queryable state (so output_shape, batch_shape, etc. still work) until real feature maps are added. I’ve added a clarifying comment!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Terrific, thanks!

return Size([])

Expand Down Expand Up @@ -203,17 +217,25 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
for feature_map in self:
block = feature_map(x, **kwargs)
block_ndim = len(feature_map.output_shape)

# Handle blocks that match the target dimensionality
if block_ndim == ndim:
# Convert LinearOperator to dense tensor if needed
block = block.to_dense() if isinstance(block, LinearOperator) else block
# Ensure block is in sparse format for efficient block diagonal construction
block = block if block.is_sparse else block.to_sparse()
else:
# For lower-dimensional blocks, we need to expand dimensions
# but keep them dense since sparse tensor broadcasting is limited
multi_index = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment or two would be really helpful here too!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to the sparse branch!

...,
*repeat(None, ndim - block_ndim),
*repeat(slice(None), block_ndim),
*repeat(None, ndim - block_ndim), # Add new axes for expansion
*repeat(slice(None), block_ndim), # Keep existing dimensions
)
block = block.to_dense()[multi_index]
blocks.append(block)

# Construct sparse block diagonal matrix from all blocks
return sparse_block_diag(blocks, base_ndim=ndim)


Expand Down
28 changes: 19 additions & 9 deletions botorch/sampling/pathwise/prior_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,20 @@ def _draw_kernel_feature_paths_MultiTaskGP(
)

# Extract kernels from the product kernel structure
# model.covar_module is a ProductKernel
# model.covar_module is a ProductKernel by definition for MTGPs
# containing data_covar_module * task_covar_module
from gpytorch.kernels import ProductKernel

if isinstance(model.covar_module, ProductKernel):
if not isinstance(model.covar_module, ProductKernel):
# Fallback for non-ProductKernel cases (legacy support)
import warnings
warnings.warn(
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like the ordering matters for the implementation, but we reverse the order of the two: ProductKernel(SomeKernel, IndexKernel)

UserWarning,
)
combined_kernel = model.covar_module
else:
# Get the individual kernels from the product kernel
kernels = model.covar_module.kernels

Expand All @@ -169,7 +178,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
else:
data_kernel = deepcopy(kernel)
else:
# If no active_dims, it's likely the data kernel
# If no active_dims on data kernel, add them so downstream helpers don't error
data_kernel = deepcopy(kernel)
data_kernel.active_dims = torch.LongTensor(
[
Expand All @@ -180,7 +189,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
device=data_kernel.device,
)

# If we couldn't find the task kernel, create it based on the structure
# If the task kernel can't be found, create it based on the structure
if task_kernel is None:
from gpytorch.kernels import IndexKernel

Expand All @@ -190,14 +199,15 @@ def _draw_kernel_feature_paths_MultiTaskGP(
active_dims=[task_index],
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)

# Set task kernel active dims correctly
task_kernel.active_dims = torch.tensor([task_index], device=task_kernel.device)
# Ensure the data kernel was found
if data_kernel is None:
raise ValueError(
f"Could not identify data kernel from ProductKernel. "
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
)

# Use the existing product kernel structure
combined_kernel = data_kernel * task_kernel
else:
# Fallback to using the original covar_module directly
combined_kernel = model.covar_module

return _draw_kernel_feature_paths_fallback(
mean_module=model.mean_module,
Expand Down
29 changes: 19 additions & 10 deletions botorch/sampling/pathwise/update_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,21 @@ def _draw_kernel_feature_paths_MultiTaskGP(
)

# Extract kernels from the product kernel structure
# model.covar_module is a ProductKernel
# model.covar_module is a ProductKernel by definition for MTGPs
# containing data_covar_module * task_covar_module
from gpytorch.kernels import ProductKernel

if isinstance(model.covar_module, ProductKernel):
if not isinstance(model.covar_module, ProductKernel):
# Fallback for non-ProductKernel cases (legacy support)
# This should be rare as MTGPs typically use ProductKernels by definition
import warnings
warnings.warn(
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Can you change the order of the kernels in the error MSG to avoid confusion?

UserWarning,
)
combined_kernel = model.covar_module
else:
# Get the individual kernels from the product kernel
kernels = model.covar_module.kernels

Expand All @@ -193,7 +203,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
else:
data_kernel = deepcopy(kernel)
else:
# If no active_dims, it's likely the data kernel
# If no active_dims on data kernel, add them so downstream helpers don't error
data_kernel = deepcopy(kernel)
data_kernel.active_dims = torch.LongTensor(
[index for index in range(num_inputs) if index != task_index],
Expand All @@ -210,16 +220,15 @@ def _draw_kernel_feature_paths_MultiTaskGP(
active_dims=[task_index],
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)

# Set task kernel active dims correctly
task_kernel.active_dims = torch.LongTensor(
[task_index], device=task_kernel.device
)
# Ensure data kernel was found
if data_kernel is None:
raise ValueError(
f"Could not identify data kernel from ProductKernel. "
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here

)

# Use the existing product kernel structure
combined_kernel = data_kernel * task_kernel
else:
# Fallback to using the original covar_module directly
combined_kernel = model.covar_module

# Return exact update using product kernel
return _gaussian_update_exact(
Expand Down