Skip to content

Conversation

@CrabExtra
Copy link

No description provided.

Signed-off-by: CrabeExtra <abbasgaroosi7@gmail.com>
@devshgraphicsprogramming
Copy link
Member

@Fletterio wanna try your hand at a review?

auto limits = m_physicalDevice->getLimits();
const uint32_t max_shared_memory_size = limits.maxComputeSharedMemorySize;
const uint32_t max_workgroup_size = limits.maxComputeWorkGroupInvocations; // Get actual GPU limit
const uint32_t bytes_per_elements = sizeof(uint32_t) * 2; // 2 uint32_t per element (key and value)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should get this with sizeof() of the structs in app_resources/common.hlsl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I said this 2 months ago, and still not done

Copy link
Contributor

@Fletterio Fletterio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized these were all pending, most are outdated see if any still hold up

}
}

GroupMemoryBarrierWithGroupSync();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If compareDistance < waveSize, these barriers serve no purpose, you are overbarriering. In fact writing to shared memory at the end of every such iteration is also pointless.

The proper way to avoid this overbarriering is to branch behaviour based on whether compareDistance < waveSize or not. All steps with compareDistance < waveSize can be done in one go. Threads shuffle their elements around using subgroup intrinsics (shuffleXor, namely), once per every compareDistance value less than the starting one, and then write back to shared memory only once. This is what we do with the FFT, although I don't expect you to infer that from the code since it can be a bit obscure. @ me on discord if you want to figure out the way we handle this with the FFT, I can explain better there since I need to draw diagrams and write a bunch more

Comment on lines +37 to +59
struct Accessor
{
static Accessor create(const uint64_t address)
{
Accessor accessor;
accessor.address = address;
return accessor;
}

template <typename AccessType, typename IndexType>
void get(const IndexType index, NBL_REF_ARG(AccessType) value)
{
value = vk::RawBufferLoad<AccessType>(address + index * sizeof(AccessType));
}

template <typename AccessType, typename IndexType>
void set(const IndexType index, const AccessType value)
{
vk::RawBufferStore<AccessType>(address + index * sizeof(AccessType), value);
}

uint64_t address;
};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's readymade BDA accessors you can use AFAIK

add_subdirectory(12_MeshLoaders)
#
#add_subdirectory(13_MaterialCompiler EXCLUDE_FROM_ALL)
add_subdirectory(12_MeshLoaders EXCLUDE_FROM_ALL)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you erroneously added EXCLUDE_FROM_ALL to example 12 and now its omitted from CI

uint64_t deviceBufferAddress;
};

NBL_CONSTEXPR uint32_t WorkgroupSizeLog2 = 10; // 1024 threads (2^10)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

512 is optimal residency on all GPUs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so log2 of 9

Comment on lines +271 to +298
std::cout << "(" << key << "," << value << "), ";
if ((i + 1) % 20 == 0) {
std::cout << "\n";
}
}
std::cout << "\nElement count: " << elementCount << "\n";

bool is_sorted = true;
int32_t error_index = -1;
for (uint32_t i = 1; i < elementCount; i++) {
uint32_t prevKey = data[(i - 1) * 2];
uint32_t currKey = data[i * 2];
if (currKey < prevKey) {
is_sorted = false;
error_index = i;
break;
}
}

if (is_sorted) {
std::cout << "Array is correctly sorted!\n";
}
else {
std::cout << "Array is NOT sorted correctly!\n";
std::cout << "Error at index " << error_index << ":\n";
std::cout << " Previous key [" << (error_index - 1) << "] = " << data[(error_index - 1) * 2] << "\n";
std::cout << " Current key [" << error_index << "] = " << data[error_index * 2] << "\n";
std::cout << " (" << data[error_index * 2] << " < " << data[(error_index - 1) * 2] << " is WRONG!)\n";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use m_logger instead of std::cout


deviceLocalBufferParams.queueFamilyIndexCount = 1;
deviceLocalBufferParams.queueFamilyIndices = &queueFamilyIndex;
deviceLocalBufferParams.size = sizeof(uint32_t) * elementCount * 2; // *2 because we store (key, value) pairs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have a struct like KeyValue so you can take a sizeof of it then

};

NBL_CONSTEXPR uint32_t WorkgroupSizeLog2 = 10; // 1024 threads (2^10)
NBL_CONSTEXPR uint32_t ElementsPerThreadLog2 = 2; // 4 elements per thread (2^2) - VIRTUAL THREADING!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't really call that virtual threads, virtual threads is if you are able to make a workgroup of size lets say 512 behave as if its 4096

processing multiple elements per invocation is an orthogonal extra to that and it helps with subgroupShuffle utilization and loading from global memory


IQueue* const queue = getComputeQueue();

const uint32_t inputSize = sizeof(uint32_t) * elementCount * 2; // *2 because we store (key, value) pairs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute this before making your m_deviceLocalBuffer and use it throughout

Comment on lines +161 to +162
inputPtr[i * 2] = key;
inputPtr[i * 2 + 1] = value;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again another reason to have a named struct so you're not doing this unreadable type punning

Comment on lines +163 to +168
std::cout << "(" << key << "," << value << "), ";
if ((i + 1) % 20 == 0) {
std::cout << "\n";
}
}
std::cout << "\nElement count: " << elementCount << "\n";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the m_logger instead of std::cout

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or an std::ostringstream and then m_logger to print the contents

{
assert(dstOffset == 0 && size == outputSize);

std::cout << "Sorted array: ";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name it differently, Output ARray

"Sorted array" is confusing cause it makes me think you're outputting a reference sorted array on the CPU


cmdbuf->pushConstants(m_pipeline->getLayout(), IShader::E_SHADER_STAGE::ESS_COMPUTE, 0u, sizeof(pc), &pc);

cmdbuf->dispatch(1, 1, 1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd consider preparing multiple test data sets, right now you only have one random one, and having 1 workgroup per test (so still one dispatch).

I'd add:

  • all keys equal
  • keys already sorted
  • keys in exact reverse
    as manual test cases.

Furthermore this should be a stable sort (unlike the a counting sort), its important to check things like stability so making some test with some neighbouring equal keys in the input array is important to check they have not changed places.

Comment on lines +283 to +287
if (currKey < prevKey) {
is_sorted = false;
error_index = i;
break;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because Bitonic should be stable, you should also check

else if (currKey==prevKey && currValue>prevValue) // check stability, this is why we've initialized the values in such a particular way
{
   // fail, not to bad sorting but instability
}

and also assert(curValue!=prevValue); because of how we initialized the values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants