1-
21#ifndef S2FFT_H
32#define S2FFT_H
43
1918
2019namespace s2fft {
2120
21+ /* *
22+ * @brief Returns the appropriate cuFFT C2C type for a given complex type.
23+ *
24+ * This function is overloaded for `cufftDoubleComplex` and `cufftComplex`
25+ * to return `CUFFT_Z2Z` (double precision) or `CUFFT_C2C` (single precision)
26+ * respectively.
27+ *
28+ * @param dummy A dummy complex object used for type deduction.
29+ * @return The corresponding cuFFT C2C type.
30+ */
2231static cufftType get_cufft_type_c2c (cufftDoubleComplex) { return CUFFT_Z2Z; }
2332static cufftType get_cufft_type_c2c (cufftComplex) { return CUFFT_C2C; }
2433
34+ /* *
35+ * @brief Transforms data from ring-based indexing to nphi-based indexing.
36+ *
37+ * This function is a placeholder for the actual implementation which would
38+ * reorder data in memory according to the specified indexing scheme.
39+ *
40+ * @param data Pointer to the input/output data.
41+ * @param nside The HEALPix Nside parameter.
42+ */
2543void s2fft_rings_2_nphi (float *data, int nside);
2644
45+ /* *
46+ * @brief Transforms data from nphi-based indexing to ring-based indexing.
47+ *
48+ * This function is a placeholder for the actual implementation which would
49+ * reorder data in memory according to the specified indexing scheme.
50+ *
51+ * @param data Pointer to the input/output data.
52+ * @param nside The HEALPix Nside parameter.
53+ */
2754void s2fft_nphi_2_rings (float *data, int nside);
2855
56+ /* *
57+ * @brief Descriptor class for s2fft operations.
58+ *
59+ * This class encapsulates all the necessary parameters to define a unique
60+ * Spherical Harmonic Transform (SHT) operation, including Nside, harmonic
61+ * band limit, reality, adjoint flag, forward/backward transform direction,
62+ * normalization, shifting, and double precision usage.
63+ */
2964class s2fftDescriptor {
3065public:
3166 int64_t nside;
@@ -38,6 +73,18 @@ class s2fftDescriptor {
3873 bool shift = true ;
3974 bool double_precision = false ;
4075
76+ /* *
77+ * @brief Constructs an s2fftDescriptor object.
78+ *
79+ * @param nside The HEALPix Nside parameter.
80+ * @param harmonic_band_limit The harmonic band limit L.
81+ * @param reality Flag indicating if the signal is real.
82+ * @param adjoint Flag indicating if the adjoint transform is to be performed.
83+ * @param forward Flag indicating if it's a forward transform (default: true).
84+ * @param norm The FFT normalization type (default: BACKWARD).
85+ * @param shift Flag indicating if FFT shifting should be applied (default: true).
86+ * @param double_precision Flag indicating if double precision should be used (default: false).
87+ */
4188 s2fftDescriptor (int64_t nside, int64_t harmonic_band_limit, bool reality, bool adjoint,
4289 bool forward = true , s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD,
4390 bool shift = true , bool double_precision = false )
@@ -50,53 +97,130 @@ class s2fftDescriptor {
5097 shift(shift),
5198 double_precision(double_precision) {}
5299
100+ /* *
101+ * @brief Default constructor for s2fftDescriptor.
102+ */
53103 s2fftDescriptor () = default;
104+
105+ /* *
106+ * @brief Destructor for s2fftDescriptor.
107+ */
54108 ~s2fftDescriptor () = default ;
55109
110+ /* *
111+ * @brief Equality operator for s2fftDescriptor.
112+ *
113+ * Compares two s2fftDescriptor objects for equality based on their member values.
114+ *
115+ * @param other The other s2fftDescriptor to compare against.
116+ * @return True if the descriptors are equal, false otherwise.
117+ */
56118 bool operator ==(const s2fftDescriptor &other) const {
57119 return nside == other.nside && harmonic_band_limit == other.harmonic_band_limit &&
58120 reality == other.reality && norm == other.norm && shift == other.shift &&
59121 double_precision == other.double_precision ;
60122 }
61123};
62124
125+ /* *
126+ * @brief Executes Spherical Harmonic Transform (SHT) operations.
127+ *
128+ * This templated class provides methods for initializing FFT plans and executing
129+ * forward and backward SHTs. It manages cuFFT handles and internal offsets
130+ * required for the transforms.
131+ *
132+ * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex) for the FFT operations.
133+ */
63134template <typename Complex>
64135class s2fftExec {
65- friend class PlanCache ;
136+ friend class PlanCache ; // Allows PlanCache to access private members for caching
66137
67138public:
139+ /* *
140+ * @brief Default constructor for s2fftExec.
141+ */
68142 s2fftExec () {}
69- ~s2fftExec () {}
70-
71- HRESULT Initialize (const s2fftDescriptor &descriptor, size_t &worksize);
72143
73- HRESULT Forward (const s2fftDescriptor &desc, cudaStream_t stream, Complex *data);
144+ /* *
145+ * @brief Destructor for s2fftExec.
146+ */
147+ ~s2fftExec () {}
74148
75- HRESULT Backward (const s2fftDescriptor &desc, cudaStream_t stream, Complex *data);
149+ /* *
150+ * @brief Initializes the FFT plans for the SHT.
151+ *
152+ * This method sets up the necessary cuFFT plans for both polar and equatorial
153+ * rings based on the provided descriptor. It also calculates and stores the
154+ * maximum required workspace size (m_work_size).
155+ *
156+ * @param descriptor The s2fftDescriptor containing the parameters for the FFT.
157+ * @return HRESULT indicating success or failure.
158+ */
159+ HRESULT Initialize (const s2fftDescriptor &descriptor);
160+
161+ /* *
162+ * @brief Executes the forward Spherical Harmonic Transform.
163+ *
164+ * This method performs the forward FFT operations on the input data
165+ * across polar and equatorial rings using the pre-initialized cuFFT plans.
166+ *
167+ * @param desc The s2fftDescriptor for the current transform.
168+ * @param stream The CUDA stream to use for execution.
169+ * @param data Pointer to the input/output data on the device.
170+ * @param workspace Pointer to the workspace memory on the device.
171+ * @param callback_params Pointer to device memory containing callback parameters.
172+ * @return HRESULT indicating success or failure.
173+ */
174+ HRESULT Forward (const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
175+ int64 *callback_params);
176+
177+ /* *
178+ * @brief Executes the backward Spherical Harmonic Transform.
179+ *
180+ * This method performs the inverse FFT operations on the input data
181+ * across polar and equatorial rings using the pre-initialized cuFFT plans.
182+ *
183+ * @param desc The s2fftDescriptor for the current transform.
184+ * @param stream The CUDA stream to use for execution.
185+ * @param data Pointer to the input/output data on the device.
186+ * @param workspace Pointer to the workspace memory on the device.
187+ * @param callback_params Pointer to device memory containing callback parameters.
188+ * @return HRESULT indicating success or failure.
189+ */
190+ HRESULT Backward (const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
191+ int64 *callback_params);
76192
77193public:
194+ // cuFFT handles for polar and equatorial FFT plans
78195 std::vector<cufftHandle> m_polar_plans;
79196 cufftHandle m_equator_plan;
80197 std::vector<cufftHandle> m_inverse_polar_plans;
81198 cufftHandle m_inverse_equator_plan;
199+
200+ // Parameters defining the SHT geometry and data layout
82201 int m_nside;
83202 int m_equatorial_ring_num;
84203 int64 m_total_pixels;
85204 int64 m_equatorial_offset_start;
86205 int64 m_equatorial_offset_end;
87206 std::vector<int64> m_upper_ring_offsets;
88207 std::vector<int64> m_lower_ring_offsets;
89-
90- // Callback params stored for cleanup purposes
91- // thrust::device_vector<cb_params> m_cb_params;
208+ size_t m_work_size = 0 ; // Maximum workspace size required for FFT plans
92209};
93210
94211} // namespace s2fft
95212
96213namespace std {
214+ /* *
215+ * @brief Custom hash specialization for s2fftDescriptor.
216+ *
217+ * This specialization allows s2fftDescriptor objects to be used as keys
218+ * in `std::unordered_map` by providing a hash function.
219+ */
97220template <>
98221struct hash <s2fft::s2fftDescriptor> {
99222 std::size_t operator ()(const s2fft::s2fftDescriptor &k) const {
223+ // Combine hash values of individual members
100224 size_t hash = std::hash<int64_t >()(k.nside ) ^ (std::hash<int64_t >()(k.harmonic_band_limit ) << 1 ) ^
101225 (std::hash<bool >()(k.reality ) << 2 ) ^ (std::hash<int >()(k.norm ) << 3 ) ^
102226 (std::hash<bool >()(k.shift ) << 4 ) ^ (std::hash<bool >()(k.double_precision ) << 5 );
0 commit comments