Python Rust Binding
Context
For a project related to https://spritesmods.com/?art=magicbrush, who print images using a commercial printer cardridge, I need to convert images composed of 255 levels per channel (classic RGB or CMYK formats) into something composed of 2 level per channel (spray color droplet or not). This process is called dithering, and the result shown below.
I’ve been using a dithering algorithm found on the net (https://scipython.com/blog/floyd-steinberg-dithering/) that is giving satisfying result. But it is very, very slow.
So I looked into a solution to make it faster, and a classic one is to use a compiled language (I suppose other things could be tried like using numba or implementing the algorithm in a better way, but I couldn’t find it easily).
I could have written the algorithm in C or C++ and use swig or ctypes to call it, but I’ve found recently that rust is providing really nice tools to create python packages.
Creating a python package calling rust code
I’m using a cool tool to create a python package, called maturin : https://www.maturin.rs/. It allow to create a python package easily, and choose between different bindings. I’ve chosen pyo3. The steps to have something running are quite simple:
- Create a virtual environment (see https://www.maturin.rs/tutorial)
maturin new fs_dithering
to create the app templatematurin develop
ormaturin develop --release
to create the package- Now you can launch python and do
import fs_dithering
, then access the functions declared infs_dithering/src/lib.rs file
That’s it.
Implementing the Floyd-Steinberg Dithering algorithm
Python version
def legacy_dithering(img: np.ndarray) -> np.ndarray:
"""
From https://scipython.com/blog/floyd-steinberg-dithering/
This one is really, really slow.
"""
def get_new_val(old_val, nc):
"""
Get the "closest" colour to old_val in the range [0,1] per channel divided
into nc values.
"""
return np.round(old_val * (nc - 1)) / (nc - 1)
nc = 2
arr = np.array(img, dtype=float) / 255
new_height, new_width, channels = arr.shape
for ir in range(new_height):
for ic in range(new_width):
# NB need to copy here for RGB arrays otherwise err will be (0,0,0)!
old_val = arr[ir, ic].copy()
new_val = get_new_val(old_val, nc)
arr[ir, ic] = new_val
err = old_val - new_val
# In this simple example, we will just ignore the border pixels.
if ic < new_width - 1:
arr[ir, ic+1] += err * 7/16
if ir < new_height - 1:
if ic > 0:
arr[ir+1, ic-1] += err * 3/16
arr[ir+1, ic] += err * 5/16
if ic < new_width - 1:
arr[ir+1, ic+1] += err / 16
carr = np.array(arr/np.max(arr, axis=(0,1)) * 255, dtype=np.uint8)
return carr
As you can see the interface is quite simple, we give a numpy array and a new, modified one is computed. Since I assume that the slow thing is the double for loop, I won’t try to reproduce the whole function but only a subset, so I wrote the following code for my new function:
def new_dithering(img: np.ndarray) -> np.ndarray:
arr = np.array(img, dtype=float) / 255
arr = fsd.fs_dither(arr, 2) # the Rust - maturin generated package.
carr = np.array(arr/np.max(arr, axis=(0,1)) * 255, dtype=np.uint8)
return carr
As you can see i kept the numpy efficient functions outside my rust code, to avoid finding in the documentation how to translate those lines in rust.
Using numpy arrays in rust
Luckily, I found the following repository doing something similar (another type of dithering) : https://github.com/BackyardML/dithering.
On this repo you can find how to use the numpy rust crate. It helped me create the simple rust interface needed in lib.rs file.
use core::panic;
use numpy::{PyUntypedArrayMethods, PyArrayDyn, PyReadonlyArrayDyn, PyArrayMethods, ToPyArray};
use pyo3::prelude::*;
/// Formats the sum of two numbers as string.
#[pyfunction]
pub fn fs_dither<'py>(
py: Python<'py>,
image: PyReadonlyArrayDyn<f64>,
nc: i32,
) -> PyResult<Bound<'py, PyArrayDyn<f64>>> {
let img_shape = image.shape();
if img_shape.len() != 3
{
panic!("Image must have 3 dimensions");
}
let mut result = image.to_owned_array();
(0..img_shape[0]).for_each(|x| {
(0..img_shape[1]).for_each(|y| {
(0..img_shape[2]).for_each(|channel| {
let old_pixel = result[[x, y, channel]];
let dithering_count = (nc - 1) as f64; // allow float multiplication
let new_value = (old_pixel * dithering_count).round_ties_even() / dithering_count;
let err = old_pixel - new_value;
if y < (img_shape[1] - 1)
{
result[[x, y+1, channel]] += err * 7f64/16f64;
}
if x < (img_shape[0] - 1)
{
if y > 0
{
result[[x+1, y-1, channel]] += err * 3f64/16f64;
}
result[[x+1, y, channel]] += err * 5f64/16f64;
if y < (img_shape[1] - 1)
{
result[[x+1, y+1, channel]] += err / 16f64;
}
}
result[[x, y, channel]] = new_value;
});
});
});
return Ok(result.to_pyarray_bound(py));
}
/// A Python module implemented in Rust.
#[pymodule]
fn fs_dithering(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fs_dither, m)?)?;
Ok(())
}
This is the exact same algorithm, translated in rust, using numpy crate and pyo3 bindings.
Validating the implementation
I wrote a small test to verify that my implementations where giving similar results:
def test_dithering():
img = np.random.randint(0, 255, (1000, 2000, 3))
start = time.time()
legacy_dithered = legacy_dithering(img)
print("Legacy dithering took", time.time() - start)
start = time.time()
new_dithered = new_dithering(img)
print("New dithering took", time.time() - start)
assert np.allclose(new_dithered, legacy_dithered)
Which gives the following output:
(.venv) toto@tata:/home/fs_dithering$ pytest -s tests
tests/test_dithering.py Legacy dithering took 53.773738622665405
New dithering took 0.40010643005371094
======= 1 passed in 54.89s =======
So the performance gain is around a factor 100.
Conclusion
In a very reasonable time (a few hours), I was able to create and call an equivalent python function implemented in rust, with a considerable performance gain of a factor 100.