The Simplest DCGAN Implementation

Overview

DCGAN in TensorLayer

This is the TensorLayer implementation of Deep Convolutional Generative Adversarial Networks. Looking for Text to Image Synthesis ? click here

alt tag

  • 🆕 🔥 2019 May: We just update this project to support TF2 and TL2. Enjoy!
  • 🆕 🔥 2019 May: This project is chosen as the default template of TL projects.

Prerequisites

  • Python3.5 3.6
  • TensorFlow==2.0.0a0 pip3 install tensorflow-gpu==2.0.0a0
  • TensorLayer=2.1.0 pip3 install tensorlayer==2.1.0

Usage

First, download the aligned face images from google or baidu to a data folder.

Second, train the GAN:

$ python train.py

Result on celebA

Comments
  • ValueError: Trying to share variable discriminator/d/h4/lin_sigmoid/W

    ValueError: Trying to share variable discriminator/d/h4/lin_sigmoid/W

    hi, I use: Tensorflow V 1.10 CPU version Tensorlayer V 1.10.1 python 3.6 windows 10

    after run : python main.py I get ValueError

    Traceback (most recent call last): File "c:/Users/USER/Desktop/פרוייקט תואר שני/playgroung/dcgan/main.py", line 166, in tf.app.run() File "C:\Python\3.6\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run _sys.exit(main(argv)) File "c:/Users/USER/Desktop/פרוייקט תואר שני/playgroung/dcgan/main.py", line 59, in main _, d2_logits = discriminator(real_images, is_train=True, reuse=True) File "c:\Users\USER\Desktop\פרוייקט תואר שני\playgroung\dcgan\model.py", line 81, in discriminator W_init = w_init, name='d/h4/lin_sigmoid') File "c:\Users\USER\Desktop\פרוייקט תואר שני\playgroung\dcgan\tensorlayer\layers\core.py", line 926, in init W = tf.get_variable(name='W', shape=(n_in, n_units), initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) File "C:\Python\3.6\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1467, in get_variable aggregation=aggregation) File "C:\Python\3.6\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1217, in get_variable aggregation=aggregation) File "C:\Python\3.6\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 527, in get_variable aggregation=aggregation) File "C:\Python\3.6\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 481, in _true_getter aggregation=aggregation) File "C:\Python\3.6\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 853, in _get_single_variable found_var.get_shape())) ValueError: Trying to share variable discriminator/d/h4/lin_sigmoid/W, but specified shape (8192, 1) and found shape (2048, 1).

    opened by motkeg 2
  • Error during download

    Error during download

    Tried twice, same error.

    python download.py celebA
    ./data/img_align_celeba.zip: 25.5kB [12:15, 34.7B/s]Traceback (most recent call last):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 302, in _error_catcher
        yield
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 601, in read_chunked
        chunk = self._handle_chunk(amt)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 557, in _handle_chunk
        value = self._fp._safe_read(amt)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/http/client.py", line 614, in _safe_read
        raise IncompleteRead(b''.join(s), amt)
    http.client.IncompleteRead: IncompleteRead(5887 bytes read, 26881 more expected)
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/requests/models.py", line 745, in generate
        for chunk in self.raw.stream(chunk_size, decode_content=True):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 432, in stream
        for line in self.read_chunked(amt, decode_content=decode_content):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 626, in read_chunked
        self._original_response.close()
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/contextlib.py", line 99, in __exit__
        self.gen.throw(type, value, traceback)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/urllib3/response.py", line 320, in _error_catcher
        raise ProtocolError('Connection broken: %r' % e, e)
    urllib3.exceptions.ProtocolError: ('Connection broken: IncompleteRead(5887 bytes read, 26881 more expected)', IncompleteRead(5887 bytes read, 26881 more expected))
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "download.py", line 164, in <module>
        download_celeb_a('./data')
      File "download.py", line 91, in download_celeb_a
        download_file_from_google_drive(drive_id, save_path)
      File "download.py", line 56, in download_file_from_google_drive
        save_response_content(response, destination)
      File "download.py", line 68, in save_response_content
        unit='B', unit_scale=True, desc=destination):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/tqdm/_tqdm.py", line 962, in __iter__
        for obj in iterable:
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/requests/models.py", line 748, in generate
        raise ChunkedEncodingError(e)
    requests.exceptions.ChunkedEncodingError: ('Connection broken: IncompleteRead(5887 bytes read, 26881 more expected)', IncompleteRead(5887 bytes read, 26881 more expected))
    
    opened by arisliang 2
  • Traceback error - main.py

    Traceback error - main.py

    Trying to run main.py and I get the following error.

    traceback (most recent call last):
      File "main.py", line 23, in <module>
        flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
      File "/anaconda2/lib/python2.7/site-packages/tensorflow/python/platform/flags.py", line 58, in wrapper
        return original_function(*args, **kwargs)
      File "/anaconda2/lib/python2.7/site-packages/absl/flags/_defines.py", line 315, in DEFINE_integer
        DEFINE(parser, name, default, help, flag_values, serializer, **args)
      File "/anaconda2/lib/python2.7/site-packages/absl/flags/_defines.py", line 81, in DEFINE
        DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),
      File "/anaconda2/lib/python2.7/site-packages/absl/flags/_flag.py", line 107, in __init__
        self._set_default(default)
      File "/anaconda2/lib/python2.7/site-packages/absl/flags/_flag.py", line 196, in _set_default
        self.default = self._parse(value)
      File "/anaconda2/lib/python2.7/site-packages/absl/flags/_flag.py", line 169, in _parse
        'flag --%s=%s: %s' % (self.name, argument, e))
    absl.flags._exceptions.IllegalFlagValueError: flag --train_size=inf: Expect argument to be a string or int, found <type 'float'>
    

    Has anyone else encountered this?

    opened by 1140lex 2
  • Error running train.py

    Error running train.py

    Running train.py i get the following error:

    Traceback (most recent call last):
      File "train.py", line 60, in <module>
        train()
      File "train.py", line 44, in train
        grad = tape.gradient(g_loss, G.trainable_weights)
    AttributeError: 'Model' object has no attribute 'trainable_weights'
    

    am i running a wrong version of tensorflow/tensorlayer (2.0.0a0/2.0.0)?

    opened by ghost 1
  • The url for CelebA dose not work now

    The url for CelebA dose not work now

    The url for CelebA dataset in download.py dose not work now because of the dropbox problem. So It's better to change it to the url of google-drive or baidu-drive.

    opened by fangpin 1
  • Minor fixes and improvements

    Minor fixes and improvements

    • Added a mini-batch generator function.
    • Renamed generator and discriminator functions.
    • Removed logits from generator as it was redundant.
    • Changed w_init in both models to tf.glorot_normal_initializer since it has shown to improve convergence of model.
    • Using imageio instead of scipy.misc wherever possible since it is deprecated.
    opened by pskrunner14 0
  • TL compatibility and code readabality

    TL compatibility and code readabality

    • Made code compatible with TL1.10.1 and TF1.10.0.
    • Removed out size and batch_size from DeConv2D layers since they're no longer needed.
    • Replaced tl.act.lrelu with tf.nn.leaky_relu since former is deprecated.
    • Replaced tl.layers.initialize_global_variables with tf.global_variables_initializer since former is deprecated.
    • Changed flag train_size dtype from int to float due to an error.
    • Impoved code readability and removed redundant comments and code.
    • Removed pprint in favour of custom logging the flags since pprint was'nt able to access flag values.
    • Fixed wildcard imports and removed unused ones.
    • Updated main module doc.
    • Removed tensorlayer package from root dir since it is installed on system.

    Note: Fixes #9 and #12.

    opened by pskrunner14 0
  • --sample_size option causes ValueError

    --sample_size option causes ValueError

    When I pass --sample_size 256 as a flag to main.py, I get

    ValueError: Cannot feed value of shape (256, 100) for Tensor 'z_noise:0', which has shape '(64, 100)'

    When I pass --sample_size 256 and --batch_size 256, the code went into an endless loop just saying "Sample images updated!" without running each epoch.

    opened by jkraybill 0
  • Performance issue in /data.py (by P3)

    Performance issue in /data.py (by P3)

    Hello! I've found a performance issue in /data.py: ds.batch(batch_size)(here) should be called before ds.map(_map_fn, num_parallel_calls=4)(here), which could make your program more efficient.

    Here is the tensorflow document to support it.

    Besides, you need to check the function _map_fn called in ds.map(_map_fn, num_parallel_calls=4) whether to be affected or not to make the changed code work properly. For example, if _map_fn needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z) after fix.

    Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.

    opened by DLPerf 1
  • How to train on existing checkpoint

    How to train on existing checkpoint

    I would like to train on an existing saved checkpoint but it seems the GAN learn from scratches.

    Is there some way to load a existing checkpoint ?

    Thanks in advance.

    opened by Fqlox 0
  • local variable 'z' referenced before assignment

    local variable 'z' referenced before assignment

    I have this error! can you help me, please?

    WARNING: Logging before flag parsing goes to stderr. I0718 14:45:17.243686 140030963623808 tl_logging.py:99] [!] checkpoint exists ... I0718 14:45:17.247771 140030963623808 tl_logging.py:99] [!] samples exists ... W0718 14:45:19.318463 140030963623808 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py:505: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version. Instructions for updating: tf.py_func is deprecated in TF V2. Instead, there are two options available in V2. - tf.py_function takes a python function which manipulates tf eager tensors instead of numpy arrays. It's easy to convert a tf eager tensor to an ndarray (just call tensor.numpy()) but having access to eager tensors means tf.py_functions can use accelerators such as GPUs as well as being differentiable using a gradient tape. - tf.numpy_function maintains the semantics of the deprecated tf.py_func (it is not differentiable, and manipulates numpy arrays). It drops the stateful argument making all functions stateful.

    I0718 14:45:19.464441 140030963623808 tl_logging.py:99] Input _inputlayer_1: [None, 100] I0718 14:45:19.974161 140030963623808 tl_logging.py:99] Dense dense_1: 8192 No Activation I0718 14:45:20.649731 140030963623808 tl_logging.py:99] Reshape reshape_1 I0718 14:45:20.707117 140030963623808 tl_logging.py:99] BatchNorm batchnorm_1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False I0718 14:45:20.782280 140030963623808 tl_logging.py:99] DeConv2d deconv2d_1: n_filters: 256 strides: (2, 2) padding: SAME act: No Activation dilation: (1, 1) I0718 14:45:22.796400 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False I0718 14:45:22.873876 140030963623808 tl_logging.py:99] DeConv2d deconv2d_2: n_filters: 128 strides: (2, 2) padding: SAME act: No Activation dilation: (1, 1) I0718 14:45:23.076492 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_2: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False I0718 14:45:23.153287 140030963623808 tl_logging.py:99] DeConv2d deconv2d_3: n_filters: 64 strides: (2, 2) padding: SAME act: No Activation dilation: (1, 1) I0718 14:45:23.227763 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_3: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False I0718 14:45:23.307783 140030963623808 tl_logging.py:99] DeConv2d deconv2d_4: n_filters: 3 strides: (2, 2) padding: SAME act: tanh dilation: (1, 1) I0718 14:45:23.390201 140030963623808 tl_logging.py:99] Input _inputlayer_2: [None, 64, 64, 3] I0718 14:45:23.456669 140030963623808 tl_logging.py:99] Conv2d conv2d_1: n_filter: 64 filter_size: (5, 5) strides: (2, 2) pad: SAME act: I0718 14:45:23.530629 140030963623808 tl_logging.py:99] Conv2d conv2d_2: n_filter: 128 filter_size: (5, 5) strides: (2, 2) pad: SAME act: No Activation I0718 14:45:23.608061 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_4: decay: 0.900000 epsilon: 0.000010 act: is_train: False I0718 14:45:23.691769 140030963623808 tl_logging.py:99] Conv2d conv2d_3: n_filter: 256 filter_size: (5, 5) strides: (2, 2) pad: SAME act: No Activation I0718 14:45:23.775271 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_5: decay: 0.900000 epsilon: 0.000010 act: is_train: False I0718 14:45:23.942049 140030963623808 tl_logging.py:99] Conv2d conv2d_4: n_filter: 512 filter_size: (5, 5) strides: (2, 2) pad: SAME act: No Activation I0718 14:45:24.041154 140030963623808 tl_logging.py:99] BatchNorm batchnorm2d_6: decay: 0.900000 epsilon: 0.000010 act: is_train: False I0718 14:45:24.114703 140030963623808 tl_logging.py:99] Flatten flatten_1: I0718 14:45:24.177696 140030963623808 tl_logging.py:99] Dense dense_2: 1 identity I0718 14:45:29.721007 140030963623808 tl_logging.py:99] [*] Saving TL weights into checkpoint/G.npz I0718 14:45:31.268130 140030963623808 tl_logging.py:99] [*] Saved I0718 14:45:31.274845 140030963623808 tl_logging.py:99] [*] Saving TL weights into checkpoint/D.npz I0718 14:45:32.590787 140030963623808 tl_logging.py:99] [*] Saved

    UnboundLocalError Traceback (most recent call last) in () 58 59 if name == 'main': ---> 60 train()

    in train() 53 D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz') 54 G.eval() ---> 55 result = G(z) 56 G.train() 57 tl.visualize.save_images(result.numpy(), [num_tiles, num_tiles], '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))

    UnboundLocalError: local variable 'z' referenced before assignment

    opened by emnab 4
  • TypeError when run main.py

    TypeError when run main.py

    Tried both the dcgan folder version of tensorlayer (1.4.5) and installed version of tensorlayer (1.8.1). tensorflow version is 1.6.0.

    $ python main.py 
    /home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
      from ._conv import register_converters as _register_converters
    Traceback (most recent call last):
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_flag.py", line 166, in _parse
        return self.parser.parse(argument)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_argument_parser.py", line 152, in parse
        val = self.convert(argument)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_argument_parser.py", line 268, in convert
        type(argument)))
    TypeError: Expect argument to be a string or int, found <class 'float'>
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "main.py", line 23, in <module>
        flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/tensorflow/python/platform/flags.py", line 58, in wrapper
        return original_function(*args, **kwargs)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_defines.py", line 315, in DEFINE_integer
        DEFINE(parser, name, default, help, flag_values, serializer, **args)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_defines.py", line 81, in DEFINE
        DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_flag.py", line 107, in __init__
        self._set_default(default)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_flag.py", line 196, in _set_default
        self.default = self._parse(value)
      File "/home/ly/anaconda3/envs/learning/lib/python3.6/site-packages/absl/flags/_flag.py", line 169, in _parse
        'flag --%s=%s: %s' % (self.name, argument, e))
    absl.flags._exceptions.IllegalFlagValueError: flag --train_size=inf: Expect argument to be a string or int, found <class 'float'>
    
    
    opened by arisliang 5
Releases(0.5.0)
Owner
TensorLayer Community
A neutral open community to promote AI technology.
TensorLayer Community
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023
Implement A3C for Mujoco gym envs

pytorch-a3c-mujoco Disclaimer: my implementation right now is unstable (you ca refer to the learning curve below), I'm not sure if it's my problems. A

Andrew 70 Dec 12, 2022
Notebook and code to synthesize complex and highly dimensional datasets using Gretel APIs.

Gretel Trainer This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code w

Gretel.ai 24 Nov 03, 2022
Simple Python project using Opencv and datetime package to recognise faces and log attendance data in a csv file.

Attendance-System-based-on-Facial-recognition-Attendance-data-stored-in-csv-file- Simple Python project using Opencv and datetime package to recognise

3 Aug 09, 2022
A module for solving and visualizing Schrödinger equation.

qmsolve This is an attempt at making a solid, easy to use solver, capable of solving and visualize the Schrödinger equation for multiple particles, an

506 Dec 28, 2022
Finetune the base 64 px GLIDE-text2im model from OpenAI on your own image-text dataset

Finetune the base 64 px GLIDE-text2im model from OpenAI on your own image-text dataset

Clay Mullis 82 Oct 13, 2022
The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch Railway

Openspoor The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch

7 Aug 22, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
VGG16 model-based classification project about brain tumor detection.

Brain-Tumor-Classification-with-MRI VGG16 model-based classification project about brain tumor detection. First, you can check what people are doing o

Atakan Erdoğan 2 Mar 21, 2022
Code for the paper "There is no Double-Descent in Random Forests"

Code for the paper "There is no Double-Descent in Random Forests" This repository contains the code to run the experiments for our paper called "There

2 Jan 14, 2022
Understanding the Effects of Datasets Characteristics on Offline Reinforcement Learning

Understanding the Effects of Datasets Characteristics on Offline Reinforcement Learning Kajetan Schweighofer1, Markus Hofmarcher1, Marius-Constantin D

Institute for Machine Learning, Johannes Kepler University Linz 17 Dec 28, 2022
A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering.

DeepFilterNet A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering. libDF contains Rust code used for dat

Hendrik Schröter 292 Dec 25, 2022
A simple consistency training framework for semi-supervised image semantic segmentation

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation PseudoSeg is a simple consistency training framework for semi-supervised image semantic s

Google Interns 143 Dec 13, 2022
🕺Full body detection and tracking

Pose-Detection 🤔 Overview Human pose estimation from video plays a critical role in various applications such as quantifying physical exercises, sign

Abbas Ataei 20 Nov 21, 2022
PyTorch implementation for 3D human pose estimation

Towards 3D Human Pose Estimation in the Wild: a Weakly-supervised Approach This repository is the PyTorch implementation for the network presented in:

Xingyi Zhou 579 Dec 22, 2022
A treasure chest for visual recognition powered by PaddlePaddle

简体中文 | English PaddleClas 简介 飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,助力使用者训练出更好的视觉模型和应用落地。 近期更新 2021.11.1 发布PP-ShiTu技术报告,新增饮料识别demo 2021.10.23 发

4.6k Dec 31, 2022
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)

代码地址: https://github.com/Sharpiless/yolov5-knowledge-distillation 教师模型: python train.py --weights weights/yolov5m.pt \ --cfg models/yolov5m.ya

52 Dec 04, 2022
这个开源项目主要是对经典的时间序列预测算法论文进行复现,模型主要参考自GluonTS,框架主要参考自Informer

Time Series Research with Torch 这个开源项目主要是对经典的时间序列预测算法论文进行复现,模型主要参考自GluonTS,框架主要参考自Informer。 建立原因 相较于mxnet和TF,Torch框架中的神经网络层需要提前指定输入维度: # 建立线性层 TensorF

Chi Zhang 85 Dec 29, 2022
Official Pytorch implementation for "End2End Occluded Face Recognition by Masking Corrupted Features, TPAMI 2021"

End2End Occluded Face Recognition by Masking Corrupted Features This is the Pytorch implementation of our TPAMI 2021 paper End2End Occluded Face Recog

Haibo Qiu 25 Oct 31, 2022
Nested cross-validation is necessary to avoid biased model performance in embedded feature selection in high-dimensional data with tiny sample sizes

Pruner for nested cross-validation - Sphinx-Doc Nested cross-validation is necessary to avoid biased model performance in embedded feature selection i

1 Dec 15, 2021