I have a function foo which is jitted with JAX. foo calls bar.
from jax import jit
def bar(x):
return x ** 2
@jit
def foo(x):
return 1 + bar(x)
print(f'foo(4) = {foo(4)}')
Prints foo(4) = 17 as expected.
If I redefine bar at runtime, what is the best way to re-jit foo?
Bonus: Is it possible to just tell jax bar changed, so that it can re-jit everything dependent upon bar?
Details
If I redefine bar and print again,
def bar(x):
return 2 * x
print(f'foo(4) = {foo(4)}')
the output is still foo(4) = 17. Clearly, the old bar is still in the jitted foo.
In order to get the desired foo with updated bar, I need to re-jit. I can do this by rewriting foo, so that it is re-jitted
@jit
def foo(x):
return 1 + bar(x)
print(f'foo(4) = {foo(4)}')
which now prints foo(4) = 9 as desired. But rewriting all of foo is silly. I can also get the desired output by just rewrapping foo as
foo = jit(foo)
But this feels dangerous, since we’re passing an already-jitted foo into the jit compiler again. I was surprised to see that it worked, and I suspect weird corner cases with this approach.
Also, both of these approaches require knowing that foo called bar. It’s trivial in this toy example, but in more complex software, there might be many functions which need to be re-jitted. It’s cumbersome to do it individually, and easy to forget some as well. Hence the ‘bonus’ part of my question.
The jit docs and jit caching docs did not answer my question, AFAICT. Nor did a reasonable search of stackoverflow or the github discussion pages.
>Solution :
The reason this is not working as expected is because the function is not pure: that is, its output depends not just on the inputs, but also on some global state. JAX transforms like JIT only work correctly for pure functions; see JAX Sharp Bits: Pure Functions for more discussion of this.
The best way to approach this is probably to change the function’s call signature so that all relevant state is explicitly passed to the function. For example, you could pass bar explicitly:
from jax import jit
from functools import partial
def bar(x):
return x ** 2
# Mark `bar` as a static argument: when it changes it will trigger re-compilation
@partial(jit, static_argnames=['bar'])
def foo(x, bar):
return 1 + bar(x)
print(f'foo(4) = {foo(4, bar)}')
# foo(4) = 17
def bar(x):
return 2 * x
print(f'foo(4) = {foo(4, bar)}')
# foo(4) = 9
Now when you change bar the output of the function changes.