최근 tensorflow로 학습한 모델인 .pb 파일을 OpenCV에서 실행해 보려 했다. 그런데 지속적으로 문제가 발생했다. 결론은 플랫폼 간 호환이 완벽하지 않기 때문이었다. 해결책으로 ONNX로 모델을 저장하고 불러오라는 답변이 있었다.
본 글은 Pytorch를 중심으로 ONNX 사용 방법에 대해 알아본다. OpenCV 문제를 해결한 코드는 마지막에 링크로 첨부되어 있다.
ONNX란
ONNX는 Open Neural Network eXchange의 약자로, 플랫폼 간 충돌을 줄일 수 있도록 만들어진 오픈소스 프로젝트이다. 예를 들어, tensorflow로 학습한 모델은 tensorflow 위에서 사용해야 한다. 하지만 학습한 모델을 다른 플랫폼에서 사용해야 하는 상황이 생긴다. 이때 ONNX를 이용한다.
Pytorch는 자체적으로 export 기능을 제공한다.
Pytorch to ONNX
학습된 CNN 모델을 예로 들어보자.
import torch
import torch.onnx
model = TrainedCNN() # 예시
model.eval()
model.cpu()
GPU가 없는 환경에서 불러올 때 문제가 발생하지 않도록 cpu로 옮겨준다.
onnx_path = "cnn.onnx"
# batch: 1, shape: (1, 28, 28), dtype: float32
dummy_input = torch.empty(1, 1, 28, 28, dtype=torch.float32)
torch.onnx.export(
model, dummy_input, onnx_path
)
ONNX 저장 전, 입력 크기를 미리 입력해 준다. ONNX는 정적 그래프를 사용하기 때문이다. 여기까지 하면 모델 구조와 입출력 크기가 export 된다. 이게 끝이다.
ONNX 추가 기능
pip install onnx onnxruntime
import onnxruntime
from PIL import Image
from torchvision import transforms
# Read a digit image(MNIST).
image = Image.open("sample.png")
image = image.convert("L")
preprocess = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
input_tensor = preprocess(image)
numpy_array = input_tensor.numpy().reshape(1, 1, 28, 28)
입력 데이터는 torch.tensor가 아닌 numpy.ndarray 형식으로 준비해야 한다. 위 코드는 MNIST 이미지를 가져오는 예시다.
# ONNX 모델 불러오기
session = onnxruntime.InferenceSession("cnn.onnx")
# 입출력 메타 데이터
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
numpy_array # 입력(np.ndarray)
outputs = session.run([output_name], {input_name: numpy_array})
print(f"Model Output: {outputs[0].squeeze(0).argmax(-1)}")
InferenceSession으로 세션을 생성하고, run으로 모델 추론을 실행한다.
모델 검증
import onnx
onnx_model = onnx.load(onnx_path)
onnx.save(onnx.shape_inference.infer_shapes(onnx_model), onnx_path)
각 layer 입출력 크기를 저장하려면 infer_shapes를 거치면 된다.
import numpy as np
onnx_path = "cnn.onnx"
onnx_model = onnx.load(onnx_path)
# 모델 상태 확인
onnx.checker.check_model(onnx_model)
# Pytorch와 ONNX 모델 비교
actual # Pytorch 모델 출력
desired # ONNX 모델 출력
# abs(actual - desired) <= atol + rtol * abs(desired)
np.testing.assert_allclose(actual, desired, rtol=1e-7, atol=0)
check_model은 자체적으로 모델을 검증해 준다. assert_allclose는 pytorch 모델 출력과 onnx 모델 출력을 비교한다.
print(onnx.helper.printable_graph(onnx_model.graph))
"""
graph main_graph (
%input[FLOAT, 1x1x28x28]
) initializers (
%conv1.0.weight[FLOAT, 16x1x3x3]
%conv1.0.bias[FLOAT, 16]
%conv2.0.weight[FLOAT, 32x16x3x3]
%conv2.0.bias[FLOAT, 32]
%classifier.1.weight[FLOAT, 32x1568]
%classifier.1.bias[FLOAT, 32]
%classifier.4.weight[FLOAT, 10x32]
%classifier.4.bias[FLOAT, 10]
) {
%/conv1/conv1.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input, %conv1.0.weight, %conv1.0.bias)
%/conv1/conv1.1/Relu_output_0 = Relu(%/conv1/conv1.0/Conv_output_0)
%/conv1/conv1.2/MaxPool_output_0 = MaxPool[ceil_mode = 0, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 1, 1, 1], strides = [2, 2]](%/conv1/conv1.1/Relu_output_0)
%/conv2/conv2.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%/conv1/conv1.2/MaxPool_output_0, %conv2.0.weight, %conv2.0.bias)
%/conv2/conv2.1/Relu_output_0 = Relu(%/conv2/conv2.0/Conv_output_0)
%/conv2/conv2.2/MaxPool_output_0 = MaxPool[ceil_mode = 0, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 1, 1, 1], strides = [2, 2]](%/conv2/conv2.1/Relu_output_0)
%/classifier/classifier.0/Flatten_output_0 = Flatten[axis = 1](%/conv2/conv2.2/MaxPool_output_0)
%/classifier/classifier.1/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/classifier/classifier.0/Flatten_output_0, %classifier.1.weight, %classifier.1.bias)
%/classifier/classifier.2/Relu_output_0 = Relu(%/classifier/classifier.1/Gemm_output_0)
%output = Gemm[alpha = 1, beta = 1, transB = 1](%/classifier/classifier.2/Relu_output_0, %classifier.4.weight, %classifier.4.bias)
return %output
}
"""
printable_graph는 불러온 모델 구조를 출력한다.
관련 코드
- CNN to ONNX: denev6/deep-learning-codes/cnn_onnx.ipynb
- ONNX 모델 검증: denev6/deep-learning-codes/verify_onnx.py
- OpenCV에서 불러오기: denev6/deep-learning-codes/cnn_mnist.cpp