.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/optimal_transport/plot_optimal_transport_2D.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_optimal_transport_plot_optimal_transport_2D.py: Optimal Transport in 2D ========================= Let's use the gradient of the Sinkhorn divergence to compute an Optimal Transport map. .. GENERATED FROM PYTHON SOURCE LINES 11-13 Setup --------------------- .. GENERATED FROM PYTHON SOURCE LINES 13-24 .. code-block:: default import numpy as np import matplotlib.pyplot as plt import time import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor .. GENERATED FROM PYTHON SOURCE LINES 25-27 Display routines ~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 27-59 .. code-block:: default from random import choices from imageio import imread def load_image(fname): img = imread(fname, as_gray=True) # Grayscale img = (img[::-1, :]) / 255.0 return 1 - img def draw_samples(fname, n, dtype=torch.FloatTensor): A = load_image(fname) xg, yg = np.meshgrid( np.linspace(0, 1, A.shape[0]), np.linspace(0, 1, A.shape[1]), indexing="xy", ) grid = list(zip(xg.ravel(), yg.ravel())) dens = A.ravel() / A.sum() dots = np.array(choices(grid, dens, k=n)) dots += (0.5 / A.shape[0]) * np.random.standard_normal(dots.shape) return torch.from_numpy(dots).type(dtype) def display_samples(ax, x, color): x_ = x.detach().cpu().numpy() ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none") .. GENERATED FROM PYTHON SOURCE LINES 60-70 Dataset ~~~~~~~~~~~~~~~~~~ Our source and target samples are drawn from measures whose densities are stored in simple PNG files. They allow us to define a pair of discrete probability measures: .. math:: \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~ \beta ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}. .. GENERATED FROM PYTHON SOURCE LINES 70-77 .. code-block:: default N, M = (100, 100) if not use_cuda else (10000, 10000) X_i = draw_samples("data/density_a.png", N, dtype) Y_j = draw_samples("data/density_b.png", M, dtype) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/optimal_transport/plot_optimal_transport_2D.py:33: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. img = imread(fname, as_gray=True) # Grayscale .. GENERATED FROM PYTHON SOURCE LINES 78-81 Lagrangian gradient descent ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 81-142 .. code-block:: default def gradient_descent(loss, lr=1): """Flows along the gradient of the loss function. Parameters: loss ((x_i,y_j) -> torch float number): Real-valued loss function. lr (float, default = 1): Learning rate, i.e. time step. """ # Parameters for the gradient descent Nsteps = 11 display_its = [0, 1, 2, 10] # Use colors to identify the particles colors = (10 * X_i[:, 0]).cos() * (10 * X_i[:, 1]).cos() colors = colors.detach().cpu().numpy() # Make sure that we won't modify the reference samples x_i, y_j = X_i.clone(), Y_j.clone() # We're going to perform gradient descent on Loss(α, β) # wrt. the positions x_i of the diracs masses that make up α: x_i.requires_grad = True t_0 = time.time() plt.figure(figsize=(12, 12)) k = 1 for i in range(Nsteps): # Euler scheme =============== # Compute cost and gradient L_αβ = loss(x_i, y_j) [g] = torch.autograd.grad(L_αβ, [x_i]) if i in display_its: # display ax = plt.subplot(2, 2, k) k = k + 1 plt.set_cmap("hsv") plt.scatter( [10], [10] ) # shameless hack to prevent a slight change of axis... display_samples(ax, y_j, [(0.55, 0.55, 0.95)]) display_samples(ax, x_i, colors) ax.set_title("it = {}".format(i)) plt.axis([0, 1, 0, 1]) plt.gca().set_aspect("equal", adjustable="box") plt.xticks([], []) plt.yticks([], []) plt.tight_layout() # in-place modification of the tensor's values x_i.data -= lr * len(x_i) * g plt.title( "it = {}, elapsed time: {:.2f}s/it".format(i, (time.time() - t_0) / Nsteps) ) .. GENERATED FROM PYTHON SOURCE LINES 143-158 Wasserstein-2 Optimal Transport ---------------------------------- Sinkhorn divergences rely on blurry transport plans :math:`\pi_{\varepsilon,\rho}^{\alpha,\beta}`, :math:`\pi_{\varepsilon,\rho}^{\alpha,\alpha}` and :math:`\pi_{\varepsilon,\rho}^{\beta,\beta}`, solutions of the entropized transport problems that cannot be readily interpreted as deterministic maps. However, when **p = 2**, we can interpret the gradient field :math:`v_i \,=\, \tfrac{1}{\alpha_i} \nabla_{x_i} \text{S}_{\varepsilon,\rho}(\alpha,\beta)` as a Brenier-like transport plan, which maps source points :math:`x_i` to a barycenter :math:`x_i+v_i` of targets at scale :math:`\text{blur}\,=\,\sqrt{\varepsilon}`. .. GENERATED FROM PYTHON SOURCE LINES 158-162 .. code-block:: default gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.1)) .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_001.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 163-171 Crucially, as the blurring scale :math:`\sqrt{\varepsilon}` tends to zero, :math:`\pi_{\varepsilon,\rho}^{\alpha,\beta}` converges towards a "genuine" Monge map between :math:`\alpha` and :math:`\beta`, while :math:`\pi_{\varepsilon,\rho}^{\alpha,\alpha}` and :math:`\pi_{\varepsilon,\rho}^{\beta,\beta}` collapse to the identity maps. The Sinkhorn gradient then converges towards the **Brenier map** and allows us to register quickly our measures with each other. .. GENERATED FROM PYTHON SOURCE LINES 171-175 .. code-block:: default gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01)) .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_002.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 176-182 The ``reach`` parameter allows us to introduce **laziness** into the classical Monge problem, specifying a maximum scale (half-life) of interaction between the :math:`x_i`'s and the :math:`y_j`'s. It may be useful in situations where **outliers** are common, as it limits the influence of samples that are too far away. .. GENERATED FROM PYTHON SOURCE LINES 182-186 .. code-block:: default gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01, reach=0.1)) .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_003.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 187-198 Optimal Transport is *not* the panacea ----------------------------------------------- Optimal Transport theory is all about **discarding the topological structure** of the data to get a simple, convex registration algorithm: the Monge map transports **bags of sands** from one location to another, and may tear shapes apart as needed. In generative modelling, this versatility allows us to fit "Gaussian blobs" to any kind of empirical distribution: .. GENERATED FROM PYTHON SOURCE LINES 198-204 .. code-block:: default X_i = draw_samples("data/crescent_a.png", N, dtype) Y_j = draw_samples("data/crescent_b.png", M, dtype) gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01)) .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_004.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/optimal_transport/plot_optimal_transport_2D.py:33: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. img = imread(fname, as_gray=True) # Grayscale .. GENERATED FROM PYTHON SOURCE LINES 205-207 Going further, in simple situations, Optimal Transport may even be used as a "cheap and easy" registration routine... .. GENERATED FROM PYTHON SOURCE LINES 207-213 .. code-block:: default X_i = draw_samples("data/worm_a.png", N, dtype) Y_j = draw_samples("data/worm_b.png", M, dtype) gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01)) .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_005.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/optimal_transport/plot_optimal_transport_2D.py:33: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. img = imread(fname, as_gray=True) # Grayscale .. GENERATED FROM PYTHON SOURCE LINES 214-222 But beware! Out-of-the-box, Optimal Transport will **not** match the salient features of both shapes (e.g. ends or corners) with each other. In real-life applications, Sinkhorn divergences should thus always be used in a relevant **feature space** (e.g. of SIFT descriptors), in conjunction with a prior-enforcing **generative model** (e.g. a convolutional neural network or a thin plate spline deformation). .. GENERATED FROM PYTHON SOURCE LINES 222-228 .. code-block:: default X_i = draw_samples("data/moon_a.png", N, dtype) Y_j = draw_samples("data/moon_b.png", M, dtype) gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01)) plt.show() .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_006.png :alt: it = 0, it = 1, it = 2, it = 10, elapsed time: 0.04s/it :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_optimal_transport_2D_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/optimal_transport/plot_optimal_transport_2D.py:33: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. img = imread(fname, as_gray=True) # Grayscale .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 4.617 seconds) .. _sphx_glr_download__auto_examples_optimal_transport_plot_optimal_transport_2D.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_optimal_transport_2D.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_optimal_transport_2D.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_