ONNXモデルの中間層出力を取得するアイデア

忘れないように自分用のメモ。

やりたいこと

  • ONNXモデルを使って推論をする時に最終出力ではなく中間ノードの出力がほしい

イデア

  • 2フェーズに分けて処理する
  • 第一フェーズでは中間ノードの出力にShapeノード+出力ノードを接続する
    • make_node('Shape', inputs=[middle_node.output[0]], output='shape1')
    • make_tensor_value_info('shape1', TensorProto.INT64, ['unknown'])
    • ValueInfoProtoのshapeは文字列もOKなので事前に値がわからなくてOK
    • Shapeの出力は1次元なのでValueInfoのdimは1固定にできる
    • たぶんこの状態で推論できるはず
    • 推論によって中間ノードのshapeが取得できる
  • 第二フェーズで中間ノードの出力にIdentityノード+出力ノードを接続する
    • make_node('Identity', inputs=[middle_node.output[0]], output='middle_output1')
    • make_tensor_value_info('middle_output1', TensorProto.FLOAT, shape1)
    • shape1は第一フェーズの出力
    • 第二フェーズの推論によって中間出力が得られる
    • データ型は自動取得する手段が思い浮かばない
    • IdentityノードではなくCastノードで強制キャストする手もありそう
    • 出力データ型がstring以外であればfloatへキャストとかでとりあえず取得できそう

あくまでアイデアなので、実際に実装して試してみたい。おそらくこのアイデアは複数の中間ノードにいっせいに実施できるはず。(outputに上限とかがなければ)

2020/02/15 追記
  • make_tensor_value_infoの引数shapeはNoneを指定できた
  • Noneにするとshape不一致で推論時エラーにならないので第一フェーズは不要になる
  • データ型は回避方法が見つからなかった
  • onnx.TensorProto.UNDEFINEDというのがあるが、指定しても推論時エラーになるだけだった
  • 事前にデータ型が不明な場合はCastノードを挟んで強制的に出力データ型を変更するしかなさそう。
    • 複素数型や文字列型にも対応できる汎用的な方法は見つかっていない