.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/comparisons/plot_gradient_flows_2D.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_comparisons_plot_gradient_flows_2D.py: Gradient flows in 2D ==================== Let's showcase the properties of **kernel MMDs**, **Hausdorff** and **Sinkhorn** divergences on a simple toy problem: the registration of one blob onto another. .. GENERATED FROM PYTHON SOURCE LINES 12-14 Setup --------------------- .. GENERATED FROM PYTHON SOURCE LINES 14-25 .. code-block:: default import numpy as np import matplotlib.pyplot as plt import time import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor .. GENERATED FROM PYTHON SOURCE LINES 26-28 Display routine ~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 28-64 .. code-block:: default import numpy as np import torch from random import choices from imageio import imread from matplotlib import pyplot as plt def load_image(fname): img = imread(fname, as_gray=True) # Grayscale img = (img[::-1, :]) / 255.0 return 1 - img def draw_samples(fname, n, dtype=torch.FloatTensor): A = load_image(fname) xg, yg = np.meshgrid( np.linspace(0, 1, A.shape[0]), np.linspace(0, 1, A.shape[1]), indexing="xy", ) grid = list(zip(xg.ravel(), yg.ravel())) dens = A.ravel() / A.sum() dots = np.array(choices(grid, dens, k=n)) dots += (0.5 / A.shape[0]) * np.random.standard_normal(dots.shape) return torch.from_numpy(dots).type(dtype) def display_samples(ax, x, color): x_ = x.detach().cpu().numpy() ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none") .. GENERATED FROM PYTHON SOURCE LINES 65-74 Dataset ~~~~~~~~~~~~~~~~~~ Our source and target samples are drawn from intervals of the real line and define discrete probability measures: .. math:: \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~ \beta ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}. .. GENERATED FROM PYTHON SOURCE LINES 74-81 .. code-block:: default N, M = (100, 100) if not use_cuda else (10000, 10000) X_i = draw_samples("data/density_a.png", N, dtype) Y_j = draw_samples("data/density_b.png", M, dtype) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/comparisons/plot_gradient_flows_2D.py:38: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. img = imread(fname, as_gray=True) # Grayscale .. GENERATED FROM PYTHON SOURCE LINES 82-93 Wasserstein gradient flow ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To study the influence of the :math:`\text{Loss}` function in measure-fitting applications, we perform gradient descent on the positions :math:`x_i` of the samples that make up :math:`\alpha` as we minimize the cost :math:`\text{Loss}(\alpha,\beta)`. This procedure can be understood as a discrete (Lagrangian) `Wasserstein gradient flow `_ and as a "model-free" machine learning program, where we optimize directly on the samples' locations. .. GENERATED FROM PYTHON SOURCE LINES 93-156 .. code-block:: default def gradient_flow(loss, lr=0.05): """Flows along the gradient of the cost function, using a simple Euler scheme. Parameters: loss ((x_i,y_j) -> torch float number): Real-valued loss function. lr (float, default = .05): Learning rate, i.e. time step. """ # Parameters for the gradient descent Nsteps = int(5 / lr) + 1 display_its = [int(t / lr) for t in [0, 0.25, 0.50, 1.0, 2.0, 5.0]] # Use colors to identify the particles colors = (10 * X_i[:, 0]).cos() * (10 * X_i[:, 1]).cos() colors = colors.detach().cpu().numpy() # Make sure that we won't modify the reference samples x_i, y_j = X_i.clone(), Y_j.clone() # We're going to perform gradient descent on Loss(α, β) # wrt. the positions x_i of the diracs masses that make up α: x_i.requires_grad = True t_0 = time.time() plt.figure(figsize=(12, 8)) k = 1 for i in range(Nsteps): # Euler scheme =============== # Compute cost and gradient L_αβ = loss(x_i, y_j) [g] = torch.autograd.grad(L_αβ, [x_i]) if i in display_its: # display ax = plt.subplot(2, 3, k) k = k + 1 plt.set_cmap("hsv") plt.scatter( [10], [10] ) # shameless hack to prevent a slight change of axis... display_samples(ax, y_j, [(0.55, 0.55, 0.95)]) display_samples(ax, x_i, colors) ax.set_title("t = {:1.2f}".format(lr * i)) plt.axis([0, 1, 0, 1]) plt.gca().set_aspect("equal", adjustable="box") plt.xticks([], []) plt.yticks([], []) plt.tight_layout() # in-place modification of the tensor's values x_i.data -= lr * len(x_i) * g plt.title( "t = {:1.2f}, elapsed time: {:.2f}s/it".format( lr * i, (time.time() - t_0) / Nsteps ) ) .. GENERATED FROM PYTHON SOURCE LINES 157-168 Kernel norms, MMDs ------------------------------------ Gaussian MMD ~~~~~~~~~~~~~~~ The smooth Gaussian kernel :math:`k(x,y) = \exp(-\|x-y\|^2/2\sigma^2)` is blind to details which are smaller than the blurring scale :math:`\sigma`: its gradient stops being informative when :math:`\alpha` and :math:`\beta` become equal "up to the high frequencies". .. GENERATED FROM PYTHON SOURCE LINES 168-172 .. code-block:: default gradient_flow(SamplesLoss("gaussian", blur=0.5)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_001.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.01s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 173-178 On the other hand, if the radius :math:`\sigma` of the kernel is too small, particles :math:`x_i` won't be attracted to the target, and may **spread out** to minimize the auto-correlation term :math:`\tfrac{1}{2}\langle \alpha, k\star\alpha\rangle`. .. GENERATED FROM PYTHON SOURCE LINES 178-182 .. code-block:: default gradient_flow(SamplesLoss("gaussian", blur=0.1)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_002.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.01s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 183-190 Laplacian MMD ~~~~~~~~~~~~~~~~ The pointy exponential kernel :math:`k(x,y) = \exp(-\|x-y\|/\sigma)` tends to provide a better fit, but tends to zero at infinity and is still very prone to **screening artifacts**. .. GENERATED FROM PYTHON SOURCE LINES 190-194 .. code-block:: default gradient_flow(SamplesLoss("laplacian", blur=0.1)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_003.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.01s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 195-201 Energy Distance MMD ~~~~~~~~~~~~~~~~~~~~~~ The scale-equivariant kernel :math:`k(x,y)=-\|x-y\|` provides a robust baseline: the Energy Distance. .. GENERATED FROM PYTHON SOURCE LINES 201-207 .. code-block:: default # sphinx_gallery_thumbnail_number = 4 gradient_flow(SamplesLoss("energy")) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_004.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.01s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 208-224 Sinkhorn divergence ---------------------- (Unbiased) Sinkhorn divergences have recently been introduced in the machine learning litterature, and can be understood as modern iterations of the classic `SoftAssign `_ algorithm from `economics `_ and `computer vision `_. Wasserstein-1 distance ~~~~~~~~~~~~~~~~~~~~~~~~ When ``p = 1``, the Sinkhorn divergence :math:`\text{S}_\varepsilon` interpolates between the Energy Distance (when :math:`\varepsilon` is large): .. GENERATED FROM PYTHON SOURCE LINES 224-227 .. code-block:: default gradient_flow(SamplesLoss("sinkhorn", p=1, blur=1.0)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_005.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.02s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 228-230 And the Earth-Mover's (Wassertein-1) distance: .. GENERATED FROM PYTHON SOURCE LINES 230-234 .. code-block:: default gradient_flow(SamplesLoss("sinkhorn", p=1, blur=0.01)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_006.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.04s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 235-247 Wasserstein-2 distance ~~~~~~~~~~~~~~~~~~~~~~~~ When ``p = 2``, :math:`\text{S}_\varepsilon` interpolates between the degenerate kernel norm .. math:: \tfrac{1}{2}\| \alpha-\beta\|^2_{-\tfrac{1}{2}\|\cdot\|^2} ~=~ \tfrac{1}{2}\| \int x \text{d}\alpha(x)~-~\int y \text{d}\beta(y)\|^2, which only registers the means of both measures with each other (when :math:`\varepsilon` is large): .. GENERATED FROM PYTHON SOURCE LINES 247-250 .. code-block:: default gradient_flow(SamplesLoss("sinkhorn", p=2, blur=1.0)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_007.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.02s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_007.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 251-254 And the quadratic, Wasserstein-2 Optimal Transport distance which has been studied so well by mathematicians from the 80's onwards (when :math:`\varepsilon` is small): .. GENERATED FROM PYTHON SOURCE LINES 254-257 .. code-block:: default gradient_flow(SamplesLoss("sinkhorn", p=2, blur=0.01)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_008.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.04s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_008.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 258-270 Introduced in 2016-2018, the *unbalanced* setting (Gaussian-Hellinger, Wasserstein-Fisher-Rao, etc.) provides a principled way of introducing a **threshold** in Optimal Transport computations: it allows you to introduce **laziness** in the transportation problem by replacing distance fields :math:`\|x-y\|` with a robustified analogous :math:`\rho\cdot( 1 - e^{-\|x-y\|/\rho} )`, whose gradient saturates beyond a given **reach**, :math:`\rho` - at least, that's the idea. In real-life applications, this tunable parameter could allow you to be a little bit more **robust to outliers**! .. GENERATED FROM PYTHON SOURCE LINES 270-272 .. code-block:: default gradient_flow(SamplesLoss("sinkhorn", p=2, blur=0.01, reach=0.3)) .. image-sg:: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_009.png :alt: t = 0.00, t = 0.25, t = 0.50, t = 1.00, t = 2.00, t = 5.00, elapsed time: 0.04s/it :srcset: /_auto_examples/comparisons/images/sphx_glr_plot_gradient_flows_2D_009.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 20.557 seconds) .. _sphx_glr_download__auto_examples_comparisons_plot_gradient_flows_2D.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gradient_flows_2D.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gradient_flows_2D.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_