You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR enables graph capture capabilities in the WebGPU provider, which
is similar with jsep one microsoft#18989.
All limitations are similar with JS/CUDA EP:
1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not
supported.
2. Usage of graph capture is limited to models where-in all ops in the
model can be partitioned to the WebGPU EP or CPU EP and no memory copy
between them.
3. Shapes of inputs/outputs cannot change across inference calls.
4. IOBinding is required. And all inputs/outputs are pre-allocated gpu
buffers.
When users use graph capture feature, we suppose they will do some
pre-process and post-process for the inference's inputs and outputs in
order to keep the whole pipeline on GPU to avoid some unnecessary cpu to
gpu or gpu to cpu copying. The usage will be like below:
```
// Initialize Dawn
{
// 1. Create Dawn instance
...
instance = wgpu::CreateInstance(&instanceDescriptor);
// 2. Create the adapter
...
instance.RequestAdapter
// 3. Create device from adapter
...
adapter.RequestDevice
}
// Create session options
webgpu_options_ = std::make_unique<Ort::SessionOptions>();
std::unordered_map<std::string, std::string> provider_options;
provider_options["dawnProcTable"] = std::to_string(reinterpret_cast<size_t>(&dawn::native::GetProcs()));
provider_options["webgpuInstance"] = std::to_string(reinterpret_cast<size_t>(instance_.Get()));
provider_options["webgpuDevice"] = std::to_string(reinterpret_cast<size_t>(device_.Get()));
provider_options["deviceId"] = "1";
provider_options["enableGraphCapture"] = "1";
// add WebGPU provider
webgpu_options_->AppendExecutionProvider("WebGPU", provider_options);
...
// create webgpu session
webgpu_session_ = std::make_unique<Ort::Session>(*env_, model_path_.c_str(), *webgpu_options_);
...
Ort::MemoryInfo memory_info_gpu("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
Ort::Allocator allocator(*webgpu_session_, memory_info_gpu);
auto input_buffer = allocator.GetAllocation(input_tensor_size_ * sizeof(float));
auto output_buffer = allocator.GetAllocation(output_tensor_size_ * sizeof(float));
// Create IoBinding objects
Ort::IoBinding webgpu_binding(*webgpu_session_);
// Upload cpu data to input_buffer or copy gpu buffer to input_buffer
...
// Create an OrtValue tensor backed by data on gpu memory
Ort::Value bound_x = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast<float*>(input_buffer.get()), input_tensor_size_,
input_dims_.data(), input_dims_.size());
Ort::Value bound_y = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast<float*>(output_buffer.get()), output_tensor_size_,
output_dims_.data(), output_dims_.size());
webgpu_binding.BindInput("input", bound_x);
webgpu_binding.BindOutput("output", bound_y);
// Run inference
webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // normal run + capturing
...
// post process output_buffer's content
...
// Update input_buffer's content
...
// Run again
webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // replay()
...
// post process output_buffer's content
...
```
0 commit comments