ndarray_conv/lib.rs
1//! `ndarray-conv` provides N-dimensional convolution operations for `ndarray` arrays.
2//!
3//! This crate extends the `ndarray` library with both standard and
4//! FFT-accelerated convolution methods.
5//!
6//! # Getting Started
7//!
8//! To start performing convolutions, you'll interact with the following:
9//!
10//! 1. **Input Arrays:** Use `ndarray`'s [`Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
11//! or [`ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)
12//! as your input data and convolution kernel.
13//! 2. **Convolution Methods:** Call `array.conv(...)` or `array.conv_fft(...)`.
14//! These methods are added to `ArrayBase` types via the traits
15//! [`ConvExt::conv`] and [`ConvFFTExt::conv_fft`].
16//! 3. **Convolution Mode:** [`ConvMode`] specifies the size of the output.
17//! 4. **Padding Mode:** [`PaddingMode`] specifies how to handle array boundaries.
18//!
19//! # Basic Example:
20//!
21//! Here's a simple example of how to perform a 2D convolution using `ndarray-conv`:
22//!
23//! ```rust
24//! use ndarray::prelude::*;
25//! use ndarray_conv::{ConvExt, ConvFFTExt, ConvMode, PaddingMode};
26//!
27//! // Input data
28//! let input = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
29//!
30//! // Convolution kernel
31//! let kernel = array![[1, 1], [1, 1]];
32//!
33//! // Perform standard convolution with "same" output size and zero padding
34//! let output = input.conv(
35//! &kernel,
36//! ConvMode::Same,
37//! PaddingMode::Zeros,
38//! ).unwrap();
39//!
40//! println!("Standard Convolution Output:\n{:?}", output);
41//!
42//! // Perform FFT-accelerated convolution with "same" output size and zero padding
43//! let output_fft = input.map(|&x| x as f32).conv_fft(
44//! &kernel.map(|&x| x as f32),
45//! ConvMode::Same,
46//! PaddingMode::Zeros,
47//! ).unwrap();
48//!
49//! println!("FFT Convolution Output:\n{:?}", output_fft);
50//! ```
51//!
52//! # Choosing a convolution method
53//!
54//! * Use [`ConvExt::conv`] for standard convolution
55//! * Use [`ConvFFTExt::conv_fft`] for FFT accelerated convolution.
56//! FFT accelerated convolution is generally faster for larger kernels, but
57//! standard convolution may be faster for smaller kernels.
58//!
59//! # Key Structs, Enums and Traits
60//!
61//! * [`ConvMode`]: Specifies how to determine the size of the convolution output (e.g., `Full`, `Same`, `Valid`).
62//! * [`PaddingMode`]: Specifies how to handle array boundaries (e.g., `Zeros`, `Reflect`, `Replicate`). You can also use `PaddingMode::Custom` or `PaddingMode::Explicit` to combine different [`BorderType`] strategies for each dimension or for each side of each dimension.
63//! * [`BorderType`]: Used with [`PaddingMode`] for `Custom` and `Explicit`, specifies the padding strategy (e.g., `Zeros`, `Reflect`, `Replicate`, `Circular`).
64//! * [`ConvExt`]: The trait that adds the `conv` method, extending `ndarray` arrays with standard convolution functionality.
65//! * [`ConvFFTExt`]: The trait that adds the `conv_fft` method, extending `ndarray` arrays with FFT-accelerated convolution functionality.
66
67mod conv;
68mod conv_fft;
69mod dilation;
70mod padding;
71
72pub(crate) use padding::ExplicitPadding;
73
74pub use conv::ConvExt;
75pub use conv_fft::{ConvFFTExt, Processor as FftProcessor};
76pub use dilation::{WithDilation, ReverseKernel};
77
78/// Specifies the convolution mode, which determines the output size.
79#[derive(Debug, Clone, Copy)]
80pub enum ConvMode<const N: usize> {
81 /// The output has the largest size, including all positions where
82 /// the kernel and input overlap at least partially.
83 Full,
84 /// The output has the same size as the input.
85 Same,
86 /// The output has the smallest size, including only positions
87 /// where the kernel and input fully overlap.
88 Valid,
89 /// Specifies custom padding and strides.
90 Custom {
91 /// The padding to use for each dimension.
92 padding: [usize; N],
93 /// The strides to use for each dimension.
94 strides: [usize; N],
95 },
96 /// Specifies explicit padding and strides.
97 Explicit {
98 /// The padding to use for each side of each dimension.
99 padding: [[usize; 2]; N],
100 /// The strides to use for each dimension.
101 strides: [usize; N],
102 },
103}
104/// Specifies the padding mode, which determines how to handle borders.
105///
106/// The padding mode can be either a single `BorderType` applied on all sides
107/// or a custom tuple of two `BorderTypes` for each dimension or a `BorderType`
108/// for each side of each dimension.
109#[derive(Debug, Clone, Copy)]
110pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
111 /// Pads with zeros.
112 Zeros,
113 /// Pads with a constant value.
114 Const(T),
115 /// Reflects the input at the borders.
116 Reflect,
117 /// Replicates the edge values.
118 Replicate,
119 /// Treats the input as a circular buffer.
120 Circular,
121 /// Specifies a different `BorderType` for each dimension.
122 Custom([BorderType<T>; N]),
123 /// Specifies a different `BorderType` for each side of each dimension.
124 Explicit([[BorderType<T>; 2]; N]),
125}
126
127/// Used with [`PaddingMode`]. Specifies the padding mode for a single dimension
128/// or a single side of a dimension.
129#[derive(Debug, Clone, Copy)]
130pub enum BorderType<T: num::traits::NumAssign + Copy> {
131 /// Pads with zeros.
132 Zeros,
133 /// Pads with a constant value.
134 Const(T),
135 /// Reflects the input at the borders.
136 Reflect,
137 /// Replicates the edge values.
138 Replicate,
139 /// Treats the input as a circular buffer.
140 Circular,
141}
142
143use thiserror::Error;
144
145/// Error type for convolution operations.
146#[derive(Error, Debug)]
147pub enum Error<const N: usize> {
148 /// Indicates that the input data array has a dimension with zero size.
149 #[error("Data shape shouldn't have ZERO. {0:?}")]
150 DataShape(ndarray::Dim<[ndarray::Ix; N]>),
151 /// Indicates that the kernel array has a dimension with zero size.
152 #[error("Kernel shape shouldn't have ZERO. {0:?}")]
153 KernelShape(ndarray::Dim<[ndarray::Ix; N]>),
154 /// Indicates that the shape of the kernel with dilation is not compatible with the chosen `ConvMode`.
155 #[error("ConvMode {0:?} does not match KernelWithDilation Size {1:?}")]
156 MismatchShape(ConvMode<N>, [ndarray::Ix; N]),
157}