Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ features = [
"x11",
"bevy_winit",
"bevy_window",
"tonemapping_luts",
"ktx2",
"zstd",
]

[[example]]
Expand All @@ -81,3 +84,6 @@ name = "one_shot"

[[example]]
name = "boids"

[[example]]
name = "shared_storage"
17 changes: 17 additions & 0 deletions assets/shaders/shared_storage_compute.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@group(0) @binding(0)
var<uniform> time: f32;

@group(0) @binding(1)
var<storage, read_write> heights: array<f32>;

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let idx = invocation_id.x;
if (idx < 121u) {
let grid_size = 11u;
let x = f32(idx % grid_size);
let z = f32(idx / grid_size);

heights[idx] = sin(x * 0.5 + time) * cos(z * 0.5 + time) * 0.1;
}
}
58 changes: 58 additions & 0 deletions assets/shaders/shared_storage_vertex.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#import bevy_pbr::{
mesh_functions,
view_transformations::position_world_to_clip,
forward_io::{Vertex, VertexOutput},
}

@group(2) @binding(100)
var<storage, read> heights: array<f32>;

@vertex
fn vertex(vertex: Vertex) -> VertexOutput {
var out: VertexOutput;

let local_pos = vertex.position;

let grid_size = 11u;
let x_index = u32(vertex.uv.x * f32(grid_size - 1u));
let z_index = u32(vertex.uv.y * f32(grid_size - 1u));
let height_index = z_index * grid_size + x_index;

let height = heights[height_index];

var modified_position = local_pos;
modified_position.y += height;

let world_from_local = mesh_functions::get_world_from_local(vertex.instance_index);
out.world_position = mesh_functions::mesh_position_local_to_world(
world_from_local,
vec4<f32>(modified_position, 1.0)
);
out.position = position_world_to_clip(out.world_position.xyz);

// compute normals from adjacent heights
let hL = heights[max(0u, height_index - 1u)];
let hR = heights[min(120u, height_index + 1u)];
let hD = heights[max(0u, height_index - grid_size)];
let hU = heights[min(120u, height_index + grid_size)];

let normal = normalize(vec3<f32>(hL - hR, 2.0, hD - hU));

out.world_normal = mesh_functions::mesh_normal_local_to_world(
normal,
vertex.instance_index
);

out.uv = vertex.uv;
out.instance_index = vertex.instance_index;

#ifdef VERTEX_TANGENTS
out.world_tangent = mesh_functions::mesh_tangent_local_to_world(
world_from_local,
vertex.tangent,
vertex.instance_index
);
#endif

return out;
}
1 change: 1 addition & 0 deletions examples/boids/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl ComputeWorker for BoidWorker {
.add_pass::<BoidsShader>(
[NUM_BOIDS / 64, 1, 1],
&["params", "boids_src", "boids_dst"],
&[],
)
.add_swap("boids_src", "boids_dst")
.build()
Expand Down
1 change: 1 addition & 0 deletions examples/game_of_life/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl ComputeWorker for GameOfLifeWorker {
1,
],
&[SETTINGS_BUFFER, CELLS_IN_BUFFER, CELLS_OUT_BUFFER],
&[]
)
.add_swap(CELLS_IN_BUFFER, CELLS_OUT_BUFFER)
.build()
Expand Down
4 changes: 2 additions & 2 deletions examples/multi_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl ComputeWorker for SimpleComputeWorker {
.add_uniform("value", &3.)
.add_storage("input", &[1., 2., 3., 4.])
.add_staging("output", &[0f32; 4])
.add_pass::<FirstPassShader>([4, 1, 1], &["value", "input", "output"]) // add each item + `value` from `input` to `output`
.add_pass::<SecondPassShader>([4, 1, 1], &["output"]) // multiply each element of `output` by itself
.add_pass::<FirstPassShader>([4, 1, 1], &["value", "input", "output"], &[]) // add each item + `value` from `input` to `output`
.add_pass::<SecondPassShader>([4, 1, 1], &["output"], &[]) // multiply each element of `output` by itself
.build();

// [1. + 3., 2. + 3., 3. + 3., 4. + 3.] = [4., 5., 6., 7.]
Expand Down
2 changes: 1 addition & 1 deletion examples/one_shot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl ComputeWorker for SimpleComputeWorker {
let worker = AppComputeWorkerBuilder::new(world)
.add_uniform("uni", &5.)
.add_staging("values", &[1., 2., 3., 4.])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"], &[])
.one_shot()
.build();

Expand Down
116 changes: 116 additions & 0 deletions examples/shared_storage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use bevy::{
pbr::{ExtendedMaterial, MaterialExtension, StandardMaterial},
prelude::*,
render::{render_resource::AsBindGroup, storage::ShaderStorageBuffer},
};
use bevy_app_compute::prelude::*;
use tracing::info;

#[derive(Asset, AsBindGroup, Reflect, Debug, Clone)]
struct SharedStorageMaterial {
#[storage(100, read_only)]
heights: Handle<ShaderStorageBuffer>,
}

impl MaterialExtension for SharedStorageMaterial {
fn vertex_shader() -> ShaderRef {
"shaders/shared_storage_vertex.wgsl".into()
}
}

#[derive(TypePath)]
struct SharedStorageComputeShader;

impl ComputeShader for SharedStorageComputeShader {
fn shader() -> ShaderRef {
"shaders/shared_storage_compute.wgsl".into()
}
}

#[derive(Resource)]
struct SimpleSharedStorageComputeWorker;

impl ComputeWorker for SimpleSharedStorageComputeWorker {
fn build(world: &mut World) -> AppComputeWorker<Self> {
let heights = vec![0.0f32; 121];
let worker = AppComputeWorkerBuilder::new(world)
.add_uniform("time", &0.0f32)
// This allows us to retrieve a `Handle<ShaderStorageBuffer>` which gets shared between the compute shader and the material on the GPU.
.add_rw_storage_asset("heights", &heights)
.add_pass::<SharedStorageComputeShader>([121, 1, 1], &["time"], &["heights"])
.build();
worker
}
}

fn main() {
App::new()
.add_plugins(DefaultPlugins)
.add_plugins(MaterialPlugin::<
ExtendedMaterial<StandardMaterial, SharedStorageMaterial>,
>::default())
.add_plugins(AppComputePlugin)
.add_plugins(AppComputeWorkerPlugin::<SimpleSharedStorageComputeWorker>::default())
.add_systems(Startup, startup_system)
.add_systems(
Update,
(
spawn_mesh_system
.run_if(resource_added::<AppComputeWorker<SimpleSharedStorageComputeWorker>>),
update_shared_storage_time_system
.run_if(resource_exists::<AppComputeWorker<SimpleSharedStorageComputeWorker>>),
),
)
.run();
}

fn startup_system(mut commands: Commands) {
commands.spawn((
Camera3d::default(),
Transform::from_xyz(5.0, 5.0, 5.0).looking_at(Vec3::ZERO, Vec3::Y),
));

commands.spawn((
DirectionalLight {
shadows_enabled: true,
..default()
},
Transform::from_rotation(Quat::from_euler(EulerRot::XYZ, -0.5, -0.5, 0.0)),
));
}

fn spawn_mesh_system(
mut commands: Commands,
worker: Res<AppComputeWorker<SimpleSharedStorageComputeWorker>>,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<ExtendedMaterial<StandardMaterial, SharedStorageMaterial>>>,
) {
if let Some(heights_handle) = worker.get_storage_buffer_asset_handle("heights") {
let mesh = meshes.add(Plane3d::default().mesh().size(4.0, 4.0).subdivisions(10));

let material = materials.add(ExtendedMaterial {
base: StandardMaterial {
base_color: Color::srgb(0.3, 0.7, 0.3),
..default()
},
extension: SharedStorageMaterial {
heights: heights_handle.clone(),
},
});

commands.spawn((
Mesh3d(mesh),
MeshMaterial3d(material),
Transform::from_xyz(0.0, 0.0, 0.0),
));

info!("Terrain mesh spawned");
}
}

fn update_shared_storage_time_system(
time: Res<Time>,
mut worker: ResMut<AppComputeWorker<SimpleSharedStorageComputeWorker>>,
) {
worker.write("time", &time.elapsed_secs());
}
2 changes: 1 addition & 1 deletion examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl ComputeWorker for SimpleComputeWorker {
let worker = AppComputeWorkerBuilder::new(world)
.add_uniform("uni", &5.)
.add_staging("values", &[1., 2., 3., 4.])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"], &[])
.build();

worker
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub type Result<T> = std::result::Result<T, Error>;
pub enum Error {
BufferNotFound(String),
StagingBufferNotFound(String),
StorageAssetHandleNotFound(String),
InvalidStep(String),
PipelinesEmpty,
PipelineNotReady,
Expand All @@ -17,6 +18,9 @@ impl std::fmt::Display for Error {
match self {
Error::BufferNotFound(name) => write!(f, "Buffer {name} not found."),
Error::StagingBufferNotFound(name) => write!(f, "Staging buffer {name} not found."),
Error::StorageAssetHandleNotFound(name) => {
write!(f, "Storage buffer asset handle {name} not found.")
}
Error::PipelinesEmpty => {
write!(f, "Missing pipelines. Have you added your shader plugins?")
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod plugin;
mod traits;
mod worker;
mod worker_builder;
mod storage_buffers;

/// Helper module to import most used elements.
pub mod prelude {
Expand Down
11 changes: 9 additions & 2 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::marker::PhantomData;

use bevy::{
prelude::*,
render::renderer::{RenderAdapter, RenderDevice},
render::{render_asset::RenderAssets, renderer::{RenderAdapter, RenderDevice}, storage::GpuShaderStorageBuffer, RenderApp},
};

use crate::{
extract_shaders, pipeline_cache::PipelineCache, traits::ComputeWorker, worker::AppComputeWorker,
extract_shaders, pipeline_cache::PipelineCache, storage_buffers::extract_gpu_storage_buffers_to_main, traits::ComputeWorker, worker::AppComputeWorker
};

/// The main plugin. Always include it if you want to use `bevy_app_compute`
Expand All @@ -29,6 +29,13 @@ impl Plugin for AppComputePlugin {
PipelineCache::process_pipeline_queue_system
.in_set(BevyEasyComputeSet::ExtractPipelines),
);

if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
render_app.add_systems(
ExtractSchedule,
extract_gpu_storage_buffers_to_main.run_if(resource_changed::<RenderAssets<GpuShaderStorageBuffer>>),
);
}
}
}

Expand Down
Loading