Numba with CUDA

Numba is a open source optimizing compiler for Python made by Continuum Analytics, using LLVM to compile Python code to machine code.

Numba 0.13 introduces CUDA support. Now you can have Numba compile routines using NumPy to CUDA code.

In [6]:
from numba import cuda
import numpy as np
import pylab

CUDA kernels can be created from Python code using decorators. For example, the following code is compiled into a device function (callable from a kernel). Note the signature for the device function is specified as a string and the fact that it is a device function is indicated.

In [2]:
@cuda.jit('int32(float64, float64, int32)', device=True)
def mandel(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    c = complex(x, y)
    z = complex(0, 0)
    for i in range(max_iters):
        z = z*z + c
        if z.real * z.real + z.imag * z.imag >= 4:
            return i
    return 255

The cuda.autojit decorator attempts to infer the types of the arguments automatically. This routine does not return image because kernels cannot have return values.

In [3]:
@cuda.autojit
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]
 
    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandel(real, imag, iters)
            image[y, x] = color

The resulting image is allocated and passed into our routines.

In [8]:
image = np.zeros((500, 750), dtype=np.uint8)
%time create_fractal(-2.0, 1.0, -1.0, 1.0, image, 100)
CPU times: user 7.71 s, sys: 203 ms, total: 7.91 s
Wall time: 7.78 s

Images can be embedded in the notebook easily when using matplotlib.

In [10]:
%matplotlib inline

pylab.imshow(image)
figure = pylab.gcf()
figure.set_size_inches((15., 15.))
pylab.show()