gausplat_trainer/metric/
psnr.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
//! Peak signal-to-noise ratio (PSNR) metric.

pub use super::*;

/// Computing the peak signal-to-noise ratio (PSNR) between the inputs:
///
/// `10 * log10(1 / MSE) = -10 / log(10) * log(MSE)`
///
/// ## Details
///
/// It relies on [`MSE`](MeanSquareError).
#[derive(Clone, Debug)]
pub struct Psnr<B: Backend> {
    /// Coefficient for PSNR.
    pub coefficient: Tensor<B, 1>,
    /// Inner metric.
    pub mse: MeanSquareError,
}

impl<B: Backend> Psnr<B> {
    /// Initialize the metric.
    pub fn init(device: &B::Device) -> Self {
        let ten = Tensor::<B, 1>::from_floats([10.0], device);
        let coefficient = ten.clone().neg().div(ten.log());
        let mse = MeanSquareError::init();
        Self { coefficient, mse }
    }
}

impl<B: Backend> Metric<B> for Psnr<B> {
    /// ## Returns
    ///
    /// The peak signal-to-noise ratio (PSNR) with shape `[1]`.
    #[inline]
    fn evaluate<const D: usize>(
        &self,
        value: Tensor<B, D>,
        target: Tensor<B, D>,
    ) -> Tensor<B, 1> {
        let mse = self.mse.evaluate(value, target);
        self.coefficient.to_owned().mul(mse.log())
    }
}

impl<B: Backend> Default for Psnr<B> {
    fn default() -> Self {
        Self::init(&Default::default())
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn default() {
        use super::*;
        use burn::backend::NdArray;

        let target = -10.0 / 10.0_f32.ln();
        let output = Psnr::<NdArray>::default().coefficient.into_scalar();
        assert_eq!(output, target);
    }

    #[test]
    fn evaluate() {
        use super::*;
        use burn::backend::NdArray;

        let device = Default::default();
        let metric = Psnr::init(&device);

        let input_0 = Tensor::<NdArray, 4>::zeros([1, 3, 256, 256], &device);
        let input_1 = Tensor::<NdArray, 4>::zeros([1, 3, 256, 256], &device);
        let score = metric.evaluate(input_0, input_1).into_scalar();
        assert_eq!(score, f32::INFINITY);

        let input_0 = Tensor::<NdArray, 4>::ones([1, 3, 256, 256], &device);
        let input_1 = Tensor::<NdArray, 4>::ones([1, 3, 256, 256], &device);
        let score = metric.evaluate(input_0, input_1).into_scalar();
        assert_eq!(score, f32::INFINITY);

        let input_0 = Tensor::<NdArray, 4>::zeros([1, 3, 256, 256], &device);
        let input_1 = Tensor::<NdArray, 4>::ones([1, 3, 256, 256], &device);
        let score = metric.evaluate(input_0, input_1).into_scalar();
        assert_eq!(score, 0.0);

        let input_0 = Tensor::<NdArray, 2>::from_floats(
            [[0.0, 0.1, 0.2], [0.5, 0.4, 0.3]],
            &device,
        );
        let input_1 = Tensor::<NdArray, 2>::from_floats(
            [[0.5, 0.6, 0.7], [0.0, 0.9, 0.8]],
            &device,
        );
        let score = metric.evaluate(input_0, input_1).into_scalar();
        assert_eq!(score, 6.0206);

        let input_0 = Tensor::<NdArray, 2>::from_floats(
            [[0.0, 0.1, 0.2], [0.5, 0.4, 0.3]],
            &device,
        );
        let input_1 = Tensor::<NdArray, 2>::from_floats(
            [[0.0, 0.6, 0.7], [0.0, 0.4, 0.3]],
            &device,
        );
        let score = metric.evaluate(input_0, input_1).into_scalar();
        assert_eq!(score, 9.030899);
    }
}