Pythonでflatten(多次元リストを一次元に平坦化)

Modified: | Tags: Python, リスト

Pythonで多次元のリスト(リストのリスト、ネストしたリスト)を一次元に平坦化するには、itertools.chain.from_iterable()sum()、リスト内包表記などを使う。

NumPy配列ndarrayの場合はflatten()またはravel()を使う。

反対に、一次元のNumPy配列ndarrayやリストを二次元に変換する方法については以下の記事を参照。

2次元のリストを平坦化

itertools.chain.from_iterable()

リストを要素として持つリスト(2次元リスト)を平坦化する場合、標準ライブラリのitertoolsのitertools.chain.from_iterable()を使う方法がある。

import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]

itertools.chain.from_iterable()はイテレータを返すので、リストに変換したい場合は上のサンプルコードのようにlist()を使う。for文で使う場合はリスト化する必要はない。

タプルも同様に処理できる。ここでは結果をtuple()でタプルにしている。リストにしたい場合はlist()を使えばよい。

t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)

itertools.chain.from_iterable()で平坦化できるのは2次元の場合のみ。3次元以上の場合(ネストが深い場合)は以下のような結果となる。

l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

また、要素の中にイテラブルオブジェクトではないものが含まれている場合はエラーとなる。

l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable

3次元以上の場合や要素の型が不規則な場合については後述。

sum()

組み込み関数のsum()を使う方法もある。

sum()の第二引数には初期値を指定できる。ここに空のリスト[]を指定すると、リストの+演算によって、要素のリストが連結される。

第二引数のデフォルト値は0なので、省略すると整数intとリストの+演算となってしまいエラーとなる。

l_2d = [[0, 1], [2, 3]]

print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'

タプルでも同様に処理可能。

t_2d = ((0, 1), (2, 3))

print(sum(t_2d, ()))
# (0, 1, 2, 3)

itertools.chain.from_iterable()と同じように、3次元以上の場合や要素の型が不規則な場合はうまくいかない。

l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
l_mix = [[0, 1], [2, 3], 4]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list

リスト内包表記

ネストしたリスト内包表記を用いる方法もある。

matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

flat = [x for row in matrix for x in row]
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]

以下のネストしたforループと等価。

flat = []
for row in matrix:
    for x in row:
        flat.append(x)

print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]

結果は省略するが、上の例のリスト内包表記の場合は、他の方法と同じく1階層しか平坦化できず、要素の中にイテラブルオブジェクトではないものが含まれているとエラーになる。

ネストを深くして3次元以上に対応させたり、要素の型によって条件分岐したりすることも可能だが、複雑になるのであまりおすすめはできない。

リスト内包表記についての詳細は以下の記事を参照。

処理速度の差

sum()はお手軽だが、行数(内部のリストの数)が多い場合はitertools.chain.from_iterable()やリスト内包表記よりも遥かに遅いので注意。行数が多く、かつ、処理速度やメモリ効率が重要な場面ではsum()は避けたほうがよい。

itertoolsをインポートする手間はあるが、itertools.chain.from_iterable()のほうがリスト内包表記よりも高速。

以下のコードはJupyter Notebook上でマジックコマンド%%timeitを使って計測したもの。Pythonスクリプトとして実行しても計測されないので注意。

5行の場合。

l_2d_5 = [[0, 1, 2] for i in range(5)]
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]

%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 537 ns ± 4.59 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
sum(l_2d_5, [])
# 319 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
[x for row in l_2d_5 for x in row]
# 764 ns ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

100行。

l_2d_100 = [[0, 1, 2] for i in range(100)]

%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 6.94 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
sum(l_2d_100, [])
# 35.5 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit
[x for row in l_2d_100 for x in row]
# 13.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

10000行。

l_2d_10000 = [[0, 1, 2] for i in range(10000)]

%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 552 µs ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sum(l_2d_10000, [])
# 343 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
[x for row in l_2d_10000 for x in row]
# 1.11 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

3次元以上のリストや不規則なリストを平坦化

3次元以上のリストや不規則なリストを平坦化するには関数を定義する。

以下を参考にした。

import collections

def flatten(l):
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

isinstance()で要素elの型をチェックして再帰的に処理している。

collections.abc.Iterableでイテラブルかどうかを判断。標準ライブラリのcollectionsをインポートする必要がある。

文字列strやバイト列bytesもイテラブルであるため除外している。除外しないと文字ごとに分解されてしまう。

この関数を使うと、あらゆる場合に対応できる。

l_2d = [[0, 1], [2, 3]]

print(list(flatten(l_2d)))
# [0, 1, 2, 3]
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
l_mix = [[0, 1], [2, 3], 4]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]

リストやタプル、rangeなど様々なイテラブルオブジェクトが含まれていても問題ない。

l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

対象をリストに限定すれば、collectionsをインポートしなくてもよい。タプルやrangeはそのままになってしまうが、多くの場合はこれで十分だろう。

def flatten_list(l):
    for el in l:
        if isinstance(el, list):
            yield from flatten_list(el)
        else:
            yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]

isinstance()の第二引数にはタプルで複数の型を指定できるので、必要な型のみ対象としてもよい。

def flatten_list_tuple_range(l):
    for el in l:
        if isinstance(el, (list, tuple, range)):
            yield from flatten_list_tuple_range(el)
        else:
            yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

もちろん、汎用的なのはcollections.abc.Iterableを使う方法。お好みで。

関連カテゴリー

関連記事