.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/brain_tractograms/transfer_labels.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_brain_tractograms_transfer_labels.py: Transferring labels from a segmented atlas ============================================= We use a new multiscale algorithm for solving regularized Optimal Transport problems on the GPU, with a linear memory footprint. We use the resulting smooth assignments to perform label transfer for atlas-based segmentation of fiber tractograms. The parameters -- \emph{blur} and \emph{reach} -- of our method are meaningful, defining the minimum and maximum distance at which two fibers are compared with each other. They can be set according to anatomical knowledge. .. GENERATED FROM PYTHON SOURCE LINES 16-20 Setup --------------------- Standard imports: .. GENERATED FROM PYTHON SOURCE LINES 20-31 .. code-block:: default import numpy as np import matplotlib.pyplot as plt import time import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor dtypeint = torch.cuda.LongTensor if use_cuda else torch.LongTensor .. GENERATED FROM PYTHON SOURCE LINES 32-34 Loading and saving data routines .. GENERATED FROM PYTHON SOURCE LINES 34-39 .. code-block:: default from tract_io import read_vtk, streamlines_resample, save_vtk, save_vtk_labels from tract_io import save_tract, save_tract_numpy from tract_io import save_tract_with_labels, save_tracts_labels_separate .. GENERATED FROM PYTHON SOURCE LINES 40-44 Dataset --------------------- Fetch data from the KeOps website: .. GENERATED FROM PYTHON SOURCE LINES 44-64 .. code-block:: default import os def fetch_file(name): if not os.path.exists(f"data/{name}.npy"): import urllib.request print("Fetching the atlas... ", end="", flush=True) urllib.request.urlretrieve( f"https://www.kernel-operations.io/data/{name}.npy", f"data/{name}.npy" ) print("Done.") fetch_file("tracto_atlas") fetch_file("atlas_labels") fetch_file("tracto1") .. GENERATED FROM PYTHON SOURCE LINES 65-68 Fibers do not have a canonical orientation. Since our ground distance is a simple L2-distance on the sampled fibers, we augment the dataset with the mirror flip of all fibers and perform the OT on this augmented dataset. .. GENERATED FROM PYTHON SOURCE LINES 68-82 .. code-block:: default def torch_load(X, dtype=dtype): return torch.from_numpy(X).type(dtype).contiguous() def add_flips(X): """Adds flips and loads on the GPU the input fiber track.""" # X = X[:,None,:,:] X_flip = torch.flip(X, (1,)) X = torch.stack((X, X_flip), dim=1) # (Nfibers, 2, NPOINTS, 3) return X .. GENERATED FROM PYTHON SOURCE LINES 83-87 Source atlas ~~~~~~~~~~~~~~~~~~~ Load atlas (segmented, each fiber has a label): .. GENERATED FROM PYTHON SOURCE LINES 87-91 .. code-block:: default Y_j = torch_load(np.load("data/tracto_atlas.npy")) labels_j = torch_load(np.load("data/atlas_labels.npy"), dtype=dtypeint) .. GENERATED FROM PYTHON SOURCE LINES 93-96 .. code-block:: default M, NPOINTS = Y_j.shape[0], Y_j.shape[1] # Number of fibers, points per fiber .. GENERATED FROM PYTHON SOURCE LINES 98-101 .. code-block:: default Y_j = Y_j.view(M, NPOINTS, 3) / np.sqrt(NPOINTS) .. GENERATED FROM PYTHON SOURCE LINES 103-105 .. code-block:: default Y_j = add_flips(Y_j) # Shape (M, 2, NPOINTS, 3) .. GENERATED FROM PYTHON SOURCE LINES 106-111 Target subject ~~~~~~~~~~~~~~~~~~~~ Load a new subject (unlabelled) .. GENERATED FROM PYTHON SOURCE LINES 111-125 .. code-block:: default X_i = torch_load(np.load("data/tracto1.npy")) N, NPOINTS_i = X_i.shape[0], X_i.shape[1] # Number of fibers, points per fiber if NPOINTS != NPOINTS_i: raise ValueError( "The atlas and the subject are not sampled with the same number of points: " + f"{NPOINTS} and {NPOINTS_i}, respectively." ) X_i = X_i.view(N, NPOINTS, 3) / np.sqrt(NPOINTS) X_i = add_flips(X_i) # Shape (N, 2, NPOINTS, 3) .. GENERATED FROM PYTHON SOURCE LINES 126-131 Feature engineering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Add some weight on both ends of our fibers: .. GENERATED FROM PYTHON SOURCE LINES 131-138 .. code-block:: default gamma = 2.0 X_i[:, :, 0, :] *= gamma X_i[:, :, -1, :] *= gamma Y_j[:, :, 0, :] *= gamma Y_j[:, :, -1, :] *= gamma .. GENERATED FROM PYTHON SOURCE LINES 139-144 Optimizing performances ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Contiguous memory accesses are critical for performances on the GPU. .. GENERATED FROM PYTHON SOURCE LINES 144-161 .. code-block:: default from pykeops.torch.cluster import sort_clusters, cluster_ranges ranges_j = cluster_ranges(labels_j) # Ranges for all clusters Y_j, labels_j = sort_clusters( Y_j, labels_j ) # Make sure that all clusters are contiguous in memory C = len(ranges_j) # Number of classes if C != labels_j.max() + 1: raise ValueError("???") for j, (start_j, end_j) in enumerate(ranges_j): if start_j >= end_j: raise ValueError(f"The {j}-th cluster of the atlas seems to be empty.") .. GENERATED FROM PYTHON SOURCE LINES 162-167 Each fiber is sampled with 20 points in R^3. Thus, one tractogram is a matrix of size n x 60 where n is the number of fibers The atlas is labelled, wich means that each fiber belong to a cluster. This is summarized by the vector labels_j of size n x 1. labels_j[i] is the label of the fiber i. Subsample the data by a factor 4 if you want to reduce the computational time: .. GENERATED FROM PYTHON SOURCE LINES 167-171 .. code-block:: default subsample = 20 if True else 1 .. GENERATED FROM PYTHON SOURCE LINES 173-181 .. code-block:: default to_keep = [] for start_j, end_j in ranges_j: to_keep += list(range(start_j, end_j, subsample)) Y_j, labels_j = Y_j[to_keep].contiguous(), labels_j[to_keep].contiguous() ranges_j = cluster_ranges(labels_j) # Keep the ranges up to date! .. GENERATED FROM PYTHON SOURCE LINES 183-187 .. code-block:: default X_i = X_i[::subsample].contiguous() .. GENERATED FROM PYTHON SOURCE LINES 189-195 .. code-block:: default N, M = len(X_i), len(Y_j) print("Data loaded.") .. GENERATED FROM PYTHON SOURCE LINES 196-199 Pre-computing cluster prototypes -------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 199-213 .. code-block:: default from pykeops.torch import LazyTensor def nn_search(x_i, y_j, ranges=None): x_i = LazyTensor(x_i[:, None, :]) # Broadcasted "line" variable y_j = LazyTensor(y_j[None, :, :]) # Broadcasted "column" variable D_ij = ((x_i - y_j) ** 2).sum(-1) # Symbolic matrix of squared distances D_ij.ranges = ranges # Apply our block-sparsity pattern return D_ij.argmin(dim=1).view(-1) .. GENERATED FROM PYTHON SOURCE LINES 214-216 K-Means loop: .. GENERATED FROM PYTHON SOURCE LINES 216-233 .. code-block:: default def KMeans(x_i, c_j, Nits=10, ranges=None): D = x_i.shape[1] for i in range(10): # Points -> Nearest cluster labs_i = nn_search(x_i, c_j, ranges=ranges) # Class cardinals: Ncl = torch.bincount(labs_i.view(-1)).type(dtype) # Compute the cluster centroids with torch.bincount: for d in range(D): # Unfortunately, vector weights are not supported... c_j[:, d] = torch.bincount(labs_i.view(-1), weights=x_i[:, d]) / Ncl return c_j, labs_i .. GENERATED FROM PYTHON SOURCE LINES 234-240 On the subject ~~~~~~~~~~~~~~~~~~~~~~~~ For new subject (unlabelled), we perform a simple Kmean on R^60 to obtain a cluster of the data. .. GENERATED FROM PYTHON SOURCE LINES 240-259 .. code-block:: default K = 1000 # Pick K fibers at random: perm = torch.randperm(N) random_labels = perm[:K] C_i = X_i[random_labels] # (K, 2, NPOINTS, 3) # Reshape our data as "N-by-60" tensors: C_i_flat = C_i.view(K * 2, NPOINTS * 3) # Flattened list of centroids X_i_flat = X_i.view(N * 2, NPOINTS * 3) # Flattened list of fibers # Retrieve our new centroids: C_i_flat, labs_i = KMeans(X_i_flat, C_i_flat) C_i = C_i_flat.view(K, 2, NPOINTS, 3) # Standard deviation of our clusters: std_i = ((X_i_flat - C_i_flat[labs_i.view(-1), :]) ** 2).sum(dim=1).mean().sqrt() .. GENERATED FROM PYTHON SOURCE LINES 260-269 On the atlas ~~~~~~~~~~~~~~~~~~~~~~~ To use the multiscale version of the regularized OT, we need to have a cluster of our input data (atlas and new subject). For the atlas, the cluster is given by the segmentation. We use a Kmeans to separate the fibers and the flips within a cluser, in order to have clusters whose fibers have similar orientation .. GENERATED FROM PYTHON SOURCE LINES 270-279 .. code-block:: default ranges_yi = 2 * ranges_j ranges_cj = 2 * torch.arange(C).type_as(ranges_j) ranges_cj = torch.stack((ranges_cj, ranges_cj + 2)).t().contiguous() slices_i = 1 + torch.arange(C).type_as(ranges_j) ranges_yi_cj = (ranges_yi, slices_i, ranges_cj, ranges_cj, slices_i, ranges_yi) .. GENERATED FROM PYTHON SOURCE LINES 280-281 Pick one unoriented (i.e. two oriented) fibers per class: .. GENERATED FROM PYTHON SOURCE LINES 281-287 .. code-block:: default first_labels = ranges_j[:, 0] # One label per class C_j = Y_j[first_labels.type(dtypeint), :, :, :] # (C, 2, NPOINTS, 3) C_j_flat = C_j.view(C * 2, NPOINTS * 3) # Flattened list of centroids .. GENERATED FROM PYTHON SOURCE LINES 289-298 .. code-block:: default Y_j_flat = Y_j.view(M * 2, NPOINTS * 3) C_j_flat, labs_j = KMeans(Y_j_flat, C_j_flat, ranges=ranges_yi_cj) C_j = C_j_flat.view(C, 2, NPOINTS, 3) std_j = ((Y_j_flat - C_j_flat[labs_j.view(-1), :]) ** 2).sum(dim=1).mean().sqrt() .. GENERATED FROM PYTHON SOURCE LINES 299-310 Compute the OT plan with the multiscale algorithm ------------------------------------------------------ To use the **multiscale** Sinkhorn algorithm, we should simply provide: - explicit **labels** and **weights** for both input measures, - a typical **cluster_scale** which specifies the iteration at which the Sinkhorn loop jumps from a **coarse** to a **fine** representation of the data. .. GENERATED FROM PYTHON SOURCE LINES 310-323 .. code-block:: default blur = 3.0 OT_solver = SamplesLoss( "sinkhorn", p=2, blur=blur, reach=20, scaling=0.9, cluster_scale=max(std_i, std_j), debias=False, potentials=True, verbose=True, ) .. GENERATED FROM PYTHON SOURCE LINES 324-326 To specify explicit cluster labels, SamplesLoss also requires explicit weights. Let's go with the default option - a uniform distribution: .. GENERATED FROM PYTHON SOURCE LINES 326-345 .. code-block:: default a_i = torch.ones(2 * N).type(dtype) / (2 * N) b_j = torch.ones(2 * M).type(dtype) / (2 * M) start = time.time() # Compute the dual vectors F_i and G_j: # 6 args -> labels_i, weights_i, locations_i, labels_j, weights_j, locations_j F_i, G_j = OT_solver( labs_i, a_i, X_i.view(N * 2, NPOINTS * 3), labs_j, b_j, Y_j.view(M * 2, NPOINTS * 3) ) if use_cuda: torch.cuda.synchronize() end = time.time() print("OT computed in in {:.3f}s.".format(end - start)) .. GENERATED FROM PYTHON SOURCE LINES 346-352 Use the OT to perform the transfer of labels ---------------------------------------------- The transport plan pi_{i,j} gives the probability for a fiber i of the subject to be assigned to the (labeled) fiber j of the atlas. We assign a label l to the fiber i as the label with maximum probability for all the soft assignement of i. .. GENERATED FROM PYTHON SOURCE LINES 352-358 .. code-block:: default # Return to the original data (unflipped) X_i = X_i[:, 0, :, :].contiguous() # (N, NPOINTS, 3) F_i = F_i[::2].contiguous() # (N,) .. GENERATED FROM PYTHON SOURCE LINES 359-361 Compute the transport plan: .. GENERATED FROM PYTHON SOURCE LINES 361-374 .. code-block:: default XX_i = LazyTensor(X_i.view(N, 1, NPOINTS * 3)) YY_j = LazyTensor(Y_j.view(1, M * 2, NPOINTS * 3)) FF_i = LazyTensor(F_i.view(N, 1, 1)) GG_j = LazyTensor(G_j.view(1, M * 2, 1)) # Cost matrix: CC_ij = ((XX_i - YY_j) ** 2).sum(-1) / 2 # (N, M * 2, 1) LazyTensor # Scaled kernel matrix: KK_ij = ((FF_i + GG_j - CC_ij) / blur ** 2).exp() # (N, M * 2, 1) LazyTensor .. GENERATED FROM PYTHON SOURCE LINES 375-377 Transfer the labels, bypassing the one-hot vector encoding for the sake of efficiency: .. GENERATED FROM PYTHON SOURCE LINES 377-403 .. code-block:: default def slicing_ranges(start, end): """KeOps does not yet support sliced indexing of LazyTensors, so we have to resort to some black magic...""" ranges_i = ( torch.Tensor([[0, N]]).type(dtypeint).int() ) # Int32, on the correct device slices_i = torch.Tensor([1]).type(dtypeint).int() redranges_j = torch.Tensor([[start, end]]).type(dtypeint).int() return (ranges_i, slices_i, redranges_j, redranges_j, slices_i, ranges_i) weights_i = torch.zeros(C + 1, N).type(torch.FloatTensor) # C classes + outliers for c in range(C): start, end = 2 * ranges_j[c] KK_ij.ranges = slicing_ranges( start, end ) # equivalent to "PP_ij[:, start:end]", which is not supported yet... weights_i[c] = (KK_ij.sum(dim=1).view(N) / (2 * M)).cpu() weights_i[C] = 0.2 # If no label has a bigger weight than .01, this fiber is an outlier labels_i = weights_i.argmax(dim=0) # (N,) vector .. GENERATED FROM PYTHON SOURCE LINES 404-405 Save our new cluster information as a signal: .. GENERATED FROM PYTHON SOURCE LINES 405-413 .. code-block:: default # Come back to the original data X_i[:, 0, :] /= gamma X_i[:, -1, :] /= gamma save_tracts_labels_separate( "output/labels_subject", X_i, labels_i, 0, labels_i.max() + 1 ) # save the data .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download__auto_examples_brain_tractograms_transfer_labels.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: transfer_labels.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: transfer_labels.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_