gausplat_renderer/render/gaussian_3d/backward.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
//! 3DGS rendering context (backward).
pub use super::*;
/// Rendering inputs (backward).
#[derive(Clone, Debug)]
pub struct RenderInput<B: Backend> {
/// The shape is `[P, 3]`
pub colors_rgb_3d: B::FloatTensorPrimitive,
/// The shape is `[P, M * 3]` <- `[P, M, 3]`
pub colors_sh: B::FloatTensorPrimitive,
/// `(0 ~ 3)`
pub colors_sh_degree_max: u32,
/// The shape is `[P, 3]`
pub conics: B::FloatTensorPrimitive,
/// The shape is `[P]`
pub depths: B::FloatTensorPrimitive,
/// `f_x <- I_x / tan(Fov_x / 2) / 2`
pub focal_length_x: f32,
/// `f_y <- I_y / tan(Fov_y / 2) / 2`
pub focal_length_y: f32,
/// `I_x / 2`
pub image_size_half_x: f32,
/// `I_y / 2`
pub image_size_half_y: f32,
/// `I_x`
pub image_size_x: u32,
/// `I_y`
pub image_size_y: u32,
/// The shape is `[P, 3]`
pub is_colors_rgb_3d_not_clamped: B::FloatTensorPrimitive,
/// The shape is `[P, 1]`
pub opacities_3d: B::FloatTensorPrimitive,
/// `P`
pub point_count: u32,
/// The shape is `[T]`
pub point_indices: B::IntTensorPrimitive,
/// The shape is `[I_y, I_x]`
pub point_rendered_counts: B::IntTensorPrimitive,
/// The shape is `[P, 2]`
pub positions_2d: B::FloatTensorPrimitive,
/// The shape is `[P, 3]`
pub positions_3d: B::FloatTensorPrimitive,
/// The shape is `[P, 2]`
pub positions_3d_in_normalized: B::FloatTensorPrimitive,
/// The shape is `[P]`
pub radii: B::IntTensorPrimitive,
/// The shape is `[P, 4]`
pub rotations: B::FloatTensorPrimitive,
/// The shape is `[P, 3, 3]`
pub rotations_matrix: B::FloatTensorPrimitive,
/// The shape is `[P, 3]`
pub scalings: B::FloatTensorPrimitive,
/// `I_x / T_x`
pub tile_count_x: u32,
/// `I_y / T_y`
pub tile_count_y: u32,
/// The shape is `[I_y / T_y, I_x / T_x, 2]`
pub tile_point_ranges: B::IntTensorPrimitive,
/// The shape is `[I_y, I_x]`
pub transmittances: B::FloatTensorPrimitive,
/// `tan(Fov_x / 2) * (C_f + 1)`
pub view_bound_x: f32,
/// `tan(Fov_y / 2) * (C_f + 1)`
pub view_bound_y: f32,
/// `[3]`
pub view_position: [f32; 3],
/// `[3 (+ 1), 3 + 1]`
pub view_transform: [[f32; 4]; 4],
}
/// Outputs for rendering (backward).
#[derive(Clone, Debug)]
pub struct RenderOutput<B: Backend> {
/// The shape is `[P, M * 3]` <- `[P, M, 3]`
///
/// It is the gradient of `colors_rgb_2d` with respect to `colors_sh`.
pub colors_sh_grad: B::FloatTensorPrimitive,
/// The shape is `[P, 1]`
///
/// It is the gradient of `colors_rgb_2d` with respect to `colors_sh_degree_max`.
pub opacities_grad: B::FloatTensorPrimitive,
/// The shape is `[P]`
///
/// It is the gradient norm of the 2D positions.
pub positions_2d_grad_norm: B::FloatTensorPrimitive,
/// The shape is `[P, 3]`
///
/// It is the gradient of `colors_rgb_2d` with respect to `positions_3d`.
pub positions_grad: B::FloatTensorPrimitive,
/// The shape is `[P, 4]`
///
/// It is the gradient of `colors_rgb_2d` with respect to `rotations`.
pub rotations_grad: B::FloatTensorPrimitive,
/// The shape is `[P, 3]`
///
/// It is the gradient of `colors_rgb_2d` with respect to `scalings`.
pub scalings_grad: B::FloatTensorPrimitive,
}