gausplat_trainer/range/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
//! Range options module.

pub use burn::config::Config;

/// The range options.
#[derive(Config, Copy, Debug, PartialEq)]
pub struct RangeOptions {
    /// The start of the range.
    pub start: u64,
    /// The end of the range.
    pub end: u64,
    /// The step of the range.
    pub step: u64,
}

impl RangeOptions {
    /// Create a new range with the specified step.
    #[inline]
    pub fn default_with_step(step: u64) -> Self {
        Self {
            step,
            ..Default::default()
        }
    }

    /// Check if the iteration is contained in the range.
    pub fn has(
        &self,
        iteration: u64,
    ) -> bool {
        iteration >= self.start
            && iteration < self.end
            && (iteration - self.start) % self.step == 0
    }
}

impl Default for RangeOptions {
    #[inline]
    fn default() -> Self {
        RangeOptions {
            start: 0,
            end: u64::MAX,
            step: 1,
        }
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn has() {
        use super::*;

        let range = RangeOptions::new(1, 9, 2);

        (0..11).for_each(|i| {
            let target = i % 2 != 0 && i < 9;
            let output = range.has(i);
            assert_eq!(output, target, "range.has({i})");
        });
    }
}