gausplat_renderer/render/gaussian_3d/jit/kernel/rank/
mod.rspub use super::*;
use burn::tensor::ops::IntTensorOps;
use bytemuck::{bytes_of, from_bytes, Pod, Zeroable};
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct Arguments {
pub point_count: u32,
pub tile_count_x: u32,
}
#[derive(Clone, Debug)]
pub struct Inputs<R: JitRuntime> {
pub depths: JitTensor<R>,
pub point_tile_bounds: JitTensor<R>,
pub radii: JitTensor<R>,
pub tile_point_count: JitTensor<R>,
pub tile_touched_offsets: JitTensor<R>,
}
#[derive(Clone, Debug)]
pub struct Outputs<R: JitRuntime> {
pub point_indices: JitTensor<R>,
pub point_orders: JitTensor<R>,
}
pub const GROUP_SIZE: u32 = 256;
pub const TILE_COUNT_MAX: u32 = 1 << 16;
pub const FACTOR_TILE_POINT_COUNT: u32 = 65;
pub fn main<R: JitRuntime, F: FloatElement, I: IntElement, B: BoolElement>(
arguments: Arguments,
inputs: Inputs<R>,
) -> Outputs<R> {
impl_kernel_source!(Kernel, "kernel.wgsl");
let client = &inputs.depths.client;
let device = &inputs.depths.device;
let tile_point_count = *from_bytes::<u32>(
&client.read([inputs.tile_point_count.handle.to_owned().binding()].into())[0],
) as usize;
let point_indices =
JitBackend::<R, F, I, B>::int_empty([tile_point_count].into(), device);
let point_orders =
JitBackend::<R, F, I, B>::int_empty([tile_point_count].into(), device);
client.execute(
Box::new(SourceKernel::new(
Kernel,
CubeDim {
x: GROUP_SIZE,
y: 1,
z: 1,
},
)),
CubeCount::Static(arguments.point_count.div_ceil(GROUP_SIZE), 1, 1),
vec![
client.create(bytes_of(&arguments)).binding(),
inputs.depths.handle.binding(),
inputs.point_tile_bounds.handle.binding(),
inputs.radii.handle.binding(),
inputs.tile_touched_offsets.handle.binding(),
point_indices.handle.to_owned().binding(),
point_orders.handle.to_owned().binding(),
],
);
Outputs {
point_indices,
point_orders,
}
}