Skip to content

Commit e63e053

Browse files
authored
[webgpu] Enable graph capture (microsoft#24900)
This PR enables graph capture capabilities in the WebGPU provider, which is similar with jsep one microsoft#18989. All limitations are similar with JS/CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the WebGPU EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IOBinding is required. And all inputs/outputs are pre-allocated gpu buffers. When users use graph capture feature, we suppose they will do some pre-process and post-process for the inference's inputs and outputs in order to keep the whole pipeline on GPU to avoid some unnecessary cpu to gpu or gpu to cpu copying. The usage will be like below: ``` // Initialize Dawn { // 1. Create Dawn instance ... instance = wgpu::CreateInstance(&instanceDescriptor); // 2. Create the adapter ... instance.RequestAdapter // 3. Create device from adapter ... adapter.RequestDevice } // Create session options webgpu_options_ = std::make_unique<Ort::SessionOptions>(); std::unordered_map<std::string, std::string> provider_options; provider_options["dawnProcTable"] = std::to_string(reinterpret_cast<size_t>(&dawn::native::GetProcs())); provider_options["webgpuInstance"] = std::to_string(reinterpret_cast<size_t>(instance_.Get())); provider_options["webgpuDevice"] = std::to_string(reinterpret_cast<size_t>(device_.Get())); provider_options["deviceId"] = "1"; provider_options["enableGraphCapture"] = "1"; // add WebGPU provider webgpu_options_->AppendExecutionProvider("WebGPU", provider_options); ... // create webgpu session webgpu_session_ = std::make_unique<Ort::Session>(*env_, model_path_.c_str(), *webgpu_options_); ... Ort::MemoryInfo memory_info_gpu("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); Ort::Allocator allocator(*webgpu_session_, memory_info_gpu); auto input_buffer = allocator.GetAllocation(input_tensor_size_ * sizeof(float)); auto output_buffer = allocator.GetAllocation(output_tensor_size_ * sizeof(float)); // Create IoBinding objects Ort::IoBinding webgpu_binding(*webgpu_session_); // Upload cpu data to input_buffer or copy gpu buffer to input_buffer ... // Create an OrtValue tensor backed by data on gpu memory Ort::Value bound_x = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast<float*>(input_buffer.get()), input_tensor_size_, input_dims_.data(), input_dims_.size()); Ort::Value bound_y = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast<float*>(output_buffer.get()), output_tensor_size_, output_dims_.data(), output_dims_.size()); webgpu_binding.BindInput("input", bound_x); webgpu_binding.BindOutput("output", bound_y); // Run inference webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // normal run + capturing ... // post process output_buffer's content ... // Update input_buffer's content ... // Run again webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // replay() ... // post process output_buffer's content ... ```
1 parent 5e4d8dc commit e63e053

16 files changed

+546
-129
lines changed

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "core/framework/session_state.h"
55
#include "core/providers/webgpu/allocator.h"
6-
#include "core/providers/webgpu/webgpu_context.h"
6+
#include "core/providers/webgpu/buffer_manager.h"
77

88
namespace onnxruntime {
99
namespace webgpu {
@@ -15,18 +15,17 @@ void* GpuBufferAllocator::Alloc(size_t size) {
1515

1616
stats_.num_allocs++;
1717

18-
#if !defined(__wasm__)
19-
if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) {
20-
return context_.BufferManager().CreateUMA(size);
18+
// Check if the buffer manager supports UMA and we're not yet in an initialized session
19+
if (!session_initialized_ && buffer_manager_.SupportsUMA()) {
20+
return buffer_manager_.CreateUMA(size);
2121
}
22-
#endif // !defined(__wasm__)
2322

24-
return context_.BufferManager().Create(size);
23+
return buffer_manager_.Create(size);
2524
}
2625

2726
void GpuBufferAllocator::Free(void* p) {
2827
if (p != nullptr) {
29-
context_.BufferManager().Release(static_cast<WGPUBuffer>(p));
28+
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
3029
stats_.num_allocs--;
3130
}
3231
}

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,29 @@
99
namespace onnxruntime {
1010
namespace webgpu {
1111

12-
class WebGpuContext;
12+
class BufferManager;
1313

1414
class GpuBufferAllocator : public IAllocator {
1515
public:
16-
GpuBufferAllocator(const WebGpuContext& context)
16+
GpuBufferAllocator(const BufferManager& buffer_manager)
1717
: IAllocator(
1818
OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator,
1919
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
2020
OrtMemTypeDefault)),
21-
context_{context} {
21+
buffer_manager_{buffer_manager} {
2222
}
2323

2424
virtual void* Alloc(size_t size) override;
2525
virtual void Free(void* p) override;
2626
void GetStats(AllocatorStats* stats) override;
27-
2827
void OnSessionInitializationEnd();
2928

29+
// Return the associated BufferManager
30+
const BufferManager& GetBufferManager() const { return buffer_manager_; }
31+
3032
private:
3133
AllocatorStats stats_;
32-
const WebGpuContext& context_;
34+
const BufferManager& buffer_manager_;
3335
bool session_initialized_ = false;
3436
};
3537

onnxruntime/core/providers/webgpu/buffer_manager.cc

Lines changed: 186 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DisabledCacheManager : public IBufferCacheManager {
3737
wgpuBufferRelease(buffer);
3838
}
3939

40-
void OnRefresh() override {
40+
void OnRefresh(const SessionState& /*session_status*/) override {
4141
// no-op
4242
}
4343
};
@@ -59,7 +59,7 @@ class LazyReleaseCacheManager : public IBufferCacheManager {
5959
pending_buffers_.emplace_back(buffer);
6060
}
6161

62-
void OnRefresh() override {
62+
void OnRefresh(const SessionState& /*session_status*/) override {
6363
Release();
6464
pending_buffers_.clear();
6565
}
@@ -103,7 +103,7 @@ class SimpleCacheManager : public IBufferCacheManager {
103103
pending_buffers_.emplace_back(buffer);
104104
}
105105

106-
void OnRefresh() override {
106+
void OnRefresh(const SessionState& /*session_status*/) override {
107107
for (auto& buffer : pending_buffers_) {
108108
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
109109
}
@@ -196,12 +196,9 @@ class BucketCacheManager : public IBufferCacheManager {
196196
pending_buffers_.emplace_back(buffer);
197197
}
198198

199-
void OnRefresh() override {
200-
// TODO: consider graph capture. currently not supported
201-
199+
void OnRefresh(const SessionState& /*session_status*/) override {
202200
for (auto& buffer : pending_buffers_) {
203201
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));
204-
205202
auto it = buckets_.find(buffer_size);
206203
if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) {
207204
it->second.emplace_back(buffer);
@@ -249,6 +246,155 @@ class BucketCacheManager : public IBufferCacheManager {
249246
std::vector<size_t> buckets_keys_;
250247
};
251248

249+
class GraphCacheManager : public IBufferCacheManager {
250+
public:
251+
GraphCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} {
252+
Initialize();
253+
}
254+
GraphCacheManager(std::unordered_map<size_t, size_t>&& buckets_limit) : buckets_limit_{buckets_limit} {
255+
Initialize();
256+
}
257+
258+
size_t CalculateBufferSize(size_t request_size) override {
259+
// binary serch size
260+
auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size);
261+
if (it == buckets_keys_.end()) {
262+
return NormalizeBufferSize(request_size);
263+
} else {
264+
return *it;
265+
}
266+
}
267+
268+
WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
269+
auto it = buckets_.find(buffer_size);
270+
if (it != buckets_.end() && !it->second.empty()) {
271+
auto buffer = it->second.back();
272+
it->second.pop_back();
273+
return buffer;
274+
}
275+
return nullptr;
276+
}
277+
278+
void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
279+
// no-op
280+
}
281+
282+
void ReleaseBuffer(WGPUBuffer buffer) override {
283+
pending_buffers_.emplace_back(buffer);
284+
}
285+
286+
void OnRefresh(const SessionState& /*session_status*/) override {
287+
// Initialize buckets if they don't exist yet
288+
if (buckets_.empty()) {
289+
for (const auto& pair : buckets_limit_) {
290+
buckets_.emplace(pair.first, std::vector<WGPUBuffer>());
291+
}
292+
}
293+
294+
for (auto& buffer : pending_buffers_) {
295+
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));
296+
auto it = buckets_.find(buffer_size);
297+
if (it != buckets_.end()) {
298+
it->second.emplace_back(buffer);
299+
} else {
300+
// insert a new bucket if it doesn't exist
301+
buckets_[buffer_size] = std::vector<WGPUBuffer>{buffer};
302+
}
303+
}
304+
305+
pending_buffers_.clear();
306+
}
307+
308+
~GraphCacheManager() {
309+
for (auto& buffer : pending_buffers_) {
310+
wgpuBufferRelease(buffer);
311+
}
312+
for (auto& pair : buckets_) {
313+
for (auto& buffer : pair.second) {
314+
wgpuBufferRelease(buffer);
315+
}
316+
}
317+
}
318+
319+
protected:
320+
void Initialize() {
321+
buckets_keys_.reserve(buckets_limit_.size());
322+
for (const auto& pair : buckets_limit_) {
323+
buckets_keys_.push_back(pair.first);
324+
}
325+
std::sort(buckets_keys_.begin(), buckets_keys_.end());
326+
327+
#ifndef NDEBUG // if debug build
328+
ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }),
329+
"Bucket sizes must be multiples of 16.");
330+
331+
for (size_t i = 1; i < buckets_keys_.size(); ++i) {
332+
ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order.");
333+
}
334+
#endif
335+
}
336+
std::unordered_map<size_t, size_t> buckets_limit_;
337+
std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_;
338+
std::vector<WGPUBuffer> pending_buffers_;
339+
std::vector<size_t> buckets_keys_;
340+
};
341+
342+
class GraphSimpleCacheManager : public IBufferCacheManager {
343+
size_t CalculateBufferSize(size_t request_size) override {
344+
return NormalizeBufferSize(request_size);
345+
}
346+
347+
WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
348+
auto it = buffers_.find(buffer_size);
349+
if (it != buffers_.end() && !it->second.empty()) {
350+
auto buffer = it->second.back();
351+
it->second.pop_back();
352+
return buffer;
353+
}
354+
355+
return nullptr;
356+
}
357+
358+
void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
359+
// no-op
360+
}
361+
362+
void ReleaseBuffer(WGPUBuffer buffer) override {
363+
pending_buffers_.emplace_back(buffer);
364+
}
365+
366+
void OnRefresh(const SessionState& session_status) override {
367+
for (auto& buffer : pending_buffers_) {
368+
if (session_status == SessionState::Default) {
369+
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
370+
} else {
371+
captured_buffers_.emplace_back(buffer);
372+
}
373+
}
374+
pending_buffers_.clear();
375+
}
376+
377+
public:
378+
~GraphSimpleCacheManager() {
379+
for (auto& buffer : pending_buffers_) {
380+
wgpuBufferRelease(buffer);
381+
}
382+
for (auto& pair : buffers_) {
383+
for (auto& buffer : pair.second) {
384+
wgpuBufferRelease(buffer);
385+
}
386+
}
387+
for (auto& buffer : captured_buffers_) {
388+
wgpuBufferRelease(buffer);
389+
}
390+
}
391+
392+
protected:
393+
std::map<size_t, std::vector<WGPUBuffer>> buffers_;
394+
std::vector<WGPUBuffer> pending_buffers_;
395+
std::vector<WGPUBuffer> captured_buffers_;
396+
};
397+
252398
std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
253399
switch (cache_mode) {
254400
case BufferCacheMode::Disabled:
@@ -259,6 +405,10 @@ std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode ca
259405
return std::make_unique<SimpleCacheManager>();
260406
case BufferCacheMode::Bucket:
261407
return std::make_unique<BucketCacheManager>();
408+
case BufferCacheMode::Graph:
409+
return std::make_unique<GraphCacheManager>();
410+
case BufferCacheMode::GraphSimple:
411+
return std::make_unique<GraphSimpleCacheManager>();
262412
default:
263413
ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode");
264414
}
@@ -278,6 +428,12 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) {
278428
case BufferCacheMode::Bucket:
279429
os << "Bucket";
280430
break;
431+
case BufferCacheMode::Graph:
432+
os << "Graph";
433+
break;
434+
case BufferCacheMode::GraphSimple:
435+
os << "GraphSimple";
436+
break;
281437
default:
282438
os << "Unknown(" << static_cast<int>(mode) << ")";
283439
}
@@ -292,7 +448,7 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf
292448
default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} {
293449
}
294450

295-
void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) {
451+
void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const {
296452
// If the buffer is mapped, we can directly write to it.
297453
void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped
298454
if (mapped_data) {
@@ -317,10 +473,10 @@ void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) {
317473
auto& command_encoder = context_.GetCommandEncoder();
318474
context_.EndComputePass();
319475
command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size);
320-
context_.Flush();
476+
context_.Flush(*this);
321477
}
322478

323-
void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) {
479+
void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const {
324480
ORT_ENFORCE(src != dst, "Source and destination buffers must be different.");
325481
EnforceBufferUnmapped(context_, src);
326482
EnforceBufferUnmapped(context_, dst);
@@ -337,7 +493,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) {
337493
command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size);
338494
}
339495

340-
WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
496+
WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
341497
auto& cache = GetCacheManager(usage);
342498
auto buffer_size = cache.CalculateBufferSize(size);
343499

@@ -358,7 +514,7 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
358514
return buffer;
359515
}
360516

361-
WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) {
517+
WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) const {
362518
ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer.");
363519
auto& cache = GetCacheManager(usage);
364520
auto buffer_size = cache.CalculateBufferSize(size);
@@ -378,12 +534,21 @@ WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) {
378534
return buffer;
379535
}
380536

381-
void BufferManager::Release(WGPUBuffer buffer) {
537+
bool BufferManager::SupportsUMA() const {
538+
#if !defined(__wasm__)
539+
// Check if the device supports the BufferMapExtendedUsages feature
540+
return context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages);
541+
#else
542+
return false;
543+
#endif // !defined(__wasm__)
544+
}
545+
546+
void BufferManager::Release(WGPUBuffer buffer) const {
382547
EnforceBufferUnmapped(context_, buffer);
383548
GetCacheManager(buffer).ReleaseBuffer(buffer);
384549
}
385550

386-
void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) {
551+
void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
387552
EnforceBufferUnmapped(context_, src);
388553
auto buffer_size = NormalizeBufferSize(size);
389554

@@ -395,7 +560,7 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) {
395560
auto& command_encoder = context_.GetCommandEncoder();
396561
context_.EndComputePass();
397562
command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size);
398-
context_.Flush();
563+
context_.Flush(*this);
399564

400565
// TODO: revise wait in whole project
401566

@@ -405,13 +570,14 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) {
405570

406571
auto mapped_data = staging_buffer.GetConstMappedRange();
407572
memcpy(dst, mapped_data, size);
573+
staging_buffer.Unmap();
408574
}
409575

410-
void BufferManager::RefreshPendingBuffers() {
411-
storage_cache_->OnRefresh();
412-
uniform_cache_->OnRefresh();
413-
query_resolve_cache_->OnRefresh();
414-
default_cache_->OnRefresh();
576+
void BufferManager::RefreshPendingBuffers(const SessionState& session_status) const {
577+
storage_cache_->OnRefresh(session_status);
578+
uniform_cache_->OnRefresh(session_status);
579+
query_resolve_cache_->OnRefresh(session_status);
580+
default_cache_->OnRefresh(session_status);
415581
}
416582

417583
IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const {

0 commit comments

Comments
 (0)