.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/sinkhorn_multiscale/plot_transport_blur.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_sinkhorn_multiscale_plot_transport_blur.py: 4) Sinkhorn vs. blurred Wasserstein distances ========================================================== Sinkhorn divergences rely on a simple idea: by **blurring** the transport plan through the addition of an entropic penalty, we can reduce the effective dimensionality of the transportation problem and compute **sensible approximations of the Wasserstein distance at a low computational cost**. .. GENERATED FROM PYTHON SOURCE LINES 13-69 As discussed in previous notebooks, the *vanilla* Sinkhorn loop can be symmetrized, de-biased and turned into a genuine multiscale algorithm: available through the :mod:`SamplesLoss("sinkhorn") ` layer, the **Sinkhorn divergence** .. math:: \text{S}_\varepsilon(\alpha,\beta)~=~ \text{OT}_\varepsilon(\alpha,\beta) - \tfrac{1}{2}\text{OT}_\varepsilon(\alpha,\alpha) - \tfrac{1}{2}\text{OT}_\varepsilon(\beta,\beta), is a tractable approximation of the Wasserstein distance that **retains its key geometric properties** - positivity, convexity, metrization of the convergence in law. **But is it really the best way of smoothing our transportation problem?** When "p = 2" and :math:`\text{C}(x,y)=\tfrac{1}{2}\|x-y\|^2`, a very sensible alternative to Sinkhorn divergences is the **blurred Wasserstein distance** .. math:: \text{B}_\varepsilon(\alpha,\beta) ~=~ \text{W}_2(\,k_{\varepsilon/4}\star\alpha,\,k_{\varepsilon/4}\star\beta\,), where :math:`\text{W}_2` denotes the *true* Wasserstein distance associated to our cost function :math:`\text{C}` and .. math:: k_{\varepsilon/4}: (x-y) \mapsto \exp(-\|x-y\|^2 / \tfrac{2}{4}\varepsilon) is a Gaussian kernel of deviation :math:`\sigma = \sqrt{\varepsilon}/2`. On top of making explicit our intuitions on **low-frequency Optimal Transport**, this simple divergence enjoys a collection of desirable properties: - It is the **square of a distance** that metrizes the convergence in law. - It takes the "correct" values on atomic **Dirac masses**, lifting the ground cost function to the space of positive measures: .. math:: \text{B}_\varepsilon(\delta_x,\delta_y)~=~\text{C}(x,y) ~=~\tfrac{1}{2}\|x-y\|^2~=~\text{S}_\varepsilon(\delta_x,\delta_y). - It has the same **asymptotic properties** as the Sinkhorn divergence, interpolating between the true Wasserstein distance (when :math:`\varepsilon \rightarrow 0`) and a degenerate kernel norm (when :math:`\varepsilon \rightarrow +\infty`). - Thanks to the joint convexity of the Wasserstein distance, :math:`\text{B}_\varepsilon(\alpha,\beta)` is a **decreasing** function of :math:`\varepsilon`: as we remove small-scale details, we lower the overall transport cost. To compare the Sinkhorn and blurred Wasserstein divergences, a simple experiment is to **display their values on pairs of 1D measures** for increasing values of the temperature :math:`\varepsilon`: having generated random samples :math:`\alpha` and :math:`\beta` on the unit interval, we can simply compute :math:`\text{S}_\varepsilon(\alpha,\beta)` with our :mod:`SamplesLoss("sinkhorn") ` layer while the blurred Wasserstein loss :math:`\text{B}_\varepsilon(\alpha,\beta)` can be quickly approximated with the **addition of a Gaussian noise** followed by a **sorting pass**. .. GENERATED FROM PYTHON SOURCE LINES 71-74 Setup --------------------- Standard imports: .. GENERATED FROM PYTHON SOURCE LINES 74-86 .. code-block:: default import numpy as np import matplotlib.pyplot as plt from sklearn.neighbors import KernelDensity # display as density curves import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() #  N.B.: We use float64 numbers to get nice limits when blur -> +infinity dtype = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor .. GENERATED FROM PYTHON SOURCE LINES 87-88 Display routine: .. GENERATED FROM PYTHON SOURCE LINES 88-101 .. code-block:: default t_plot = np.linspace(-0.5, 1.5, 1000)[:, np.newaxis] def display_samples(ax, x, color, label=None): """Displays samples on the unit interval using a density curve.""" kde = KernelDensity(kernel="gaussian", bandwidth=0.005).fit(x.data.cpu().numpy()) dens = np.exp(kde.score_samples(t_plot)) dens[0] = 0 dens[-1] = 0 ax.fill(t_plot, dens, color=color, label=label) .. GENERATED FROM PYTHON SOURCE LINES 102-104 Experiment ------------- .. GENERATED FROM PYTHON SOURCE LINES 104-177 .. code-block:: default def rweight(): """Random weight.""" return torch.rand(1).type(dtype) N = 100 if not use_cuda else 10 ** 3 # Number of samples per measure C = 100 if not use_cuda else 10000 # number of copies for the Gaussian blur for _ in range(5): # Repeat the experiment 5 times K = 5 # Generate random 1D measures as the superposition of K=5 intervals t = torch.linspace(0, 1, N // K).type(dtype).view(-1, 1) X_i = torch.cat([rweight() ** 2 * t + rweight() - 0.5 for k in range(K)], dim=0) Y_j = torch.cat([rweight() ** 2 * t + rweight() - 0.5 for k in range(K)], dim=0) # Compute the limits when blur = 0... x_, _ = X_i.sort(dim=0) y_, _ = Y_j.sort(dim=0) true_wass = (0.5 / len(X_i)) * ((x_ - y_) ** 2).sum() true_wass = true_wass.item() # and when blur = +infinity: mean_diff = 0.5 * ((X_i.mean(0) - Y_j.mean(0)) ** 2).sum() mean_diff = mean_diff.item() blurs = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0] sink, bwass = [], [] for blur in blurs: # Compute the Sinkhorn divergence: # N.B.: To be super-precise, we use the well-tested "online" backend # with a very large 'scaling' coefficient loss = SamplesLoss("sinkhorn", p=2, blur=blur, scaling=0.99, backend="online") sink.append(loss(X_i, Y_j).item()) # Compute the blurred Wasserstein distance: x_i = torch.cat([X_i] * C, dim=0) y_j = torch.cat([Y_j] * C, dim=0) x_i = x_i + 0.5 * blur * torch.randn(x_i.shape).type(dtype) y_j = y_j + 0.5 * blur * torch.randn(y_j.shape).type(dtype) x_, _ = x_i.sort(dim=0) y_, _ = y_j.sort(dim=0) wass = (0.5 / len(x_i)) * ((x_ - y_) ** 2).sum() bwass.append(wass.item()) # Fancy display: plt.figure(figsize=(12, 5)) if N < 10 ** 5: ax = plt.subplot(1, 2, 1) display_samples(ax, X_i, (1.0, 0, 0, 0.5), label="$\\alpha$") display_samples(ax, Y_j, (0, 0, 1.0, 0.5), label="$\\beta$") plt.axis([-0.5, 1.5, -0.1, 5.5]) plt.ylabel("density") ax.legend() plt.tight_layout() ax = plt.subplot(1, 2, 2) plt.plot([0.01, 10], [true_wass, true_wass], "g", label="True Wasserstein") plt.plot(blurs, sink, "r-o", label="Sinkhorn divergence") plt.plot(blurs, bwass, "b-o", label="Blurred Wasserstein") plt.plot( [0.01, 10], [mean_diff, mean_diff], "m", label="Squared difference of means" ) ax.set_xscale("log") ax.legend() plt.axis([0.01, 10.0, 0.0, 1.5 * bwass[0]]) plt.xlabel("blur $\\sqrt{\\varepsilon}$") plt.tight_layout() plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_001.png :alt: plot transport blur :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_001.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_002.png :alt: plot transport blur :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_002.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_003.png :alt: plot transport blur :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_003.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_004.png :alt: plot transport blur :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_004.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_005.png :alt: plot transport blur :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_transport_blur_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 178-216 Conclusion -------------- In practice, the Sinkhorn and blurred Wasserstein divergences are **nearly indistinguishable**. But as far as we can tell *today*, these two loss functions have very different properties: - :math:`\text{B}_\varepsilon` is **easy to define**, compute in 1D and **analyze** from geometric or statistical point of views... But cannot (?) be computed efficiently in higher dimensions, where the true OT problem is nearly intractable. - :math:`\text{S}_\varepsilon` is simply available through the :mod:`SamplesLoss("sinkhorn") ` layer, but has a weird, composite definition and is pretty **hard to** **study** rigorously - as evidenced by recent, technical proofs of `positivity, definiteness (Feydy et al., 2018) `_ and `sample complexity (Genevay et al., 2018) `_. **So couldn't we get the best of both worlds?** In an ideal world, we'd like to tweak the *efficient* multiscale Sinkhorn algorithm to compute the *natural* divergence :math:`\text{B}_\varepsilon`... but this may be out of reach. A realistic target could be to **quantify** **the difference** between these two objects, thus legitimizing the use of the :mod:`SamplesLoss("sinkhorn") ` layer as a **cheap proxy** for the intuitive and well-understood *blurred Wasserstein distance*. In my opinion, investigating the link between these two quantities is one of the most interesting questions left open in the field of discrete entropic OT. The geometric loss functions implemented in GeomLoss are probably *good enough* for most practical purposes, but getting a **rigorous understanding** of the multiscale, wavelet-like behavior of our algorithms as we add small details through an exponential decay of the blurring scale :math:`\sqrt{\varepsilon}` would be truly insightful. In some sense, couldn't we prove a `Hilbert `_-`Plancherel `_ theorem for the Wasserstein distance? .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 26.314 seconds) .. _sphx_glr_download__auto_examples_sinkhorn_multiscale_plot_transport_blur.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_transport_blur.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_transport_blur.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_