1- #include " nbl/video/CVulkanRayTracingPipeline .h"
1+ #include " nbl/asset/IRayTracingPipeline .h"
22
3+ #include " nbl/video/CVulkanRayTracingPipeline.h"
34#include " nbl/video/CVulkanLogicalDevice.h"
5+ #include " nbl/video/IGPURayTracingPipeline.h"
6+
7+ #include < algorithm>
48
59namespace nbl ::video
610{
@@ -12,11 +16,60 @@ namespace nbl::video
1216 IGPURayTracingPipeline (params),
1317 m_vkPipeline (vk_pipeline),
1418 m_shaders (core::make_refctd_dynamic_array<ShaderContainer>(params.shaders.size())),
19+ m_missStackSizes (core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.misses.size())),
20+ m_hitGroupStackSizes (core::make_refctd_dynamic_array<HitGroupStackSizeContainer>(params.shaderGroups.hits.size())),
21+ m_callableStackSizes (core::make_refctd_dynamic_array<GeneralGroupStackSizeContainer>(params.shaderGroups.hits.size())),
1522 m_shaderGroupHandles (std::move(shaderGroupHandles))
1623 {
1724 for (size_t shaderIx = 0 ; shaderIx < params.shaders .size (); shaderIx++)
1825 m_shaders->operator [](shaderIx) = ShaderRef (static_cast <const CVulkanShader*>(params.shaders [shaderIx].shader ));
1926
27+ const auto * vulkanDevice = static_cast <const CVulkanLogicalDevice*>(getOriginDevice ());
28+ auto * vk = vulkanDevice->getFunctionTable ();
29+
30+ auto getVkShaderGroupStackSize = [&](uint32_t baseGroupIx, uint32_t shaderGroupIx, uint32_t shaderIx, VkShaderGroupShaderKHR shaderType) -> uint16_t
31+ {
32+ if (shaderIx == SShaderGroupsParams::SIndex::Unused)
33+ return 0 ;
34+
35+ return vk->vk .vkGetRayTracingShaderGroupStackSizeKHR (
36+ vulkanDevice->getInternalObject (),
37+ m_vkPipeline,
38+ baseGroupIx + shaderGroupIx,
39+ shaderType
40+ );
41+ };
42+
43+ m_raygenStackSize = getVkShaderGroupStackSize (getRaygenIndex (), 0 , params.shaderGroups .raygen .index , VK_SHADER_GROUP_SHADER_GENERAL_KHR);
44+
45+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .misses .size (); shaderGroupIx++)
46+ {
47+ m_missStackSizes->operator [](shaderGroupIx) = getVkShaderGroupStackSize (
48+ getMissBaseIndex (),
49+ shaderGroupIx,
50+ params.shaderGroups .misses [shaderGroupIx].index ,
51+ VK_SHADER_GROUP_SHADER_GENERAL_KHR);
52+ }
53+
54+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .hits .size (); shaderGroupIx++)
55+ {
56+ const auto & hitGroup = params.shaderGroups .hits [shaderGroupIx];
57+ const auto baseIndex = getHitBaseIndex ();
58+ m_hitGroupStackSizes->operator [](shaderGroupIx) = SHitGroupStackSize{
59+ .closestHit = getVkShaderGroupStackSize (baseIndex,shaderGroupIx, hitGroup.closestHit , VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR),
60+ .anyHit = getVkShaderGroupStackSize (baseIndex, shaderGroupIx, hitGroup.anyHit ,VK_SHADER_GROUP_SHADER_ANY_HIT_KHR),
61+ .intersection = getVkShaderGroupStackSize (baseIndex, shaderGroupIx, hitGroup.intersection , VK_SHADER_GROUP_SHADER_INTERSECTION_KHR),
62+ };
63+ }
64+
65+ for (size_t shaderGroupIx = 0 ; shaderGroupIx < params.shaderGroups .callables .size (); shaderGroupIx++)
66+ {
67+ m_callableStackSizes->operator [](shaderGroupIx) = getVkShaderGroupStackSize (
68+ getCallableBaseIndex (),
69+ shaderGroupIx,
70+ params.shaderGroups .callables [shaderGroupIx].index ,
71+ VK_SHADER_GROUP_SHADER_GENERAL_KHR);
72+ }
2073 }
2174
2275 CVulkanRayTracingPipeline::~CVulkanRayTracingPipeline ()
@@ -26,27 +79,86 @@ namespace nbl::video
2679 vk->vk .vkDestroyPipeline (vulkanDevice->getInternalObject (), m_vkPipeline, nullptr );
2780 }
2881
29-
3082 const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getRaygen () const
3183 {
32- return m_shaderGroupHandles->operator [](0 );
84+ return m_shaderGroupHandles->operator [](getRaygenIndex () );
3385 }
3486
3587 const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getMiss (uint32_t index) const
3688 {
37- const auto baseIndex = 1 ; // one raygen group before this groups
89+ const auto baseIndex = getMissBaseIndex ();
3890 return m_shaderGroupHandles->operator [](baseIndex + index);
3991 }
4092
4193 const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getHit (uint32_t index) const
4294 {
43- const auto baseIndex = 1 + getMissGroupCount (); // one raygen group + miss gropus before this groups
95+ const auto baseIndex = getHitBaseIndex ();
4496 return m_shaderGroupHandles->operator [](baseIndex + index);
4597 }
4698
4799 const IGPURayTracingPipeline::SShaderGroupHandle& CVulkanRayTracingPipeline::getCallable (uint32_t index) const
48100 {
49- const auto baseIndex = 1 + getMissGroupCount () + getHitGroupCount (); // one raygen group + miss groups + hit gropus before this groups
101+ const auto baseIndex = getCallableBaseIndex ();
50102 return m_shaderGroupHandles->operator [](baseIndex + index);
51103 }
104+
105+ uint16_t CVulkanRayTracingPipeline::getRaygenStackSize () const
106+ {
107+ return m_raygenStackSize;
108+ }
109+
110+ std::span<const uint16_t > CVulkanRayTracingPipeline::getMissStackSizes () const
111+ {
112+ return std::span (m_missStackSizes->begin (), m_missStackSizes->end ());
113+ }
114+
115+ std::span<const IGPURayTracingPipeline::SHitGroupStackSize> CVulkanRayTracingPipeline::getHitStackSizes () const
116+ {
117+ return std::span (m_hitGroupStackSizes->begin (), m_hitGroupStackSizes->end ());
118+ }
119+
120+ std::span<const uint16_t > CVulkanRayTracingPipeline::getCallableStackSizes () const
121+ {
122+ return std::span (m_callableStackSizes->begin (), m_callableStackSizes->end ());
123+ }
124+
125+ uint16_t CVulkanRayTracingPipeline::getDefaultStackSize () const
126+ {
127+ // calculation follow the formula from
128+ // https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#ray-tracing-pipeline-stack
129+ const auto raygenStackMax = m_raygenStackSize;
130+ const auto closestHitStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::closestHit)->closestHit ;
131+ const auto anyHitStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::anyHit)->anyHit ;
132+ const auto intersectionStackMax = std::ranges::max_element (getHitStackSizes (), std::ranges::less{}, &SHitGroupStackSize::intersection)->intersection ;
133+ const auto missStackMax = *std::ranges::max_element (getMissStackSizes ());
134+ const auto callableStackMax = *std::ranges::max_element (getCallableStackSizes ());
135+ return raygenStackMax + std::min<uint16_t >(1 , m_params.maxRecursionDepth ) *
136+ std::max (closestHitStackMax, std::max<uint16_t >(missStackMax, intersectionStackMax + anyHitStackMax)) +
137+ std::max<uint16_t >(0 , m_params.maxRecursionDepth - 1 ) * std::max (closestHitStackMax, missStackMax) + 2 *
138+ callableStackMax;
139+ }
140+
141+ uint32_t CVulkanRayTracingPipeline::getRaygenIndex () const
142+ {
143+ return 0 ;
144+ }
145+
146+ uint32_t CVulkanRayTracingPipeline::getMissBaseIndex () const
147+ {
148+ // one raygen group before this groups
149+ return 1 ;
150+ }
151+
152+ uint32_t CVulkanRayTracingPipeline::getHitBaseIndex () const
153+ {
154+ // one raygen group + miss groups before this groups
155+ return 1 + getMissGroupCount ();
156+ }
157+
158+ uint32_t CVulkanRayTracingPipeline::getCallableBaseIndex () const
159+ {
160+ // one raygen group + miss groups + hit groups before this groups
161+ return 1 + getMissGroupCount () + getHitGroupCount ();
162+ }
163+
52164}
0 commit comments