gausplat_renderer/render/gaussian_3d/jit/kernel/transform_backward/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//! Transforming the points (backward).

pub use super::*;

use burn::tensor::ops::FloatTensorOps;
use bytemuck::{bytes_of, Pod, Zeroable};

/// Arguments.
#[repr(C, align(16))]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct Arguments {
    /// `(0 ~ 3)`
    pub colors_sh_degree_max: u32,
    /// `f_x <- I_x / tan(Fov_x / 2) / 2`
    pub focal_length_x: f32,
    /// `f_y <- I_y / tan(Fov_y / 2) / 2`
    pub focal_length_y: f32,
    /// `I_x / 2`
    pub image_size_half_x: f32,
    /// `I_y / 2`
    pub image_size_half_y: f32,
    /// `P`
    pub point_count: u32,
    /// `tan(Fov_x / 2) * (C_f + 1)`
    pub view_bound_x: f32,
    /// `tan(Fov_y / 2) * (C_f + 1)`
    pub view_bound_y: f32,
    /// `[3]`
    pub view_position: [f32; 3],
    /// Padding
    pub _padding_1: [u32; 1],
    /// `[3 (+ 1), 3 + 1]`
    pub view_transform: [[f32; 4]; 4],
}

/// Inputs.
#[derive(Clone, Debug)]
pub struct Inputs<R: JitRuntime> {
    /// `[P, 3]`
    pub colors_rgb_3d_grad: JitTensor<R>,
    /// `[P, M * 3]` <- `[P, M, 3]`
    pub colors_sh: JitTensor<R>,
    /// `[P, 3]`
    pub conics: JitTensor<R>,
    /// `[P, 3]`
    pub conics_grad: JitTensor<R>,
    /// `[P]`
    pub depths: JitTensor<R>,
    /// `[P, 3]`
    pub is_colors_rgb_3d_not_clamped: JitTensor<R>,
    /// `[P, 2]`
    pub positions_2d_grad: JitTensor<R>,
    /// `[P, 3]`
    pub positions_3d: JitTensor<R>,
    /// `[P, 2]`
    pub positions_3d_in_normalized: JitTensor<R>,
    /// `[P]`
    pub radii: JitTensor<R>,
    /// `[P, 4]`
    pub rotations: JitTensor<R>,
    /// `[P, 3, 3]`
    pub rotations_matrix: JitTensor<R>,
    /// `[P, 3]`
    pub scalings: JitTensor<R>,
}

/// Outputs.
#[derive(Clone, Debug)]
pub struct Outputs<R: JitRuntime> {
    /// `[P, M * 3]` <- `[P, M, 3]`
    pub colors_sh_grad: JitTensor<R>,
    /// `[P]`
    pub positions_2d_grad_norm: JitTensor<R>,
    /// `[P, 3]`
    pub positions_3d_grad: JitTensor<R>,
    /// `[P, 4]`
    pub rotations_grad: JitTensor<R>,
    /// `[P, 3]`
    pub scalings_grad: JitTensor<R>,
}

/// `C_f`
pub const FILTER_LOW_PASS: f64 = 0.3;
/// Group size.
pub const GROUP_SIZE: u32 = 256;

/// Transforming the points.
pub fn main<R: JitRuntime, F: FloatElement, I: IntElement, B: BoolElement>(
    arguments: Arguments,
    inputs: Inputs<R>,
) -> Outputs<R> {
    impl_kernel_source!(Kernel, "kernel.wgsl");

    // Specifying the parameters

    let client = &inputs.colors_rgb_3d_grad.client;
    let device = &inputs.colors_rgb_3d_grad.device;
    // P
    let point_count = arguments.point_count as usize;

    // [P, M * 3] <- [P, M, 3]
    let colors_sh_grad =
        JitBackend::<R, F, I, B>::float_zeros([point_count, 48].into(), device);
    let positions_2d_grad_norm =
        JitBackend::<R, F, I, B>::float_zeros([point_count].into(), device);
    let positions_3d_grad =
        JitBackend::<R, F, I, B>::float_zeros([point_count, 3].into(), device);
    let rotations_grad =
        JitBackend::<R, F, I, B>::float_zeros([point_count, 4].into(), device);
    let scalings_grad =
        JitBackend::<R, F, I, B>::float_zeros([point_count, 3].into(), device);

    client.execute(
        Box::new(SourceKernel::new(
            Kernel,
            CubeDim {
                x: GROUP_SIZE,
                y: 1,
                z: 1,
            },
        )),
        CubeCount::Static((point_count as u32).div_ceil(GROUP_SIZE), 1, 1),
        vec![
            client.create(bytes_of(&arguments)).binding(),
            inputs.colors_rgb_3d_grad.handle.binding(),
            inputs.colors_sh.handle.binding(),
            inputs.conics.handle.binding(),
            inputs.conics_grad.handle.binding(),
            inputs.depths.handle.binding(),
            inputs.is_colors_rgb_3d_not_clamped.handle.binding(),
            inputs.positions_2d_grad.handle.binding(),
            inputs.positions_3d.handle.binding(),
            inputs.positions_3d_in_normalized.handle.binding(),
            inputs.radii.handle.binding(),
            inputs.rotations.handle.binding(),
            inputs.rotations_matrix.handle.binding(),
            inputs.scalings.handle.binding(),
            colors_sh_grad.handle.to_owned().binding(),
            positions_2d_grad_norm.handle.to_owned().binding(),
            positions_3d_grad.handle.to_owned().binding(),
            rotations_grad.handle.to_owned().binding(),
            scalings_grad.handle.to_owned().binding(),
        ],
    );

    Outputs {
        colors_sh_grad,
        positions_2d_grad_norm,
        positions_3d_grad,
        rotations_grad,
        scalings_grad,
    }
}