gausplat_renderer/render/gaussian_3d/jit/kernel/rasterize/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
//! Rasterizing the point to the image.
pub use super::*;
use burn::tensor::ops::{FloatTensorOps, IntTensorOps};
use bytemuck::{bytes_of, Pod, Zeroable};
/// Arguments.
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct Arguments {
/// $ \text{im}_x $
pub image_size_x: u32,
/// $ \text{im}_y $
pub image_size_y: u32,
/// $ \frac{\text{im}_x}{\text{t}_x} $
///
/// $ \text{t}_x $ is the tile width.
pub tile_count_x: u32,
/// $ \frac{\text{im}_y}{\text{t}_y} $
///
/// $ \text{t}_y $ is the tile height.
pub tile_count_y: u32,
}
/// Inputs.
#[derive(Clone, Debug)]
pub struct Inputs<R: JitRuntime> {
/// $ C_{rgb} \in \mathbb{R}^{3} $ of $ p $ points.
pub colors_rgb_3d: JitTensor<R>,
/// $ \Sigma^{'-1} \in \mathbb{R}^{2 \times 2} $ of $ p $ points.
///
/// Inverse of the 2D covariance.
///
/// It can be $ \mathbb{R}^{3} $ since it is symmetric.
pub conics: JitTensor<R>,
/// $ \alpha \in \mathbb{R} $ of $ p $ points.
pub opacities_3d: JitTensor<R>,
/// $ i \in [0, p) $.
///
/// It is point index per tile.
pub point_indices: JitTensor<R>,
/// $ P^' \in \mathbb{R}^{2} $ of $ p $ points.
///
/// 2D position in screen space.
pub positions_2d: JitTensor<R>,
/// $ [i_{start}, i_{end}) $ of each tile.
pub tile_point_ranges: JitTensor<R>,
}
/// Outputs.
#[derive(Clone, Debug)]
pub struct Outputs<R: JitRuntime> {
/// $ C_{rgb}^' \in \mathbb{R}^{3} $ of each image pixel.
pub colors_rgb_2d: JitTensor<R>,
/// Rendered point count of each image pixel.
pub point_rendered_counts: JitTensor<R>,
/// $ T_{last} $
///
/// Last transmittance of each image pixel.
pub transmittances: JitTensor<R>,
}
/// $ \text{t}_x $
pub const TILE_SIZE_X: u32 = 16;
/// $ \text{t}_y $
pub const TILE_SIZE_Y: u32 = 16;
/// Rasterize the point to the image.
///
/// For each pixel in each tile, do the following steps:
///
/// 1. Collect the points in the tile onto the shared memory.
///
/// 2. Compute the Gaussian density centered at the pixel position $ P_x $
/// using the parameters of each point $ n $ touched the tile
/// ([$ \Sigma^{'-1} $](Inputs::conics) and [$ P_v^' $](Inputs::positions_2d)):
/// $$ D = P_v^' - P_x \in \mathbb{R}^2 $$
/// $$ \sigma_n = e^{-\frac{1}{2} D^T \Sigma^{'-1} D} $$
///
/// 3. Accumulate the transmittance [$ T_n $](Outputs::transmittances)
/// and [$ C_{rgb}^' $](Outputs::colors_rgb_2d) of each pixel
/// using [$ \alpha_n $](Inputs::opacities_3d)
/// and [$ C_{rgb,n} $](Inputs::colors_rgb_3d) of each point $ n $
/// (Order-dependent transparency blending):
/// $$ \alpha_n^' \leftarrow \alpha_n \sigma_n $$
/// $$ T_{n + 1} \leftarrow T_n (1 - \alpha_n^') $$
/// $$ C_{rgb}^' \leftarrow C_{rgb,n}^' + (C_{rgb} \cdot \alpha_n^' \cdot T_n) $$
pub fn main<R: JitRuntime, F: FloatElement, I: IntElement, B: BoolElement>(
arguments: Arguments,
inputs: Inputs<R>,
) -> Outputs<R> {
impl_kernel_source!(Kernel, "kernel.wgsl");
// Specifying the parameters
let client = &inputs.colors_rgb_3d.client;
let device = &inputs.colors_rgb_3d.device;
// I_x
let image_size_x = arguments.image_size_x as usize;
// I_y
let image_size_y = arguments.image_size_y as usize;
// [I_x, I_y, 3]
let colors_rgb_2d = JitBackend::<R, F, I, B>::float_empty(
[image_size_y, image_size_x, 3].into(),
device,
);
// [I_x, I_y]
let point_rendered_counts =
JitBackend::<R, F, I, B>::int_empty([image_size_y, image_size_x].into(), device);
// [I_x, I_y]
let transmittances = JitBackend::<R, F, I, B>::float_empty(
[image_size_y, image_size_x].into(),
device,
);
// Launching the kernel
client.execute(
Box::new(SourceKernel::new(
Kernel,
CubeDim {
x: TILE_SIZE_X,
y: TILE_SIZE_Y,
z: 1,
},
)),
CubeCount::Static(arguments.tile_count_x, arguments.tile_count_y, 1),
vec![
client.create(bytes_of(&arguments)).binding(),
inputs.colors_rgb_3d.handle.binding(),
inputs.conics.handle.binding(),
inputs.opacities_3d.handle.binding(),
inputs.point_indices.handle.binding(),
inputs.positions_2d.handle.binding(),
inputs.tile_point_ranges.handle.binding(),
colors_rgb_2d.handle.to_owned().binding(),
point_rendered_counts.handle.to_owned().binding(),
transmittances.handle.to_owned().binding(),
],
);
Outputs {
colors_rgb_2d,
point_rendered_counts,
transmittances,
}
}