.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/optimal_transport/plot_wasserstein_barycenters_2D.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_optimal_transport_plot_wasserstein_barycenters_2D.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_optimal_transport_plot_wasserstein_barycenters_2D.py:


Wasserstein barycenters in 2D
==================================

Let's compute pseudo-Wasserstein barycenters between 2D densities,
using the gradient of the Sinkhorn divergence as a cheap approximation
of the Monge map.

.. GENERATED FROM PYTHON SOURCE LINES 11-13

Setup
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 13-26

.. code-block:: Python


    import numpy as np
    import matplotlib.pyplot as plt
    from imageio import imread
    from sklearn.neighbors import KernelDensity
    from torch.nn.functional import avg_pool2d

    import torch
    from geomloss import SamplesLoss

    use_cuda = torch.cuda.is_available()
    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor








.. GENERATED FROM PYTHON SOURCE LINES 27-32

Dataset
~~~~~~~~~~~~~~~~~~

In this tutorial, we work with square images
understood as densities on the unit square.

.. GENERATED FROM PYTHON SOURCE LINES 32-57

.. code-block:: Python



    def grid(W):
        x, y = torch.meshgrid([torch.arange(0.0, W).type(dtype) / W] * 2, indexing="xy")
        return torch.stack((x, y), dim=2).view(-1, 2)


    def load_image(fname):
        img = np.mean(imread(fname), axis=2)  # Grayscale
        img = (img[:, :]) / 255.0
        return 1 - img  # black = 1, white = 0


    def as_measure(fname, size):
        weights = torch.from_numpy(load_image(fname)).type(dtype)
        sampling = weights.shape[0] // size
        weights = (
            avg_pool2d(weights.unsqueeze(0).unsqueeze(0), sampling).squeeze(0).squeeze(0)
        )
        weights = weights / weights.sum()

        samples = grid(size)
        return weights.view(-1), samples









.. GENERATED FROM PYTHON SOURCE LINES 58-61

To perform Lagrangian computations,
we turn these **png** bitmaps into **weighted point clouds**,
regularly spaced on a grid:

.. GENERATED FROM PYTHON SOURCE LINES 61-68

.. code-block:: Python


    N, M = (8, 8) if not use_cuda else (128, 64)

    A, B = as_measure("data/A.png", M), as_measure("data/B.png", M)
    C, D = as_measure("data/C.png", M), as_measure("data/D.png", M)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/code/geomloss/geomloss/examples/optimal_transport/plot_wasserstein_barycenters_2D.py:40: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.
      img = np.mean(imread(fname), axis=2)  # Grayscale




.. GENERATED FROM PYTHON SOURCE LINES 69-71

The starting point of our algorithm is
a finely grained uniform sample on the unit square:

.. GENERATED FROM PYTHON SOURCE LINES 71-77

.. code-block:: Python


    x_i = grid(N).view(-1, 2)
    a_i = (torch.ones(N * N) / (N * N)).type_as(x_i)

    x_i.requires_grad = True








.. GENERATED FROM PYTHON SOURCE LINES 78-84

Display routine
~~~~~~~~~~~~~~~~~

To display our interpolating point clouds, we put points
into square bins and display the resulting density,
using an appropriate threshold to mitigate quantization artifacts:

.. GENERATED FROM PYTHON SOURCE LINES 84-104

.. code-block:: Python


    import matplotlib

    matplotlib.rc("image", cmap="gray")

    grid_plot = grid(M).view(-1, 2).cpu().numpy()


    def display_samples(ax, x, weights=None):
        """Displays samples on the unit square using a simple binning algorithm."""
        x = x.clamp(0, 1 - 0.1 / M)
        bins = (x[:, 0] * M).floor() + M * (x[:, 1] * M).floor()
        count = bins.int().bincount(weights=weights, minlength=M * M)
        ax.imshow(
            count.detach().float().view(M, M).cpu().numpy(),
            vmin=0,
            vmax=0.5 * count.max().item(),
        )









.. GENERATED FROM PYTHON SOURCE LINES 105-131

In the :doc:`notebook on Wasserstein barycenters <plot_wasserstein_barycenters_1D>`,
we've seen how to solve generic optimization problems
of the form

.. math::
      \alpha^\star~=~\arg\min_\alpha
            w_a \cdot \text{S}_{\varepsilon,\rho}(\,\alpha,\,A\,)
       ~&+~  w_b \cdot \text{S}_{\varepsilon,\rho}(\,\alpha,\,B\,) \\
       ~+~  w_c \cdot \text{S}_{\varepsilon,\rho}(\,\alpha,\,C\,)
        ~&+~  w_d \cdot \text{S}_{\varepsilon,\rho}(\,\alpha,\,D\,)

using Eulerian and Lagrangian schemes.

Focusing on the Lagrangian descent, a **single** (weighted)
**gradient step** on the points :math:`x_i`
that make up the variable distribution
:math:`\alpha = \sum_{i=1}^N \alpha_i \delta_{x_i}`
results in an update

.. math::
      x_i ~\gets~ x_i + w_a\cdot v_i^A + w_b\cdot v_i^B + w_c\cdot v_i^C  + w_d\cdot v_i^D,

where the :math:`\,v_i^A\,=\,-\tfrac{1}{\alpha_i}\nabla_{x_i}\text{S}_{\varepsilon,\rho}(\,\alpha,\,A\,)\,`, etc.
are the displacement vectors that map the starting (uniform) sample :math:`\alpha`
to the target measures
:math:`A`, :math:`B`, :math:`C` and :math:`D`.

.. GENERATED FROM PYTHON SOURCE LINES 131-141

.. code-block:: Python


    Loss = SamplesLoss("sinkhorn", blur=0.01, scaling=0.9)
    models = []
    for b_j, y_j in [A, B, C, D]:
        L_ab = Loss(a_i, x_i, b_j, y_j)
        [g_i] = torch.autograd.grad(L_ab, [x_i])
        models.append(x_i - g_i / a_i.view(-1, 1))

    a, b, c, d = models








.. GENERATED FROM PYTHON SOURCE LINES 142-152

If the weights :math:`w_k` sum up to 1, this update is a barycentric
combination of the **target points** :math:`x_i + v_i^A`, :math:`~\dots\,`, :math:`x_i + v_i^D`,
images of the source sample :math:`x_i`
under the action of the :doc:`generalized Monge/Brenier maps <../sinkhorn_multiscale/plot_epsilon_scaling>` that transport
our uniform sample onto the four target measures.

Using the resulting sample as an **ersatz for the true Wasserstein barycenter**
is thus an approximation that holds in dimension 1, and is reasonable
for most applications. As evidenced below, it allows us to interpolate
between arbitrary densities at a low numerical cost:

.. GENERATED FROM PYTHON SOURCE LINES 152-188

.. code-block:: Python


    plt.figure(figsize=(14, 14))

    # Display the target measures in the corners of our Figure
    ax = plt.subplot(7, 7, 1)
    ax.imshow(A[0].reshape(M, M).cpu())
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    ax = plt.subplot(7, 7, 7)
    ax.imshow(B[0].reshape(M, M).cpu())
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    ax = plt.subplot(7, 7, 43)
    ax.imshow(C[0].reshape(M, M).cpu())
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    ax = plt.subplot(7, 7, 49)
    ax.imshow(D[0].reshape(M, M).cpu())
    ax.set_xticks([], [])
    ax.set_yticks([], [])

    # Display the interpolating densities as a 5x5 waffle plot
    for i in range(5):
        for j in range(5):
            x, y = j / 4, i / 4
            barycenter = (
                (1 - x) * (1 - y) * a + x * (1 - y) * b + (1 - x) * y * c + x * y * d
            )

            ax = plt.subplot(7, 7, 7 * (i + 1) + j + 2)
            display_samples(ax, barycenter)
            ax.set_xticks([], [])
            ax.set_yticks([], [])

    plt.tight_layout()
    plt.show()



.. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_2D_001.png
   :alt: plot wasserstein barycenters 2D
   :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_2D_001.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.130 seconds)


.. _sphx_glr_download__auto_examples_optimal_transport_plot_wasserstein_barycenters_2D.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_wasserstein_barycenters_2D.ipynb <plot_wasserstein_barycenters_2D.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_wasserstein_barycenters_2D.py <plot_wasserstein_barycenters_2D.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_