fastaiで学習に使う関数をApache MXNetで真似してみた

f:id:aptpod-tetsu:20200513191117j:plain

はじめに

先進技術調査グループのせとです。本ブログでは、Apache MXNetを用いてfastaiで実装されている実践的な関数を真似てみた結果を紹介します。この試みのゴールは、完全一致の結果を目指すのではなく同じような傾向を得られるかを目指したものになります。完全一致を目指したいところですが、各フレームワークで用意しているモデルの構造が少し違ったり、各関数の計算方法が異なるので結果が等しくなりませんでした。もちろん、他方に併せて関数を自作すればほとんど一致する結果を得ることができますが実装のコストが高かったため、今回は行いませんでした。

モチベーション

弊社のプロジェクトでAI部分をAmazon SageMaker(以下、SageMaker)を使って実装したい要望がありました。しかし、プロジェクトで利用していたフレームワークはfastaiであるために簡単にSageMaker上で実行できないことがわかりました。この課題を解決するために、はじめは単純な学習を行えば達成できると思っていたのですが、実際にためしたところfastaiで達成した精度を再現できませんでした。このため、ファーストステップとしてfastaiで用いた関数をSageMakerのベースに使われているApache MXNetで真似て精度を再現を行えるかを試みました。

Deep Learningライブラリの説明

ここでは利用したDeep Learningライブラリを簡単に説明します。

fastai

fastaiは、PyTorchをベースにしたDeep Learningのフレームワークです。特徴は、実験的に良かった内容に関する論文の成果を実装して使えるようにしているところです。良いノウハウを簡単に利用できて実戦的に使える良いフレームワークだと思います。

Apache MXNet

Apache MXNet(以下、mxnet)はfastaiと同様のDeep Learningのフレームワークで、fastaiとの違いは自作関数を書くような低レベルな書き方や、gluonをつかった高レベルな書き方など柔軟なフレームワークです。MXNet とは | AWSにあるように、SageMakerに利用されているフレームワークです。

実装した機能

プロジェクトで利用した以下の2つの機能をmxnetで実装しました。

以下にそれぞれの機能の概要を説明します。

learning rate finder

Deep Learningのモデル学習では、色々なハイパーパラメータ(bach size、weight decay、learning rate、momentumなど)を設定する必要があります。不適切なパラメータを用いると、学習しない・収束に時間がかかるなどが起こるため、適切なパラメータを探す必要があります。経験則や一般的に良いと言われている値などを用いて仮決めしたあとに、学習の様子をみて微調整して決定することが多いです。lerning rate finderの機能は、良いlerning rateを決めるためにある範囲内のlearning rateごとにlossを算出します。この結果のlossを見て適切なlearning rateを決定します。詳しい説明はここに書かれており、この手法はこの論文で提案されています。

fit one cycle

Deep Learningのモデル学習では、一般的には学習が完了するまでに時間がかかります。この収束を早くするための工夫がone cycle training機能になります。この機能は、設定したlearning rateを基準に基準より小さいlearning rateから学習を開始しイテレーションのたび大きくします。基準まで到達した後、下限に設定したlearning rateまで段々と小さくします。この手法はこの論文で提案されています。

実験条件

実行環境

  • OS: Ubunt18.04(docker on CentOS7)
  • CPU: Intel Core i7-6850K
  • GPU: GeForce 1080Ti
  • Memory: 63GB
  • nvidia-docker

学習条件

  • dataset
    • CIFAR10
  • 画像サイズ
    • 224x224x3
      ※もともと32x32x3ですが、アーキテクチャを変えたくなかったため224にアップサンプリングしています。推論にとってはあまり必要のない行為です。
  • モデル
    • resnet50 + cnn learner module
      ※cnn learner moduleはfastaiの関数で生成される層のことです。
  • 各パラメータ
    • epochs: 3
    • batch size: 32
  • データ拡張
    • fastai: get_transform{max_rotate=5.}で設定できる処理

実験

learning rate finder関数の比較

fastaiでlearning rate finder関数を実行した結果は下図のとおりです。

f:id:aptpod_tech-writer:20200511152951p:plain
fastai: learning rate finder

mxnetで模擬した結果は下図のとおりです。

f:id:aptpod_tech-writer:20200511153040p:plain
mxnet: learning rate finder

上図を比較すると、learning rateごとのlossが似たような形になっていることが見て取れます。このため、関数の実装ができたとしました。

fit one cycle関数の比較

fastaiのfit one cycle関数を実行した結果が下表のとおりです。

f:id:aptpod_tech-writer:20200511153102p:plain
fastai: fit one cycle

mxnetで模擬した結果は下表のとおりです。

f:id:aptpod_tech-writer:20200511153124p:plain
mxnet: fit one cycle

上図を比較すると、1エポックと2エポックでaccuracyとvalid-accが似たような傾向になっています。このため、この関数も実装ができたとしました。

まとめ

上記結果から、fastaiのlearning rate finder関数とfit one cycle関数をmxnetで実現することができました。これらのモジュールを用いてSageMakerへ投げる用のスクリプトを書けば学習・デプロイが可能になると思います。

今回は、プロジェクト起因の課題から手探りで思ったより苦労しましたが、学びが多くありました。fastaiの関数を理解するためにドキュメントや実装を見てみましたが、細かいテクニックなど(バッチ正規化部分だけはフリーズしないなど)が隠蔽化されており、使う側は何も意識しなくて良いですが、他のフレームワークで似たようなことをしようと思うと表面だけ真似ても難しいことがわかりました。フレームワークを自由に選択できるのであればfastaiはベターな結果を短時間で得られるフレームワークだと思いました。

一方、mxnetは低レベル・高レベルの関数が容易されておりかなり簡易に実装を書けることがわかりました。使い方はPytorchに似た部分もあり複数のフレームワークを扱ったことがあればそこまで苦労しないと思います。しかし、mxnet独自の部分はドキュメントや記事などが少なく労力がかかったので、コミュニティがもっと活発になってほしいなと思っています。

今後は、色々な取り組みがある中で、mxnetベースの画像に特化したGluonCV ToolkitやSageMakerの機能などの使ってみた体験を紹介していきたいと思っています。