-
Notifications
You must be signed in to change notification settings - Fork 457
Refactor botorch/sampling/pathwise and add support for product kernels #2838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
2403f81
b5127ca
9384cc0
95c6d13
5d2000c
0321b49
a2f2ef5
9774176
04ae7c4
c2ed4d6
7fe9237
ebe03af
bf3a70e
29640a9
693ab9e
e12a545
d85a6cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,28 +122,42 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor: | |
| shape = self.raw_output_shape | ||
| ndim = len(shape) | ||
| for feature_map in self: | ||
| # Collect/scale individual feature blocks | ||
| block = feature_map(x, **kwargs).to_dense() | ||
| block_ndim = len(feature_map.output_shape) | ||
|
|
||
| # Handle broadcasting for lower-dimensional feature maps | ||
| if block_ndim < ndim: | ||
| # Determine how the tiling/broadcasting works for lower-dimensional feature maps | ||
| tile_shape = shape[-ndim:-block_ndim] | ||
| num_copies = prod(tile_shape) | ||
|
|
||
| # Scale down by sqrt of number of copies to maintain proper variance | ||
| if num_copies > 1: | ||
| block = block * (num_copies**-0.5) | ||
|
|
||
| # Create multi-index for broadcasting: add None dimensions for tiling | ||
| # This expands the block to match the target dimensionality | ||
| multi_index = ( | ||
| ..., | ||
| *repeat(None, ndim - block_ndim), | ||
| *repeat(slice(None), block_ndim), | ||
| *repeat(None, ndim - block_ndim), # Add new axes for tiling | ||
| *repeat(slice(None), block_ndim), # Keep existing dimensions | ||
| ) | ||
| # Apply the multi-index and expand to tile across the new dimensions | ||
| block = block[multi_index].expand( | ||
| *block.shape[:-block_ndim], *tile_shape, *block.shape[-block_ndim:] | ||
| ) | ||
| blocks.append(block) | ||
|
|
||
| # Concatenate all blocks along the last dimension | ||
| return torch.concat(blocks, dim=-1) | ||
|
|
||
| @property | ||
| def raw_output_shape(self) -> Size: | ||
| # Handle empty DirectSumFeatureMap case - can occur when: | ||
| # 1. Purposely start with an empty container and plan to append feature maps later, or | ||
| # 2. Deleted the last entry and the list is now length-zero. | ||
| # Returning Size([]) keeps the object in a queryable state until real feature maps are added. | ||
| if not self: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When does this occur?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We hit this in unit tests that build
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Terrific, thanks! |
||
| return Size([]) | ||
|
|
||
|
|
@@ -203,17 +217,25 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor: | |
| for feature_map in self: | ||
| block = feature_map(x, **kwargs) | ||
| block_ndim = len(feature_map.output_shape) | ||
|
|
||
| # Handle blocks that match the target dimensionality | ||
| if block_ndim == ndim: | ||
| # Convert LinearOperator to dense tensor if needed | ||
| block = block.to_dense() if isinstance(block, LinearOperator) else block | ||
| # Ensure block is in sparse format for efficient block diagonal construction | ||
| block = block if block.is_sparse else block.to_sparse() | ||
| else: | ||
| # For lower-dimensional blocks, we need to expand dimensions | ||
| # but keep them dense since sparse tensor broadcasting is limited | ||
| multi_index = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment or two would be really helpful here too!
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments to the sparse branch! |
||
| ..., | ||
| *repeat(None, ndim - block_ndim), | ||
| *repeat(slice(None), block_ndim), | ||
| *repeat(None, ndim - block_ndim), # Add new axes for expansion | ||
| *repeat(slice(None), block_ndim), # Keep existing dimensions | ||
| ) | ||
| block = block.to_dense()[multi_index] | ||
| blocks.append(block) | ||
|
|
||
| # Construct sparse block diagonal matrix from all blocks | ||
| return sparse_block_diag(blocks, base_ndim=ndim) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,11 +150,20 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| ) | ||
|
|
||
| # Extract kernels from the product kernel structure | ||
| # model.covar_module is a ProductKernel | ||
| # model.covar_module is a ProductKernel by definition for MTGPs | ||
| # containing data_covar_module * task_covar_module | ||
| from gpytorch.kernels import ProductKernel | ||
|
|
||
| if isinstance(model.covar_module, ProductKernel): | ||
| if not isinstance(model.covar_module, ProductKernel): | ||
| # Fallback for non-ProductKernel cases (legacy support) | ||
| import warnings | ||
| warnings.warn( | ||
| f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). " | ||
| "Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.", | ||
|
||
| UserWarning, | ||
| ) | ||
| combined_kernel = model.covar_module | ||
| else: | ||
| # Get the individual kernels from the product kernel | ||
| kernels = model.covar_module.kernels | ||
|
|
||
|
|
@@ -169,7 +178,7 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| else: | ||
| data_kernel = deepcopy(kernel) | ||
| else: | ||
| # If no active_dims, it's likely the data kernel | ||
| # If no active_dims on data kernel, add them so downstream helpers don't error | ||
| data_kernel = deepcopy(kernel) | ||
| data_kernel.active_dims = torch.LongTensor( | ||
| [ | ||
|
|
@@ -180,7 +189,7 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| device=data_kernel.device, | ||
| ) | ||
|
|
||
| # If we couldn't find the task kernel, create it based on the structure | ||
| # If the task kernel can't be found, create it based on the structure | ||
| if task_kernel is None: | ||
| from gpytorch.kernels import IndexKernel | ||
|
|
||
|
|
@@ -190,14 +199,15 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| active_dims=[task_index], | ||
| ).to(device=model.covar_module.device, dtype=model.covar_module.dtype) | ||
|
|
||
| # Set task kernel active dims correctly | ||
| task_kernel.active_dims = torch.tensor([task_index], device=task_kernel.device) | ||
| # Ensure the data kernel was found | ||
| if data_kernel is None: | ||
| raise ValueError( | ||
| f"Could not identify data kernel from ProductKernel. " | ||
| "MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern." | ||
| ) | ||
|
|
||
| # Use the existing product kernel structure | ||
| combined_kernel = data_kernel * task_kernel | ||
| else: | ||
| # Fallback to using the original covar_module directly | ||
| combined_kernel = model.covar_module | ||
|
|
||
| return _draw_kernel_feature_paths_fallback( | ||
| mean_module=model.mean_module, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -174,11 +174,21 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| ) | ||
|
|
||
| # Extract kernels from the product kernel structure | ||
| # model.covar_module is a ProductKernel | ||
| # model.covar_module is a ProductKernel by definition for MTGPs | ||
| # containing data_covar_module * task_covar_module | ||
| from gpytorch.kernels import ProductKernel | ||
|
|
||
| if isinstance(model.covar_module, ProductKernel): | ||
| if not isinstance(model.covar_module, ProductKernel): | ||
| # Fallback for non-ProductKernel cases (legacy support) | ||
| # This should be rare as MTGPs typically use ProductKernels by definition | ||
| import warnings | ||
| warnings.warn( | ||
| f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). " | ||
| "Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.", | ||
|
||
| UserWarning, | ||
| ) | ||
| combined_kernel = model.covar_module | ||
| else: | ||
| # Get the individual kernels from the product kernel | ||
| kernels = model.covar_module.kernels | ||
|
|
||
|
|
@@ -193,7 +203,7 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| else: | ||
| data_kernel = deepcopy(kernel) | ||
| else: | ||
| # If no active_dims, it's likely the data kernel | ||
| # If no active_dims on data kernel, add them so downstream helpers don't error | ||
| data_kernel = deepcopy(kernel) | ||
| data_kernel.active_dims = torch.LongTensor( | ||
| [index for index in range(num_inputs) if index != task_index], | ||
|
|
@@ -210,16 +220,15 @@ def _draw_kernel_feature_paths_MultiTaskGP( | |
| active_dims=[task_index], | ||
| ).to(device=model.covar_module.device, dtype=model.covar_module.dtype) | ||
|
|
||
| # Set task kernel active dims correctly | ||
| task_kernel.active_dims = torch.LongTensor( | ||
| [task_index], device=task_kernel.device | ||
| ) | ||
| # Ensure data kernel was found | ||
| if data_kernel is None: | ||
| raise ValueError( | ||
| f"Could not identify data kernel from ProductKernel. " | ||
| "MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern." | ||
|
||
| ) | ||
|
|
||
| # Use the existing product kernel structure | ||
| combined_kernel = data_kernel * task_kernel | ||
| else: | ||
| # Fallback to using the original covar_module directly | ||
| combined_kernel = model.covar_module | ||
|
|
||
| # Return exact update using product kernel | ||
| return _gaussian_update_exact( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I'm reading this correctly,
raw_output_shapereturns the largest broadcastable shape across all sub-kernels? How does it differ fromtorch.broadcast_shapes?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raw_output_shapedoes two things in one go:forward(e.g. when a 1-D feature map has to be expanded to align with a 2-D one).torch.broadcast_shapesby itself only answers “what common shape can all these tensors be viewed as without copying?”. It never alters the size of any dimension- especially not the last one.In our case we need to grow the last dimension because we’re gluing feature vectors together; that’s why
raw_output_shapecan’t be expressed as a singletorch.broadcast_shapescall.