.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/brain_tractograms/track_barycenter.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download__auto_examples_brain_tractograms_track_barycenter.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_brain_tractograms_track_barycenter.py: Create an atlas using Wasserstein barycenters ================================================== In this tutorial, we compute the barycenter of a dataset of probability tracks. The barycenter is computed as the Fréchet mean for the Sinkhorn divergence, using a Lagrangian optimization scheme. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Setup --------------------- .. GENERATED FROM PYTHON SOURCE LINES 12-36 .. code-block:: Python import numpy as np import matplotlib.pyplot as plt from scipy import misc from sklearn.neighbors import KernelDensity from torch.nn.functional import avg_pool2d import torch from geomloss import SamplesLoss import time use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor from scipy.interpolate import RegularGridInterpolator import gzip import shutil import pdb import nibabel as nib import matplotlib.pyplot as plt .. GENERATED FROM PYTHON SOURCE LINES 37-41 Dataset ~~~~~~~~~~~~~~~~~~ In this tutorial, we work with probability tracks, that can be understood as normalized 3D images. We will compute the Wasserstein barycenter of this dataset. .. GENERATED FROM PYTHON SOURCE LINES 41-97 .. code-block:: Python import os def fetch_file(name): if not os.path.exists(f"data/{name}.nii.gz"): import urllib.request print("Fetching the atlas... ", end="", flush=True) urllib.request.urlretrieve( f"https://www.kernel-operations.io/data/{name}.nii.gz", f"data/{name}.nii.gz", ) with gzip.open(f"data/{name}.nii.gz", "rb") as f_in: with open(f"data/{name}.nii", "wb") as f_out: shutil.copyfileobj(f_in, f_out) print("Done.") for i in range(5): fetch_file(f"manual_ifof{i+1}") affine_transform = nib.load("data/manual_ifof1.nii").affine # load data in the nii format to a 3D, normalized array. def load_data_nii(fname): img = nib.load(fname) affine_mat = img.affine hdr = img.header data = img.get_fdata() data_norm = data / np.max(data) data_norm = torch.from_numpy(data_norm).type(dtype) return data_norm def grid(nx, ny, nz): x, y, z = torch.meshgrid( torch.arange(0.0, nx).type(dtype), torch.arange(0.0, ny).type(dtype), torch.arange(0.0, nz).type(dtype), indexing="ij", ) return torch.stack((x, y, z), dim=3).view(-1, 3).detach().cpu().numpy() # load the data set (here, we have 5 subjects) dataset = [] for i in range(5): fname = "data/manual_ifof" + str(i + 1) + ".nii" image_norm = load_data_nii(fname) print(image_norm.shape) dataset.append(image_norm) .. GENERATED FROM PYTHON SOURCE LINES 98-99 In this tutorial, we work with 3D images, understood as densities on the cube. .. GENERATED FROM PYTHON SOURCE LINES 99-119 .. code-block:: Python def img_to_points_cloud(data_norm): # normalized images (between 0 and 1) nx, ny, nz = data_norm.shape ind = data_norm.nonzero() indx = ind[:, 0] indy = ind[:, 1] indz = ind[:, 2] data_norm = data_norm / data_norm.sum() a_i = data_norm[indx, indy, indz] return ind.type(dtype), a_i def measure_to_image(x, nx, ny, nz, weights=None): bins = (x[:, 2]).floor() + nz * (x[:, 1]).floor() + nz * ny * (x[:, 0]).floor() count = bins.int().bincount(weights=weights, minlength=nx * ny * nz) return count.view(nx, ny, nz) .. GENERATED FROM PYTHON SOURCE LINES 120-121 To perform our computations, we turn these 3D arrays into weighted point cloud, regularly spaced in the grid. .. GENERATED FROM PYTHON SOURCE LINES 121-130 .. code-block:: Python a, b = img_to_points_cloud(dataset[0]), img_to_points_cloud(dataset[1]) c, d, e = ( img_to_points_cloud(dataset[2]), img_to_points_cloud(dataset[3]), img_to_points_cloud(dataset[4]), ) .. GENERATED FROM PYTHON SOURCE LINES 131-132 We initialize the barycenter as an upsampled, arithmetic mean of the data set. .. GENERATED FROM PYTHON SOURCE LINES 132-157 .. code-block:: Python nx, ny, nz = image_norm.shape def initialize_barycenter(dataset): mean = torch.zeros(nx, ny, nz).type(dtype) for k in range(len(dataset)): img = dataset[k] mean = mean + img mean = mean / len(dataset) x_i, a_i = img_to_points_cloud(mean) bar_pos, bar_weight = torch.tensor([]).type(dtype), torch.tensor([]).type(dtype) for d in range(3): x_i_d1, x_i_d2 = x_i.clone(), x_i.clone() x_i_d1[:, d], a_i_d1 = x_i_d1[:, d] + 0.25, a_i / 6 x_i_d2[:, d], a_i_d2 = x_i_d2[:, d] - 0.25, a_i / 6 bar_pos, bar_weight = torch.cat((bar_pos, x_i_d1, x_i_d2), 0), torch.cat( (bar_weight, a_i_d1, a_i_d2), 0 ) return bar_pos, bar_weight x_i, a_i = initialize_barycenter(dataset) .. GENERATED FROM PYTHON SOURCE LINES 158-160 The barycenter will be the minimizer of the sum of Sinkhorn distances to the dataset. It is computed through a Lagrangian gradient descent on the particles' positions. .. GENERATED FROM PYTHON SOURCE LINES 160-181 .. code-block:: Python Loss = SamplesLoss("sinkhorn", blur=1, scaling=0.9, debias=False) models = [] x_i.requires_grad = True start = time.time() for j in range(len(dataset)): img_j = dataset[j] y_j, b_j = img_to_points_cloud(img_j) L_ab = Loss(a_i, x_i, b_j, y_j) [g_i] = torch.autograd.grad(L_ab, [x_i]) models.append(x_i - g_i / a_i.view(-1, 1)) a, b, c, d, e = models barycenter = (a + b + c + d + e) / 5 if use_cuda: torch.cuda.synchronize() end = time.time() print("barycenter computed in {:.3f}s.".format(end - start)) .. GENERATED FROM PYTHON SOURCE LINES 182-183 We can plot slices of the computed barycenters .. GENERATED FROM PYTHON SOURCE LINES 183-188 .. code-block:: Python img_barycenter = measure_to_image(barycenter, nx, ny, nz, a_i) plt.figure() plt.imshow(img_barycenter.detach().cpu().numpy()[20, :, :]) plt.show() .. GENERATED FROM PYTHON SOURCE LINES 189-190 Or save the 3D image in .nii format, once put in the same coordinates system as the data images. .. GENERATED FROM PYTHON SOURCE LINES 190-200 .. code-block:: Python linear_transform_inv = np.linalg.inv(affine_transform[:3, :3]) translation_inv = -affine_transform[:3, 3] affine_inv = np.r_[ np.c_[linear_transform_inv, translation_inv], np.array([[0, 0, 0, 1]]) ] barycenter_nib = nib.Nifti1Image( 521 * (img_barycenter / img_barycenter.max()).detach().cpu().numpy(), affine_transform, ) nib.save(barycenter_nib, "barycenter_image.nii") .. _sphx_glr_download__auto_examples_brain_tractograms_track_barycenter.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: track_barycenter.ipynb <track_barycenter.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: track_barycenter.py <track_barycenter.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_