Tag Archives: graph neural networks

Using JAX and Haiku to build a Graph Neural Network


JAX

Last year, I had an opportunity to delve into the world of JAX whilst working at InstaDeep. My first blopig post seems like an ideal time to share some of that knowledge. JAX is an experimental Python library created by Google’s DeepMind for applying accelerated differentiation. JAX can be used to differentiate functions written in NumPy or native Python, just-in-time compile and execute functions on GPUs and TPUs with XLA, and mini-batch repetitious functions with vectorization. Collectively, these qualities place JAX as an ideal candidate for accelerated deep learning research [1].

JAX is inspired by the NumPy API, making usage very familiar for any Python user who has already worked with NumPy [2]. However, unlike NumPy, JAX arrays are immutable; once they are assigned in memory they cannot be changed. As such, JAX includes specific syntax for index manipulation. In the code below, we create a JAX array and change the 1^{st} element to a 4:

Continue reading

Universal graph pooling for GNNs

Graph neural networks (GNNs) have quickly become one of the most important tools in computational chemistry and molecular machine learning. GNNs are a type of deep learning architecture designed for the adaptive extraction of vectorial features directly from graph-shaped input data, such as low-level molecular graphs. The feature-extraction mechanism of most modern GNNs can be decomposed into two phases:

  • Message-passing: In this phase the node feature vectors of the graph are iteratively updated following a trainable local neighbourhood-aggregation scheme often referred to as message-passing. Each iteration delivers a set of updated node feature vectors which is then imagined to form a new “layer” on top of all the previous sets of node feature vectors.
  • Global graph pooling: After a sufficient number of layers has been computed, the updated node feature vectors are used to generate a single vectorial representation of the entire graph. This step is known as global graph readout or global graph pooling. Usually only the top layer (i.e. the final set of updated node feature vectors) is used for global graph pooling, but variations of this are possible that involve all computed graph layers and even the set of initial node feature vectors. Commonly employed global graph pooling strategies include taking the sum or the average of the node features in the top graph layer.

While a lot of research attention has been focused on designing novel and more powerful message-passing schemes for GNNs, the global graph pooling step has often been treated with relative neglect. As mentioned in my previous post on the issues of GNNs, I believe this to be problematic. Naive global pooling methods (such as simply summing up all final node feature vectors) can potentially form dangerous information bottlenecks within the neural graph learning pipeline. In the worst case, such information bottlenecks pose the risk of largely cancelling out the information signal delivered by the message-passing step, no matter how sophisticated the message-passing scheme.

Continue reading

Entering a Stable Relationship with your Neural Network

Over the past year, I have been working on building a graph-based paratope (antibody binding site) prediction tool – Paragraph. Fortunately, I have had moderate success with this and you can now check out the preprint of this work here.

However, for a long time, I struggled with a highly unstable network, where different random seeds yielded very different results. I believe this instability was largely due to the high class imbalance in my data – only ~10% of all residues in the Fv (variable region of the antibody) belong to the paratope.

I tried many different things in an attempt to stabilise my training, most of which failed. I will share all of these ideas with you though – successful or not – as what works for one person/network is never guaranteed to work for another. I hope that the below may provide some ideas to try out for others facing similar issues. Where possible, I also provide some example hyperparameter values that could act as sensible starting points.

Continue reading

How to turn a SMILES string into a molecular graph for Pytorch Geometric

Despite some of their technical issues, graph neural networks (GNNs) are quickly being adopted as one of the state-of-the-art methods for molecular property prediction. The differentiable extraction of molecular features from low-level molecular graphs has become a viable (although not always superior) alternative to classical molecular representation techniques such as Morgan fingerprints and molecular descriptor vectors.

But molecular data usually comes in the sequential form of labeled SMILES strings. It is not obvious for beginners how to optimally transform a SMILES string into a structured molecular graph object that can be used as an input for a GNN. In this post, we show how to convert a SMILES string into a molecular graph object which can subsequently be used for graph-based machine learning. We do so within the framework of Pytorch Geometric which currently is one of the best and most commonly used Python-based GNN-libraries.

We divide our task into three high-level steps:

  1. We define a function that maps an RDKit atom object to a suitable atom feature vector.
  2. We define a function that maps an RDKit bond object to a suitable bond feature vector.
  3. We define a function that takes as its input a list of SMILES strings and associated labels and then uses the functions from 1.) and 2.) to create a list of labeled Pytorch Geometric graph objects as its output.
Continue reading

Issues with graph neural networks: the cracks are where the light shines through

Deep convolutional neural networks have lead to astonishing breakthroughs in the area of computer vision in recent years. The reason for the extraordinary performance of convolutional architectures in the image domain is their strong ability to extract informative high-level features from visual data. For prediction tasks on images, this has lead to superhuman performance in a variety of applications and to an almost universal shift from classical feature engineering to differentiable feature learning.

Unfortunately, the picture is not quite as rosy yet in the area of molecular machine learning. Feature learning techniques which operate directly on raw molecular graphs without intermediate feature-engineering steps have only emerged in the last few years in the form of graph neural networks (GNNs). GNNs, however, still have not managed to definitively outcompete and replace more classical non-differentiable molecular representation methods such as extended-connectivity fingerprints (ECFPs). There is an increasing awareness in the computational chemistry community that GNNs have not quite lived up to the initial hype and still suffer from a number of technical limitations.

Continue reading