ndarray_layout/transform/
split.rs

1use crate::ArrayLayout;
2
3/// 切分变换参数。
4pub struct Split<'a, const N: usize> {
5    src: &'a ArrayLayout<N>,
6    axis: usize,
7    start: usize,
8    parts: &'a [usize],
9}
10
11impl<const N: usize> ArrayLayout<N> {
12    /// 切分变换讲单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。
13    ///
14    /// ```rust
15    /// # use ndarray_layout::ArrayLayout;
16    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0);
17    /// let mut splits = layout.split(2, &[1, 3]);
18    ///
19    /// let layout = splits.next().unwrap();
20    /// assert_eq!(layout.shape(), &[2, 3, 1]);
21    /// assert_eq!(layout.strides(), &[12, 4, 1]);
22    /// assert_eq!(layout.offset(), 0);
23    ///
24    /// let layout = splits.next().unwrap();
25    /// assert_eq!(layout.shape(), &[2, 3, 3]);
26    /// assert_eq!(layout.strides(), &[12, 4, 1]);
27    /// assert_eq!(layout.offset(), 1);
28    /// ```
29    #[inline]
30    pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> {
31        assert_eq!(self.shape()[axis], parts.iter().sum());
32        Split {
33            src: self,
34            axis,
35            start: 0,
36            parts,
37        }
38    }
39}
40
41impl<const N: usize> Iterator for Split<'_, N> {
42    type Item = ArrayLayout<N>;
43
44    #[inline]
45    fn next(&mut self) -> Option<Self::Item> {
46        self.parts.split_first().map(|(&head, tail)| {
47            let start = self.start;
48            self.start += head;
49            self.parts = tail;
50            self.src.slice(self.axis, start, 1, head)
51        })
52    }
53}
54
55#[test]
56fn test_split() {
57    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0);
58    let mut splits = layout.split(2, &[1, 3]);
59    let layout = splits.next().unwrap();
60    assert_eq!(layout.shape(), &[2, 3, 1]);
61    assert_eq!(layout.strides(), &[12, 4, 1]);
62    assert_eq!(layout.offset(), 0);
63    let layout = splits.next().unwrap();
64    assert_eq!(layout.shape(), &[2, 3, 3]);
65    assert_eq!(layout.strides(), &[12, 4, 1]);
66    assert_eq!(layout.offset(), 1);
67}