Edge-Augmented Graph Transformer

Overview

PWCPWCPWCPWCPWC

Edge-augmented Graph Transformer

Introduction

This is the official implementation of the Edge-augmented Graph Transformer (EGT) as described in https://arxiv.org/abs/2108.03348, which augments the Transformer architecture with residual edge channels. The resultant architecture can directly process graph-structured data and acheives good results on supervised graph-learning tasks as presented by Dwivedi et al.. It also achieves good performance on the large-scale PCQM4M-LSC (0.1263 MAE on val) dataset. EGT beats convolutional/message-passing graph neural networks on a wide range of supervised tasks and thus demonstrates that convolutional aggregation is not an essential inductive bias for graphs.

Requirements

  • python >= 3.7
  • tensorflow >= 2.1.0
  • h5py >= 2.8.0
  • numpy >= 1.18.4
  • scikit-learn >= 0.22.1

Download the Datasets

For our experiments, we converted the datasets to HDF5 format for the convenience of using them without any specific library. Only the h5py library is required. The datasets can be downloaded from -

Or you can simply run the provided bash scripts download_medium_scale_datasets.sh, download_large_scale_datasets.sh. The default location of the datasets is the datasets directory.

Run Training and Evaluations

You must create a JSON config file containing the configuration of a model, its training and evaluation configs (configurations). The same config file is used to do both training and evaluations.

  • To run training: python run_training.py <config_file.json>
  • To end training (prematurely): python end_training.py <config_file.json>
  • To perform evaluations: python do_evaluations.py <config_file.json>

Config files for the main results presented in the paper are contained in the configs/main directory, whereas configurations for the ablation study are contained in the configs/ablation directory. The paths and names of the files are self-explanatory.

More About Training and Evaluations

Once the training is started a model folder will be created in the models directory, under the specified dataset name. This folder will contain a copy of the input config file, for the convenience of resuming training/evaluation. Also, it will contain a config.json which will contain all configs, including unspecified default values, used for the training. Training will be checkpointed per epoch. In case of any interruption you can resume training by running the run_training.py with the config.json file again.

In case you wish to finalize training midway, just stop training and run end_training.py script with the config.json file to save the model weights.

After training, you can run the do_evaluations.py script with the same config file to perform evaluations. Alongside being printed to stdout, results will be saved in the predictions directory, under the model directory.

Config File

The config file can contain many different configurations, however, the only required configuration is scheme, which specifies the training scheme. If the other configurations are not specified, a default value will be assumed for them. Here are some of the commonly used configurations:

scheme: Used to specify the training scheme. It has a format <dataset_name>.<positional_encoding>. For example: cifar10.svd or zinc.eig. If no encoding is to be used it can be something like pcqm4m.mat. For a full list you can explore the lib/training/schemes directory.

dataset_path: If the datasets are contained in the default location in the datasets directory, this config need not be specified. Otherwise you have to point it towards the <dataset_name>.h5 file.

model_name: Serves as an identifier for the model, also specifies default path of the model directory, weight files etc.

save_path: The training process will create a model directory containing the logs, checkpoints, configs, model summary and predictions/evaluations. By default it creates a folder at models/<dataset_name>/<model_name> but it can be changed via this config.

cache_dir: During first time of training/evaluation the data will be cached to a tensorflow cache format. Default path is data_cache/<dataset_name>/<positional_encoding>. But it can be changed via this config.

distributed: In a multi-gpu setting you can set it to True, for distributed training.

batch_size: Batch size.

num_epochs: Maximum Number of epochs.

initial_lr: Initial learning rate. In case of warmup it is the maximum learning rate.

rlr_factor: Reduce LR on plateau factor. Setting it to a value >= 1.0 turns off Reduce LR.

rlr_patience: Reduce LR patience, i.e. the number of epochs after which LR is reduced if validation loss doesn't improve.

min_lr_factor: The factor by which the minimum LR is smaller, of the initial LR. Default is 0.01.

model_height: The number of layers L.

model_width: The dimensionality of the node channels d_h.

edge_width: The dimensionality of the edge channels d_e.

num_heads: The number of attention heads. Default is 8.

ffn_multiplier: FFN multiplier for both channels. Default is 2.0 .

virtual_nodes: number of virtual nodes. 0 (default) would result in global average pooling being used instead of virtual nodes.

upto_hop: Clipping value of the input distance matrix. A value of 1 (default) would result in adjacency matrix being used as input structural matrix.

mlp_layers: Dimensionality of the final MLP layers, specified as a list of factors with respect to d_h. Default is [0.5, 0.25].

gate_attention: Set this to False to get the ungated EGT variant (EGT-U).

dropout: Dropout rate for both channels. Default is 0.

edge_dropout: If specified, applies a different dropout rate to the edge channels.

edge_channel_type: Used to create ablated variants of EGT. A value of "residual" (default) implies pure/full EGT. "constrained" implies EGT-constrained. "bias" implies EGT-simple.

warmup_steps: If specified, performs a linear learning rate warmup for the specified number of gradient update steps.

total_steps: If specified, performs a cosine annealing after warmup, so that the model is trained for the specified number of steps.

[For SVD-based encodings]:

use_svd: Turning this off (False) would result in no positional encoding being used.

sel_svd_features: Rank of the SVD encodings r.

random_neg: Augment SVD encodings by random negation.

[For Eigenvectors encodings]:

use_eig: Turning this off (False) would result in no positional encoding being used.

sel_eig_features: Number of eigen vectors.

[For Distance prediction Objective (DO)]:

distance_target: Predict distance up to the specified hop, nu.

distance_loss: Factor by which to multiply the distance prediction loss, kappa.

Creation of the HDF5 Datasets from Scratch

We included two Jupyter notebooks to demonstrate how the HDF5 datasets are created

  • For the medium scale datasets view create_hdf_benchmarking_datasets.ipynb. You will need pytorch, ogb==1.1.1 and dgl==0.4.2 libraries to run the notebook. The notebook is also runnable on Google Colaboratory.
  • For the large scale pcqm4m dataset view create_hdf_pcqm4m.ipynb. You will need pytorch, ogb>=1.3.0 and rdkit>=2019.03.1 to run the notebook.

Python Environment

The Anaconda environment in which our experiments were conducted is specified in the environment.yml file.

Citation

Please cite the following paper if you find the code useful:

@article{hussain2021edge,
  title={Edge-augmented Graph Transformers: Global Self-attention is Enough for Graphs},
  author={Hussain, Md Shamim and Zaki, Mohammed J and Subramanian, Dharmashankar},
  journal={arXiv preprint arXiv:2108.03348},
  year={2021}
}
Owner
Md Shamim Hussain
Md Shamim Hussain is a Ph.D. student in Computer Science at Rensselaer Polytechnic Institute, NY. He got his B.Sc. and M.Sc. in EEE from BUET, Dhaka.
Md Shamim Hussain
Code for Editing Factual Knowledge in Language Models

KnowledgeEditor Code for Editing Factual Knowledge in Language Models (https://arxiv.org/abs/2104.08164). @inproceedings{decao2021editing, title={Ed

Nicola De Cao 86 Nov 28, 2022
Amazon Multilingual Counterfactual Dataset (AMCD)

Amazon Multilingual Counterfactual Dataset (AMCD)

35 Sep 20, 2022
NeMo: a toolkit for conversational AI

NVIDIA NeMo Introduction NeMo is a toolkit for creating Conversational AI applications. NeMo product page. Introductory video. The toolkit comes with

NVIDIA Corporation 5.3k Jan 04, 2023
Translators - is a library which aims to bring free, multiple, enjoyable translation to individuals and students in Python

Translators - is a library which aims to bring free, multiple, enjoyable translation to individuals and students in Python

UlionTse 907 Dec 27, 2022
Longformer: The Long-Document Transformer

Longformer Longformer and LongformerEncoderDecoder (LED) are pretrained transformer models for long documents. ***** New December 1st, 2020: Longforme

AI2 1.6k Dec 29, 2022
Creating a Feed of MISP Events from ThreatFox (by abuse.ch)

ThreatFox2Misp Creating a Feed of MISP Events from ThreatFox (by abuse.ch) What will it do? This will fetch IOCs from ThreatFox by Abuse.ch, convert t

17 Nov 22, 2022
Code Implementation of "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction".

Span-ASTE: Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction ***** New March 31th, 2022: Scikit-Style API for Easy Usage *****

Chia Yew Ken 111 Dec 23, 2022
OCR을 이용하여 인원수를 인식 후 줌을 Kill 해줍니다

How To Use killtheZoom-2.0 Windows 0. https://joyhong.tistory.com/79 이 글을 보면서 tesseract를 C:\Program Files\Tesseract-OCR 경로로 설치해주세요(한국어 언어 추가 필요) 상단의 초

김정인 9 Sep 13, 2021
Tools to download and cleanup Common Crawl data

cc_net Tools to download and clean Common Crawl as introduced in our paper CCNet. If you found these resources useful, please consider citing: @inproc

Meta Research 483 Jan 02, 2023
The code from the whylogs workshop in DataTalks.Club on 29 March 2022

whylogs Workshop The code from the whylogs workshop in DataTalks.Club on 29 March 2022 whylogs - The open source standard for data logging (Don't forg

DataTalksClub 12 Sep 05, 2022
Perform sentiment analysis on textual data that people generally post on websites like social networks and movie review sites.

Sentiment Analyzer The goal of this project is to perform sentiment analysis on textual data that people generally post on websites like social networ

Madhusudan.C.S 53 Mar 01, 2022
null

CP-Cluster Confidence Propagation Cluster aims to replace NMS-based methods as a better box fusion framework in 2D/3D Object detection, Instance Segme

Yichun Shen 41 Dec 08, 2022
Club chatbot

Chatbot Club chatbot Instructions to get the Chatterbot working Step 1. First make sure you are using a version of Python 3 or newer. To check your ve

5 Mar 07, 2022
Official implementation of Meta-StyleSpeech and StyleSpeech

Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation Dongchan Min, Dong Bok Lee, Eunho Yang, and Sung Ju Hwang This is an official code

min95 169 Jan 05, 2023
An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI, torch2trt to accelerate. our model support for int8, dynamic input and profiling. (Nvidia-Alibaba-TensoRT-hackathon2021)

Ultra_Fast_Lane_Detection_TensorRT An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI to accelerate. our model support for in

steven.yan 121 Dec 27, 2022
SimBERT升级版(SimBERTv2)!

RoFormer-Sim RoFormer-Sim,又称SimBERTv2,是我们之前发布的SimBERT模型的升级版。 介绍 https://kexue.fm/archives/8454 训练 tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 下载

317 Dec 23, 2022
AI-Broad-casting - AI Broad casting with python

Basic Code 1. Use The Code Configuration Environment conda create -n code_base p

Transformer related optimization, including BERT, GPT

This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA.

NVIDIA Corporation 1.7k Jan 04, 2023
Fidibo.com comments Sentiment Analyser

Fidibo.com comments Sentiment Analyser Introduction This project first asynchronously grab Fidibo.com books comment data using grabber.py and then sav

Iman Kermani 3 Apr 15, 2022
Leon is an open-source personal assistant who can live on your server.

Leon Your open-source personal assistant. Website :: Documentation :: Roadmap :: Contributing :: Story 👋 Introduction Leon is an open-source personal

Leon AI 11.7k Dec 30, 2022