Skip to content

Commit 389248c

Browse files
committed
Implemented reweighting
1 parent 57a6a0f commit 389248c

File tree

5 files changed

+210
-16
lines changed

5 files changed

+210
-16
lines changed

31_HLSLPathTracer/app_resources/hlsl/pathtracer.hlsl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,19 @@ struct Unidirectional
314314
uint32_t base;
315315
};
316316

317-
// tmp
318-
float calculateLumaRec709(float32_t4 color)
317+
/**
318+
* @brief Resets all buffers in the cascade to 0 at the given pixel coordinates.
319+
*
320+
* This function writes zero values to every buffer in the cascade
321+
* for the specified 2D pixel location.
322+
*
323+
* @param coords Integer 2D coordinates of the pixel to reset.
324+
* @param cascadeSize number of buffers in the cascade to clear.
325+
*/
326+
void resetCascade(NBL_CONST_REF_ARG(int32_t2) coords, uint32_t cascadeSize)
319327
{
320-
return 0.2126 * color.r + 0.7152 * color.g + 0.0722 * color.b;
328+
for (int i = 0; i < 6; ++i)
329+
cascade[uint3(coords.x, coords.y, i)] = float32_t4(0.0f, 0.0f, 0.0f, 0.0f);
321330
}
322331

323332
void generateCascade(int32_t2 coords, uint32_t numSamples, uint32_t depth, NBL_CONST_REF_ARG(RWMCCascadeSettings) cascadeSettings, NBL_CONST_REF_ARG(scene_type) scene)
@@ -331,7 +340,6 @@ struct Unidirectional
331340
measure_type accumulation = getSingleSampleMeasure(i, depth, scene);
332341

333342
const float luma = getLuma(accumulation);
334-
//const float luma = calculateLumaRec709(float32_t4(accumulation, 1.0f));
335343

336344
uint32_t lowerCascadeIndex = 0u;
337345
while (!(luma < upperScale) && lowerCascadeIndex < cascadeSettings.size - 2)

31_HLSLPathTracer/app_resources/hlsl/render.comp.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ void main(uint32_t3 threadID : SV_DispatchThreadID)
231231
cascadeSettings.start = 1u;
232232
cascadeSettings.base = 8u;
233233

234+
pathtracer.resetCascade(coords, 6u);
234235
pathtracer.generateCascade(coords, pc.sampleCount, pc.depth, cascadeSettings, scene);
235236
}
236237

31_HLSLPathTracer/app_resources/hlsl/render_common.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ struct SPushConstants
66
float32_t4x4 invMVP;
77
int sampleCount;
88
int depth;
9+
uint32_t rwmcCascadeSize;
10+
uint32_t rwmcCascadeStart;
11+
uint32_t rwmcCascadeBase;
912
};
1013

1114
[[vk::push_constant]] SPushConstants pc;

31_HLSLPathTracer/app_resources/hlsl/reweighting.hlsl

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
2+
#include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
23

4+
struct SPushConstants
5+
{
6+
uint32_t cascadeCount;
7+
float base;
8+
uint32_t sampleCount;
9+
float minReliableLuma;
10+
float kappa;
11+
};
12+
13+
[[vk::push_constant]] SPushConstants pc;
314
[[vk::image_format("rgba16f")]] [[vk::binding(0, 0)]] RWTexture2D<float32_t4> outImage;
415
[[vk::image_format("rgba16f")]] [[vk::binding(1, 0)]] RWTexture2DArray<float32_t4> cascade;
516

@@ -10,6 +21,145 @@ NBL_CONSTEXPR uint32_t WorkgroupSize = 512;
1021
NBL_CONSTEXPR uint32_t MAX_DEPTH_LOG2 = 4;
1122
NBL_CONSTEXPR uint32_t MAX_SAMPLES_LOG2 = 10;
1223

24+
struct RWMCReweightingParameters
25+
{
26+
uint32_t lastCascadeIndex;
27+
float initialEmin; // a minimum image brightness that we always consider reliable
28+
float reciprocalBase;
29+
float reciprocalN;
30+
float reciprocalKappa;
31+
float colorReliabilityFactor;
32+
float NOverKappa;
33+
};
34+
35+
RWMCReweightingParameters computeReweightingParameters(uint32_t cascadeCount, float base, uint32_t sampleCount, float minReliableLuma, float kappa)
36+
{
37+
RWMCReweightingParameters retval;
38+
retval.lastCascadeIndex = cascadeCount - 1u;
39+
retval.initialEmin = minReliableLuma;
40+
retval.reciprocalBase = 1.f / base;
41+
const float N = float(sampleCount);
42+
retval.reciprocalN = 1.f / N;
43+
retval.reciprocalKappa = 1.f / kappa;
44+
// if not interested in exact expected value estimation (kappa!=1.f), can usually accept a bit more variance relative to the image brightness we already have
45+
// allow up to ~<cascadeBase> more energy in one sample to lessen bias in some cases
46+
retval.colorReliabilityFactor = base + (1.f - base) * retval.reciprocalKappa;
47+
retval.NOverKappa = N * retval.reciprocalKappa;
48+
49+
return retval;
50+
}
51+
52+
struct RWMCCascadeSample
53+
{
54+
float32_t3 centerValue;
55+
float normalizedCenterLuma;
56+
float normalizedNeighbourhoodAverageLuma;
57+
};
58+
59+
// TODO: figure out what values should pixels outside have, 0.0f is incorrect
60+
float32_t3 RWMCsampleCascadeTexel(int32_t2 currentCoord, int32_t2 offset, uint32_t cascadeIndex)
61+
{
62+
const int32_t2 texelCoord = currentCoord + offset;
63+
if (any(texelCoord < int32_t2(0, 0)))
64+
return float32_t3(0.0f, 0.0f, 0.0f);
65+
66+
float32_t4 output = cascade.Load(int32_t3(texelCoord, int32_t(cascadeIndex)));
67+
return float32_t3(output.r, output.g, output.b);
68+
}
69+
70+
float32_t calcLuma(in float32_t3 col)
71+
{
72+
return hlsl::dot<float32_t3>(hlsl::transpose(colorspace::scRGBtoXYZ)[1], col);
73+
}
74+
75+
RWMCCascadeSample RWMCSampleCascade(in int32_t2 coord, in uint cascadeIndex, in float reciprocalBaseI)
76+
{
77+
float32_t3 neighbourhood[9];
78+
neighbourhood[0] = RWMCsampleCascadeTexel(coord, int32_t2(-1, -1), cascadeIndex);
79+
neighbourhood[1] = RWMCsampleCascadeTexel(coord, int32_t2(0, -1), cascadeIndex);
80+
neighbourhood[2] = RWMCsampleCascadeTexel(coord, int32_t2(1, -1), cascadeIndex);
81+
neighbourhood[3] = RWMCsampleCascadeTexel(coord, int32_t2(-1, 0), cascadeIndex);
82+
neighbourhood[4] = RWMCsampleCascadeTexel(coord, int32_t2(0, 0), cascadeIndex);
83+
neighbourhood[5] = RWMCsampleCascadeTexel(coord, int32_t2(1, 0), cascadeIndex);
84+
neighbourhood[6] = RWMCsampleCascadeTexel(coord, int32_t2(-1, 1), cascadeIndex);
85+
neighbourhood[7] = RWMCsampleCascadeTexel(coord, int32_t2(0, 1), cascadeIndex);
86+
neighbourhood[8] = RWMCsampleCascadeTexel(coord, int32_t2(1, 1), cascadeIndex);
87+
88+
// numerical robustness
89+
float32_t3 excl_hood_sum = ((neighbourhood[0] + neighbourhood[1]) + (neighbourhood[2] + neighbourhood[3])) +
90+
((neighbourhood[5] + neighbourhood[6]) + (neighbourhood[7] + neighbourhood[8]));
91+
92+
RWMCCascadeSample retval;
93+
retval.centerValue = neighbourhood[4];
94+
retval.normalizedNeighbourhoodAverageLuma = retval.normalizedCenterLuma = calcLuma(neighbourhood[4]) * reciprocalBaseI;
95+
retval.normalizedNeighbourhoodAverageLuma = (calcLuma(excl_hood_sum) * reciprocalBaseI + retval.normalizedNeighbourhoodAverageLuma) / 9.f;
96+
return retval;
97+
}
98+
99+
float32_t3 RWMCReweight(in RWMCReweightingParameters params, in int32_t2 coord)
100+
{
101+
float reciprocalBaseI = 1.f;
102+
RWMCCascadeSample curr = RWMCSampleCascade(coord, 0u, reciprocalBaseI);
103+
104+
float32_t3 accumulation = float32_t3(0.0f, 0.0f, 0.0f);
105+
float Emin = params.initialEmin;
106+
107+
float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
108+
for (uint i = 0u; i <= params.lastCascadeIndex; i++)
109+
{
110+
const bool notFirstCascade = i != 0u;
111+
const bool notLastCascade = i != params.lastCascadeIndex;
112+
113+
RWMCCascadeSample next;
114+
if (notLastCascade)
115+
{
116+
reciprocalBaseI *= params.reciprocalBase;
117+
next = RWMCSampleCascade(coord, i + 1u, reciprocalBaseI);
118+
}
119+
120+
121+
float reliability = 1.f;
122+
// sample counting-based reliability estimation
123+
if (params.reciprocalKappa <= 1.f)
124+
{
125+
float localReliability = curr.normalizedCenterLuma;
126+
// reliability in 3x3 pixel block (see robustness)
127+
float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
128+
if (notFirstCascade)
129+
{
130+
localReliability += prevNormalizedCenterLuma;
131+
globalReliability += prevNormalizedNeighbourhoodAverageLuma;
132+
}
133+
if (notLastCascade)
134+
{
135+
localReliability += next.normalizedCenterLuma;
136+
globalReliability += next.normalizedNeighbourhoodAverageLuma;
137+
}
138+
// check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
139+
reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
140+
{
141+
const float accumLuma = calcLuma(accumulation);
142+
if (accumLuma > Emin)
143+
Emin = accumLuma;
144+
145+
const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
146+
147+
reliability += colorReliability;
148+
reliability *= params.NOverKappa;
149+
reliability -= params.reciprocalKappa;
150+
reliability = clamp(reliability * 0.5f, 0.f, 1.f);
151+
}
152+
}
153+
accumulation += curr.centerValue * reliability;
154+
155+
prevNormalizedCenterLuma = curr.normalizedCenterLuma;
156+
prevNormalizedNeighbourhoodAverageLuma = curr.normalizedNeighbourhoodAverageLuma;
157+
curr = next;
158+
}
159+
160+
return accumulation;
161+
}
162+
13163
int32_t2 getCoordinates()
14164
{
15165
uint32_t width, height;
@@ -19,7 +169,7 @@ int32_t2 getCoordinates()
19169

20170
// this function is for testing purpose
21171
// simply adds every cascade buffer, output shoud be nearly the same as output of default accumulator (RWMC off)
22-
void sumCascade(in const int32_t2 coords)
172+
float32_t3 sumCascade(in const int32_t2 coords)
23173
{
24174
float32_t3 accumulation = float32_t3(0.0f, 0.0f, 0.0f);
25175

@@ -31,9 +181,7 @@ void sumCascade(in const int32_t2 coords)
31181

32182
accumulation /= 32.0f;
33183

34-
float32_t4 output = float32_t4(accumulation, 1.0f);
35-
36-
outImage[coords] = output;
184+
return accumulation;
37185
}
38186

39187
[numthreads(WorkgroupSize, 1, 1)]
@@ -45,9 +193,11 @@ void main(uint32_t3 threadID : SV_DispatchThreadID)
45193
return;
46194

47195
const int32_t2 coords = getCoordinates();
48-
sumCascade(coords);
196+
//float32_t3 color = sumCascade(coords);
49197

50-
// zero out cascade
51-
for (int i = 0; i < 6; ++i)
52-
cascade[uint3(coords.x, coords.y, i)] = float32_t4(0.0f, 0.0f, 0.0f, 0.0f);
198+
RWMCReweightingParameters reweightingParameters = computeReweightingParameters(pc.cascadeCount, pc.base, pc.sampleCount, pc.minReliableLuma, pc.kappa);
199+
float32_t3 color = RWMCReweight(reweightingParameters, coords);
200+
color /= pc.sampleCount;
201+
202+
outImage[coords] = float32_t4(color, 1.0f);
53203
}

31_HLSLPathTracer/main.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,24 @@ using namespace asset;
1515
using namespace ui;
1616
using namespace video;
1717

18-
struct PTPushConstant {
18+
static constexpr uint32_t CascadeSize = 6u;
19+
struct PTPushConstant
20+
{
1921
matrix4SIMD invMVP;
2022
int sampleCount;
2123
int depth;
24+
const uint32_t rwmcCascadeSize = CascadeSize;
25+
uint32_t rwmcCascadeStart;
26+
uint32_t rwmcCascadeBase;
27+
};
28+
29+
struct RWMCPushConstants
30+
{
31+
const uint32_t cascadeSize = CascadeSize;
32+
float base;
33+
uint32_t sampleCount;
34+
float minReliableLuma;
35+
float kappa;
2236
};
2337

2438
// TODO: Add a QueryPool for timestamping once its ready
@@ -509,8 +523,14 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
509523

510524
// Create reweighting pipeline
511525
{
526+
const nbl::asset::SPushConstantRange pcRange = {
527+
.stageFlags = IShader::E_SHADER_STAGE::ESS_COMPUTE,
528+
.offset = 0,
529+
.size = sizeof(RWMCPushConstants)
530+
};
531+
512532
auto pipelineLayout = m_device->createPipelineLayout(
513-
{},
533+
{ &pcRange, 1 },
514534
core::smart_refctd_ptr(gpuDescriptorSetLayout0)
515535
);
516536

@@ -1098,6 +1118,15 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
10981118
m_oracle.reportBeginFrameRecord();
10991119
m_camera.mapKeysToWASD();
11001120

1121+
// set initial push constants contents
1122+
rwmcPushConstants.base = 8.0f;
1123+
rwmcPushConstants.sampleCount = spp;
1124+
rwmcPushConstants.minReliableLuma = 1.0f;
1125+
rwmcPushConstants.kappa = 5.0f;
1126+
1127+
pc.rwmcCascadeStart = 1.0;
1128+
pc.rwmcCascadeBase = 8.0f;
1129+
11011130
return true;
11021131
}
11031132

@@ -1162,11 +1191,12 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
11621191
cmdbuf->reset(IGPUCommandBuffer::RESET_FLAGS::NONE);
11631192
// disregard surface/swapchain transformation for now
11641193
const auto viewProjectionMatrix = m_camera.getConcatenatedMatrix();
1165-
PTPushConstant pc;
11661194
viewProjectionMatrix.getInverseTransform(pc.invMVP);
11671195
pc.sampleCount = spp;
11681196
pc.depth = depth;
11691197

1198+
rwmcPushConstants.sampleCount = spp;
1199+
11701200
// safe to proceed
11711201
// upload buffer data
11721202
cmdbuf->beginDebugMarker("ComputeShaderPathtracer IMGUI Frame");
@@ -1293,6 +1323,7 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
12931323

12941324
cmdbuf->bindComputePipeline(pipeline);
12951325
cmdbuf->bindDescriptorSets(EPBP_COMPUTE, pipeline->getLayout(), 0u, 1u, &m_descriptorSet0.get());
1326+
cmdbuf->pushConstants(pipeline->getLayout(), IShader::E_SHADER_STAGE::ESS_COMPUTE, 0, sizeof(RWMCPushConstants), &rwmcPushConstants);
12961327
cmdbuf->dispatch(1 + (WindowDimensions.x * WindowDimensions.y - 1) / DefaultWorkGroupSize, 1u, 1u);
12971328
}
12981329

@@ -1573,7 +1604,6 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
15731604
smart_refctd_ptr<IGPUImageView> m_envMapView, m_scrambleView;
15741605
smart_refctd_ptr<IGPUBufferView> m_sequenceBufferView;
15751606
smart_refctd_ptr<IGPUImageView> m_outImgView;
1576-
static constexpr uint32_t CascadeSize = 6u;
15771607
smart_refctd_ptr<IGPUImageView> m_cascadeView;
15781608

15791609
// sync
@@ -1610,6 +1640,8 @@ class HLSLComputePathtracer final : public examples::SimpleWindowedApplication,
16101640
int spp = 32;
16111641
int depth = 3;
16121642
bool usePersistentWorkGroups = false;
1643+
RWMCPushConstants rwmcPushConstants;
1644+
PTPushConstant pc;
16131645

16141646
bool m_firstFrame = true;
16151647
IGPUCommandBuffer::SClearColorValue clearColor = { .float32 = {0.f,0.f,0.f,1.f} };

0 commit comments

Comments
 (0)