Universal graph pooling for GNNs

Graph neural networks (GNNs) have quickly become one of the most important tools in computational chemistry and molecular machine learning. GNNs are a type of deep learning architecture designed for the adaptive extraction of vectorial features directly from graph-shaped input data, such as low-level molecular graphs. The feature-extraction mechanism of most modern GNNs can be decomposed into two phases:

  • Message-passing: In this phase the node feature vectors of the graph are iteratively updated following a trainable local neighbourhood-aggregation scheme often referred to as message-passing. Each iteration delivers a set of updated node feature vectors which is then imagined to form a new “layer” on top of all the previous sets of node feature vectors.
  • Global graph pooling: After a sufficient number of layers has been computed, the updated node feature vectors are used to generate a single vectorial representation of the entire graph. This step is known as global graph readout or global graph pooling. Usually only the top layer (i.e. the final set of updated node feature vectors) is used for global graph pooling, but variations of this are possible that involve all computed graph layers and even the set of initial node feature vectors. Commonly employed global graph pooling strategies include taking the sum or the average of the node features in the top graph layer.

While a lot of research attention has been focused on designing novel and more powerful message-passing schemes for GNNs, the global graph pooling step has often been treated with relative neglect. As mentioned in my previous post on the issues of GNNs, I believe this to be problematic. Naive global pooling methods (such as simply summing up all final node feature vectors) can potentially form dangerous information bottlenecks within the neural graph learning pipeline. In the worst case, such information bottlenecks pose the risk of largely cancelling out the information signal delivered by the message-passing step, no matter how sophisticated the message-passing scheme.

In this post I would like to draw attention to a 2019 paper from Navarin, Van Tran and Sperduti titled Universal Readout for Graph Convolutional Neural Networks [1]. This paper immediately caught my attention; the authors took the issue of a potential information-bottleneck during global graph pooling seriously and presented an elegant solution approach, along with impressive theoretical underpinnings derived from results about deep learning on mathematical sets presented in [2].

There are three basic properties we reasonably want from a global pooling function:

  • (1) Arbitrary input length: we want the function to be defined for an arbitrary number of input node feature vectors (as input graphs can have different sizes).
  • (2) Permutation-invariance: we want the function to deliver the same result for all possible (arbitrary) orderings of input node feature vectors.
  • (3) Continuity: we want the function to be continuous in the mathematical sense, i.e. we want to assure that a continuous variation of the input node feature vectors only ever leads to a continuous variation (and never to a sudden jump) of the output vector.

Navarin, Van Tran and Sperduti managed to specifically describe a trainable neural function that can provably approximate arbitrarily well every possible function which fulfills the requirements (1) – (3). In other words, they found a trainable universal pooling function that can approximate arbitrarily well every reasonable pooling function.

Their proposed pooling function is simple to understand and implement. It consists of a neural network transformation of each node feature vector, followed by a summation, following by another neural network transformation of the sum:

The involved neural networks need to have at least one hidden layer (i.e. an input layer, a hidden layer, and an output layer) each so that the universal approximation theorem guarantees that they can approximate arbitrary continuous functions.

The approximation powers of this universal pooling function allow a GNN-architecture to, at least in theory, always learn the global pooling operation that is most appropriate to the prediction task at hand. Note in particular that this would eliminate the risk of harming model performance by manually choosing an unsuitable pooling operation. In practise, the actual approximative powers of this universal pooling function will of course strongly depend on a variety of factors such as the specific architectures of the involved neural networks and the available amounts of training data.

In spite of the compelling theoretical foundations of the universal pooling operator and its simplicity and elegance, it seems that it has not been used in a large number of publications yet. The paper by Navarin et al. has only been cited 22 times as of 5th of October 2022 according to Google Scholar. The experimental results presented in the paper are encouraging. However, I would expect that the full advantage of universal graph pooling relative to commonly used pooling operations (such as summation or averaging of node feature vectors) will become most apparent in settings where large amounts of data are available. This is why I intend to employ universal graph pooling in some of my own future experiments where I will implement my ideas in the realm of self-supervised molecular representation learning. There I will have the luxury of being able to train state-of-the-art GNN-models on literally millions of (unlabelled) molecular graphs. I am curious to see the results.

[1] Navarin, N., Van Tran, D., & Sperduti, A. (2019, July). Universal readout for graph convolutional neural networks. In 2019 International Joint Conference on Neural Networks (IJCNN) (pp. 1-7). IEEE.

[2] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017). Deep sets. Advances in neural information processing systems, 30.

Author