🔌 Toolbox of short, reusable pieces of code and knowledge.
Useful packages/libraries:
jit
, vmap
, grad
, PRNG, PyTree, Numpy API)linen.Module
(see API here).tree_map
which maps a function to the elements (leaves) contained in a Pytree. So in this case, we can apply an update rule (e.g., for gradient descent, stored as a function) to the Pytree params in one line of code.
# credit: flax docs
params = jax.tree_map(lambda p, g: p - learning_rate * g, params, jax.grad(mse_loss)(params, x_samples, y_samples))
The following will be a reference of cool (or not so cool) functions in Jax. Sort of like a mini reference document.