.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/optimal_transport/plot_wasserstein_barycenters_1D.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_optimal_transport_plot_wasserstein_barycenters_1D.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_optimal_transport_plot_wasserstein_barycenters_1D.py:


Wasserstein barycenters in 1D
==================================

Let's compute Wasserstein barycenters
with a Sinkhorn divergence,
using Eulerian and Lagrangian optimization schemes.

.. GENERATED FROM PYTHON SOURCE LINES 11-13

Setup
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 13-24

.. code-block:: Python


    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.neighbors import KernelDensity  # display as density curves

    import torch
    from geomloss import SamplesLoss

    use_cuda = torch.cuda.is_available()
    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor








.. GENERATED FROM PYTHON SOURCE LINES 25-53

Dataset
~~~~~~~~~~~~~~~~~~

Given a weight :math:`w\in[0,1]`
and two *endpoint* measures :math:`\alpha`
and :math:`\beta`, we wish to compute
the **Sinkhorn barycenter**

.. math::
      \gamma^\star ~=~ \arg\min_{\gamma}~
          (1-w)\cdot\text{S}_{\varepsilon,\rho}(\gamma,\alpha)
         \,+\, w\cdot\text{S}_{\varepsilon,\rho}(\gamma,\beta),

which coincides with :math:`\alpha` when :math:`w=0`
and with :math:`\beta` when :math:`w=1`.

If our input measures

.. math::
  \alpha ~=~ \frac{1}{M}\sum_{i=1}^M \delta_{x_i}, ~~~
  \beta  ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j},

are fixed, the optimization problem
above is `known to be convex <https://arxiv.org/abs/1810.08278>`_  with
respect to the weights :math:`\gamma_k` of the *variable* measure

.. math::
  \gamma  ~=~ \sum_{k=1}^N \gamma_k\,\delta_{z_k}.

.. GENERATED FROM PYTHON SOURCE LINES 53-63

.. code-block:: Python



    N, M = (50, 50) if not use_cuda else (500, 500)

    t_i = torch.linspace(0, 1, M).type(dtype).view(-1, 1)
    t_j = torch.linspace(0, 1, M).type(dtype).view(-1, 1)

    X_i, Y_j = 0.1 * t_i, 0.2 * t_j + 0.8  # Intervals [0., 0.1] and [.8, 1.].









.. GENERATED FROM PYTHON SOURCE LINES 64-73

In this notebook, we thus propose to solve the barycentric
optimization problem through a (quasi-)convex optimization
on the (log-)weights :math:`\log(\gamma_k)` - with fixed :math:`\delta_{z_k}`'s -
and through a well-conditioned
descent on the samples' positions :math:`\delta_{z_k}`
- with uniform weights :math:`\gamma_k = 1/N`.

In both sections (Eulerian vs. Lagrangian), we'll start from a
uniform sample on the unit interval:

.. GENERATED FROM PYTHON SOURCE LINES 73-78

.. code-block:: Python


    t_k = torch.linspace(0, 1, N).type(dtype).view(-1, 1)
    Z_k = t_k









.. GENERATED FROM PYTHON SOURCE LINES 79-84

Display routine
~~~~~~~~~~~~~~~~~

We display our samples using (smoothed) density
curves, computed with a straightforward Gaussian convolution:

.. GENERATED FROM PYTHON SOURCE LINES 84-100

.. code-block:: Python


    t_plot = np.linspace(-0.1, 1.1, 1000)[:, np.newaxis]


    def display_samples(ax, x, color, weights=None, blur=0.002):
        """Displays samples on the unit interval using a density curve."""
        kde = KernelDensity(kernel="gaussian", bandwidth=blur).fit(
            x.data.cpu().numpy(),
            sample_weight=None if weights is None else weights.data.cpu().numpy(),
        )
        dens = np.exp(kde.score_samples(t_plot))
        dens[0] = 0
        dens[-1] = 0
        ax.fill(t_plot, dens, color=color)









.. GENERATED FROM PYTHON SOURCE LINES 101-109

Eulerian gradient flow
------------------------------------------

Taking advantage of the **convexity** of Sinkhorn divergences
with respect to the measures' weights, we first solve
the barycentric optimization problem through a
(quasi-convex) **Eulerian**
descent on the **log-weights** :math:`l_k = \log(\gamma_k)`:

.. GENERATED FROM PYTHON SOURCE LINES 109-198

.. code-block:: Python



    from geomloss.examples.optimal_transport.model_fitting import (
        fit_model,
    )  # Wrapper around scipy.optimize
    from torch.nn import Module, Parameter  # PyTorch syntax for optimization problems


    class Barycenter(Module):
        """Abstract model for the computation of Sinkhorn barycenters."""

        def __init__(self, loss, w=0.5):
            super(Barycenter, self).__init__()
            self.loss = loss  # Sinkhorn divergence to optimize
            self.w = w  # Interpolation coefficient
            # We copy the reference starting points, to prevent in-place modification:
            self.x_i, self.y_j, self.z_k = X_i.clone(), Y_j.clone(), Z_k.clone()

        def fit(self, display=False, tol=1e-10):
            """Uses a custom wrapper around the scipy.optimize module."""
            fit_model(self, method="L-BFGS", lr=1.0, display=display, tol=tol, gtol=tol)

        def weights(self):
            """The default weights are uniform, equal to 1/N."""
            return (torch.ones(len(self.z_k)) / len(self.z_k)).type_as(self.z_k)

        def plot(self, nit=0, cost=0, ax=None, title=None):
            """Displays the descent using a custom 'waffle' layout.

            N.B.: As the L-BFGS descent typically induces high-frequencies in
                  the optimization process, we blur the 'interpolating' measure
                  a little bit more than the two endpoints.
            """
            if ax is None:
                if nit == 0 or nit % 16 == 4:
                    plt.pause(0.01)
                    plt.figure(figsize=(16, 4))

                if nit <= 4 or nit % 4 == 0:
                    if nit < 4:
                        index = nit + 1
                    else:
                        index = (nit // 4 - 1) % 4 + 1
                    ax = plt.subplot(1, 4, index)

            if ax is not None:
                display_samples(ax, self.x_i, (0.95, 0.55, 0.55))
                display_samples(ax, self.y_j, (0.55, 0.55, 0.95))
                display_samples(
                    ax, self.z_k, (0.55, 0.95, 0.55), weights=self.weights(), blur=0.005
                )

                if title is None:
                    ax.set_title("nit = {}, cost = {:3.4f}".format(nit, cost))
                else:
                    ax.set_title(title)

                ax.axis([-0.1, 1.1, -0.1, 20.5])
                ax.set_xticks([], [])
                ax.set_yticks([], [])
                plt.tight_layout()


    class EulerianBarycenter(Barycenter):
        """Barycentric model with fixed locations z_k, as we optimize on the log-weights l_k."""

        def __init__(self, loss, w=0.5):
            super(EulerianBarycenter, self).__init__(loss, w)

            # We're going to work with variable weights, so we should explicitely
            # define the (uniform) weights on the "endpoint" samples:
            self.a_i = (torch.ones(len(self.x_i)) / len(self.x_i)).type_as(self.x_i)
            self.b_j = (torch.ones(len(self.y_j)) / len(self.y_j)).type_as(self.y_j)

            # Our parameter to optimize: the logarithms of our weights
            self.l_k = Parameter(torch.zeros(len(self.z_k)).type_as(self.z_k))

        def weights(self):
            """Turns the l_k's into the weights of a positive probabilty measure."""
            return torch.nn.functional.softmax(self.l_k, dim=0)

        def forward(self):
            """Returns the cost to minimize."""
            c_k = self.weights()
            return self.w * self.loss(c_k, self.z_k, self.a_i, self.x_i) + (
                1 - self.w
            ) * self.loss(c_k, self.z_k, self.b_j, self.y_j)









.. GENERATED FROM PYTHON SOURCE LINES 199-203

For this first experiment, we err on the side of caution
and use a small **blur** value in conjuction
with a large **scaling** coefficient - i.e. a large number of iterations
in the Sinkhorn loop:

.. GENERATED FROM PYTHON SOURCE LINES 203-206

.. code-block:: Python


    EulerianBarycenter(SamplesLoss("sinkhorn", blur=0.001, scaling=0.99)).fit(display=True)




.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_001.png
         :alt: nit = 0, cost = 0.1210, nit = 1, cost = 0.1195, nit = 2, cost = 0.1144, nit = 3, cost = 0.0935
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_001.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_002.png
         :alt: nit = 4, cost = 0.0916, nit = 8, cost = 0.0905, nit = 12, cost = 0.0904, nit = 16, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_002.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_003.png
         :alt: nit = 20, cost = 0.0904, nit = 24, cost = 0.0904, nit = 28, cost = 0.0904, nit = 32, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_003.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_004.png
         :alt: nit = 36, cost = 0.0904, nit = 40, cost = 0.0904, nit = 44, cost = 0.0904, nit = 48, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_004.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_005.png
         :alt: nit = 52, cost = 0.0904, nit = 56, cost = 0.0904, nit = 60, cost = 0.0904, nit = 64, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_005.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_006.png
         :alt: nit = 68, cost = 0.0904, nit = 72, cost = 0.0904, nit = 76, cost = 0.0904, nit = 80, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_006.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_007.png
         :alt: nit = 84, cost = 0.0904, nit = 88, cost = 0.0904, nit = 92, cost = 0.0904, nit = 96, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_007.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_008.png
         :alt: nit = 100, cost = 0.0904, nit = 104, cost = 0.0904, nit = 108, cost = 0.0904, nit = 112, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_008.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_009.png
         :alt: nit = 116, cost = 0.0904, nit = 120, cost = 0.0904, nit = 124, cost = 0.0904, nit = 128, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_009.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_010.png
         :alt: nit = 132, cost = 0.0904, nit = 136, cost = 0.0904, nit = 140, cost = 0.0904, nit = 144, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_010.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_011.png
         :alt: nit = 148, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_011.png
         :class: sphx-glr-multi-img





.. GENERATED FROM PYTHON SOURCE LINES 207-212

As evidenced here, the **Eulerian** descent fits **one by one**
the Fourier modes of the "true" Wasserstein barycenter:
we start from a Gaussian blob and progressively
integrate the higher frequencies, slowly converging
towards a **sharp** step function.

.. GENERATED FROM PYTHON SOURCE LINES 215-224

Lagrangian gradient flow
------------------------------------------

The procedure above is theoretically sound (thanks to the **convexity** of
Sinkhorn divergences),
but may be too slow for practical purposes.
A simple workaround is to tackle the barycentric interpolation problem
using a Lagrangian, particular scheme and optimize our weighted
loss with respect to the **samples' positions**:

.. GENERATED FROM PYTHON SOURCE LINES 224-241

.. code-block:: Python



    class LagrangianBarycenter(Barycenter):
        def __init__(self, loss, w=0.5):
            super(LagrangianBarycenter, self).__init__(loss, w)

            # Our parameter to optimize: the locations of the input samples
            self.z_k = Parameter(Z_k.clone())

        def forward(self):
            """Returns the cost to minimize."""
            # By default, the weights are uniform and sum up to 1:
            return self.w * self.loss(self.z_k, self.x_i) + (1 - self.w) * self.loss(
                self.z_k, self.y_j
            )









.. GENERATED FROM PYTHON SOURCE LINES 242-244

As evidenced below, this algorithm converges quickly towards
a decent interpolator, even for small-ish values of the scaling coefficient:

.. GENERATED FROM PYTHON SOURCE LINES 244-248

.. code-block:: Python


    LagrangianBarycenter(SamplesLoss("sinkhorn", blur=0.01, scaling=0.9)).fit(display=True)





.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_012.png
         :alt: nit = 0, cost = 0.1210, nit = 1, cost = 0.1109, nit = 2, cost = 0.0904, nit = 3, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_012.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_013.png
         :alt: nit = 4, cost = 0.0904, nit = 8, cost = 0.0904, nit = 12, cost = 0.0904, nit = 16, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_013.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_014.png
         :alt: nit = 20, cost = 0.0904
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_014.png
         :class: sphx-glr-multi-img





.. GENERATED FROM PYTHON SOURCE LINES 249-255

This algorithm can be understood as a generalization
of :doc:`Optimal Transport registration <plot_optimal_transport_2D>`
to **multi-target** applications and can be used
to compute efficiently some :doc:`Wasserstein barycenters in 2D <plot_wasserstein_barycenters_2D>`.
The trade-off between speed and accuracy (especially with respect to oscillating artifacts)
can be tuned with the **tol** and **scaling** parameters:

.. GENERATED FROM PYTHON SOURCE LINES 255-260

.. code-block:: Python


    LagrangianBarycenter(SamplesLoss("sinkhorn", blur=0.01, scaling=0.5)).fit(
        display=True, tol=1e-5
    )
    plt.show()



.. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_015.png
   :alt: nit = 0, cost = 0.1210, nit = 1, cost = 0.1109, nit = 2, cost = 0.0904, nit = 3, cost = 0.0904
   :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_wasserstein_barycenters_1D_015.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (1 minutes 45.727 seconds)


.. _sphx_glr_download__auto_examples_optimal_transport_plot_wasserstein_barycenters_1D.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_wasserstein_barycenters_1D.ipynb <plot_wasserstein_barycenters_1D.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_wasserstein_barycenters_1D.py <plot_wasserstein_barycenters_1D.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_