Round Egg 8: Compute shaders - Sphere deformation

2023-12-01

Pedro Burgos, Dominykas Jogela

In Round Egg 6 we introduced compute shaders and in Round Egg 7 we explained our migration to bevy_app_compute. It is now time to actually use compute compute shaders to deform our sphere.

Thanks to the new library, we can greatly simplify the management of the compute tasks. We first need to implement bevy_app_compute::ComputeShader, which points to our compute shader file.

struct SphereDeformationShader;

impl ComputeShader for SphereDeformationShader {
    fn shader() -> ShaderRef {
        "shaders/sphere_deformation.wgsl".into()
    }
}

Then we need to create a worker to handle the compute task. The bevy_app_compute library has a ComputeWorker trait for this.


#[derive(Resource)]
pub struct SphereDeformationWorker;

impl ComputeWorker for SphereDeformationWorker {
    fn build(world: &mut World) -> AppComputeWorker<Self> {
      // ... 
    }
}

For now, lets just forget about this. We know we need a shader that takes each each point and moves it along the direction of the normal with a certain random distribution. We will use WGSL but actually, WGPU supports automatic transpilation of GLSL shaders to WGSL.

# The compute shader

We can define our inputs and outputs at the top, which would be provided later by the Rust code, for each input, we need to specify the group and binding. We also define our own Vec3 type, just for convenience. Specifically we use a storage buffer with read and write capabilities for our vertices, so that we can modify them in-place.

Note: WGPU removes unused parameters from the compiled shader, which is slightly annoying when debugging, make sure to always assign the parameters to a variable in the main function.

@group(0) @binding(0)
var<uniform> strength: f32;

@group(0) @binding(1)
var<storage, read_write> vertices: array<Vec3>;

@group(0) @binding(2)
var<storage> normals: array<Vec3>;


struct Vec3 {
    x: f32,
    y: f32,
    z: f32,
}

Then we have the entrypoint of the shader. We can access each vertex by it's invocation ID. For now we assign the new position of the vertex to the old position plus the normal vector.

@compute @workgroup_size(32) 
fn main(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
    let vmax: u32 = 10000000u;
    if (invocation_id.x >= vmax) {
      return;
    }

    let seed = 1.0;
    let pos = vec3(vertices[invocation_id.x].x, vertices[invocation_id.x].y, vertices[invocation_id.x].z);

    vertices[invocation_id.x].x = vertices[invocation_id.x].x + (normals[invocation_id.x].x * strength);
    vertices[invocation_id.x].y = vertices[invocation_id.x].y + (normals[invocation_id.x].y * strength);
    vertices[invocation_id.x].z = vertices[invocation_id.x].z + (normals[invocation_id.x].z * strength);
}

This works, but the only effect is to enlarge the sphere. In order to deform the sphere we need some randomness. In our case, we chose to use perlin noise, which produces a smooth gradient of values. This kind of noise is often used for procedural generation of textures.

Perlin Noise example displayed as a 2D texture
Perlin Noise example displayed as a 2D texture

The implementation for the perlin noise is well known, in particular we found noisy_bevy to be the best available implementation in Rust and WGSL.

let n = fbm_simplex_3d_seeded(pos, 120, 1.5, 0.5, seed);

The function above works by taking a position (the position of the vertex) a seed, and other parameters (octaves, lacunarity and gain) that tweak the smoothness of the noise. It is important to use a seeded noise function, so that the vertices are deformed in a consistent way.

We simply add this noise to the position of the vertex, and we have our deformation shader.

@compute @workgroup_size(32) 
fn main(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
    // ...

    let n = fbm_simplex_3d_seeded(pos, 120, 1.5, 0.5, seed);

    vertices[invocation_id.x].x = vertices[invocation_id.x].x + (normals[invocation_id.x].x * n * strength);
    vertices[invocation_id.x].y = vertices[invocation_id.x].y + (normals[invocation_id.x].y * n * strength);
    vertices[invocation_id.x].z = vertices[invocation_id.x].z + (normals[invocation_id.x].z * n * strength);
}

# Putting it all together

Now that our compute shader has a shape, we can update the Rust code to call it. We only want to run this shader once, so we will call one_shot() at the end.

    fn build(world: &mut World) -> AppComputeWorker<Self> {
        let size = calc_num_vertices_for_res(crate::SPHERE_RESOLUTION as u32) as usize;
        let vertices_buffer = vec![Vec3Shader::new(0., 0., 0.); size];
        let normals_buffer = vec![Vec3Shader::new(0., 0., 0.); size];

        let worker = AppComputeWorkerBuilder::new(world)
            .add_uniform("strength", &0.2)
            .add_staging("vertices", &vertices_buffer) // read and write
            .add_storage("normals", &normals_buffer)
            .add_pass::<SphereDeformationShader>(
                [size as u32, 1, 1], // forget about this for now
                &["strength", "vertices", "normals"],
            )
            .one_shot()
            .build();

        worker
    }

So far we have defined our shader, but we have not yet invoked it. We decided to make a system to launch the compute shader, and another one to check if the compute shader has finished. Although not an optimal solution, it gets the job done quick. In order to do this, we first need to define some sort of app state, that controls the activation of the systems.

#[derive(Debug, Clone, Eq, PartialEq, Hash, States, Default)]
pub enum AppState {
    #[default]
    Base,               // Start
    AwaitingPositions,  // Waiting for the compute shader to finish
    PositionsFinished,  // The compute shader has finished
}

To run the compute shader, we first populate it's parameters. We can obtain the vertices and normals from the mesh itself (Mesh::ATTRIBUTE_POSITION), but we cannot write directly to the mesh. We must remember that the GPU and the CPU memories are separate, even though Bevy handles passing between the Main World and the Render World, we can't simply pass a pointer to the GPU.

fn compute_positions(
    sphere_handle: Query<&Handle<Mesh>, With<SphereDeformation>>,
    meshes: ResMut<Assets<Mesh>>,
    mut compute_worker: ResMut<AppComputeWorker<SphereDeformationWorker>>,
    mut next_state: ResMut<NextState<AppState>>,
) {
    // This is a quick and dirty way to convert between types
    let handle: Handle<Mesh> = sphere_handle.get_single().expect("no sphere").clone();
    let mesh = meshes.get(&handle).expect("no mesh");
    let normals = mesh
        .attribute(Mesh::ATTRIBUTE_NORMAL)
        .unwrap()
        .as_float3()
        .unwrap();
    let vertices = mesh
        .attribute(Mesh::ATTRIBUTE_POSITION)
        .unwrap()
        .as_float3()
        .unwrap();

    compute_worker.write_slice("vertices", &vertices);
    compute_worker.write_slice("normals", &normals);

    // Queue the compute shader
    compute_worker.execute();

    // Move to the next stage
    next_state.set(AppState::AwaitingPositions);
}

In order to read the result from the compute shader, we need must poll the worker until it finishes. This is probably not ideal, hopefully in the future we will simply have a future that we can await or some sort of callback. In order to do this, we create another system deform_spheres, that will only be active in AppState::AwaitingPositions. Its purpose is to read the new vertices from the compute worker and update the mesh.

pub fn deform_spheres(
    mut sphere_handle: Query<&mut Handle<Mesh>, With<SphereDeformation>>,
    mut meshes: ResMut<Assets<Mesh>>,
    compute_worker: Res<AppComputeWorker<SphereDeformationWorker>>,
    mut next_state: ResMut<NextState<AppState>>,
) {
    if !compute_worker.ready() {
        println!("deform spheres not ready yet");
        return;
    };

    let handle: Handle<Mesh> = sphere_handle.get_single_mut().expect("no sphere").clone();
    let mesh = meshes.get_mut(&handle).expect("no mesh");

    // Read the new vertices from the compute worker
    let new_vertices = compute_worker.read_vec("vertices");

    // Type conversion, dirty way
    let new_vertices: Vec<Vec3> = new_vertices.into_iter().map(|v| v.into()).collect();

    // Update the mesh and move to the nwxt stage
    mesh.insert_attribute(Mesh::ATTRIBUTE_POSITION, new_vertices);
    next_state.set(AppState::PositionsFinished);
}

Finally, we modify our SphereDeformationPlugin to schedule the systems in the appropriate AppStates.

Note: As of this article, Bevy is implementing better ways to handle one off systems, but we have not yet tested them.

impl Plugin for SphereDeformationPlugin {
    fn build(&self, app: &mut App) {
        if !app.is_plugin_added::<AppComputePlugin>() {
            app.add_plugins(AppComputePlugin);
        }
        app.add_plugins(AppComputeWorkerPlugin::<SphereDeformationWorker>::default())
            .add_systems(Update, compute_positions.run_if(in_state(AppState::Base)))
            .add_systems(Update, deform_spheres.run_if(in_state(AppState::AwaitingPositions)));
    }
}

Our planet after applying the compute shader
Our planet after applying the compute shader

# I saw that forget about this for now comment

Yes, we did not explain the AppComputeWorkerBuilder::add_pass function. This is a way to specify the size of the workgroups. Workgroups allow us to run multiple invocations of the compute shader in parallel. Here is a wonderful visualization of workgroups by Freya Holmer.

Visualization of Compute Shaders by <a href='https://acegikmo.com' target='_blank'>Freya Holmer</a>.
Visualization of Compute Shaders by Freya Holmer.