TechFULの中の人

TechFULスタッフ・エンジニアによる技術ブログ。IT関連のことやTechFULコーディングバトルの超難問の深掘り・解説などを毎週更新

PyTorchのLazy modulesを使ってみる

こんにちは。TechFULでアルバイトをしているberryberryです。 TechFUL PROで機械学習に関する問題を作成しています。 今回は、ディープラーニングを実装できるライブラリであるPyTorchに実装されている、Lazy modulesを紹介したいと思います。

Lazy modulesは、パラメータを後から初期化してくれる機能です。Lazy〇〇という名前の層が該当し、例えばLazyLinear、LazyConv2d、LazyBatchNorm2dなどがあります。 TensorFlowやChainerなどにも似たような機能がありますが、PyTorchでは最近になって導入されました。 Lazy modulesは、ネットワークの実装を少しだけ楽にしてくれます。今回は、このLazy modulesの使い方と注意点、役に立ちそうな場合について紹介したいと思います。

使い方

画像分類などに使われる、2次元の畳み込みニューラルネットワークについて考えます。 このネットワークで使われる畳み込み層はConv2d、そのLazy moduleはLazyConv2dです。これらの層は、それぞれ次のように定義します。

import torch
import torch.nn as nn

in_channels = 3  # 入力チャネル数
out_channels = 32  # 出力チャネル数

# 通常の畳み込み層
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

# Lazy module
lazy_conv = nn.LazyConv2d(out_channels, kernel_size=3, stride=1, padding=1)

LazyConv2dはConv2dと同じような使い方が可能です。

# オプティマイザの設定
optimizer = torch.optim.SGD(lazy_conv.parameters(), lr=0.01)

# 状態の読み込み
state = lazy_conv.state_dict()
lazy_conv.load_state_dict(state)

# 通常のConv2dの状態をLazyConv2dに読み込むことも可能
state = conv.state_dict()
lazy_conv.load_state_dict(state)

Lazy modulesを使ってみる

さて、最初に挙げた例でそれぞれの引数を見比べるとLazyConv2dは入力のチャネル数を受け取っていません。入力のチャネル数が分からないと、畳み込み層の重みは定義できないはずですが、内部はどうなっているのでしょうか。そこで、それぞれの層の重み(weight)の型を見てみましょう。

# 出力:<class 'torch.nn.parameter.Parameter'>
print(type(conv.weight))  
# 出力:<class 'torch.nn.parameter.UninitializedParameter'>
print(type(lazy_conv.weight))  

見るとわかるように、Lazy moduleの場合重みの型がUninitializedParameterになっています。また、biasも同様にUninitializedParameterです。これがパラメータを後から初期化するという言葉の意味で、Lazy modulesは定義された段階ではまだパラメータを初期化することはせず、ただUninitializedParameterという型にしておいて、後から必要になる情報を保持しています。そして、入力が渡されたときに初めてパラメータを初期化します。なので、次のコードのようにlazy_convに入力を渡してから、再度重みの型を見るとUninitializedParameterがParameterになっています。

import torch
input = torch.zeros([1, 3, 32, 32]).float()
lazy_conv(input)  # 1回入力を渡す
print(type(lazy_conv.weight))  # 出力:<class 'torch.nn.parameter.Parameter'>

さて、実際に畳み込み層を使うときは、何層か重ねて畳み込みニューラルネットワークとして扱うことが多いと思います。この場合、それぞれの層に対して個別に入力を渡す必要はなく、畳み込みニューラルネットワークに対して入力を渡せば一度にすべての畳み込み層を初期化してくれます。余程変わったことをしなければ、ニューラルネットワークを学習するときは自然に入力を渡すので、その場合 1イテレーション目にLazy moduleが自動的に初期化されます。 そのため、ユーザ側でLazy moduleを明示的に初期化する必要はほとんどありません。便利ですね。

lazy_cnn = nn.Sequential(
    nn.LazyConv2d(32, kernel_size=3, stride=1, padding=1),
    nn.LazyConv2d(64, kernel_size=3, stride=1, padding=1),
    nn.LazyConv2d(128, kernel_size=3, stride=1, padding=1),
)

input = torch.zeros([1, 3, 32, 32]).float()
lazy_cnn(input)  # すべてのLazyConv2dが自動的に初期化される。

注意点

しかし、Lazy modulesに対してパラメータの初期値を全部0にするなどしようとすると、少し困ったことが起きます。 まず、例を見てみましょう。

conv = nn.Conv2d(3, 32, 3, 1, 1)
lazy_conv = nn.LazyConv2d(32, 3, 1, 1)

# 重みをすべて0で初期化
nn.init.zeros_(conv.weight)
print(float(conv.weight.sum()))  # 出力:0.0

# LazyConv2dをそのまま初期化しようとするとValueErrorが出る。
nn.init.zeros_(lazy_conv.weight) 

通常、PyTorchではnn.initを使ってパラメータの初期値を変更します。しかし、Lazy modules(上の例だとLazyConv2d)の場合UninitializedParameterに対して初期値を変更しようとするとValueErrorが返されてしまいます。 現状(PyTorch ver 1.9.1)、このエラーを避けるためには事前にLazy modulesを初期化する必要があるようです。ここの挙動は、今後のアップデートで改善されるかもしれません。

Lazy modulesが便利な場面

最後に、Lazy modulesが便利な場合を挙げます。

モデルをベタ書きするとき

Lazy modulesは、ネットワークの構造をべた書きしているときに便利だと思います。例えば、以下のような2層の畳み込みニューラルネットワークを考えます。例えば、次のようなニューラルネットワークを書いたとして、

cnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, 1, 1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32, 64, 3, 1, 1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
)

このネットワークの2層目に新しい層を加えたくなったとしましょう。 その場合、通常のConv2dで実装しているときは以下のように書き換えることになります。

cnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, 1, 1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32, 48, 3, 1, 1),  # add
    nn.BatchNorm2d(48),  # add
    nn.ReLU()  # add
    nn.Conv2d(48, 64, 3, 1, 1),  # rewrite
    nn.BatchNorm2d(64),
    nn.ReLU(),
)

この時、最後のConv2dの入力チャネル数はついついそのままにしがちです。 するとエラーが出て、修正する必要が出てくるわけですが、LazyConv2dを使っていればその修正が必要無くなります。 個人的にここのミスはやりがちなので、そのミスを防いでくれるのは結構うれしいです。

入力のサイズがプーリング層などで変わるとき

画像を扱うときの話にはなってしまいますが、よくある畳み込みニューラルネットワークの構造として、畳み込み層とプーリング層を繰り返したあと全結合層に渡すことがあります。 コードでみると次のようなネットワークです。

nn.Sequential(
    nn.Conv2d(3, 32, 3, 1, 1),
    nn.ReLU(),
    nn.MaxPool2d(3, 2, 1),
    nn.Conv2d(32, 64, 3, 1, 1),
    nn.ReLU(),
    nn.MaxPool2d(3, 2, 1),
    nn.Flatten(),  # Linear に入力するために形を整える。
    nn.Linear(???, 32),  # 入力のサイズを事前に計算する必要がある。
)

このようなネットワークを使うとき、最後の全結合層に入力される次元数を考えるのは少し面倒です。 なぜなら、入力される画像のサイズを事前に知っておかなければならないのはもちろん、プーリング層を通るとサイズが半分になってしまうのでそこも考える必要があるからです。 また、少し精度が悪いからもう一層追加しよう、とかプーリング層抜こうとか考え始めるといちいち設定しなおす必要が出てきます。 このような場合、Linearの代わりにLazyLinearを使うことで修正の手間が一気になくなり、非常に楽になります。

最後に

今回はLazy modulesについて説明しました。Lazy modulesを使うことで、実装の手間を少し減らすことができより快適なPyTorchライフが実現できるかもしれません。 注意点として、この機能はまだ開発段階であるため挙動が将来的に変わってしまう可能性があります。そのため、まだLazy modulesを使うのは個人的なプロジェクトに留めておくのが無難だと思います。 TechFUL PROでは、ディープラーニングでも取り組めるような機械学習の問題を提供しています。ぜひ、TechFUL PROが提供している問題を使って、Lazy modulesを試してみてください。