Automatic differentiation (autodiff) is the bread and butter of every deep learning framework. In fact, it is not a stretch to say that deep learning as it exists today would not have been possible without the widespread adoption of autodiff tools. You define a neural network, as complex as it may be, and you let it .fit(X, y)
itself. Nothing simpler, right?
Strangely, autodiff is marginally understood (or even misunderstood) by many deep learning practitioners, its scope and possibilities just mentioned (e.g., ever heard of differentiable programming?).
About these tutorials
The purpose of these tutorials is twofold. Firstly, we want to introduce you to the misteries behind autodiff: its principles, algorithms, math, implementations. We would like to show you that there is a vast world hiding behind most deep learning frameworks that, once revealed, opens up to a variety of interesting experiments.
Secondly, we complement theory with a description of several concepts from JAX, a beautiful Python library for high-performance tensor computations, with powerful autodiff and compilation tools. These tutorials will freely mix theory, JAX ideas, and code snippets to explain, with an accompanying Colab notebook if you only care about code. We hope you find this informative!
What you will need
- Basic understanding of calculus and linear algebra;
- Some experience with neural networks!
You can look at the JAX introduction, if you want, but we will introduce most concepts when we need them.
Have fun! For any feedback, feel free to contact me via GitHub or any other channel.
Content of the tutorials
Part 1: what is autodiff?
In this first part, we introduce what is, in fact, autodiff, and how to use it in JAX, by restricting our attention to scalar values and forward-mode autodiff.
Part 2: reverse-mode autodiff
In the second part, we introduce AD tools for vector-matrix computations, and a more powerful reverse-mode technique especially useful in machine learning.