どこかに向かうらしい話

迷走エンジニアの放浪記

Python3におけるmap/filterの使い方

はじめに

map()/filter()の使いどころがよくわからないので、自分なりに調査をしてみた。
この2つの関数は、Python2系とPython3系では挙動が異なるので、まずはその話から。

まずは、具体的に実行内容を見てみたい。

python2系での実行結果

>>> print(range(10))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> print(map(lambda x: x*2, range(10)))
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> print(filter(lambda x: x % 2, range(10)))
[1, 3, 5, 7, 9]

python3系での実行結果

>>> print(range(10))
range(0, 10)
>>> print(map(lambda x: x*2, range(10)))
<map object at 0xb7bdd74c>
>>> print(filter(lambda x: x % 2, range(10)))
<filter object at 0xb7bdd74c>

Python2系では,上に挙げたそれぞれの関数、メソッドはそれぞれリストを返している。
しかし,Python3系ではリストではなく、iterator objectを返すように挙動が変更された。
iteratorを返すことで、一度に全要素を計算せず、要素が必要とされる時に計算するようになり、使用するメモリ量を低く抑えているとのこと。

2系のように、iteratorをリスト化したい場合はiteratorをlist関数に渡してリストを生成すればよい。

>>> print(list(map(lambda x: x*2, range(10))))
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]

>>> print(list(filter(lambda x: x % 2, range(10))))
[1, 3, 5, 7, 9]

set関数に渡せば、iteratorより集合も生成可能。

>>> print(set(map(lambda x: x*2, range(10))))
{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}

>>> print(set(filter(lambda x: x % 2, range(10))))
{1, 9, 3, 5, 7}

map()

map()とは、リストの要素に演算を適用してくれる関数。

map(func, iterable, ...)

第2引数以降のiterableを順にfuncに渡した結果をyieldで返す(generator)。
引数iteratbleは第1引数の関数の引数に順に渡される。

基本的な使い方としては、以下のように各要素の2乗されたリストを生成などがある。

>>> def square(x):
...     return x * x
...

>>> map(square, range(1, 10))
<map at 0x7f474ad63e80>

>>> list(map(square, range(1, 10)))
[1, 4, 9, 16, 25, 36, 49, 64, 81]

ラムダ式を用いれば、1行で同様の結果を得ることができる。

>>> map(lambda x: x*x, range(1, 10))
<map at 0x7f474ad67128>

>>> list(map(lambda x: x*x, range(1, 10)))
[1, 4, 9, 16, 25, 36, 49, 64, 81]

内包記法を使えば、直接リストを生成することができる。

>>> [x*x for x in range(1, 10)]
[1, 4, 9, 16, 25, 36, 49, 64, 81]

filter()

filter()とは、リストの要素を抽出してくれる関数。

filter(func, iterable)

iterableのオブジェクトを順にfuncに渡し、funcがTrueで返したオブジェクトをieldで返す(generator)。

>>> def is_mod(x):
...     return x % 3 == 1
...

基本的な使い方としては、以下のように各要素の3で割ったときに1余る要素のリストを生成などがある。

>>> filter(is_mod, range(1 ,10))
<filter at 0x7f474ad77f28>

>>> list(filter(is_mod, range(1 ,10)))
[1, 4, 7]

ラムダ式を用いれば、1行で同様の結果を得ることができる。

>>> filter(lambda x:x % 3 == 1, range(1, 10))
<filter at 0x7f474ad7d400>

>>> list(filter(lambda x:x % 3 == 1, range(1, 10)))
[1, 4, 7]

内包記法を使えば、直接リストを生成することができる。

>>> [x for x in range(1 ,10) if x % 3 == 1]
[1, 4, 7]

filter() と map()の組み合わせ

リストの要素を抽出して演算を適用する。
例えば、3で割ったとき余りが1になる要素を取得して、取得した各要素を2倍するとする。
通常の通りにやろうとすると、以下の通りになる。

>>> result = []

>>> for x in range(1, 10):
...     if x % 3 == 1:
...         result += [x * 2]
... print(result)
...
[2, 8, 14]

filter()、map()を用いて書くと、以下のようになる。

>>> map(lambda x: x * 2, filter(lambda x: x % 3 == 1, range(1, 10)))
 <map at 0x7f474ad772e8>

>>> list(map(lambda x: x * 2, filter(lambda x: x % 3 == 1, range(1, 10))))
[2, 8, 14]

内包記法を使ってのリスト生成は以下の通り。

>>> [x * 2 for x in range(1, 10) if x % 3 == 1]
[2, 8, 14]

最後に

今回のネタは以下のサイトを参考にしました。

特に「何気にPythonでつかっていた関数型プログラミング技法いろいろ」の情報は2系ではよくまとまっていると思う。
reduce()についてもどっかのタイミングで調べておきたいなと思った。