ONNXRuntimeのカスタムオペレータを実装してみた

ソースコードはここ↓ github.com

背景とか

TensorRTを使おうとしてモデルの一部をONNXのカスタムオペレータにすることがある。(TensorRTのプラグインを使うケース)

ただカスタムオペレータを含むモデルはそのままではONNXRuntimeで推論できない。ということはONNX Simplifierのようなツールを使うこともできない。 カスタムオペレータを実装することで推論が可能になってツール類も使える。

# ONNXのカスタムオペレータはノードの'domain' attributeに独自ドメイン名を指定すれば作れる
nodes = [onnx.helper.make_node('Fma', ['A', 'B', 'C'], ['out'], domain='ai.onnx.contrib')]

カスタムオペレータの実装手段

少し調べたところ、大きく分けて2つのやり方がある。

  1. Pythonでカスタムオペレータを実装する
  2. C++でカスタムオペレータを実装する

Pythonで実装すると推論コードもPythonで記載できてPythonのみで実現できるので楽ちん。ただし、onnxruntime-extensionsパッケージが必要になるのと、強めの制約がある。(後述)

C++は実装が面倒だがPythonよりは制約を緩められる。

Pythonで実装する方法

onnx_opデコレータをカスタムオペレータ実装ルーチンにつけるだけ。引数のテンソルはnumpy.ndarrayが渡される。戻り値のテンソルもndarrayを返せばOK。

# 引数はすべてfloat32型、戻り値もfloat32型。op_typeは'Fma'
@onnx_op(op_type='Fma', inputs=[PyOp.dt_float, PyOp.dt_float, PyOp.dt_float], outputs=[PyOp.dt_float])
def fma(a, b, c):
    return a * b + c

ただし、float32バージョンfloat64バージョン、のように扱うデータ型のバリデーションを作ることができない。 これはop_typeに対するルーチンが1つしか登録できない作りになっているためと思われる。

推論は以下のようにする。以下のコードでmodel_func()の呼び出しがONNXモデル全体の推論実行処理になっている。

model_func = PyOrtFunction.from_model(_ONNX_FILE_NAME)
result = model_func(A, B, C)

C++で実装する方法

C++で実装する場合は主に以下の要素を用意すればよい。

  • void Compute(OrtKernelContext* context)メソッドを持つkernelクラス
    • 計算処理本体を実装する
  • Ort::CustomOpBase<Op, Kernel>クラスを継承し必要なメソッドを実装したクラス
    • void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const
    • const char* GetName() const
      • op_type名を返す
    • ONNXTensorElementDataType GetInputType(size_t index) const
      • index番目の入力データ型を返す
    • size_t GetInputTypeCount() const
      • オペレータ引数の数を返す
    • OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const
      • 省略可能な引数を持つ場合に実装する
      • index番目の引数が必須か省略可能かを返す
    • ONNXTensorElementDataType GetOutputType(size_t index) const
      • index番目の戻り値のデータ型を返す
    • size_t GetOutputTypeCount() const
      • オペレータの戻り値の数を返す
    • OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const
      • 省略可能な戻り値を持つ場合に実装する
      • index番目の戻り値が必須か省略可能かを返す
  • RegisterCustomOps()関数

ソースコードのビルドはONNXRuntimeの3つのヘッダファイルがあればOK。 CMakeLists.txtではfind_path()でヘッダファイルの場所を探すようにしているので参考に。

kernelクラス

kernelクラスは大雑把に以下の構造になるように実装する。

struct FmaKernel {
    FmaKernel(OrtApi api):api_(api), ort_(api_) {}

    void Compute(OrtKernelContext* context) {
        // ... カスタムオペレータの計算処理
    }
private:
    OrtApi api_;
    Ort::CustomOpApi ort_;
};

入力データへのポインタなどの計算に必要なデータはOrt::CustomOpApiでアクセスできる。ただし、CustomOpApiクラスはコンストラクタ引数のOrtApiインスタンスを参照で保持するためOrtApiインスタンスのコピーをkernelクラスのメンバに保持する必要があるとのこと。

// 0番目の引数(float型)へのポインタをもらう例
const auto input_a = ort_.KernelContext_GetInput(context, 0);
auto ptr_a = ort_.GetTensorData<float>(input_a);

// 出力0を[1, 3, 224, 224]のshapeで作ってポインタをもらう例
size_t shape_dim = 4;
const int64_t shape[shape_dim] = {1, 3, 224, 224};
auto output_0 = ort_.KernelContext_GetOutput(context, 0, shape, shape_dim);
auto ptr_0 = ort_.GetTensorMutableData<float>(output_0);

もし出力shapeが入力と同じならGetTensorTypeAndShape()GetTensorShape()を呼ぶと入力shapeをもらえるので出力shapeの指定にそのまま使えばよい。

オペレータクラス

オペレータクラスは以下のようにCustomOpBaseクラスを継承する。CustomOpBaseクラスのテンプレート引数には自分自身とkernelクラスを指定する。

struct CustomOpFma : Ort::CustomOpBase<CustomOpFma, FmaKernel> {
    // 対応するkernelクラスのインスタンスをnewして返す
    void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
        return new FmaKernel(api);
    }
    // 単純に対応するop_type名を返すだけでOK
    const char* GetName() const {
        return "Fma";
    }
    // ...
};

後のI/Fは特筆すべき内容は無いのでinput側だけ掲載する。具体的な実装コードはGitHubにpushした実装コードを参照のこと。

 // index番目の引数のデータ型をONNXTensorElementDataType(onnxruntime_c_api.hで定義されている)で返す
    ONNXTensorElementDataType GetInputType(size_t index) const {
        if (index > 0) {
            // "T"を指定可能なのは1つの引数のみ。残りはFLOATなどの具体的な型を返す必要がある
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
        }
        // UNDEFINEDを返すとデータ型は"T"(任意型)として扱われる
        return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
    }
    // 引数の数を返す。Fmaの例なら引数は3つなので3を返せばOK
    size_t GetInputTypeCount() const {
        return 3;
    }
    // index番目の引数が省略可能ならOPTIONAL、必須ならREQUIREDを返す
    OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
        if (index > 1) {
            // 3つ目(index == 2)の引数を省略可能とする例
            return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
        }
        return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
    }

注意点

float32版オペレータfloat64版オペレータなどと複数データ型に対応させたい場合に問題点がある。2つのやり方がある。

  1. データ型をONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINEDで報告する
  2. データ型ごとに別々のドメイン名に分ける

1の方法はデータ型を"T"(任意型)に指定する方法だが、任意型をとる引数が1つのみのケースしか使えない。FMAのように3つの引数を"T"と指定することができない。

2の方法はONNXモデル側でドメインを分ける必要が発生してしまうかわりに各データ型に対応する実装を作ることができる。もしデータ型ごとにop_type名を変えることができるならドメイン名は同じでop_type名で分ける方法もある。 (カスタムオペレータはop_type名につき1種類の実装しか登録できないためop_type名かドメイン名を分ける必要がある)

RegisterCustomOps関数

このルーチンはDLLのロード時(正確にはregister_custom_ops_library()呼び出し時)に呼ばれる。

constexpr std::string_view domain_name = "my_ops";
CustomOpFma op_fma;

extern "C" {
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
    const OrtApi* ort_api = api_base->GetApi(ORT_API_VERSION);
    OrtStatus* status;

    // ドメインオブジェクトを作る
    OrtCustomOpDomain* domain = nullptr;
    if (status = ort_api->CreateCustomOpDomain(domain_name.data(), &domain)) {
        return status;
    }

    // ドメインオブジェクトを(後で解放処理を実行するために)覚えておく
    register_domain(domain, ort_api);

    // ドメインオブジェクトにカスタムオペレータを登録する
    if (status = ort_api->CustomOpDomain_Add(domain, &op_fma)) {
        return status;
    }

    // ドメインオブジェクトをONNXRuntimeに登録する
    status = ort_api->AddCustomOpDomain(options, domain);
    return status;
}
}   // extern "C"

register_domain()は何かのI/Fとかではなく、今回独自に実装した。やっていることはドメインオブジェクトのポインタをdeleter付きのunique_ptrでくるんでvectorに登録しているだけ。 ドメインオブジェクトだけ?はインスタンスの削除をこちらで実施する必要があるっぽい。そしてdeleteで削除するのではなくRelease系メソッド呼び出しが必要らしい。

具体的な処理内容は実装コードを参照のこと。

推論処理

Pythonで推論するには以下のようにInferenceSessionの引数にSessionOptionsを渡せばよい。(ONNXモデルのカスタムオペレータのドメイン名、op_type名とC++実装コードのドメイン名、CustomOpFma::GetName()が返す名前が一致している必要がある)

import onnx
import numpy as np
import onnxruntime as ort

option = ort.SessionOptions()
option.register_custom_ops_library('./libmy_custom_t.so')

model = onnx.load(_ONNX_FILE_NAME)
sess = ort.InferenceSession(model.SerializeToString(), option)

A = np.ones([1], dtype=np.float32)
B = np.ones([1], dtype=np.float32)
results = sess.run(None, {'A': A, 'B': B})