-
Notifications
You must be signed in to change notification settings - Fork 14
bitonic sort sample #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
bitonic sort sample #209
Conversation
Signed-off-by: CrabeExtra <abbasgaroosi7@gmail.com>
|
@Fletterio wanna try your hand at a review? |
13_BitonicSort/main.cpp
Outdated
| auto limits = m_physicalDevice->getLimits(); | ||
| const uint32_t max_shared_memory_size = limits.maxComputeSharedMemorySize; | ||
| const uint32_t max_workgroup_size = limits.maxComputeWorkGroupInvocations; // Get actual GPU limit | ||
| const uint32_t bytes_per_elements = sizeof(uint32_t) * 2; // 2 uint32_t per element (key and value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should get this with sizeof() of the structs in app_resources/common.hlsl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I said this 2 months ago, and still not done
Fletterio
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized these were all pending, most are outdated see if any still hold up
| } | ||
| } | ||
|
|
||
| GroupMemoryBarrierWithGroupSync(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If compareDistance < waveSize, these barriers serve no purpose, you are overbarriering. In fact writing to shared memory at the end of every such iteration is also pointless.
The proper way to avoid this overbarriering is to branch behaviour based on whether compareDistance < waveSize or not. All steps with compareDistance < waveSize can be done in one go. Threads shuffle their elements around using subgroup intrinsics (shuffleXor, namely), once per every compareDistance value less than the starting one, and then write back to shared memory only once. This is what we do with the FFT, although I don't expect you to infer that from the code since it can be a bit obscure. @ me on discord if you want to figure out the way we handle this with the FFT, I can explain better there since I need to draw diagrams and write a bunch more
| struct Accessor | ||
| { | ||
| static Accessor create(const uint64_t address) | ||
| { | ||
| Accessor accessor; | ||
| accessor.address = address; | ||
| return accessor; | ||
| } | ||
|
|
||
| template <typename AccessType, typename IndexType> | ||
| void get(const IndexType index, NBL_REF_ARG(AccessType) value) | ||
| { | ||
| value = vk::RawBufferLoad<AccessType>(address + index * sizeof(AccessType)); | ||
| } | ||
|
|
||
| template <typename AccessType, typename IndexType> | ||
| void set(const IndexType index, const AccessType value) | ||
| { | ||
| vk::RawBufferStore<AccessType>(address + index * sizeof(AccessType), value); | ||
| } | ||
|
|
||
| uint64_t address; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's readymade BDA accessors you can use AFAIK
| add_subdirectory(12_MeshLoaders) | ||
| # | ||
| #add_subdirectory(13_MaterialCompiler EXCLUDE_FROM_ALL) | ||
| add_subdirectory(12_MeshLoaders EXCLUDE_FROM_ALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you erroneously added EXCLUDE_FROM_ALL to example 12 and now its omitted from CI
| uint64_t deviceBufferAddress; | ||
| }; | ||
|
|
||
| NBL_CONSTEXPR uint32_t WorkgroupSizeLog2 = 10; // 1024 threads (2^10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
512 is optimal residency on all GPUs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so log2 of 9
| std::cout << "(" << key << "," << value << "), "; | ||
| if ((i + 1) % 20 == 0) { | ||
| std::cout << "\n"; | ||
| } | ||
| } | ||
| std::cout << "\nElement count: " << elementCount << "\n"; | ||
|
|
||
| bool is_sorted = true; | ||
| int32_t error_index = -1; | ||
| for (uint32_t i = 1; i < elementCount; i++) { | ||
| uint32_t prevKey = data[(i - 1) * 2]; | ||
| uint32_t currKey = data[i * 2]; | ||
| if (currKey < prevKey) { | ||
| is_sorted = false; | ||
| error_index = i; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (is_sorted) { | ||
| std::cout << "Array is correctly sorted!\n"; | ||
| } | ||
| else { | ||
| std::cout << "Array is NOT sorted correctly!\n"; | ||
| std::cout << "Error at index " << error_index << ":\n"; | ||
| std::cout << " Previous key [" << (error_index - 1) << "] = " << data[(error_index - 1) * 2] << "\n"; | ||
| std::cout << " Current key [" << error_index << "] = " << data[error_index * 2] << "\n"; | ||
| std::cout << " (" << data[error_index * 2] << " < " << data[(error_index - 1) * 2] << " is WRONG!)\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use m_logger instead of std::cout
|
|
||
| deviceLocalBufferParams.queueFamilyIndexCount = 1; | ||
| deviceLocalBufferParams.queueFamilyIndices = &queueFamilyIndex; | ||
| deviceLocalBufferParams.size = sizeof(uint32_t) * elementCount * 2; // *2 because we store (key, value) pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have a struct like KeyValue so you can take a sizeof of it then
| }; | ||
|
|
||
| NBL_CONSTEXPR uint32_t WorkgroupSizeLog2 = 10; // 1024 threads (2^10) | ||
| NBL_CONSTEXPR uint32_t ElementsPerThreadLog2 = 2; // 4 elements per thread (2^2) - VIRTUAL THREADING! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't really call that virtual threads, virtual threads is if you are able to make a workgroup of size lets say 512 behave as if its 4096
processing multiple elements per invocation is an orthogonal extra to that and it helps with subgroupShuffle utilization and loading from global memory
|
|
||
| IQueue* const queue = getComputeQueue(); | ||
|
|
||
| const uint32_t inputSize = sizeof(uint32_t) * elementCount * 2; // *2 because we store (key, value) pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compute this before making your m_deviceLocalBuffer and use it throughout
| inputPtr[i * 2] = key; | ||
| inputPtr[i * 2 + 1] = value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again another reason to have a named struct so you're not doing this unreadable type punning
| std::cout << "(" << key << "," << value << "), "; | ||
| if ((i + 1) % 20 == 0) { | ||
| std::cout << "\n"; | ||
| } | ||
| } | ||
| std::cout << "\nElement count: " << elementCount << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the m_logger instead of std::cout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or an std::ostringstream and then m_logger to print the contents
| { | ||
| assert(dstOffset == 0 && size == outputSize); | ||
|
|
||
| std::cout << "Sorted array: "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name it differently, Output ARray
"Sorted array" is confusing cause it makes me think you're outputting a reference sorted array on the CPU
|
|
||
| cmdbuf->pushConstants(m_pipeline->getLayout(), IShader::E_SHADER_STAGE::ESS_COMPUTE, 0u, sizeof(pc), &pc); | ||
|
|
||
| cmdbuf->dispatch(1, 1, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd consider preparing multiple test data sets, right now you only have one random one, and having 1 workgroup per test (so still one dispatch).
I'd add:
- all keys equal
- keys already sorted
- keys in exact reverse
as manual test cases.
Furthermore this should be a stable sort (unlike the a counting sort), its important to check things like stability so making some test with some neighbouring equal keys in the input array is important to check they have not changed places.
| if (currKey < prevKey) { | ||
| is_sorted = false; | ||
| error_index = i; | ||
| break; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because Bitonic should be stable, you should also check
else if (currKey==prevKey && currValue>prevValue) // check stability, this is why we've initialized the values in such a particular way
{
// fail, not to bad sorting but instability
}and also assert(curValue!=prevValue); because of how we initialized the values
No description provided.