.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/brain_tractograms/track_barycenter.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_brain_tractograms_track_barycenter.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_brain_tractograms_track_barycenter.py:


Create an atlas using Wasserstein barycenters
==================================================

In this tutorial, we compute the barycenter of a dataset of probability tracks. 
The barycenter is computed as the Fréchet mean for the Sinkhorn divergence, using a Lagrangian optimization scheme. 

.. GENERATED FROM PYTHON SOURCE LINES 10-12

Setup
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 12-36

.. code-block:: Python


    import numpy as np
    import matplotlib.pyplot as plt
    from scipy import misc
    from sklearn.neighbors import KernelDensity
    from torch.nn.functional import avg_pool2d
    import torch
    from geomloss import SamplesLoss
    import time

    use_cuda = torch.cuda.is_available()
    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    from scipy.interpolate import RegularGridInterpolator


    import gzip
    import shutil
    import pdb


    import nibabel as nib
    import matplotlib.pyplot as plt



.. GENERATED FROM PYTHON SOURCE LINES 37-41

Dataset
~~~~~~~~~~~~~~~~~~

In this tutorial, we work with probability tracks, that can be understood as normalized 3D images. We will compute the Wasserstein barycenter of this dataset.

.. GENERATED FROM PYTHON SOURCE LINES 41-97

.. code-block:: Python



    import os


    def fetch_file(name):
        if not os.path.exists(f"data/{name}.nii.gz"):
            import urllib.request

            print("Fetching the atlas... ", end="", flush=True)
            urllib.request.urlretrieve(
                f"https://www.kernel-operations.io/data/{name}.nii.gz",
                f"data/{name}.nii.gz",
            )
            with gzip.open(f"data/{name}.nii.gz", "rb") as f_in:
                with open(f"data/{name}.nii", "wb") as f_out:
                    shutil.copyfileobj(f_in, f_out)
            print("Done.")


    for i in range(5):
        fetch_file(f"manual_ifof{i+1}")


    affine_transform = nib.load("data/manual_ifof1.nii").affine


    # load data in the nii format to a 3D, normalized array.
    def load_data_nii(fname):
        img = nib.load(fname)
        affine_mat = img.affine
        hdr = img.header
        data = img.get_fdata()
        data_norm = data / np.max(data)
        data_norm = torch.from_numpy(data_norm).type(dtype)
        return data_norm


    def grid(nx, ny, nz):
        x, y, z = torch.meshgrid(
            torch.arange(0.0, nx).type(dtype),
            torch.arange(0.0, ny).type(dtype),
            torch.arange(0.0, nz).type(dtype),
            indexing="ij",
        )
        return torch.stack((x, y, z), dim=3).view(-1, 3).detach().cpu().numpy()


    # load the data set (here, we have 5 subjects)
    dataset = []
    for i in range(5):
        fname = "data/manual_ifof" + str(i + 1) + ".nii"
        image_norm = load_data_nii(fname)
        print(image_norm.shape)
        dataset.append(image_norm)


.. GENERATED FROM PYTHON SOURCE LINES 98-99

In this tutorial, we work with 3D images, understood as densities on the cube.

.. GENERATED FROM PYTHON SOURCE LINES 99-119

.. code-block:: Python



    def img_to_points_cloud(data_norm):  # normalized images (between 0 and 1)
        nx, ny, nz = data_norm.shape
        ind = data_norm.nonzero()
        indx = ind[:, 0]
        indy = ind[:, 1]
        indz = ind[:, 2]
        data_norm = data_norm / data_norm.sum()
        a_i = data_norm[indx, indy, indz]

        return ind.type(dtype), a_i


    def measure_to_image(x, nx, ny, nz, weights=None):
        bins = (x[:, 2]).floor() + nz * (x[:, 1]).floor() + nz * ny * (x[:, 0]).floor()
        count = bins.int().bincount(weights=weights, minlength=nx * ny * nz)
        return count.view(nx, ny, nz)



.. GENERATED FROM PYTHON SOURCE LINES 120-121

To perform our computations, we turn these 3D arrays into weighted point cloud, regularly spaced in the grid.

.. GENERATED FROM PYTHON SOURCE LINES 121-130

.. code-block:: Python



    a, b = img_to_points_cloud(dataset[0]), img_to_points_cloud(dataset[1])
    c, d, e = (
        img_to_points_cloud(dataset[2]),
        img_to_points_cloud(dataset[3]),
        img_to_points_cloud(dataset[4]),
    )


.. GENERATED FROM PYTHON SOURCE LINES 131-132

We initialize the barycenter as an upsampled, arithmetic mean of the data set.

.. GENERATED FROM PYTHON SOURCE LINES 132-157

.. code-block:: Python



    nx, ny, nz = image_norm.shape


    def initialize_barycenter(dataset):
        mean = torch.zeros(nx, ny, nz).type(dtype)
        for k in range(len(dataset)):
            img = dataset[k]
            mean = mean + img
        mean = mean / len(dataset)
        x_i, a_i = img_to_points_cloud(mean)
        bar_pos, bar_weight = torch.tensor([]).type(dtype), torch.tensor([]).type(dtype)
        for d in range(3):
            x_i_d1, x_i_d2 = x_i.clone(), x_i.clone()
            x_i_d1[:, d], a_i_d1 = x_i_d1[:, d] + 0.25, a_i / 6
            x_i_d2[:, d], a_i_d2 = x_i_d2[:, d] - 0.25, a_i / 6
            bar_pos, bar_weight = torch.cat((bar_pos, x_i_d1, x_i_d2), 0), torch.cat(
                (bar_weight, a_i_d1, a_i_d2), 0
            )
        return bar_pos, bar_weight


    x_i, a_i = initialize_barycenter(dataset)


.. GENERATED FROM PYTHON SOURCE LINES 158-160

The barycenter will be the minimizer of the sum of Sinkhorn distances to the dataset.
It is computed through a Lagrangian gradient descent on the particles' positions.

.. GENERATED FROM PYTHON SOURCE LINES 160-181

.. code-block:: Python


    Loss = SamplesLoss("sinkhorn", blur=1, scaling=0.9, debias=False)
    models = []
    x_i.requires_grad = True


    start = time.time()
    for j in range(len(dataset)):
        img_j = dataset[j]
        y_j, b_j = img_to_points_cloud(img_j)
        L_ab = Loss(a_i, x_i, b_j, y_j)
        [g_i] = torch.autograd.grad(L_ab, [x_i])
        models.append(x_i - g_i / a_i.view(-1, 1))

    a, b, c, d, e = models
    barycenter = (a + b + c + d + e) / 5
    if use_cuda:
        torch.cuda.synchronize()
    end = time.time()
    print("barycenter computed in {:.3f}s.".format(end - start))


.. GENERATED FROM PYTHON SOURCE LINES 182-183

We can plot slices of the computed barycenters

.. GENERATED FROM PYTHON SOURCE LINES 183-188

.. code-block:: Python

    img_barycenter = measure_to_image(barycenter, nx, ny, nz, a_i)
    plt.figure()
    plt.imshow(img_barycenter.detach().cpu().numpy()[20, :, :])
    plt.show()


.. GENERATED FROM PYTHON SOURCE LINES 189-190

Or save the 3D image in .nii format, once put in the same coordinates system as the data images.

.. GENERATED FROM PYTHON SOURCE LINES 190-200

.. code-block:: Python

    linear_transform_inv = np.linalg.inv(affine_transform[:3, :3])
    translation_inv = -affine_transform[:3, 3]
    affine_inv = np.r_[
        np.c_[linear_transform_inv, translation_inv], np.array([[0, 0, 0, 1]])
    ]
    barycenter_nib = nib.Nifti1Image(
        521 * (img_barycenter / img_barycenter.max()).detach().cpu().numpy(),
        affine_transform,
    )
    nib.save(barycenter_nib, "barycenter_image.nii")


.. _sphx_glr_download__auto_examples_brain_tractograms_track_barycenter.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: track_barycenter.ipynb <track_barycenter.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: track_barycenter.py <track_barycenter.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_