DirectML使ってみた

冬は寒いのでDNNの学習を回すのにぴったり!GPUの廃熱で暖房費節約だぜ!などと思ったけれど、メインマシンはメインOSはWindowsで運用していてビデオカードAMDなのでDNNフレームワークを動かすのはしんどい。 調べたらDirectMLってのでWindows+AMDでもいけそうじゃん!ということで試しに使ってみた。
結論としてはGPU使用率があまり高くならず暖かくならなかった

環境構築(共通)

環境はAnacodaを使わずにWindowsPythonをインストールしてvenvで構築する。Anacondaはライセンス変わっちゃったからね

Pythonのインストール

python.orgからインストーラをダウンロードする。自分はこちらのページから3.8.10の「Windows installer (64-bit)」を選んだ。

インストール時にはpipが一緒にインストールされるようにオプションがOnになっていることを確認した。(自分の環境ではデフォルトでOnになっていた)

あと、Pythonのパッケージを入れる際に git.exe と cl.exe も必要になるのでインストールしてパスに追加する。

gitは「Git for Windows」を入れたような気がするが、ずいぶん前の事なので詳細は不明。

cl.exeはMicrosoft C++ Build Toolsページから「Build Tools のダウンロード」ボタンを押してインストーラを取得、インストーラからMicrosoft Visual C++だか何だかを選択して入れたような気がする。(こちらもうろ覚え)

venv

Pythonをインストールした時点でvenvも入っているのでそのままvenv環境を作れる。
構築直後はpipのバージョンが古いのでバージョンアップしておく。この時、直接pipコマンドでバージョンアップしようとすると環境を壊してしまうので注意。python -m pipでバージョンアップする。

>python -m venv env_top_dir
>env_top_dir\Scripts\activate
>python -m pip install --upgrade pip

ONNXRuntime

環境構築

venv環境にonnxとonnxruntime-directmlパッケージをインストールすればOK。

>pip install onnx onnxruntime-directml

お試し

基本的にはDirectMLのExecution Provider(DmlExecutionProvider)を指定するだけだが、2点注意点がある。

  1. opsetバージョンはv17まで
  2. セッションのオプションでenable_mem_patternを無効化しておく必要がある

どちらもDirectML版ONNXRuntimeが対応してないっぽい。ちなみにenable_mem_patternの方は無効化しなくても以下の警告が表示されて自動的に無効化されるっぽい。

[W:onnxruntime:, inference_session.cc:491 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.

以下お試しコード。モデルはConv1個だけのなんちゃってモデル。

import onnx
import onnx.numpy_helper
import numpy as np
import onnxruntime as ort


# Conv1個だけのモデル
inputs  = [onnx.helper.make_tensor_value_info('input' , onnx.TensorProto.FLOAT, [1, 3, 4, 4])]
outputs = [onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [1, 1, 4, 4])]
nodes   = [onnx.helper.make_node('Conv', ['input', 'weight'], ['output'])]
inits   = [onnx.numpy_helper.from_array(1.0 / 4.0 * np.ones([1, 3, 1, 1], dtype=np.float32), 'weight')]
model = onnx.helper.make_model(onnx.helper.make_graph(nodes, 'conv', inputs, outputs, inits), opset_imports=[onnx.helper.make_opsetid('', 17)])

# Onだと警告が出るのであらかじめOff設定を入れておく
options = ort.SessionOptions()
options.enable_mem_pattern = False

# ExecutionProviderにDML版を指定して実行する
sess = ort.InferenceSession(model.SerializeToString(), options, ['DmlExecutionProvider', 'CPUExecutionProvider'])
sess.run(None, {'input': np.ones([1, 3, 4, 4], dtype=np.float32)})

TensorFlow

環境構築

venv環境にpipで入れるだけ。ほかに必要なパッケージは依存パッケージとして自動で入った。

>pip install tensorflow-directml-plugin

お試し

これで普通に動いた。'/job:localhost/replica:0/task:0/device:GPU:0'などと表示されたのでたぶん動いてる。

import tensorflow as tf
a = tf.constant([1.5])
b = tf.constant([0.5])
(a + b).device

あと公式のサンプルをそのまま書かれている通りに実行してみたら普通に動いた。データセットのダウンロードも自動で実行してくれてとても楽だった。

PyTorch

環境構築

同じくvenv環境にpipで入れる。

>pip install torchvision==0.14.0
>pip install torch==1.13
>pip install torch-directml

お試し

torch.deviceをDirectMLのもので指定すれば良いらしい。Tensor.to()には文字列を指定できずtorch.deviceを渡す必要がある。

あと、torch.Tensorをrepr()などで表示しようとするとエラーになる。(CPUに転送すれば表示できる)

import torch
import torch_directml


dml = torch_directml.device()

a = torch.tensor([1.5]).to(dml)
b = torch.tensor([0.5]).to(dml)
c = a + b
c.to('cpu')

簡単なモデルを作って動かしてみたがConv2d、BatchNorm2d、ReLU、Linearあたりは普通に動きそうだった。

mmdetectionでDETRの学習

PyTorchで動く物体検出向けフレームワーク?のMMDetectionを使ってDETR実装で学習を回すところまで改造してみた。

結論を先に言っておくとGPU使用率は上がらず温まらなかったtouch.deviceを入れ替えるだけでは動かなかった。

まだまだCPU実行時と同じ動きをしてくれないオペレーションがあるので既存のフレームワークなんかをそのまま使うのは厳しい、ということが分かった。
今はまだ公式のサンプルを使うのがよさそうに思える。サンプルのyolov3を試そうとしたらデータセットのダウンロード方法がよくわからず面倒になってやめてしまった

環境構築

このあたりを参考にしつつ以下の手順でvenv環境にインストールした。

>pip install mmcv-full==1.7.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html
>git clone https://github.com/open-mmlab/mmdetection.git
>cd mmdetection
>pip install -v -e .
>pip install opencv-python

※DirectMLで動かすためにgit cloneしたmmdetectionリポジトリソースコードを改造して無理やり動かしている

さらにDETRの定義ファイルと重みデータをダウンロードする。

>pip install -U openmim
>mim download mmdet --config yolov3_mobilenetv2_320_300e_coco --dest checkpoints

データセットのダウンロード方法を見ながらMS COCOデータセットを用意して、学習の実行方法を参考にした。

最終的にはデータセットの置き場所をEドライブのdatasetsに変更していたので以下の感じで実行した。

>set "MMDET_DATASETS=E:/datasets/coco/"
>python source_packages\mmdetection\tools\train.py checkpoints\detr_r50_8x2_150e_coco.py --cfg-options data.samples_per_gpu=4

samples_per_gpuは1枚のビデオカードで一度に読み出すデータ数らしくてビデオカードが1枚しか存在しない環境ならそのままバッチサイズになるらしい。(たぶん。↑だとバッチサイズ4ということ)

困ったこと

DirectMLで動かそうとして遭遇したことは以下の通り。

  • VRAMが足りなくなるとブルースクリーンでOSごと落ちる(正確にはPCが再起動する)
  • DirectMLが対応していないオペレーションがある
    • エラーになるケース(Pythonの例外が送出される)とエラーにならず実行結果がCPU実行時と異なるケースの2パターンある
    • どちらのケースも該当箇所の処理をCPUデバイスで実行するようにすればとりあえず動くようになる

DirectMLが対応していなかった箇所(DETRで通過する箇所のみ)

mmdet/core/bbox/match_costs/match_cost.py

torch.cdist()で例外になる。

RuntimeError: The size of tensor a (2) must match the size of tensor b (100) at non-singleton dimension 0

CPU実行時はバッチ次元が異なっても問題なく実行できるがDirectML実行時はエラーになる。

@@ -47,8 +47,8 @@ class BBoxL1Cost:
             gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
         elif self.box_format == 'xyxy':
             bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
-        bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
-        return bbox_cost * self.weight
+        bbox_cost = torch.cdist(bbox_pred.to('cpu'), gt_bboxes.to('cpu'), p=1)
+        return bbox_cost.to(gt_bboxes.device) * self.weight
mmdet/core/bbox/samplers/pseudo_sampler.py

unique()でエラーになる。(※エラーの内容はメモり忘れてた…)

@@ -33,9 +33,9 @@ class PseudoSampler(BaseSampler):
             :obj:`SamplingResult`: sampler results
         """
         pos_inds = torch.nonzero(
-            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).cpu().unique().to(gt_bboxes.device)
         neg_inds = torch.nonzero(
-            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).cpu().unique().to(gt_bboxes.device)
         gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
         sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
                                          assign_result, gt_flags)
mmdet/models/dense_heads/detr_head.py

このファイルは2か所あって、1つ目はテンソルの一部をSlice指定で上書きするコードがDirectMLだとなぜか上書きされないという挙動になる。 2つ目はバッチ次元が0(教師データのBBox数が0個)の時に [0, 4] shape との演算にDirectMLが対応していなくて例外になる。

RuntimeError: self must have at least one element!

@@ -244,10 +244,11 @@ class DETRHead(AnchorFreeHead):
         # ignored positions, while zero values means valid positions.
         batch_size = x.size(0)
         input_img_h, input_img_w = img_metas[0]['batch_input_shape']
-        masks = x.new_ones((batch_size, input_img_h, input_img_w))
+        masks = x.new_ones((batch_size, input_img_h, input_img_w)).cpu()
         for img_id in range(batch_size):
             img_h, img_w, _ = img_metas[img_id]['img_shape']
             masks[img_id, :img_h, :img_w] = 0
+        masks = masks.to(x.device)

         x = self.input_proj(x)
@@ -537,8 +538,8 @@ class DETRHead(AnchorFreeHead):
         # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
         factor = bbox_pred.new_tensor([img_w, img_h, img_w,
                                        img_h]).unsqueeze(0)
-        pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
-        pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+        pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor if len(sampling_result.pos_gt_bboxes) else sampling_result.pos_gt_bboxes
+        pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) if len(sampling_result.pos_gt_bboxes) else sampling_result.pos_gt_bboxes
         bbox_targets[pos_inds] = pos_gt_bboxes_targets
         return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
                 neg_inds)

torch.device関連コード(参考)

説明が面倒になってきたのでそのままソースコードの差分だけ貼っておきます。

ちゃんと対応するにはmmcv側から改造が必要になるのと、mmdetection内ではtorch.deviceを使わずデバイス名のstrを受け取る形で実装されているので、以下の箇所以外に色々修正しないとダメだったりで不完全なので。

mmdet/apis/inference.py
@@ -151,6 +151,7 @@ def inference_detector(model, imgs):
             assert not isinstance(
                 m, RoIPool
             ), 'CPU inference with RoIPool is not supported currently.'
+        data['img'] = [cpu_tensor.to(device) for cpu_tensor in data['img']]

     # forward the model
     with torch.no_grad():
mmdet/apis/train.py
@@ -41,6 +41,10 @@ def init_random_seed(seed=None, device='cuda'):
     if world_size == 1:
         return seed

+    if device == 'dml':
+        import torch_directml
+        device = torch_directml.device()
+
     if rank == 0:
         random_num = torch.tensor(seed, dtype=torch.int32, device=device)
     else:
mmdet/utils/util_distribution.py
@@ -33,6 +33,12 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
         from mmcv.device.mlu import MLUDataParallel
         dp_factory['mlu'] = MLUDataParallel
         model = model.mlu()
+    elif device == 'dml':
+        import torch_directml
+        from mmdet.device.dml import DMLDataParallel
+        dp_factory['dml'] = DMLDataParallel
+        dml = torch_directml.device()
+        model = model.to(dml)

     return dp_factory[device](model, dim=dim, *args, **kwargs)
@@ -55,7 +61,7 @@ def build_ddp(model, device='cuda', *args, **kwargs):
                      DistributedDataParallel.html
     """
     assert device in ['cuda', 'mlu',
-                      'npu'], 'Only available for cuda or mlu or npu devices.'
+                      'npu', 'dml'], 'Only available for cuda or mlu or npu devices.'
     if device == 'npu':
         from mmcv.device.npu import NPUDistributedDataParallel
         torch.npu.set_compile_mode(jit_compile=False)

@@ -81,9 +93,18 @@ def is_mlu_available():
     return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()


+def is_dml_available():
+    try:
+        import torch_directml
+        return torch_directml.is_available()
+    except ImportError as e:
+        return False
+
+
 def get_device():
     """Returns an available device, cpu, cuda or mlu."""
     is_device_available = {
+        'dml': is_dml_available(),
         'npu': is_npu_available(),
         'cuda': torch.cuda.is_available(),
         'mlu': is_mlu_available()
mmdet/device/dml 配下

こちらは本来はmmcvに入れるべきコード。面倒なのでmmdetection配下に入れた。

# __init__.py
from ._functions import scatter, scatter_kwargs
from .data_parallel import DMLDataParallel
from .distributed import DMLDistributedDataParallel


__all__ = ['scatter', 'scatter_kwargs', 'DMLDataParallel', 'DMLDistributedDataParallel']


# _functions.py
import torch
import torch_directml
from typing import Union, List
from mmcv.parallel.data_container import DataContainer
from mmcv.device._functions import Scatter


def _scatter_core(current_device: torch.device, obj: Union[List, torch.Tensor]):
    if isinstance(obj, list):
        return [_scatter_core(current_device, elem) for elem in obj]
    elif isinstance(obj, torch.Tensor):
        return obj.to(current_device)
    else:
        raise RuntimeError(f'obj is unsupported type {type(obj)}')


def _scatter_data_container(current_device: torch.device, obj: DataContainer):
    outputs = _scatter_core(current_device, obj.data)
    return tuple(outputs) if isinstance(outputs, list) else (outputs,)


def scatter(inputs, target_devices, dim=0):
    device_id = next(iter(target_devices), torch_directml.default_device())
    current_device = torch_directml.device(device_id)

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            if target_devices != [-1]:
                obj = obj.to(current_device)
                return [obj]
            else:
                # for CPU inference we use self-implemented scatter
                return Scatter.forward(target_devices, obj)
        if isinstance(obj, DataContainer):
            if obj.cpu_only:
                return obj.data
            else:
                return _scatter_data_container(current_device, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            out = list(map(list, zip(*map(scatter_map, obj))))
            return out
        if isinstance(obj, dict) and len(obj) > 0:
            out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
            return out
        return [obj for _ in target_devices]

    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
    inputs = scatter(inputs, target_devices, dim) if inputs else []
    kwargs = scatter(kwargs, target_devices, dim) if kwargs else []

    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])

    inputs = tuple(inputs)
    kwargs = tuple(kwargs)

    return inputs, kwargs


# data_parallel.py
import torch_directml
from mmcv.parallel import MMDataParallel
from ._functions import scatter_kwargs


class DMLDataParallel(MMDataParallel):
    def __init__(self, *args, dim=0, **kwargs):
        super().__init__(*args, dim=dim, **kwargs)

        self.device_ids = kwargs.get('device_ids', [torch_directml.default_device()])
        self.src_device_obj = torch_directml.device(self.device_ids[0])

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)