Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How can I redefine a subfunction of a JAX-jitted function?

I have a function foo which is jitted with JAX. foo calls bar.

from jax import jit

def bar(x):
  return x ** 2

@jit
def foo(x):
  return 1 + bar(x)

print(f'foo(4) = {foo(4)}')

Prints foo(4) = 17 as expected.

If I redefine bar at runtime, what is the best way to re-jit foo?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

Bonus: Is it possible to just tell jax bar changed, so that it can re-jit everything dependent upon bar?


Details

If I redefine bar and print again,

def bar(x):
  return 2 * x

print(f'foo(4) = {foo(4)}')

the output is still foo(4) = 17. Clearly, the old bar is still in the jitted foo.

In order to get the desired foo with updated bar, I need to re-jit. I can do this by rewriting foo, so that it is re-jitted

@jit
def foo(x):
  return 1 + bar(x)

print(f'foo(4) = {foo(4)}')

which now prints foo(4) = 9 as desired. But rewriting all of foo is silly. I can also get the desired output by just rewrapping foo as

foo = jit(foo)

But this feels dangerous, since we’re passing an already-jitted foo into the jit compiler again. I was surprised to see that it worked, and I suspect weird corner cases with this approach.

Also, both of these approaches require knowing that foo called bar. It’s trivial in this toy example, but in more complex software, there might be many functions which need to be re-jitted. It’s cumbersome to do it individually, and easy to forget some as well. Hence the ‘bonus’ part of my question.

The jit docs and jit caching docs did not answer my question, AFAICT. Nor did a reasonable search of stackoverflow or the github discussion pages.

>Solution :

The reason this is not working as expected is because the function is not pure: that is, its output depends not just on the inputs, but also on some global state. JAX transforms like JIT only work correctly for pure functions; see JAX Sharp Bits: Pure Functions for more discussion of this.

The best way to approach this is probably to change the function’s call signature so that all relevant state is explicitly passed to the function. For example, you could pass bar explicitly:

from jax import jit
from functools import partial

def bar(x):
  return x ** 2

# Mark `bar` as a static argument: when it changes it will trigger re-compilation
@partial(jit, static_argnames=['bar'])
def foo(x, bar):
  return 1 + bar(x)

print(f'foo(4) = {foo(4, bar)}')
# foo(4) = 17

def bar(x):
  return 2 * x

print(f'foo(4) = {foo(4, bar)}')
# foo(4) = 9

Now when you change bar the output of the function changes.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading