pub use super::*;
use crate::error::Error;
use burn::{
module::Param,
nn::{conv, PaddingConfig2d},
tensor::Int,
};
#[derive(Clone, Debug)]
pub struct MeanStructuralSimilarity<B: Backend, const C: usize> {
pub filter: conv::Conv2d<B>,
}
impl<B: Backend, const C: usize> MeanStructuralSimilarity<B, C> {
pub fn init(device: &B::Device) -> Self {
const WEIGHT_SIZE: usize = 11;
const WEIGHT_SIZE_HALF: usize = WEIGHT_SIZE >> 1;
const WEIGHT_STD: f64 = 1.5;
const WEIGHT_STD2: f64 = WEIGHT_STD * WEIGHT_STD;
let mut filter = conv::Conv2dConfig::new([C; 2], [WEIGHT_SIZE; 2])
.with_bias(false)
.with_groups(C)
.with_padding(PaddingConfig2d::Valid)
.init(device);
filter.weight = Param::uninitialized(
Default::default(),
move |device, is_required_grad| {
let size_half = WEIGHT_SIZE_HALF as i64;
let x = Tensor::<B, 1, Int>::arange(-size_half..size_half + 1, device);
let x2_n = x.powi_scalar(2).neg().float().unsqueeze::<2>();
let y2_n = x2_n.to_owned().transpose();
let x2_y2_n = x2_n + y2_n;
let w = x2_y2_n.div_scalar(2.0 * WEIGHT_STD2).exp();
let w_normalized = w.to_owned().div(w.sum().unsqueeze::<2>());
w_normalized
.expand([C, 1, WEIGHT_SIZE, WEIGHT_SIZE])
.set_require_grad(is_required_grad)
},
device.to_owned(),
false,
);
Self { filter }
}
}
impl<B: Backend, const C: usize> Metric<B> for MeanStructuralSimilarity<B, C> {
fn evaluate<const D: usize>(
&self,
value: Tensor<B, D>,
target: Tensor<B, D>,
) -> Tensor<B, 1> {
const K1: f64 = 0.01;
const K2: f64 = 0.03;
const L: f64 = 1.0;
const C1: f64 = (K1 * L) * (K1 * L);
const C2: f64 = (K2 * L) * (K2 * L);
const FRAC_C1_2: f64 = C1 / 2.0;
const FRAC_C2_2: f64 = C2 / 2.0;
let input = (
value.unsqueeze::<4>().expand([-1, C as i64, -1, -1]),
target.unsqueeze::<4>().expand([-1, C as i64, -1, -1]),
);
let input_0_shape = input.0.shape().dims;
let input_1_shape = input.1.shape().dims;
if input_0_shape != input_1_shape {
panic!(
"assertion `left == right` failed: {}",
Error::MismatchedTensorShape(input_0_shape, input_1_shape)
);
}
let filter = &self.filter;
let mean = (
filter.forward(input.0.to_owned()),
filter.forward(input.1.to_owned()),
);
let mean2 = (
mean.0.to_owned().powf_scalar(2.0),
mean.1.to_owned().powf_scalar(2.0),
);
let std2 = (
filter
.forward(input.0.to_owned().powf_scalar(2.0))
.sub(mean2.0.to_owned()),
filter
.forward(input.1.to_owned().powf_scalar(2.0))
.sub(mean2.1.to_owned()),
);
let mean_01 = mean.0.mul(mean.1);
let std_01 = filter.forward(input.0.mul(input.1)).sub(mean_01.to_owned());
let indexes = (mean_01 + FRAC_C1_2) * (std_01 + FRAC_C2_2) * 4.0
/ ((mean2.0 + mean2.1 + C1) * (std2.0 + std2.1 + C2));
indexes.mean()
}
}
impl<B: Backend, const C: usize> Default for MeanStructuralSimilarity<B, C> {
#[inline]
fn default() -> Self {
Self::init(&Default::default())
}
}
#[cfg(test)]
mod tests {
#[test]
fn evaluate() {
use super::*;
use burn::{backend::NdArray, tensor::Distribution};
let device = Default::default();
let metric = MeanStructuralSimilarity::<NdArray<f32>, 3>::init(&device);
let input_0 = Tensor::<NdArray<f32>, 4>::zeros([1, 3, 32, 32], &device);
let input_1 = Tensor::zeros([1, 3, 32, 32], &device);
let score = metric.evaluate(input_0, input_1).into_scalar();
assert_eq!(score, 1.0);
let input_0 = Tensor::<NdArray<f32>, 4>::ones([1, 3, 32, 32], &device);
let input_1 = Tensor::ones([1, 3, 32, 32], &device);
let score = metric.evaluate(input_0, input_1).into_scalar();
assert_eq!(score, 1.0);
let input_0 = Tensor::<NdArray<f32>, 4>::zeros([1, 3, 32, 32], &device);
let input_1 = Tensor::ones([1, 3, 32, 32], &device);
let score = metric.evaluate(input_0, input_1).into_scalar();
assert!(score > 0.0 && score < 1e-4, "score: {:?}", score);
let input_0 = Tensor::<NdArray<f32>, 4>::random(
[1, 3, 32, 32],
Distribution::Uniform(0.01, 0.99),
&device,
);
let input_1 = input_0.to_owned().neg().add_scalar(1.0);
let score = metric.evaluate(input_0, input_1).into_scalar();
assert!(score < 0.0, "score: {:?}", score);
}
#[test]
fn default() {
use super::*;
use burn::backend::NdArray;
MeanStructuralSimilarity::<NdArray<f32>, 3>::default();
}
}