Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
REF: melt
  • Loading branch information
mroeschke committed Nov 9, 2023
commit fa85039fb3fffc6f13b6445bbb7144116db46ed7
120 changes: 55 additions & 65 deletions pandas/core/reshape/melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@

import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical
import pandas.core.common as com
from pandas.core.indexes.api import (
Index,
MultiIndex,
)
from pandas.core.indexes.api import MultiIndex
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import tile_compat
from pandas.core.shared_docs import _shared_docs
Expand All @@ -26,76 +22,68 @@
if TYPE_CHECKING:
from collections.abc import Hashable

from pandas._typing import AnyArrayLike
from pandas._typing import (
AnyArrayLike,
Scalar,
)

from pandas import DataFrame


def ensure_list_vars(
arg_vars: Scalar | AnyArrayLike | None, variable: str, columns
) -> list:
if arg_vars is not None:
if not is_list_like(arg_vars):
return [arg_vars]
elif isinstance(columns, MultiIndex) and not isinstance(arg_vars, list):
raise ValueError(
f"{variable} must be a list of tuples when columns are a MultiIndex"
)
else:
return list(arg_vars)
else:
return []


@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
def melt(
frame: DataFrame,
id_vars=None,
value_vars=None,
var_name=None,
value_name: Hashable = "value",
col_level=None,
id_vars: Scalar | AnyArrayLike | None = None,
value_vars: Scalar | AnyArrayLike | None = None,
var_name: Scalar | None = None,
value_name: Scalar = "value",
col_level: Scalar | None = None,
ignore_index: bool = True,
) -> DataFrame:
# If multiindex, gather names of columns on all level for checking presence
# of `id_vars` and `value_vars`
if isinstance(frame.columns, MultiIndex):
cols = [x for c in frame.columns for x in c]
else:
cols = list(frame.columns)

if value_name in frame.columns:
raise ValueError(
f"value_name ({value_name}) cannot match an element in "
"the DataFrame columns."
)
id_vars = ensure_list_vars(id_vars, "id_vars", frame.columns)
value_vars = ensure_list_vars(value_vars, "value_vars", frame.columns)

if id_vars is not None:
if not is_list_like(id_vars):
id_vars = [id_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(id_vars, list):
raise ValueError(
"id_vars must be a list of tuples when columns are a MultiIndex"
)
else:
# Check that `id_vars` are in frame
id_vars = list(id_vars)
missing = Index(com.flatten(id_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'id_vars' are not present "
f"in the DataFrame: {list(missing)}"
)
else:
id_vars = []

if value_vars is not None:
if not is_list_like(value_vars):
value_vars = [value_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(value_vars, list):
raise ValueError(
"value_vars must be a list of tuples when columns are a MultiIndex"
)
else:
value_vars = list(value_vars)
# Check that `value_vars` are in frame
missing = Index(com.flatten(value_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'value_vars' are not present in "
f"the DataFrame: {list(missing)}"
)
if id_vars or value_vars:
if col_level is not None:
idx = frame.columns.get_level_values(col_level).get_indexer(
id_vars + value_vars
level = frame.columns.get_level_values(col_level)
else:
level = frame.columns
labels = id_vars + value_vars
idx = level.get_indexer_for(labels)
missing = idx == -1
if missing.any():
missing_labels = [
lab for lab, not_found in zip(labels, missing) if not_found
]
raise KeyError(
"The following 'id_vars' or 'value_vars' are not present in "
f"the DataFrame: {missing_labels}"
)
if value_vars:
frame = frame.iloc[:, algos.unique(idx)]
else:
idx = algos.unique(frame.columns.get_indexer_for(id_vars + value_vars))
frame = frame.iloc[:, idx]
frame = frame.copy()
else:
frame = frame.copy()

Expand All @@ -113,24 +101,26 @@ def melt(
var_name = [
frame.columns.name if frame.columns.name is not None else "variable"
]
if isinstance(var_name, str):
elif is_list_like(var_name):
raise ValueError(f"{var_name=} must be a scalar.")
else:
var_name = [var_name]

N, K = frame.shape
K -= len(id_vars)
num_rows, K = frame.shape
num_cols_adjusted = K - len(id_vars)

mdata: dict[Hashable, AnyArrayLike] = {}
for col in id_vars:
id_data = frame.pop(col)
if not isinstance(id_data.dtype, np.dtype):
# i.e. ExtensionDtype
if K > 0:
mdata[col] = concat([id_data] * K, ignore_index=True)
if num_cols_adjusted > 0:
mdata[col] = concat([id_data] * num_cols_adjusted, ignore_index=True)
else:
# We can't concat empty list. (GH 46044)
mdata[col] = type(id_data)([], name=id_data.name, dtype=id_data.dtype)
else:
mdata[col] = np.tile(id_data._values, K)
mdata[col] = np.tile(id_data._values, num_cols_adjusted)

mcolumns = id_vars + var_name + [value_name]

Expand All @@ -143,12 +133,12 @@ def melt(
else:
mdata[value_name] = frame._values.ravel("F")
for i, col in enumerate(var_name):
mdata[col] = frame.columns._get_level_values(i).repeat(N)
mdata[col] = frame.columns._get_level_values(i).repeat(num_rows)

result = frame._constructor(mdata, columns=mcolumns)

if not ignore_index:
result.index = tile_compat(frame.index, K)
result.index = tile_compat(frame.index, num_cols_adjusted)

return result

Expand Down
8 changes: 4 additions & 4 deletions pandas/core/shared_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,17 @@

Parameters
----------
id_vars : tuple, list, or ndarray, optional
id_vars : scalar, tuple, list, or ndarray, optional
Column(s) to use as identifier variables.
value_vars : tuple, list, or ndarray, optional
value_vars : scalar, tuple, list, or ndarray, optional
Column(s) to unpivot. If not specified, uses all columns that
are not set as `id_vars`.
var_name : scalar
var_name : scalar, default None
Name to use for the 'variable' column. If None it uses
``frame.columns.name`` or 'variable'.
value_name : scalar, default 'value'
Name to use for the 'value' column, can't be an existing column label.
col_level : int or str, optional
col_level : scalar, optional
If columns are a MultiIndex then use this level to melt.
ignore_index : bool, default True
If True, original index is ignored. If False, the original index is retained.
Expand Down
34 changes: 34 additions & 0 deletions pandas/tests/reshape/test_melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,40 @@ def test_melt_preserves_datetime(self):
)
tm.assert_frame_equal(result, expected)

def test_melt_allows_non_scalar_id_vars(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
result = df.melt(
id_vars="a",
var_name=0,
value_name=1,
)
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
tm.assert_frame_equal(result, expected)

def test_melt_allows_non_string_var_name(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
result = df.melt(
id_vars=["a"],
var_name=0,
value_name=1,
)
expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]})
tm.assert_frame_equal(result, expected)

def test_melt_non_scalar_var_name_raises(self):
df = DataFrame(
data={"a": [1, 2, 3], "b": [4, 5, 6]},
index=["11", "22", "33"],
)
with pytest.raises(ValueError, match=r".* must be a scalar."):
df.melt(id_vars=["a"], var_name=[1, 2])


class TestLreshape:
def test_pairs(self):
Expand Down