ONNXモデルの中間ノード出力を抽出するライブラリを作ってみた
先日記事にしたこれを実装してみた。 maminus.hatenadiary.org
ソースコード
GitHubに上げました。まともな?ライブラリを作るの初めてなので何か間違ってるかも・・・(特にPipenv自信なし。pip install も直接gitを指定できなかった。たぶんどこか間違ってる)
やってること
- ONNXモデル(onnx.ModelProto)を引数に受け取り、編集後のONNXモデルを戻り値で返す関数 add_middle_output() を実装
- 各ノードのoutput[]からIdentityノードを生やす
- model.graph.outputに中間層出力用のValueInfoProtoを追加
- Identityノードの出力がValueInfoProto
- 出力名は
OMMLE.node_output_name.middle_0
の形式- node_output_nameがノードの出力名
- 上記の名前で既存の名前と衝突する場合は最後の0をインクリメントする
- 引数cast_typeにonnx.TensorProtoのデータ型を指定した場合、Identityノードの代わりにCastノードを使う
- 指定したデータ型にキャストした中間ノード出力が得られる
- 事前に中間ノード出力のデータ型が不明な場合に使う想定
- cast_typeを省略した場合、onnx.TensorProto.FLOAT(numpy.float32相当)である前提で動くので注意
- 引数exclude_op_typesリストにノードのop_type名を入れると該当するノードから中間出力を出さない
- op_type名は完全一致で比較する
- 引数exclude_output_namesリストにノードのoutput名を入れると該当するoutput名から中間出力を出さない
- 同じく完全一致で比較
- output[]名はNetronなり何なりで調べて指定する
メモ
Identityノードを挟まずに直接ノードの出力をONNX出力に指定すると推論時に以下のようにエラーになる。(ONNX仕様なのか、ONNXRuntineの実装都合なのかは未調査)
> self._sess.load_model(providers) E onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Graph output (conv_1) does not exist in the graph.