Transformers are Graph Neural Networks!

Overview

🚀 Gated Graph Transformers

Gated Graph Transformers for graph-level property prediction, i.e. graph classification and regression.

Associated article: Transformers are Graph Neural Networks, by Chaitanya K. Joshi, published with The Gradient.

This repository is a continuously updated personal project to build intuitions about and track progress in Graph Representation Learning research. I aim to develop the most universal and powerful model which unifies state-of-the-art architectures from Graph Neural Networks and Transformers, without incorporating domain-specific tricks.

Gated Graph Transformer

Key Architectural Ideas

🤖 Deep, Residual Transformer Backbone

  • As the backbone architecture, I borrow the two-sub-layered, pre-normalization variant of Transformer encoders that has emerged as the standard in the NLP community, e.g. GPT-3. Each Transformer block consists of a message-passing sub-layer followed by a node-wise feedforward sub-layer. The graph convolution is described later.
  • The feedforward sub-layer projects node embeddings to an absurdly large dimension, passes them through a non-linear activation function, does dropout, and reduces back to the original embedding dimension.
  • The Transformer backbone enables training very deep and extremely overparameterized models. Overparameterization is important for performance in NLP and other combinatorially large domains, but was previously not possible for GNNs trained on small graph classifcation datasets. Coupled with unique node positional encodings (described later) and the feedforward sub-layer, overparameterization ensures that our GNN is Turing Universal (based on A. Loukas's recent insightful work, including this paper).

✉️ Anisotropic Graph Convolutions


Source: 'Deep Parametric Continuous Convolutional Neural Networks', Wang et al., 2018

  • As the graph convolution layer, I use the Gated Graph Convolution with dense attention mechanism, which we found to be the best performing graph convolution in Benchmarking GNNs. Intuitively, Gated GraphConv generalizes directional CNN filters for 2D images to arbitrary graphs by learning a weighted aggregations over the local neighbors of each node. It upgrades the node-to-node attention mechanism from GATs and MoNet (i.e. one attention weight per node pair) to consider dense feature-to-feature attention (i.e. d attention weights for pairs of d-dimensional node embeddings).
  • Another intuitive motivation for the Gated GraphConv is as a learnable directional diffusion process over the graph, or as a coupled PDE over node and edge features in the graph. Gated GraphConv makes the diffusion process/neighborhood aggregation anisotropic or directional, countering oversmoothing/oversquashing of features and enabling deeper models.
  • This graph convolution was originally proposed as a sentence encoder for NLP and further developed at NTU for molecule generation and combinatorial optimization. Evidently, I am partial to this idea. At the same time, it is worth noting that anisotropic local aggregations and generalizations of directed CNN filters have demonstrated strong performance across a myriad of applications, including 3D point clouds, drug discovery, material science, and programming languages.

🔄 Graph Positional Encodings


Source: 'Geometric Deep Learning: Going beyond Euclidean Data', Bronstein et al., 2017

  • I use the top-k non-trivial Laplacian Eigenvectors as unique node identifiers to inject structural/positional priors into the Transformer backbone. Laplacian Eigenvectors are a generalization of sinusoidal positional encodings from the original Transformers, and were concurrently proposed in the Benchmarking GNNs, EigenGNNs, and GCC papers.
  • Randomly flipping the sign of Laplacian Eigenvectors during training (due to symmetry) can be seen as an additional data augmentation or regularization technique, helping delay overfitting to training patterns. Going further, the Directional Graph Networks paper presents a more principled approach for using Laplacian Eigenvectors.

Some ideas still in the pipeline include:

  • Graph-specific Normalization - Originally motivated in Benchmarking GNNs as 'graph size normalization', there have been several subsequent graph-specific normalization techniques such as GraphNorm and MessageNorm, aiming to replace or augment standard Batch Normalization. Intuitively, there is room for improvement as BatchNorm flattens mini-batches of graphs instead of accounting for the underlying graph structure.

  • Theoretically Expressive Aggregation - There are several exciting ideas aiming to bridge the gap between theoretical expressive power, computational feasability, and generalization capacity for GNNs: PNA-style multi-head aggregation and scaling, generalized aggreagators from DeeperGCNs, pre-computing structural motifs as in GSN, etc.

  • Virtual Node and Low Rank Global Attention - After the message-passing step, the virtual node trick adds messages to-and-fro a virtual/super node connected to all graph nodes. LRGA comes with additional theretical motivations but does something similar. Intuitively, these techniques enable modelling long range or latent interactions in graphs and counter the oversquashing problem with deeper networks.

  • General Purpose Pre-training - It isn't truly a Transformer unless its pre-trained on hundreds of GPUs for thousands of hours...but general purpose pre-training for graph representation learning remains an open question!

Installation and Usage

# Create new Anaconda environment
conda create -n new-env python=3.7
conda activate new-env
# Install PyTorch 1.6 for CUDA 10.x
conda install pytorch=1.6 cudatoolkit=10.x -c pytorch
# Install DGL for CUDA 10.x
conda install -c dglteam dgl-cuda10.x
# Install other dependencies
conda install tqdm scikit-learn pandas urllib3 tensorboard
pip install -U ogb

# Train GNNs on ogbg-mol* datasets
python main_mol.py --dataset [ogbg-molhiv/ogbg-molpcba] --gnn [gated-gcn/gcn/mlp]

# Prepare submission for OGB leaderboards
bash scripts/ogbg-mol*.sh

# Collate results for submission
python submit.py --dataset [ogbg-molhiv/ogbg-molpcba] --expt [path-to-logs]

Note: The code was tested on Ubuntu 16.04, using Python 3.6, PyTorch 1.6 and CUDA 10.1.

Citation

@article{joshi2020transformers,
  author = {Joshi, Chaitanya K},
  title = {Transformers are Graph Neural Networks},
  journal = {The Gradient},
  year = {2020},
  howpublished = {\url{https://thegradient.pub/transformers-are-gaph-neural-networks/ } },
}
Owner
Chaitanya Joshi
Research Engineer at A*STAR, working on Graph Neural Networks
Chaitanya Joshi
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
A simple log parser and summariser for IIS web server logs

IISLogFileParser A basic parser tool for IIS Logs which summarises findings from the log file. Inspired by the Gist https://gist.github.com/wh13371/e7

2 Mar 26, 2022
ParaGen is a PyTorch deep learning framework for parallel sequence generation

ParaGen is a PyTorch deep learning framework for parallel sequence generation. Apart from sequence generation, ParaGen also enhances various NLP tasks, including sequence-level classification, extrac

Bytedance Inc. 169 Dec 22, 2022
Assessing the Influence of Models on the Performance of Reinforcement Learning Algorithms applied on Continuous Control Tasks

Assessing the Influence of Models on the Performance of Reinforcement Learning Algorithms applied on Continuous Control Tasks This is the master thesi

Giacomo Arcieri 1 Mar 21, 2022
Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting

Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting Note: You can find here the accompanying seq2seq RNN forecas

Guillaume Chevalier 1k Dec 25, 2022
DP-CL(Continual Learning with Differential Privacy)

DP-CL(Continual Learning with Differential Privacy) This is the official implementation of the Continual Learning with Differential Privacy. If you us

Phung Lai 3 Nov 04, 2022
Fast, Attemptable Route Planner for Navigation in Known and Unknown Environments

FAR Planner uses a dynamically updated visibility graph for fast replanning. The planner models the environment with polygons and builds a global visi

Fan Yang 346 Dec 30, 2022
Python TFLite scripts for detecting objects of any class in an image without knowing their label.

Python TFLite scripts for detecting objects of any class in an image without knowing their label.

Ibai Gorordo 42 Oct 07, 2022
Draw like Bob Ross using the power of Neural Networks (With PyTorch)!

Draw like Bob Ross using the power of Neural Networks! (+ Pytorch) Learning Process Visualization Getting started Install dependecies Requires python3

Kendrick Tan 116 Mar 07, 2022
CLIPort: What and Where Pathways for Robotic Manipulation

CLIPort CLIPort: What and Where Pathways for Robotic Manipulation Mohit Shridhar, Lucas Manuelli, Dieter Fox CoRL 2021 CLIPort is an end-to-end imitat

246 Dec 11, 2022
Code for the paper "Multi-task problems are not multi-objective"

Multi-Task problems are not multi-objective This is the code for the paper "Multi-Task problems are not multi-objective" in which we show that the com

Michael Ruchte 5 Aug 19, 2022
True Few-Shot Learning with Language Models

This codebase supports using language models (LMs) for true few-shot learning: learning to perform a task using a limited number of examples from a single task distribution.

Ethan Perez 124 Jan 04, 2023
PyTorch Personal Trainer: My framework for deep learning experiments

Alex's PyTorch Personal Trainer (ptpt) (name subject to change) This repository contains my personal lightweight framework for deep learning projects

Alex McKinney 8 Jul 14, 2022
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization'

pytorch-AdaIN This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Hua

Naoto Inoue 873 Jan 06, 2023
make ASCII Art by Deep Learning

DeepAA This is convolutional neural networks generating ASCII art. This repository is under construction. This work is accepted by NIPS 2017 Workshop,

OsciiArt 1.4k Dec 28, 2022
A `Neural = Symbolic` framework for sound and complete weighted real-value logic

Logical Neural Networks LNNs are a novel Neuro = symbolic framework designed to seamlessly provide key properties of both neural nets (learning) and s

International Business Machines 138 Dec 19, 2022
Semantic Image Synthesis with SPADE

Semantic Image Synthesis with SPADE New implementation available at imaginaire repository We have a reimplementation of the SPADE method that is more

NVIDIA Research Projects 7.3k Jan 07, 2023
LowRankModels.jl is a julia package for modeling and fitting generalized low rank models.

LowRankModels.jl LowRankModels.jl is a Julia package for modeling and fitting generalized low rank models (GLRMs). GLRMs model a data array by a low r

Madeleine Udell 183 Dec 17, 2022
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 09, 2023
HomeAssitant custom integration for dyson

HomeAssistant Custom Integration for Dyson This custom integration is still under development. This is a HA custom integration for dyson. There are se

Xiaonan Shen 232 Dec 31, 2022