Skip to content

Commit b5cbeac

Browse files
committed
Implement VMAP and transpose rules for cuda primitive
1 parent e2cc68c commit b5cbeac

File tree

4 files changed

+352
-104
lines changed

4 files changed

+352
-104
lines changed

lib/include/cudastreamhandler.hpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
2+
/**
3+
* @file cudastreamhandler.hpp
4+
* @brief Singleton class for managing CUDA streams and events.
5+
*
6+
* This header provides a singleton implementation that encapsulates the creation,
7+
* management, and cleanup of CUDA streams and events. It offers functions to fork
8+
* streams, add new streams, and synchronize (join) streams with a given dependency.
9+
*
10+
* Usage example:
11+
* @code
12+
* #include "cudastreamhandler.hpp"
13+
*
14+
* int main() {
15+
* // Create a handler instance
16+
* CudaStreamHandler handler;
17+
*
18+
* // Fork 4 streams dependent on a given stream 'stream_main'
19+
* handler.Fork(stream_main, 4);
20+
*
21+
* // Do work on the forked streams...
22+
*
23+
* // Join the streams back to 'stream_main'
24+
* handler.join(stream_main);
25+
*
26+
* return 0;
27+
* }
28+
* @endcode
29+
*
30+
* Author: Wassim KABALAN
31+
*/
32+
33+
#ifndef CUDASTREAMHANDLER_HPP
34+
#define CUDASTREAMHANDLER_HPP
35+
36+
#include <algorithm>
37+
#include <atomic>
38+
#include <cuda_runtime.h>
39+
#include <stdexcept>
40+
#include <thread>
41+
#include <vector>
42+
43+
// Singleton class managing CUDA streams and events
44+
class CudaStreamHandlerImpl {
45+
public:
46+
static CudaStreamHandlerImpl &instance() {
47+
static CudaStreamHandlerImpl instance;
48+
return instance;
49+
}
50+
51+
void AddStreams(int numStreams) {
52+
if (numStreams > m_streams.size()) {
53+
int streamsToAdd = numStreams - m_streams.size();
54+
m_streams.resize(numStreams);
55+
std::generate(m_streams.end() - streamsToAdd, m_streams.end(), []() {
56+
cudaStream_t stream;
57+
cudaStreamCreate(&stream);
58+
return stream;
59+
});
60+
}
61+
}
62+
63+
void join(cudaStream_t finalStream) {
64+
std::for_each(m_streams.begin(), m_streams.end(), [this, finalStream](cudaStream_t stream) {
65+
cudaEvent_t event;
66+
cudaEventCreate(&event);
67+
cudaEventRecord(event, stream);
68+
cudaStreamWaitEvent(finalStream, event, 0);
69+
m_events.push_back(event);
70+
});
71+
72+
if (!cleanup_thread.joinable()) {
73+
stop_thread.store(false);
74+
cleanup_thread = std::thread([this]() { this->AsyncEventCleanup(); });
75+
}
76+
}
77+
78+
// Fork function to add streams and set dependency on a given stream
79+
void Fork(cudaStream_t dependentStream, int N) {
80+
AddStreams(N); // Add N streams
81+
82+
// Set dependency on the provided stream
83+
std::for_each(m_streams.end() - N, m_streams.end(), [this, dependentStream](cudaStream_t stream) {
84+
cudaEvent_t event;
85+
cudaEventCreate(&event);
86+
cudaEventRecord(event, dependentStream);
87+
cudaStreamWaitEvent(stream, event, 0); // Set the stream to wait on the event
88+
m_events.push_back(event);
89+
});
90+
}
91+
92+
auto getIterator() { return StreamIterator(m_streams.begin(), m_streams.end()); }
93+
94+
~CudaStreamHandlerImpl() {
95+
stop_thread.store(true);
96+
if (cleanup_thread.joinable()) {
97+
cleanup_thread.join();
98+
}
99+
100+
std::for_each(m_streams.begin(), m_streams.end(), cudaStreamDestroy);
101+
std::for_each(m_events.begin(), m_events.end(), cudaEventDestroy);
102+
}
103+
104+
// Custom Iterator class to iterate over streams
105+
class StreamIterator {
106+
public:
107+
StreamIterator(std::vector<cudaStream_t>::iterator begin, std::vector<cudaStream_t>::iterator end)
108+
: current(begin), end(end) {}
109+
110+
cudaStream_t next() {
111+
if (current == end) {
112+
throw std::out_of_range("No more streams.");
113+
}
114+
return *current++;
115+
}
116+
117+
bool hasNext() const { return current != end; }
118+
119+
private:
120+
std::vector<cudaStream_t>::iterator current;
121+
std::vector<cudaStream_t>::iterator end;
122+
};
123+
124+
private:
125+
CudaStreamHandlerImpl() : stop_thread(false) {}
126+
CudaStreamHandlerImpl(const CudaStreamHandlerImpl &) = delete;
127+
CudaStreamHandlerImpl &operator=(const CudaStreamHandlerImpl &) = delete;
128+
129+
void AsyncEventCleanup() {
130+
while (!stop_thread.load()) {
131+
std::for_each(m_events.begin(), m_events.end(), [this](cudaEvent_t &event) {
132+
if (cudaEventQuery(event) == cudaSuccess) {
133+
cudaEventDestroy(event);
134+
event = nullptr;
135+
}
136+
});
137+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
138+
}
139+
}
140+
141+
std::vector<cudaStream_t> m_streams;
142+
std::vector<cudaEvent_t> m_events;
143+
std::thread cleanup_thread;
144+
std::atomic<bool> stop_thread;
145+
};
146+
147+
// Public class for encapsulating the singleton operations
148+
class CudaStreamHandler {
149+
public:
150+
CudaStreamHandler() = default;
151+
~CudaStreamHandler() = default;
152+
153+
void AddStreams(int numStreams) { CudaStreamHandlerImpl::instance().AddStreams(numStreams); }
154+
155+
void join(cudaStream_t finalStream) { CudaStreamHandlerImpl::instance().join(finalStream); }
156+
157+
void Fork(cudaStream_t cudastream, int N) { CudaStreamHandlerImpl::instance().Fork(cudastream, N); }
158+
159+
// Get the custom iterator for CUDA streams
160+
CudaStreamHandlerImpl::StreamIterator getIterator() {
161+
return CudaStreamHandlerImpl::instance().getIterator();
162+
}
163+
};
164+
165+
#endif // CUDASTREAMHANDLER_HPP

lib/include/s2fft.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@ class s2fftDescriptor {
3131
int64_t nside;
3232
int64_t harmonic_band_limit;
3333
bool reality;
34+
bool adjoint;
3435

3536
bool forward = true;
3637
s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD;
3738
bool shift = true;
3839
bool double_precision = false;
3940

40-
s2fftDescriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward = true,
41-
s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD, bool shift = true,
42-
bool double_precision = false)
41+
s2fftDescriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool adjoint,
42+
bool forward = true, s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD,
43+
bool shift = true, bool double_precision = false)
4344
: nside(nside),
4445
harmonic_band_limit(harmonic_band_limit),
4546
reality(reality),
47+
adjoint(adjoint),
4648
norm(norm),
4749
forward(forward),
4850
shift(shift),

0 commit comments

Comments
 (0)