PyTorchモデルをONNX化する際にカスタムオペレータの実装が必要っぽい件

やりたかったこと

  • PyTorchモデルをONNXファイルにexportしたい
  • モデルにはカスタムオペレータを含む
  • カスタムオペレータの実装は存在しない(該当部分は推論不能
  • カスタムオペレータはONNXでそれらしきop_typeを持つノードで出力してくれればOK

やってみたこと

  • register_custom_op_symbolic()で'custom_ops::my_operator'名でONNXオペレータ定義のルーチンを登録
  • forward()メソッドでtorch.ops.custom_ops.my_operator(...)呼び出し
  • torch.onnx.export()でモデルを変換

結果

カスタムオペレータの実装を探しに行って例外発生。

  File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torch/_ops.py", line 61, in __getattr__
    op = torch._C._jit_get_operation(qualified_op_name)
RuntimeError: No such operator custom_ops::batched_nms

考察(推測)

  • ONNX変換時にカスタムオペレータの実装コードが必要(っぽい)
  • register_custom_op_symbolic()は(おそらく)ONNX変換の時に使うだけで変換時のダミー推論には使われない
  • (おそらく)トレースありの状態でダミー推論を実行して、トレース中に呼び出されたオペレータをONNXに変換する、という形でONNX変換が実現されているような気がする
    • なので推論自体が実行できないとダメ、ということだと思われる
    • おそらくカスタムオペレータの実装はそれらしきダミーデータを固定値で返す実装でもONNX変換には影響なさそう

メモ

カスタムオペレータを実装する場合はこちらの公式チュートリアルを参考にすればよさそう。