ndarray_layout/transform/
transpose.rs1use crate::ArrayLayout;
2use std::{collections::BTreeSet, iter::zip};
3
4impl<const N: usize> ArrayLayout<N> {
5 pub fn transpose(&self, perm: &[usize]) -> Self {
15 let perm_ = perm.iter().collect::<BTreeSet<_>>();
16 assert_eq!(perm_.len(), perm.len());
17
18 let content = self.content();
19 let shape = content.shape();
20 let strides = content.strides();
21
22 let mut ans = Self::with_ndim(self.ndim);
23 let mut content = ans.content_mut();
24 content.set_offset(self.offset());
25 let mut set = |i, j| {
26 content.set_shape(i, shape[j]);
27 content.set_stride(i, strides[j]);
28 };
29
30 let mut last = 0;
31 for (&i, &j) in zip(perm_, perm) {
32 for i in last..i {
33 set(i, i);
34 }
35 set(i, j);
36 last = i + 1;
37 }
38 for i in last..shape.len() {
39 set(i, i);
40 }
41 ans
42 }
43}
44
45#[test]
46fn test_transpose() {
47 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]);
48 assert_eq!(layout.shape(), &[3, 2, 4]);
49 assert_eq!(layout.strides(), &[4, 12, 1]);
50 assert_eq!(layout.offset(), 0);
51
52 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[2, 0]);
53 assert_eq!(layout.shape(), &[4, 3, 2]);
54 assert_eq!(layout.strides(), &[1, 4, 12]);
55 assert_eq!(layout.offset(), 0);
56}