A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

Overview

CapsGNN

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019).

Abstract

The high-quality node embeddings learned from the Graph Neural Networks (GNNs) have been applied to a wide range of node-based applications and some of them have achieved state-of-the-art (SOTA) performance. However, when applying node embeddings learned from GNNs to generate graph embeddings, the scalar node representation may not suffice to preserve the node/graph properties efficiently, resulting in sub-optimal graph embeddings. Inspired by the Capsule Neural Network (CapsNet), we propose the Capsule Graph Neural Network (CapsGNN), which adopts the concept of capsules to address the weakness in existing GNN-based graph embeddings algorithms. By extracting node features in the form of capsules, routing mechanism can be utilized to capture important information at the graph level. As a result, our model generates multiple embeddings for each graph to capture graph properties from different aspects. The attention module incorporated in CapsGNN is used to tackle graphs with various sizes which also enables the model to focus on critical parts of the graphs. Our extensive evaluations with 10 graph-structured datasets demonstrate that CapsGNN has a powerful mechanism that operates to capture macroscopic properties of the whole graph by data-driven. It outperforms other SOTA techniques on several graph classification tasks, by virtue of the new instrument.

This repository provides a PyTorch implementation of CapsGNN as described in the paper:

Capsule Graph Neural Network. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Paper]

The core Capsule Neural Network implementation adapted is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id and node label has to be indexed from 0. Keys of dictionaries are stored strings in order to make JSON serialization possible.

Every JSON file has the following key-value structure:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
 "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
 "target": 1}

The **edges** key has an edge list value which descibes the connectivity structure. The **labels** key has labels for each node which are stored as a dictionary -- within this nested dictionary labels are values, node identifiers are keys. The **target** key has an integer value which is the class membership.

Outputs

The predictions are saved in the `output/` directory. Each embedding has a header and a column with the graph identifiers. Finally, the predictions are sorted by the identifier column.

Options

Training a CapsGNN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --training-graphs   STR    Training graphs folder.      Default is `dataset/train/`.
  --testing-graphs    STR    Testing graphs folder.       Default is `dataset/test/`.
  --prediction-path   STR    Output predictions file.     Default is `output/watts_predictions.csv`.

Model options

  --epochs                      INT     Number of epochs.                  Default is 100.
  --batch-size                  INT     Number fo graphs per batch.        Default is 32.
  --gcn-filters                 INT     Number of filters in GCNs.         Default is 20.
  --gcn-layers                  INT     Number of GCNs chained together.   Default is 2.
  --inner-attention-dimension   INT     Number of neurons in attention.    Default is 20.  
  --capsule-dimensions          INT     Number of capsule neurons.         Default is 8.
  --number-of-capsules          INT     Number of capsules in layer.       Default is 8.
  --weight-decay                FLOAT   Weight decay of Adam.              Defatuls is 10^-6.
  --lambd                       FLOAT   Regularization parameter.          Default is 0.5.
  --theta                       FLOAT   Reconstruction loss weight.        Default is 0.1.
  --learning-rate               FLOAT   Adam learning rate.                Default is 0.01.

Examples

The following commands learn a model and save the predictions. Training a model on the default dataset:

$ python src/main.py

Training a CapsGNNN model for a 100 epochs.

$ python src/main.py --epochs 100

Changing the batch size.

$ python src/main.py --batch-size 128

License

Comments
  •  Coordinate Addition module & Routing

    Coordinate Addition module & Routing

    Hi, thanks for your codes of GapsGNN. And I have some questions about Coordinate Addition module and Routing.

    1. Do you use Coordinate Addition module in this codes?
    2. In /src/layers.py, line 137 : c_ij = torch.nn.functional.softmax(b_ij, dim=0) . At this time, b_ij.size(0) == 1, why use dim =0 ?

    Thanks again.

    opened by S-rz 4
  • Something about reshape

    Something about reshape

    Hi @benedekrozemberczki ! Thank you for your work!

    I have a question at line 61 and 62 of CapsGNN/src/capsgnn.py

    hidden_representations = torch.cat(tuple(hidden_representations)) hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters,-1)

    Why you directly reshape L*N,D to 1,L,D,N instead of using permutation after reshape, e.g

    hidden_representations = hidden_representations.view(1, self.args.gcn_layers, -1,self.args.gcn_filters).permute(0,1,3,2)

    Thank you for your help!

    opened by yanx27 4
  • Reproduce Issues

    Reproduce Issues

    Hi, thanks for your PyTorch codes of GapsGNN. I try to run the codes on NCI, DD, and other graph classification datasets, but it doesn't work (For example, training loss converges to 2.0, and test acc is about 50% on NCI1 after several iterations.) How should I do if I want to run these codes on NCI, DD and etc? Thanks again.

    opened by veophi 1
  • D&D dataset

    D&D dataset

    I notice some datasets in your paper such as D&D dataset. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by try-to-anything 1
  • Other datasets

    Other datasets

    I notice some datasets in your paper such as RE-M5K and RE-M12K. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by HongyangGao 1
  • Not able to install torch-scatter with torch 0.4.1

    Not able to install torch-scatter with torch 0.4.1

    Hello,

    Thanks for sharing the implementation.

    While I'm try to run your code I get some error for installing the environment. I have torch 0.4.1, but not able to install torch-scatter.Got the following error: fatal error: torch/extension.h: No such file or directory

    But I can successfully install them for torch 1.0.

    Is your code working for torch 1.0? Or how to install torch-scatter for torch 0.4.1?

    Details:

    $ pip list Package Version


    backcall 0.1.0
    certifi 2018.8.24
    .... torch 0.4.1.post2 torch-geometric 1.1.1
    torchfile 0.1.0
    torchvision 0.2.1
    tornado 5.1
    tqdm 4.31.1
    traitlets 4.3.2
    urllib3 1.23
    visdom 0.1.8.5
    vispy 0.5.3
    .... ....

    $pip install torch-scatter

    opened by jkuh626 1
  • how to repeat your expriments?

    how to repeat your expriments?

    Enumerating feature and target values.

    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 14754.82it/s]

    Training started.

    Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00, 1.90it/s] CapsGNN (Loss=0.7279): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]

    Scoring.

    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 128.47it/s]

    Accuracy: 0.3333

    Accuracy is too small

    opened by robotzheng 1
  • default input dir for graphs is

    default input dir for graphs is "input"

    The README mentions the default train and test graphs to be in dataset/train and dataset/test, whereas they are in input/train and input/test respectively. The param_parser.py has the correct default paths nevertheless.

    opened by Utkarsh87 0
Releases(v_0001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation

Multipath RefineNet A MATLAB based framework for semantic image segmentation and general dense prediction tasks on images. This is the source code for

Guosheng Lin 575 Dec 06, 2022
Gated-Shape CNN for Semantic Segmentation (ICCV 2019)

GSCNN This is the official code for: Gated-SCNN: Gated Shape CNNs for Semantic Segmentation Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler

859 Dec 26, 2022
My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs (GNN, GAT, GraphSAGE, GCN)

machine-learning-with-graphs My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs Course materials can be

Marko Njegomir 7 Dec 14, 2022
Official Implementation and Dataset of "PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask and Group-Level Consistency", CVPR 2021

Portrait Photo Retouching with PPR10K Paper | Supplementary Material PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask an

184 Dec 11, 2022
PSPNet in Chainer

PSPNet This is an unofficial implementation of Pyramid Scene Parsing Network (PSPNet) in Chainer. Training Requirement Python 3.4.4+ Chainer 3.0.0b1+

Shunta Saito 76 Dec 12, 2022
Python package for dynamic system estimation of time series

PyDSE Toolset for Dynamic System Estimation for time series inspired by DSE. It is in a beta state and only includes ARMA models right now. Documentat

Blue Yonder GmbH 40 Oct 07, 2022
[Nature Machine Intelligence' 21] "Advancing COVID-19 Diagnosis with Privacy-Preserving Collaboration in Artificial Intelligence"

[UCADI] COVID-19 Diagnosis With Federated Learning Intro We developed a Federated Learning (FL) Framework for global researchers to collaboratively tr

HUST EIC AI-LAB 30 Dec 12, 2022
BBScan py3 - BBScan py3 With Python

BBScan_py3 This repository is forked from lijiejie/BBScan 1.5. I migrated the fo

baiyunfei 12 Dec 30, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
StyleGAN2-ADA-training-jupyter - Training custom datasets in styleGAN2-ADA by NVIDIA using Jupyter

styleGAN2-ADA-training-jupyter Training custom datasets in styleGAN2-ADA on Jupyter Official StyleGAN2-ADA by NIVIDIA Paper Training Generative Advers

Mang Su Hyun 2 Feb 24, 2022
A task Provided by A respective Artenal Ai and Ml based Company to complete it

A task Provided by A respective Alternal Ai and Ml based Company to complete it .

Parth Madan 1 Jan 25, 2022
The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text"

Finnish Dialect Identification The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text". We present a te

Rootroo Ltd 2 Dec 25, 2021
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

229 Dec 13, 2022
Multi-modal Vision Transformers Excel at Class-agnostic Object Detection

Multi-modal Vision Transformers Excel at Class-agnostic Object Detection

Muhammad Maaz 206 Jan 04, 2023
Experiments for Neural Flows paper

Neural Flows: Efficient Alternative to Neural ODEs [arxiv] TL;DR: We directly model the neural ODE solutions with neural flows, which is much faster a

54 Dec 07, 2022
Repository for the paper "Online Domain Adaptation for Occupancy Mapping", RSS 2020

RSS 2020 - Online Domain Adaptation for Occupancy Mapping Repository for the paper "Online Domain Adaptation for Occupancy Mapping", Robotics: Science

Anthony 26 Sep 22, 2022
Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network

DeepCDR Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network This work has been accepted to ECCB2020 and was also published in the

Qiao Liu 50 Dec 18, 2022
Rethinking Nearest Neighbors for Visual Classification

Rethinking Nearest Neighbors for Visual Classification arXiv Environment settings Check out scripts/env_setup.sh Setup data Download the following fin

Menglin Jia 29 Oct 11, 2022
Semi-supevised Semantic Segmentation with High- and Low-level Consistency

Semi-supevised Semantic Segmentation with High- and Low-level Consistency This Pytorch repository contains the code for our work Semi-supervised Seman

123 Dec 30, 2022