1313from tidy3d .components .base import Tidy3dBaseModel
1414from tidy3d .components .types import TYPE_TAG_STR
1515from tidy3d .plugins .autograd .functions import convolve
16+ from tidy3d .plugins .autograd .primitives import gaussian_filter as autograd_gaussian_filter
1617from tidy3d .plugins .autograd .types import KernelType , PaddingType
1718from tidy3d .plugins .autograd .utilities import get_kernel_size_px , make_kernel
1819
20+ _GAUSSIAN_SIGMA_SCALE = 0.445 # empirically matches conic kernel response in 1D/2D tests
21+ _GAUSSIAN_PADDING_MAP = {
22+ "constant" : "constant" ,
23+ "edge" : "nearest" ,
24+ "reflect" : "reflect" ,
25+ "symmetric" : "mirror" ,
26+ "wrap" : "wrap" ,
27+ }
28+
1929
2030class AbstractFilter (Tidy3dBaseModel , abc .ABC ):
2131 """An abstract class for creating and applying convolution filters."""
@@ -92,9 +102,13 @@ def __call__(self, array: NDArray) -> NDArray:
92102 size_px = tuple (np .atleast_1d (self .kernel_size ))
93103 if len (size_px ) != squeezed_array .ndim :
94104 size_px *= squeezed_array .ndim
105+ filtered_array = self ._apply_filter (squeezed_array , size_px )
106+ return np .reshape (filtered_array , original_shape )
107+
108+ def _apply_filter (self , array : NDArray , size_px : tuple [int , ...]) -> NDArray :
109+ """Apply the concrete filter implementation to the squeezed array."""
95110 kernel = self .get_kernel (size_px , self .normalize )
96- convolved_array = convolve (squeezed_array , kernel , padding = self .padding )
97- return np .reshape (convolved_array , original_shape )
111+ return convolve (array , kernel , padding = self .padding )
98112
99113
100114class ConicFilter (AbstractFilter ):
@@ -127,6 +141,60 @@ def get_kernel(size_px: Iterable[int], normalize: bool) -> NDArray:
127141 return make_kernel (kernel_type = "circular" , size = size_px , normalize = normalize )
128142
129143
144+ class GaussianFilter (AbstractFilter ):
145+ """A Gaussian filter implemented via separable gaussian_filter primitive.
146+
147+ Notes
148+ -----
149+ Padding modes ``'constant'``, ``'edge'``, ``'reflect'``, ``'symmetric'``, and ``'wrap'`` are
150+ supported. Modes ``'edge'`` and ``'symmetric'`` are internally mapped to the SciPy equivalents
151+ ``'nearest'`` and ``'mirror'`` respectively. The default ``sigma_scale`` of 0.445 was tuned to
152+ match the conic kernel when expressed in pixel radius. The ``normalize`` flag inherited from
153+ :class:`AbstractFilter` is ignored because the separable Gaussian implementation always returns
154+ a unit-sum kernel; setting it to ``False`` has no effect.
155+ """
156+
157+ sigma_scale : float = pd .Field (
158+ _GAUSSIAN_SIGMA_SCALE ,
159+ title = "Sigma Scale" ,
160+ description = "Scale factor mapping radius in pixels to Gaussian sigma." ,
161+ ge = 0.0 ,
162+ )
163+ truncate : float = pd .Field (
164+ 2.0 ,
165+ title = "Truncate" ,
166+ description = "Truncation radius in multiples of sigma passed to ``gaussian_filter``." ,
167+ ge = 0.0 ,
168+ )
169+
170+ @staticmethod
171+ def get_kernel (size_px : Iterable [int ], normalize : bool ) -> NDArray :
172+ raise NotImplementedError ("GaussianFilter does not build an explicit kernel." )
173+
174+ def _apply_filter (self , array : NDArray , size_px : tuple [int , ...]) -> NDArray :
175+ radius_px = np .maximum ((np .array (size_px , dtype = float ) - 1.0 ) / 2.0 , 0.0 )
176+ if radius_px .size == 0 :
177+ return array
178+
179+ mode = _GAUSSIAN_PADDING_MAP .get (self .padding )
180+ if mode is None :
181+ raise ValueError (
182+ f"Unsupported padding mode '{ self .padding } ' for gaussian filter; "
183+ f"supported modes are { tuple (_GAUSSIAN_PADDING_MAP )} ."
184+ )
185+
186+ sigma = tuple (float (self .sigma_scale * r ) if r > 0 else 0.0 for r in radius_px )
187+ if not any (sigma ):
188+ return array
189+
190+ kwargs : dict [str , Any ] = {"mode" : mode , "truncate" : float (self .truncate )}
191+ if mode == "constant" :
192+ kwargs ["cval" ] = 0.0
193+
194+ filtered = autograd_gaussian_filter (array , sigma = sigma , ** kwargs )
195+ return filtered
196+
197+
130198def _get_kernel_size (
131199 radius : Union [float , tuple [float , ...]],
132200 dl : Union [float , tuple [float , ...]],
@@ -189,7 +257,7 @@ def make_filter(
189257 padding : PaddingType = "reflect"
190258 The padding mode to use.
191259 filter_type : KernelType
192- The type of kernel to create (``circular`` or ``conic ``).
260+ The type of kernel to create (``circular``, ``conic``, or ``gaussian ``).
193261
194262 Returns
195263 -------
@@ -202,10 +270,12 @@ def make_filter(
202270 filter_class = ConicFilter
203271 elif filter_type == "circular" :
204272 filter_class = CircularFilter
273+ elif filter_type == "gaussian" :
274+ filter_class = GaussianFilter
205275 else :
206276 raise ValueError (
207277 f"Unsupported filter_type: { filter_type } . "
208- "Must be one of `CircularFilter` or `ConicFilter `."
278+ "Must be one of `CircularFilter`, `ConicFilter`, or `GaussianFilter `."
209279 )
210280
211281 filter_instance = filter_class (kernel_size = kernel_size , normalize = normalize , padding = padding )
@@ -221,11 +291,21 @@ def make_filter(
221291"""
222292
223293make_circular_filter = partial (make_filter , filter_type = "circular" )
224- make_circular_filter .__doc__ = """make_filter() with a default filter_type value of `circular`.
294+ make_circular_filter .__doc__ = """make_filter() with a default filter_type value of ``circular``.
295+
296+ See Also
297+ --------
298+ :func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size.
299+ """
300+
301+ make_gaussian_filter = partial (make_filter , filter_type = "gaussian" )
302+ make_gaussian_filter .__doc__ = """make_filter() with a default filter_type value of ``gaussian``.
225303
226304See Also
227305--------
228306:func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size.
229307"""
230308
231- FilterType = Annotated [Union [ConicFilter , CircularFilter ], pd .Field (discriminator = TYPE_TAG_STR )]
309+ FilterType = Annotated [
310+ Union [ConicFilter , CircularFilter , GaussianFilter ], pd .Field (discriminator = TYPE_TAG_STR )
311+ ]
0 commit comments