BASEプロダクトチームブログ

ネットショップ作成サービス「BASE ( https://thebase.in )」、ショッピングアプリ「BASE ( https://thebase.in/sp )」のプロダクトチームによるブログです。

機械学習チームで論文読み会を実施してみました(A ConvNet for the 2020s解説)

BASEの機械学習チームで論文読み会を実施してみました

こんにちは。BASEのDataStrategy(DS)チームでエンジニアをしている竹内です。
DSチームではBASEにおける様々なデータ分析業務をはじめ、機械学習技術を利用した検索、推薦機能のサポート、商品のチェックや不正決済の防止などに取り組んでいます。

先日、チーム内で最新の機械学習技術についての知見を相互に深めるための試みとして、各々興味のある機械学習系の論文を持ち寄って紹介し合う、いわゆる論文読み会というものを実施してみました。
この記事では、その会で私が発表した内容の一部を紹介したいと思います。

※ 中身は論文読み会用から本記事用に一部修正を加えています。

A ConvNet for the 2020s

紹介する論文について

タイトル: A ConvNet for the 2020s
著者: Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie
Facebook AI Research (FAIR), UC Berkeley
CVPR 2022

arXivリンク: https://arxiv.org/abs/2201.03545
公式実装: https://github.com/facebookresearch/ConvNeXt

※ 挿入している図(画像)と英文は特に言及がない限り本論文からの引用になります。

TL;DR

  • 直近の画像処理NNのアーキテクチャにおいては、Transformerをベースにしたものがトップクラスの性能を発揮(Swin-T)
  • TransformerのキーとなるモジュールはMulti-Head Self-Attention(MSA)だが、実際にはそれ以外にも従来のConvNetに取り入れられていない様々なテクニックが存在→真にConvNetを上回っているとは言えないのでは
  • そこで従来のResNetに、Transformerに加えられているMSA以外のテクニックを可能な限り盛り込んだ(ConvNeXtと命名)
  • ConvNeXtは従来のTransformerベースのモデルに対して、モデルサイズを抑えながら性能を上回ることができた

    We gradually “modernize” a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way.

f:id:shotakeuchi:20220325085132p:plain
画像分類タスクにおけるConvNeXtとViTの性能比較

画像系NNモデルアーキテクチャの流れ

ConvNeXtに到達するまでのターニングポイント的なアーキテクチャをざっくりと振り返ってみる。(リンクはarXiv)

  • AlexNet(2012)
    • ConvNetの始祖的な存在
    • ImageNetコンテストで圧勝
      →以降ConvNetの層を深くするのがトレンドに
      →層を深くすると以下の二つの問題が浮上
      • Back Propagation時の勾配消失・勾配爆発の問題
      • 精度の飽和、学習時のエラーの上昇(Degradation)の問題
  • ResNet(2015)
    • 勾配消失・勾配爆発問題はBatch Normalizationが有効
    • Degradationに対するアプローチとして、層間のショートカット接続の重要性に注目
    • ショートカットを含んだブロックを含むアーキテクチャを提唱
      f:id:shotakeuchi:20220325085239p:plain
      ResNetのショートカット構造(He, Kaiming, et al. "Deep residual learning for image recognition."より)
  • ResNeXt(2016)
    • ResNetのブロックを並列に並べて集計する仕組みを提唱
      • パラメータ数と性能のトレードオフを改善
    • 並列に並べる数Cardinalityをハイパーパラメータとして導入
    • 少ないパラメータ数、小さいモデルサイズ、シンプルな形でResNetの性能を上回る
      f:id:shotakeuchi:20220325085530p:plain
      ResNeXtの仕組み(Xie, Saining, et al. "Aggregated residual transformations for deep neural networks."より)
  • EfficientNet(2019)
    • ネットワークの深さ、広さ、解像度(画像サイズ)の3つをパラメータとして最適化
    • 扱いやすく画像系の機械学習コンペ等ではよく見るアーキテクチャの一つ
  • VisionTransformer(ViT)(2020)
    • 自然言語処理系NNにおいてはデファクトスタンダードとなったTransformerを画像処理に応用
    • 画像を16x16のパッチに分け、それぞれのEmbeddingを単語に見立てる
    • 画像分類において少ない計算コストでトップ性能を発揮
    • 画像処理における最近の大きなブレークスルーの一つ
      f:id:shotakeuchi:20220325085628p:plain
      ViTで用いられる画像のパッチ化(Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale."より)
  • Swin Transformer(Swin-T)(2021)
    • ViTを物体検知やセマンティックセグメンテーションなど、ピクセル単位の解像度が要求される他のタスクでも効果を発揮できるように改良
    • パッチ化の処理を階層化+1マスずつズラす処理(Shifted Window)を導入することで画像サイズの2乗であったViTの計算量を線形まで落とした
    • 画像処理系のトップカンファであるICCV'21のBest Paper
    • ConvNeXtの論文では、Swin-Tで採用されているWindowをズラす処理がConvと類似しているため、重要な要素であると考えられることが言及されている

      For example, the “sliding window” strategy (e.g. attention within local windows) was reintroduced to Transformers, allowing them to behave more similarly to ConvNets. ... Swin Transformer’s success and rapid adoption also revealed one thing: the essence of convolution is not becoming irrelevant; rather, it remains much desired and has never faded.

f:id:shotakeuchi:20220325085807p:plain
Swin-Tで用いられる階層的なパッチ化(Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows."より)

ResNetがConvNeXtになるまでに加えられた改良

Chapter2以降ではベースとなるResNetに加えられた手法と、それによる精度の改善幅について順に説明されている。

2.1 学習手法

ネットワークのアーキテクチャを弄る前に学習手法をTransformerに倣って改善していく。

  • エポック数を90→300に
  • AdamW optimizer(2019)
    L2正則化とWeight DecayがAdamでは同一視することができないことを示し、AdamのWeight Decayに修正を加えた
  • データ拡張
    • Mixup(2018)
      2種類のデータとラベルをベータ分布からランダム生成された \lambdaを使って以下のように混ぜ合わせる
      データ:  X=\lambda X_1+(1-\lambda)X_2
      ラベル:  y=\lambda y_1+(1-\lambda)y_2
    • Cutmix(2019)
      画像の一部を切り取り、別のラベルの画像を挿入
    • RandAugment(2020)
      グリッドサーチで最適な拡張度合いを見つける
    • RandomErasing(2017)
      画像内にランダムな矩形を追加する

f:id:shotakeuchi:20220325085928p:plain
Cutmix(Yun, Sangdoo, et al. "Cutmix: Regularization strategy to train strong classifiers with localizable features."より)

f:id:shotakeuchi:20220325090020p:plain
Random Erasing(Zhong, Zhun, et al. "Random erasing data augmentation."より)

  • 正則化
    • Stochastic Depth(2016)
      ランダムでResブロックをスキップのみにする(後ろの層になるほどその確率が高くなる)
    • Label Smoothing(2016)
      正解ラベルと不正解ラベルの値を1, 0ではなく0.9, 0.1などとする

これらの追加により性能は76.1%→78.8%に改善

2.2 マクロデザイン

ここからは、ResNetのマクロな構造をTransformerに近づけていく。

ステージごとの計算比率の変更

複数のResブロックからなる1かたまりはステージと名付けられており、ConvNeXtには合計で4つのステージが存在する。
それぞれのステージのブロック数はResNet-50では(3, 4, 6, 3)であったが、これをSwin-Tに合わせて(3, 3, 9, 3)に変更した。
78.8%→79.4%に改善

stemで画像をパッチ化するように変更

入力された画像に対して一番最初に処理を行う部分はstemと名付けられており、ViTなどでは画像のパッチ化を行う部分に相当する。
従来のResNetでは、まず入力の画像を適切な特徴量のサイズにするためにカーネルサイズ7×7ストライド2のConv+Max Poolingを使用することで4倍のダウンサンプルを行なっていた。
一方でViTでは画像を16×16のパッチにする処理を行なっているが、これはカーネルサイズ16×16で重複なし(ストライド16)のConvに相当する。
Swin-Tではより小さい4×4のパッチを作成しているため、これに倣ってstemでカーネルサイズ4×4でストライド4のConvを使用する。
79.4%→79.5%に改善

2.3 ResNeXt化

ResNeXtで取り入れられている、1つのConvを複数に分岐させてあとからまとめることでパラメタ数を削減する手法を適用した。
モデルのキャパシティの減少を抑えつつパラメタ数を効率的に削減できるため、モデルサイズを維持したまま性能を大幅に引き上げることができる。
今回はチャンネル数分の分岐を作成するDepthwise Convを使用する。これはTransformerにおけるAttention層のMulti-Head化に相当すると言及されている。

We note that depthwise convolution is similar to the weighted sum operation in self-attention

79.5%→80.5%に改善

2.4 Inverted Bottleneck

Transformerでは入力の次元より隠れ層MLPの次元の方が4倍大きくなるInverted Bottleneckというデザインを採用している。
このアイデアはMobileNetV2(2018)ですでに利用されており、その後の改良型ConvNetでもしばしば用いられている。
下図の(a)がResNeXtの1ブロックで用いられる通常のBottleneckで、チャンネルサイズを384→96に落としてからDepthwise Convで96→96に畳み込み、最後にチャンネル数を384に戻している。
(b)がInverted Bottleneckで、チャンネル数を逆に96→384に増やしてから384→96に戻している。
これによってDepthwise Convの計算量は増えるものの、入力部分がダウンサンプルされていることによってResブロックのショートカットの1×1Convの計算量が減るため、全体の計算量は減ることになる。

f:id:shotakeuchi:20220325090837p:plain
(a)がResNeXt, (b)がInverted Bottleneck, (c)がDepthwise Convを移動させたもの

80.5%→80.6%に改善(ResNet-200では81.9%→82.6%に改善)

2.5 Large Kernel Sizes

従来のConvNetでは3×3など小さいカーネルサイズ使用するのが主流であったものの、Swin-TのWindowのサイズは小さくとも7×7である点を考慮すると、カーネルサイズは大きい方が有効だと思われる。
これを実現するために以下の二つの手順を踏んでいる。

Depthwise Convの移動

より大きなカーネルサイズを利用するために、Depthwise ConvをResブロックの最初にもってくる。(上の図の(b)→(c)に対応)
これはTransformerのMulti-Head AttentionがMLPの前に配置されていることに対応する。
(一時的に)80.6%→79.9%に悪化

カーネルサイズの増加

Depthwise Convを移動させた後、そのカーネルサイズを3から5, 7, 9, 11と増やしていくと計算量は大体保たれたまま性能が改善され、7で大体性能が飽和する。
サイズの大きいResNet-200でも同じ7で飽和することが確認されている。
79.9%→80.6に改善

この時点でViTで採用されているデザインの大部分を実現できていることになる。

2.6ミクロデザイン

大枠のアーキテクチャは完成したため、ここからはレイヤーレベルで改善していく。

ReLUをGELUで置き換える

活性化関数として使用されているReLUをBERTやGPT-2、ViTでも使われている以下のGELU(2016)に置き換える。


GELU(x) = xP(X)=x\Phi(x) \\\
\space \Phi(x) = P(X\le x),  X \sim N(0,1)

f:id:shotakeuchi:20220325090907p:plain
GELUと他の活性化関数との比較(Hendrycks, Dan, and Kevin Gimpel. "Gaussian error linear units (gelus)."より)

80.6%→80.6%で性能据え置き

活性化関数を減らす

Transformerの1ブロック(入力のKey/Query/ValueをEmbeddingしてMLPに入れる部分)には活性化関数が1回しか使用されていない一方で、ResNetは1ブロックにConv層の数だけ活性化関数が存在する。
これを1×1のConv2つの間のみに絞ることで活性化関数の数を合わせる。
80.6%→81.3%に改善

Normalization層を減らす

これもTransformerに合わせて1×1conv層の前にのみBatch Normalization層を置く。
81.3%→81.4%に改善

BatchNormをLayerNormに変更

これもTransformerで使用されている手法ではあるが、単純なResNetのBNをLNに置き換えるだけでは性能が下がることが確認されている。ここまでの改造を全て加えた上でLNに置き換えると性能の改善が見られる。
81.4%→81.5%に改善

ダウンサンプル層を切り離す

ResNetでは各ステージの最初のResブロックでストライド2のConvによってダウンサンプリングを行なっているが、Swin-Tではこのような処理は各ステージの間で行われている。
これに倣ってConvNeXtでもステージの間にダウンサンプル層とNorm層を追加することで学習を安定化させた。(Norm層なしだと学習が発散した。)
81.5%→82.0%に改善

f:id:shotakeuchi:20220325090935p:plain
ResNetに加えた全ての変更点とそれによる性能および計算量の改善

f:id:shotakeuchi:20220325091016p:plain
最終的なモデルのアーキテクチャ

性能

f:id:shotakeuchi:20220325091045p:plain
画像分類タスクにおける性能比較

ImageNetにおいてConvNeXtは同程度のモデルサイズのSwinTransformerを上回る性能を発揮している。

f:id:shotakeuchi:20220325091140p:plain
物体検出タスクにおける性能比較
COCOデータセットにおける物体検出においてもConvNeXtは同程度のモデルサイズのSwinTransformerを上回る性能を発揮している。

感想など

  • ここ最近の画像処理系の流れを振り返るのにちょうど良い論文で、公式のpytorchによる実装と合わせて内容が非常にわかりやすく良い論文
  • 強いて言えば既存の技術の応用という面が強いため、新しいアイデアや知見、理論的な深掘り(「なぜAttentionよりConvolutionの方が上手くいくのか」など))の面では若干物足りない気もする
  • 昔読んだ深層強化学習系のRainbow: Combining Improvements in Deep Reinforcement Learningという論文になんとなく立ち位置が似ているなと感じた
    • ベースとなるDeep Q-Networkという手法に7つの改善手法を加えたときのパフォーマンスの改善について研究した論文
  • 自分でベースモデルのアーキテクチャや学習手法に改善を加える際の流れとしても参考になる

おわりに

今回のような論文読み会はDSチームとしては初の試みでしたが、新しい知見を取り入れ視野を広げる良い機会に感じたので1Qに1回ぐらいのペースで継続していけたらと思っています。
今後もDSチームでは新しい技術についても積極的に検討し、検証を重ねることで更なるプロダクトの改善、サービスの向上に取り組んでいきます。