gausplat_renderer/render/gaussian_3d/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//! 3DGS rendering implementation.
//!
//! For more information, see:
//! 1. [3DGS survey](https://arxiv.org/abs/2401.03890).
//! 2. [JIT kernel API](jit::kernel).

pub mod backward;
pub mod forward;
pub mod jit;

pub use super::view::*;
pub use crate::{
    backend::{autodiff, Autodiff, AutodiffBackend, Backend},
    error::Error,
    spherical_harmonics::SH_DEGREE_MAX,
};
pub use burn::{
    config::Config,
    record::Record,
    tensor::{Int, Tensor},
};

use std::fmt;

/// 3DGS scene renderer.
pub trait Gaussian3dRenderer<B: Backend>: 'static + Send + Sized + fmt::Debug {
    /// Render the 3DGS scene (forward).
    fn render_forward(
        input: forward::RenderInput<B>,
        view: &View,
        options: &Gaussian3dRenderOptions,
    ) -> Result<forward::RenderOutput<B>, Error>;

    /// Render the 3DGS scene (backward).
    ///
    /// It computes the gradients from
    /// the [output in forward pass](forward::RenderOutput).
    fn render_backward(
        state: backward::RenderInput<B>,
        colors_rgb_2d_grad: B::FloatTensorPrimitive,
    ) -> backward::RenderOutput<B>;
}

/// 3DGS rendering options.
#[derive(Config, Copy, Debug, PartialEq, Record)]
pub struct Gaussian3dRenderOptions {
    #[config(default = "SH_DEGREE_MAX")]
    /// The maximum degree of color in SH space.
    ///
    /// It should be no more than [`SH_DEGREE_MAX`].
    pub colors_sh_degree_max: u32,
}

/// 3DGS rendering output.
#[derive(Clone)]
pub struct Gaussian3dRenderOutput<B: Backend> {
    /// `[I_y, I_x, 3]`
    pub colors_rgb_2d: Tensor<B, 3>,
    // TODO: THM
}

/// 3DGS rendering output (autodiff enabled).
#[derive(Clone)]
pub struct Gaussian3dRenderOutputAutodiff<AB: AutodiffBackend> {
    /// 2D Colors in RGB space.
    ///
    /// The shape is `[I_y, I_x, 3]`.
    /// - `I_y`: Image height.
    /// - `I_x`: Image width.
    ///
    /// It is the rendered image.
    pub colors_rgb_2d: Tensor<AB, 3>,
    /// Its gradient is the gradient norm of the 2D positions.
    ///
    /// The gradient shape is `[P]`.
    /// - `P`: Point count.
    ///
    /// ## Usage
    ///
    /// ```ignore
    /// use burn::backend::autodiff::grads::Gradients;
    ///
    /// let mut grads: Gradients = todo!();
    ///
    /// let positions_2d_grad_norm =
    ///     positions_2d_grad_norm_ref.grad_remove(&mut grads);
    /// ```
    pub positions_2d_grad_norm_ref: Tensor<AB, 1>,
    /// Visible radii of 3D Gaussians.
    ///
    /// The shape is `[P]`.
    /// - `P`: Point count.
    pub radii: Tensor<AB::InnerBackend, 1, Int>,
}

impl Default for Gaussian3dRenderOptions {
    #[inline]
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(not(test))]
impl<B: Backend> fmt::Debug for Gaussian3dRenderOutput<B> {
    fn fmt(
        &self,
        f: &mut fmt::Formatter<'_>,
    ) -> fmt::Result {
        f.debug_struct(&format!("RenderOutput<{}>", B::name()))
            .field("colors_rgb_2d.dims()", &self.colors_rgb_2d.dims())
            .finish()
    }
}

#[cfg(not(test))]
impl<AB: AutodiffBackend> fmt::Debug for Gaussian3dRenderOutputAutodiff<AB> {
    fn fmt(
        &self,
        f: &mut fmt::Formatter<'_>,
    ) -> fmt::Result {
        let radii_dims = self.radii.dims();
        let positions_2d_grad_norm_dims = &radii_dims;

        f.debug_struct(&format!("RenderOutputAutodiff<{}>", AB::name()))
            .field("colors_rgb_2d.dims()", &self.colors_rgb_2d.dims())
            .field(
                "positions_2d_grad_norm.dims()",
                &positions_2d_grad_norm_dims,
            )
            .field("radii.dims()", &radii_dims)
            .finish()
    }
}