Adaptive Symmetrization of the KL Divergence

An approach to non-adversarially minimize the symmetric Jeffreys divergence.

Illustration of different training dynamics, where the goal is to train $p_{\theta}$ and $q_{\psi}$ to match the data distribution $\pi$. Each point represents a probability distribution in conceptual distribution space, and the arrows indicate the direction of gradient descent (GD). (a) In maximum likelihood estimation (MLE), GD minimizes the divergence $D_{\text{KL}}\left(\pi \parallel p_{\theta} \right)$. In practice, GD minimizes a finite sample estimation of the KL, rather than the true KL. (b) GD of generative adversarial networks (GAN) pulls the generator $p_{\theta}$ towards the discriminator $q_{\psi}$ and the discriminator towards $\pi$, but it also pushes $q_{\psi}$ away from $p_{\theta}$. This repelling dynamic makes GAN training unstable. (c) In our framework, GD pulls $p_{\theta}$ and $q_{\psi}$ towards each other, while also pulling both towards $\pi$. The mutual attraction stabilizes training.

Introduction

Many tasks in machine learning can be described as learning a distribution $\pi$ from a finite set of samples. The most widely used approach is to minimize a statistical divergence. A common divergence is the forward KL divergence, as it is equivalent to cross entropy, or maximum likelihood estimation (MLE). However, the KL divergence is not symmetric and may lead the trained model to overestimate the weight of modes (a behavior also known as mode-covering). While minimizing the reverse KL tends to a mode-seeking behavior, which can counteract the mode-covering of the forward KL, this minimization requires access to the distribution we are trying to learn, which is impossible when given only finite data samples.

GANs propose another approach to minimize a symmetric divergence based on variational representations, where an additional model is introduced to represent the divergence as an optimization problem. This approach results in a min-max objective that is unstable to train and sensitive to the choice of hyper-parameters.

In this work, we propose to introduce an additional model to serve as a proxy to the true distribution $\pi$ and use it to approximate the reverse KL. By minimizing the sum of the forward and reverse KLs, also known as the Jeffreys divergence, we have both mode-covering and mode-seeking behaviors.

Comparison of negative log-likelihood (NLL) with normalizing flows (NF) and Wasserstein GAN (WGAN). The left graph (a) shows that on a small dataset, NF will learn the empirical distribution and with time will diverge from the true distributin. Our method achieves a smaller (better) and consistent NLL. The right graph (b) shows that increasing the learning rate prevents WGAN from training, while barely affecting our method.

Resources

Collaborators

References