How to train your GAN
Strategies for training generative models
If you’ve heard about algorithms that generate photorealistic images from scratch — human faces wearing fake expressions, daytime photos that look like night, cat sketches that are turned into realistic cat images — you’ve heard of GANs.
Generative Adversarial Networks, which I described a couple posts ago while trying to generate real-looking faces from random vectors of numbers, have huge potential. But there’s a catch, of course. They’re fiendishly hard to train.
In this post, I’ll write a bit about my experience training them, trying out the latest GAN-training fad diets. I’ll also link to better resources to help you train your GAN.
Why are GANs so awful?
Depending on your GAN’s hyperparameters, your GAN might train flawlessly – more likely, though, is that you’ll fall victim to one of GAN’s well-documented failure modes, and you won’t learn anything.
You might get stuck with mode collapse. Your generator’s job is to produce outputs that fool the discriminator. It has a little conversation in its head, that goes like this:
Woah! This photo is easy to draw but the discriminator is totally fooled by it. Sucker! My work here is done! I’ll just output this image, every single time, regardless of the random vector I’ve been sent.
⏰ Five minutes later 🕟
Fuck! Looks like the discriminator caught on about this image being fake! Gotta make the image a little different and keep pumping it out.
Because the generator doesn’t output diverse samples, the discriminator doesn’t learn much useful, except to flag the specific image (“mode”) that the generator’s currently outputting as fake. I don’t have a solid mathematical intuition for what exactly triggers the generator to start producing a bunch of very similar images, and I’d like to know.
There are plenty of supposed solutions to mode collapse, and I’ll talk about them in a bit.
Another common failure mode is vanishing gradients. This one’s pretty simple – the discriminator gets really good, faster than the generator, and is able to mark the generator’s output as fake with super-high probability. If your discriminator is marking fakes as 99.9999% fake, there’s only a tiny sliver of “realness.” The generator is trained to increase the “realness” of its samples according to the discriminator — if all we have is a tiny sliver of realness, the gradients are tiny and the training is week. We end up in a sad, sad world where the poor generator doesn’t have the feedback (in the form of gradients) to learn much of anything. You’ll see tiny discriminator loss, and huge generator loss.
WGAN to the rescue
In early 2017, three researchers published a paper introducing the Wasserstein GAN, claiming it alleviated all mode-collapse issues and made training GANs significantly more stable.
Most of the paper is spent establishing a theoretical model of why GANs are difficult to train. In GANs, we attempt to train the generator to match the probability distribution of the real data. To accomplish this with gradient descent, we need to minimize some “distance metric” between the generator’s current output distribution and the real data distribution. The WGAN paper interprets the traditional GAN algorithm as using KL divergence as a distance metric, which isn’t smoothly differentiable at most points, making it really difficult to do successfully.
They suggest a different metric — “earth mover distance” — which is differentiable everywhere, and offers good gradients. The incredible thing about the paper is that changing the GAN algorithm the be equivalent to minimizing earth-mover distance requires just three weird tricks:
rather than outputting classification probabilities (using something like softmax cross-entropy real/fake), the discriminator should output numbers, which can be as large as possible. train this discriminator — they call it a critic instead — to return a high positive number for real inputs, and a large negative number for fake inputs.
clip the weights of the gradient after each training iteration — the paper’s authors suggest clipping weights between [-0.01, 0.01]
rather than trying to balance generator/discriminator training, just train the critic to convergence before you start training the generator — the critic should still give good gradients, even if it’s really strong.
I played around with implementing a WGAN to generate synthetic samples based on MNIST digits and CelebA human faces.
Generating MNIST digits was a breeze — without any hyperparameter tuning, WGAN generated great samples: [📸 See the code]
CelebA didn’t work too well. The generator didn’t collapse and give terrible results — the sample quality clearly improved over time, but training was incredibly slow, far slower than a DCGAN. Here’s what I got before I stopped training:
Least-squares GAN
Another paper promising an improved GAN model is the “Least Squares Generative Adversarial Networks.” This post by Augustinus Kristiani gives a good overview of the paper, and makes some bold claims — that it’s as stable as WGAN but not nearly as slow, and also generates higher-quality samples.
The basic idea is even simpler than WGAN’s – rather than using softmax cross-entropy classification loss in the discriminator, use least-squares predictions instead. (i.e. train the discriminator to output 1 for real samples and -1 for fake samples, and train using L2 loss.) This makes sense because it forces the discriminator to output reasonably-sized numbers like -1 and 1, no matter how “good” or “confident” it is, which should give better gradients than a discriminator trained to convergence using softmax cross-entropy.
I wasn’t able to achieve particularly miraculous results using LSGAN on CelebA — I needed a bit of hyperparameter optimization to get anywhere reasonable, and my generator still tended to “collapse” occasionally and stop outputting good images.
That being said, the recent (and super cool) paper “Unsupervised Image-to-Image Translation Networks” uses least-squares loss instead of the traditional GAN formulation, so it’s clearly useful.
Improved WGAN (with Gradient Penalty)
A recent paper finds theoretical issues with one of the tricks WGAN uses to work. WGAN uses ‘gradient clipping’ to enforce a ‘Lipschitz constraint’ on the critic parameters (I have no idea what this means). The paper suggests that gradient clipping is a suboptimal way to enforce Lipschitz-ness, and ends up biasing the critic towards simpler models of the true distribution. Instead of clipping gradients, they suggest augmenting the critic loss function to encourage the critic’s gradients to be close to 1 with respect to input images.
Implementing gradient clipping isn’t that difficult — I was able to get an improved WGAN working on MNIST pretty quickly — 🎒 here’s the code for that.
Other tricks
There are lots and lots of people offering strategies for improving GAN training and stability. Many of these work for WGAN and LSGAN as well, but may not be as useful, given those algorithms’ stability promises.
Here are some I find interesting:
Store a ‘replay buffer’ of previous generator outputs. Occasionally, rather than training the discriminator on the latest generator outputs, train it on some old generator outputs — it should still know they’re fake!
Using various different optimizers in the discriminator, rather than Adam (which is usually my go-to) — apparently, momentum might cause instability
“Principled” attempts to balance discriminator and generator strength. I’ve tried keeping track of discriminator accuracy, and stop training the discriminator while its accuracy is < 80%, giving the generator a chance to “catch up.”
There are plenty of GAN-training resources that provide more (and better motivated) tricks than these — here are some of my favorites:
GAN Hacks by Soumith Chintala
Improved GANs from OpenAI














