Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Overview

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Comments
  • Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    As noted below, this PR contains the following features:

    • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
    • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
    • General changes that enable the framework-agnostic mindset.
    • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

    Tasks:

    • [x] Create hooks module
    • [x] Refactor Model with low-level API and remove Module dependencies
    • [x] Refactor Module to use hooks
    • [x] Create GeneralizedModule and GeneralizedOptimizer Inferfaces
    • [x] Implement GeneralizedModule for flax.linen.Module
    • [x] Implement GeneralizedModule for elegy.Module
    • [x] Implement GeneralizedModule for haiku.Module
    • [x] Implement GeneralizedOptimizer for optax.GradientTransformation
    • [x] Implement GeneralizedOptimizer for elegy.Optimizer
    • [x] Fix Model.summary
    • [x] Fix tests
    • [x] Fix examples
    • [ ] Fix README
    • [ ] Fix guides
    • [ ] Fix docstrings
    opened by cgarciae 27
  • WGAN-GP low-level API example

    WGAN-GP low-level API example

    A more extensive example using the new low-level API: Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

    Some good generated images: epoch-0079 epoch-0084 epoch-0089

    Some notes:

    • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
    • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
    • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
    • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
    • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

    @cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

    opened by alexander-g 11
  • Add learning rate logging

    Add learning rate logging

    Implements the same functionality from #131 using only minor modifications to elegy.Optimizer.

    • [x] Add lr_schedule and steps_per_epoch to Optimizer.
    • [x] Implement Optimizer.get_effective_learning_rate
    • [x] Copy logging code from #131
    • [x] Add documentation

    @alexander-g Here is a proposal that is a bit simpler, closer to what I mentioned in #124. What do you think? @charlielito should we log the learning rate automatically if available or should we create a Callback?

    opened by cgarciae 9
  • Question: how to set the random state when calling model.predict(...)

    Question: how to set the random state when calling model.predict(...)

    Not sure if this is the right place to post this...

    I have built and trained a VAE. When calling model.predict(x=test_set), I would like to make multiple predictions for each item in the test set (because VAE's are probabilistic). That way I can look at the distribution of predictions for each item in the test_set.

    The call() for the VAE includes the line
    intrinsic_latents = mean + stds * jax.random.normal(self.next_key(), mean.shape).

    I haven't been able to find an explanation for how self.next_key() works or how to change the random seed on each call so that I can get different predictions. I could rewrite the code so that random seeds are explicitly passed, but I assume there is some functionality build into elegy to make this easy?

    Could someone explain how this works, or point me to the documentation explaining it?

    Thanks!

    opened by jfcrenshaw 8
  • Examples Cleanup

    Examples Cleanup

    • refactored examples/imagenet/resnet_imagenet.py to accept parameters instead of modifying them inside the script
    • added README.md for examples/imagenet/
    • removed unnecessary Lambda class from examples/mnist.py
    • moved global average pooling in examples/mnist_conv.py before the Linear layer
    opened by alexander-g 7
  • Resnet

    Resnet

    • ResNet model architecture and an example for training on ImageNet
      • code is mostly adapted from the flax library
      • pretrained ResNet50 with 76.5% accuracy
      • pretrained ResNet18 with 68.7% accuracy
    • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
    • Some issues I had during training:
      • There seems to be a memory leak during training, RAM constantly increased
      • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else
    opened by alexander-g 7
  • [Bug] Problem with computing metrics

    [Bug] Problem with computing metrics

    Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

    TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'
    

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import jax
    import jax.numpy as jnp
    import ml_collections
    import numpy as np
    import optax
    import elegy as eg
    
    
    class eCNN(eg.Module):
        """A simple CNN model."""
    
        @eg.compact
        def __call__(self, x):
            x=eg.Conv(10,kernel_size=(10,))(x)
            x=jax.nn.relu(x)
            x = eg.Linear(1)(x)
            x=jax.nn.sigmoid(x)
            return x
    
    n=200
    X_train = np.random.rand(n*100).reshape(n,100)
    y_train = np.random.rand(n).reshape(n,1)
    print(X_train.shape)
    print(y_train.shape)
    
    model = eg.Model(
        module=eCNN(),
        loss=[
            eg.losses.MeanSquaredError(),
        ],
        metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
        optimizer=optax.rmsprop(1e-3),
    )
    
    model.fit(X_train,y_train,
        epochs=10,
        batch_size=20,
        #validation_data=0.1,
        shuffle=False,
        callbacks=[eg.callbacks.TensorBoard("summaries")]
        )
    

    Library Info Please provide os info and elegy version.

    import elegy
    print(elegy.__version__) 
    # 0.8.4
    
    bug 
    opened by organic-chemistry 6
  • Multi-gpu with pmap docs

    Multi-gpu with pmap docs

    One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

    opened by sooheon 6
  • SCCE fix for bug in Jax<0.2.7

    SCCE fix for bug in Jax<0.2.7

    Small fix for a bug in Jax<0.2.7 where jax.lax.take_along_axis gives incorrect results for uint8 indices. Very relevant for semantic segmentation.

    Alternatively consider updating Jax

    opened by alexander-g 6
  • Dataset & DataLoader

    Dataset & DataLoader

    Dataset and parallel DataLoader API similar to PyTorch. Can be used with Model.fit()

    class MyDataset(elegy.data.Dataset):
        def __len__(self):
            return 128
    
        def __getitem__(self, i):
            #dummy data
            return np.random.random([224, 224, 3]),  np.random.randint(10)
    
    ds     = MyDataset()
    loader = elegy.data.DataLoader(ds, batch_size=8, n_workers=8, worker_type='thread', shuffle=True)
    
    batch = next(iter(loader))
    assert batch[0].shape == (8,224,224,3)
    assert batch[1].shape == (8,)
    assert len(loader) == 16
    
    model.fit(loader, epochs=10)
    
    opened by alexander-g 6
  • Implemented BinaryCrossentropy metric

    Implemented BinaryCrossentropy metric

    Updates:

    • Created BinaryCrossentropy metric
    • Created basic tests for BinaryCrossentropy metric (passing)
    • Created docs for BinaryCrossentropy metric
    • Refactored main docs by balancing files and correcting language typos
    documentation 
    opened by sebasarango1180 6
  • use poetry-core

    use poetry-core

    poetry-core is intended to be a light weight, fully compliant, self-contained package allowing PEP 517 compatible build frontends to build Poetry managed projects.

    Using poetry-core allows distribution packages to depend only on the build backend.

    opened by dotlambda 0
  • Documentation/API reference not accessible via project website[Bug]

    Documentation/API reference not accessible via project website[Bug]

    Hi, It looks like I can't really access the API reference for Elegy. The corresponding link on the project's website simply takes me back to the main page (https://poets-ai.github.io/elegy/). Any idea what's up?

    bug 
    opened by geomlyd 0
  • [Bug] elegy does not work with latest haiku version

    [Bug] elegy does not work with latest haiku version

    Describe the bug When I type 'import elegy' I get this error

     File "/home/kpmurphy/mambaforge/lib/python3.10/site-packages/elegy/generalized_module/haiku_module.py", line 4, in <module>
        from haiku._src.base import current_bundle_name
    

    Minimal code to reproduce

    import elegy
    

    Expected behavior A clear and concise description of what you expected to happen.

    Library Info Please provide os info and elegy version.

    >> 
    >>> jax.__version__
    '0.2.28'
    >>> haiku.__version__
    '0.0.9.dev'
    >>> elegy.__version__. #  elegy-0.5.0-py3-none-any.whl 
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    NameError: name 'elegy' is not defined
    >>> 
    

    Screenshots

    Screen Shot 2022-10-03 at 2 33 21 PM

    Additional context Add any other context about the problem here.

    bug 
    opened by murphyk 5
  • CSVLogger iteration over a 0-d array

    CSVLogger iteration over a 0-d array

    Describe the bug When using the CSVLogger callback, elegy crashes at the end of the first epoch.

    Minimal code to reproduce

    import elegy as eg
    import optax
    import numpy as np
    
    x = np.random.randn(64, 1)
    y = np.random.randn(64, 1)
    
    model = eg.Model(
        eg.nn.Linear(1),
        loss=eg.losses.MeanSquaredError(),
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        x,
        y,
        epochs=10,
        callbacks=[
            eg.callbacks.CSVLogger("train.csv"), <-- commenting
        ]
    )
    

    Stack trace:

    Epoch 1/10
    2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
        return '"[%s]"' % (", ".join(map(str, k)))
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
        raise TypeError("iteration over a 0-d array")  # same as numpy error
    TypeError: iteration over a 0-d array
    

    Expected behavior Should treat 0-d array as scalar.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

    is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
    
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
    │ 8 in handle_value                                                                                │
    │                                                                                                  │
    │    65 │   │   │   if isinstance(k, six.string_types):                                            │
    │    66 │   │   │   │   return k                                                                   │
    │    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
    │ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
    │    69 │   │   │   else:                                                                          │
    │    70 │   │   │   │   return k                                                                   │
    │    71                                                                                            │
    │                                                                                                  │
    │ ╭──────────────────────────── locals ─────────────────────────────╮                              │
    │ │ is_zero_dim_ndarray = False                                     │                              │
    │ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
    │ ╰─────────────────────────────────────────────────────────────────╯                              │
    │                                                                                                  │
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
    │ __iter__                                                                                         │
    │                                                                                                  │
    │   242                                                                                            │
    │   243   def __iter__(self):                                                                      │
    │   244 │   if self.ndim == 0:                                                                     │
    │ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
    │   246 │   else:                                                                                  │
    │   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
    │   248                                                                                            │
    │                                                                                                  │
    │ ╭───────────────────── locals ─────────────────────╮                                             │
    │ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
    │ ╰──────────────────────────────────────────────────╯                                             │
    ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
    TypeError: iteration over a 0-d array
    
    bug 
    opened by ScottAlexanderCameron 0
  • Metrics ignore

    Metrics ignore "on" keyword arg

    Describe the bug I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import elegy as eg
    import optax
    import numpy as np
    
    
    def data_generator():
        while True:
            yield (
                np.random.randn(10, 1),
                {"target": {"y": np.random.randn(10, 1)}},
            )
    
    
    class MyModule(eg.Module):
        @eg.compact
        def __call__(self, x):
            return {"y": eg.nn.Linear(1)(x)}
    
    
    model = eg.Model(
        MyModule(),
        loss=eg.losses.MeanSquaredError(on="y"),
        metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        data_generator(),
        steps_per_epoch=10,
        epochs=10,
    )
    

    Stack trace:

    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
        tmp_logs = self.train_on_batch(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
        logs, model = train_step_fn(self, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
        return model.train_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
        grads, (logs, model) = grad_fn(params, model, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
        loss, logs, model = model.test_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
        batch_loss_and_logs.update(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
        self.metrics.update(**metrics_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
        metric.update(**metric_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
        values = _mean_absolute_error(preds, target)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
        target = target.astype(preds.dtype)
    AttributeError: 'dict' object has no attribute 'astype'
    

    Expected behavior Should produce the same result as if the dictionaries are removed and the on arg not specified.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

    bug 
    opened by ScottAlexanderCameron 0
  • [Bug] Elegy crash on GPU

    [Bug] Elegy crash on GPU

    Describe the bug

    Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

    This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

    Running on CPU does not have this problem.

    Running on eager mode with GPU does not have this problem.

    Minimal code to reproduce

    python mnist_cnn.py
    

    Expected behavior Not stuck.

    Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

    Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

    bug 
    opened by jiyuuchc 2
Releases(0.8.6)
  • 0.8.6(Mar 23, 2022)

    🚀 Features

    • Weights and Biases Callback for Elegy
      • PR: #220

    🐛 Fixes

    • Docs typos
      • PR: #222
    • Donate model's memory buffer to jit/pmap functions.
      • PR: #226
    • Lazy load wandb
      • PR: #228
    Source code(tar.gz)
    Source code(zip)
  • 0.8.5(Feb 23, 2022)

  • 0.8.4(Dec 14, 2021)

  • 0.8.3(Dec 13, 2021)

  • 0.8.2(Dec 13, 2021)

  • 0.8.1(Nov 8, 2021)

    Elegy is now based on Treex 🎉

    Changes

    • Remove the module, nn, metrics, and losses from Elegy, instead Elegy reexports these modules from Treex.
    • GeneralizedModule and friends are gone, to use Flax Modules use the elegy.nn.FlaxModule wrapper.
    • Low level API is massively simplified:
      • States is removed, since Model is a pytree all parameters are tracked automatically thanks to Treex / Treeo.
      • All static state arguments (training, initializing) are removed, Modules can simply use self.training to pick their training state and self.initializing() to check whether they are initializing.
      • Signature for pred_step, test_step, and train_step now simply consists of inputs and labels, where labels is a dict that can contain additional keys like sample_weight or class_weight as required by the losses and metrics.
    • Adds the DistributedStrategy class which currently has 3 instances
      • Eager: Runs model in a single device in eager mode (no jit)
      • JIT: Runs model in a single device with jit
      • DataParallel: Run the model in multiple devices using pmap.
    • Adds methods to change the model's distributed strategy:
      • .distributed(strategy = DataParallel): changes the distributed strategy, DataParallel used by default.
      • .local(): changes the distributed strategy to JIT.
      • .eager(): changes the distributed strategy to Eager.
    • Removes the .eager field in favor of the .eager() method.
    Source code(tar.gz)
    Source code(zip)
  • 0.7.4(Jun 1, 2021)

  • 0.7.2(Mar 10, 2021)

  • 0.7.1(Mar 1, 2021)

  • 0.7.0(Feb 22, 2021)

    Features

    • init now only called once internally and required to be called explicitly by the user under certain circumstances
    • summary now uses jax.eval_shape under the hood so its super fast since it doesn't allocate memory or perform any computations on the device.

    Merged pull requests:

    • Fix notebook #166 (cgarciae)
    • Single Initialization: Removes the current progressive initialization in favor of a single underlying call to init_step. #165 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Feb 14, 2021)

  • 0.5.0(Feb 8, 2021)

    This version simplifies parts of the low-level API in spirit of what was introduced in 0.4.0 to provide a more homogeneous and simpler experience.

    Merged pull requests:

    • Improve States: uses __dict__ so States works with vars #159 (cgarciae)
    • Simplify API: Cleans-up some API details around Model and Module #158 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.4.1(Feb 3, 2021)

  • 0.4.0(Feb 1, 2021)

    Implemented enhancements:

    • [Feature Request] Monitoring learning rates #124

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 17, 2020)

    Implemented enhancements:

    • elegy.nn.Sequential docs not clear #107
    • [Feature Request] Community example repo. #98

    Fixed bugs:

    • [Bug] Accuracy from Model.evaluate() is inconsistent with manually computed accuracy #109
    • Exceptions in "Getting Started" colab notebook #104

    Closed issues:

    • l2_normalize #102
    • Need some help for contributing new losses. #93
    • Document Sum #62
    • Binary Accuracy Metric #58
    • Automate generation of API Reference folder structure #19
    • Implement Model.summary #3

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Aug 31, 2020)

  • 0.2.1(Aug 25, 2020)

  • 0.2.0(Aug 17, 2020)

  • 0.1.5(Jul 28, 2020)

    • Mean Absolute Percentage Error Implementation @Ciroye
    • Adds elegy.nn.Linear, elegy.nn.Conv2D, elegy.nn.Flatten, elegy.nn.Sequential @cgarciae
    • Add Elegy hooks @cgarciae
    • Improves Tensorboard support @Davidnet
    • Added coverage metrics to CI @charlielito
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 24, 2020)

    • Adds elegy.metrics.BinaryCrossentropy @sebasarango1180
    • Adds elegy.nn.Dropout and elegy.nn.BatchNormalization @cgarciae
    • Improves documentation
    • Fixes bug that cause error when using is_training via dependency injection on Model.predict.
    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 23, 2020)

This is an official implementation for "ResT: An Efficient Transformer for Visual Recognition".

ResT By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the official implement

zhql 222 Dec 13, 2022
Perform zero-order Hankel Transform for an 1D array (float or real valued).

perform zero-order Hankel Transform for an 1D array (float or real valued). An discrete form of Parseval theorem is guaranteed. Suit for iterative problems.

1 Jan 17, 2022
Intrusion Detection System using ensemble learning (machine learning)

IDS-ML implementation of an intrusion detection system using ensemble machine learning methods Data set This project is carried out using the UNSW-15

4 Nov 25, 2022
Composing methods for ML training efficiency

MosaicML Composer contains a library of methods, and ways to compose them together for more efficient ML training.

MosaicML 2.8k Jan 08, 2023
Implementation of SegNet: A Deep Convolutional Encoder-Decoder Architecture for Semantic Pixel-Wise Labelling

Caffe SegNet This is a modified version of Caffe which supports the SegNet architecture As described in SegNet: A Deep Convolutional Encoder-Decoder A

Alex Kendall 1.1k Jan 02, 2023
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.

What's New Below we share, in reverse chronological order, the updates and new releases in VISSL. All VISSL releases are available here. [Oct 2021]: V

Meta Research 2.9k Jan 07, 2023
Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

3 May 19, 2022
IhoneyBakFileScan Modify - 批量网站备份文件扫描器,增加文件规则,优化内存占用

ihoneyBakFileScan_Modify 批量网站备份文件泄露扫描工具 2022.2.8 添加、修改内容 增加备份文件fuzz规则 修改备份文件大小判断

VMsec 220 Jan 05, 2023
Supervised multi-SNE (S-multi-SNE): Multi-view visualisation and classification

S-multi-SNE Supervised multi-SNE (S-multi-SNE): Multi-view visualisation and classification A repository containing the code to reproduce the findings

Theodoulos Rodosthenous 3 Apr 15, 2022
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
ESL: Event-based Structured Light

ESL: Event-based Structured Light Video (click on the image) This is the code for the 2021 3DV paper ESL: Event-based Structured Light by Manasi Mugli

Robotics and Perception Group 29 Oct 24, 2022
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Google 69 Dec 21, 2022
Winners of the Facebook Image Similarity Challenge

Winners of the Facebook Image Similarity Challenge

DrivenData 111 Jan 05, 2023
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
This repository contains the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"

GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields Project Page | Paper | Supplementary | Video | Slides | Blog | Talk If

1.1k Dec 30, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
C3D is a modified version of BVLC caffe to support 3D ConvNets.

C3D C3D is a modified version of BVLC caffe to support 3D convolution and pooling. The main supporting features include: Training or fine-tuning 3D Co

Meta Archive 1.1k Nov 14, 2022
This package contains deep learning models and related scripts for RoseTTAFold

RoseTTAFold This package contains deep learning models and related scripts to run RoseTTAFold This repository is the official implementation of RoseTT

1.6k Jan 03, 2023
CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer

CSAW-M This repository contains code for CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer. Source code for tr

Yue Liu 7 Oct 11, 2022
Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung

Vending_Machine_(Mesin_Penjual_Minuman) Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung Raw Sketch untuk Essay Ringkasan P

QueenLy 1 Nov 08, 2021