ONNXモデルの中間ノード出力を抽出するライブラリを作ってみた

先日記事にしたこれを実装してみた。 maminus.hatenadiary.org

ソースコード

GitHub上げました。まともな?ライブラリを作るの初めてなので何か間違ってるかも・・・(特にPipenv自信なし。pip install も直接gitを指定できなかった。たぶんどこか間違ってる)

やってること

  • ONNXモデル(onnx.ModelProto)を引数に受け取り、編集後のONNXモデルを戻り値で返す関数 add_middle_output() を実装
  • 各ノードのoutput[]からIdentityノードを生やす
  • model.graph.outputに中間層出力用のValueInfoProtoを追加
    • Identityノードの出力がValueInfoProto
    • 出力名はOMMLE.node_output_name.middle_0の形式
      • node_output_nameがノードの出力名
      • 上記の名前で既存の名前と衝突する場合は最後の0をインクリメントする
  • 引数cast_typeにonnx.TensorProtoのデータ型を指定した場合、Identityノードの代わりにCastノードを使う
    • 指定したデータ型にキャストした中間ノード出力が得られる
    • 事前に中間ノード出力のデータ型が不明な場合に使う想定
    • cast_typeを省略した場合、onnx.TensorProto.FLOAT(numpy.float32相当)である前提で動くので注意
  • 引数exclude_op_typesリストにノードのop_type名を入れると該当するノードから中間出力を出さない
    • op_type名は完全一致で比較する
  • 引数exclude_output_namesリストにノードのoutput名を入れると該当するoutput名から中間出力を出さない
    • 同じく完全一致で比較
    • output[]名はNetronなり何なりで調べて指定する

メモ

Identityノードを挟まずに直接ノードの出力をONNX出力に指定すると推論時に以下のようにエラーになる。(ONNX仕様なのか、ONNXRuntineの実装都合なのかは未調査)

>       self._sess.load_model(providers)
E       onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Graph output (conv_1) does not exist in the graph.