ndarray_layout/transform/
transpose.rs

1use crate::ArrayLayout;
2use std::{collections::BTreeSet, iter::zip};
3
4impl<const N: usize> ArrayLayout<N> {
5    /// 转置变换允许调换张量的维度顺序,但不改变元素的存储顺序。
6    ///
7    /// ```rust
8    /// # use ndarray_layout::ArrayLayout;
9    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]);
10    /// assert_eq!(layout.shape(), &[3, 2, 4]);
11    /// assert_eq!(layout.strides(), &[4, 12, 1]);
12    /// assert_eq!(layout.offset(), 0);
13    /// ```
14    pub fn transpose(&self, perm: &[usize]) -> Self {
15        let perm_ = perm.iter().collect::<BTreeSet<_>>();
16        assert_eq!(perm_.len(), perm.len());
17
18        let content = self.content();
19        let shape = content.shape();
20        let strides = content.strides();
21
22        let mut ans = Self::with_ndim(self.ndim);
23        let mut content = ans.content_mut();
24        content.set_offset(self.offset());
25        let mut set = |i, j| {
26            content.set_shape(i, shape[j]);
27            content.set_stride(i, strides[j]);
28        };
29
30        let mut last = 0;
31        for (&i, &j) in zip(perm_, perm) {
32            for i in last..i {
33                set(i, i);
34            }
35            set(i, j);
36            last = i + 1;
37        }
38        for i in last..shape.len() {
39            set(i, i);
40        }
41        ans
42    }
43}
44
45#[test]
46fn test_transpose() {
47    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]);
48    assert_eq!(layout.shape(), &[3, 2, 4]);
49    assert_eq!(layout.strides(), &[4, 12, 1]);
50    assert_eq!(layout.offset(), 0);
51
52    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[2, 0]);
53    assert_eq!(layout.shape(), &[4, 3, 2]);
54    assert_eq!(layout.strides(), &[1, 4, 12]);
55    assert_eq!(layout.offset(), 0);
56}