ndarray_layout/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4/// An array layout allow N dimensions inlined.
5pub struct ArrayLayout<const N: usize> {
6    ndim: usize,
7    content: Union<N>,
8}
9
10unsafe impl<const N: usize> Send for ArrayLayout<N> {}
11unsafe impl<const N: usize> Sync for ArrayLayout<N> {}
12
13union Union<const N: usize> {
14    ptr: NonNull<usize>,
15    _inlined: (isize, [usize; N], [isize; N]),
16}
17
18impl<const N: usize> Clone for ArrayLayout<N> {
19    #[inline]
20    fn clone(&self) -> Self {
21        Self::new(self.shape(), self.strides(), self.offset())
22    }
23}
24
25impl<const N: usize> PartialEq for ArrayLayout<N> {
26    #[inline]
27    fn eq(&self, other: &Self) -> bool {
28        self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice()
29    }
30}
31
32impl<const N: usize> Eq for ArrayLayout<N> {}
33
34impl<const N: usize> Drop for ArrayLayout<N> {
35    fn drop(&mut self) {
36        if let Some(ptr) = self.ptr_allocated() {
37            unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) }
38        }
39    }
40}
41
42/// 元信息存储顺序。
43#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
44pub enum Endian {
45    /// 大端序,范围更大的维度在元信息中更靠前的位置。
46    BigEndian,
47    /// 小端序,范围更小的维度在元信息中更靠前的位置。
48    LittleEndian,
49}
50
51impl<const N: usize> ArrayLayout<N> {
52    /// Creates a new Layout with the given shape, strides, and offset.
53    ///
54    /// ```rust
55    /// # use ndarray_layout::ArrayLayout;
56    /// let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
57    /// assert_eq!(layout.offset(), 20);
58    /// assert_eq!(layout.shape(), &[2, 3, 4]);
59    /// assert_eq!(layout.strides(), &[12, -4, 1]);
60    /// ```
61    pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
62        // check
63        assert_eq!(
64            shape.len(),
65            strides.len(),
66            "shape and strides must have the same length"
67        );
68
69        let mut ans = Self::with_ndim(shape.len());
70        let mut content = ans.content_mut();
71        content.set_offset(offset);
72        content.copy_shape(shape);
73        content.copy_strides(strides);
74        ans
75    }
76
77    /// Creates a new contiguous Layout with the given shape.
78    ///
79    /// ```rust
80    /// # use ndarray_layout::{Endian, ArrayLayout};
81    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
82    /// assert_eq!(layout.offset(), 0);
83    /// assert_eq!(layout.shape(), &[2, 3, 4]);
84    /// assert_eq!(layout.strides(), &[4, 8, 24]);
85    /// ```
86    pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self {
87        let mut ans = Self::with_ndim(shape.len());
88        let mut content = ans.content_mut();
89        content.set_offset(0);
90        content.copy_shape(shape);
91        let mut mul = element_size as isize;
92        let push = |i| {
93            content.set_stride(i, mul);
94            mul *= shape[i] as isize;
95        };
96        // 大端小端区别在于是否反转
97        match endian {
98            Endian::BigEndian => (0..shape.len()).rev().for_each(push),
99            Endian::LittleEndian => (0..shape.len()).for_each(push),
100        }
101        ans
102    }
103
104    /// Gets offset.
105    #[inline]
106    pub const fn ndim(&self) -> usize {
107        self.ndim
108    }
109
110    /// Gets offset.
111    #[inline]
112    pub fn offset(&self) -> isize {
113        self.content().offset()
114    }
115
116    /// Gets shape.
117    #[inline]
118    pub fn shape(&self) -> &[usize] {
119        self.content().shape()
120    }
121
122    /// Gets strides.
123    #[inline]
124    pub fn strides(&self) -> &[isize] {
125        self.content().strides()
126    }
127
128    /// Copy data to another `ArrayLayout` with inline size `M`.
129    ///
130    /// ```rust
131    /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout};
132    /// let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], BigEndian, 0);
133    /// assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::<usize>());
134    ///
135    /// let layout = layout.to_inline_size::<2>();
136    /// assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::<usize>());
137    /// ```
138    pub fn to_inline_size<const M: usize>(&self) -> ArrayLayout<M> {
139        ArrayLayout::new(self.shape(), self.strides(), self.offset())
140    }
141
142    /// Calculates the number of elements in the array.
143    ///
144    /// ```rust
145    /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout};
146    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 20);
147    /// assert_eq!(layout.num_elements(), 24);
148    /// ```
149    #[inline]
150    pub fn num_elements(&self) -> usize {
151        self.shape().iter().product()
152    }
153
154    /// Calculates the offset of element at the given `index`.
155    ///
156    /// ```rust
157    /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout};
158    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 4);
159    /// assert_eq!(layout.element_offset(22, BigEndian), 88); // 88 <- (22 % 4 * 4) + (22 / 4 % 3 * 16) + (22 / 4 / 3 % 2 * 48)
160    /// ```
161    pub fn element_offset(&self, index: usize, endian: Endian) -> isize {
162        fn offset_forwards(
163            mut rem: usize,
164            shape: impl IntoIterator<Item = usize>,
165            strides: impl IntoIterator<Item = isize>,
166        ) -> isize {
167            let mut ans = 0;
168            for (d, s) in zip(shape, strides) {
169                ans += s * (rem % d) as isize;
170                rem /= d
171            }
172            ans
173        }
174
175        let shape = self.shape().iter().cloned();
176        let strides = self.strides().iter().cloned();
177        self.offset()
178            + match endian {
179                Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()),
180                Endian::LittleEndian => offset_forwards(index, shape, strides),
181            }
182    }
183
184    /// Calculates the range of data in bytes to determine the location of the memory area that the array needs to access.
185    ///
186    /// ```rust
187    /// # use ndarray_layout::ArrayLayout;
188    /// let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 1], 20);
189    /// let range = layout.data_range();
190    /// assert_eq!(range, 12..=35);
191    /// ```
192    pub fn data_range(&self) -> RangeInclusive<isize> {
193        let content = self.content();
194        let mut start = content.offset();
195        let mut end = content.offset();
196        for (&d, s) in zip(content.shape(), content.strides()) {
197            use std::cmp::Ordering::{Equal, Greater, Less};
198            let i = d as isize - 1;
199            match s.cmp(&0) {
200                Equal => {}
201                Less => start += s * i,
202                Greater => end += s * i,
203            }
204        }
205        start..=end
206    }
207}
208
209mod fmt;
210mod transform;
211pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg};
212
213use std::{
214    alloc::{Layout, alloc, dealloc},
215    iter::zip,
216    ops::RangeInclusive,
217    ptr::{NonNull, copy_nonoverlapping},
218    slice::from_raw_parts,
219};
220
221impl<const N: usize> ArrayLayout<N> {
222    #[inline]
223    fn ptr_allocated(&self) -> Option<NonNull<usize>> {
224        const { assert!(N > 0) }
225        // ndim > N 则 content 是 ptr,否则是元组。
226        if self.ndim > N {
227            Some(unsafe { self.content.ptr })
228        } else {
229            None
230        }
231    }
232
233    #[inline]
234    fn content(&self) -> Content<false> {
235        Content {
236            ptr: self
237                .ptr_allocated()
238                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
239            ndim: self.ndim,
240        }
241    }
242
243    #[inline]
244    fn content_mut(&mut self) -> Content<true> {
245        Content {
246            ptr: self
247                .ptr_allocated()
248                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
249            ndim: self.ndim,
250        }
251    }
252
253    /// Create a new ArrayLayout with the given dimensions.
254    #[inline]
255    fn with_ndim(ndim: usize) -> Self {
256        Self {
257            ndim,
258            content: if ndim <= N {
259                Union {
260                    _inlined: (0, [0; N], [0; N]),
261                }
262            } else {
263                Union {
264                    ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) },
265                }
266            },
267        }
268    }
269}
270
271struct Content<const MUT: bool> {
272    ptr: NonNull<usize>,
273    ndim: usize,
274}
275
276impl<const MUT: bool> Content<MUT> {
277    #[inline]
278    fn as_slice(&self) -> &[usize] {
279        unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) }
280    }
281
282    #[inline]
283    fn offset(&self) -> isize {
284        unsafe { self.ptr.cast().read() }
285    }
286
287    #[inline]
288    fn shape<'a>(&self) -> &'a [usize] {
289        unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) }
290    }
291
292    #[inline]
293    fn strides<'a>(&self) -> &'a [isize] {
294        unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) }
295    }
296}
297
298impl Content<true> {
299    #[inline]
300    fn set_offset(&mut self, val: isize) {
301        unsafe { self.ptr.cast().write(val) }
302    }
303
304    #[inline]
305    fn set_shape(&mut self, idx: usize, val: usize) {
306        assert!(idx < self.ndim);
307        unsafe { self.ptr.add(1 + idx).write(val) }
308    }
309
310    #[inline]
311    fn set_stride(&mut self, idx: usize, val: isize) {
312        assert!(idx < self.ndim);
313        unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) }
314    }
315
316    #[inline]
317    fn copy_shape(&mut self, val: &[usize]) {
318        assert!(val.len() == self.ndim);
319        unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) }
320    }
321
322    #[inline]
323    fn copy_strides(&mut self, val: &[isize]) {
324        assert!(val.len() == self.ndim);
325        unsafe {
326            copy_nonoverlapping(
327                val.as_ptr(),
328                self.ptr.add(1 + self.ndim).cast().as_ptr(),
329                self.ndim,
330            )
331        }
332    }
333}
334
335#[inline]
336fn layout(ndim: usize) -> Layout {
337    Layout::array::<usize>(1 + ndim * 2).unwrap()
338}
339
340#[test]
341fn test_new() {
342    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
343    assert_eq!(layout.offset(), 20);
344    assert_eq!(layout.shape(), &[2, 3, 4]);
345    assert_eq!(layout.strides(), &[12, -4, 1]);
346    assert_eq!(layout.ndim(), 3);
347}
348
349#[test]
350fn test_new_contiguous_little_endian() {
351    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
352    assert_eq!(layout.offset(), 0);
353    assert_eq!(layout.shape(), &[2, 3, 4]);
354    assert_eq!(layout.strides(), &[4, 8, 24]);
355}
356
357#[test]
358fn test_new_contiguous_big_endian() {
359    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
360    assert_eq!(layout.offset(), 0);
361    assert_eq!(layout.shape(), &[2, 3, 4]);
362    assert_eq!(layout.strides(), &[4, 8, 24]);
363}
364
365#[test]
366#[should_panic(expected = "shape and strides must have the same length")]
367fn test_new_invalid_shape_strides_length() {
368    ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20);
369}
370
371#[test]
372fn test_to_inline_size() {
373    let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0);
374    assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::<usize>());
375    let layout = layout.to_inline_size::<2>();
376    assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::<usize>());
377}
378
379#[test]
380fn test_num_elements() {
381    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20);
382    assert_eq!(layout.num_elements(), 24);
383}
384
385#[test]
386fn test_element_offset_little_endian() {
387    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
388    assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88);
389}
390
391#[test]
392fn test_element_offset_big_endian() {
393    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4);
394    assert_eq!(layout.element_offset(22, Endian::BigEndian), 88);
395}
396
397#[test]
398fn test_data_range_positive_strides() {
399    let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
400    let range = layout.data_range();
401    assert_eq!(range, 0..=92); // 0 + 2*4 + 3*8 + 4*24 = 92
402}
403
404#[test]
405fn test_data_range_mixed_strides() {
406    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 0], 20);
407    let range = layout.data_range();
408    assert_eq!(range, 12..=32);
409}
410
411#[test]
412fn test_clone_and_eq() {
413    let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
414    let layout2 = layout1.clone();
415    assert!(layout1.eq(&layout2));
416}
417
418#[test]
419fn test_drop() {
420    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
421    drop(layout);
422}