Keras Implementation of Sketch-RNN
by Eyal Zakkay
A Keras Implementation of Sketch-RNN
In this repo there’s a Kares implementation of the Sketch-RNN algorithm,
as described in the paper A Neural Representation of Sketch Drawings by David Ha and Douglas Eck (Google AI).
The implementation is ported from the official Tensorflow implementation that was released under project Magenta by the authors.
Overview
Sketch-RNN consists of a Sequence to Sequence Variational Autoencoder (Seq2SeqVAE), which is able to encode a series of pen strokes (a sketch) into a latent space, using a bidirectional LSTM as the encoder. The latent representation can then be decoded back into a series of strokes.
The model is trained to reconstruct the original stroke sequences while maintaining a normal distribution across latent space elements. Since encoding is performed stochastically, and so is the sampling mechanism of the decoder, the reconstructed sketches are always different.
This allows to let a trained model draw new and unique sketches that it has not seen before. Designing the model as a variational autoencoder also allows to perform latent space manipulations to get interesting interpolations between different sketches.
There’s no need to elaborate on the specifics of the algorithm, since many great resources exist for this end.
I recommend David Ha’s blog post Teaching Machines to Draw.
Implementation Details
You can find in this repo some useful solutions for common pitfalls when porting from TF to Keras (and writing Keras in general), for example:
- Injecting values to intermediate tensors and predicting the corresponding values of other tensors by building sub-models
- Using custom generators to wrap data loader classes
- Using an auxiliary loss term that uses intermediate layers’ outputs rather than the model’s predictions
- Using a CuDNN LSTM layer, while allowing inference on CPU
- Resuming a training process from a checkpoint in the case that custom callbacks are used with dynamic internal variables
Dependencies
Tested in the following environment:
- Keras 2.2.4 (Tensorflow 1.11 backend)
- Python 3.5
- Windows OS
Hopefully, soon I will update with minimum requirements
Usage
Training
To train a model, you need a dataset in the appropriate format. You can download one of many prepared sketches datasets that were released by Google. Simply download one or more .npz
files and save them in the same directory (recommended to use a datasets
directory within the project’s main directory).
Example usage:
python seq2seqVAE_train --data_dir=datasets --data_set=cat --experiment_dir=\sketch_rnn\experiments
Currently, configurable hyperparameters can only be modified by changing their default values in seq2seqVAE.py
.
I might add an option to configure them via command line in the future.
You can also resume the training from a saved checkpoint by supplying the --checkpoint=[file path]
and a --initial_epoch=[epoch from which you are restarting]
arguments.
The full list of configurable parameters:
Using a trained model to draw
In the notebook Skecth_RNN_Keras.ipynb
you can supply a path to a trained model and a dataset and explore what the model has learned.
There are examples of encoding and decoding of sketches, interpolating in latent space, sampling under different temperature
values etc.
You can also load models trained on multiple data-sets and generate nifty interpolations such as these guitar-cats!
References
- The guitar animation was created using this tutorial