Skip to content

ENH: general concat with ExtensionArrays through find_common_type #33607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 2, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial find_common_type/_get_common_type + tests for IntegerDtype
  • Loading branch information
jorisvandenbossche committed Apr 17, 2020
commit 3464e95064ad1c1d4ac9d37e3d381215165a8ffe
13 changes: 11 additions & 2 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numbers
from typing import TYPE_CHECKING, Tuple, Type, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
import warnings

import numpy as np

from pandas._libs import lib, missing as libmissing
from pandas._typing import ArrayLike
from pandas._typing import ArrayLike, DtypeObj
from pandas.compat import set_function_name
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -95,6 +95,15 @@ def construct_array_type(cls) -> Type["IntegerArray"]:
"""
return IntegerArray

def _get_common_type(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be common_type or common_dtype? we've been loose about this distinction so far and i think it has caused amibiguity

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't care that much. I mainly used "type", because it is meant to be used in find_common_type.

(that find_common_type name is inspired on the numpy function, and that one actually handles both dtypes and scalar types, which I assume is the reason for the name. The pandas version, though, doesn't really make the distinction, so could have been named "find_common_dtype")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to "common_dtype" instead of "common_type". The internal function that uses this is still find_common_type, but that name from numpy is actually a misnomer here, since we are only dealing with dtypes, and not scalar types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for indulging me on this nitpick

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a private method on the Dtype? get_common_type (or get_common_dtype) seems fine

# for now only handle other integer types
if not all(isinstance(t, _IntegerDtype) for t in dtypes):
return None
np_dtype = np.find_common_type([t.numpy_dtype for t in dtypes], [])
if np.issubdtype(np_dtype, np.integer):
return _dtypes[str(np_dtype)]
return None

def __from_arrow__(
self, array: Union["pyarrow.Array", "pyarrow.ChunkedArray"]
) -> "IntegerArray":
Expand Down
27 changes: 27 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np

from pandas._typing import DtypeObj
from pandas.errors import AbstractMethodError

from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
Expand Down Expand Up @@ -322,3 +323,29 @@ def _is_boolean(self) -> bool:
bool
"""
return False

def _get_common_type(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm can we keep the return type as ExtensionDtype? Do you envision cases where we'd like to return a plain NumPy dtype?

Oh... I suppose tz-naive DatetimeArray might break this, since it wants to return a NumPy dtype...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was my first thought as well. But, right now, eg Categorical can end up with any kind of numpy dtype (depending on the dtype of its categories).

As long as not yet all dtypes have a EA version, I don't think it is feasible to require ExtensionDtype here

"""
Return the common dtype, if one exists.

Used in `find_common_type` implementation. This is for example used
to determine the resulting dtype in a concat operation.

If no common dtype exists, return None. If all dtypes in the list
will return None, then the common dtype will be "object" dtype.

Parameters
----------
dtypes : list of dtypes
The dtypes for which to determine a common dtype. This is a list
of np.dtype or ExtensionDtype instances.

Returns
-------
Common dtype (np.dtype or ExtensionDtype) or None
"""
if len(set(dtypes)) == 1:
# only itself
return self
else:
return None
7 changes: 6 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,12 @@ def find_common_type(types):
return first

if any(isinstance(t, ExtensionDtype) for t in types):
return np.object
for t in types:
if isinstance(t, ExtensionDtype):
res = t._get_common_type(types)
if res is not None:
return res
return np.dtype("object")

# take lowest unit
if all(is_datetime64_dtype(t) for t in types):
Expand Down
24 changes: 21 additions & 3 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from pandas.core.dtypes.cast import find_common_type
from pandas.core.dtypes.common import (
is_bool_dtype,
is_categorical_dtype,
Expand All @@ -17,6 +18,9 @@
)
from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries

from pandas.core.arrays import ExtensionArray
from pandas.core.construction import array


def get_dtype_kinds(l):
"""
Expand Down Expand Up @@ -99,9 +103,23 @@ def is_nonempty(x) -> bool:
single_dtype = len({x.dtype for x in to_concat}) == 1
any_ea = any(is_extension_array_dtype(x.dtype) for x in to_concat)

if any_ea and single_dtype and axis == 0:
cls = type(to_concat[0])
return cls._concat_same_type(to_concat)
if any_ea and axis == 0:
if not single_dtype:
target_dtype = find_common_type([x.dtype for x in to_concat])

def cast(arr, dtype):
if is_extension_array_dtype(dtype):
if isinstance(arr, np.ndarray):
return array(arr, dtype=dtype, copy=False)
return arr.astype(dtype, copy=False)

to_concat = [cast(arr, target_dtype) for arr in to_concat]

if isinstance(to_concat[0], ExtensionArray):
cls = type(to_concat[0])
return cls._concat_same_type(to_concat)
else:
np.concatenate(to_concat)

elif "category" in typs:
# this must be prior to concat_datetime,
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/arrays/integer/test_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

import pandas as pd
import pandas._testing as tm


@pytest.mark.parametrize(
"to_concat_dtypes, result_dtype",
[
(["Int64", "Int64"], "Int64"),
(["UInt64", "UInt64"], "UInt64"),
(["Int8", "Int8"], "Int8"),
(["Int8", "Int16"], "Int16"),
(["UInt8", "Int8"], "Int16"),
(["Int32", "UInt32"], "Int64"),
# this still gives object (awaiting float extension dtype)
(["Int64", "UInt64"], "object"),
],
)
def test_concat_series(to_concat_dtypes, result_dtype):

result = pd.concat([pd.Series([1, 2, pd.NA], dtype=t) for t in to_concat_dtypes])
expected = pd.concat([pd.Series([1, 2, pd.NA], dtype=object)] * 2).astype(
result_dtype
)
tm.assert_series_equal(result, expected)