1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4pub struct ArrayLayout<const N: usize> {
6 ndim: usize,
7 content: Union<N>,
8}
9
10unsafe impl<const N: usize> Send for ArrayLayout<N> {}
11unsafe impl<const N: usize> Sync for ArrayLayout<N> {}
12
13union Union<const N: usize> {
14 ptr: NonNull<usize>,
15 _inlined: (isize, [usize; N], [isize; N]),
16}
17
18impl<const N: usize> Clone for ArrayLayout<N> {
19 #[inline]
20 fn clone(&self) -> Self {
21 Self::new(self.shape(), self.strides(), self.offset())
22 }
23}
24
25impl<const N: usize> PartialEq for ArrayLayout<N> {
26 #[inline]
27 fn eq(&self, other: &Self) -> bool {
28 self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice()
29 }
30}
31
32impl<const N: usize> Eq for ArrayLayout<N> {}
33
34impl<const N: usize> Drop for ArrayLayout<N> {
35 fn drop(&mut self) {
36 if let Some(ptr) = self.ptr_allocated() {
37 unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) }
38 }
39 }
40}
41
42#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
44pub enum Endian {
45 BigEndian,
47 LittleEndian,
49}
50
51impl<const N: usize> ArrayLayout<N> {
52 pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
62 assert_eq!(
64 shape.len(),
65 strides.len(),
66 "shape and strides must have the same length"
67 );
68
69 let mut ans = Self::with_ndim(shape.len());
70 let mut content = ans.content_mut();
71 content.set_offset(offset);
72 content.copy_shape(shape);
73 content.copy_strides(strides);
74 ans
75 }
76
77 pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self {
87 let mut ans = Self::with_ndim(shape.len());
88 let mut content = ans.content_mut();
89 content.set_offset(0);
90 content.copy_shape(shape);
91 let mut mul = element_size as isize;
92 let push = |i| {
93 content.set_stride(i, mul);
94 mul *= shape[i] as isize;
95 };
96 match endian {
98 Endian::BigEndian => (0..shape.len()).rev().for_each(push),
99 Endian::LittleEndian => (0..shape.len()).for_each(push),
100 }
101 ans
102 }
103
104 #[inline]
106 pub const fn ndim(&self) -> usize {
107 self.ndim
108 }
109
110 #[inline]
112 pub fn offset(&self) -> isize {
113 self.content().offset()
114 }
115
116 #[inline]
118 pub fn shape(&self) -> &[usize] {
119 self.content().shape()
120 }
121
122 #[inline]
124 pub fn strides(&self) -> &[isize] {
125 self.content().strides()
126 }
127
128 pub fn to_inline_size<const M: usize>(&self) -> ArrayLayout<M> {
139 ArrayLayout::new(self.shape(), self.strides(), self.offset())
140 }
141
142 #[inline]
150 pub fn num_elements(&self) -> usize {
151 self.shape().iter().product()
152 }
153
154 pub fn element_offset(&self, index: usize, endian: Endian) -> isize {
162 fn offset_forwards(
163 mut rem: usize,
164 shape: impl IntoIterator<Item = usize>,
165 strides: impl IntoIterator<Item = isize>,
166 ) -> isize {
167 let mut ans = 0;
168 for (d, s) in zip(shape, strides) {
169 ans += s * (rem % d) as isize;
170 rem /= d
171 }
172 ans
173 }
174
175 let shape = self.shape().iter().cloned();
176 let strides = self.strides().iter().cloned();
177 self.offset()
178 + match endian {
179 Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()),
180 Endian::LittleEndian => offset_forwards(index, shape, strides),
181 }
182 }
183
184 pub fn data_range(&self) -> RangeInclusive<isize> {
193 let content = self.content();
194 let mut start = content.offset();
195 let mut end = content.offset();
196 for (&d, s) in zip(content.shape(), content.strides()) {
197 use std::cmp::Ordering::{Equal, Greater, Less};
198 let i = d as isize - 1;
199 match s.cmp(&0) {
200 Equal => {}
201 Less => start += s * i,
202 Greater => end += s * i,
203 }
204 }
205 start..=end
206 }
207}
208
209mod fmt;
210mod transform;
211pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg};
212
213use std::{
214 alloc::{Layout, alloc, dealloc},
215 iter::zip,
216 ops::RangeInclusive,
217 ptr::{NonNull, copy_nonoverlapping},
218 slice::from_raw_parts,
219};
220
221impl<const N: usize> ArrayLayout<N> {
222 #[inline]
223 fn ptr_allocated(&self) -> Option<NonNull<usize>> {
224 const { assert!(N > 0) }
225 if self.ndim > N {
227 Some(unsafe { self.content.ptr })
228 } else {
229 None
230 }
231 }
232
233 #[inline]
234 fn content(&self) -> Content<false> {
235 Content {
236 ptr: self
237 .ptr_allocated()
238 .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
239 ndim: self.ndim,
240 }
241 }
242
243 #[inline]
244 fn content_mut(&mut self) -> Content<true> {
245 Content {
246 ptr: self
247 .ptr_allocated()
248 .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
249 ndim: self.ndim,
250 }
251 }
252
253 #[inline]
255 fn with_ndim(ndim: usize) -> Self {
256 Self {
257 ndim,
258 content: if ndim <= N {
259 Union {
260 _inlined: (0, [0; N], [0; N]),
261 }
262 } else {
263 Union {
264 ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) },
265 }
266 },
267 }
268 }
269}
270
271struct Content<const MUT: bool> {
272 ptr: NonNull<usize>,
273 ndim: usize,
274}
275
276impl<const MUT: bool> Content<MUT> {
277 #[inline]
278 fn as_slice(&self) -> &[usize] {
279 unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) }
280 }
281
282 #[inline]
283 fn offset(&self) -> isize {
284 unsafe { self.ptr.cast().read() }
285 }
286
287 #[inline]
288 fn shape<'a>(&self) -> &'a [usize] {
289 unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) }
290 }
291
292 #[inline]
293 fn strides<'a>(&self) -> &'a [isize] {
294 unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) }
295 }
296}
297
298impl Content<true> {
299 #[inline]
300 fn set_offset(&mut self, val: isize) {
301 unsafe { self.ptr.cast().write(val) }
302 }
303
304 #[inline]
305 fn set_shape(&mut self, idx: usize, val: usize) {
306 assert!(idx < self.ndim);
307 unsafe { self.ptr.add(1 + idx).write(val) }
308 }
309
310 #[inline]
311 fn set_stride(&mut self, idx: usize, val: isize) {
312 assert!(idx < self.ndim);
313 unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) }
314 }
315
316 #[inline]
317 fn copy_shape(&mut self, val: &[usize]) {
318 assert!(val.len() == self.ndim);
319 unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) }
320 }
321
322 #[inline]
323 fn copy_strides(&mut self, val: &[isize]) {
324 assert!(val.len() == self.ndim);
325 unsafe {
326 copy_nonoverlapping(
327 val.as_ptr(),
328 self.ptr.add(1 + self.ndim).cast().as_ptr(),
329 self.ndim,
330 )
331 }
332 }
333}
334
335#[inline]
336fn layout(ndim: usize) -> Layout {
337 Layout::array::<usize>(1 + ndim * 2).unwrap()
338}
339
340#[test]
341fn test_new() {
342 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
343 assert_eq!(layout.offset(), 20);
344 assert_eq!(layout.shape(), &[2, 3, 4]);
345 assert_eq!(layout.strides(), &[12, -4, 1]);
346 assert_eq!(layout.ndim(), 3);
347}
348
349#[test]
350fn test_new_contiguous_little_endian() {
351 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
352 assert_eq!(layout.offset(), 0);
353 assert_eq!(layout.shape(), &[2, 3, 4]);
354 assert_eq!(layout.strides(), &[4, 8, 24]);
355}
356
357#[test]
358fn test_new_contiguous_big_endian() {
359 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
360 assert_eq!(layout.offset(), 0);
361 assert_eq!(layout.shape(), &[2, 3, 4]);
362 assert_eq!(layout.strides(), &[4, 8, 24]);
363}
364
365#[test]
366#[should_panic(expected = "shape and strides must have the same length")]
367fn test_new_invalid_shape_strides_length() {
368 ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20);
369}
370
371#[test]
372fn test_to_inline_size() {
373 let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0);
374 assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::<usize>());
375 let layout = layout.to_inline_size::<2>();
376 assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::<usize>());
377}
378
379#[test]
380fn test_num_elements() {
381 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20);
382 assert_eq!(layout.num_elements(), 24);
383}
384
385#[test]
386fn test_element_offset_little_endian() {
387 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
388 assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88);
389}
390
391#[test]
392fn test_element_offset_big_endian() {
393 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4);
394 assert_eq!(layout.element_offset(22, Endian::BigEndian), 88);
395}
396
397#[test]
398fn test_data_range_positive_strides() {
399 let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
400 let range = layout.data_range();
401 assert_eq!(range, 0..=92); }
403
404#[test]
405fn test_data_range_mixed_strides() {
406 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 0], 20);
407 let range = layout.data_range();
408 assert_eq!(range, 12..=32);
409}
410
411#[test]
412fn test_clone_and_eq() {
413 let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
414 let layout2 = layout1.clone();
415 assert!(layout1.eq(&layout2));
416}
417
418#[test]
419fn test_drop() {
420 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
421 drop(layout);
422}