Pitfall in Lambda

Pitfall in Lambda

2018, Mar 10    

嗨,大家好,這篇又再談談我遇到的另外一個問題:在建置機器學習模型中需要測試多個客製化的量測函式 (實際的問題我留在下一篇再來講),本篇以簡單的例子呈現原始問題的所在:

首先,假設實驗裡面需要測量二維向量的weighted inner product,定義一組函式如下,並賦予三組權重W來測試一下:

def w_metric(X, Y, W):
    return sum(x*y*w for x, y, w in zip(X, Y, W))

print(w_metric([1, 2], [4, 5], [1, 1]))
print(w_metric([1, 2], [4, 5], [1, 2]))
print(w_metric([1, 2], [4, 5], [2, 1]))

# >> 14
# >> 24
# >> 18


Using Lambda

接著,我把三組帶有不同W的量測函式物件打包起來,結果輸出不如預期;三組W確實在list comprehension裡面被展開,哪裡出錯了?

def make_w_metric_list(W):
    return [lambda X, Y: w_metric(X, Y, w) for w in W]

metric_list = make_w_metric_list([[1, 1], [1, 2], [2, 1]])

for metric in metric_list:
    print(metric([1, 2], [4, 5]))

# what I expect
# >> 14
# >> 24
# >> 18

# what actually output
# >> 18
# >> 18
# >> 18


Late Binding - What a Gotcha

吼吼抓到了,透過下面這段程式碼來理解,會發現變數b和函式simple_line之間的關係不像在C/C++的模式,但程式碼可以正常運作,Python的解析變數是當需要用到的時候,透過LEGB (local->enclosing->global->built-in) 規則去作name binding。

def simple_line(x, a):
    return a*x+b

b = 5
print(simple_line(1, 2))
# >> 7

b = 3
print(simple_line(1, 2))
# >> 5

相同的規則,回頭去看看原本程式碼,三組函式物件拿到的w其實指向W最後一個元素[2, 1],透過live object inspection來驗證一下,__code__有編譯後的bytecode資訊,找到w在這函式物件裡面是free variables,被closure閉包住了。

def make_w_metric_list(W):
    return [lambda X, Y: w_metric(X, Y, w) for w in W]

metric_list = make_w_metric_list([[1, 1], [1, 2], [2, 1]])

for metric in metric_list:
    print(metric([1, 2], [4, 5]))
    print(metric.__code__.co_freevars)
    print(metric.__closure__[0].cell_contents)

# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]
# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]
# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]

為了幫助理解lambda的行為,原本的函式定義可以等效以較明確的closure方式來呈現,以同樣方式來驗證看看,確實餵進去三組w_metricw都指向for loop後的最後一個w = [2, 1]。 好,有方法可以解決這種late binding嗎?

def make_w_metric_list(W):
    func_list = []
    for w in W:
        def _w_metric(X, Y):
            return w_metric(X, Y, w)
        func_list.append(_w_metric)
    return func_list

metric_list = make_w_metric_list([[1, 1], [1, 2], [2, 1]])

for metric in metric_list:
    print(metric([1, 2], [4, 5]))
    print(metric.__code__.co_freevars)
    print(metric.__closure__[0].cell_contents)

# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]
# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]
# >> 18.0
# >> ('w',)
# >> [2.0, 1.0]


Solution

Early Binding - Default Argument

那就early binding吧,因為Python的機制是function default argument實際上在definition time就被決定的,所以function的每一次呼叫都是使用同樣的default argument value。

def make_w_metric_list(W):
    return [lambda X, Y, W=w: w_metric(X, Y, W) for w in W]

metric_list = make_w_metric_list([[1, 1], [1, 2], [2, 1]])

for metric in metric_list:
    print(metric([1, 2], [4, 5]))

# >> 14.0
# >> 24.0
# >> 18.0

Functool.Partial

另外,我認為這個方式可讀性比較高,functool.partial可以把callable的物件重新包裝,並且可以預先設定default argument的值固定住。

def make_w_metric_list(W):
    return [partial(w_metric, W=w) for w in W]

metric_list = make_w_metric_list([[1, 1], [1, 2], [2, 1]])

for metric in metric_list:
    print(metric([1, 2], [4, 5]))

# >> 14.0
# >> 24.0
# >> 18.0


Reference

Python-Execution-Model
Python-Scope-of-Variables
Python-Live-Object-Inspection
Python-Function-Definition
Python-Functool-Partial