.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/brain_tractograms/transfer_labels.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_brain_tractograms_transfer_labels.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_brain_tractograms_transfer_labels.py:


Transferring labels from a segmented atlas
=============================================

We use a new multiscale algorithm for solving regularized Optimal Transport 
problems on the GPU, with a linear memory footprint. 

We use the resulting smooth assignments to perform label transfer for atlas-based 
segmentation of fiber tractograms. The parameters -- \emph{blur} and \emph{reach} -- 
of our method are meaningful, defining the minimum and maximum distance at which 
two fibers are compared with each other. They can be set according to anatomical knowledge.

.. GENERATED FROM PYTHON SOURCE LINES 16-20

Setup
---------------------

Standard imports:

.. GENERATED FROM PYTHON SOURCE LINES 20-31

.. code-block:: Python


    import numpy as np
    import matplotlib.pyplot as plt
    import time
    import torch
    from geomloss import SamplesLoss

    use_cuda = torch.cuda.is_available()
    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    dtypeint = torch.cuda.LongTensor if use_cuda else torch.LongTensor


.. GENERATED FROM PYTHON SOURCE LINES 32-34

Loading and saving data routines


.. GENERATED FROM PYTHON SOURCE LINES 34-39

.. code-block:: Python


    from tract_io import read_vtk, streamlines_resample, save_vtk, save_vtk_labels
    from tract_io import save_tract, save_tract_numpy
    from tract_io import save_tract_with_labels, save_tracts_labels_separate


.. GENERATED FROM PYTHON SOURCE LINES 40-44

Dataset
---------------------

Fetch data from the KeOps website:

.. GENERATED FROM PYTHON SOURCE LINES 44-64

.. code-block:: Python


    import os


    def fetch_file(name):
        if not os.path.exists(f"data/{name}.npy"):
            import urllib.request

            print("Fetching the atlas... ", end="", flush=True)
            urllib.request.urlretrieve(
                f"https://www.kernel-operations.io/data/{name}.npy", f"data/{name}.npy"
            )
            print("Done.")


    fetch_file("tracto_atlas")
    fetch_file("atlas_labels")
    fetch_file("tracto1")



.. GENERATED FROM PYTHON SOURCE LINES 65-68

Fibers do not have a canonical orientation. Since our ground distance is a simple
L2-distance on the sampled fibers, we augment the dataset with the mirror flip
of all fibers and perform the OT on this augmented dataset.

.. GENERATED FROM PYTHON SOURCE LINES 68-82

.. code-block:: Python



    def torch_load(X, dtype=dtype):
        return torch.from_numpy(X).type(dtype).contiguous()


    def add_flips(X):
        """Adds flips and loads on the GPU the input fiber track."""
        #    X = X[:,None,:,:]
        X_flip = torch.flip(X, (1,))
        X = torch.stack((X, X_flip), dim=1)  # (Nfibers, 2, NPOINTS, 3)
        return X



.. GENERATED FROM PYTHON SOURCE LINES 83-87

Source atlas
~~~~~~~~~~~~~~~~~~~

Load atlas (segmented, each fiber has a label):

.. GENERATED FROM PYTHON SOURCE LINES 87-91

.. code-block:: Python


    Y_j = torch_load(np.load("data/tracto_atlas.npy"))
    labels_j = torch_load(np.load("data/atlas_labels.npy"), dtype=dtypeint)


.. GENERATED FROM PYTHON SOURCE LINES 93-96

.. code-block:: Python


    M, NPOINTS = Y_j.shape[0], Y_j.shape[1]  # Number of fibers, points per fiber


.. GENERATED FROM PYTHON SOURCE LINES 98-101

.. code-block:: Python


    Y_j = Y_j.view(M, NPOINTS, 3) / np.sqrt(NPOINTS)


.. GENERATED FROM PYTHON SOURCE LINES 103-105

.. code-block:: Python


    Y_j = add_flips(Y_j)  # Shape (M, 2, NPOINTS, 3)

.. GENERATED FROM PYTHON SOURCE LINES 106-111

Target subject
~~~~~~~~~~~~~~~~~~~~

Load a new subject (unlabelled)


.. GENERATED FROM PYTHON SOURCE LINES 111-125

.. code-block:: Python


    X_i = torch_load(np.load("data/tracto1.npy"))
    N, NPOINTS_i = X_i.shape[0], X_i.shape[1]  # Number of fibers, points per fiber

    if NPOINTS != NPOINTS_i:
        raise ValueError(
            "The atlas and the subject are not sampled with the same number of points: "
            + f"{NPOINTS} and {NPOINTS_i}, respectively."
        )

    X_i = X_i.view(N, NPOINTS, 3) / np.sqrt(NPOINTS)
    X_i = add_flips(X_i)  # Shape (N, 2, NPOINTS, 3)



.. GENERATED FROM PYTHON SOURCE LINES 126-131

Feature engineering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Add some weight on both ends of our fibers:


.. GENERATED FROM PYTHON SOURCE LINES 131-138

.. code-block:: Python


    gamma = 2.0
    X_i[:, :, 0, :] *= gamma
    X_i[:, :, -1, :] *= gamma
    Y_j[:, :, 0, :] *= gamma
    Y_j[:, :, -1, :] *= gamma


.. GENERATED FROM PYTHON SOURCE LINES 139-144

Optimizing performances
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Contiguous memory accesses are critical for performances on the GPU.


.. GENERATED FROM PYTHON SOURCE LINES 144-161

.. code-block:: Python


    from pykeops.torch.cluster import sort_clusters, cluster_ranges

    ranges_j = cluster_ranges(labels_j)  # Ranges for all clusters
    Y_j, labels_j = sort_clusters(
        Y_j, labels_j
    )  # Make sure that all clusters are contiguous in memory

    C = len(ranges_j)  # Number of classes

    if C != labels_j.max() + 1:
        raise ValueError("???")

    for j, (start_j, end_j) in enumerate(ranges_j):
        if start_j >= end_j:
            raise ValueError(f"The {j}-th cluster of the atlas seems to be empty.")


.. GENERATED FROM PYTHON SOURCE LINES 162-167

Each fiber is sampled with 20 points in R^3.
Thus, one tractogram is a matrix of size n x 60 where n is the number of fibers
The atlas is labelled, wich means that each fiber belong to a cluster.
This is summarized by the vector labels_j of size n x 1. labels_j[i] is the label of the fiber i.
Subsample the data by a factor 4 if you want to reduce the computational time:

.. GENERATED FROM PYTHON SOURCE LINES 167-171

.. code-block:: Python


    subsample = 20 if True else 1



.. GENERATED FROM PYTHON SOURCE LINES 173-181

.. code-block:: Python


    to_keep = []
    for start_j, end_j in ranges_j:
        to_keep += list(range(start_j, end_j, subsample))

    Y_j, labels_j = Y_j[to_keep].contiguous(), labels_j[to_keep].contiguous()
    ranges_j = cluster_ranges(labels_j)  # Keep the ranges up to date!


.. GENERATED FROM PYTHON SOURCE LINES 183-187

.. code-block:: Python


    X_i = X_i[::subsample].contiguous()



.. GENERATED FROM PYTHON SOURCE LINES 189-195

.. code-block:: Python


    N, M = len(X_i), len(Y_j)

    print("Data loaded.")



.. GENERATED FROM PYTHON SOURCE LINES 196-199

Pre-computing cluster prototypes
--------------------------------------


.. GENERATED FROM PYTHON SOURCE LINES 199-213

.. code-block:: Python


    from pykeops.torch import LazyTensor


    def nn_search(x_i, y_j, ranges=None):
        x_i = LazyTensor(x_i[:, None, :])  # Broadcasted "line" variable
        y_j = LazyTensor(y_j[None, :, :])  # Broadcasted "column" variable

        D_ij = ((x_i - y_j) ** 2).sum(-1)  # Symbolic matrix of squared distances
        D_ij.ranges = ranges  # Apply our block-sparsity pattern

        return D_ij.argmin(dim=1).view(-1)



.. GENERATED FROM PYTHON SOURCE LINES 214-216

K-Means loop:


.. GENERATED FROM PYTHON SOURCE LINES 216-232

.. code-block:: Python



    def KMeans(x_i, c_j, Nits=10, ranges=None):
        D = x_i.shape[1]
        for i in range(10):
            # Points -> Nearest cluster
            labs_i = nn_search(x_i, c_j, ranges=ranges)
            # Class cardinals:
            Ncl = torch.bincount(labs_i.view(-1)).type(dtype)
            # Compute the cluster centroids with torch.bincount:
            for d in range(D):  # Unfortunately, vector weights are not supported...
                c_j[:, d] = torch.bincount(labs_i.view(-1), weights=x_i[:, d]) / Ncl

        return c_j, labs_i



.. GENERATED FROM PYTHON SOURCE LINES 233-239

On the subject
~~~~~~~~~~~~~~~~~~~~~~~~

For new subject (unlabelled), we perform a simple Kmean
on R^60 to obtain a cluster of the data.


.. GENERATED FROM PYTHON SOURCE LINES 239-258

.. code-block:: Python


    K = 1000

    # Pick K fibers at random:
    perm = torch.randperm(N)
    random_labels = perm[:K]
    C_i = X_i[random_labels]  # (K, 2, NPOINTS, 3)

    # Reshape our data as "N-by-60" tensors:
    C_i_flat = C_i.view(K * 2, NPOINTS * 3)  # Flattened list of centroids
    X_i_flat = X_i.view(N * 2, NPOINTS * 3)  # Flattened list of fibers

    # Retrieve our new centroids:
    C_i_flat, labs_i = KMeans(X_i_flat, C_i_flat)
    C_i = C_i_flat.view(K, 2, NPOINTS, 3)
    # Standard deviation of our clusters:
    std_i = ((X_i_flat - C_i_flat[labs_i.view(-1), :]) ** 2).sum(dim=1).mean().sqrt()



.. GENERATED FROM PYTHON SOURCE LINES 259-268

On the atlas
~~~~~~~~~~~~~~~~~~~~~~~

To use the multiscale version of the regularized OT,
we need to have a cluster of our input data (atlas and new subject).
For the atlas, the cluster is given by the segmentation. We use a Kmeans to
separate the fibers and the flips within a cluser, in order to have clusters whose fibers have similar
orientation


.. GENERATED FROM PYTHON SOURCE LINES 269-278

.. code-block:: Python

    ranges_yi = 2 * ranges_j

    ranges_cj = 2 * torch.arange(C).type_as(ranges_j)
    ranges_cj = torch.stack((ranges_cj, ranges_cj + 2)).t().contiguous()

    slices_i = 1 + torch.arange(C).type_as(ranges_j)
    ranges_yi_cj = (ranges_yi, slices_i, ranges_cj, ranges_cj, slices_i, ranges_yi)



.. GENERATED FROM PYTHON SOURCE LINES 279-280

Pick one unoriented (i.e. two oriented) fibers per class:

.. GENERATED FROM PYTHON SOURCE LINES 280-286

.. code-block:: Python


    first_labels = ranges_j[:, 0]  # One label per class

    C_j = Y_j[first_labels.type(dtypeint), :, :, :]  # (C, 2, NPOINTS, 3)
    C_j_flat = C_j.view(C * 2, NPOINTS * 3)  # Flattened list of centroids


.. GENERATED FROM PYTHON SOURCE LINES 288-297

.. code-block:: Python



    Y_j_flat = Y_j.view(M * 2, NPOINTS * 3)
    C_j_flat, labs_j = KMeans(Y_j_flat, C_j_flat, ranges=ranges_yi_cj)
    C_j = C_j_flat.view(C, 2, NPOINTS, 3)

    std_j = ((Y_j_flat - C_j_flat[labs_j.view(-1), :]) ** 2).sum(dim=1).mean().sqrt()



.. GENERATED FROM PYTHON SOURCE LINES 298-309

Compute the OT plan with the multiscale algorithm
------------------------------------------------------

To use the **multiscale** Sinkhorn algorithm,
we should simply provide:

- explicit **labels** and **weights** for both input measures,
- a typical **cluster_scale** which specifies the iteration at which
  the Sinkhorn loop jumps from a **coarse** to a **fine** representation
  of the data.


.. GENERATED FROM PYTHON SOURCE LINES 309-322

.. code-block:: Python

    blur = 3.0
    OT_solver = SamplesLoss(
        "sinkhorn",
        p=2,
        blur=blur,
        reach=20,
        scaling=0.9,
        cluster_scale=max(std_i, std_j),
        debias=False,
        potentials=True,
        verbose=True,
    )


.. GENERATED FROM PYTHON SOURCE LINES 323-325

To specify explicit cluster labels, SamplesLoss also requires
explicit weights. Let's go with the default option - a uniform distribution:

.. GENERATED FROM PYTHON SOURCE LINES 325-344

.. code-block:: Python


    a_i = torch.ones(2 * N).type(dtype) / (2 * N)
    b_j = torch.ones(2 * M).type(dtype) / (2 * M)

    start = time.time()

    # Compute the dual vectors F_i and G_j:
    # 6 args -> labels_i, weights_i, locations_i, labels_j, weights_j, locations_j
    F_i, G_j = OT_solver(
        labs_i, a_i, X_i.view(N * 2, NPOINTS * 3), labs_j, b_j, Y_j.view(M * 2, NPOINTS * 3)
    )

    if use_cuda:
        torch.cuda.synchronize()
    end = time.time()

    print("OT computed in  in {:.3f}s.".format(end - start))



.. GENERATED FROM PYTHON SOURCE LINES 345-351

Use the OT to perform the transfer of labels
----------------------------------------------

The transport plan pi_{i,j} gives the probability for
a fiber i of the subject to be assigned to the (labeled) fiber j of the atlas.
We assign a label l to the fiber i as the label with maximum probability for all the soft assignement of i.

.. GENERATED FROM PYTHON SOURCE LINES 351-357

.. code-block:: Python


    # Return to the original data (unflipped)
    X_i = X_i[:, 0, :, :].contiguous()  # (N, NPOINTS, 3)
    F_i = F_i[::2].contiguous()  # (N,)



.. GENERATED FROM PYTHON SOURCE LINES 358-360

Compute the transport plan:


.. GENERATED FROM PYTHON SOURCE LINES 360-373

.. code-block:: Python


    XX_i = LazyTensor(X_i.view(N, 1, NPOINTS * 3))
    YY_j = LazyTensor(Y_j.view(1, M * 2, NPOINTS * 3))
    FF_i = LazyTensor(F_i.view(N, 1, 1))
    GG_j = LazyTensor(G_j.view(1, M * 2, 1))

    # Cost matrix:
    CC_ij = ((XX_i - YY_j) ** 2).sum(-1) / 2  # (N, M * 2, 1) LazyTensor

    # Scaled kernel matrix:
    KK_ij = ((FF_i + GG_j - CC_ij) / blur**2).exp()  # (N, M * 2, 1) LazyTensor



.. GENERATED FROM PYTHON SOURCE LINES 374-376

Transfer the labels, bypassing the one-hot vector encoding
for the sake of efficiency:

.. GENERATED FROM PYTHON SOURCE LINES 376-402

.. code-block:: Python



    def slicing_ranges(start, end):
        """KeOps does not yet support sliced indexing of LazyTensors, so we have to resort to some black magic..."""
        ranges_i = (
            torch.Tensor([[0, N]]).type(dtypeint).int()
        )  # Int32, on the correct device
        slices_i = torch.Tensor([1]).type(dtypeint).int()
        redranges_j = torch.Tensor([[start, end]]).type(dtypeint).int()
        return (ranges_i, slices_i, redranges_j, redranges_j, slices_i, ranges_i)


    weights_i = torch.zeros(C + 1, N).type(torch.FloatTensor)  # C classes + outliers

    for c in range(C):
        start, end = 2 * ranges_j[c]
        KK_ij.ranges = slicing_ranges(
            start, end
        )  # equivalent to "PP_ij[:, start:end]", which is not supported yet...
        weights_i[c] = (KK_ij.sum(dim=1).view(N) / (2 * M)).cpu()

    weights_i[C] = 0.2  # If no label has a bigger weight than .01, this fiber is an outlier

    labels_i = weights_i.argmax(dim=0)  # (N,) vector



.. GENERATED FROM PYTHON SOURCE LINES 403-404

Save our new cluster information as a signal:

.. GENERATED FROM PYTHON SOURCE LINES 404-412

.. code-block:: Python


    # Come back to the original data
    X_i[:, 0, :] /= gamma
    X_i[:, -1, :] /= gamma

    save_tracts_labels_separate(
        "output/labels_subject", X_i, labels_i, 0, labels_i.max() + 1
    )  # save the data


.. _sphx_glr_download__auto_examples_brain_tractograms_transfer_labels.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: transfer_labels.ipynb <transfer_labels.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: transfer_labels.py <transfer_labels.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_