Context

For a project related to https://spritesmods.com/?art=magicbrush, who print images using a commercial printer cardridge, I need to convert images composed of 255 levels per channel (classic RGB or CMYK formats) into something composed of 2 level per channel (spray color droplet or not). This process is called dithering, and the result shown below.

Original image Dithered image

I’ve been using a dithering algorithm found on the net (https://scipython.com/blog/floyd-steinberg-dithering/) that is giving satisfying result. But it is very, very slow.

So I looked into a solution to make it faster, and a classic one is to use a compiled language (I suppose other things could be tried like using numba or implementing the algorithm in a better way, but I couldn’t find it easily).

I could have written the algorithm in C or C++ and use swig or ctypes to call it, but I’ve found recently that rust is providing really nice tools to create python packages.

Creating a python package calling rust code

I’m using a cool tool to create a python package, called maturin : https://www.maturin.rs/. It allow to create a python package easily, and choose between different bindings. I’ve chosen pyo3. The steps to have something running are quite simple:

  • Create a virtual environment (see https://www.maturin.rs/tutorial)
  • maturin new fs_dithering to create the app template
  • maturin develop or maturin develop --release to create the package
  • Now you can launch python and do import fs_dithering, then access the functions declared in fs_dithering/src/lib.rs file

That’s it.

Implementing the Floyd-Steinberg Dithering algorithm

Python version

def legacy_dithering(img: np.ndarray) -> np.ndarray:
    """
    From https://scipython.com/blog/floyd-steinberg-dithering/
    This one is really, really slow.
    """

    def get_new_val(old_val, nc):
        """
        Get the "closest" colour to old_val in the range [0,1] per channel divided
        into nc values.

        """

        return np.round(old_val * (nc - 1)) / (nc - 1)
    nc = 2
    arr = np.array(img, dtype=float) / 255
    new_height, new_width, channels = arr.shape
    for ir in range(new_height):
        for ic in range(new_width):
            # NB need to copy here for RGB arrays otherwise err will be (0,0,0)!
            old_val = arr[ir, ic].copy()
            new_val = get_new_val(old_val, nc)
            arr[ir, ic] = new_val
            err = old_val - new_val
            # In this simple example, we will just ignore the border pixels.
            if ic < new_width - 1:
                arr[ir, ic+1] += err * 7/16
            if ir < new_height - 1:
                if ic > 0:
                    arr[ir+1, ic-1] += err * 3/16
                arr[ir+1, ic] += err * 5/16
                if ic < new_width - 1:
                    arr[ir+1, ic+1] += err / 16

    carr = np.array(arr/np.max(arr, axis=(0,1)) * 255, dtype=np.uint8)
    return carr

As you can see the interface is quite simple, we give a numpy array and a new, modified one is computed. Since I assume that the slow thing is the double for loop, I won’t try to reproduce the whole function but only a subset, so I wrote the following code for my new function:

    def new_dithering(img: np.ndarray) -> np.ndarray:
        arr = np.array(img, dtype=float) / 255
        arr = fsd.fs_dither(arr, 2) # the Rust - maturin generated package.
        carr = np.array(arr/np.max(arr, axis=(0,1)) * 255, dtype=np.uint8)
        return carr

As you can see i kept the numpy efficient functions outside my rust code, to avoid finding in the documentation how to translate those lines in rust.

Using numpy arrays in rust

Luckily, I found the following repository doing something similar (another type of dithering) : https://github.com/BackyardML/dithering.

On this repo you can find how to use the numpy rust crate. It helped me create the simple rust interface needed in lib.rs file.

use core::panic;

use numpy::{PyUntypedArrayMethods, PyArrayDyn, PyReadonlyArrayDyn, PyArrayMethods, ToPyArray};

use pyo3::prelude::*;

/// Formats the sum of two numbers as string.
#[pyfunction]
pub fn fs_dither<'py>(
    py: Python<'py>,
    image: PyReadonlyArrayDyn<f64>,
    nc: i32,
) -> PyResult<Bound<'py, PyArrayDyn<f64>>> {
    let img_shape = image.shape();
    if img_shape.len() != 3
    {
        panic!("Image must have 3 dimensions");
    }
    let mut result = image.to_owned_array();
    (0..img_shape[0]).for_each(|x| {
        (0..img_shape[1]).for_each(|y| {
            (0..img_shape[2]).for_each(|channel| {
                let old_pixel = result[[x, y, channel]];
                let dithering_count = (nc - 1) as f64; // allow float multiplication
                let new_value = (old_pixel * dithering_count).round_ties_even() / dithering_count;
                let err = old_pixel - new_value;
                if y < (img_shape[1] - 1)
                {
                    result[[x, y+1, channel]] += err * 7f64/16f64;
                }
                if x < (img_shape[0] - 1)
                {
                    if y > 0
                    {
                        result[[x+1, y-1, channel]] += err * 3f64/16f64;
                    }
                    result[[x+1, y, channel]] += err * 5f64/16f64;
                    if y < (img_shape[1] - 1)
                    {
                        result[[x+1, y+1, channel]] += err / 16f64;
                    }
                }
                result[[x, y, channel]] = new_value;
            });
        });
    });
    return Ok(result.to_pyarray_bound(py));
}

/// A Python module implemented in Rust.
#[pymodule]
fn fs_dithering(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fs_dither, m)?)?;
    Ok(())
}

This is the exact same algorithm, translated in rust, using numpy crate and pyo3 bindings.

Validating the implementation

I wrote a small test to verify that my implementations where giving similar results:

def test_dithering():
    img = np.random.randint(0, 255, (1000, 2000, 3))
    start = time.time()
    legacy_dithered = legacy_dithering(img)
    print("Legacy dithering took", time.time() - start)
    start = time.time()
    new_dithered = new_dithering(img)
    print("New dithering took", time.time() - start)
    assert np.allclose(new_dithered, legacy_dithered)

Which gives the following output:


(.venv) toto@tata:/home/fs_dithering$ pytest -s tests

tests/test_dithering.py Legacy dithering took 53.773738622665405
New dithering took 0.40010643005371094
======= 1 passed in 54.89s =======

So the performance gain is around a factor 100.

Conclusion

In a very reasonable time (a few hours), I was able to create and call an equivalent python function implemented in rust, with a considerable performance gain of a factor 100.