- Published on

# Optimize pickling disk space for deploying scikit-learn trees to production

- Authors

*Update 04/2023*

We have created a package that implements the ideas presented in this article and presented it at PyData Berlin 2023. Feel free to check `slim-trees`

out on GitHub.

At QuantCo, we use ML models such as random forests to create value from data for our customers. To ship these models to production, we have to include them in the application, e.g., bake them into a Docker container or fetch them from a model registry. To save these models, we pickle them.

If we want to improve the model, we sometimes increase the number of parameters. This comes at the cost of disk space because the model is getting more complex. In our case, we were not able to deploy a larger model because of a file size constraint. Nevertheless, in general it makes sense to strive for smaller Docker images to achieve quicker transfer and deployment.

One way to solve this problem would be to use a different format altogether, like ONNX. Instead, we challenged the scikit-learn pickle format to see if there are any gains to be had. We found a 4x improvement over the native scikit-learn format.

Here are the numbers:

dtype compression | compression algorithm | size | dump time | load time |
---|---|---|---|---|

no | no | 496 MB | 1.4 s | 0.7 s |

yes | no | 128 MB | 1.6 s | 0.9 s |

no | lzma | 144 MB | 217 s | 8.4 s |

yes | lzma | 32 MB | 53 s | 2.8 s |

In this article, we present how we did it.

`Tree`

is saved

How a Usually, when pickling an object, the methods `__reduce__`

and `__setstate__`

get called to serialize the object. The following code illustrates some of what is happening when calling `pickle.dump(obj)`

.

```
def pickle_obj(obj):
cls, init_args, state = obj.__reduce__()
return unpickle_obj, (cls, init_args, state)
def unpickle_obj(cls, init_args, state):
obj = cls(*init_args)
obj.__setstate__(state)
return obj
myobj = ...
unpickle_func, myobj_data = pickle_obj(myobj)
save_to_file("myobj.pkl", unpickle_func, myobj_data)
...
unpickle_func, myobj_data = load_from_file("myobj.pkl")
myobj_unpickled = unpickle_func(*myobj_data)
```

Let's look at the structure of `state`

when we call `tree.__reduce__()`

.

```
{
'max_depth': 19,
'node_count': 469,
'nodes': np.array(...,
shape=(469,),
dtype=np.dtype([
('left_child', '<i8'),
('right_child', '<i8'),
('feature', '<i8'),
('threshold', '<f8'),
('impurity', '<f8'),
('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])),
'values': np.array(...,
shape=(469, 1, 1), dtype=float64)
}
```

Since `max_depth`

and `node_count`

are only single integers, they can be neglected when trying to minimize the disk space. `nodes`

is a NumPy data type object, each element of the array is actually a tuple containing the indices of the children, the used feature, etc. per node. This has a lot of potential for savings. `value`

is a separate array because it could be a multidimensional output too which doesn't fit in the data type object structure of `nodes`

.

Returning to the `state`

of our `Tree`

, we can see that all the elements of `nodes`

are either of type `<i8`

or `<f8`

, i.e., `int64`

or `float64`

. This takes up a lot of disk space. To mitigate this problem, we could simply change how these parameters of the model are stored to make the serialization more efficient.

How can we implement changing all `float64`

and `int64`

numbers to more efficient datatypes? To do this, we can customize the behaviour of scikit-learn when pickling a `Tree`

.

## Custom picklers

We will be creating a custom pickler whose behavior varies when pickling a `Tree`

object. Thus, every instance of a `Tree`

in our model structure will get pickled in a custom manner while the rest of the pickling behavior stays the same.

To achieve this, we need to customize the dispatch table of said pickler.

```
def dump_dtype_reduction(model: Any, file: BinaryIO):
p = pickle.Pickler(file)
p.dispatch_table = copyreg.dispatch_table.copy()
p.dispatch_table[Tree] = compressed_tree_pickle
p.dump(model)
```

The method `compressed_tree_pickle`

is supposed to replace the method `pickle_obj`

from above.

Since we want to reduce the size of a `Tree`

's `state`

, we want to do something like this.

```
def compressed_tree_pickle(tree):
assert isinstance(tree, Tree)
cls, init_args, state = tree.__reduce__()
compressed_state = compress_tree_state(state)
return compressed_tree_unpickle, (cls, init_args, compressed_state)
def compressed_tree_unpickle(cls, init_args, state):
tree = cls(*init_args)
decompressed_state = decompress_tree_state(state)
tree.__setstate__(decompressed_state)
return tree
```

Now, all that's left is to implement `compress_tree_state`

and `decompress_tree_state`

to a more efficient format.

`Tree`

state

Compressing a We can see that `left_child`

and `right_child`

are stored as `<i8`

, i.e., `int64`

numbers. But since there are not *that* many leaves in a tree, the possible values of the array containing the children of each node are relatively small. Thus, we can change the `dtype`

of the array to `int16`

for example. The same can be done with the `feature`

array. This attribute stores the index of the features to be used in the node. The `value`

array contains the actual output values of the leaves as `float64`

numbers. With a bit of loss of precision, we could use `float32`

here.

There are now four remaining attributes of `state`

: `threshold`

, `impurity`

, `n_node_samples`

, and `weighted_node_samples`

. The last three are all purely analytical and not used during inference. Thus, we can just leave them out in our custom pickling solution.

Now let's look at `threshold`

, i.e., the splitting points at each node, whether to go to the left or right child in the tree. We could simply do the same as with the `value`

array and call it a day. But unlike the `value`

array, we need to have in mind that error propagation can occur. If a "wrong" turn is made in the beginning, we could land at a completely different output node than we should.

There are three ways to continue:

- accept the extra loss and continue;
- use
`float64`

again in the`threshold`

array and don't gain disk space in this array; - examine the structure of the
`threshold`

array and try to find a way to do lossless compression.

Let's try out the latter option.

### Compressing half ints

The following only works if you have many categorical columns in your dataset. If you don't, you may want to go with 1. or 2. or look into other methods to do lossless compression.

When we looked at the structure of the `threshold`

array, we realized that there are a lot of values that are of the form `x.5`

, i.e., `3.5`

, `0.5`

, `9.5`

, etc. Let's call them half ints. This is because our dataset has a lot of categorical values which are converted to `int`

and then cast to `float64`

for fitting by scikit-learn. Since every point still lies on an integer, the optimal splitting point is in the middle of two integers. These optimal splitting points are exactly what the sklearn trees learn. Now what can we do with this information?

We could compress these half ints by multiplying them with `2`

and casting them to `int8`

.

Because of floating point precision, we need to make sure not only to catch the exact integers but an epsilon neighborhood of the integer values.

```
a2 = np.array([1.9999999999999, 2.0000000000001])
np.minimum(np.abs(a2 % 1 - 1), a2 % 1)
# array([9.99200722e-14, 9.99200722e-14])
```

```
def compress_half_int_float_array(a, compression_dtype=np.int8):
info = np.iinfo(compression_dtype)
a2 = 2. * a
is_compressible = (np.minimum(np.abs(a2 % 1 - 1), a2 % 1) < 1e-12) & \
(a2 >= info.min) & (a2 <= info.max)
not_compressible = np.logical_not(is_compressible)
a2_compressible = a2[is_compressible].astype(compression_dtype)
a_incompressible = a[not_compressible]
state = {
"is_compressible": is_compressible,
"a2_compressible": a2_compressible,
"a_incompressible": a_incompressible,
}
return state
def decompress_half_int_float_array(state):
is_compressible = state["is_compressible"]
a = np.zeros(len(is_compressible), dtype="float64")
a[is_compressible] = state["a2_compressible"] / 2.
a[np.logical_not(is_compressible)] = state["a_incompressible"]
return a
```

This format is very efficient if you have a lot of these half int values in your array. `state["is_compressible"]`

is a `bool`

array which takes up only one bit per entry, `state["a2_compressible"]`

is an `int8`

array which only takes up eight bits per entry and `state["a_incompressible"]`

takes up the full 64 bits per entry.

In our productive example, the `threshold`

array contained 4000 elements with 65-70% of them being half integers. This leads to a theoretical compression of

compared to the `float32`

compression of $32\;\mathrm{Bit} / 64\;\mathrm{Bit} = 0.5$. So better theoretical compression performance while staying lossless.

### Further optimizations

All the arrays that we have talked about have the same length, namely the number of nodes in the tree. But when thinking about it, the leaves don't really need a left/right child, a corresponding feature or threshold. It only needs the value. That's why the values in the `children_left`

and `children_right`

fields are `-1`

there and the value of `features`

and `thresholds`

are `-2`

.

We can just delete these elements and save some space again by creating a `bool`

array `is_leaf`

that is used as a mask for getting and setting the corresponding values.

```
def compress_tree_state(state: dict):
...
is_leaf = nodes['left_child'] == -1
is_not_leaf = np.logical_not(is_leaf)
# feature, threshold and children are irrelevant for leaves
children_left = nodes['left_child'][is_not_leaf].astype(dtype_child)
children_right = nodes['right_child'][is_not_leaf].astype(dtype_child)
features = nodes['feature'][is_not_leaf].astype(dtype_feature)
# do lossless compression for thresholds by downcasting
# half ints (e.g. 5.5, 10.5, ...) to dtype_halfint
thresholds = nodes['threshold'][is_not_leaf].astype(dtype_threshold)
thresholds = compress_half_int_float_array(thresholds, dtype_halfint)
# value is irrelevant when node is not a leaf
values = state['values'][is_leaf].astype(dtype_value)
...
def decompress_tree_state(state: dict):
...
n_edges = len(is_leaf)
children_left = np.empty(n_edges, dtype=np.int64)
children_right = np.empty(n_edges, dtype=np.int64)
features = np.empty(n_edges, dtype=np.int64)
thresholds = np.empty(n_edges, dtype=np.float64)
# same shape as values but with all edges instead of only the leaves
values = np.zeros((n_edges, *state['values'].shape[1:]), dtype=np.float64)
children_left[is_not_leaf] = state['children_left']
children_left[is_leaf] = -1 # child of leaves is -1
children_right[is_not_leaf] = state['children_right']
children_right[is_leaf] = -1 # child of leaves is -1
features[is_not_leaf] = state['features']
features[is_leaf] = -2 # feature of leaves is -2
thresholds[is_not_leaf] = decompress_half_int_float_array(state['thresholds'])
thresholds[is_leaf] = -2 # threshold of leaves is -2
values[is_leaf] = state['values']
...
```

## Reconstructing the state of the tree

To be able to reconstruct a `Tree`

, we need the exact same state format as in the original, i.e., a dict with `max_depth`

, `node_count`

, `nodes`

and `values`

, all in the same format as before. `nodes`

is actually a NumPy data type object as mentioned before, so we need to specify the structure manually.

```
def decompress_tree_state(state: dict):
...
dtype = np.dtype([('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])
nodes = np.zeros(n_edges, dtype=dtype)
nodes['left_child'] = children_left
nodes['right_child'] = children_right
nodes['feature'] = features
nodes['threshold'] = thresholds
return {'max_depth': state['max_depth'],
'node_count': state['node_count'],
'nodes': nodes,
'values': values}
```

## Compression errors

In our production example, we tested the predictions on a dataset with 3 500 000 entries. The average compression error was at `5.59e-8`

, there were 8 / 3 500 000 elements with a prediction error `> 1e-4`

with the largest error being `2.87e-4`

. When we cumulate all prediction errors, we get an absolute error of `0.19`

. For our use case, these errors are completely fine.

## Discussion

In our production example, we could shrink down the model size from 496 MB to 128 MB, i.e., compress the model to 26% of its original size.

One can also use a compression algorithm like `gzip`

or `lzma`

to compress the model even further. We tried out `zstd`

, `gzip`

, `lzma`

, `bz2`

and `zlib`

. We opted for `lzma`

because it is part of the Python stdlib and it yielded the best compression results. We managed to go down to 32 MB (without the extra `dtype`

reduction steps and only `lzma`

compression we got to 144 MB, so we still get the model down to 22% of its original "`lzma`

size").

These extra compression algorithms do bring additional startup time with them, though. You can find the numbers at the beginning of this blog post.

This is although only the size on the disk. When you actually load the model, these steps don't help on reducing the size of the model in memory. It still has its use cases, for example when providing a docker image that is as small as possible.

```
# dump compressed
with lzma.open("/path/to/model.pickle.lzma", mode="wb") as f:
dump_dtype_reduction(model, file)
# load compressed
with lzma.open("/path/to/model.pickle.lzma", mode="rb") as f:
model_dtype_reduction = pickle.load(file)
```

Pickling has several disadvantages though like arbitrary code execution and no version compatibility. Uwe Korn did a blog post about the implications of pickling ML models and Ned Batchelder wrote one about pickle's nine flaws.

Another alternative could be using ONNX to compile the model. This also could shrink down the model and improve the inference speed and the simplification of deployment. ONNX brings its own challenges though: you need create workflows to transform your fitted model to ONNX which is not trivial for more complicated models. Also, you would need to make sure that the converted ONNX model predicts the same things as your original model.

The code for this blog post can be found on GitHub Gist. The whole procedure was tested on `scikit-learn=0.23.2`

, the internal structure of `Tree`

may however change in upcoming versions.