Create an atlas using Wasserstein barycenters

In this tutorial, we compute the barycenter of a dataset of probability tracks. The barycenter is computed as the Fréchet mean for the Sinkhorn divergence, using a Lagrangian optimization scheme.

Setup

import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
from sklearn.neighbors import KernelDensity
from torch.nn.functional import avg_pool2d
import torch
from geomloss import SamplesLoss
import time

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
from scipy.interpolate import RegularGridInterpolator


import gzip
import shutil
import pdb


import nibabel as nib
import matplotlib.pyplot as plt

Dataset

In this tutorial, we work with probability tracks, that can be understood as normalized 3D images. We will compute the Wasserstein barycenter of this dataset.

import os


def fetch_file(name):
    if not os.path.exists(f"data/{name}.nii.gz"):
        import urllib.request

        print("Fetching the atlas... ", end="", flush=True)
        urllib.request.urlretrieve(
            f"https://www.kernel-operations.io/data/{name}.nii.gz",
            f"data/{name}.nii.gz",
        )
        with gzip.open(f"data/{name}.nii.gz", "rb") as f_in:
            with open(f"data/{name}.nii", "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)
        print("Done.")


for i in range(5):
    fetch_file(f"manual_ifof{i+1}")


affine_transform = nib.load("data/manual_ifof1.nii").affine

# load data in the nii format to a 3D, normalized array.
def load_data_nii(fname):
    img = nib.load(fname)
    affine_mat = img.affine
    hdr = img.header
    data = img.get_fdata()
    data_norm = data / np.max(data)
    data_norm = torch.from_numpy(data_norm).type(dtype)
    return data_norm


def grid(nx, ny, nz):
    x, y, z = torch.meshgrid(
        torch.arange(0.0, nx).type(dtype),
        torch.arange(0.0, ny).type(dtype),
        torch.arange(0.0, nz).type(dtype),
        indexing="ij",
    )
    return torch.stack((x, y, z), dim=3).view(-1, 3).detach().cpu().numpy()


# load the data set (here, we have 5 subjects)
dataset = []
for i in range(5):
    fname = "data/manual_ifof" + str(i + 1) + ".nii"
    image_norm = load_data_nii(fname)
    print(image_norm.shape)
    dataset.append(image_norm)

In this tutorial, we work with 3D images, understood as densities on the cube.

def img_to_points_cloud(data_norm):  # normalized images (between 0 and 1)
    nx, ny, nz = data_norm.shape
    ind = data_norm.nonzero()
    indx = ind[:, 0]
    indy = ind[:, 1]
    indz = ind[:, 2]
    data_norm = data_norm / data_norm.sum()
    a_i = data_norm[indx, indy, indz]

    return ind.type(dtype), a_i


def measure_to_image(x, nx, ny, nz, weights=None):
    bins = (x[:, 2]).floor() + nz * (x[:, 1]).floor() + nz * ny * (x[:, 0]).floor()
    count = bins.int().bincount(weights=weights, minlength=nx * ny * nz)
    return count.view(nx, ny, nz)

To perform our computations, we turn these 3D arrays into weighted point cloud, regularly spaced in the grid.

a, b = img_to_points_cloud(dataset[0]), img_to_points_cloud(dataset[1])
c, d, e = (
    img_to_points_cloud(dataset[2]),
    img_to_points_cloud(dataset[3]),
    img_to_points_cloud(dataset[4]),
)

We initialize the barycenter as an upsampled, arithmetic mean of the data set.

nx, ny, nz = image_norm.shape


def initialize_barycenter(dataset):
    mean = torch.zeros(nx, ny, nz).type(dtype)
    for k in range(len(dataset)):
        img = dataset[k]
        mean = mean + img
    mean = mean / len(dataset)
    x_i, a_i = img_to_points_cloud(mean)
    bar_pos, bar_weight = torch.tensor([]).type(dtype), torch.tensor([]).type(dtype)
    for d in range(3):
        x_i_d1, x_i_d2 = x_i.clone(), x_i.clone()
        x_i_d1[:, d], a_i_d1 = x_i_d1[:, d] + 0.25, a_i / 6
        x_i_d2[:, d], a_i_d2 = x_i_d2[:, d] - 0.25, a_i / 6
        bar_pos, bar_weight = torch.cat((bar_pos, x_i_d1, x_i_d2), 0), torch.cat(
            (bar_weight, a_i_d1, a_i_d2), 0
        )
    return bar_pos, bar_weight


x_i, a_i = initialize_barycenter(dataset)

The barycenter will be the minimizer of the sum of Sinkhorn distances to the dataset. It is computed through a Lagrangian gradient descent on the particles’ positions.

Loss = SamplesLoss("sinkhorn", blur=1, scaling=0.9, debias=False)
models = []
x_i.requires_grad = True


start = time.time()
for j in range(len(dataset)):
    img_j = dataset[j]
    y_j, b_j = img_to_points_cloud(img_j)
    L_ab = Loss(a_i, x_i, b_j, y_j)
    [g_i] = torch.autograd.grad(L_ab, [x_i])
    models.append(x_i - g_i / a_i.view(-1, 1))

a, b, c, d, e = models
barycenter = (a + b + c + d + e) / 5
if use_cuda:
    torch.cuda.synchronize()
end = time.time()
print("barycenter computed in {:.3f}s.".format(end - start))

We can plot slices of the computed barycenters

img_barycenter = measure_to_image(barycenter, nx, ny, nz, a_i)
plt.figure()
plt.imshow(img_barycenter.detach().cpu().numpy()[20, :, :])
plt.show()

Or save the 3D image in .nii format, once put in the same coordinates system as the data images.

linear_transform_inv = np.linalg.inv(affine_transform[:3, :3])
translation_inv = -affine_transform[:3, 3]
affine_inv = np.r_[
    np.c_[linear_transform_inv, translation_inv], np.array([[0, 0, 0, 1]])
]
barycenter_nib = nib.Nifti1Image(
    521 * (img_barycenter / img_barycenter.max()).detach().cpu().numpy(),
    affine_transform,
)
nib.save(barycenter_nib, "barycenter_image.nii")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery