Tag Archives: JAX

Using JAX and Haiku to build a Graph Neural Network


JAX

Last year, I had an opportunity to delve into the world of JAX whilst working at InstaDeep. My first blopig post seems like an ideal time to share some of that knowledge. JAX is an experimental Python library created by Google’s DeepMind for applying accelerated differentiation. JAX can be used to differentiate functions written in NumPy or native Python, just-in-time compile and execute functions on GPUs and TPUs with XLA, and mini-batch repetitious functions with vectorization. Collectively, these qualities place JAX as an ideal candidate for accelerated deep learning research [1].

JAX is inspired by the NumPy API, making usage very familiar for any Python user who has already worked with NumPy [2]. However, unlike NumPy, JAX arrays are immutable; once they are assigned in memory they cannot be changed. As such, JAX includes specific syntax for index manipulation. In the code below, we create a JAX array and change the 1^{st} element to a 4:

Continue reading

Einops: Powerful library for tensor operations in deep learning

Tobias and I recently gave a talk at the OPIG retreat on tips for using PyTorch. For this we created a tutorial on Google Colab notebook (link can be found here). I remember rambling about the advantages of implementing your own models against using other peoples code. Well If I convinced you, einops is for you!!

Basically, einops lets you perform operations on tensors using the Einstein Notation. This package comes with a number of advantages a few of which I will try and summarise here:

Continue reading