A Keras Implementation of Sketch-RNN
In this repo there’s a Kares implementation of the Sketch-RNN algorithm,
as described in the paper A Neural Representation of Sketch Drawings by David Ha and Douglas Eck (Google AI).
Sketch-RNN consists of a Sequence to Sequence Variational Autoencoder (Seq2SeqVAE), which is able to encode a series of pen strokes (a sketch) into a latent space, using a bidirectional LSTM as the encoder. The latent representation can then be decoded back into a series of strokes.
The model is trained to reconstruct the original stroke sequences while maintaining a normal distribution across latent space elements. Since encoding is performed stochastically, and so is the sampling mechanism of the decoder, the reconstructed sketches are always different.
This allows to let a trained model draw new and unique sketches that it has not seen before. Designing the model as a variational autoencoder also allows to perform latent space manipulations to get interesting interpolations between different sketches.
There’s no need to elaborate on the specifics of the algorithm, since many great resources exist for this end.
I recommend David Ha’s blog post Teaching Machines to Draw.
You can find in this repo some useful solutions for common pitfalls when porting from TF to Keras (and writing Keras in general), for example:
- Injecting values to intermediate tensors and predicting the corresponding values of other tensors by building sub-models
- Using custom generators to wrap data loader classes
- Using an auxiliary loss term that uses intermediate layers’ outputs rather than the model’s predictions
- Using a CuDNN LSTM layer, while allowing inference on CPU
- Resuming a training process from a checkpoint in the case that custom callbacks are used with dynamic internal variables
Tested in the following environment:
- Keras 2.2.4 (Tensorflow 1.11 backend)
- Python 3.5
- Windows OS
Hopefully, soon I will update with minimum requirements
To train a model, you need a dataset in the appropriate format. You can download one of many prepared sketches datasets that were released by Google. Simply download one or more
.npz files and save them in the same directory (recommended to use a
datasets directory within the project’s main directory).
python seq2seqVAE_train --data_dir=datasets --data_set=cat --experiment_dir=\sketch_rnn\experiments
Currently, configurable hyperparameters can only be modified by changing their default values in
I might add an option to configure them via command line in the future.
You can also resume the training from a saved checkpoint by supplying the
--checkpoint=[file path] and a
--initial_epoch=[epoch from which you are restarting] arguments.
The full list of configurable parameters:
Using a trained model to draw
In the notebook
Skecth_RNN_Keras.ipynb you can supply a path to a trained model and a dataset and explore what the model has learned.
There are examples of encoding and decoding of sketches, interpolating in latent space, sampling under different
temperature values etc.
You can also load models trained on multiple data-sets and generate nifty interpolations such as these guitar-cats!
- The guitar animation was created using this tutorial