Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

vectorized minimization and root finding in jax

I have a family of functions parameterized by args

f(x, args)

and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My first attempt was to loop through the different values of args and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap inside a jax.scipy.optimize.minimize or jaxopt.ScipyMinimize, but I can’t seem to pass more than one value for args.

Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args are not changing.

>Solution :

You can define a function to find the minimum given particular args, and then wrap it in jax.vmap to automatically vectorize it. For example:

import jax
import jax.numpy as jnp
from jax.scipy import optimize

def f(x, args):
  a, b = args
  return jnp.sum(a + (x - b) ** 2)

def find_min(a, b):
  x0 = jnp.array([1.0])
  args = (a, b)
  return optimize.minimize(f, x0, (args,), method="BFGS")

a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))

results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())

print(results.success)
# [ True  True  True  True  True  True  True  True  True  True  True  True
#   True  True  True  True  True  True  True  True  True  True  True  True
#   True]

print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
#   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading