PyTorchモデルをONNX化する際にカスタムオペレータの実装が必要っぽい件
やりたかったこと
- PyTorchモデルをONNXファイルにexportしたい
- モデルにはカスタムオペレータを含む
- カスタムオペレータの実装は存在しない(該当部分は推論不能)
- カスタムオペレータはONNXでそれらしきop_typeを持つノードで出力してくれればOK
やってみたこと
- register_custom_op_symbolic()で'custom_ops::my_operator'名でONNXオペレータ定義のルーチンを登録
- forward()メソッドでtorch.ops.custom_ops.my_operator(...)呼び出し
- torch.onnx.export()でモデルを変換
結果
カスタムオペレータの実装を探しに行って例外発生。
File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torch/_ops.py", line 61, in __getattr__ op = torch._C._jit_get_operation(qualified_op_name) RuntimeError: No such operator custom_ops::batched_nms
考察(推測)
- ONNX変換時にカスタムオペレータの実装コードが必要(っぽい)
- register_custom_op_symbolic()は(おそらく)ONNX変換の時に使うだけで変換時のダミー推論には使われない
- (おそらく)トレースありの状態でダミー推論を実行して、トレース中に呼び出されたオペレータをONNXに変換する、という形でONNX変換が実現されているような気がする
- なので推論自体が実行できないとダメ、ということだと思われる
- おそらくカスタムオペレータの実装はそれらしきダミーデータを固定値で返す実装でもONNX変換には影響なさそう