1+ #include "nbl/builtin/hlsl/cpp_compat.hlsl"
2+ #include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
3+ #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
4+ #include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl"
5+ #include "nbl/builtin/hlsl/workgroup/scratch_size.hlsl"
6+ #include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
7+ #include "nbl/builtin/hlsl/enums.hlsl"
8+
9+ namespace nbl
10+ {
11+ namespace hlsl
12+ {
13+ namespace box_blur
14+ {
15+
16+ template<
17+ typename DataAccessor,
18+ typename SharedAccessor,
19+ typename ScanSharedAccessor,
20+ typename Sampler,
21+ uint16_t WorkgroupSize,
22+ class device_capabilities=void > // TODO: define concepts for the Box1D and apply constraints
23+ struct Box1D
24+ {
25+ // TODO: Generalize later on when Francesco enforces accessor-concepts in `workgroup` and adds a `SharedMemoryAccessor` concept
26+ struct ScanSharedAccessorWrapper
27+ {
28+ void get (const uint16_t ix, NBL_REF_ARG (float32_t) val)
29+ {
30+ val = base.template get<float32_t, uint16_t>(ix);
31+ }
32+
33+ void set (const uint16_t ix, const float32_t val)
34+ {
35+ base.template set<float32_t, uint16_t>(ix, val);
36+ }
37+
38+ void workgroupExecutionAndMemoryBarrier ()
39+ {
40+ base.workgroupExecutionAndMemoryBarrier ();
41+ }
42+
43+ ScanSharedAccessor base;
44+ };
45+
46+ void operator ()(
47+ NBL_REF_ARG (DataAccessor) data,
48+ NBL_REF_ARG (SharedAccessor) scratch,
49+ NBL_REF_ARG (ScanSharedAccessor) scanScratch,
50+ NBL_REF_ARG (Sampler) boxSampler,
51+ const uint16_t channel)
52+ {
53+ const uint16_t end = data.linearSize ();
54+ const uint16_t localInvocationIndex = workgroup::SubgroupContiguousIndex ();
55+
56+ // prefix sum
57+ // note the dynamically uniform loop condition
58+ for (uint16_t baseIx = 0 ; baseIx < end;)
59+ {
60+ const uint16_t ix = localInvocationIndex + baseIx;
61+ float32_t input = data.template get<float32_t>(channel, ix);
62+ // dynamically uniform condition
63+ if (baseIx != 0 )
64+ {
65+ // take result of previous prefix sum and add it to first element here
66+ if (localInvocationIndex == 0 )
67+ input += scratch.template get<float32_t>(baseIx - 1 );
68+ }
69+ // need to copy-in / copy-out the accessor cause no references in HLSL - yay!
70+ ScanSharedAccessorWrapper scanScratchWrapper;
71+ scanScratchWrapper.base = scanScratch;
72+ const float32_t sum = workgroup::inclusive_scan<plus<float32_t>, WorkgroupSize, device_capabilities>::template __call (input, scanScratchWrapper);
73+ scanScratch = scanScratchWrapper.base;
74+ // loop increment
75+ baseIx += WorkgroupSize;
76+ // if doing the last prefix sum, we need to barrier to stop aliasing of temporary scratch for `inclusive_scan` and our scanline
77+ // TODO: might be worth adding a non-aliased mode as NSight says nr 1 hotspot is barrier waiting in this code
78+ if (end + ScanSharedAccessor::Size > SharedAccessor::Size)
79+ scratch.workgroupExecutionAndMemoryBarrier ();
80+ // save prefix sum results
81+ if (ix < end)
82+ scratch.template set<float32_t>(ix, sum);
83+ // previous prefix sum must have finished before we ask for results
84+ scratch.workgroupExecutionAndMemoryBarrier ();
85+ }
86+
87+ const float32_t last = end - 1 ;
88+ const float32_t normalizationFactor = 1.f / (2.f * radius + 1.f );
89+
90+ for (float32_t ix = localInvocationIndex; ix < end; ix += WorkgroupSize)
91+ {
92+ const float32_t result = boxSampler (scratch, ix, radius, borderColor[channel]);
93+ data.template set<float32_t>(channel, uint16_t (ix), result * normalizationFactor);
94+ }
95+ }
96+
97+ vector <float32_t, DataAccessor::Channels> borderColor;
98+ float32_t radius;
99+ };
100+
101+ template<typename PrefixSumAccessor, typename T>
102+ struct BoxSampler
103+ {
104+ uint16_t wrapMode;
105+ uint16_t linearSize;
106+
107+ T operator ()(NBL_REF_ARG (PrefixSumAccessor) prefixSumAccessor, float32_t ix, float32_t radius, float32_t borderColor)
108+ {
109+ const float32_t alpha = radius - floor (radius);
110+ const float32_t lastIdx = linearSize - 1 ;
111+ const float32_t rightIdx = float32_t (ix) + radius;
112+ const float32_t leftIdx = float32_t (ix) - radius;
113+ const int32_t rightFlIdx = (int32_t)floor (rightIdx);
114+ const int32_t rightClIdx = (int32_t)ceil (rightIdx);
115+ const int32_t leftFlIdx = (int32_t)floor (leftIdx);
116+ const int32_t leftClIdx = (int32_t)ceil (leftIdx);
117+
118+ T result = 0 ;
119+ if (rightFlIdx < linearSize)
120+ {
121+ result += lerp (prefixSumAccessor.template get<T, uint32_t>(rightFlIdx), prefixSumAccessor.template get<T, uint32_t>(rightClIdx), alpha);
122+ }
123+ else
124+ {
125+ switch (wrapMode) {
126+ case ETC_REPEAT:
127+ {
128+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
129+ const T floored = prefixSumAccessor.template get<T, uint32_t>(rightFlIdx % linearSize) + ceil (float32_t (rightFlIdx % lastIdx) / linearSize) * last;
130+ const T ceiled = prefixSumAccessor.template get<T, uint32_t>(rightClIdx % linearSize) + ceil (float32_t (rightClIdx % lastIdx) / linearSize) * last;
131+ result += lerp (floored, ceiled, alpha);
132+ break ;
133+ }
134+ case ETC_CLAMP_TO_BORDER:
135+ {
136+ result += prefixSumAccessor.template get<T, uint32_t>(lastIdx) + (rightIdx - lastIdx) * borderColor;
137+ break ;
138+ }
139+ case ETC_CLAMP_TO_EDGE:
140+ {
141+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
142+ const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1 );
143+ result += (rightIdx - lastIdx) * (last - lastMinusOne) + last;
144+ break ;
145+ }
146+ case ETC_MIRROR:
147+ {
148+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
149+ T floored, ceiled;
150+ int32_t d = rightFlIdx - lastIdx;
151+
152+ if (d % (2 * linearSize) == linearSize)
153+ floored = ((d + linearSize) / linearSize) * last;
154+ else
155+ {
156+ const uint32_t period = uint32_t (ceil (float32_t (d) / linearSize));
157+ if ((period & 0x1u) == 1 )
158+ floored = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t (d % linearSize));
159+ else
160+ floored = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1 ) % linearSize);
161+ }
162+
163+ d = rightClIdx - lastIdx;
164+ if (d % (2 * linearSize) == linearSize)
165+ ceiled = ((d + linearSize) / linearSize) * last;
166+ else
167+ {
168+ const uint32_t period = uint32_t (ceil (float32_t (d) / linearSize));
169+ if ((period & 0x1u) == 1 )
170+ ceiled = period * last + last - prefixSumAccessor.template get<T, uint32_t>(lastIdx - uint32_t (d % linearSize));
171+ else
172+ ceiled = period * last + prefixSumAccessor.template get<T, uint32_t>((d - 1 ) % linearSize);
173+ }
174+
175+ result += lerp (floored, ceiled, alpha);
176+ break ;
177+ }
178+ case ETC_MIRROR_CLAMP_TO_EDGE:
179+ {
180+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
181+ const T first = prefixSumAccessor.template get<T, uint32_t>(0 );
182+ const T firstPlusOne = prefixSumAccessor.template get<T, uint32_t>(1 );
183+ result += (rightIdx - lastIdx) * (firstPlusOne - first) + last;
184+ break ;
185+ }
186+ }
187+ }
188+
189+ if (leftFlIdx >= 0 )
190+ {
191+ result -= lerp (prefixSumAccessor.template get<T, uint32_t>(leftFlIdx), prefixSumAccessor.template get<T, uint32_t>(leftClIdx), alpha);
192+ }
193+ else
194+ {
195+ switch (wrapMode) {
196+ case ETC_REPEAT:
197+ {
198+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
199+ const T floored = prefixSumAccessor.template get<T, uint32_t>(abs (leftFlIdx) % linearSize) + ceil (T (leftFlIdx) / linearSize) * last;
200+ const T ceiled = prefixSumAccessor.template get<T, uint32_t>(abs (leftClIdx) % linearSize) + ceil (float32_t (leftClIdx) / linearSize) * last;
201+ result -= lerp (floored, ceiled, alpha);
202+ break ;
203+ }
204+ case ETC_CLAMP_TO_BORDER:
205+ {
206+ result -= prefixSumAccessor.template get<T, uint32_t>(0 ) + leftIdx * borderColor;
207+ break ;
208+ }
209+ case ETC_CLAMP_TO_EDGE:
210+ {
211+ result -= leftIdx * prefixSumAccessor.template get<T, uint32_t>(0 );
212+ break ;
213+ }
214+ case ETC_MIRROR:
215+ {
216+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
217+ T floored, ceiled;
218+
219+ if (abs (leftFlIdx + 1 ) % (2 * linearSize) == 0 )
220+ floored = -(abs (leftFlIdx + 1 ) / linearSize) * last;
221+ else
222+ {
223+ const uint32_t period = uint32_t (ceil (float32_t (abs (leftFlIdx + 1 )) / linearSize));
224+ if ((period & 0x1u) == 1 )
225+ floored = -(period - 1 ) * last - prefixSumAccessor.template get<T, uint32_t>((abs (leftFlIdx + 1 ) - 1 ) % linearSize);
226+ else
227+ floored = -(period - 1 ) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftFlIdx + 1 ) % linearSize - 1 ));
228+ }
229+
230+ if (leftClIdx == 0 ) // Special case, wouldn't be possible for `floored` above
231+ ceiled = 0 ;
232+ else if (abs (leftClIdx + 1 ) % (2 * linearSize) == 0 )
233+ ceiled = -(abs (leftClIdx + 1 ) / linearSize) * last;
234+ else
235+ {
236+ const uint32_t period = uint32_t (ceil (float32_t (abs (leftClIdx + 1 )) / linearSize));
237+ if ((period & 0x1u) == 1 )
238+ ceiled = -(period - 1 ) * last - prefixSumAccessor.template get<T, uint32_t>((abs (leftClIdx + 1 ) - 1 ) % linearSize);
239+ else
240+ ceiled = -(period - 1 ) * last - (last - prefixSumAccessor.template get<T, uint32_t>((leftClIdx + 1 ) % linearSize - 1 ));
241+ }
242+
243+ result -= lerp (floored, ceiled, alpha);
244+ break ;
245+ }
246+ case ETC_MIRROR_CLAMP_TO_EDGE:
247+ {
248+ const T last = prefixSumAccessor.template get<T, uint32_t>(lastIdx);
249+ const T lastMinusOne = prefixSumAccessor.template get<T, uint32_t>(lastIdx - 1 );
250+ result -= leftIdx * (last - lastMinusOne);
251+ break ;
252+ }
253+ }
254+ }
255+
256+ return result;
257+ }
258+ };
259+
260+ }
261+ }
262+ }
0 commit comments