Skip to content

Commit 063491c

Browse files
authored
[webgpu] add usage of ReadonlyAllocator for WebGPU EP (microsoft#25690)
### Description add usage of `ReadonlyAllocator` for WebGPU EP ### Motivation and Context `ReadonlyAllocator` is added in microsoft#25348 to allow an EP to register a separated allocator only use for initializers. WebGPU EP already has different handling to initializers and non-initializers, and this change makes WebGPU EP to use the preferred way to deal with it. Now the allocator depends on the `OrtAllocatorType` instead of the session initialization status.
1 parent 0dd71ba commit 063491c

File tree

8 files changed

+44
-59
lines changed

8 files changed

+44
-59
lines changed

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,28 @@
88
namespace onnxruntime {
99
namespace webgpu {
1010

11+
GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator)
12+
: IAllocator(
13+
OrtMemoryInfo(WEBGPU_BUFFER,
14+
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
15+
: OrtAllocatorType::OrtDeviceAllocator,
16+
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
17+
OrtMemTypeDefault)),
18+
buffer_manager_{buffer_manager},
19+
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
20+
}
21+
1122
void* GpuBufferAllocator::Alloc(size_t size) {
1223
if (size == 0) {
1324
return nullptr;
1425
}
1526

1627
stats_.num_allocs++;
1728

18-
// Check if the buffer manager supports UMA and we're not yet in an initialized session
19-
if (!session_initialized_ && buffer_manager_.SupportsUMA()) {
20-
return buffer_manager_.CreateUMA(size);
21-
}
29+
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
30+
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
2231

23-
return buffer_manager_.Create(size);
32+
return buffer_manager_.Create(size, usage);
2433
}
2534

2635
void GpuBufferAllocator::Free(void* p) {
@@ -34,9 +43,5 @@ void GpuBufferAllocator::GetStats(AllocatorStats* stats) {
3443
*stats = stats_;
3544
}
3645

37-
void GpuBufferAllocator::OnSessionInitializationEnd() {
38-
session_initialized_ = true;
39-
}
40-
4146
} // namespace webgpu
4247
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,16 @@ class BufferManager;
1313

1414
class GpuBufferAllocator : public IAllocator {
1515
public:
16-
GpuBufferAllocator(const BufferManager& buffer_manager)
17-
: IAllocator(
18-
OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator,
19-
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
20-
OrtMemTypeDefault)),
21-
buffer_manager_{buffer_manager} {
22-
}
16+
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);
2317

2418
virtual void* Alloc(size_t size) override;
2519
virtual void Free(void* p) override;
2620
void GetStats(AllocatorStats* stats) override;
27-
void OnSessionInitializationEnd();
2821

2922
private:
3023
AllocatorStats stats_;
3124
const BufferManager& buffer_manager_;
32-
bool session_initialized_ = false;
25+
bool mapped_at_creation_;
3326
};
3427

3528
} // namespace webgpu

onnxruntime/core/providers/webgpu/buffer_manager.cc

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
508508
wgpu::BufferDescriptor desc{};
509509
desc.size = buffer_size;
510510
desc.usage = usage;
511+
if (usage & wgpu::BufferUsage::MapWrite) {
512+
desc.mappedAtCreation = true; // ensure the buffer is mapped for writing at creation
513+
}
511514
buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();
512515

513516
ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), ".");
@@ -516,26 +519,6 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
516519
return buffer;
517520
}
518521

519-
WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) const {
520-
ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer.");
521-
auto& cache = GetCacheManager(usage);
522-
auto buffer_size = cache.CalculateBufferSize(size);
523-
524-
// Ensure the buffer is mapped for writing at creation.
525-
usage |= wgpu::BufferUsage::MapWrite;
526-
527-
wgpu::BufferDescriptor desc{};
528-
desc.size = buffer_size;
529-
desc.usage = usage;
530-
desc.mappedAtCreation = true;
531-
auto buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();
532-
533-
ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), ".");
534-
535-
cache.RegisterBuffer(buffer, size);
536-
return buffer;
537-
}
538-
539522
bool BufferManager::SupportsUMA() const {
540523
#if !defined(__wasm__)
541524
// Check if the device supports the BufferMapExtendedUsages feature

onnxruntime/core/providers/webgpu/buffer_manager.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,8 @@ class BufferManager {
7070
BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode);
7171
void Upload(void* src, WGPUBuffer dst, size_t size) const;
7272
void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const;
73-
WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const;
74-
// Create a buffer mapped for writing.
75-
WGPUBuffer CreateUMA(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const;
76-
// Check if CreateUMA is supported (i.e., the device has BufferMapExtendedUsages feature)
77-
bool SupportsUMA() const;
73+
WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const;
74+
bool SupportsUMA() const; // Check if CreateUMA is supported (i.e., the device has BufferMapExtendedUsages feature)
7875
void Release(WGPUBuffer buffer) const;
7976
void Download(WGPUBuffer src, void* dst, size_t size) const;
8077
void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const;

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
162162
buffer_cache_config.uniform.mode,
163163
buffer_cache_config.query_resolve.mode);
164164

165+
// create initializer buffer manager. cache is always disabled for initializer buffer manager
166+
initializer_buffer_mgr_ = BufferManagerFactory::Create(*this,
167+
BufferCacheMode::Disabled,
168+
BufferCacheMode::Disabled,
169+
BufferCacheMode::Disabled);
170+
165171
// create program manager
166172
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
167173

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,18 @@ class WebGpuContext final {
132132

133133
void Flush(const webgpu::BufferManager& buffer_mgr);
134134

135+
/**
136+
* Get the buffer manager.
137+
*/
135138
webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; }
136139

140+
/**
141+
* Get the initializer buffer manager.
142+
*
143+
* This buffer manager is used for read-only buffers (e.g. initializers).
144+
*/
145+
webgpu::BufferManager& InitializerBufferManager() const { return *initializer_buffer_mgr_; }
146+
137147
inline webgpu::ValidationMode ValidationMode() const {
138148
return validation_mode_;
139149
}
@@ -236,6 +246,7 @@ class WebGpuContext final {
236246
wgpu::ComputePassEncoder current_compute_pass_encoder_;
237247

238248
std::unique_ptr<webgpu::BufferManager> buffer_mgr_;
249+
std::unique_ptr<webgpu::BufferManager> initializer_buffer_mgr_;
239250
std::unique_ptr<ProgramManager> program_mgr_;
240251

241252
uint32_t num_pending_dispatches_ = 0;

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -801,13 +801,12 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
801801
}
802802

803803
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
804-
AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) {
805-
return std::make_unique<webgpu::GpuBufferAllocator>(BufferManager());
806-
},
807-
0, false);
808-
auto preferred_allocators = std::vector<AllocatorPtr>{CreateAllocator(gpuBufferAllocatorCreationInfo)};
809-
allocator_ = reinterpret_cast<webgpu::GpuBufferAllocator*>(preferred_allocators[0].get());
810-
return preferred_allocators;
804+
return {
805+
// allocator for initializers
806+
std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), true),
807+
// default allocator
808+
std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false),
809+
};
811810
}
812811

813812
std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapability(
@@ -912,13 +911,6 @@ std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
912911
return profiler;
913912
}
914913

915-
Status WebGpuExecutionProvider::OnSessionInitializationEnd() {
916-
if (allocator_ != nullptr) {
917-
allocator_->OnSessionInitializationEnd();
918-
}
919-
return Status::OK();
920-
}
921-
922914
Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
923915
if (context_.ValidationMode() >= ValidationMode::Basic) {
924916
context_.PushErrorScope();

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class WebGpuExecutionProvider : public IExecutionProvider {
7171
bool ConcurrentRunSupported() const override { return false; }
7272

7373
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
74-
Status OnSessionInitializationEnd() override;
7574

7675
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
7776
Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
@@ -100,7 +99,6 @@ class WebGpuExecutionProvider : public IExecutionProvider {
10099
int regular_run_count_before_graph_capture_ = 0;
101100
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
102101
int m_current_graph_annotation_id = 0;
103-
webgpu::GpuBufferAllocator* allocator_ = nullptr;
104102

105103
// Buffer manager specifically for graph capture mode
106104
std::unique_ptr<webgpu::BufferManager> graph_buffer_mgr_ = nullptr;

0 commit comments

Comments
 (0)