Load and Run an ONNX Model

onnx/models is a repository for storing the pre-trained ONNX models. Every ONNX backend should support running these models out of the box. After downloading and extracting the tarball of each model, there should be:

  • A protobuf file model.onnx which is the serialized ONNX model.
  • Several sets of sample inputs and outputs files (test_data_*.npz), they are numpy serialized archive.

In this tutorial, you’ll learn how to use a backend to load and run a ONNX model.

Example: Using TensorFlow backend

First, install ONNX TensorFlow backend by following the instructions here.

Then download and extract the tarball of ResNet-50.

Next, we load the necessary R and Python libraries (via reticulate):

library(onnx)
library(reticulate)
np <- import("numpy", convert = FALSE)
backend <- import("onnx_tf.backend")

We can then use the loaded numpy Python library to define a helper function to load testing sample from numpy serialized archive.

load_npz_samples <- function(npz_path) {
  sample <- np$load(normalizePath(npz_path), encoding = 'bytes')
  list(
    inputs = sample$items()[[0]][[1]][[0]],
    outputs = sample$items()[[1]][[1]]
  )
}

Finally, we can load the ONNX model and the testing samples, and then run the model using ONNX TensorFlow backend:

# Specify paths to ONNX model and testing samples
onnx_model_dir <- "~/Downloads/resnet50"
model_pb_path <- file.path(onnx_model_dir, "model.onnx")
npz_path <- file.path(onnx_model_dir, "test_data_0.npz")

# Load ONNX model
model <- load_from_file(model_pb_path)

# Load testing sample from numpy serialized archive
samples <- load_npz_samples(npz_path)
inputs <- samples$inputs
expected_outputs <- samples$outputs

# Run the model with an onnx backend
actual_outputs <- backend$run_model(model, inputs)

We can also use numpy to verify the result:

np$testing$assert_almost_equal(expected_outputs, actual_outputs, decimal = 6)

That’s it! Isn’t it easy? Next you can go ahead and try out different ONNX models as well as different ONNX backends, e.g. PyTorch, MXNet, Caffe2, CNTK, Chainer, etc.