Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import time

import torch
from torch.autograd import grad

import plotly.graph_objs as go

from pykeops.torch import Vi, Vj

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device("cuda:0") if use_cuda else "cpu"
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split(".")[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = "data/hippos.pt"
else:
    datafile = "data/hippos_reduc_reduc.pt"

Define the kernels

Define Gaussian kernel \((K(x,y)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2)b_j\)

def GaussKernel(sigma):
    x, y, b = Vi(0, 3), Vj(1, 3), Vj(2, 3)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp()
    return (K * b).sum_reduction(axis=1)

Define “Gaussian-CauchyBinet” kernel \((K(x,y,u,v)b)_i = \sum_j \exp(-\gamma\|x_i-y_j\|^2) \langle u_i,v_j\rangle^2 b_j\)

def GaussLinKernel(sigma):
    x, y, u, v, b = Vi(0, 3), Vj(1, 3), Vi(2, 3), Vj(3, 3), Vj(4, 1)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp() * (u * v).sum() ** 2
    return (K * b).sum_reduction(axis=1)

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(
                map(
                    lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
                    x,
                    xdot,
                    xdoti,
                )
            )
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return 0.5 * (p * K(q, q, p)).sum()

    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)

    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp

    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)

    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)

    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p, q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)

    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def get_center_length_normal(F, V):
        V0, V1, V2 = (
            V.index_select(0, F[:, 0]),
            V.index_select(0, F[:, 1]),
            V.index_select(0, F[:, 2]),
        )
        centers, normals = (V0 + V1 + V2) / 3, 0.5 * torch.cross(V1 - V0, V2 - V0)
        length = (normals**2).sum(dim=1)[:, None].sqrt()
        return centers, length, normals / length

    CT, LT, NTn = get_center_length_normal(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = get_center_length_normal(FS, VS)
        return (
            cst
            + (LS * K(CS, CS, NSn, NSn, LS)).sum()
            - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()
        )

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

x, y, z = (
    q0[:, 0].detach().cpu().numpy(),
    q0[:, 1].detach().cpu().numpy(),
    q0[:, 2].detach().cpu().numpy(),
)
i, j, k = (
    FS[:, 0].detach().cpu().numpy(),
    FS[:, 1].detach().cpu().numpy(),
    FS[:, 2].detach().cpu().numpy(),
)

xt, yt, zt = (
    VT[:, 0].detach().cpu().numpy(),
    VT[:, 1].detach().cpu().numpy(),
    VT[:, 2].detach().cpu().numpy(),
)
it, jt, kt = (
    FT[:, 0].detach().cpu().numpy(),
    FT[:, 1].detach().cpu().numpy(),
    FT[:, 2].detach().cpu().numpy(),
)

save_folder = os.path.join("..", "..", "..", "..", "doc", "_build", "html", "_images")
os.makedirs(save_folder, exist_ok=True)

fig = go.Figure(
    data=[
        go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, color="blue", opacity=0.50),
        go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color="red", opacity=0.50),
    ]
)
fig.write_html(os.path.join(save_folder, "data.html"), auto_open=False)
# sphinx_gallery_thumbnail_path = '_static/plot_LDDMM_Surface_thumb.png'

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)
/home/code/keops/pykeops/pykeops/tutorials/surface_registration/plot_LDDMM_Surface.py:176: UserWarning:

Using torch.cross without specifying the dim arg is deprecated.
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /opt/conda/conda-bld/pytorch_1704987288773/work/aten/src/ATen/native/Cross.cpp:63.)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print("performing optimization...")
start = time.time()


def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print("loss", L.detach().cpu().numpy())
    L.backward()
    return L


for i in range(10):
    print("it ", i, ": ", end="")
    optimizer.step(closure)

print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
performing optimization...
it  0 : loss 87667.31
loss 81358.0
loss 21514.5
loss 15458.5
loss 9952.25
loss 7281.75
loss 4681.125
loss 4017.0
loss 3976.875
loss 3928.5
it  1 : loss 3928.5
loss 3887.625
loss 3773.5
loss 3568.0
loss 3267.875
loss 2933.5
loss 2562.75
loss 2287.625
loss 1969.375
loss 1708.25
it  2 : loss 1708.25
loss 1504.125
loss 1403.5
loss 1321.0
loss 1244.0
loss 1163.875
loss 1087.125
loss 1057.5
loss 1042.375
loss 1028.375
it  3 : loss 1028.375
loss 1015.5
loss 993.75
loss 952.75
loss 881.375
loss 839.0
loss 780.375
loss 737.5
loss 720.125
loss 703.0
it  4 : loss 703.0
loss 687.75
loss 675.625
loss 657.625
loss 627.0
loss 595.25
loss 570.25
loss 557.25
loss 551.5
loss 547.125
it  5 : loss 547.125
loss 535.625
loss 521.875
loss 511.75
loss 491.875
loss 465.375
loss 443.5
loss 435.5
loss 428.0
loss 423.5
it  6 : loss 423.5
loss 422.25
loss 421.25
loss 420.0
loss 417.0
loss 411.125
loss 404.625
loss 399.125
loss 396.125
loss 393.75
it  7 : loss 393.75
loss 392.5
loss 391.5
loss 390.75
loss 389.875
loss 388.75
loss 386.0
loss 380.625
loss 369.5
loss 351.0
it  8 : loss 351.0
loss 330.25
loss 367.125
loss 314.5
loss 313.125
loss 312.25
loss 312.25
it  9 : loss 312.25
loss 312.0
loss 311.375
loss 310.5
loss 309.0
loss 307.125
loss 304.25
loss 302.0
loss 299.375
loss 297.125
Optimization (L-BFGS) time:  26.86  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the figure:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()

# Create figure
fig = go.Figure()
fig.add_trace(
    go.Mesh3d(
        visible=True,
        x=VTnp[:, 0],
        y=VTnp[:, 1],
        z=VTnp[:, 2],
        i=FTnp[:, 0],
        j=FTnp[:, 1],
        k=FTnp[:, 2],
    )
)

# Add traces, one for each slider step
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()
    fig.add_trace(
        go.Mesh3d(
            visible=False,
            x=qnp[:, 0],
            y=qnp[:, 1],
            z=qnp[:, 2],
            i=FSnp[:, 0],
            j=FSnp[:, 1],
            k=FSnp[:, 2],
        )
    )

# Make 10th trace visible
fig.data[1].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data) - 1):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][0] = True
    step["args"][1][i + 1] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [
    dict(active=0, currentvalue={"prefix": "time: "}, pad={"t": 20}, steps=steps)
]

fig.update_layout(sliders=sliders)

fig.write_html(os.path.join(save_folder, "results.html"), auto_open=False)

Total running time of the script: (0 minutes 27.279 seconds)

Gallery generated by Sphinx-Gallery