Skip to content

Commit a83dbd1

Browse files
committed
Fix memory illegal access issue
Add comprehensive documentation and fix dependency issues for CUDA FFT integration. This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in to ensure proper installation and functionality. Key changes include: - no more CUDA Malloc .. all memory is allocated in Python by XLA - Added detailed docstrings to C++ header files - Enhanced inline comments in C++ source files to explain complex logic and algorithms. - Updated to relax JAX version dependency, resolving installation issues. - Refined docstrings and comments in Python files for clarity and consistency. - Cleaned up debug print statements
1 parent 866d1f2 commit a83dbd1

File tree

10 files changed

+1157
-220
lines changed

10 files changed

+1157
-220
lines changed

lib/include/plan_cache.h

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#ifndef PLAN_CACHE_H
32
#define PLAN_CACHE_H
43

@@ -9,36 +8,84 @@
98
#include "hresult.h"
109
#include "s2fft.h"
1110
#include <unordered_map>
11+
#include <type_traits>
1212

1313
namespace s2fft {
1414

15+
/**
16+
* @brief Manages and caches s2fftExec instances to optimize resource usage.
17+
*
18+
* This class implements the singleton pattern to ensure only one instance
19+
* of the PlanCache exists throughout the application. It stores pre-initialized
20+
* s2fftExec objects based on their descriptors (parameters like nside, L, etc.)
21+
* to avoid redundant initialization, which can be computationally expensive.
22+
*/
1523
class PlanCache {
1624
public:
25+
/**
26+
* @brief Returns the singleton instance of the PlanCache.
27+
*
28+
* @return A reference to the single PlanCache instance.
29+
*/
1730
static PlanCache &GetInstance() {
1831
static PlanCache instance;
1932
return instance;
2033
}
2134

22-
HRESULT GetS2FFTExec(s2fftDescriptor &descriptor, std::shared_ptr<s2fftExec<cufftComplex>> &executor);
35+
/**
36+
* @brief Retrieves an s2fftExec instance from the cache or initializes a new one.
37+
*
38+
* This templated method attempts to find an existing s2fftExec instance
39+
* matching the provided descriptor in its internal cache (m_Descriptors32 or m_Descriptors64)
40+
* based on the Complex type T. If a matching instance is found, it is returned.
41+
* Otherwise, a new s2fftExec instance is created, initialized with the descriptor,
42+
* and then stored in the cache before being returned.
43+
*
44+
* @tparam T The complex type (cufftComplex or cufftDoubleComplex) of the s2fftExec instance.
45+
* @param descriptor The s2fftDescriptor containing the parameters for the FFT.
46+
* @param executor A shared_ptr that will point to the retrieved or newly initialized s2fftExec instance.
47+
* @return HRESULT indicating success (S_OK if new, S_FALSE if from cache) or failure.
48+
*/
49+
template <typename T>
50+
HRESULT GetS2FFTExec(s2fftDescriptor &descriptor, std::shared_ptr<s2fftExec<T>> &executor);
2351

24-
HRESULT GetS2FFTExec(s2fftDescriptor &descriptor,
25-
std::shared_ptr<s2fftExec<cufftDoubleComplex>> &executor);
52+
/**
53+
* @brief Clears all cached s2fftExec instances.
54+
*
55+
* This method is typically called during application shutdown to release
56+
* all resources held by the cached FFT plans.
57+
*/
58+
void Finalize();
2659

27-
~PlanCache() {}
60+
/**
61+
* @brief Destructor for PlanCache.
62+
*
63+
* Ensures that Finalize() is called when the PlanCache instance is destroyed,
64+
* performing necessary cleanup.
65+
*/
66+
~PlanCache();
2867

2968
private:
3069
bool is_initialized = false;
3170

71+
// Unordered maps to store cached s2fftExec instances for double and single precision
3272
std::unordered_map<s2fftDescriptor, std::shared_ptr<s2fftExec<cufftDoubleComplex>>,
3373
std::hash<s2fftDescriptor>, std::equal_to<>>
3474
m_Descriptors64;
3575
std::unordered_map<s2fftDescriptor, std::shared_ptr<s2fftExec<cufftComplex>>, std::hash<s2fftDescriptor>,
3676
std::equal_to<>>
3777
m_Descriptors32;
3878

79+
/**
80+
* @brief Private constructor for PlanCache.
81+
*
82+
* Initializes the PlanCache instance. This constructor is private to enforce
83+
* the singleton pattern.
84+
*/
3985
PlanCache();
4086

4187
public:
88+
// Delete copy constructor and assignment operator to prevent copying
4289
PlanCache(PlanCache const &) = delete;
4390
void operator=(PlanCache const &) = delete;
4491
};

lib/include/s2fft.h

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#ifndef S2FFT_H
32
#define S2FFT_H
43

@@ -19,13 +18,49 @@
1918

2019
namespace s2fft {
2120

21+
/**
22+
* @brief Returns the appropriate cuFFT C2C type for a given complex type.
23+
*
24+
* This function is overloaded for `cufftDoubleComplex` and `cufftComplex`
25+
* to return `CUFFT_Z2Z` (double precision) or `CUFFT_C2C` (single precision)
26+
* respectively.
27+
*
28+
* @param dummy A dummy complex object used for type deduction.
29+
* @return The corresponding cuFFT C2C type.
30+
*/
2231
static cufftType get_cufft_type_c2c(cufftDoubleComplex) { return CUFFT_Z2Z; }
2332
static cufftType get_cufft_type_c2c(cufftComplex) { return CUFFT_C2C; }
2433

34+
/**
35+
* @brief Transforms data from ring-based indexing to nphi-based indexing.
36+
*
37+
* This function is a placeholder for the actual implementation which would
38+
* reorder data in memory according to the specified indexing scheme.
39+
*
40+
* @param data Pointer to the input/output data.
41+
* @param nside The HEALPix Nside parameter.
42+
*/
2543
void s2fft_rings_2_nphi(float *data, int nside);
2644

45+
/**
46+
* @brief Transforms data from nphi-based indexing to ring-based indexing.
47+
*
48+
* This function is a placeholder for the actual implementation which would
49+
* reorder data in memory according to the specified indexing scheme.
50+
*
51+
* @param data Pointer to the input/output data.
52+
* @param nside The HEALPix Nside parameter.
53+
*/
2754
void s2fft_nphi_2_rings(float *data, int nside);
2855

56+
/**
57+
* @brief Descriptor class for s2fft operations.
58+
*
59+
* This class encapsulates all the necessary parameters to define a unique
60+
* Spherical Harmonic Transform (SHT) operation, including Nside, harmonic
61+
* band limit, reality, adjoint flag, forward/backward transform direction,
62+
* normalization, shifting, and double precision usage.
63+
*/
2964
class s2fftDescriptor {
3065
public:
3166
int64_t nside;
@@ -38,6 +73,18 @@ class s2fftDescriptor {
3873
bool shift = true;
3974
bool double_precision = false;
4075

76+
/**
77+
* @brief Constructs an s2fftDescriptor object.
78+
*
79+
* @param nside The HEALPix Nside parameter.
80+
* @param harmonic_band_limit The harmonic band limit L.
81+
* @param reality Flag indicating if the signal is real.
82+
* @param adjoint Flag indicating if the adjoint transform is to be performed.
83+
* @param forward Flag indicating if it's a forward transform (default: true).
84+
* @param norm The FFT normalization type (default: BACKWARD).
85+
* @param shift Flag indicating if FFT shifting should be applied (default: true).
86+
* @param double_precision Flag indicating if double precision should be used (default: false).
87+
*/
4188
s2fftDescriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool adjoint,
4289
bool forward = true, s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD,
4390
bool shift = true, bool double_precision = false)
@@ -50,53 +97,130 @@ class s2fftDescriptor {
5097
shift(shift),
5198
double_precision(double_precision) {}
5299

100+
/**
101+
* @brief Default constructor for s2fftDescriptor.
102+
*/
53103
s2fftDescriptor() = default;
104+
105+
/**
106+
* @brief Destructor for s2fftDescriptor.
107+
*/
54108
~s2fftDescriptor() = default;
55109

110+
/**
111+
* @brief Equality operator for s2fftDescriptor.
112+
*
113+
* Compares two s2fftDescriptor objects for equality based on their member values.
114+
*
115+
* @param other The other s2fftDescriptor to compare against.
116+
* @return True if the descriptors are equal, false otherwise.
117+
*/
56118
bool operator==(const s2fftDescriptor &other) const {
57119
return nside == other.nside && harmonic_band_limit == other.harmonic_band_limit &&
58120
reality == other.reality && norm == other.norm && shift == other.shift &&
59121
double_precision == other.double_precision;
60122
}
61123
};
62124

125+
/**
126+
* @brief Executes Spherical Harmonic Transform (SHT) operations.
127+
*
128+
* This templated class provides methods for initializing FFT plans and executing
129+
* forward and backward SHTs. It manages cuFFT handles and internal offsets
130+
* required for the transforms.
131+
*
132+
* @tparam Complex The complex type (cufftComplex or cufftDoubleComplex) for the FFT operations.
133+
*/
63134
template <typename Complex>
64135
class s2fftExec {
65-
friend class PlanCache;
136+
friend class PlanCache; // Allows PlanCache to access private members for caching
66137

67138
public:
139+
/**
140+
* @brief Default constructor for s2fftExec.
141+
*/
68142
s2fftExec() {}
69-
~s2fftExec() {}
70-
71-
HRESULT Initialize(const s2fftDescriptor &descriptor, size_t &worksize);
72143

73-
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data);
144+
/**
145+
* @brief Destructor for s2fftExec.
146+
*/
147+
~s2fftExec() {}
74148

75-
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data);
149+
/**
150+
* @brief Initializes the FFT plans for the SHT.
151+
*
152+
* This method sets up the necessary cuFFT plans for both polar and equatorial
153+
* rings based on the provided descriptor. It also calculates and stores the
154+
* maximum required workspace size (m_work_size).
155+
*
156+
* @param descriptor The s2fftDescriptor containing the parameters for the FFT.
157+
* @return HRESULT indicating success or failure.
158+
*/
159+
HRESULT Initialize(const s2fftDescriptor &descriptor);
160+
161+
/**
162+
* @brief Executes the forward Spherical Harmonic Transform.
163+
*
164+
* This method performs the forward FFT operations on the input data
165+
* across polar and equatorial rings using the pre-initialized cuFFT plans.
166+
*
167+
* @param desc The s2fftDescriptor for the current transform.
168+
* @param stream The CUDA stream to use for execution.
169+
* @param data Pointer to the input/output data on the device.
170+
* @param workspace Pointer to the workspace memory on the device.
171+
* @param callback_params Pointer to device memory containing callback parameters.
172+
* @return HRESULT indicating success or failure.
173+
*/
174+
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
175+
int64 *callback_params);
176+
177+
/**
178+
* @brief Executes the backward Spherical Harmonic Transform.
179+
*
180+
* This method performs the inverse FFT operations on the input data
181+
* across polar and equatorial rings using the pre-initialized cuFFT plans.
182+
*
183+
* @param desc The s2fftDescriptor for the current transform.
184+
* @param stream The CUDA stream to use for execution.
185+
* @param data Pointer to the input/output data on the device.
186+
* @param workspace Pointer to the workspace memory on the device.
187+
* @param callback_params Pointer to device memory containing callback parameters.
188+
* @return HRESULT indicating success or failure.
189+
*/
190+
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
191+
int64 *callback_params);
76192

77193
public:
194+
// cuFFT handles for polar and equatorial FFT plans
78195
std::vector<cufftHandle> m_polar_plans;
79196
cufftHandle m_equator_plan;
80197
std::vector<cufftHandle> m_inverse_polar_plans;
81198
cufftHandle m_inverse_equator_plan;
199+
200+
// Parameters defining the SHT geometry and data layout
82201
int m_nside;
83202
int m_equatorial_ring_num;
84203
int64 m_total_pixels;
85204
int64 m_equatorial_offset_start;
86205
int64 m_equatorial_offset_end;
87206
std::vector<int64> m_upper_ring_offsets;
88207
std::vector<int64> m_lower_ring_offsets;
89-
90-
// Callback params stored for cleanup purposes
91-
// thrust::device_vector<cb_params> m_cb_params;
208+
size_t m_work_size = 0; // Maximum workspace size required for FFT plans
92209
};
93210

94211
} // namespace s2fft
95212

96213
namespace std {
214+
/**
215+
* @brief Custom hash specialization for s2fftDescriptor.
216+
*
217+
* This specialization allows s2fftDescriptor objects to be used as keys
218+
* in `std::unordered_map` by providing a hash function.
219+
*/
97220
template <>
98221
struct hash<s2fft::s2fftDescriptor> {
99222
std::size_t operator()(const s2fft::s2fftDescriptor &k) const {
223+
// Combine hash values of individual members
100224
size_t hash = std::hash<int64_t>()(k.nside) ^ (std::hash<int64_t>()(k.harmonic_band_limit) << 1) ^
101225
(std::hash<bool>()(k.reality) << 2) ^ (std::hash<int>()(k.norm) << 3) ^
102226
(std::hash<bool>()(k.shift) << 4) ^ (std::hash<bool>()(k.double_precision) << 5);

lib/include/s2fft_callbacks.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,44 @@
1212
typedef long long int int64;
1313

1414
namespace s2fftKernels {
15+
/**
16+
* @brief Defines the normalization types for FFT operations.
17+
*/
1518
enum fft_norm { FORWARD = 1, BACKWARD = 2, ORTHO = 3, NONE = 4 };
1619

17-
HRESULT setCallback(cufftHandle forwardPlan, cufftHandle backwardPlan, int64 *params_dev, bool shift,
18-
bool equator, bool doublePrecision, fft_norm norm);
20+
/**
21+
* @brief Sets cuFFT callbacks specifically for a forward FFT plan.
22+
*
23+
* This function configures the cuFFT library to use custom callbacks
24+
* for normalization and shifting operations during forward FFT execution.
25+
*
26+
* @param plan The cuFFT handle for the forward FFT plan.
27+
* @param params_dev Pointer to device memory containing parameters for the callbacks.
28+
* @param shift Boolean flag indicating whether to apply FFT shifting.
29+
* @param equator Boolean flag indicating if the current operation is for the equatorial ring.
30+
* @param doublePrecision Boolean flag indicating if double precision is used.
31+
* @param norm The FFT normalization type to apply.
32+
* @return HRESULT indicating success or failure.
33+
*/
34+
HRESULT setForwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator,
35+
bool doublePrecision, fft_norm norm);
36+
37+
/**
38+
* @brief Sets cuFFT callbacks specifically for a backward FFT plan.
39+
*
40+
* This function configures the cuFFT library to use custom callbacks
41+
* for normalization and shifting operations during backward FFT execution.
42+
*
43+
* @param plan The cuFFT handle for the inverse FFT plan.
44+
* @param params_dev Pointer to device memory containing parameters for the callbacks.
45+
* @param shift Boolean flag indicating whether to apply FFT shifting.
46+
* @param equator Boolean flag indicating if the current operation is for the equatorial ring.
47+
* @param doublePrecision Boolean flag indicating if double precision is used.
48+
* @param norm The FFT normalization type to apply.
49+
* @return HRESULT indicating success or failure.
50+
*/
51+
HRESULT setBackwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator,
52+
bool doublePrecision, fft_norm norm);
1953
} // namespace s2fftKernels
2054

2155
#endif

0 commit comments

Comments
 (0)