Pythonでflatten(多次元リストを一次元に平坦化)
Pythonで多次元のリスト(リストのリスト、ネストしたリスト)を一次元に平坦化するには、itertools.chain.from_iterable()
やsum()
、リスト内包表記などを使う。
NumPy配列ndarray
の場合はflatten()
またはravel()
を使う。
反対に、一次元のNumPy配列ndarray
やリストを二次元に変換する方法については以下の記事を参照。
2次元のリストを平坦化
itertools.chain.from_iterable()
リストを要素として持つリスト(2次元リスト)を平坦化する場合、標準ライブラリのitertoolsのitertools.chain.from_iterable()
を使う方法がある。
import itertools
l_2d = [[0, 1], [2, 3]]
print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]
itertools.chain.from_iterable()
はイテレータを返すので、リストに変換したい場合は上のサンプルコードのようにlist()
を使う。for文で使う場合はリスト化する必要はない。
タプルも同様に処理できる。ここでは結果をtuple()
でタプルにしている。リストにしたい場合はlist()
を使えばよい。
t_2d = ((0, 1), (2, 3))
print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)
itertools.chain.from_iterable()
で平坦化できるのは2次元の場合のみ。3次元以上の場合(ネストが深い場合)は以下のような結果となる。
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
また、要素の中にイテラブルオブジェクトではないものが含まれている場合はエラーとなる。
l_mix = [[0, 1], [2, 3], 4]
# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable
3次元以上の場合や要素の型が不規則な場合については後述。
sum()
組み込み関数のsum()
を使う方法もある。
sum()
の第二引数には初期値を指定できる。ここに空のリスト[]
を指定すると、リストの+
演算によって、要素のリストが連結される。
第二引数のデフォルト値は0
なので、省略すると整数int
とリストの+
演算となってしまいエラーとなる。
l_2d = [[0, 1], [2, 3]]
print(sum(l_2d, []))
# [0, 1, 2, 3]
# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'
タプルでも同様に処理可能。
t_2d = ((0, 1), (2, 3))
print(sum(t_2d, ()))
# (0, 1, 2, 3)
itertools.chain.from_iterable()
と同じように、3次元以上の場合や要素の型が不規則な場合はうまくいかない。
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
l_mix = [[0, 1], [2, 3], 4]
# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list
リスト内包表記
ネストしたリスト内包表記を用いる方法もある。
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = [x for row in matrix for x in row]
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
以下のネストしたforループと等価。
flat = []
for row in matrix:
for x in row:
flat.append(x)
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
結果は省略するが、上の例のリスト内包表記の場合は、他の方法と同じく1階層しか平坦化できず、要素の中にイテラブルオブジェクトではないものが含まれているとエラーになる。
ネストを深くして3次元以上に対応させたり、要素の型によって条件分岐したりすることも可能だが、複雑になるのであまりおすすめはできない。
リスト内包表記についての詳細は以下の記事を参照。
- 関連記事: Pythonリスト内包表記の使い方
処理速度の差
sum()
はお手軽だが、行数(内部のリストの数)が多い場合はitertools.chain.from_iterable()
やリスト内包表記よりも遥かに遅いので注意。行数が多く、かつ、処理速度やメモリ効率が重要な場面ではsum()
は避けたほうがよい。
itertoolsをインポートする手間はあるが、itertools.chain.from_iterable()
のほうがリスト内包表記よりも高速。
以下のコードはJupyter Notebook上でマジックコマンド%%timeit
を使って計測したもの。Pythonスクリプトとして実行しても計測されないので注意。
5行の場合。
l_2d_5 = [[0, 1, 2] for i in range(5)]
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]
%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 537 ns ± 4.59 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%%timeit
sum(l_2d_5, [])
# 319 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%%timeit
[x for row in l_2d_5 for x in row]
# 764 ns ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
100行。
l_2d_100 = [[0, 1, 2] for i in range(100)]
%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 6.94 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
sum(l_2d_100, [])
# 35.5 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
[x for row in l_2d_100 for x in row]
# 13.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10000行。
l_2d_10000 = [[0, 1, 2] for i in range(10000)]
%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 552 µs ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
sum(l_2d_10000, [])
# 343 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
[x for row in l_2d_10000 for x in row]
# 1.11 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3次元以上のリストや不規則なリストを平坦化
3次元以上のリストや不規則なリストを平坦化するには関数を定義する。
以下を参考にした。
import collections
def flatten(l):
for el in l:
if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten(el)
else:
yield el
isinstance()
で要素el
の型をチェックして再帰的に処理している。
collections.abc.Iterable
でイテラブルかどうかを判断。標準ライブラリのcollectionsをインポートする必要がある。
- collections.abc.Iterable --- コレクションの抽象基底クラス — Python 3.7.3 ドキュメント
- 関連記事: Pythonのhasattr(), 抽象基底クラスABCによるダックタイピング
文字列str
やバイト列bytes
もイテラブルであるため除外している。除外しないと文字ごとに分解されてしまう。
この関数を使うと、あらゆる場合に対応できる。
l_2d = [[0, 1], [2, 3]]
print(list(flatten(l_2d)))
# [0, 1, 2, 3]
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
l_mix = [[0, 1], [2, 3], 4]
print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]
リストやタプル、range
など様々なイテラブルオブジェクトが含まれていても問題ない。
l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]
print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
対象をリストに限定すれば、collectionsをインポートしなくてもよい。タプルやrange
はそのままになってしまうが、多くの場合はこれで十分だろう。
def flatten_list(l):
for el in l:
if isinstance(el, list):
yield from flatten_list(el)
else:
yield el
print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]
print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]
print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]
isinstance()
の第二引数にはタプルで複数の型を指定できるので、必要な型のみ対象としてもよい。
def flatten_list_tuple_range(l):
for el in l:
if isinstance(el, (list, tuple, range)):
yield from flatten_list_tuple_range(el)
else:
yield el
print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
もちろん、汎用的なのはcollections.abc.Iterable
を使う方法。お好みで。