chainerのlinks.CRF1dとfunctions.crf1dの使い方メモ。
crfと同様に系列を扱うNStepLSTMとかと同じ気分で使うと事故る。
概要
functions.crf1d(cost, xs, ts)
遷移スコア、タグスコア、正解ラベルをそれぞれ受け取って、最適パスをVariableのlistで返す。
遷移スコアは
\ 1,2,3 1 . . . 2 . . . 3 . . .
って感じの行列で、例えばタグ1から2への遷移スコアを5にするなら以下のようにする。
\ 1,2,3 1 . 5 . 2 . . . 3 . . .
他のスコアも適当に埋めて渡せばいい。
links.CRF1d(xs, ts)
最適な遷移スコアをヒューリスティックに探すのはしんどいので遷移スコアを学習する。
中でfunctions.crf1dを呼ぶので入出力は同じ形。
入力(xsの形)
# xs for crf1d # batch size = 2 # batch = [(x,x,x), (x,x)] [ 1st :((1,0,0),(1,0,0)) 2nd :((0,1,0),(0,1,0)) 3rd :((0,0,1) ]
[]はリスト、()はVariableかnumpyのarray。
リストのインデックスが各時刻、行列がバッチ内の各タグスコア。
実際のコードだと以下のような感じ。
links.CRF1dを使っているが、functionsの方を使うなら上で述べた行列を第一引数に渡せばいい。
>>> import chainer >>> from chainer import links as L >>> import numpy as xp >>> crfL = L.CRF1d(3) #tag sizeを3に設定 #縦方向が時刻 >>> xs = [xp.array([[1,0,0]], xp.float32), xp.array([[0,1,0]], xp.float32), xp.array([[0,0,1]], xp.float32)] >>> crfL.argmax(xs)[1] #.argmaxはscoreとpathのタプルを返す [array([0], dtype=int32), array([1], dtype=int32), array([2], dtype=int32)]
# batch処理の場合はこう。 # 長さ3の系列を二つまとめて入力。 # タグサイズは3 >>> xs = [xp.array([[1,0,0],[0,0,1]], xp.float32), ... xp.array([[0,1,0],[0,1,0]], xp.float32), ... xp.array([[0,0,1],[1,0,0]], xp.float32)] >>> crfL.argmax(xs)[1] [array([0, 2], dtype=int32), array([1, 1], dtype=int32), array([2, 0], dtype=int32)] #可変長なbatch入力をするときは長い順にソートする必要がある。 >>> xs = [xp.array([[1,0,0],[1,0,0]], xp.float32), ... xp.array([[0,1,0],[0,1,0]], xp.float32), ... xp.array([[0,1,0]], xp.float32)] >>> crfL.argmax(xs)[1] [array([0, 0], dtype=int32), array([1, 1], dtype=int32), array([1], dtype=int32)]
入力(tsの形)
教師データもxsと同じで、縦にして大きい順に横に並べる。
# ts for crf1d # batch size = 2 # batch = [(x,x,x), (x,x)] 1st :[(0,0)] 2nd :[(1,1)] 3rd :[(2)]
# input >>> xs = [xp.array([[1,0,0],[1,0,0]], xp.float32), ... xp.array([[0,1,0],[0,1,0]], xp.float32), ... xp.array([[0,1,0]], xp.float32)] # teacher >>> ts = [xp.array([0,0],xp.int32), ... xp.array([1,1],xp.int32), ... xp.array([2],xp.int32)] >>> crfL(xs, ts) <variable at 0x10ea6cac8> # variableがかえってくる >>> crfL(xs, ts).data array(1.87861168384552, dtype=float32)
ちなみに僕が混乱したのは、同じ系列を扱うNStepLSTMの入力が以下のようだから。
xs = [(1,2,3), (1,2,3,4,5), (1,2,3,4)] len(xs): batch Size
BiLSTMでのリストがバッチサイズを表しているのに対して、crf1dのリストは各時刻の入力を表している。
超紛らわしい。
(余談)
CRFを使うときにsoftmaxを噛ませると学習できないので、pre-trainingとかしてるひとは気をつけてください。
functions使うときにスコアをある程度正規化してから渡したいなぁ〜とか思ってると、普通に学習できなくて無理。
わかりづらいとか、間違いとかあったらコメントでお願いします。