44
55#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
66#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
7+ #include "nbl/builtin/hlsl/bda/__ptr.hlsl"
78
89using namespace nbl::hlsl;
910
@@ -24,36 +25,26 @@ float3 unpackNormals3x10(uint32_t v)
2425
2526float3 calculateSmoothNormals (int instID, int primID, SGeomInfo geom, float2 bary)
2627{
27- uint idxOffset = primID * 3 ;
28-
2928 const uint indexType = geom.indexType;
3029 const uint vertexStride = geom.vertexStride;
3130
3231 const uint64_t vertexBufferAddress = geom.vertexBufferAddress;
3332 const uint64_t indexBufferAddress = geom.indexBufferAddress;
3433
35- uint i0, i1, i2 ;
34+ uint32_t3 indices ;
3635 switch (indexType)
3736 {
3837 case 0 : // EIT_16BIT
39- {
40- i0 = uint32_t (vk::RawBufferLoad<uint16_t>(indexBufferAddress + (idxOffset + 0 ) * sizeof (uint16_t), 2u));
41- i1 = uint32_t (vk::RawBufferLoad<uint16_t>(indexBufferAddress + (idxOffset + 1 ) * sizeof (uint16_t), 2u));
42- i2 = uint32_t (vk::RawBufferLoad<uint16_t>(indexBufferAddress + (idxOffset + 2 ) * sizeof (uint16_t), 2u));
43- }
44- break ;
38+ indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint16_t3>::create (indexBufferAddress)+primID).deref ().load ());
39+ break ;
4540 case 1 : // EIT_32BIT
46- {
47- i0 = vk::RawBufferLoad<uint32_t>(indexBufferAddress + (idxOffset + 0 ) * sizeof (uint32_t));
48- i1 = vk::RawBufferLoad<uint32_t>(indexBufferAddress + (idxOffset + 1 ) * sizeof (uint32_t));
49- i2 = vk::RawBufferLoad<uint32_t>(indexBufferAddress + (idxOffset + 2 ) * sizeof (uint32_t));
50- }
51- break ;
41+ indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint32_t3>::create (indexBufferAddress)+primID).deref ().load ());
42+ break ;
5243 default : // EIT_NONE
5344 {
54- i0 = idxOffset ;
55- i1 = idxOffset + 1 ;
56- i2 = idxOffset + 2 ;
45+ indices[ 0 ] = primID * 3 ;
46+ indices[ 1 ] = indices[ 0 ] + 1 ;
47+ indices[ 2 ] = indices[ 0 ] + 2 ;
5748 }
5849 }
5950
@@ -62,9 +53,10 @@ float3 calculateSmoothNormals(int instID, int primID, SGeomInfo geom, float2 bar
6253 {
6354 case OT_CUBE:
6455 {
65- uint32_t v0 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i0 * vertexStride, 2u);
66- uint32_t v1 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i1 * vertexStride, 2u);
67- uint32_t v2 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i2 * vertexStride, 2u);
56+ // TODO: document why the alignment is 2 here and nowhere else? isnt the `vertexStride` aligned to more than 2 anyway?
57+ uint32_t v0 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[0 ] * vertexStride, 2u);
58+ uint32_t v1 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[1 ] * vertexStride, 2u);
59+ uint32_t v2 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[2 ] * vertexStride, 2u);
6860
6961 n0 = normalize (nbl::hlsl::spirv::unpackSnorm4x8 (v0).xyz);
7062 n1 = normalize (nbl::hlsl::spirv::unpackSnorm4x8 (v1).xyz);
@@ -76,9 +68,9 @@ float3 calculateSmoothNormals(int instID, int primID, SGeomInfo geom, float2 bar
7668 case OT_ARROW:
7769 case OT_CONE:
7870 {
79- uint32_t v0 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i0 * vertexStride);
80- uint32_t v1 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i1 * vertexStride);
81- uint32_t v2 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + i2 * vertexStride);
71+ uint32_t v0 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[ 0 ] * vertexStride);
72+ uint32_t v1 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[ 1 ] * vertexStride);
73+ uint32_t v2 = vk::RawBufferLoad<uint32_t>(vertexBufferAddress + indices[ 2 ] * vertexStride);
8274
8375 n0 = normalize (unpackNormals3x10 (v0));
8476 n1 = normalize (unpackNormals3x10 (v1));
@@ -90,9 +82,9 @@ float3 calculateSmoothNormals(int instID, int primID, SGeomInfo geom, float2 bar
9082 case OT_ICOSPHERE:
9183 default :
9284 {
93- n0 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + i0 * vertexStride));
94- n1 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + i1 * vertexStride));
95- n2 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + i2 * vertexStride));
85+ n0 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[ 0 ] * vertexStride));
86+ n1 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[ 1 ] * vertexStride));
87+ n2 = normalize (vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[ 2 ] * vertexStride));
9688 }
9789 }
9890
@@ -132,6 +124,7 @@ void main(uint32_t3 threadID : SV_DispatchThreadID)
132124 const int instID = spirv::rayQueryGetIntersectionInstanceIdKHR (query, true );
133125 const int primID = spirv::rayQueryGetIntersectionPrimitiveIndexKHR (query, true );
134126
127+ // TODO: candidate for `bda::__ptr<SGeomInfo>`
135128 const SGeomInfo geom = vk::RawBufferLoad<SGeomInfo>(pc.geometryInfoBuffer + instID * sizeof (SGeomInfo));
136129
137130 float3 normals;
0 commit comments