Axial-Attentionについて調べたことのメモ

Axial-DeepLab (ECCV 2020, Spotlight、arXiv:2003.07853)の論文1を読んでいて、Axial-Attentionが具体的にどんなものなのかよく分からなかったので、色々調べたり著者実装2を自分で実行してみたりして「こんな感じかなぁ」という認識に至ったので内容をメモしておく。

Attention?

そもそも「Attentionが何か」というのは、こちらの記事3が網羅的で分かりやすいと思う。

自分の理解では入力featureからqueryとkeyとvalueを生成し、queryとkeyの行列積で類似度表のようなものを作り、類似度表とvalueの行列積で値を取り出すと思っている。

類似度表がいわゆるAttentionなのだと認識している。(ただし、理解が正しいという自信はない)

Axial-Attentionモジュールの構造

論文より図を引用する。

axial-attentionモジュール
Axial-Attentionモジュール 論文 Fig.1 (right)より引用

 W_q  W_k  W_vはquery、key、value。rq、rk、rvはquery、key、valueのpositional encodingベクトルとのこと。

処理内容を箇条書きすると以下のようになる。

  1. 入力featureを全結合層に通してquery、key、valueを得る
  2. チャネル次元をグループに分割する
  3. queryとkeyの行列積で類似度表をグループの数だけ生成する
  4. query、keyにそれぞれpositional encodingベクトルを行列積して類似度表に加算する
  5. 類似度表とvalueの行列積、類似度表とpositional encodingベクトルの行列積を求めて合算する

※Fig.1中の「softmax」とある辺りが類似度表である

以下に補足する。

query、key、valueの生成、グループ分割

入力featureのshapeを[N, C, H, W]とする。(Nはミニバッチサイズ、Cはチャネル数、HとWのpixel数)

これを[N*W, C, H]に形状変更する。(Height-Axis版の場合、Width-Axis版はHとWが入れ替わる。以下同様)

さらに全結合層に通す。著者実装では全結合層は Conv2d 1x1 で実装されている。queryとkeyはチャネル数がAttentionモジュールの出力チャネル数の半分になる。さらにチャネル次元はグループに分割される。

query: [N*W, C, H] → Conv → [N*W, outs/2, H] → 分割 → [N*W, group, plane/2, H]
key:   [N*W, C, H] → Conv → [N*W, outs/2, H] → 分割 → [N*W, group, plane/2, H]
value: [N*W, C, H] → Conv → [N*W, outs,   H] → 分割 → [N*W, group, plane,   H]

out
sは出力チャネル数
groupはグループ数
planeはouts/group

類似度表の生成

query×keyで類似度表を生成する。

einsum('bgci, bgcj -> bgij', query, key)
[N*W, group, plane/2, H]×[N*W, group, plane/2, H] -> [N*W, group, H, H]

positional encodingベクトルがそれぞれquery、keyに適用されて類似度表に合算される。(要素ごとの和)

einsum('bgci, cij -> bgij', query, embedding)
[N*W, group, plane/2, H]×[plane/2, kernel, kernel] -> [N*W, group, kernel, H]
kernelはHと同じサイズ

kernelというのが、具体的に何を指しているのかは不明。著者実装を動かしてわかったことは、とにかくkernelサイズとHが一致する、ということだけだった。

類似度表はH次元でsoftmaxを取るので以下のようなH vs Hの表になる。

H vs Hのsimilarity表

これはある画素(i, j)の出力値を決める時に、W(行)方向は固定してH(列)方向の各pixelの出力を何%ずつ持ってくるか、という成分表のようなイメージになる。

上図を例に取ると、j=0の出力はH=0が10%、H=1が80%、H=2が10%でミックスした値となる。

なお、similarity表のshapeは[N*W, group, kernel, H]であるので、上図の表は実際には(N*W)×group数個存在している。(W方向はpixelごとに別の表になる)

類似度表から出力feature生成

valueとpositional encodingベクトルをそれぞれ類似度表と行列積して要素ごとの和を取る。

einsum('bgij, bgcj -> bgci', similarity, value)
[N*W, group, H, H]×[N*W, group, plane, H] -> [N*W, group, plane, H]
einsum('bgij, cij -> bgci', similarity, embedding)
[N*W, group, H, H]×[plane, kernel, kernel] -> [N*W, group, plane, H]

行列積の最内ループは内積なので H dot H になっていて、類似度表は0〜1の成分表なので成分表に従ってミックスする操作になっている。

下図は画像の中心点(だいだい色に塗ったpixel)の出力を決める時に縦方向のpixelを類似度表で参照している、という図である。 similarityベクトルのイメージ

類似度表との行列積が終われば、あとは形状変更で[N, C, H, W]にして出力となる。

なお、成分表がグループごとに別々なのでチャネル次元でみた時にグループごとに別々の領域にAttention(注目)してくれる可能性がある。(そのように学習が進めば)

positional encodingベクトルについて

自分はpositional encodingについて正直よくわかっていない。位置ごとに別の値を足した状態で正しく識別できるように学習するのだから、何らかの位置を加味した重みに学習されるに違いないくらいのふんわりした認識しかない。

学習対象のパラメータ

著者実装を見る限り、以下が学習パラメータ(重み)になりそうだった。

  • 全結合層(Conv 1x1)の重み
  • positional encodingベクトル
  • Batch Normalizationのパラメータ

Axial-Attentionの(乱暴な)イメージ

自分の勝手なイメージだとBatch Normalizationは正則化項のような働きをしてくれる何かと認識しているので、実質Axial-Attentionは入力featureのどの部分を取り出すかを positional encodingベクトル、Conv 1x1 で決めているという理解であながち間違っていないのではないかと思う。

類似度表を直接重みにしていないのでとっつきにくいが、入力とConvとpositional encodingベクトルで類似度表が生成される、という点が理解できれば十分な気がした。

Axial-Attentionブロックの構造

Height-Axis版とWidth-Axis版のAxial-Attentionモジュールを組み合わせて縦横全域をカバーするAxial-Attentionブロックになる。

Axial-Attentionブロック
Axial-Attentionブロック 論文 Fig.2 より引用

論文の図でMulti-Head、Concatと記載されている箇所は、チャネルをグループに分割してグループごとに別々の類似度表を適用していることに相当する。著者実装ではshapeの形状変更で対応できているのでConcatしていない。(要素ごとの和を取る際にConcatしているが、Multi-Headの件とは別の実装都合と思われる)

Attentionの可視化

論文よりAttentionの可視化画像を引用する。

Attentionの可視化
Attentionの可視化画像 論文 Fig.7 より引用

最初にこの画像をみた際に、どこに着目して見ればよいのかが分からなかった。

column headとある図は、下図のように青いpixelの出力が縦方向に見て特に赤色になっているpixelの値を使っている、ということである。

可視化画像の解説
画像の解説

row headとある図は、青いpixelの出力が横方向で赤い箇所の出力を多く使っているということになる。

以上。


  1. Wang, Huiyu and Zhu, Yukun and Green, Bradley and Adam, Hartwig and Yuille, Alan and Chen, Liang-Chieh. Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation. European Conference on Computer Vision (ECCV). 2020

  2. csrhddlam. 2020. Axial-DeepLab (ECCV 2020, Spotlight). https://github.com/csrhddlam/axial-deeplab , 2021/02/24閲覧

  3. @halhorn. 2018. 作って理解する Transformer / Attention. https://qiita.com/halhorn/items/c91497522be27bde17ce , 2021/02/24閲覧