Published on

JAX for option pricing, Introduction

Authors
  • avatar
    Name
    Kevin Givens
    Twitter

Introduction

JAX is a (fairly) new Python library for numerical computation designed for ease of use on modern accelerator hardward (GPU/TPU). Its intended computational domain is large-scale machine learning problems, such as deep learning with neural networks. However, it's a general enough framework to be extended to other numerical domains, such as quantitative finance for example. In fact, there's a growing "JAX ecosystem" of numerical libraries documented on their site that implement JAX-enabled numerical methods such as minimization, root finding, interpolation, quadrature etc. All of these methods are used in the field of option pricing.

My purpose in the post is to explore some the basics of this library in the context of a simple option pricing problem, European Options in the Black Scholes model. In later posts, I hope to extend this framework to other option contracts and models.

NumPy API

JAX implements a NumPy-like high-level interface based on arrays. Essentially, one can write scripts that closely resemble NumPy scripts, with np replace by jnp.

For example

import matplotlib.pyplot as plt 

t = jnp.linspace(0, 10, 100)
plt.plot(jnp.sin(t), label="sin(x)")
plt.legend()
png

For the most part, JAX arrays look and act like NumPy arrays, with some crucial differences discussed in the JAX documentation.

Broadcasting works in the usual way

x = jnp.arange(5.)
x + 3
Array([3., 4., 5., 6., 7.], dtype=float32)
x + x
Array([0., 2., 4., 6., 8.], dtype=float32)

as do ufuncs

jnp.exp(x)
Array([ 1., 2.7182817, 7.389056, 20.085537, 54.598152], dtype=float32)

Automatic Differentiation

One of the primary features of JAX is its incorporation of automatic differentiation, AD, as this is a foundational set of algorithms used in deep learning. For functions producing scalar output, derivatives can be computed via the grad method. For example

from jax import grad, vmap

t = jnp.linspace(0, 10, 100)
plt.plot(jnp.sin(t), label="sin(x)")
plt.plot(vmap(grad(jnp.sin))(t), '*', label="grad(sin(x))")
plt.plot(jnp.cos(t), '-', label="cos(x)")
plt.legend()
png

As you can see from the previous plot, grad(jnp.sin)(t) computes the exact derivate of sin evaluated at a point in the t interval. The vmap function allows us to extend the domain of the derivative over the entire interval t via vectorization, which we discuss below.

As is well documented 1,2 automatic differentiation has become an extremely import computational techinique in the field of option pricing in order to compute greeks/sensitives of large portfolios of OTC derivatives. This is my primary motivation for studying the use of JAX for option pricing problems. I believe that all modern option pricing libraries should be equiped with AD functionality at their core.

Vectorization

Similar to NumPy, JAX functions can be efficiently vectorized over an additional output dimension using the vmap method. This means that a function designed to produce scalar output, for example

lambda x: jnp.sum(x)

can be vectorized over a particular dimension of data. For example

y = jnp.arange(1., 10.).reshape(3,3)
Array([[1., 2., 3.],
       [4., 5., 6.],
       [7., 8., 9.]], dtype=float32)
vmap(lambda x: jnp.sum(x))(y)
Array([ 6., 15., 24.], dtype=float32)

We see that the sum has been computed row-wise on the y array.

In the context of machine learning, the vectorized axis is often referred to as the batch dimension, meaning extending a function over a batch of inputs. We'll use the same language with regards to option pricing, where by batch we mean a collection of options in a portfolio or option chain.

Just-in-Time Compilation

JAX also includes just-in-time compilation functionality via the jit method. This is a very powerful function whose details I won't explore in this post, but essentially it allows the user to improve the runtime performance of certain JAX-compatible functions via compiler optimization technology and the XLA project.

Parallelization

Finally, JAX directly supports a type of parallelism known as Single Program, Multiple Data SPMD through the pmap method. Roughly speaking, this method will jit compile and distribute copies of certain JAX-compatible functions for parallel execution across hardward acceleration devices such as GPU's and TPU's. Again, I won't dive into this topic in this post, but I will say that it seem's to be a natural fit for option pricing problems involving Monte Carlo simulation, as one could imagine distrubuting the simulation of paths across processing units and combining results to compute expectation values.

Pricing a European Option

Using JAX's array api, we can implement the price function for a European Option in the Black Scholes model, almost directly as a math equation, as follows

def pv(s, k, r, q, t, σ, ω):
    """ present value (price) of European Option using Black Scholes Model 
        Parameters
        ---------
        s: ArrayLike
            spot
        k: ArrayLike
            strike
        r: ArrayLike
            discount rate
        q: ArrayLike
            dividend yield
        t: ArrayLike
            time to expiry
        σ: ArrayLike
            volatility
        ω: ArrayLike
            put/call indicator (1 for call, -1 for put)    
        
        Returns
        -------
            Array
    """
    s, k, r, q, t, σ, ω = map(asarray_inexact, (s, k, r, q, t, σ, ω))
    return (ω*s*jnp.exp(-q*t)*norm.cdf(ω*dp(s/k, r, q, t, σ))
            -ω*k*jnp.exp(-r*t)*norm.cdf(ω*dm(s/k, r, q, t, σ)))

dp and dm are the familiar CDF arguments for the Black Scholes model

def dp(κ, r, q, t, σ):
    """ d plus CDF argument 
        Parameters
        ---------
        kappa: ArrayLike
            moneyness (spot/strike)
        r: ArrayLike
            discount rate
        q: ArrayLike
            dividend yield
        t: ArrayLike
            time to expiry
        σ: ArrayLike
            volatility
        ω: ArrayLike
            put/call indicator (1 for call, -1 for put)    
        
        Returns
        -------
            Array
    """
    
    return (jnp.log(κ) + (r - q)*t + 0.5*σ**2*t)/(σ*jnp.sqrt(t))

The function asarray_inexact implements the jnp.asarray(), similar to NumPy, along with an dtype propogation.

Plotting the PV as a function of spot, and holding all other parameters fixed, gives the well known Call spot dependency plot

s = jnp.linspace(70, 130, 100)
k, r, q, t, σ, ω =  95.00, 0.10, 0.05, 0.5, 0.20, 1

plt.plot(s, pv(s, k, r, q, t, σ, ω), label="Call PV, $t = 0.5$")
plt.plot(s, pv(s, k, r, q, 0.05, σ, ω), label="Call PV, $t = 0.05$")
plt.legend()
plt.show()
png

Notice that s is a 1d array of values, while the other terms are scalars. The output of pv is a 1d array via broadcasting. Had we passed in 1d arrays of equal lengths, the pv function would similarly returned a 1d array. This is an important use case in option pricing for computing pv's over collections of options.

Next, using JAX's vmap and grad methods we can trivially implement vectorized greeks like the following

def delta(s, k, r, q, t, σ, ω):
    s, k, r, q, t, σ, ω = map(asarray_inexact, (s, k, r, q, t, σ, ω))
    return grad(pv, 0)(s, k, r, q, t, σ, ω)

This maps the derivative of the pv function over its first argument, namely s

We can compare this function with the analytic expression of delta

def delta_analytic(s, k, r, q, t, σ, ω):
    """  -N(-dp)= N(dp)-1 """
    return jnp.exp(-q*t)*ω*(norm.cdf(ω*dp(t, s/k, r, q, σ)))
s = jnp.linspace(70, 130, 100)
k, r, q, t, σ, ω =  95.00, 0.10, 0.05, 0.1, 0.20, 1
deltas = vmap(delta, in_axes=[0, None, None, None, None, None, None])(s, k, r, q, t, σ, ω)

plt.plot(s, deltas, label="Call delta AD, $t = 0.1$", linewidth=2.0)
plt.plot(s, delta_analytic(s, k, r, q, t, σ, ω), '--', label="Call delta analytic, $t = 0.1$", linewidth=2.0)
plt.xlabel("Spot")
plt.ylabel("Delta")
plt.legend()
plt.show()
png

We see that both functions return the same output. Other greeks can be computed similarly, for example

gamma = vmap(grad(grad(pv, 0), 0)) 
rho = vmap(grad(pv, 2))
dividend_rho = vmap(grad(pv, 3))
theta = vmap(grad(pv, 4))
vega = vmap(grad(pv, 5))

Conclusion

In this post, I (barely) cracked the surface of what's possible using JAX. I demonstrated that Black Scholes prices and greeks can be computed trivially and accurately in JAX. I hope to revisit this library in the future in the context of other option pricing problems, for instance those involving numerical quadrature, finite difference methods or Monte Carlo simulation. I would also like to carry out performance tests against compiled code.