.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/brain_tractograms/transfer_labels.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download__auto_examples_brain_tractograms_transfer_labels.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_brain_tractograms_transfer_labels.py: Transferring labels from a segmented atlas ============================================= We use a new multiscale algorithm for solving regularized Optimal Transport problems on the GPU, with a linear memory footprint. We use the resulting smooth assignments to perform label transfer for atlas-based segmentation of fiber tractograms. The parameters -- \emph{blur} and \emph{reach} -- of our method are meaningful, defining the minimum and maximum distance at which two fibers are compared with each other. They can be set according to anatomical knowledge. .. GENERATED FROM PYTHON SOURCE LINES 16-20 Setup --------------------- Standard imports: .. GENERATED FROM PYTHON SOURCE LINES 20-31 .. code-block:: Python import numpy as np import matplotlib.pyplot as plt import time import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor dtypeint = torch.cuda.LongTensor if use_cuda else torch.LongTensor .. GENERATED FROM PYTHON SOURCE LINES 32-34 Loading and saving data routines .. GENERATED FROM PYTHON SOURCE LINES 34-39 .. code-block:: Python from tract_io import read_vtk, streamlines_resample, save_vtk, save_vtk_labels from tract_io import save_tract, save_tract_numpy from tract_io import save_tract_with_labels, save_tracts_labels_separate .. GENERATED FROM PYTHON SOURCE LINES 40-44 Dataset --------------------- Fetch data from the KeOps website: .. GENERATED FROM PYTHON SOURCE LINES 44-64 .. code-block:: Python import os def fetch_file(name): if not os.path.exists(f"data/{name}.npy"): import urllib.request print("Fetching the atlas... ", end="", flush=True) urllib.request.urlretrieve( f"https://www.kernel-operations.io/data/{name}.npy", f"data/{name}.npy" ) print("Done.") fetch_file("tracto_atlas") fetch_file("atlas_labels") fetch_file("tracto1") .. GENERATED FROM PYTHON SOURCE LINES 65-68 Fibers do not have a canonical orientation. Since our ground distance is a simple L2-distance on the sampled fibers, we augment the dataset with the mirror flip of all fibers and perform the OT on this augmented dataset. .. GENERATED FROM PYTHON SOURCE LINES 68-82 .. code-block:: Python def torch_load(X, dtype=dtype): return torch.from_numpy(X).type(dtype).contiguous() def add_flips(X): """Adds flips and loads on the GPU the input fiber track.""" # X = X[:,None,:,:] X_flip = torch.flip(X, (1,)) X = torch.stack((X, X_flip), dim=1) # (Nfibers, 2, NPOINTS, 3) return X .. GENERATED FROM PYTHON SOURCE LINES 83-87 Source atlas ~~~~~~~~~~~~~~~~~~~ Load atlas (segmented, each fiber has a label): .. GENERATED FROM PYTHON SOURCE LINES 87-91 .. code-block:: Python Y_j = torch_load(np.load("data/tracto_atlas.npy")) labels_j = torch_load(np.load("data/atlas_labels.npy"), dtype=dtypeint) .. GENERATED FROM PYTHON SOURCE LINES 93-96 .. code-block:: Python M, NPOINTS = Y_j.shape[0], Y_j.shape[1] # Number of fibers, points per fiber .. GENERATED FROM PYTHON SOURCE LINES 98-101 .. code-block:: Python Y_j = Y_j.view(M, NPOINTS, 3) / np.sqrt(NPOINTS) .. GENERATED FROM PYTHON SOURCE LINES 103-105 .. code-block:: Python Y_j = add_flips(Y_j) # Shape (M, 2, NPOINTS, 3) .. GENERATED FROM PYTHON SOURCE LINES 106-111 Target subject ~~~~~~~~~~~~~~~~~~~~ Load a new subject (unlabelled) .. GENERATED FROM PYTHON SOURCE LINES 111-125 .. code-block:: Python X_i = torch_load(np.load("data/tracto1.npy")) N, NPOINTS_i = X_i.shape[0], X_i.shape[1] # Number of fibers, points per fiber if NPOINTS != NPOINTS_i: raise ValueError( "The atlas and the subject are not sampled with the same number of points: " + f"{NPOINTS} and {NPOINTS_i}, respectively." ) X_i = X_i.view(N, NPOINTS, 3) / np.sqrt(NPOINTS) X_i = add_flips(X_i) # Shape (N, 2, NPOINTS, 3) .. GENERATED FROM PYTHON SOURCE LINES 126-131 Feature engineering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Add some weight on both ends of our fibers: .. GENERATED FROM PYTHON SOURCE LINES 131-138 .. code-block:: Python gamma = 2.0 X_i[:, :, 0, :] *= gamma X_i[:, :, -1, :] *= gamma Y_j[:, :, 0, :] *= gamma Y_j[:, :, -1, :] *= gamma .. GENERATED FROM PYTHON SOURCE LINES 139-144 Optimizing performances ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Contiguous memory accesses are critical for performances on the GPU. .. GENERATED FROM PYTHON SOURCE LINES 144-161 .. code-block:: Python from pykeops.torch.cluster import sort_clusters, cluster_ranges ranges_j = cluster_ranges(labels_j) # Ranges for all clusters Y_j, labels_j = sort_clusters( Y_j, labels_j ) # Make sure that all clusters are contiguous in memory C = len(ranges_j) # Number of classes if C != labels_j.max() + 1: raise ValueError("???") for j, (start_j, end_j) in enumerate(ranges_j): if start_j >= end_j: raise ValueError(f"The {j}-th cluster of the atlas seems to be empty.") .. GENERATED FROM PYTHON SOURCE LINES 162-167 Each fiber is sampled with 20 points in R^3. Thus, one tractogram is a matrix of size n x 60 where n is the number of fibers The atlas is labelled, wich means that each fiber belong to a cluster. This is summarized by the vector labels_j of size n x 1. labels_j[i] is the label of the fiber i. Subsample the data by a factor 4 if you want to reduce the computational time: .. GENERATED FROM PYTHON SOURCE LINES 167-171 .. code-block:: Python subsample = 20 if True else 1 .. GENERATED FROM PYTHON SOURCE LINES 173-181 .. code-block:: Python to_keep = [] for start_j, end_j in ranges_j: to_keep += list(range(start_j, end_j, subsample)) Y_j, labels_j = Y_j[to_keep].contiguous(), labels_j[to_keep].contiguous() ranges_j = cluster_ranges(labels_j) # Keep the ranges up to date! .. GENERATED FROM PYTHON SOURCE LINES 183-187 .. code-block:: Python X_i = X_i[::subsample].contiguous() .. GENERATED FROM PYTHON SOURCE LINES 189-195 .. code-block:: Python N, M = len(X_i), len(Y_j) print("Data loaded.") .. GENERATED FROM PYTHON SOURCE LINES 196-199 Pre-computing cluster prototypes -------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 199-213 .. code-block:: Python from pykeops.torch import LazyTensor def nn_search(x_i, y_j, ranges=None): x_i = LazyTensor(x_i[:, None, :]) # Broadcasted "line" variable y_j = LazyTensor(y_j[None, :, :]) # Broadcasted "column" variable D_ij = ((x_i - y_j) ** 2).sum(-1) # Symbolic matrix of squared distances D_ij.ranges = ranges # Apply our block-sparsity pattern return D_ij.argmin(dim=1).view(-1) .. GENERATED FROM PYTHON SOURCE LINES 214-216 K-Means loop: .. GENERATED FROM PYTHON SOURCE LINES 216-232 .. code-block:: Python def KMeans(x_i, c_j, Nits=10, ranges=None): D = x_i.shape[1] for i in range(10): # Points -> Nearest cluster labs_i = nn_search(x_i, c_j, ranges=ranges) # Class cardinals: Ncl = torch.bincount(labs_i.view(-1)).type(dtype) # Compute the cluster centroids with torch.bincount: for d in range(D): # Unfortunately, vector weights are not supported... c_j[:, d] = torch.bincount(labs_i.view(-1), weights=x_i[:, d]) / Ncl return c_j, labs_i .. GENERATED FROM PYTHON SOURCE LINES 233-239 On the subject ~~~~~~~~~~~~~~~~~~~~~~~~ For new subject (unlabelled), we perform a simple Kmean on R^60 to obtain a cluster of the data. .. GENERATED FROM PYTHON SOURCE LINES 239-258 .. code-block:: Python K = 1000 # Pick K fibers at random: perm = torch.randperm(N) random_labels = perm[:K] C_i = X_i[random_labels] # (K, 2, NPOINTS, 3) # Reshape our data as "N-by-60" tensors: C_i_flat = C_i.view(K * 2, NPOINTS * 3) # Flattened list of centroids X_i_flat = X_i.view(N * 2, NPOINTS * 3) # Flattened list of fibers # Retrieve our new centroids: C_i_flat, labs_i = KMeans(X_i_flat, C_i_flat) C_i = C_i_flat.view(K, 2, NPOINTS, 3) # Standard deviation of our clusters: std_i = ((X_i_flat - C_i_flat[labs_i.view(-1), :]) ** 2).sum(dim=1).mean().sqrt() .. GENERATED FROM PYTHON SOURCE LINES 259-268 On the atlas ~~~~~~~~~~~~~~~~~~~~~~~ To use the multiscale version of the regularized OT, we need to have a cluster of our input data (atlas and new subject). For the atlas, the cluster is given by the segmentation. We use a Kmeans to separate the fibers and the flips within a cluser, in order to have clusters whose fibers have similar orientation .. GENERATED FROM PYTHON SOURCE LINES 269-278 .. code-block:: Python ranges_yi = 2 * ranges_j ranges_cj = 2 * torch.arange(C).type_as(ranges_j) ranges_cj = torch.stack((ranges_cj, ranges_cj + 2)).t().contiguous() slices_i = 1 + torch.arange(C).type_as(ranges_j) ranges_yi_cj = (ranges_yi, slices_i, ranges_cj, ranges_cj, slices_i, ranges_yi) .. GENERATED FROM PYTHON SOURCE LINES 279-280 Pick one unoriented (i.e. two oriented) fibers per class: .. GENERATED FROM PYTHON SOURCE LINES 280-286 .. code-block:: Python first_labels = ranges_j[:, 0] # One label per class C_j = Y_j[first_labels.type(dtypeint), :, :, :] # (C, 2, NPOINTS, 3) C_j_flat = C_j.view(C * 2, NPOINTS * 3) # Flattened list of centroids .. GENERATED FROM PYTHON SOURCE LINES 288-297 .. code-block:: Python Y_j_flat = Y_j.view(M * 2, NPOINTS * 3) C_j_flat, labs_j = KMeans(Y_j_flat, C_j_flat, ranges=ranges_yi_cj) C_j = C_j_flat.view(C, 2, NPOINTS, 3) std_j = ((Y_j_flat - C_j_flat[labs_j.view(-1), :]) ** 2).sum(dim=1).mean().sqrt() .. GENERATED FROM PYTHON SOURCE LINES 298-309 Compute the OT plan with the multiscale algorithm ------------------------------------------------------ To use the **multiscale** Sinkhorn algorithm, we should simply provide: - explicit **labels** and **weights** for both input measures, - a typical **cluster_scale** which specifies the iteration at which the Sinkhorn loop jumps from a **coarse** to a **fine** representation of the data. .. GENERATED FROM PYTHON SOURCE LINES 309-322 .. code-block:: Python blur = 3.0 OT_solver = SamplesLoss( "sinkhorn", p=2, blur=blur, reach=20, scaling=0.9, cluster_scale=max(std_i, std_j), debias=False, potentials=True, verbose=True, ) .. GENERATED FROM PYTHON SOURCE LINES 323-325 To specify explicit cluster labels, SamplesLoss also requires explicit weights. Let's go with the default option - a uniform distribution: .. GENERATED FROM PYTHON SOURCE LINES 325-344 .. code-block:: Python a_i = torch.ones(2 * N).type(dtype) / (2 * N) b_j = torch.ones(2 * M).type(dtype) / (2 * M) start = time.time() # Compute the dual vectors F_i and G_j: # 6 args -> labels_i, weights_i, locations_i, labels_j, weights_j, locations_j F_i, G_j = OT_solver( labs_i, a_i, X_i.view(N * 2, NPOINTS * 3), labs_j, b_j, Y_j.view(M * 2, NPOINTS * 3) ) if use_cuda: torch.cuda.synchronize() end = time.time() print("OT computed in in {:.3f}s.".format(end - start)) .. GENERATED FROM PYTHON SOURCE LINES 345-351 Use the OT to perform the transfer of labels ---------------------------------------------- The transport plan pi_{i,j} gives the probability for a fiber i of the subject to be assigned to the (labeled) fiber j of the atlas. We assign a label l to the fiber i as the label with maximum probability for all the soft assignement of i. .. GENERATED FROM PYTHON SOURCE LINES 351-357 .. code-block:: Python # Return to the original data (unflipped) X_i = X_i[:, 0, :, :].contiguous() # (N, NPOINTS, 3) F_i = F_i[::2].contiguous() # (N,) .. GENERATED FROM PYTHON SOURCE LINES 358-360 Compute the transport plan: .. GENERATED FROM PYTHON SOURCE LINES 360-373 .. code-block:: Python XX_i = LazyTensor(X_i.view(N, 1, NPOINTS * 3)) YY_j = LazyTensor(Y_j.view(1, M * 2, NPOINTS * 3)) FF_i = LazyTensor(F_i.view(N, 1, 1)) GG_j = LazyTensor(G_j.view(1, M * 2, 1)) # Cost matrix: CC_ij = ((XX_i - YY_j) ** 2).sum(-1) / 2 # (N, M * 2, 1) LazyTensor # Scaled kernel matrix: KK_ij = ((FF_i + GG_j - CC_ij) / blur**2).exp() # (N, M * 2, 1) LazyTensor .. GENERATED FROM PYTHON SOURCE LINES 374-376 Transfer the labels, bypassing the one-hot vector encoding for the sake of efficiency: .. GENERATED FROM PYTHON SOURCE LINES 376-402 .. code-block:: Python def slicing_ranges(start, end): """KeOps does not yet support sliced indexing of LazyTensors, so we have to resort to some black magic...""" ranges_i = ( torch.Tensor([[0, N]]).type(dtypeint).int() ) # Int32, on the correct device slices_i = torch.Tensor([1]).type(dtypeint).int() redranges_j = torch.Tensor([[start, end]]).type(dtypeint).int() return (ranges_i, slices_i, redranges_j, redranges_j, slices_i, ranges_i) weights_i = torch.zeros(C + 1, N).type(torch.FloatTensor) # C classes + outliers for c in range(C): start, end = 2 * ranges_j[c] KK_ij.ranges = slicing_ranges( start, end ) # equivalent to "PP_ij[:, start:end]", which is not supported yet... weights_i[c] = (KK_ij.sum(dim=1).view(N) / (2 * M)).cpu() weights_i[C] = 0.2 # If no label has a bigger weight than .01, this fiber is an outlier labels_i = weights_i.argmax(dim=0) # (N,) vector .. GENERATED FROM PYTHON SOURCE LINES 403-404 Save our new cluster information as a signal: .. GENERATED FROM PYTHON SOURCE LINES 404-412 .. code-block:: Python # Come back to the original data X_i[:, 0, :] /= gamma X_i[:, -1, :] /= gamma save_tracts_labels_separate( "output/labels_subject", X_i, labels_i, 0, labels_i.max() + 1 ) # save the data .. _sphx_glr_download__auto_examples_brain_tractograms_transfer_labels.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: transfer_labels.ipynb <transfer_labels.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: transfer_labels.py <transfer_labels.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_