Learning from graph data using Keras

Overview

Steps to run =>

  • Download the cora dataset from this link : https://linqs.soe.ucsc.edu/data
  • unzip the files in the folder input/cora
  • cd code
  • python eda.py
  • python word_features_only.py # for baseline model 53.28% accuracy
  • python graph_embedding.py # for model_1 73.06% accuracy
  • python graph_features_embedding.py # for model_2 76.35% accuracy

Learning from Graph data using Keras and Tensorflow

Cora Data set Citation Graph

Motivation :

There is a lot of data out there that can be represented in the form of a graph in real-world applications like in Citation Networks, Social Networks (Followers graph, Friends network, … ), Biological Networks or Telecommunications.
Using Graph extracted features can boost the performance of predictive models by relying of information flow between close nodes. However, representing graph data is not straightforward especially if we don’t intend to implement hand-crafted features.
In this post we will explore some ways to deal with generic graphs to do node classification based on graph representations learned directly from data.

Dataset :

The Cora citation network data set will serve as the base to the implementations and experiments throughout this post. Each node represents a scientific paper and edges between nodes represent a citation relation between the two papers.
Each node is represented by a set of binary features ( Bag of words ) as well as by a set of edges that link it to other nodes.
The dataset has 2708 nodes classified into one of seven classes. The network has 5429 links. Each Node is also represented by a binary word features indicating the presence of the corresponding word. Overall there is 1433 binary (Sparse) features for each node. In what follows we only use 140 samples for training and the rest for validation/test.

Problem Setting :

Problem : Assigning a class label to nodes in a graph while having few training samples.
Intuition/Hypothesis : Nodes that are close in the graph are more likely to have similar labels.
Solution : Find a way to extract features from the graph to help classify new nodes.

Proposed Approach :


Baseline Model :

Simple Baseline Model

We first experiment with the simplest model that learn to predict node classes using only the binary features and discarding all graph information.
This model is a fully-connected Neural Network that takes as input the binary features and outputs the class probabilities for each node.

Baseline model Accuracy : 53.28%

****This is the initial accuracy that we will try to improve on by adding graph based features.

Adding Graph features :

One way to automatically learn graph features by embedding each node into a vector by training a network on the auxiliary task of predicting the inverse of the shortest path length between two input nodes like detailed on the figure and code snippet below :

Learning an embedding vector for each node

The next step is to use the pre-trained node embedding as input to the classification model. We also add the an additional input which is the average binary features of the neighboring nodes using distance of learned embedding vectors.

The resulting classification network is described in the following figure :

Using pretrained embeddings to do node classification

Graph embedding classification model Accuracy : 73.06%

We can see that adding learned graph features as input to the classification model helps significantly improve the classification accuracy compared to the baseline model from **53.28% to 73.06% ** πŸ˜„ .

Improving Graph feature learning :

We can look to further improve the previous model by pushing the pre-training further and using the binary features in the node embedding network and reusing the pre-trained weights from the binary features in addition to the node embedding vector. This results in a model that relies on more useful representations of the binary features learned from the graph structure.

Improved Graph embedding classification model Accuracy : 76.35%

This additional improvement adds a few percent accuracy compared to the previous approach.

Conclusion :

In this post we saw that we can learn useful representations from graph structured data and then use these representations to improve the generalization performance of a node classification model from **53.28% to 76.35% ** 😎 .

Code to reproduce the results is available here : https://github.com/CVxTz/graph_classification

Owner
Mansar Youness
Mansar Youness
Deep Probabilistic Programming Course @ DIKU

Deep Probabilistic Programming Course @ DIKU

52 May 14, 2022
Implementation of the GVP-Transformer, which was used in the paper "Learning inverse folding from millions of predicted structures" for de novo protein design alongside Alphafold2

GVP Transformer (wip) Implementation of the GVP-Transformer, which was used in the paper Learning inverse folding from millions of predicted structure

Phil Wang 19 May 06, 2022
Related resources for our EMNLP 2021 paper

Plan-then-Generate: Controlled Data-to-Text Generation via Planning Authors: Yixuan Su, David Vandyke, Sihui Wang, Yimai Fang, and Nigel Collier Code

Yixuan Su 61 Jan 03, 2023
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intenti

NVIDIA Corporation 6.9k Jan 03, 2023
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Trevor Ablett*, Bryan Chan*,

STARS Laboratory 8 Sep 14, 2022
Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023
The implementation of ICASSP 2020 paper "Pixel-level self-paced learning for super-resolution"

Pixel-level Self-Paced Learning for Super-Resolution This is an official implementaion of the paper Pixel-level Self-Paced Learning for Super-Resoluti

Elon Lin 41 Dec 15, 2022
D2LV: A Data-Driven and Local-Verification Approach for Image Copy Detection

Facebook AI Image Similarity Challenge: Matching Track β€”β€” Team: imgFp This is the source code of our 3rd place solution to matching track of Image Sim

16 Dec 25, 2022
Source code for "Roto-translated Local Coordinate Framesfor Interacting Dynamical Systems"

Roto-translated Local Coordinate Frames for Interacting Dynamical Systems Source code for Roto-translated Local Coordinate Frames for Interacting Dyna

Miltiadis Kofinas 19 Nov 27, 2022
Residual Pathway Priors for Soft Equivariance Constraints

Residual Pathway Priors for Soft Equivariance Constraints This repo contains the implementation and the experiments for the paper Residual Pathway Pri

Marc Finzi 13 Oct 12, 2022
Implementation of our paper "DMT: Dynamic Mutual Training for Semi-Supervised Learning"

DMT: Dynamic Mutual Training for Semi-Supervised Learning This repository contains the code for our paper DMT: Dynamic Mutual Training for Semi-Superv

Zhengyang Feng 120 Dec 30, 2022
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Boostcamp CV Serving For Python

Boostcamp-CV-Serving Prerequisites MySQL GCP Cloud Storage GCP key file Sentry Streamlit Cloud Secrets: .streamlit/secrets.toml #DO NOT SHARE THIS I

Jungwon Seo 19 Feb 22, 2022
Fastshap: A fast, approximate shap kernel

fastshap: A fast, approximate shap kernel fastshap was designed to be: Fast Calculating shap values can take an extremely long time. fastshap utilizes

Samuel Wilson 22 Sep 24, 2022
BERT model training impelmentation using 1024 A100 GPUs for MLPerf Training v1.1

Pre-trained checkpoint and bert config json file Location of checkpoint and bert config json file This MLCommons members Google Drive location contain

SAIT (Samsung Advanced Institute of Technology) 12 Apr 27, 2022
PySOT - SenseTime Research platform for single object tracking, implementing algorithms like SiamRPN and SiamMask.

PySOT is a software system designed by SenseTime Video Intelligence Research team. It implements state-of-the-art single object tracking algorit

STVIR 4.1k Dec 29, 2022
Convolutional neural network that analyzes self-generated images in a variety of languages to find etymological similarities

This project is a convolutional neural network (CNN) that analyzes self-generated images in a variety of languages to find etymological similarities. Specifically, the goal is to prove that computer

1 Feb 03, 2022
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 03, 2023