Surface registration

Example of a diffeomorphic matching of surfaces using varifolds metrics: We perform an LDDMM matching of two meshes using the geodesic shooting algorithm.

Define our dataset

Standard imports

import os
import time

import torch
from torch.autograd import grad

import plotly.graph_objs as go

from pykeops.torch import Vi, Vj

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device("cuda:0") if use_cuda else "cpu"
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split(".")[1]  # 'float32'

Import data file, one of :

  • “hippos.pt” : original data (6611 vertices),

  • “hippos_red.pt” : reduced size (1654 vertices),

  • “hippos_reduc.pt” : further reduced (662 vertices),

  • “hippos_reduc_reduc.pt” : further reduced (68 vertices)

if use_cuda:
    datafile = "data/hippos.pt"
else:
    datafile = "data/hippos_reduc_reduc.pt"

Define the kernels

Define Gaussian kernel (K(x,y)b)i=jexp(γxiyj2)bj

def GaussKernel(sigma):
    x, y, b = Vi(0, 3), Vj(1, 3), Vj(2, 3)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp()
    return (K * b).sum_reduction(axis=1)

Define “Gaussian-CauchyBinet” kernel (K(x,y,u,v)b)i=jexp(γxiyj2)ui,vj2bj

def GaussLinKernel(sigma):
    x, y, u, v, b = Vi(0, 3), Vj(1, 3), Vi(2, 3), Vj(3, 3), Vj(4, 1)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp() * (u * v).sum() ** 2
    return (K * b).sum_reduction(axis=1)

Custom ODE solver, for ODE systems which are defined on tuples

def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(
                map(
                    lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
                    x,
                    xdot,
                    xdoti,
                )
            )
            l.append(x)
        return l

    return f

LDDMM implementation

Deformations: diffeomorphism

Hamiltonian system

def Hamiltonian(K):
    def H(p, q):
        return 0.5 * (p * K(q, q, p)).sum()

    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)

    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp

    return HS

Shooting approach

def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)

    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)

    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p, q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)

    return loss

Data attachment term

Varifold data attachment loss for surfaces

# VT: vertices coordinates of target surface,
# FS,FT : Face connectivity of source and target surfaces
# K kernel
def lossVarifoldSurf(FS, VT, FT, K):
    def get_center_length_normal(F, V):
        V0, V1, V2 = (
            V.index_select(0, F[:, 0]),
            V.index_select(0, F[:, 1]),
            V.index_select(0, F[:, 2]),
        )
        centers, normals = (V0 + V1 + V2) / 3, 0.5 * torch.cross(V1 - V0, V2 - V0)
        length = (normals**2).sum(dim=1)[:, None].sqrt()
        return centers, length, normals / length

    CT, LT, NTn = get_center_length_normal(FT, VT)
    cst = (LT * K(CT, CT, NTn, NTn, LT)).sum()

    def loss(VS):
        CS, LS, NSn = get_center_length_normal(FS, VS)
        return (
            cst
            + (LS * K(CS, CS, NSn, NSn, LS)).sum()
            - 2 * (LS * K(CS, CT, NSn, NTn, LT)).sum()
        )

    return loss

Registration

Load the dataset and plot it

VS, FS, VT, FT = torch.load(datafile)
q0 = VS.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
VT = VT.clone().detach().to(dtype=torchdtype, device=torchdeviceId)
FS = FS.clone().detach().to(dtype=torch.long, device=torchdeviceId)
FT = FT.clone().detach().to(dtype=torch.long, device=torchdeviceId)
sigma = torch.tensor([20], dtype=torchdtype, device=torchdeviceId)

x, y, z = (
    q0[:, 0].detach().cpu().numpy(),
    q0[:, 1].detach().cpu().numpy(),
    q0[:, 2].detach().cpu().numpy(),
)
i, j, k = (
    FS[:, 0].detach().cpu().numpy(),
    FS[:, 1].detach().cpu().numpy(),
    FS[:, 2].detach().cpu().numpy(),
)

xt, yt, zt = (
    VT[:, 0].detach().cpu().numpy(),
    VT[:, 1].detach().cpu().numpy(),
    VT[:, 2].detach().cpu().numpy(),
)
it, jt, kt = (
    FT[:, 0].detach().cpu().numpy(),
    FT[:, 1].detach().cpu().numpy(),
    FT[:, 2].detach().cpu().numpy(),
)

save_folder = os.path.join("..", "..", "..", "..", "doc", "_build", "html", "_images")
os.makedirs(save_folder, exist_ok=True)

fig = go.Figure(
    data=[
        go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, color="blue", opacity=0.50),
        go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color="red", opacity=0.50),
    ]
)
fig.write_html(os.path.join(save_folder, "data.html"), auto_open=False)
# sphinx_gallery_thumbnail_path = '_static/plot_LDDMM_Surface_thumb.png'

Define data attachment and LDDMM functional

dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=sigma))
Kv = GaussKernel(sigma=sigma)
loss = LDDMMloss(Kv, dataloss)
/home/code/keops/pykeops/pykeops/tutorials/surface_registration/plot_LDDMM_Surface.py:176: UserWarning:

Using torch.cross without specifying the dim arg is deprecated.
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /pytorch/aten/src/ATen/native/Cross.cpp:62.)

Perform optimization

# initialize momentum vectors
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)
print("performing optimization...")
start = time.time()


def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print("loss", L.detach().cpu().numpy())
    L.backward()
    return L


for i in range(10):
    print("it ", i, ": ", end="")
    optimizer.step(closure)

print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
performing optimization...
it  0 : loss 87667.31
loss 81358.0
loss 21514.5
loss 15458.375
loss 9952.25
loss 7281.875
loss 4681.25
loss 4017.125
loss 3976.75
loss 3928.5
it  1 : loss 3928.5
loss 3887.625
loss 3773.5
loss 3568.125
loss 3267.875
loss 2933.5
loss 2563.0
loss 2287.5
loss 1968.875
loss 1708.25
it  2 : loss 1708.25
loss 1504.125
loss 1403.5
loss 1320.5
loss 1243.0
loss 1162.75
loss 1087.0
loss 1058.625
loss 1043.625
loss 1030.75
it  3 : loss 1030.75
loss 1018.625
loss 997.25
loss 958.0
loss 890.5
loss 846.25
loss 790.5
loss 747.75
loss 728.375
loss 708.0
it  4 : loss 708.0
loss 690.125
loss 675.0
loss 655.0
loss 634.25
loss 600.25
loss 573.75
loss 561.5
loss 556.0
loss 552.5
it  5 : loss 552.5
loss 546.0
loss 537.125
loss 524.25
loss 505.0
loss 486.75
loss 467.625
loss 453.0
loss 441.75
loss 430.375
it  6 : loss 430.375
loss 424.75
loss 420.5
loss 418.625
loss 417.0
loss 415.125
loss 412.0
loss 407.375
loss 403.5
loss 398.0
it  7 : loss 398.0
loss 394.0
loss 391.25
loss 390.375
loss 389.875
loss 388.875
loss 387.375
loss 383.625
loss 376.125
loss 362.125
it  8 : loss 362.125
loss 344.625
loss 328.0
loss 347.25
loss 314.5
loss 313.25
loss 312.375
loss 312.0
loss 311.75
loss 311.5
it  9 : loss 311.5
loss 311.125
loss 310.125
loss 308.5
loss 305.25
loss 301.375
loss 298.625
loss 297.875
loss 297.75
loss 297.75
Optimization (L-BFGS) time:  23.87  seconds

Display output

The animated version of the deformation:

nt = 15
listpq = Shooting(p0, q0, Kv, nt=nt)

The code to generate the figure:

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()

# Create figure
fig = go.Figure()
fig.add_trace(
    go.Mesh3d(
        visible=True,
        x=VTnp[:, 0],
        y=VTnp[:, 1],
        z=VTnp[:, 2],
        i=FTnp[:, 0],
        j=FTnp[:, 1],
        k=FTnp[:, 2],
    )
)

# Add traces, one for each slider step
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()
    fig.add_trace(
        go.Mesh3d(
            visible=False,
            x=qnp[:, 0],
            y=qnp[:, 1],
            z=qnp[:, 2],
            i=FSnp[:, 0],
            j=FSnp[:, 1],
            k=FSnp[:, 2],
        )
    )

# Make 10th trace visible
fig.data[1].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data) - 1):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][0] = True
    step["args"][1][i + 1] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [
    dict(active=0, currentvalue={"prefix": "time: "}, pad={"t": 20}, steps=steps)
]

fig.update_layout(sliders=sliders)

fig.write_html(os.path.join(save_folder, "results.html"), auto_open=False)

Total running time of the script: (0 minutes 24.125 seconds)

Gallery generated by Sphinx-Gallery