TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

Overview

Documents | Projects | API References

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Installation

pip install torchshard

More options in INSTALL.md.

Usage

import torchshard as ts

ts.init_process_group(group_size=2)                       # init parallel groups

m = torch.nn.Sequential(
    torch.nn.Linear(20, 30, bias=True),               
    ts.nn.ParallelLinear(30, 30, bias=True, dim=None),    # equal to nn.Linear()
    ts.nn.ParallelLinear(30, 30, bias=True, dim=0),       # parallel in row dimension
    ts.nn.ParallelLinear(30, 30, bias=True, dim=1),       # parallel in column dimension
).cuda()

x = m(x)                                                  # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y)      # parallel loss function
loss.backward()                                           # backward

torch.save(
  ts.collect_state_dict(m, m.state_dict()), 'm.pt')       # save model state

Performance

The following figure is a showcase of training ResNet-50 on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up classes from 1000 → 1 Million. The input size is 224 x 224, and the batch size is 256. Parallelism is with 8-way data parallel and 8-way model parallel.

The following figure shows training minGPT on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up parameters from 10 Million → 808 Million. The input size is 32 x 32, and the batch size is 16. Parallelism is with 1-way data parallel and 8-way model parallel.

Contributing

The TorchShard welcomes your expertise and enthusiasm!

If you are interested in torchshard, you are welcome to help

  • polish code and develop new features
  • develop high-quality tutorials, projects, and advanced materials

Direct pull requests are welcome. Contact: kaiyuyue [at] umd.edu.

Citing TorchShard

If you think TorchShard is helpful in your research and consider to cite it, please use the following BibTeX entry.

@misc{torchshard2021,
  author =       {Kaiyu Yue},
  title =        {TorchShard},
  howpublished = {\url{https://github.com/KaiyuYue/torchshard}},
  year =         {2021}
}
Comments
  • Future Planinig on this project.

    Future Planinig on this project.

    Hello Kaiyu, I love this awesome project. The API design is elegant and simple and the software is lightweight and user-friendly. My understanding is that this project has realized a series of PyTorch wrappers for tensor slicing.

    1. I am curious about the future planning of this project.
    2. Is there some overlap in functionality between torchshard and N-D parallelism proposed in ColossalAI.
    3. How is compatibility with ZeRO? According to am+zero example, the memory footprint has a little change after combination torchshard with ZeRO.
    opened by feifeibear 2
  • Which one is faster?

    Which one is faster?

    Thanks for contributing this great lib. I have one question. Which one is faster (in speed) between dim=0and dim=1? The documentations seem to only contain accuracy results.

    opened by NOBLES5E 2
  • 8 gpus test example raise error.

    8 gpus test example raise error.

    When I do Unit Tests, it can pass when use two gpu devices, run command below: CUDA_VISIBLE_DEVICES=0,1 python3 -m unittest discover -v -s tests

    But I do Unit Tests with eight gpu devices, it raise ncclSystemError. run command: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m unittest discover -v -s tests raise error: RuntimeError: NCCL error in ../torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled system error, NCCL version 2.7.8 ncclSystemError: System call (socket, malloc, munmap, etc) failed.

    Is it necessary to pass unittest in eights gpu devices?

    opened by JiaquanYe 1
  • Error?

    Error?

    Hi, thanks for the excellent job! When I install it from pip, and

    import torchshard as ts
    ts.init_process_group(group_size=2) 
    

    The AttributeError occurs:

    AttributeError: module 'torchshard' has no attribute 'init_process_group'
    
    opened by WangWenhao0716 1
  • Multi-node setting?

    Multi-node setting?

    https://github.com/KaiyuYue/torchshard/blob/89e21def180bf6063ceb2e312a61631173abc7e7/projects/minGPT/main.py#L150

    I have noticed that the group_size is set to world_size in examples, but in fact the group_size can be set to other numbers according to my understanding.

    https://github.com/KaiyuYue/torchshard/blob/main/torchshard/distributed/core.py#L18

    I have also found that the get_world_size() will return the number of all processes.

    The two findings make me confused in a multi-node setting, say 2 nodes with each node with 2 processes.

    If the group_size is 2, then there are 2 distinct groups besides the default group (w/ overlap). However, get_world_size() is used without specifying a group can make a layer be splitted to 4 parts, which is expected to be 2 in our case.

    Correct me if I am wrong.

    Good Issue 
    opened by GeneZC 1
  • Is it possible to collect state dict in cpu?

    Is it possible to collect state dict in cpu?

    When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict). But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict). I found that will gather the state_dict in GPU, is it anyway to gather in CPU?

    Good Issue 
    opened by JiaquanYe 2
Releases(v0.1)
Owner
Kaiyu Yue
Kaiyu Yue
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

878 Dec 30, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

1k Dec 28, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
Riemannian Adaptive Optimization Methods with pytorch optim

geoopt Manifold aware pytorch.optim. Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more. Installation Make sur

642 Jan 03, 2023
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
270 Dec 24, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022