Convert onnx models to pytorch.

Overview

onnx2torch

onnx2torch is an ONNX to PyTorch converter. Our converter:

  • Is easy to use – Convert the ONNX model with the function call convert;
  • Is easy to extend – Write your own custom layer in PyTorch and register it with @add_converter;
  • Convert back to ONNX – You can convert the model back to ONNX using the torch.onnx.export function.

If you find an issue, please let us know! And feel free to create merge requests.

Please note that this converter covers only a limited number of PyTorch / ONNX models and operations.
Let us know which models you use or want to convert from onnx to torch here.

Installation

From PyPi

pip install onnx2torch

Usage

Below you can find some examples of use.

Convert

import torch
from onnx2torch.converter import convert

# Path to ONNX model
onnx_model_path = '/some/path/mobile_net_v2.onnx'
# You can pass the path to the onnx model to convert it or...
torch_model_1 = convert(onnx_model_path)

# Or you can load a regular onnx model and pass it to the converter
onnx_model = onnx.load(onnx_model_path)
torch_model_2 = convert(onnx_model)

Execute

We can execute the returned PyTorch model in the same way as the original torch model.

import onnxruntime as ort
# Create example data
x = torch.ones((1, 2, 224, 224)).cuda()

out_torch = torch_model_1(x)

ort_sess = ort.InferenceSession(onnx_model_path)
outputs_ort = ort_sess.run(None, {'input': x.numpy()})

# Check the Onnx output against PyTorch
print(torch.max(torch.abs(outputs_ort - out_torch.detach().numpy())))
print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7))

Models

We have tested the following models:

  • ResNet50
  • SSDLite with MobileNetV2 backbone

How to add new operations to converter

Here we show how to add the module:

  1. Supported by both PyTorch and ONNX and has the same behaviour.
    An example of such a module is Relu
@add_converter(operation_type='Relu', version=6)
@add_converter(operation_type='Relu', version=13)
@add_converter(operation_type='Relu', version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
    return OperationConverterResult(
        torch_module=nn.ReLU(),
        onnx_mapping=onnx_mapping_from_node(node=node),
    )

Here we have registered an operation named Relu for opset versions 6, 13, 14.
Note that the torch_module argument in OperationConverterResult must be a torch.nn.Module, not just a callable object!
If Operation's behaviour differs from one opset version to another, you should implement it separately.

  1. Operations supported by PyTorch and ONNX BUT have different behaviour
class OnnxExpand(nn.Module):

    @staticmethod
    def _do_forward(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
        return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device)

    def forward(self, *args) -> torch.Tensor:
        if torch.onnx.is_in_onnx_export():
            with skip_torch_tracing():
                output = self._do_forward(*args)
                return _ExpandExportToOnnx.set_output_and_apply(output, *args)

        return self._do_forward(*args)


class _ExpandExportToOnnx(CustomExportToOnnx):

    @staticmethod
    def symbolic(graph: torch_C.Graph, *args, **kwargs) -> torch_C.Value:
        return graph.op('Expand', *args, **kwargs, outputs=1)


@add_converter(operation_type='Expand', version=8)
@add_converter(operation_type='Expand', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:  # pylint: disable=unused-argument
    return OperationConverterResult(
        torch_module=OnnxExpand(),
        onnx_mapping=onnx_mapping_from_node(node=node),
    )

Here we have used a trick to convert the model from torch back to ONNX by defining the custom _ExpandExportToOnnx.

Comments
  • IndexError because pytorch does not support int32 for indexing like onnx

    IndexError because pytorch does not support int32 for indexing like onnx

    Been trying a bunch of things to solve this now but the error message isn't very helpful. It just sounds like it's just the input dtype that's handled incorrectly.

      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
        raise e.with_traceback(None)
    IndexError: tensors used as indices must be long, byte or bool tensors
    

    Here is the onnx file I'm trying to convert https://drive.google.com/file/d/1FX_D6dcYEoVssr-y29F5RLbhmhf4RmyZ/view?usp=sharing

    It is supposed to take a tensor of type long. Here's an example in json: https://drive.google.com/file/d/1yTSxwOY10g0cULEVt3KQZWSzziDpD8bC/view?usp=sharing

    This is how I try to run it:

        model = torch.load(torch_path).eval().requires_grad_(False)
    
        tokens = torch.tensor(
            json.loads(Path(f"tests/{model_name}/tokens.json").read_text())
        ).long()
        text_encodings = torch.from_numpy(np.array(model(tokens)))
    

    The onnx file works with onnxruntime and gives the correct result.

        tokens = np.array(
            json.loads(Path(f"tests/{model_name}/tokens.json").read_text())
        ).astype(np.int64)
    
        providers = ["CPUExecutionProvider"]
        m = rt.InferenceSession(onnx_path, providers=providers)
        text_encoding = m.run(output_names, dict(inputs=tokens))
    

    Do you have any tips for what could be wrong?

    opened by samedii 12
  • Key error for Reshape when trying to convert ONNX model

    Key error for Reshape when trying to convert ONNX model

    Do you think this error occuring because the weights name is missing in the ONNX file or because the operation isn't found?

    I'm getting the right results if I load and run the model with ONNX runtime at least.

      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/onnx2torch/converter.py", line 109, in convert
        torch_module, onnx_mapping = converter(onnx_node, onnx_graph)
      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/onnx2torch/node_converters/gemm.py", line 65, in _
        weights = graph.initializers[weights_value_name]
    KeyError: 'Reshape__2177:0'
    

    Here is a link to the ONNX file if you want to try it yourself https://drive.google.com/file/d/1P_Bl7n2hbUoOhfMh_9UJfgK_N8kP_xjB/view?usp=sharing

    I'm trying to make this model https://paperswithcode.com/paper/lit-zero-shot-transfer-with-locked-image-text easier to use for people

    bug 
    opened by samedii 5
  • release on conda-forge or similar?

    release on conda-forge or similar?

    Hi there,

    Are there plans to release the package for conda installation through conda-forge or another channel? I would love to be able to conda install this package so it will track in my conda environments.

    (My personal machine is on Mac OS, and I also work on a Linux server, so if I could put in requests for build OSs it would be those.)

    thank you for your consideration!

    feat 
    opened by monicathieu 5
  • feat: Add LRN opset 13 support

    feat: Add LRN opset 13 support

    Hi y'all,

    After installing the package (through conda-forge! 🙏🏼 ), I went to convert my ONNX model, which is a modified AlexNet model (same layer architecture, just retrained with different image classes and new weights in the last layer). At this point, I realized that local response normalization was not yet implemented in onnx2torch! I looked at the relevant pytorch docs and ONNX docs, and it looks like both of them implement the formula directly from the original AlexNet paper, just with slightly different argument names/order. Accordingly, I thought I might be able to handle adding a new converter for LRN.

    I do want to note that I haven't added any tests for the converter, because that was beyond my Python/pytorch ability (this is all part of me learning!), but when I use my local version of the package to import my modified AlexNet model, it seeeeems not to have been mangled.

    If you'd like me to work on adding the tests, I can, but it would take me some time to figure out the syntax. I would be thrilled if it ended up being very fast for you guys to add the tests yourselves.

    feat 
    opened by monicathieu 4
  • A lot of Ops with their implementations.

    A lot of Ops with their implementations.

    Hi, developer team of onnx2torch. I am currently developing an neural network quantization framework: https://github.com/openppl-public/ppq/tree/master/ppq. The really interesting part is that we both need to run an onnx model with pytorch : ) I am glad to share our operator implementations with you: https://github.com/openppl-public/ppq/blob/master/ppq/executor/op/torch/default.py

    We support following onnx operators by now(Still work in progress):

    1. 'Abs': Abs_forward,
    2. 'AdaptiveAvgPool2d': AdaptiveAvgPool2d_forward,
    3. 'And':And_forward,
    4. 'Add': Add_forward,
    5. 'ArgMax': ArgMax_forward,
    6. 'AveragePool': AveragePool_forward,
    7. 'BatchNormalization': BatchNormalization_forward,
    8. 'Cast': Cast_forward,
    9. 'Clip': Clip_forward,
    10. 'Concat': Concat_forward,
    11. 'Constant': Constant_forward,
    12. 'ConstantOfShape': ConstantOfShape_forward,
    13. 'Conv': Conv_forward,
    14. 'ConvTranspose': ConvTranspose_forward,
    15. 'Cos': Cos_forward,
    16. 'Div': Eltwise_forward,
    17. 'Equal': Equal_forward,
    18. 'Exp': UnaryEltwise_forward,
    19. 'Expand': Expand_forward,
    20. 'Flatten': Flatten_forward,
    21. 'Gather': Gather_forward,
    22. 'GatherElements': Gather_forward,
    23. 'GatherND': GatherND_forward,
    24. 'Gelu': Gelu_forward,
    25. 'Gemm': Gemm_forward,
    26. 'grid_sampler': Grid_sampler_forward,
    27. 'GlobalAveragePool': AveragePool_forward,
    28. 'GlobalMaxPool': MaxPool2d_forward,
    29. 'Greater': Greater_forward,
    30. 'LayerNorm': LayerNorm_forward,
    31. 'LeakyRelu': LeakyRelu_forward,
    32. 'Less': Less_forward,
    33. 'LogSoftmax': LogSoftmax_forward,
    34. 'MatMul': MatMul_forward,
    35. 'Max': Eltwise_forward,
    36. 'MaxPool': MaxPool2d_forward,
    37. 'Min': Eltwise_forward,
    38. 'Mul': Mul_forward,
    39. 'MultiHeadAttention': MultiHeadAttention_forward,
    40. 'NonMaxSuppression': _NMS_forward,
    41. 'NonZero': NonZero_forward,
    42. 'Not': Not_forward,
    43. 'Pad': Pad_forward,
    44. 'PRelu': PRelu_forward,
    45. 'Range': Range_forward,
    46. 'ReduceL2': ReduceL2_forward,
    47. 'ReduceMax': ReduceMax_forward,
    48. 'ReduceMean': ReduceMean_forward,
    49. 'ReduceSum': ReduceSum_forward,
    50. 'Relu': UnaryEltwise_forward,
    51. 'Reshape': Reshape_forward,
    52. 'Resize': Resize_forward,
    53. 'ScatterElements': ScatterElements_forward,
    54. 'ScatterND': ScatterND_forward,
    55. 'Shape': Shape_forward,
    56. 'Sigmoid': UnaryEltwise_forward,
    57. 'Sin': Sin_forward,
    58. 'Slice': Slice_forward,
    59. 'Softmax': Softmax_forward,
    60. 'Softplus': Softplus_forward,
    61. 'Split': Split_forward,
    62. 'Squeeze': Squeeze_forward,
    63. 'Sub': Eltwise_forward,
    64. 'Tile': Tile_forward,
    65. 'TopK': TopK_forward,
    66. 'Transpose': Transpose_forward,
    67. 'Unsqueeze': Unsqueeze_forward,
    68. 'Where': Where_forward,
    69. 'Sqrt': Sqrt_forward,
    70. 'Log': Log_forward,
    71. 'Floor': Floor_forward,
    72. 'RoiAlign': RoiAlign_forward,
    73. 'MMCVRoiAlign': MMCVRoiAlign_forward,
    74. 'SpaceToDepth': SpaceToDepth_forward,
    75. 'DepthToSpace': DepthToSpace_forward,
    76. 'Scale': Scale_forward, # caffe op
    77. 'Tanh': Tanh_forward,
    78. 'Pow': Pow_forward,
    79. 'Crop': Crop_forward, # caffe op
    80. 'ChannelShuffle': ChannelShuffle_forward, # caffe op
    81. 'InstanceNormalization': InstanceNormalization_forward,
    82. 'Parameter': Parameter_forward, # caffe op
    83. 'Interp': Interp_forward, # caffe op
    84. 'CaffeArgMax': CaffeArgMax_forward, # caffe op
    85. 'HardSigmoid': HardSigmoid_forward,
    86. 'HardSwish': HardSwish_forward,
    87. 'Neg': Neg_forward,
    88. 'GRU': GRU_forward,
    89. 'PPQDeviceSwitch': PPQDeviceSwitch_forward,
    90. 'Identity': Identity_forward,
    91. 'OneHot': Onehot_forward,
    92. 'Reciprocal': Reciprocal_forward,
    93. 'LSTM': LSTM_forward,
    Stale 
    opened by ZhangZhiPku 4
  • Does not support the convert of xception?

    Does not support the convert of xception?

    I successfully converted the Xception ONNX model file to a .pt file, but the predicted data is inconsistent with the original ONNX file.

    Is there any unsupported operation? but also did not throw any error message when converting.

    bug 
    opened by koon-kai 3
  • GlobalAveragePool: make it compatible with PyTorch FX QAT

    GlobalAveragePool: make it compatible with PyTorch FX QAT

    range() is not compatible with torch.fx symbolic tracing yet [1]. I created a slightly different module that avoids range() in forward() when possible (i.e., inferred shapes exist).

    [1] https://github.com/pytorch/pytorch/issues/44851

    Here is an example of the current issue:

    import onnx
    import onnx.checker
    import onnx.helper
    import onnx2torch
    import torch.quantization.quantize_fx
    
    # https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.0-12.onnx
    model = onnx.load_model('squeezenet1.0-12.onnx')
    net = onnx2torch.convert(model)
    
    net.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
    qconfig_dict = {'': net.qconfig}
    torch.quantization.quantize_fx.prepare_qat_fx(net, qconfig_dict)
    

    torch.fx fails even after adding "torch.fx.wrap('len')" to onnx2torch/node_converters/global_average_pool.py:

    Traceback (most recent call last):
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/t.py", line 20, in <module>
        main()
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/t.py", line 17, in main
        quantize_fx.prepare_qat_fx(net, qconfig_dict)
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 564, in prepare_qat_fx
        return _prepare_fx(
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 237, in _prepare_fx
        graph_module = GraphModule(model, tracer.trace(model))
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 566, in trace
        self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
      File "<eval_with_key>.0", line 71, in forward
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 556, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 158, in call_module
        return super().call_module(m, forward, args, kwargs)
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 372, in call_module
        return forward(*args, **kwargs)
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 552, in forward
        return _orig_module_call(mod, *args, **kwargs)
      File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/onnx2torch/node_converters/global_average_pool.py", line 19, in forward
        x_dims = list(range(2, len(input_tensor.shape)))
    TypeError: 'Proxy' object cannot be interpreted as an integer
    
    fix 
    opened by yan12125 3
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    Stale 
    opened by TrellixVulnTeam 2
  • Gemm Bug

    Gemm Bug

    https://github.com/ENOT-AutoDL/onnx2torch/blob/2791442c88af2bd243d0092963838ce940f8f1b0/onnx2torch/node_converters/gemm.py#L67 is wrong, it should be if bias is None or bias.dim() == 1:

    Stale 
    opened by JiangYongYu1 2
  • feat: add LRN operator support (attempt 2)

    feat: add LRN operator support (attempt 2)

    This PR is for the same additions as #96 , but I had to make a new one. I renamed the "main" branch on my fork so that pre-commit no-commit-to-branch would pass, but in doing so, it deleted the branch that was associated with the old PR and auto-closed the PR. So, here we are. Phew! I think that's it though?

    opened by monicathieu 2
  • [Code Quality] Remove trailing whitespace and add isort for module import

    [Code Quality] Remove trailing whitespace and add isort for module import

    Use pre-commit to auto remove all trailing whitespace.

    Usage

    Use pip (or brew in macOS) to install pre-commit first

    pip install pre-commit
    

    Install the git hook scripts to this project

    pre-commit install
    

    Run against all the files

    $ pre-commit run --all-files
    

    Run for some files

    $ pre-commit run --files a.py b.py xxx.py
    enhancement 
    opened by triple-Mu 2
  • KeyError: 'Tensor

    KeyError: 'Tensor "onnx::Clip_2052" is not found in constant values'

    Trying to convert a LayoutMl model to pytorch model getting this error :

    KeyError Traceback (most recent call last) /usr/local/lib/python3.8/dist-packages/onnx2torch/node_converters/clip.py in _(node, graph) 56 try: ---> 57 min_val = float(get_const_value(min_name, graph)) if min_name is not None else None 58 max_val = float(get_const_value(max_name, graph)) if max_name is not None else None

    3 frames KeyError: 'Tensor "onnx::Clip_2052" is not found in constant values'

    The above exception was the direct cause of the following exception:

    NotImplementedError Traceback (most recent call last) /usr/local/lib/python3.8/dist-packages/onnx2torch/node_converters/clip.py in _(node, graph) 58 max_val = float(get_const_value(max_name, graph)) if max_name is not None else None 59 except KeyError as exc: ---> 60 raise NotImplementedError('Dynamic value of min/max is not implemented') from exc 61 62 torch_module = _create_torch_module(min_val=min_val, max_val=max_val)

    NotImplementedError: Dynamic value of min/max is not implemented

    opened by NeoKun004 0
  • save converted yolov5

    save converted yolov5

    Exported the module to folder using "to_folder" and encounter "NameError: name 'onnx2torch_converter_lambda_' is not defined" error when inference after import the module. Any tips what's wrong?

    opened by clh-b 2
  • Is There any possible way to have OnnxGemm converted into torch.nn.Linear?

    Is There any possible way to have OnnxGemm converted into torch.nn.Linear?

    Hi, It is kind of you to share your work. Very appreciate it.

    I have noticed that the GEMM layer of the onnx graph will be converted into a object of OnnxGemm.

    for example: image

    then we check the pytorch model image and its forward image

    I wonder that is it possible to have these GEMMs converted into PyTorch native module? And it will be much more scalable.

    Thank you!!

    opened by feanor21115 0
  • GatherND missing

    GatherND missing

    While mentioned in release v1.2.0, GatherND is not implemented. See https://github.com/ENOT-AutoDL/onnx2torch/pull/25. Are there any plans to still add the operator?

    opened by jonas-doevenspeck 4
Releases(v1.5.4)
  • v1.5.4(Nov 14, 2022)

  • v1.5.3(Sep 16, 2022)

  • v1.5.2(Sep 9, 2022)

  • v1.5.1(Sep 1, 2022)

  • v1.5.0(Aug 30, 2022)

    You've been waiting, you've asked... And here it is a new release(1.5.0)!

    In this release:

    • Fixed shape inference of big models (now you can load your lovely big language model, like gptj6B);
    • Fixed BarchNorm converter according onnx specs;
    • Added asymmetric padding for both Conv and MaxPool.

    Thank you for your discussions and issues! We always ready to help you!

    Source code(tar.gz)
    Source code(zip)
  • v1.4.1(Jul 12, 2022)

  • v1.4.0(Jul 6, 2022)

    :rocket: New day - new release! :rocket: New features in release 1.4.0:

    • Added operations:

      • Einsum,
      • Reciprocal,
      • Neg,
      • Prelu,
      • Mean,
      • Min,
      • Max,
      • CumSum,
      • Sum.
    • Fixes:

      • GlobalAveragePool compatible with PyTorch FX QAT,
      • Compatibility with torch 1.12,
      • Allow to ignore bs+ch dimensions for size input in Resize.
      • Check roi only for tf_crop_and_resize mode.
    • Code style and workflows:

    We thank all contributors for their work! Together we will make onnx2torch great! :star: 10^6

    Source code(tar.gz)
    Source code(zip)
  • v1.3.0(Apr 20, 2022)

    New features in release 1.3.0:

    • Add Pad operation
    • Add Hardswish, Celu, Elu, Selu, Softplus, Softsign functions
    • Add GatherElements operation
    • Add new model tests

    Fixes:

    • Fix ConvTranspose operation
    Source code(tar.gz)
    Source code(zip)
  • v1.2.5(Apr 7, 2022)

  • v1.2.0(Mar 23, 2022)

    New features in release 1.2.0:

    • Support VIT, SWIN, FasterRcnn models;
    • add RoiAlign operation;
    • Add Floor, Ceil, Round and trigonometric functions;
    • New model tests for classification, detection and segmentation;

    Fixes:

    • Fix ScatterND and its speed;
    • Fix names of plaseholders;
    • Fix optional arguments if last arguments are not passed;
    • Fix checking of empty inputs;
    • Fix avg pool operation;
    Source code(tar.gz)
    Source code(zip)
  • v1.1.0(Jan 18, 2022)

    New features in release 1.1.0:

    Support of optional input arguments;

    Fixed dynamic axes for squeeze and unsqueeze.

    Also added new operations listed below:

    Resize Reduces operations (ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceMax, ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare) Logical Operaations (Or, Xor, And, Not) HardSigmoid, LeakyRelu Pow, Sqrt

    Source code(tar.gz)
    Source code(zip)
  • v1.0.0(Dec 14, 2021)

Owner
ENOT
ENOT
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
PyTorch implementation of "Learn to Dance with AIST++: Music Conditioned 3D Dance Generation."

Learn to Dance with AIST++: Music Conditioned 3D Dance Generation. Installation pip install -r requirements.txt Prepare Dataset bash data/scripts/pre

Zj Li 8 Sep 07, 2021
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 06, 2022
Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

75 Nov 24, 2022
This repository is an implementation of paper : Improving the Training of Graph Neural Networks with Consistency Regularization

CRGNN Paper : Improving the Training of Graph Neural Networks with Consistency Regularization Environments Implementing environment: GeForce RTX™ 3090

THUDM 28 Dec 09, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

48 Dec 20, 2022
A highly efficient and modular implementation of Gaussian Processes in PyTorch

GPyTorch GPyTorch is a Gaussian process library implemented using PyTorch. GPyTorch is designed for creating scalable, flexible, and modular Gaussian

3k Jan 02, 2023
Point-NeRF: Point-based Neural Radiance Fields

Point-NeRF: Point-based Neural Radiance Fields Project Sites | Paper | Primary c

Qiangeng Xu 662 Jan 01, 2023
Logsig-RNN: a novel network for robust and efficient skeleton-based action recognition

GCN_LogsigRNN This repository holds the codebase for the paper: Logsig-RNN: a novel network for robust and efficient skeleton-based action recognition

7 Oct 14, 2022
Adversarial Self-Defense for Cycle-Consistent GANs

Adversarial Self-Defense for Cycle-Consistent GANs This is the official implementation of the CycleGAN robust to self-adversarial attacks used in pape

Dina Bashkirova 10 Oct 10, 2022
Over-the-Air Ensemble Inference with Model Privacy

Over-the-Air Ensemble Inference with Model Privacy This repository contains simulations for our private ensemble inference method. Installation Instal

Selim Firat Yilmaz 1 Jun 29, 2022
SuRE Evaluation: A Supplementary Material

SuRE Evaluation: A Supplementary Material This repository contains supplementary material regarding the evaluations presented in the paper Visual Expl

NYU Visualization Lab 0 Dec 14, 2021
Real-Time and Accurate Full-Body Multi-Person Pose Estimation&Tracking System

News! Aug 2020: v0.4.0 version of AlphaPose is released! Stronger tracking! Include whole body(face,hand,foot) keypoints! Colab now available. Dec 201

Machine Vision and Intelligence Group @ SJTU 6.7k Dec 28, 2022
Annotated notes and summaries of the TensorFlow white paper, along with SVG figures and links to documentation

TensorFlow White Paper Notes Features Notes broken down section by section, as well as subsection by subsection Relevant links to documentation, resou

Sam Abrahams 437 Oct 09, 2022
Spectralformer: Rethinking hyperspectral image classification with transformers

Spectralformer: Rethinking hyperspectral image classification with transformers Danfeng Hong, Zhu Han, Jing Yao, Lianru Gao, Bing Zhang, Antonio Plaza

Danfeng Hong 102 Dec 29, 2022
Code release of paper "Deep Multi-View Stereo gone wild"

Deep MVS gone wild Pytorch implementation of "Deep MVS gone wild" (Paper | website) This repository provides the code to reproduce the experiments of

François Darmon 53 Dec 24, 2022
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

536 Jan 05, 2023
High performance distributed framework for training deep learning recommendation models based on PyTorch.

High performance distributed framework for training deep learning recommendation models based on PyTorch.

340 Dec 30, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

32 Sep 21, 2022
Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec

Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec This repo

Building and Urban Data Science (BUDS) Group 5 Dec 02, 2022