Published on

Defining ONNX graphs with ndonnx

Authors

ndonnx is an ONNX-backed Python array library that implements the Array API standard. It helps us take our machine learning models into production at QuantCo by facilitating ONNX export of Array API compliant code and providing a high level API for constructing ONNX graphs.

Motivation

In the previous post in our ONNX blog series, we showed how you can export a trained linear regression model to ONNX using the Spox library. You might have noticed that we implemented the prediction path twice - once using NumPy (LinearRegression.predict) for training and experimentation and a second time using Spox (linear_regression) in order to be able to export model to ONNX.

The two implementations are shown below; note that they are edited slightly for ease of understanding.

import numpy as np

class LinearRegression:
    _coefficients: np.ndarray
    _intercept: float

    def fit(self, X, y):
        y_offset = np.average(y, axis=0)
        X_offset = np.average(X, axis=0)
        self._coefficients, *_ = np.linalg.lstsq(X - X_offset, y - y_offset)
        self._intercept = y_offset - X_offset @ self._coefficients
        return self

    def predict(self, X):
        return X @ self._coefficients + self._intercept

The Spox implementation operates on Var objects rather than NumPy arrays.

from spox import Var
import spox.opset.ai.onnx.v20 as op

def linear_regression(X: Var, coefficients: np.ndarray, intercept: float) -> Var:
    mul_result = op.matmul(X, op.const(coefficients))

    return op.add(
        mul_result,
        op.const(intercept),
    )

Implementing the same algorithm twice using two completely different primitives is error-prone and challenging to maintain. It forces library developers to learn a second toolchain just to support ONNX export and continuously maintain two code paths that need to remain semantically equivalent, detracting from feature development time. Unfortunately, manually reimplementing code in this style is still the state-of-the-art approach in the broader ONNX community, with libraries like scikit-learn having dedicated libraries for ONNX conversion1.

We identified that the root cause of this problem is the lack of interoperability between ONNX tooling and existing array libraries.

We turned to the Array API standard for a solution. It provides a well-specified API for common array operations and enables library authors to target multiple array backends while maintaining a single code path2. Libraries such as NumPy, JAX and CuPy have adopted it in their primary APIs3.

With ndonnx, we developed an ONNX backend to the Array API standard, enabling developers to integrate ONNX export capabilities in their libraries without the traditional maintenance overhead and expertise in ONNX tooling. Concretely for us, this means that the LinearRegression.predict method we have already seen can be exported to ONNX without the need for a separate Spox implementation.

Using ndonnx

ndonnx will feel very familiar to NumPy users since the Array API is a subset of NumPy's current API. One important point of difference is that ndonnx arrays can be instantiated with only a shape and data type. These arrays represent inputs to an ONNX graph and operations with them are abstractly evaluated - this enables functions like LinearRegression.predict to be traced in order to serialize them to ONNX.

>>> import ndonnx as ndx
>>> ndx.array(shape=("N", 3), dtype=ndx.float64)
Array(dtype=Float64)

Notice that the shape we provided can have symbolic dimensions as well as concrete integer ones, something we discussed in our post about Spox.

Let's use this array to export LinearRegression.predict to ONNX. We will first need to make a small update to the method to make it Array API compatible (and therefore able to accept multiple array backends as input).

class LinearRegression:
    _coefficients: np.ndarray
    _intercept: float

    # `fit` is unchanged from before
    def fit(self, X, y): ...

    def predict(self, X):
        xp = X.__array_namespace__()
        return X @ xp.asarray(self._coefficients) + self._intercept

X_np = np.asarray([[0, 0, 0], [1, 1, 1], [2, 2, 2]], np.float64)
y_np = np.asarray([1, 2, 3], np.float64)

model = LinearRegression().fit(X_np, y_np)

We first extract the input array X's namespace xp, which is the module containing top-level functions for the array backend in use. In this simple example, we only need this to use xp.asarray to convert our trained coefficients (a NumPy array) into the same array backend as X. You can now export this to ONNX using our ndonnx array from before.

import onnx

# The same array from before
x = ndx.array(shape=("N", 3), dtype=ndx.float64)
y = model.predict(x)

# Build and save ONNX model to disk
model_proto = ndx.build({"X": x}, {"y": y})
onnx.save(model_proto, "linear_regression.onnx")

Using Netron, we can visualize linear_regression.onnx.

linear_regression.onnx

We successfully exported the LinearRegression.predict method to ONNX without the need for a separate implementation! This is a significant improvement over the traditional approach of writing a second ONNX implementation and we expect that this will be a much more maintainable solution in the long run. We will now explore some other features of ndonnx and see how it compares with existing ONNX conversion tooling.

Eager evaluation

Operations can also be eagerly evaluated when ndonnx arrays are instantiated with data, much like any regular array library. This is especially valuable for quick prototyping and debugging, as it allows you to track intermediate values using standard Python debugging tools. It also enables ndonnx to constant fold graphs for free when building ONNX graphs for export.

import numpy as np
import jax.numpy as jnp
import ndonnx as ndx

pred_np = model.predict(np.asarray([[1, 2, 3]]))
pred_jax = model.predict(jnp.asarray([[1, 2, 3]]))
pred_onnx = model.predict(ndx.asarray([[1, 2, 3]]))

print(pred_np, pred_jax, pred_onnx.to_numpy())
# [3.] [3.] [3.]

Notice that since LinearRegression.predict is Array API compliant, we are able to run it using NumPy, JAX or ndonnx arrays without the need for separate implementations. Eager evaluation in ndonnx works by dispatching each incremental operation to an ONNX backend like onnxruntime.

Interoperability with Spox

While using the high-level API provided by ndonnx is usually the correct option, sometimes dropping down to a lower level of abstraction is necessary. Controlling precisely which operators are emitted can be useful for performance optimization and using non-standard ONNX operators can be valuable in niche domains.

ndonnx makes it easy to drop down to the ONNX operator-level primitives provided by Spox for this purpose and it is just as easy to move back to ndonnx4.

x = ndx.array(shape=(2, 3), dtype=ndx.float64)

# Freely use Spox
var = op.log_softmax(x.spox_var())

# Return to ndonnx
y = ndx.from_spox_var(var)

The example above demonstrates how you can compute the log of the softmax of a ndonnx array using the specialized LogSoftmax ONNX operator, despite a dedicated function for this not being exposed by ndonnx itself. Developers are empowered to create libraries that take full advantage of the ONNX operator set.

Advanced features

ndonnx has numerous additional features we use extensively at QuantCo like strings, nullability5, user-defined data types and custom operator integration. While we did not go far beyond a very simple core set of Array API features in this post, the documentation has a complete overview.

Alternatives for constructing ONNX graphs

There are several alternatives for constructing ONNX graphs and we'll touch on how ndonnx might differ.

Spox and the onnx.helper package stay close to the ONNX standard. They do not provide the conveniences you would expect from a NumPy-like array library like type promotion, operator overloading and common array manipulation functions. We expect ndonnx will be more approachable for the majority of developers who are familiar with NumPy but not with the details of the ONNX standard. The higher-level API means that ndonnx is also a more useful foundation for building higher-level primitives seen in machine learning pipelines like dataframes. As mentioned earlier, Spox is still vital in places where fine-grained control over the ONNX graph is required or in models not made up of tensor operations.

Another library in this space is ONNX Script, which approaches the ONNX conversion challenge partly by translating Python language constructs to ONNX. One point of similarity is that both ndonnx and ONNX Script are fairly high-level libraries that aim to make ONNX conversion more approachable. The first difference is that ONNX Script does not aim to reduce code duplication between libraries and their ONNX implementations. Secondly, ndonnx very intentionally does not translate Python language constructs to ONNX for similar design reasons to Spox (we touch on this in our Spox post).

Afterword

In this article we presented ndonnx, an ONNX-backed Python array library. We showed how ndonnx brings the Array API to ONNX and discussed how this enables machine learning library developers to more productively add ONNX support to their libraries. We encourage you to try it out and ask any questions over on GitHub.

Footnotes

  1. Converter libraries target a specific machine learning library and provide semantically-equivalent implementations of functionality from the source library in ONNX. As an example, the sklearn-onnx library provides ONNX conversion for scikit-learn.

  2. We recommend NumPy Enhancement Proposals 47 and 56 for background on the motivation and scope of the standard.

  3. NumPy added it to its main namespace in NumPy 2 and JAX added it to jax.numpy namespace in the 0.4.32 release.

  4. Note that ndonnx contains nullable types and permits user-defined types where an Array may correspond to more than one Spox Var. See the moving between Spox and ndonnx section of the documentation for more detail.

  5. ndonnx is a superset of the Array API standard and includes additional data types like nullable variants of standard types as well as a utf8 data type for string data.