FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Overview

FedJAX: Federated learning with JAX

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

The FedJAX Intro notebook provides an introduction into running existing FedJAX experiments. For more custom use cases, please refer to the FedJAX Advanced notebook.

You can also take a look at some of our examples:

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Useful pointers

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

Comments
  • FedJax depends on TensorFlow Federated?

    FedJax depends on TensorFlow Federated?

    I am helping users install FedJax for use in their federated learning research projects and I noticed that installing FedJax is pulling in TensorFlow Federated (0.17) and TensorFlow (2.3). I don't see either of these listed as dependencies of FedJax so I am trying to understand why they are being pulled in by pip install fedjax.

    opened by davidrpugh 7
  • CIFAR 100 Questions

    CIFAR 100 Questions

    Hi, thanks for the awesome library! I want to ask a couple of questions related to CIFAR100 datasets.

    1. I noticed that while the dataset is available in the library, the model is not. Curious if a model for CIFAR100 is work-in-progress, or if there is no short-term plan for this?
    2. Looking at the CIFAR100 dataset, this seems to be inconsistent with Google's TFF. Notably, the cropping size and normalizing are done differently from TFF. Is this intentional? Would it be correct to say that we could expect this to mirror TFF's design eventually?

    Thanks in advance for all the help!

    opened by HanGuo97 5
  • unbiased scale for DRIVE

    unbiased scale for DRIVE

    Following a discussion with @stheertha, I suggest using the unbiased scale (section 4.2 in Drive's paper) for cases where there is more than 1 client.

    Thank you for considering.

    opened by amitport 3
  • Problem of Quick Start in Readme.md

    Problem of Quick Start in Readme.md

    I tried to run the code in the QuickStart and I found some problems. federated_data = fedjax.FederatedData() can not be executed because it is an abstract class. So I replaced it as

    client_a_data = {
            'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
            'y': np.array([7, 8])
        }
    client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
    client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}
    federated_data = fedjax.InMemoryFederatedData(client_to_data_mapping)
    

    The other things are same as the QuickStart, but i got an error

    for client_id, client_output, _ in func(shared_input, clients):
    for client_id, client_batches, client_input in clients:
    ValueError: not enough values to unpack (expected 3, got 2)
    

    It seems that client_batches is missing and we need to batch the dataset, but there is no example which fits this situation.

    opened by Ichiruchan 2
  • Full EMNIST example does not exhibit parallelization

    Full EMNIST example does not exhibit parallelization

    Hi! I am facing an issue with parallelizing the base code provided by the developers.

    • My local workstation contains two GPUs.
    • I installed FedJax in a conda environment
    • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
    • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
    • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

    Any ideas what I am doing wrong?

    opened by gaseln 2
  • Clarifying the meaning of

    Clarifying the meaning of "weight"

    In the Intro notebook, the backward_pass_output from model.backward has a weight feature. It seems to me that this is used for performing a weighted averaging in FedAvg, but this is not clear to me how. Perhaps this can be renamed to batch_size?

    opened by Saipraneet 1
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Why? The deprecated jax.experimental.stax path will soon be removed (see https://github.com/google/jax/pull/11700), and this causes pytype to fail.

    opened by copybara-service[bot] 0
  • Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed.

    opened by copybara-service[bot] 0
  • Implement standard CIFAR-100 model in fedjax.models.cifar100

    Implement standard CIFAR-100 model in fedjax.models.cifar100

    Add a standard implementation of the model for the CIFAR-100 task. The dataset can be found in fedjax.datasets.cifar100.

    For the model architecture, we should follow “Adaptive Federated Optimization”. The model architecture is detailed in section 4 as a ResNet-18 (replacing batch norm with group norm). Code for this paper and a Keras implementation of the model can be found here. We suggest using either haiku or flax to implement the model for use with JAX.

    If you choose to use haiku, you can use fedjax.create_model_from_haiku to create a fedjax compatible model. If you choose to use flax, wrapping it in a fedjax.Model is fairly straightforward and we can provide guidance for this.

    A good example to follow is #265 that checks in a simple linear model for CIFAR-100 and includes the model implementation, tests, and baseline results with FedAvg using this script. Make sure to add a flags file similar to https://github.com/google/fedjax/blob/main/experiments/fed_avg/fed_avg.CIFAR100_LOGISTIC.flags and add the new task to https://github.com/google/fedjax/blob/main/fedjax/training/tasks.py.

    Thanks for your contributions!

    enhancement contributions welcome 
    opened by jaehunro 1
  • Support for manually modifying client/server learning rate

    Support for manually modifying client/server learning rate

    Hi, I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

    Basically, I need to change the LR following a schedule based on the current round. Is that possible?

    Thanks

    opened by marcociccone 1
  • Support for gldv2 and inaturalist datasets

    Support for gldv2 and inaturalist datasets

    I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

    By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

    The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

    Thanks!

    opened by marcociccone 7
  • Support for haiku models with non-trainable state

    Support for haiku models with non-trainable state

    Hi! congrats on this great library! I've started using it a few days ago and I love it!

    Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

    Thanks a lot for your help!

    opened by marcociccone 2
  • How to create a validation dataset?

    How to create a validation dataset?

    Hello!

    I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?

    opened by gaseln 4
  • Feature request: Convert standard dataset into a federated dataset

    Feature request: Convert standard dataset into a federated dataset

    Synthetic federated datasets can constructed from standard centralized ones by artificially splitting them among clients. This is usually done using a Dirichlet distribution (e.g. Hsu et al. 2019). Such synthetic datasets are very useful since we can explicitly control the total number of users, as well as the heterogeneity.

    It would be great to have primitives which can automatically convert standard numpy dataset into a FedJax datset.

    contributions welcome 
    opened by Saipraneet 5
Releases(v0.0.15)
Owner
Google
Google ❤️ Open Source
Google
Phy-Q: A Benchmark for Physical Reasoning

Phy-Q: A Benchmark for Physical Reasoning Cheng Xue*, Vimukthini Pinto*, Chathura Gamage* Ekaterina Nikonova, Peng Zhang, Jochen Renz School of Comput

29 Dec 19, 2022
Pydantic models for pywttr and aiopywttr.

Pydantic models for pywttr and aiopywttr.

Almaz 2 Dec 08, 2022
Create Data & AI apps in 20 lines of code with Shimoku

Install with: pip install shimoku-api-python Start with: from os import getenv import shimoku_api_python.client as Shimoku

Shimoku 5 Nov 07, 2022
The Official Implementation of the ICCV-2021 Paper: Semantically Coherent Out-of-Distribution Detection.

SCOOD-UDG (ICCV 2021) This repository is the official implementation of the paper: Semantically Coherent Out-of-Distribution Detection Jingkang Yang,

Jake YANG 62 Nov 21, 2022
CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces

CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces This is a repository for the following pape

17 Oct 13, 2022
A selection of State Of The Art research papers (and code) on human locomotion (pose + trajectory) prediction (forecasting)

A selection of State Of The Art research papers (and code) on human trajectory prediction (forecasting). Papers marked with [W] are workshop papers.

Karttikeya Manglam 40 Nov 18, 2022
The fastai book, published as Jupyter Notebooks

English / Spanish / Korean / Chinese / Bengali / Indonesian The fastai book These notebooks cover an introduction to deep learning, fastai, and PyTorc

fast.ai 17k Jan 07, 2023
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision

NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations"

Felix Petersen 76 Jan 01, 2023
Projects of Andfun Yangon

AndFunYangon Projects of Andfun Yangon First Commit We can use gsearch.py to sea

Htin Aung Lu 1 Dec 28, 2021
torchsummaryDynamic: support real FLOPs calculation of dynamic network or user-custom PyTorch ops

torchsummaryDynamic Improved tool of torchsummaryX. torchsummaryDynamic support real FLOPs calculation of dynamic network or user-custom PyTorch ops.

Bohong Chen 1 Jan 07, 2022
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
Simulate genealogical trees and genomic sequence data using population genetic models

msprime msprime is a population genetics simulator based on tskit. Msprime can simulate random ancestral histories for a sample of individuals (consis

Tskit developers 150 Dec 14, 2022
Project ArXiv Citation Network

Project ArXiv Citation Network Overview This project involved the analysis of the ArXiv citation network. Usage The complete code of this project is i

Dennis Núñez-Fernández 5 Oct 20, 2022
Code for "My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack" paper

Myo Keylogging This is the source code for our paper My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack by Matthias Ga

Secure Mobile Networking Lab 7 Jan 03, 2023
bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED)

osed-scripts bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED) Table of Contents Standalone Scripts egghunter.py fin

epi 268 Jan 05, 2023
Python scripts using the Mediapipe models for Halloween.

Mediapipe-Halloween-Examples Python scripts using the Mediapipe models for Halloween. WHY Mainly for fun. But this repository also includes useful exa

Ibai Gorordo 23 Jan 06, 2023
The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography using a CNN-based orientation classifier')

The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography

James 135 Dec 23, 2022
DTCN SMP Challenge - Sequential prediction learning framework and algorithm

DTCN This is the implementation of our paper "Sequential Prediction of Social Me

Bobby 2 Jan 24, 2022
Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper

LEXA Benchmark Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper (Discovering and Achieving Goals via World Models

Oleg Rybkin 36 Dec 22, 2022
Algorithmic trading using machine learning.

Algorithmic Trading This machine learning algorithm was built using Python 3 and scikit-learn with a Decision Tree Classifier. The program gathers sto

Sourav Biswas 101 Nov 10, 2022