11#include "nbl/builtin/hlsl/cpp_compat.hlsl"
2+ #include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
23
4+ struct SPushConstants
5+ {
6+ uint32_t cascadeCount;
7+ float base;
8+ uint32_t sampleCount;
9+ float minReliableLuma;
10+ float kappa;
11+ };
12+
13+ [[vk::push_constant]] SPushConstants pc;
314[[vk::image_format ("rgba16f" )]] [[vk::binding (0 , 0 )]] RWTexture2D <float32_t4> outImage;
415[[vk::image_format ("rgba16f" )]] [[vk::binding (1 , 0 )]] RWTexture2DArray <float32_t4> cascade;
516
@@ -10,6 +21,145 @@ NBL_CONSTEXPR uint32_t WorkgroupSize = 512;
1021NBL_CONSTEXPR uint32_t MAX_DEPTH_LOG2 = 4 ;
1122NBL_CONSTEXPR uint32_t MAX_SAMPLES_LOG2 = 10 ;
1223
24+ struct RWMCReweightingParameters
25+ {
26+ uint32_t lastCascadeIndex;
27+ float initialEmin; // a minimum image brightness that we always consider reliable
28+ float reciprocalBase;
29+ float reciprocalN;
30+ float reciprocalKappa;
31+ float colorReliabilityFactor;
32+ float NOverKappa;
33+ };
34+
35+ RWMCReweightingParameters computeReweightingParameters (uint32_t cascadeCount, float base, uint32_t sampleCount, float minReliableLuma, float kappa)
36+ {
37+ RWMCReweightingParameters retval;
38+ retval.lastCascadeIndex = cascadeCount - 1u;
39+ retval.initialEmin = minReliableLuma;
40+ retval.reciprocalBase = 1.f / base;
41+ const float N = float (sampleCount);
42+ retval.reciprocalN = 1.f / N;
43+ retval.reciprocalKappa = 1.f / kappa;
44+ // if not interested in exact expected value estimation (kappa!=1.f), can usually accept a bit more variance relative to the image brightness we already have
45+ // allow up to ~<cascadeBase> more energy in one sample to lessen bias in some cases
46+ retval.colorReliabilityFactor = base + (1.f - base) * retval.reciprocalKappa;
47+ retval.NOverKappa = N * retval.reciprocalKappa;
48+
49+ return retval;
50+ }
51+
52+ struct RWMCCascadeSample
53+ {
54+ float32_t3 centerValue;
55+ float normalizedCenterLuma;
56+ float normalizedNeighbourhoodAverageLuma;
57+ };
58+
59+ // TODO: figure out what values should pixels outside have, 0.0f is incorrect
60+ float32_t3 RWMCsampleCascadeTexel (int32_t2 currentCoord, int32_t2 offset, uint32_t cascadeIndex)
61+ {
62+ const int32_t2 texelCoord = currentCoord + offset;
63+ if (any (texelCoord < int32_t2 (0 , 0 )))
64+ return float32_t3 (0.0f , 0.0f , 0.0f );
65+
66+ float32_t4 output = cascade.Load (int32_t3 (texelCoord, int32_t (cascadeIndex)));
67+ return float32_t3 (output.r, output.g, output.b);
68+ }
69+
70+ float32_t calcLuma (in float32_t3 col)
71+ {
72+ return hlsl::dot<float32_t3>(hlsl::transpose (colorspace::scRGBtoXYZ)[1 ], col);
73+ }
74+
75+ RWMCCascadeSample RWMCSampleCascade (in int32_t2 coord, in uint cascadeIndex, in float reciprocalBaseI)
76+ {
77+ float32_t3 neighbourhood[9 ];
78+ neighbourhood[0 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , -1 ), cascadeIndex);
79+ neighbourhood[1 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , -1 ), cascadeIndex);
80+ neighbourhood[2 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , -1 ), cascadeIndex);
81+ neighbourhood[3 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , 0 ), cascadeIndex);
82+ neighbourhood[4 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , 0 ), cascadeIndex);
83+ neighbourhood[5 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , 0 ), cascadeIndex);
84+ neighbourhood[6 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , 1 ), cascadeIndex);
85+ neighbourhood[7 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , 1 ), cascadeIndex);
86+ neighbourhood[8 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , 1 ), cascadeIndex);
87+
88+ // numerical robustness
89+ float32_t3 excl_hood_sum = ((neighbourhood[0 ] + neighbourhood[1 ]) + (neighbourhood[2 ] + neighbourhood[3 ])) +
90+ ((neighbourhood[5 ] + neighbourhood[6 ]) + (neighbourhood[7 ] + neighbourhood[8 ]));
91+
92+ RWMCCascadeSample retval;
93+ retval.centerValue = neighbourhood[4 ];
94+ retval.normalizedNeighbourhoodAverageLuma = retval.normalizedCenterLuma = calcLuma (neighbourhood[4 ]) * reciprocalBaseI;
95+ retval.normalizedNeighbourhoodAverageLuma = (calcLuma (excl_hood_sum) * reciprocalBaseI + retval.normalizedNeighbourhoodAverageLuma) / 9.f ;
96+ return retval;
97+ }
98+
99+ float32_t3 RWMCReweight (in RWMCReweightingParameters params, in int32_t2 coord)
100+ {
101+ float reciprocalBaseI = 1.f ;
102+ RWMCCascadeSample curr = RWMCSampleCascade (coord, 0u, reciprocalBaseI);
103+
104+ float32_t3 accumulation = float32_t3 (0.0f , 0.0f , 0.0f );
105+ float Emin = params.initialEmin;
106+
107+ float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
108+ for (uint i = 0u; i <= params.lastCascadeIndex; i++)
109+ {
110+ const bool notFirstCascade = i != 0u;
111+ const bool notLastCascade = i != params.lastCascadeIndex;
112+
113+ RWMCCascadeSample next;
114+ if (notLastCascade)
115+ {
116+ reciprocalBaseI *= params.reciprocalBase;
117+ next = RWMCSampleCascade (coord, i + 1u, reciprocalBaseI);
118+ }
119+
120+
121+ float reliability = 1.f ;
122+ // sample counting-based reliability estimation
123+ if (params.reciprocalKappa <= 1.f )
124+ {
125+ float localReliability = curr.normalizedCenterLuma;
126+ // reliability in 3x3 pixel block (see robustness)
127+ float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
128+ if (notFirstCascade)
129+ {
130+ localReliability += prevNormalizedCenterLuma;
131+ globalReliability += prevNormalizedNeighbourhoodAverageLuma;
132+ }
133+ if (notLastCascade)
134+ {
135+ localReliability += next.normalizedCenterLuma;
136+ globalReliability += next.normalizedNeighbourhoodAverageLuma;
137+ }
138+ // check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
139+ reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
140+ {
141+ const float accumLuma = calcLuma (accumulation);
142+ if (accumLuma > Emin)
143+ Emin = accumLuma;
144+
145+ const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
146+
147+ reliability += colorReliability;
148+ reliability *= params.NOverKappa;
149+ reliability -= params.reciprocalKappa;
150+ reliability = clamp (reliability * 0.5f , 0.f , 1.f );
151+ }
152+ }
153+ accumulation += curr.centerValue * reliability;
154+
155+ prevNormalizedCenterLuma = curr.normalizedCenterLuma;
156+ prevNormalizedNeighbourhoodAverageLuma = curr.normalizedNeighbourhoodAverageLuma;
157+ curr = next;
158+ }
159+
160+ return accumulation;
161+ }
162+
13163int32_t2 getCoordinates ()
14164{
15165 uint32_t width, height;
@@ -19,7 +169,7 @@ int32_t2 getCoordinates()
19169
20170// this function is for testing purpose
21171// simply adds every cascade buffer, output shoud be nearly the same as output of default accumulator (RWMC off)
22- void sumCascade (in const int32_t2 coords)
172+ float32_t3 sumCascade (in const int32_t2 coords)
23173{
24174 float32_t3 accumulation = float32_t3 (0.0f , 0.0f , 0.0f );
25175
@@ -31,9 +181,7 @@ void sumCascade(in const int32_t2 coords)
31181
32182 accumulation /= 32.0f ;
33183
34- float32_t4 output = float32_t4 (accumulation, 1.0f );
35-
36- outImage[coords] = output;
184+ return accumulation;
37185}
38186
39187[numthreads (WorkgroupSize, 1 , 1 )]
@@ -45,9 +193,11 @@ void main(uint32_t3 threadID : SV_DispatchThreadID)
45193 return ;
46194
47195 const int32_t2 coords = getCoordinates ();
48- sumCascade (coords);
196+ //float32_t3 color = sumCascade(coords);
49197
50- // zero out cascade
51- for (int i = 0 ; i < 6 ; ++i)
52- cascade[uint3 (coords.x, coords.y, i)] = float32_t4 (0.0f , 0.0f , 0.0f , 0.0f );
198+ RWMCReweightingParameters reweightingParameters = computeReweightingParameters (pc.cascadeCount, pc.base, pc.sampleCount, pc.minReliableLuma, pc.kappa);
199+ float32_t3 color = RWMCReweight (reweightingParameters, coords);
200+ color /= pc.sampleCount;
201+
202+ outImage[coords] = float32_t4 (color, 1.0f );
53203}
0 commit comments