gausplat_renderer/render/gaussian_3d/jit/kernel/rank/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
//! Ranking the points.

pub use super::*;

use burn::tensor::ops::IntTensorOps;
use bytemuck::{bytes_of, from_bytes, Pod, Zeroable};

/// Arguments.
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct Arguments {
    /// `P`
    pub point_count: u32,
    /// `I_x / T_x`
    pub tile_count_x: u32,
}

/// Inputs.
#[derive(Clone, Debug)]
pub struct Inputs<R: JitRuntime> {
    /// `[P]`
    pub depths: JitTensor<R>,
    /// `[P, 4]`
    pub point_tile_bounds: JitTensor<R>,
    /// `[P]`
    pub radii: JitTensor<R>,
    /// `T`
    pub tile_point_count: JitTensor<R>,
    /// `[P]`
    pub tile_touched_offsets: JitTensor<R>,
}

/// Outputs.
#[derive(Clone, Debug)]
pub struct Outputs<R: JitRuntime> {
    /// `[T]`
    pub point_indices: JitTensor<R>,
    /// `[T]`
    pub point_orders: JitTensor<R>,
}

/// Group size.
pub const GROUP_SIZE: u32 = 256;
/// Maximum of `(I_y / T_y) * (I_x / T_x)`
pub const TILE_COUNT_MAX: u32 = 1 << 16;
/// `E[T / P]`
pub const FACTOR_TILE_POINT_COUNT: u32 = 65;

/// Rank the points by its tile index and depth.
pub fn main<R: JitRuntime, F: FloatElement, I: IntElement, B: BoolElement>(
    arguments: Arguments,
    inputs: Inputs<R>,
) -> Outputs<R> {
    impl_kernel_source!(Kernel, "kernel.wgsl");

    // Specifying the parameters

    let client = &inputs.depths.client;
    let device = &inputs.depths.device;

    let tile_point_count = *from_bytes::<u32>(
        &client.read([inputs.tile_point_count.handle.to_owned().binding()].into())[0],
    ) as usize;

    // [T]
    let point_indices =
        JitBackend::<R, F, I, B>::int_empty([tile_point_count].into(), device);
    // [T]
    let point_orders =
        JitBackend::<R, F, I, B>::int_empty([tile_point_count].into(), device);

    // Launching the kernel

    client.execute(
        Box::new(SourceKernel::new(
            Kernel,
            CubeDim {
                x: GROUP_SIZE,
                y: 1,
                z: 1,
            },
        )),
        CubeCount::Static(arguments.point_count.div_ceil(GROUP_SIZE), 1, 1),
        vec![
            client.create(bytes_of(&arguments)).binding(),
            inputs.depths.handle.binding(),
            inputs.point_tile_bounds.handle.binding(),
            inputs.radii.handle.binding(),
            inputs.tile_touched_offsets.handle.binding(),
            point_indices.handle.to_owned().binding(),
            point_orders.handle.to_owned().binding(),
        ],
    );

    Outputs {
        point_indices,
        point_orders,
    }
}