.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/brain_tractograms/track_barycenter.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_brain_tractograms_track_barycenter.py: Create an atlas using Wasserstein barycenters ================================================== In this tutorial, we compute the barycenter of a dataset of probability tracks. The barycenter is computed as the Fréchet mean for the Sinkhorn divergence, using a Lagrangian optimization scheme. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Setup --------------------- .. GENERATED FROM PYTHON SOURCE LINES 12-36 .. code-block:: default import numpy as np import matplotlib.pyplot as plt from scipy import misc from sklearn.neighbors import KernelDensity from torch.nn.functional import avg_pool2d import torch from geomloss import SamplesLoss import time use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor from scipy.interpolate import RegularGridInterpolator import gzip import shutil import pdb import nibabel as nib import matplotlib.pyplot as plt .. GENERATED FROM PYTHON SOURCE LINES 37-41 Dataset ~~~~~~~~~~~~~~~~~~ In this tutorial, we work with probability tracks, that can be understood as normalized 3D images. We will compute the Wasserstein barycenter of this dataset. .. GENERATED FROM PYTHON SOURCE LINES 41-96 .. code-block:: default import os def fetch_file(name): if not os.path.exists(f"data/{name}.nii.gz"): import urllib.request print("Fetching the atlas... ", end="", flush=True) urllib.request.urlretrieve( f"https://www.kernel-operations.io/data/{name}.nii.gz", f"data/{name}.nii.gz", ) with gzip.open(f"data/{name}.nii.gz", "rb") as f_in: with open(f"data/{name}.nii", "wb") as f_out: shutil.copyfileobj(f_in, f_out) print("Done.") for i in range(5): fetch_file(f"manual_ifof{i+1}") affine_transform = nib.load("data/manual_ifof1.nii").affine # load data in the nii format to a 3D, normalized array. def load_data_nii(fname): img = nib.load(fname) affine_mat = img.affine hdr = img.header data = img.get_fdata() data_norm = data / np.max(data) data_norm = torch.from_numpy(data_norm).type(dtype) return data_norm def grid(nx, ny, nz): x, y, z = torch.meshgrid( torch.arange(0.0, nx).type(dtype), torch.arange(0.0, ny).type(dtype), torch.arange(0.0, nz).type(dtype), indexing="ij", ) return torch.stack((x, y, z), dim=3).view(-1, 3).detach().cpu().numpy() # load the data set (here, we have 5 subjects) dataset = [] for i in range(5): fname = "data/manual_ifof" + str(i + 1) + ".nii" image_norm = load_data_nii(fname) print(image_norm.shape) dataset.append(image_norm) .. GENERATED FROM PYTHON SOURCE LINES 97-98 In this tutorial, we work with 3D images, understood as densities on the cube. .. GENERATED FROM PYTHON SOURCE LINES 98-118 .. code-block:: default def img_to_points_cloud(data_norm): # normalized images (between 0 and 1) nx, ny, nz = data_norm.shape ind = data_norm.nonzero() indx = ind[:, 0] indy = ind[:, 1] indz = ind[:, 2] data_norm = data_norm / data_norm.sum() a_i = data_norm[indx, indy, indz] return ind.type(dtype), a_i def measure_to_image(x, nx, ny, nz, weights=None): bins = (x[:, 2]).floor() + nz * (x[:, 1]).floor() + nz * ny * (x[:, 0]).floor() count = bins.int().bincount(weights=weights, minlength=nx * ny * nz) return count.view(nx, ny, nz) .. GENERATED FROM PYTHON SOURCE LINES 119-120 To perform our computations, we turn these 3D arrays into weighted point cloud, regularly spaced in the grid. .. GENERATED FROM PYTHON SOURCE LINES 120-129 .. code-block:: default a, b = img_to_points_cloud(dataset[0]), img_to_points_cloud(dataset[1]) c, d, e = ( img_to_points_cloud(dataset[2]), img_to_points_cloud(dataset[3]), img_to_points_cloud(dataset[4]), ) .. GENERATED FROM PYTHON SOURCE LINES 130-131 We initialize the barycenter as an upsampled, arithmetic mean of the data set. .. GENERATED FROM PYTHON SOURCE LINES 131-156 .. code-block:: default nx, ny, nz = image_norm.shape def initialize_barycenter(dataset): mean = torch.zeros(nx, ny, nz).type(dtype) for k in range(len(dataset)): img = dataset[k] mean = mean + img mean = mean / len(dataset) x_i, a_i = img_to_points_cloud(mean) bar_pos, bar_weight = torch.tensor([]).type(dtype), torch.tensor([]).type(dtype) for d in range(3): x_i_d1, x_i_d2 = x_i.clone(), x_i.clone() x_i_d1[:, d], a_i_d1 = x_i_d1[:, d] + 0.25, a_i / 6 x_i_d2[:, d], a_i_d2 = x_i_d2[:, d] - 0.25, a_i / 6 bar_pos, bar_weight = torch.cat((bar_pos, x_i_d1, x_i_d2), 0), torch.cat( (bar_weight, a_i_d1, a_i_d2), 0 ) return bar_pos, bar_weight x_i, a_i = initialize_barycenter(dataset) .. GENERATED FROM PYTHON SOURCE LINES 157-159 The barycenter will be the minimizer of the sum of Sinkhorn distances to the dataset. It is computed through a Lagrangian gradient descent on the particles' positions. .. GENERATED FROM PYTHON SOURCE LINES 159-180 .. code-block:: default Loss = SamplesLoss("sinkhorn", blur=1, scaling=0.9, debias=False) models = [] x_i.requires_grad = True start = time.time() for j in range(len(dataset)): img_j = dataset[j] y_j, b_j = img_to_points_cloud(img_j) L_ab = Loss(a_i, x_i, b_j, y_j) [g_i] = torch.autograd.grad(L_ab, [x_i]) models.append(x_i - g_i / a_i.view(-1, 1)) a, b, c, d, e = models barycenter = (a + b + c + d + e) / 5 if use_cuda: torch.cuda.synchronize() end = time.time() print("barycenter computed in {:.3f}s.".format(end - start)) .. GENERATED FROM PYTHON SOURCE LINES 181-182 We can plot slices of the computed barycenters .. GENERATED FROM PYTHON SOURCE LINES 182-187 .. code-block:: default img_barycenter = measure_to_image(barycenter, nx, ny, nz, a_i) plt.figure() plt.imshow(img_barycenter.detach().cpu().numpy()[20, :, :]) plt.show() .. GENERATED FROM PYTHON SOURCE LINES 188-189 Or save the 3D image in .nii format, once put in the same coordinates system as the data images. .. GENERATED FROM PYTHON SOURCE LINES 189-199 .. code-block:: default linear_transform_inv = np.linalg.inv(affine_transform[:3, :3]) translation_inv = -affine_transform[:3, 3] affine_inv = np.r_[ np.c_[linear_transform_inv, translation_inv], np.array([[0, 0, 0, 1]]) ] barycenter_nib = nib.Nifti1Image( 521 * (img_barycenter / img_barycenter.max()).detach().cpu().numpy(), affine_transform, ) nib.save(barycenter_nib, "barycenter_image.nii") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download__auto_examples_brain_tractograms_track_barycenter.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: track_barycenter.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: track_barycenter.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_