ONNXRuntimeのカスタムオペレータを実装してみた
ソースコードはここ↓ github.com
背景とか
TensorRTを使おうとしてモデルの一部をONNXのカスタムオペレータにすることがある。(TensorRTのプラグインを使うケース)
ただカスタムオペレータを含むモデルはそのままではONNXRuntimeで推論できない。ということはONNX Simplifierのようなツールを使うこともできない。 カスタムオペレータを実装することで推論が可能になってツール類も使える。
# ONNXのカスタムオペレータはノードの'domain' attributeに独自ドメイン名を指定すれば作れる nodes = [onnx.helper.make_node('Fma', ['A', 'B', 'C'], ['out'], domain='ai.onnx.contrib')]
カスタムオペレータの実装手段
少し調べたところ、大きく分けて2つのやり方がある。
Pythonで実装すると推論コードもPythonで記載できてPythonのみで実現できるので楽ちん。ただし、onnxruntime-extensionsパッケージが必要になるのと、強めの制約がある。(後述)
Pythonで実装する方法
onnx_op
デコレータをカスタムオペレータ実装ルーチンにつけるだけ。引数のテンソルはnumpy.ndarrayが渡される。戻り値のテンソルもndarrayを返せばOK。
# 引数はすべてfloat32型、戻り値もfloat32型。op_typeは'Fma' @onnx_op(op_type='Fma', inputs=[PyOp.dt_float, PyOp.dt_float, PyOp.dt_float], outputs=[PyOp.dt_float]) def fma(a, b, c): return a * b + c
ただし、float32バージョン
、float64バージョン
、のように扱うデータ型のバリデーションを作ることができない。
これはop_typeに対するルーチンが1つしか登録できない作りになっているためと思われる。
推論は以下のようにする。以下のコードでmodel_func()
の呼び出しがONNXモデル全体の推論実行処理になっている。
model_func = PyOrtFunction.from_model(_ONNX_FILE_NAME) result = model_func(A, B, C)
C++で実装する方法
C++で実装する場合は主に以下の要素を用意すればよい。
void Compute(OrtKernelContext* context)
メソッドを持つkernelクラス- 計算処理本体を実装する
Ort::CustomOpBase<Op, Kernel>
クラスを継承し必要なメソッドを実装したクラスvoid* CreateKernel(OrtApi api, const OrtKernelInfo* info) const
- kernelクラスのインスタンスを生成して返す
const char* GetName() const
- op_type名を返す
ONNXTensorElementDataType GetInputType(size_t index) const
- index番目の入力データ型を返す
size_t GetInputTypeCount() const
- オペレータ引数の数を返す
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const
- 省略可能な引数を持つ場合に実装する
- index番目の引数が必須か省略可能かを返す
ONNXTensorElementDataType GetOutputType(size_t index) const
- index番目の戻り値のデータ型を返す
size_t GetOutputTypeCount() const
- オペレータの戻り値の数を返す
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const
- 省略可能な戻り値を持つ場合に実装する
- index番目の戻り値が必須か省略可能かを返す
RegisterCustomOps()
関数
ソースコードのビルドはONNXRuntimeの3つのヘッダファイルがあればOK。
CMakeLists.txtではfind_path()
でヘッダファイルの場所を探すようにしているので参考に。
kernelクラス
kernelクラスは大雑把に以下の構造になるように実装する。
struct FmaKernel { FmaKernel(OrtApi api):api_(api), ort_(api_) {} void Compute(OrtKernelContext* context) { // ... カスタムオペレータの計算処理 } private: OrtApi api_; Ort::CustomOpApi ort_; };
入力データへのポインタなどの計算に必要なデータはOrt::CustomOpApiでアクセスできる。ただし、CustomOpApiクラスはコンストラクタ引数のOrtApiインスタンスを参照で保持するためOrtApiインスタンスのコピーをkernelクラスのメンバに保持する必要があるとのこと。
// 0番目の引数(float型)へのポインタをもらう例 const auto input_a = ort_.KernelContext_GetInput(context, 0); auto ptr_a = ort_.GetTensorData<float>(input_a); // 出力0を[1, 3, 224, 224]のshapeで作ってポインタをもらう例 size_t shape_dim = 4; const int64_t shape[shape_dim] = {1, 3, 224, 224}; auto output_0 = ort_.KernelContext_GetOutput(context, 0, shape, shape_dim); auto ptr_0 = ort_.GetTensorMutableData<float>(output_0);
もし出力shapeが入力と同じならGetTensorTypeAndShape()
とGetTensorShape()
を呼ぶと入力shapeをもらえるので出力shapeの指定にそのまま使えばよい。
オペレータクラス
オペレータクラスは以下のようにCustomOpBaseクラスを継承する。CustomOpBaseクラスのテンプレート引数には自分自身とkernelクラスを指定する。
struct CustomOpFma : Ort::CustomOpBase<CustomOpFma, FmaKernel> { // 対応するkernelクラスのインスタンスをnewして返す void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { return new FmaKernel(api); } // 単純に対応するop_type名を返すだけでOK const char* GetName() const { return "Fma"; } // ... };
後のI/Fは特筆すべき内容は無いのでinput側だけ掲載する。具体的な実装コードはGitHubにpushした実装コードを参照のこと。
// index番目の引数のデータ型をONNXTensorElementDataType(onnxruntime_c_api.hで定義されている)で返す ONNXTensorElementDataType GetInputType(size_t index) const { if (index > 0) { // "T"を指定可能なのは1つの引数のみ。残りはFLOATなどの具体的な型を返す必要がある return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // UNDEFINEDを返すとデータ型は"T"(任意型)として扱われる return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } // 引数の数を返す。Fmaの例なら引数は3つなので3を返せばOK size_t GetInputTypeCount() const { return 3; } // index番目の引数が省略可能ならOPTIONAL、必須ならREQUIREDを返す OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { if (index > 1) { // 3つ目(index == 2)の引数を省略可能とする例 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; } return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; }
注意点
float32版オペレータ
、float64版オペレータ
などと複数データ型に対応させたい場合に問題点がある。2つのやり方がある。
- データ型を
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
で報告する - データ型ごとに別々のドメイン名に分ける
1の方法はデータ型を"T"(任意型)に指定する方法だが、任意型をとる引数が1つのみのケースしか使えない。FMAのように3つの引数を"T"と指定することができない。
2の方法はONNXモデル側でドメインを分ける必要が発生してしまうかわりに各データ型に対応する実装を作ることができる。もしデータ型ごとにop_type名を変えることができるならドメイン名は同じでop_type名で分ける方法もある。 (カスタムオペレータはop_type名につき1種類の実装しか登録できないためop_type名かドメイン名を分ける必要がある)
RegisterCustomOps関数
このルーチンはDLLのロード時(正確にはregister_custom_ops_library()
呼び出し時)に呼ばれる。
constexpr std::string_view domain_name = "my_ops"; CustomOpFma op_fma; extern "C" { OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { const OrtApi* ort_api = api_base->GetApi(ORT_API_VERSION); OrtStatus* status; // ドメインオブジェクトを作る OrtCustomOpDomain* domain = nullptr; if (status = ort_api->CreateCustomOpDomain(domain_name.data(), &domain)) { return status; } // ドメインオブジェクトを(後で解放処理を実行するために)覚えておく register_domain(domain, ort_api); // ドメインオブジェクトにカスタムオペレータを登録する if (status = ort_api->CustomOpDomain_Add(domain, &op_fma)) { return status; } // ドメインオブジェクトをONNXRuntimeに登録する status = ort_api->AddCustomOpDomain(options, domain); return status; } } // extern "C"
register_domain()
は何かのI/Fとかではなく、今回独自に実装した。やっていることはドメインオブジェクトのポインタをdeleter付きのunique_ptrでくるんでvectorに登録しているだけ。
ドメインオブジェクトだけ?はインスタンスの削除をこちらで実施する必要があるっぽい。そしてdeleteで削除するのではなくRelease系メソッド呼び出しが必要らしい。
具体的な処理内容は実装コードを参照のこと。
推論処理
Pythonで推論するには以下のようにInferenceSessionの引数にSessionOptionsを渡せばよい。(ONNXモデルのカスタムオペレータのドメイン名、op_type名とC++実装コードのドメイン名、CustomOpFma::GetName()が返す名前が一致している必要がある)
import onnx import numpy as np import onnxruntime as ort option = ort.SessionOptions() option.register_custom_ops_library('./libmy_custom_t.so') model = onnx.load(_ONNX_FILE_NAME) sess = ort.InferenceSession(model.SerializeToString(), option) A = np.ones([1], dtype=np.float32) B = np.ones([1], dtype=np.float32) results = sess.run(None, {'A': A, 'B': B})