functools.wrapsについて調べてみた

flask-peeweeのRestAPIのソースを読んでいたところ、functools.wrapsという関数が使用されていたので調べてみました

ドキュメント
http://www.python.jp/doc/release/library/functools.html?highlight=functools#functools.wraps

functoolsとはなにか

モジュール functools は高階関数、つまり関数に対する関数、あるいは他の関数を返す関数、のためのものです。

functools.wrapsとはなにか

これはラッパ関数を定義するときに partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated) を関数デコレータとして呼び出す便宜関数です。 

partial関数と同様の処理のようです

functools.partialとはなにか

新しい partial オブジェクトを返します。このオブジェクトは呼び出されると位置引数 args とキーワード引数 keywords 付きで呼び出された func のように振る舞います。

わかりにくいですが、partialのサンプルプログラムを見るとちょっとわかります

from functools import partial
basetwo = partial(int, base=2)
basetwo.__doc__ = 'Convert base 2 string to an int.'
basetwo('10010')
# => 18

partial関数は、関数の引数を固定したバージョンとして振る舞うpartialオブジェクトを返します。
上記の例だとbasetwoの呼び出しは、int関数にbase=2を渡した状態で呼び出したのと同じになります

partialの代わりにwrapsを使ってみると以下のような感じです。

import functools
def wrapper(func,base=None):
    @functools.wraps(func)
    def inner(*args, **kwargs):
        kwargs.update(base=base)
        return func(*args, **kwargs)
    return inner

basetwo = wrapper(int,base=2)
basetwo('10010')
# => 18
baseeight = wrapper(int,base=8)
baseeight('10010')
# => 4104

flask-peeweeのRestAPIではどのように使っているのか

flaskext/rest.py

    def require_method(self, func, methods):
        @functools.wraps(func)
        def inner(*args, **kwargs):
            if request.method not in methods:
                return self.response_bad_method()
            return func(*args, **kwargs)
        return inner

    def get_urls(self):
        return (
            ('/', self.require_method(self.api_list, ['GET', 'POST'])),
            ('/<pk>/', self.require_method(self.api_detail, ['GET', 'PUT', 'DELETE'])),
        )

呼び出しurlに適応する関数を、HTTP METHODのチェック処理でラップするのに使用していました
api_listを呼び出すときはGETかPOST、api_detailを呼び出すときはGET、PUT、DELETEのいづれかでないと
response_bad_methodを返す、というわけですね


シンプルにするとちょっとわかりやすいです。

import functools

def wrapper(func):
   @functools.wraps(func)
   def inner(*args, **kwargs):
        print("*** wrapped ***")
        return func(*args, **kwargs)
   return inner

def test_func():
   print("*** test_func ***")

test_func()
# => *** test_func ***

wrapped_func = wrapper(test_func)
wrapped_func()
# => *** wrapped ***
# = *** test_func ***

@wrapper
def test_func2():
    print("*** test_func2 ***")

test_func2()
# => *** wrapped ***
# = *** test_func ***