ndarray_layout/transform/
split.rs1use crate::ArrayLayout;
2
3pub struct Split<'a, const N: usize> {
5 src: &'a ArrayLayout<N>,
6 axis: usize,
7 start: usize,
8 parts: &'a [usize],
9}
10
11impl<const N: usize> ArrayLayout<N> {
12 #[inline]
30 pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> {
31 assert_eq!(self.shape()[axis], parts.iter().sum());
32 Split {
33 src: self,
34 axis,
35 start: 0,
36 parts,
37 }
38 }
39}
40
41impl<const N: usize> Iterator for Split<'_, N> {
42 type Item = ArrayLayout<N>;
43
44 #[inline]
45 fn next(&mut self) -> Option<Self::Item> {
46 self.parts.split_first().map(|(&head, tail)| {
47 let start = self.start;
48 self.start += head;
49 self.parts = tail;
50 self.src.slice(self.axis, start, 1, head)
51 })
52 }
53}
54
55#[test]
56fn test_split() {
57 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0);
58 let mut splits = layout.split(2, &[1, 3]);
59 let layout = splits.next().unwrap();
60 assert_eq!(layout.shape(), &[2, 3, 1]);
61 assert_eq!(layout.strides(), &[12, 4, 1]);
62 assert_eq!(layout.offset(), 0);
63 let layout = splits.next().unwrap();
64 assert_eq!(layout.shape(), &[2, 3, 3]);
65 assert_eq!(layout.strides(), &[12, 4, 1]);
66 assert_eq!(layout.offset(), 1);
67}