JAX Hands-on I
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. This tutorial is the first episode of the JAX hands-on series I've developed for the Artificial Neural Network Course at University of Tehran. JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.