datafusion_common/
test_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions to make testing DataFusion based crates easier
19
20use crate::arrow::util::pretty::pretty_format_batches_with_options;
21use crate::format::DEFAULT_FORMAT_OPTIONS;
22use arrow::array::RecordBatch;
23use std::{error::Error, path::PathBuf};
24
25/// Compares formatted output of a record batch with an expected
26/// vector of strings, with the result of pretty formatting record
27/// batches. This is a macro so errors appear on the correct line
28///
29/// Designed so that failure output can be directly copy/pasted
30/// into the test code as expected results.
31///
32/// Expects to be called about like this:
33///
34/// `assert_batches_eq!(expected_lines: &[&str], batches: &[RecordBatch])`
35///
36/// # Example
37/// ```
38/// # use std::sync::Arc;
39/// # use arrow::record_batch::RecordBatch;
40/// # use arrow::array::{ArrayRef, Int32Array};
41/// # use datafusion_common::assert_batches_eq;
42/// let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
43///  let batch = RecordBatch::try_from_iter([("column", col)]).unwrap();
44/// // Expected output is a vec of strings
45/// let expected = vec![
46///     "+--------+",
47///     "| column |",
48///     "+--------+",
49///     "| 1      |",
50///     "| 2      |",
51///     "+--------+",
52/// ];
53/// // compare the formatted output of the record batch with the expected output
54/// assert_batches_eq!(expected, &[batch]);
55/// ```
56#[macro_export]
57macro_rules! assert_batches_eq {
58    ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
59        let expected_lines: Vec<String> =
60            $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
61
62        let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
63            $CHUNKS,
64            &$crate::format::DEFAULT_FORMAT_OPTIONS,
65        )
66        .unwrap()
67        .to_string();
68
69        let actual_lines: Vec<&str> = formatted.trim().lines().collect();
70
71        assert_eq!(
72            expected_lines, actual_lines,
73            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
74            expected_lines, actual_lines
75        );
76    };
77}
78
79pub fn batches_to_string(batches: &[RecordBatch]) -> String {
80    let actual = pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)
81        .unwrap()
82        .to_string();
83
84    actual.trim().to_string()
85}
86
87pub fn batches_to_sort_string(batches: &[RecordBatch]) -> String {
88    let actual_lines =
89        pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)
90            .unwrap()
91            .to_string();
92
93    let mut actual_lines: Vec<&str> = actual_lines.trim().lines().collect();
94
95    // sort except for header + footer
96    let num_lines = actual_lines.len();
97    if num_lines > 3 {
98        actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
99    }
100
101    actual_lines.join("\n")
102}
103
104/// Compares formatted output of a record batch with an expected
105/// vector of strings in a way that order does not matter.
106/// This is a macro so errors appear on the correct line
107///
108/// See [`assert_batches_eq`] for more details and example.
109///
110/// Expects to be called about like this:
111///
112/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])`
113#[macro_export]
114macro_rules! assert_batches_sorted_eq {
115    ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
116        let mut expected_lines: Vec<String> =
117            $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
118
119        // sort except for header + footer
120        let num_lines = expected_lines.len();
121        if num_lines > 3 {
122            expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
123        }
124
125        let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
126            $CHUNKS,
127            &$crate::format::DEFAULT_FORMAT_OPTIONS,
128        )
129        .unwrap()
130        .to_string();
131        // fix for windows: \r\n -->
132
133        let mut actual_lines: Vec<&str> = formatted.trim().lines().collect();
134
135        // sort except for header + footer
136        let num_lines = actual_lines.len();
137        if num_lines > 3 {
138            actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
139        }
140
141        assert_eq!(
142            expected_lines, actual_lines,
143            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
144            expected_lines, actual_lines
145        );
146    };
147}
148
149/// A macro to assert that one string is contained within another with
150/// a nice error message if they are not.
151///
152/// Usage: `assert_contains!(actual, expected)`
153///
154/// Is a macro so test error
155/// messages are on the same line as the failure;
156///
157/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>)
158#[macro_export]
159macro_rules! assert_contains {
160    ($ACTUAL: expr, $EXPECTED: expr) => {
161        let actual_value: String = $ACTUAL.into();
162        let expected_value: String = $EXPECTED.into();
163        assert!(
164            actual_value.contains(&expected_value),
165            "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
166            expected_value,
167            actual_value
168        );
169    };
170}
171
172/// A macro to assert that one string is NOT contained within another with
173/// a nice error message if they are are.
174///
175/// Usage: `assert_not_contains!(actual, unexpected)`
176///
177/// Is a macro so test error
178/// messages are on the same line as the failure;
179///
180/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>)
181#[macro_export]
182macro_rules! assert_not_contains {
183    ($ACTUAL: expr, $UNEXPECTED: expr) => {
184        let actual_value: String = $ACTUAL.into();
185        let unexpected_value: String = $UNEXPECTED.into();
186        assert!(
187            !actual_value.contains(&unexpected_value),
188            "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
189            unexpected_value,
190            actual_value
191        );
192    };
193}
194
195/// Returns the datafusion test data directory, which is by default rooted at `datafusion/core/tests/data`.
196///
197/// The default can be overridden by the optional environment
198/// variable `DATAFUSION_TEST_DATA`
199///
200/// panics when the directory can not be found.
201///
202/// Example:
203/// ```
204/// let testdata = datafusion_common::test_util::datafusion_test_data();
205/// let csvdata = format!("{}/window_1.csv", testdata);
206/// assert!(std::path::PathBuf::from(csvdata).exists());
207/// ```
208pub fn datafusion_test_data() -> String {
209    match get_data_dir("DATAFUSION_TEST_DATA", "../../datafusion/core/tests/data") {
210        Ok(pb) => pb.display().to_string(),
211        Err(err) => panic!("failed to get arrow data dir: {err}"),
212    }
213}
214
215/// Returns the arrow test data directory, which is by default stored
216/// in a git submodule rooted at `testing/data`.
217///
218/// The default can be overridden by the optional environment
219/// variable `ARROW_TEST_DATA`
220///
221/// panics when the directory can not be found.
222///
223/// Example:
224/// ```
225/// let testdata = datafusion_common::test_util::arrow_test_data();
226/// let csvdata = format!("{}/csv/aggregate_test_100.csv", testdata);
227/// assert!(std::path::PathBuf::from(csvdata).exists());
228/// ```
229pub fn arrow_test_data() -> String {
230    match get_data_dir("ARROW_TEST_DATA", "../../testing/data") {
231        Ok(pb) => pb.display().to_string(),
232        Err(err) => panic!("failed to get arrow data dir: {err}"),
233    }
234}
235
236/// Returns the parquet test data directory, which is by default
237/// stored in a git submodule rooted at
238/// `parquet-testing/data`.
239///
240/// The default can be overridden by the optional environment variable
241/// `PARQUET_TEST_DATA`
242///
243/// panics when the directory can not be found.
244///
245/// Example:
246/// ```
247/// let testdata = datafusion_common::test_util::parquet_test_data();
248/// let filename = format!("{}/binary.parquet", testdata);
249/// assert!(std::path::PathBuf::from(filename).exists());
250/// ```
251#[cfg(feature = "parquet")]
252pub fn parquet_test_data() -> String {
253    match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") {
254        Ok(pb) => pb.display().to_string(),
255        Err(err) => panic!("failed to get parquet data dir: {err}"),
256    }
257}
258
259/// Returns a directory path for finding test data.
260///
261/// udf_env: name of an environment variable
262///
263/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR)
264///
265///  Returns either:
266/// The path referred to in `udf_env` if that variable is set and refers to a directory
267/// The submodule_data directory relative to CARGO_MANIFEST_PATH
268pub fn get_data_dir(
269    udf_env: &str,
270    submodule_data: &str,
271) -> Result<PathBuf, Box<dyn Error>> {
272    // Try user defined env.
273    if let Ok(dir) = std::env::var(udf_env) {
274        let trimmed = dir.trim().to_string();
275        if !trimmed.is_empty() {
276            let pb = PathBuf::from(trimmed);
277            if pb.is_dir() {
278                return Ok(pb);
279            } else {
280                return Err(format!(
281                    "the data dir `{}` defined by env {} not found",
282                    pb.display(),
283                    udf_env
284                )
285                .into());
286            }
287        }
288    }
289
290    // The env is undefined or its value is trimmed to empty, let's try default dir.
291
292    // env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package",
293    // set by `cargo run` or `cargo test`, see:
294    // https://doc.rust-lang.org/cargo/reference/environment-variables.html
295    let dir = env!("CARGO_MANIFEST_DIR");
296
297    let pb = PathBuf::from(dir).join(submodule_data);
298    if pb.is_dir() {
299        Ok(pb)
300    } else {
301        Err(format!(
302            "env `{}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\
303             HINT: try running `git submodule update --init`",
304            udf_env,
305            pb.display(),
306        ).into())
307    }
308}
309
310#[macro_export]
311macro_rules! create_array {
312    (Boolean, $values: expr) => {
313        std::sync::Arc::new(arrow::array::BooleanArray::from($values))
314    };
315    (Int8, $values: expr) => {
316        std::sync::Arc::new(arrow::array::Int8Array::from($values))
317    };
318    (Int16, $values: expr) => {
319        std::sync::Arc::new(arrow::array::Int16Array::from($values))
320    };
321    (Int32, $values: expr) => {
322        std::sync::Arc::new(arrow::array::Int32Array::from($values))
323    };
324    (Int64, $values: expr) => {
325        std::sync::Arc::new(arrow::array::Int64Array::from($values))
326    };
327    (UInt8, $values: expr) => {
328        std::sync::Arc::new(arrow::array::UInt8Array::from($values))
329    };
330    (UInt16, $values: expr) => {
331        std::sync::Arc::new(arrow::array::UInt16Array::from($values))
332    };
333    (UInt32, $values: expr) => {
334        std::sync::Arc::new(arrow::array::UInt32Array::from($values))
335    };
336    (UInt64, $values: expr) => {
337        std::sync::Arc::new(arrow::array::UInt64Array::from($values))
338    };
339    (Float16, $values: expr) => {
340        std::sync::Arc::new(arrow::array::Float16Array::from($values))
341    };
342    (Float32, $values: expr) => {
343        std::sync::Arc::new(arrow::array::Float32Array::from($values))
344    };
345    (Float64, $values: expr) => {
346        std::sync::Arc::new(arrow::array::Float64Array::from($values))
347    };
348    (Utf8, $values: expr) => {
349        std::sync::Arc::new(arrow::array::StringArray::from($values))
350    };
351}
352
353/// Creates a record batch from literal slice of values, suitable for rapid
354/// testing and development.
355///
356/// Example:
357/// ```
358/// use datafusion_common::{record_batch, create_array};
359/// let batch = record_batch!(
360///     ("a", Int32, vec![1, 2, 3]),
361///     ("b", Float64, vec![Some(4.0), None, Some(5.0)]),
362///     ("c", Utf8, vec!["alpha", "beta", "gamma"])
363/// );
364/// ```
365#[macro_export]
366macro_rules! record_batch {
367    ($(($name: expr, $type: ident, $values: expr)),*) => {
368        {
369            let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![
370                $(
371                    arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true),
372                )*
373            ]));
374
375            let batch = arrow::array::RecordBatch::try_new(
376                schema,
377                vec![$(
378                    $crate::create_array!($type, $values),
379                )*]
380            );
381
382            batch
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use crate::cast::{as_float64_array, as_int32_array, as_string_array};
390    use crate::error::Result;
391
392    use super::*;
393    use std::env;
394
395    #[test]
396    fn test_data_dir() {
397        let udf_env = "get_data_dir";
398        let cwd = env::current_dir().unwrap();
399
400        let existing_pb = cwd.join("..");
401        let existing = existing_pb.display().to_string();
402        let existing_str = existing.as_str();
403
404        let non_existing = cwd.join("non-existing-dir").display().to_string();
405        let non_existing_str = non_existing.as_str();
406
407        env::set_var(udf_env, non_existing_str);
408        let res = get_data_dir(udf_env, existing_str);
409        assert!(res.is_err());
410
411        env::set_var(udf_env, "");
412        let res = get_data_dir(udf_env, existing_str);
413        assert!(res.is_ok());
414        assert_eq!(res.unwrap(), existing_pb);
415
416        env::set_var(udf_env, " ");
417        let res = get_data_dir(udf_env, existing_str);
418        assert!(res.is_ok());
419        assert_eq!(res.unwrap(), existing_pb);
420
421        env::set_var(udf_env, existing_str);
422        let res = get_data_dir(udf_env, existing_str);
423        assert!(res.is_ok());
424        assert_eq!(res.unwrap(), existing_pb);
425
426        env::remove_var(udf_env);
427        let res = get_data_dir(udf_env, non_existing_str);
428        assert!(res.is_err());
429
430        let res = get_data_dir(udf_env, existing_str);
431        assert!(res.is_ok());
432        assert_eq!(res.unwrap(), existing_pb);
433    }
434
435    #[test]
436    #[cfg(feature = "parquet")]
437    fn test_happy() {
438        let res = arrow_test_data();
439        assert!(PathBuf::from(res).is_dir());
440
441        let res = parquet_test_data();
442        assert!(PathBuf::from(res).is_dir());
443    }
444
445    #[test]
446    fn test_create_record_batch() -> Result<()> {
447        use arrow::array::Array;
448
449        let batch = record_batch!(
450            ("a", Int32, vec![1, 2, 3, 4]),
451            ("b", Float64, vec![Some(4.0), None, Some(5.0), None]),
452            ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"])
453        )?;
454
455        assert_eq!(3, batch.num_columns());
456        assert_eq!(4, batch.num_rows());
457
458        let values: Vec<_> = as_int32_array(batch.column(0))?
459            .values()
460            .iter()
461            .map(|v| v.to_owned())
462            .collect();
463        assert_eq!(values, vec![1, 2, 3, 4]);
464
465        let values: Vec<_> = as_float64_array(batch.column(1))?
466            .values()
467            .iter()
468            .map(|v| v.to_owned())
469            .collect();
470        assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]);
471
472        let nulls: Vec<_> = as_float64_array(batch.column(1))?
473            .nulls()
474            .unwrap()
475            .iter()
476            .collect();
477        assert_eq!(nulls, vec![true, false, true, false]);
478
479        let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect();
480        assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]);
481
482        Ok(())
483    }
484}