ChainerでChain(モデル)にxpを渡したくない

class Model(Chain):
  def __init__(self):
    I = self.xp.eyes(size)

ってかんじで、selfからxp呼べるみたいです(しらなかった!)。
chanteraさんありがとうございます!

 

以下苦しんでた形跡

 

メモです。

Chainerで学習器を定義するとき、Chainクラスを継承する。
学習器というか、ネットワークのことをモデルと呼ぶことにする。

例えば単位行列を使いたいときとか、モデルの中でnumpyを使いたいことがたまにある。
numpyはGPU環境でcupyになるので、単純にはモデルをnewするときにコンストラクタに引数を渡さないといけない。
慣例として、CPU環境ならxp=numpy、GPU環境ならxp=cupyとしてxpをモデルに渡したりする。

import chainer
class Model(chainer.Chain):
  def __init__(self, xp):
    super().__init__()
    self.I = xp.eye(size).astype('f') # 単位行列

  def forward(self, x):
    # 何か単位行列を使う処理
    y = x + self.I
    return y

上の例はネットワークでも何でもない嘘コードだけど、こんな感じにすることが多かった。
入力は正方行列に限定する。
これだと引数が汚くなるのであまり好きじゃない。

これを以下のようにすると、xpを使わずに済む。

import chainer
from copy import copy
class Model(chainer.Chain):
  def __init__(self):
    super().__init__()
    self.I = None

  def forward(self, x):
    # 最初だけself.Iを作る
    if self.I is None:
      self.setI(x)

    # 何か単位行列を使う処理
    y = x + self.I
    return y

  def setI(self, x):
    # xと同じサイズの単位行列をxp無しで作る
    self.I = copy(x.data)
    chainer.initializers.Identity(1)(self.I)

xはcupy.arrayなので、これをもとに必要な行列を作ってやれば、一回の初期化のためにxpを渡さなくても済む。

前向き計算の処理によっては無理だけど、外から見たクラスがきれいになるので良い。

このへんって皆さんどうしてるんですか。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です