gausplat_renderer/render/gaussian_3d/jit/kernel/segment/
mod.rspub use super::*;
use burn::tensor::ops::IntTensorOps;
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct Arguments {
pub tile_count_x: u32,
pub tile_count_y: u32,
}
#[derive(Clone, Debug)]
pub struct Inputs<R: JitRuntime> {
pub point_orders: JitTensor<R>,
pub tile_point_count: JitTensor<R>,
}
#[derive(Clone, Debug)]
pub struct Outputs<R: JitRuntime> {
pub tile_point_ranges: JitTensor<R>,
}
pub const GROUP_SIZE: u32 = 256;
pub const GROUP_SIZE2: u32 = GROUP_SIZE * GROUP_SIZE;
pub fn main<R: JitRuntime, F: FloatElement, I: IntElement, B: BoolElement>(
arguments: Arguments,
inputs: Inputs<R>,
) -> Outputs<R> {
impl_kernel_source!(Kernel1, "kernel.1.wgsl");
impl_kernel_source!(Kernel2, "kernel.2.wgsl");
let client = &inputs.point_orders.client;
let device = &inputs.point_orders.device;
let group_count = JitBackend::<R, F, I, B>::int_empty([3].into(), device);
let tile_point_ranges = JitBackend::<R, F, I, B>::int_zeros(
[
arguments.tile_count_y as usize,
arguments.tile_count_x as usize,
2,
]
.into(),
device,
);
client.execute(
Box::new(SourceKernel::new(Kernel1, CubeDim { x: 1, y: 1, z: 1 })),
CubeCount::Static(1, 1, 1),
vec![
inputs.tile_point_count.handle.to_owned().binding(),
group_count.handle.to_owned().binding(),
],
);
client.execute(
Box::new(SourceKernel::new(
Kernel2,
CubeDim {
x: GROUP_SIZE,
y: 1,
z: 1,
},
)),
CubeCount::Dynamic(group_count.handle.binding()),
vec![
inputs.tile_point_count.handle.binding(),
inputs.point_orders.handle.binding(),
tile_point_ranges.handle.to_owned().binding(),
],
);
Outputs { tile_point_ranges }
}