Published on

Optimize pickling disk space for deploying scikit-learn trees to production

Authors

Update 04/2023

We have created a package that implements the ideas presented in this article and presented it at PyData Berlin 2023. Feel free to check slim-trees out on GitHub.


At QuantCo, we use ML models such as random forests to create value from data for our customers. To ship these models to production, we have to include them in the application, e.g., bake them into a Docker container or fetch them from a model registry. To save these models, we pickle them.

If we want to improve the model, we sometimes increase the number of parameters. This comes at the cost of disk space because the model is getting more complex. In our case, we were not able to deploy a larger model because of a file size constraint. Nevertheless, in general it makes sense to strive for smaller Docker images to achieve quicker transfer and deployment.

One way to solve this problem would be to use a different format altogether, like ONNX. Instead, we challenged the scikit-learn pickle format to see if there are any gains to be had. We found a 4x improvement over the native scikit-learn format.

Here are the numbers:

dtype compressioncompression algorithmsizedump timeload time
nono496 MB1.4 s0.7 s
yesno128 MB1.6 s0.9 s
nolzma144 MB217 s8.4 s
yeslzma32 MB53 s2.8 s

In this article, we present how we did it.

How a Tree is saved

Usually, when pickling an object, the methods __reduce__ and __setstate__ get called to serialize the object. The following code illustrates some of what is happening when calling pickle.dump(obj).

def pickle_obj(obj):
    cls, init_args, state = obj.__reduce__()
    return unpickle_obj, (cls, init_args, state)


def unpickle_obj(cls, init_args, state):
    obj = cls(*init_args)
    obj.__setstate__(state)
    return obj


myobj = ...
unpickle_func, myobj_data = pickle_obj(myobj)
save_to_file("myobj.pkl", unpickle_func, myobj_data)
...
unpickle_func, myobj_data = load_from_file("myobj.pkl")
myobj_unpickled = unpickle_func(*myobj_data)

Let's look at the structure of state when we call tree.__reduce__().

{
    'max_depth': 19,
    'node_count': 469,
    'nodes': np.array(..., 
        shape=(469,),
        dtype=np.dtype([
            ('left_child', '<i8'), 
            ('right_child', '<i8'), 
            ('feature', '<i8'), 
            ('threshold', '<f8'), 
            ('impurity', '<f8'), 
            ('n_node_samples', '<i8'), 
            ('weighted_n_node_samples', '<f8')])),
    'values': np.array(..., 
        shape=(469, 1, 1), dtype=float64)
}

Since max_depth and node_count are only single integers, they can be neglected when trying to minimize the disk space. nodes is a NumPy data type object, each element of the array is actually a tuple containing the indices of the children, the used feature, etc. per node. This has a lot of potential for savings. value is a separate array because it could be a multidimensional output too which doesn't fit in the data type object structure of nodes.

Returning to the state of our Tree, we can see that all the elements of nodes are either of type <i8 or <f8, i.e., int64 or float64. This takes up a lot of disk space. To mitigate this problem, we could simply change how these parameters of the model are stored to make the serialization more efficient.

How can we implement changing all float64 and int64 numbers to more efficient datatypes? To do this, we can customize the behaviour of scikit-learn when pickling a Tree.

Custom picklers

We will be creating a custom pickler whose behavior varies when pickling a Tree object. Thus, every instance of a Tree in our model structure will get pickled in a custom manner while the rest of the pickling behavior stays the same.

To achieve this, we need to customize the dispatch table of said pickler.

def dump_dtype_reduction(model: Any, file: BinaryIO):
    p = pickle.Pickler(file)
    p.dispatch_table = copyreg.dispatch_table.copy()
    p.dispatch_table[Tree] = compressed_tree_pickle
    p.dump(model)

The method compressed_tree_pickle is supposed to replace the method pickle_obj from above.

Since we want to reduce the size of a Tree's state, we want to do something like this.

def compressed_tree_pickle(tree):
    assert isinstance(tree, Tree)
    cls, init_args, state = tree.__reduce__()
    compressed_state = compress_tree_state(state)
    return compressed_tree_unpickle, (cls, init_args, compressed_state)


def compressed_tree_unpickle(cls, init_args, state):
    tree = cls(*init_args)
    decompressed_state = decompress_tree_state(state)
    tree.__setstate__(decompressed_state)
    return tree

Now, all that's left is to implement compress_tree_state and decompress_tree_state to a more efficient format.

Compressing a Tree state

We can see that left_child and right_child are stored as <i8, i.e., int64 numbers. But since there are not that many leaves in a tree, the possible values of the array containing the children of each node are relatively small. Thus, we can change the dtype of the array to int16 for example. The same can be done with the feature array. This attribute stores the index of the features to be used in the node. The value array contains the actual output values of the leaves as float64 numbers. With a bit of loss of precision, we could use float32 here.

There are now four remaining attributes of state: threshold, impurity, n_node_samples, and weighted_node_samples. The last three are all purely analytical and not used during inference. Thus, we can just leave them out in our custom pickling solution.

Now let's look at threshold, i.e., the splitting points at each node, whether to go to the left or right child in the tree. We could simply do the same as with the value array and call it a day. But unlike the value array, we need to have in mind that error propagation can occur. If a "wrong" turn is made in the beginning, we could land at a completely different output node than we should.

There are three ways to continue:

  1. accept the extra loss and continue;
  2. use float64 again in the threshold array and don't gain disk space in this array;
  3. examine the structure of the threshold array and try to find a way to do lossless compression.

Let's try out the latter option.

Compressing half ints

The following only works if you have many categorical columns in your dataset. If you don't, you may want to go with 1. or 2. or look into other methods to do lossless compression.

When we looked at the structure of the threshold array, we realized that there are a lot of values that are of the form x.5, i.e., 3.5, 0.5, 9.5, etc. Let's call them half ints. This is because our dataset has a lot of categorical values which are converted to int and then cast to float64 for fitting by scikit-learn. Since every point still lies on an integer, the optimal splitting point is in the middle of two integers. These optimal splitting points are exactly what the sklearn trees learn. Now what can we do with this information?

We could compress these half ints by multiplying them with 2 and casting them to int8.

Because of floating point precision, we need to make sure not only to catch the exact integers but an epsilon neighborhood of the integer values.

a2 = np.array([1.9999999999999, 2.0000000000001])
np.minimum(np.abs(a2 % 1 - 1), a2 % 1)
# array([9.99200722e-14, 9.99200722e-14])
def compress_half_int_float_array(a, compression_dtype=np.int8):
    info = np.iinfo(compression_dtype)
    a2 = 2. * a
    is_compressible = (np.minimum(np.abs(a2 % 1 - 1), a2 % 1) < 1e-12) & \
                      (a2 >= info.min) & (a2 <= info.max)
    not_compressible = np.logical_not(is_compressible)

    a2_compressible = a2[is_compressible].astype(compression_dtype)
    a_incompressible = a[not_compressible]

    state = {
        "is_compressible": is_compressible,
        "a2_compressible": a2_compressible,
        "a_incompressible": a_incompressible,
    }

    return state


def decompress_half_int_float_array(state):
    is_compressible = state["is_compressible"]
    a = np.zeros(len(is_compressible), dtype="float64")
    a[is_compressible] = state["a2_compressible"] / 2.
    a[np.logical_not(is_compressible)] = state["a_incompressible"]
    return a

This format is very efficient if you have a lot of these half int values in your array. state["is_compressible"] is a bool array which takes up only one bit per entry, state["a2_compressible"] is an int8 array which only takes up eight bits per entry and state["a_incompressible"] takes up the full 64 bits per entry.

In our productive example, the threshold array contained 4000 elements with 65-70% of them being half integers. This leads to a theoretical compression of

40000.658  Bit+40000.3564  Bit+40001  Bit400064  Bit=0.658+0.3564+1640.45\frac{4000 \cdot 0.65 \cdot 8\;\mathrm{Bit} + 4000 \cdot 0.35 \cdot 64\;\mathrm{Bit} + 4000 \cdot 1\;\mathrm{Bit}}{4000 \cdot 64\;\mathrm{Bit}} = \frac{0.65 \cdot 8 + 0.35 \cdot 64 + 1}{64} \approx 0.45

compared to the float32 compression of 32  Bit/64  Bit=0.532\;\mathrm{Bit} / 64\;\mathrm{Bit} = 0.5. So better theoretical compression performance while staying lossless.

Further optimizations

All the arrays that we have talked about have the same length, namely the number of nodes in the tree. But when thinking about it, the leaves don't really need a left/right child, a corresponding feature or threshold. It only needs the value. That's why the values in the children_left and children_right fields are -1 there and the value of features and thresholds are -2.

We can just delete these elements and save some space again by creating a bool array is_leaf that is used as a mask for getting and setting the corresponding values.

def compress_tree_state(state: dict):
    ...
    is_leaf = nodes['left_child'] == -1
    is_not_leaf = np.logical_not(is_leaf)

    # feature, threshold and children are irrelevant for leaves
    children_left = nodes['left_child'][is_not_leaf].astype(dtype_child)
    children_right = nodes['right_child'][is_not_leaf].astype(dtype_child)
    features = nodes['feature'][is_not_leaf].astype(dtype_feature)
    # do lossless compression for thresholds by downcasting 
    # half ints (e.g. 5.5, 10.5, ...) to dtype_halfint
    thresholds = nodes['threshold'][is_not_leaf].astype(dtype_threshold)
    thresholds = compress_half_int_float_array(thresholds, dtype_halfint)

    # value is irrelevant when node is not a leaf
    values = state['values'][is_leaf].astype(dtype_value)
    ...


def decompress_tree_state(state: dict):
    ...
    n_edges = len(is_leaf)

    children_left = np.empty(n_edges, dtype=np.int64)
    children_right = np.empty(n_edges, dtype=np.int64)
    features = np.empty(n_edges, dtype=np.int64)
    thresholds = np.empty(n_edges, dtype=np.float64)

    # same shape as values but with all edges instead of only the leaves
    values = np.zeros((n_edges, *state['values'].shape[1:]), dtype=np.float64)

    children_left[is_not_leaf] = state['children_left']
    children_left[is_leaf] = -1  # child of leaves is -1
    children_right[is_not_leaf] = state['children_right']
    children_right[is_leaf] = -1  # child of leaves is -1
    features[is_not_leaf] = state['features']
    features[is_leaf] = -2  # feature of leaves is -2
    thresholds[is_not_leaf] = decompress_half_int_float_array(state['thresholds'])
    thresholds[is_leaf] = -2  # threshold of leaves is -2
    values[is_leaf] = state['values']
    ...

Reconstructing the state of the tree

To be able to reconstruct a Tree, we need the exact same state format as in the original, i.e., a dict with max_depth, node_count, nodes and values, all in the same format as before. nodes is actually a NumPy data type object as mentioned before, so we need to specify the structure manually.

def decompress_tree_state(state: dict):
    ...
    dtype = np.dtype([('left_child', '<i8'), ('right_child', '<i8'), 
                      ('feature', '<i8'), ('threshold', '<f8'),
                      ('impurity', '<f8'), ('n_node_samples', '<i8'), 
                      ('weighted_n_node_samples', '<f8')])
    nodes = np.zeros(n_edges, dtype=dtype)
    nodes['left_child'] = children_left
    nodes['right_child'] = children_right
    nodes['feature'] = features
    nodes['threshold'] = thresholds

    return {'max_depth': state['max_depth'],
            'node_count': state['node_count'],
            'nodes': nodes,
            'values': values}

Compression errors

In our production example, we tested the predictions on a dataset with 3 500 000 entries. The average compression error was at 5.59e-8, there were 8 / 3 500 000 elements with a prediction error > 1e-4 with the largest error being 2.87e-4. When we cumulate all prediction errors, we get an absolute error of 0.19. For our use case, these errors are completely fine.

Discussion

In our production example, we could shrink down the model size from 496 MB to 128 MB, i.e., compress the model to 26% of its original size.

One can also use a compression algorithm like gzip or lzma to compress the model even further. We tried out zstd, gzip, lzma, bz2 and zlib. We opted for lzma because it is part of the Python stdlib and it yielded the best compression results. We managed to go down to 32 MB (without the extra dtype reduction steps and only lzma compression we got to 144 MB, so we still get the model down to 22% of its original "lzma size").

These extra compression algorithms do bring additional startup time with them, though. You can find the numbers at the beginning of this blog post.

This is although only the size on the disk. When you actually load the model, these steps don't help on reducing the size of the model in memory. It still has its use cases, for example when providing a docker image that is as small as possible.

# dump compressed
with lzma.open("/path/to/model.pickle.lzma", mode="wb") as f:
    dump_dtype_reduction(model, file)

# load compressed
with lzma.open("/path/to/model.pickle.lzma", mode="rb") as f:
    model_dtype_reduction = pickle.load(file)

Pickling has several disadvantages though like arbitrary code execution and no version compatibility. Uwe Korn did a blog post about the implications of pickling ML models and Ned Batchelder wrote one about pickle's nine flaws.

Another alternative could be using ONNX to compile the model. This also could shrink down the model and improve the inference speed and the simplification of deployment. ONNX brings its own challenges though: you need create workflows to transform your fitted model to ONNX which is not trivial for more complicated models. Also, you would need to make sure that the converted ONNX model predicts the same things as your original model.

The code for this blog post can be found on GitHub Gist. The whole procedure was tested on scikit-learn=0.23.2, the internal structure of Tree may however change in upcoming versions.