Skip to content

Commit a3c3e2f

Browse files
fs-eireCopilot
andauthored
[webgpu] a few optimizations to graph capture implementation (microsoft#25305)
### Description 1. rename `SessionState` to `GraphCaptureState`, since there is already one SessionState type in ORT. 2. optimize implementation of `ComputeContext::BufferManager()` --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent fcd448a commit a3c3e2f

File tree

8 files changed

+39
-35
lines changed

8 files changed

+39
-35
lines changed

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ class GpuBufferAllocator : public IAllocator {
2626
void GetStats(AllocatorStats* stats) override;
2727
void OnSessionInitializationEnd();
2828

29-
// Return the associated BufferManager
30-
const BufferManager& GetBufferManager() const { return buffer_manager_; }
31-
3229
private:
3330
AllocatorStats stats_;
3431
const BufferManager& buffer_manager_;

onnxruntime/core/providers/webgpu/buffer_manager.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DisabledCacheManager : public IBufferCacheManager {
3737
wgpuBufferRelease(buffer);
3838
}
3939

40-
void OnRefresh(const SessionState& /*session_status*/) override {
40+
void OnRefresh(GraphCaptureState /*graph_capture_state*/) override {
4141
// no-op
4242
}
4343
};
@@ -59,7 +59,7 @@ class LazyReleaseCacheManager : public IBufferCacheManager {
5959
pending_buffers_.emplace_back(buffer);
6060
}
6161

62-
void OnRefresh(const SessionState& /*session_status*/) override {
62+
void OnRefresh(GraphCaptureState /*graph_capture_state*/) override {
6363
Release();
6464
pending_buffers_.clear();
6565
}
@@ -103,7 +103,7 @@ class SimpleCacheManager : public IBufferCacheManager {
103103
pending_buffers_.emplace_back(buffer);
104104
}
105105

106-
void OnRefresh(const SessionState& /*session_status*/) override {
106+
void OnRefresh(GraphCaptureState /*graph_capture_state*/) override {
107107
for (auto& buffer : pending_buffers_) {
108108
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
109109
}
@@ -196,7 +196,7 @@ class BucketCacheManager : public IBufferCacheManager {
196196
pending_buffers_.emplace_back(buffer);
197197
}
198198

199-
void OnRefresh(const SessionState& /*session_status*/) override {
199+
void OnRefresh(GraphCaptureState /*graph_capture_state*/) override {
200200
for (auto& buffer : pending_buffers_) {
201201
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));
202202
auto it = buckets_.find(buffer_size);
@@ -283,7 +283,7 @@ class GraphCacheManager : public IBufferCacheManager {
283283
pending_buffers_.emplace_back(buffer);
284284
}
285285

286-
void OnRefresh(const SessionState& /*session_status*/) override {
286+
void OnRefresh(GraphCaptureState /*graph_capture_state*/) override {
287287
// Initialize buckets if they don't exist yet
288288
if (buckets_.empty()) {
289289
for (const auto& pair : buckets_limit_) {
@@ -363,9 +363,9 @@ class GraphSimpleCacheManager : public IBufferCacheManager {
363363
pending_buffers_.emplace_back(buffer);
364364
}
365365

366-
void OnRefresh(const SessionState& session_status) override {
366+
void OnRefresh(GraphCaptureState graph_capture_state) override {
367367
for (auto& buffer : pending_buffers_) {
368-
if (session_status == SessionState::Default) {
368+
if (graph_capture_state == GraphCaptureState::Default) {
369369
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
370370
} else {
371371
captured_buffers_.emplace_back(buffer);
@@ -573,11 +573,11 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
573573
staging_buffer.Unmap();
574574
}
575575

576-
void BufferManager::RefreshPendingBuffers(const SessionState& session_status) const {
577-
storage_cache_->OnRefresh(session_status);
578-
uniform_cache_->OnRefresh(session_status);
579-
query_resolve_cache_->OnRefresh(session_status);
580-
default_cache_->OnRefresh(session_status);
576+
void BufferManager::RefreshPendingBuffers(GraphCaptureState graph_capture_state) const {
577+
storage_cache_->OnRefresh(graph_capture_state);
578+
uniform_cache_->OnRefresh(graph_capture_state);
579+
query_resolve_cache_->OnRefresh(graph_capture_state);
580+
default_cache_->OnRefresh(graph_capture_state);
581581
}
582582

583583
IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const {

onnxruntime/core/providers/webgpu/buffer_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace webgpu {
1515
class WebGpuContext;
1616

1717
// For command capture and replay
18-
enum class SessionState {
18+
enum class GraphCaptureState {
1919
Default,
2020
Capturing,
2121
Replaying
@@ -59,7 +59,7 @@ class IBufferCacheManager {
5959
virtual void ReleaseBuffer(WGPUBuffer buffer) = 0;
6060

6161
// when a stream refresh is requested
62-
virtual void OnRefresh(const SessionState& session_status) = 0;
62+
virtual void OnRefresh(GraphCaptureState graph_capture_state) = 0;
6363
};
6464

6565
//
@@ -77,7 +77,7 @@ class BufferManager {
7777
bool SupportsUMA() const;
7878
void Release(WGPUBuffer buffer) const;
7979
void Download(WGPUBuffer src, void* dst, size_t size) const;
80-
void RefreshPendingBuffers(const SessionState& session_status) const;
80+
void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const;
8181

8282
private:
8383
IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const;

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#include "core/providers/webgpu/webgpu_context.h"
88
#include "core/providers/webgpu/allocator.h"
99
#include "core/providers/webgpu/buffer_manager.h"
10+
#include "core/providers/webgpu/webgpu_execution_provider.h"
1011

1112
namespace onnxruntime {
1213
namespace webgpu {
13-
ComputeContext::ComputeContext(OpKernelContext& kernel_context)
14+
ComputeContext::ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep)
1415
: webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())},
15-
kernel_context_{kernel_context} {
16+
kernel_context_{kernel_context},
17+
ep_{ep} {
1618
}
1719

1820
void ComputeContext::PushErrorScope() {
@@ -29,10 +31,7 @@ Status ComputeContext::PopErrorScope() {
2931
}
3032

3133
const webgpu::BufferManager& ComputeContext::BufferManager() const {
32-
OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0);
33-
AllocatorPtr allocator = kernel_context_.GetAllocator(gpu_device);
34-
const GpuBufferAllocator* gpu_allocator = static_cast<const GpuBufferAllocator*>(allocator.get());
35-
return gpu_allocator->GetBufferManager();
34+
return ep_.BufferManager();
3635
}
3736

3837
} // namespace webgpu

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
namespace onnxruntime {
1717

1818
class Tensor;
19+
class WebGpuExecutionProvider;
1920

2021
namespace webgpu {
2122

@@ -24,7 +25,7 @@ class BufferManager;
2425

2526
class ComputeContext {
2627
public:
27-
ComputeContext(OpKernelContext& kernel_context);
28+
ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep);
2829

2930
virtual ~ComputeContext() = default;
3031

@@ -145,6 +146,7 @@ class ComputeContext {
145146
protected:
146147
WebGpuContext& webgpu_context_;
147148
OpKernelContext& kernel_context_;
149+
const WebGpuExecutionProvider& ep_;
148150
};
149151

150152
} // namespace webgpu

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,8 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) {
691691
}
692692
auto command_buffer = current_command_encoder_.Finish();
693693
device_queue_.Submit(1, &command_buffer);
694-
if (session_status_ != SessionState::Replaying) {
695-
buffer_mgr.RefreshPendingBuffers(session_status_);
694+
if (graph_capture_state_ != GraphCaptureState::Replaying) {
695+
buffer_mgr.RefreshPendingBuffers(graph_capture_state_);
696696
}
697697
current_command_encoder_ = nullptr;
698698
num_pending_dispatches_ = 0;
@@ -724,7 +724,7 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
724724
bind_group_desc.label = {program_artifact.name.data(), program_artifact.name.length()};
725725

726726
auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc);
727-
if (session_status_ == SessionState::Capturing) {
727+
if (graph_capture_state_ == GraphCaptureState::Capturing) {
728728
external_captured_commands_->push_back({program_artifact.compute_pipeline,
729729
bind_group,
730730
bind_group_layout,
@@ -754,12 +754,12 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
754754
// TODO: support profiling with graph capture.
755755
ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode");
756756

757-
session_status_ = SessionState::Capturing;
757+
graph_capture_state_ = GraphCaptureState::Capturing;
758758
}
759759

760760
void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captured_commands, const webgpu::BufferManager& buffer_manager) {
761761
LOGS_DEFAULT(VERBOSE) << "Replay with external storage";
762-
session_status_ = SessionState::Replaying;
762+
graph_capture_state_ = GraphCaptureState::Replaying;
763763
// Replay all captured commands from the provided vector
764764
const size_t command_count = captured_commands.size();
765765
for (size_t i = 0; i < command_count; ++i) {
@@ -784,13 +784,13 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
784784
// Flush any remaining commands
785785
Flush(buffer_manager);
786786

787-
session_status_ = SessionState::Default;
787+
graph_capture_state_ = GraphCaptureState::Default;
788788
}
789789

790790
void WebGpuContext::CaptureEnd() {
791791
LOGS_DEFAULT(VERBOSE) << "CaptureEnd";
792792

793-
session_status_ = SessionState::Default;
793+
graph_capture_state_ = GraphCaptureState::Default;
794794
external_captured_commands_ = nullptr;
795795
}
796796

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class WebGpuContext final {
254254
uint64_t gpu_timestamp_offset_ = 0;
255255
bool is_profiling_ = false;
256256
bool preserve_device_;
257-
SessionState session_status_{SessionState::Default};
257+
GraphCaptureState graph_capture_state_{GraphCaptureState::Default};
258258

259259
// External vector to store captured commands, owned by EP
260260
std::vector<webgpu::CapturedCommandInfo>* external_captured_commands_ = nullptr;

onnxruntime/core/providers/webgpu/webgpu_kernel.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "core/framework/op_kernel.h"
1010

1111
namespace onnxruntime {
12+
13+
class WebGpuExecutionProvider;
1214
namespace webgpu {
1315

1416
// -----------------------------------------------------------------------
@@ -17,11 +19,12 @@ namespace webgpu {
1719
class WebGpuKernel : public OpKernel {
1820
public:
1921
explicit WebGpuKernel(const OpKernelInfo& info)
20-
: OpKernel(info) {
22+
: OpKernel(info),
23+
ep_(*static_cast<const WebGpuExecutionProvider*>(info.GetExecutionProvider())) {
2124
}
2225

2326
Status Compute(OpKernelContext* p_op_kernel_context) const override {
24-
ComputeContext context{*p_op_kernel_context};
27+
ComputeContext context{*p_op_kernel_context, ep_};
2528

2629
context.PushErrorScope();
2730
Status s = ComputeInternal(context);
@@ -31,6 +34,9 @@ class WebGpuKernel : public OpKernel {
3134
}
3235

3336
virtual Status ComputeInternal(ComputeContext& context) const = 0;
37+
38+
private:
39+
const WebGpuExecutionProvider& ep_;
3440
};
3541

3642
} // namespace webgpu

0 commit comments

Comments
 (0)