FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Overview

FedJAX: Federated learning with JAX

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

The FedJAX Intro notebook provides an introduction into running existing FedJAX experiments. For more custom use cases, please refer to the FedJAX Advanced notebook.

You can also take a look at some of our examples:

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Useful pointers

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

Comments
  • FedJax depends on TensorFlow Federated?

    FedJax depends on TensorFlow Federated?

    I am helping users install FedJax for use in their federated learning research projects and I noticed that installing FedJax is pulling in TensorFlow Federated (0.17) and TensorFlow (2.3). I don't see either of these listed as dependencies of FedJax so I am trying to understand why they are being pulled in by pip install fedjax.

    opened by davidrpugh 7
  • CIFAR 100 Questions

    CIFAR 100 Questions

    Hi, thanks for the awesome library! I want to ask a couple of questions related to CIFAR100 datasets.

    1. I noticed that while the dataset is available in the library, the model is not. Curious if a model for CIFAR100 is work-in-progress, or if there is no short-term plan for this?
    2. Looking at the CIFAR100 dataset, this seems to be inconsistent with Google's TFF. Notably, the cropping size and normalizing are done differently from TFF. Is this intentional? Would it be correct to say that we could expect this to mirror TFF's design eventually?

    Thanks in advance for all the help!

    opened by HanGuo97 5
  • unbiased scale for DRIVE

    unbiased scale for DRIVE

    Following a discussion with @stheertha, I suggest using the unbiased scale (section 4.2 in Drive's paper) for cases where there is more than 1 client.

    Thank you for considering.

    opened by amitport 3
  • Problem of Quick Start in Readme.md

    Problem of Quick Start in Readme.md

    I tried to run the code in the QuickStart and I found some problems. federated_data = fedjax.FederatedData() can not be executed because it is an abstract class. So I replaced it as

    client_a_data = {
            'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
            'y': np.array([7, 8])
        }
    client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
    client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}
    federated_data = fedjax.InMemoryFederatedData(client_to_data_mapping)
    

    The other things are same as the QuickStart, but i got an error

    for client_id, client_output, _ in func(shared_input, clients):
    for client_id, client_batches, client_input in clients:
    ValueError: not enough values to unpack (expected 3, got 2)
    

    It seems that client_batches is missing and we need to batch the dataset, but there is no example which fits this situation.

    opened by Ichiruchan 2
  • Full EMNIST example does not exhibit parallelization

    Full EMNIST example does not exhibit parallelization

    Hi! I am facing an issue with parallelizing the base code provided by the developers.

    • My local workstation contains two GPUs.
    • I installed FedJax in a conda environment
    • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
    • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
    • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

    Any ideas what I am doing wrong?

    opened by gaseln 2
  • Clarifying the meaning of

    Clarifying the meaning of "weight"

    In the Intro notebook, the backward_pass_output from model.backward has a weight feature. It seems to me that this is used for performing a weighted averaging in FedAvg, but this is not clear to me how. Perhaps this can be renamed to batch_size?

    opened by Saipraneet 1
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Why? The deprecated jax.experimental.stax path will soon be removed (see https://github.com/google/jax/pull/11700), and this causes pytype to fail.

    opened by copybara-service[bot] 0
  • Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed.

    opened by copybara-service[bot] 0
  • Implement standard CIFAR-100 model in fedjax.models.cifar100

    Implement standard CIFAR-100 model in fedjax.models.cifar100

    Add a standard implementation of the model for the CIFAR-100 task. The dataset can be found in fedjax.datasets.cifar100.

    For the model architecture, we should follow “Adaptive Federated Optimization”. The model architecture is detailed in section 4 as a ResNet-18 (replacing batch norm with group norm). Code for this paper and a Keras implementation of the model can be found here. We suggest using either haiku or flax to implement the model for use with JAX.

    If you choose to use haiku, you can use fedjax.create_model_from_haiku to create a fedjax compatible model. If you choose to use flax, wrapping it in a fedjax.Model is fairly straightforward and we can provide guidance for this.

    A good example to follow is #265 that checks in a simple linear model for CIFAR-100 and includes the model implementation, tests, and baseline results with FedAvg using this script. Make sure to add a flags file similar to https://github.com/google/fedjax/blob/main/experiments/fed_avg/fed_avg.CIFAR100_LOGISTIC.flags and add the new task to https://github.com/google/fedjax/blob/main/fedjax/training/tasks.py.

    Thanks for your contributions!

    enhancement contributions welcome 
    opened by jaehunro 1
  • Support for manually modifying client/server learning rate

    Support for manually modifying client/server learning rate

    Hi, I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

    Basically, I need to change the LR following a schedule based on the current round. Is that possible?

    Thanks

    opened by marcociccone 1
  • Support for gldv2 and inaturalist datasets

    Support for gldv2 and inaturalist datasets

    I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

    By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

    The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

    Thanks!

    opened by marcociccone 7
  • Support for haiku models with non-trainable state

    Support for haiku models with non-trainable state

    Hi! congrats on this great library! I've started using it a few days ago and I love it!

    Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

    Thanks a lot for your help!

    opened by marcociccone 2
  • How to create a validation dataset?

    How to create a validation dataset?

    Hello!

    I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?

    opened by gaseln 4
  • Feature request: Convert standard dataset into a federated dataset

    Feature request: Convert standard dataset into a federated dataset

    Synthetic federated datasets can constructed from standard centralized ones by artificially splitting them among clients. This is usually done using a Dirichlet distribution (e.g. Hsu et al. 2019). Such synthetic datasets are very useful since we can explicitly control the total number of users, as well as the heterogeneity.

    It would be great to have primitives which can automatically convert standard numpy dataset into a FedJax datset.

    contributions welcome 
    opened by Saipraneet 5
Releases(v0.0.15)
Owner
Google
Google ❤️ Open Source
Google
EsViT: Efficient self-supervised Vision Transformers

Efficient Self-Supervised Vision Transformers (EsViT) PyTorch implementation for EsViT, built with two techniques: A multi-stage Transformer architect

Microsoft 352 Dec 25, 2022
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
Super Pix Adv - Offical implemention of Robust Superpixel-Guided Attentional Adversarial Attack (CVPR2020)

Super_Pix_Adv Offical implemention of Robust Superpixel-Guided Attentional Adver

DLight 8 Oct 26, 2022
Modular Gaussian Processes

Modular Gaussian Processes for Transfer Learning 🧩 Introduction This repository contains the implementation of our paper Modular Gaussian Processes f

Pablo Moreno-Muñoz 10 Mar 15, 2022
[NeurIPS 2020] This project provides a strong single-stage baseline for Long-Tailed Classification, Detection, and Instance Segmentation (LVIS).

A Strong Single-Stage Baseline for Long-Tailed Problems This project provides a strong single-stage baseline for Long-Tailed Classification (under Ima

Kaihua Tang 514 Dec 23, 2022
Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic

Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic [Paper] [Colab is coming soon] Approach Example Usage To r

170 Jan 03, 2023
Official implementation of the Implicit Behavioral Cloning (IBC) algorithm

Implicit Behavioral Cloning This codebase contains the official implementation of the Implicit Behavioral Cloning (IBC) algorithm from our paper: Impl

Google Research 210 Dec 09, 2022
End-To-End Crowdsourcing

End-To-End Crowdsourcing Comparison of traditional crowdsourcing approaches to a state-of-the-art end-to-end crowdsourcing approach LTNet on sentiment

Andreas Koch 1 Mar 06, 2022
Deep-learning-roadmap - All You Need to Know About Deep Learning - A kick-starter

Deep Learning - All You Need to Know Sponsorship To support maintaining and upgrading this project, please kindly consider Sponsoring the project deve

Instill AI 4.4k Dec 26, 2022
Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space

extrinsic2pyramid Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space Intro A very simple and straightforward modu

JEONG HYEONJIN 106 Dec 28, 2022
Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2

Graph Transformer - Pytorch Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by bot

Phil Wang 97 Dec 28, 2022
Athena is the only tool that you will ever need to optimize your portfolio.

Athena Portfolio optimization is the process of selecting the best portfolio (asset distribution), out of the set of all portfolios being considered,

Indrajit 1 Mar 25, 2022
This is the repository for our paper Ditch the Gold Standard: Re-evaluating Conversational Question Answering

Ditch the Gold Standard: Re-evaluating Conversational Question Answering This is the repository for our paper Ditch the Gold Standard: Re-evaluating C

Princeton Natural Language Processing 38 Dec 16, 2022
unet-family: Ultimate version

unet-family: Ultimate version 基于之前my-unet代码,我整理出来了这一份终极版本unet-family,方便其他人阅读。 相比于之前的my-unet代码,代码分类更加规范,有条理 对于clone下来的代码不需要修改各种复杂繁琐的路径问题,直接就可以运行。 并且代码有

2 Sep 19, 2022
Code for SALT: Stackelberg Adversarial Regularization, EMNLP 2021.

SALT: Stackelberg Adversarial Regularization Code for Adversarial Regularization as Stackelberg Game: An Unrolled Optimization Approach, EMNLP 2021. R

Simiao Zuo 10 Jan 10, 2022
Hyperbolic Hierarchical Clustering.

Hyperbolic Hierarchical Clustering (HypHC) This code is the official PyTorch implementation of the NeurIPS 2020 paper: From Trees to Continuous Embedd

HazyResearch 154 Dec 15, 2022
Cross-Modal Contrastive Learning for Text-to-Image Generation

Cross-Modal Contrastive Learning for Text-to-Image Generation This repository hosts the open source JAX implementation of XMC-GAN. Setup instructions

Google Research 94 Nov 12, 2022
A collection of awesome resources image-to-image translation.

awesome image-to-image translation A collection of resources on image-to-image translation. Contributing If you think I have missed out on something (

876 Dec 28, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
UniFormer - official implementation of UniFormer

UniFormer This repo is the official implementation of "Uniformer: Unified Transf

SenseTime X-Lab 573 Jan 04, 2023