.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/optimal_transport/plot_interpolation_3D.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_optimal_transport_plot_interpolation_3D.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_optimal_transport_plot_interpolation_3D.py:

Creating a fancy interpolation video between 3D meshes.
==============================================================

N.B.: I am currently very busy writing my PhD thesis. Comments will come soon!

.. GENERATED FROM PYTHON SOURCE LINES 9-14

Setup
----------------------

Standard imports.


.. GENERATED FROM PYTHON SOURCE LINES 14-31

.. code-block:: Python



    import numpy as np
    import torch
    import os

    use_cuda = torch.cuda.is_available()
    tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    numpy = lambda x: x.detach().cpu().numpy()

    from matplotlib import pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    from geomloss import SamplesLoss
    from pykeops.torch import LazyTensor









.. GENERATED FROM PYTHON SOURCE LINES 32-33

Utility: turn a triangle mesh into a weighted point cloud.

.. GENERATED FROM PYTHON SOURCE LINES 33-55

.. code-block:: Python



    def to_measure(points, triangles):
        """Turns a triangle into a weighted point cloud."""

        # Our mesh is given as a collection of ABC triangles:
        A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]

        # Locations and weights of our Dirac atoms:
        X = (A + B + C) / 3  # centers of the faces
        S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2  # areas of the faces

        print(
            "File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format(
                len(X)
            )
        )

        # We return a (normalized) vector of weights + a "list" of points
        return tensor(S / np.sum(S)), tensor(X)









.. GENERATED FROM PYTHON SOURCE LINES 56-58

Utility: load ".ply" mesh file.


.. GENERATED FROM PYTHON SOURCE LINES 58-75

.. code-block:: Python


    from plyfile import PlyData, PlyElement


    def load_ply_file(fname):
        """Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face."""

        # Load the data, and read the connectivity information:
        plydata = PlyData.read(fname)
        triangles = np.vstack(plydata["face"].data["vertex_indices"])

        # Normalize the point cloud, as specified by the user:
        points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]])

        return to_measure(points, triangles)









.. GENERATED FROM PYTHON SOURCE LINES 76-78

Utility: load ".nii" volume file.


.. GENERATED FROM PYTHON SOURCE LINES 78-93

.. code-block:: Python


    import SimpleITK as sitk
    from skimage.measure import marching_cubes


    def load_nii_file(fname, threshold=0.5):
        """Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud."""

        mask = sitk.GetArrayFromImage(sitk.ReadImage(fname))
        # mask = skimage.transform.downscale_local_mean(mask, (4,4,4))
        verts, faces, normals, values = marching_cubes(mask, threshold)

        return to_measure(verts, faces)









.. GENERATED FROM PYTHON SOURCE LINES 94-96

Synthetic sphere - a typical source measure:


.. GENERATED FROM PYTHON SOURCE LINES 96-113

.. code-block:: Python



    def create_sphere(n_samples=1000):
        """Creates a uniform sample on the unit sphere."""
        n_samples = int(n_samples)

        indices = np.arange(0, n_samples, dtype=float) + 0.5
        phi = np.arccos(1 - 2 * indices / n_samples)
        theta = np.pi * (1 + 5**0.5) * indices

        x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)
        points = np.vstack((x, y, z)).T
        weights = np.ones(n_samples) / n_samples

        return tensor(weights), tensor(points)









.. GENERATED FROM PYTHON SOURCE LINES 114-116

Simple (slow) display routine:


.. GENERATED FROM PYTHON SOURCE LINES 116-132

.. code-block:: Python



    def display_cloud(ax, measure, color):
        w_i, x_i = numpy(measure[0]), numpy(measure[1])

        ax.view_init(elev=110, azim=-90)
        # ax.set_aspect('equal')

        weights = w_i / w_i.sum()
        ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color)

        ax.axes.set_xlim3d(left=-1.4, right=1.4)
        ax.axes.set_ylim3d(bottom=-1.4, top=1.4)
        ax.axes.set_zlim3d(bottom=-1.4, top=1.4)









.. GENERATED FROM PYTHON SOURCE LINES 133-134

Save the output as a VTK folder, to be rendered with Paraview:

.. GENERATED FROM PYTHON SOURCE LINES 134-150

.. code-block:: Python


    folder = "output/wasserstein_3D/"
    os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True)

    import pyvista as pv


    def save_vtk(fname, points, colors):
        """N.B.: Paraview is a good VTK viewer, which supports ray-tracing."""

        # Use PyVista to save the point cloud as a VTK file:
        points = pv.PolyData(points)
        points["colors"] = colors
        points.save(folder + fname)









.. GENERATED FROM PYTHON SOURCE LINES 151-155

Data
----------------

Shall we work on subsampled data or at full resolution?

.. GENERATED FROM PYTHON SOURCE LINES 155-163

.. code-block:: Python


    fast_demo = False if use_cuda else True

    if use_cuda:
        Npoints = 1e4 if fast_demo else 2e5
    else:
        Npoints = 1e3








.. GENERATED FROM PYTHON SOURCE LINES 164-165

Create a reference template:

.. GENERATED FROM PYTHON SOURCE LINES 165-168

.. code-block:: Python


    template = create_sphere(Npoints)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/code/geomloss/geomloss/examples/optimal_transport/plot_interpolation_3D.py:110: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
      return tensor(weights), tensor(points)




.. GENERATED FROM PYTHON SOURCE LINES 169-171

Use color labels to track the particles:


.. GENERATED FROM PYTHON SOURCE LINES 171-177

.. code-block:: Python


    K = 12
    colors = (K * template[1][:, 0]).cos()
    colors = colors.view(-1).detach().cpu().numpy()









.. GENERATED FROM PYTHON SOURCE LINES 178-180

Fetch the data:


.. GENERATED FROM PYTHON SOURCE LINES 180-198

.. code-block:: Python



    os.makedirs(os.path.dirname("data/"), exist_ok=True)
    if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"):
        print("Fetching the data... ", end="", flush=True)
        import urllib.request

        urllib.request.urlretrieve(
            "http://www.kernel-operations.io/data/wasserstein_3D_models.zip",
            "data/wasserstein_3D_models.zip",
        )

        import shutil

        shutil.unpack_archive("data/wasserstein_3D_models.zip", "data")
        print("Done.")









.. GENERATED FROM PYTHON SOURCE LINES 199-200

Load the data on the GPU:

.. GENERATED FROM PYTHON SOURCE LINES 200-210

.. code-block:: Python



    print("Loading the data:")
    # N.B.: Since Plyfile is far from being optimized, this may take some time!
    targets = [
        load_ply_file("data/wasserstein_3D_models/Stanford_dragon_200k.ply"),
        load_ply_file("data/wasserstein_3D_models/vertebrae_400k_biol260_sketchfab_CC.ply"),
        load_nii_file("data/wasserstein_3D_models/brain.nii.gz"),
    ]





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Loading the data:
    File loaded, and encoded as the weighted sum of 202,520 atoms in 3D.
    File loaded, and encoded as the weighted sum of 394,237 atoms in 3D.
    File loaded, and encoded as the weighted sum of 382,074 atoms in 3D.




.. GENERATED FROM PYTHON SOURCE LINES 211-212

Normalize and subsample everyone, if required:

.. GENERATED FROM PYTHON SOURCE LINES 212-239

.. code-block:: Python



    def normalize(measure, n=None):
        """Reduce a point cloud to at most n points and normalize the weights and point cloud."""
        weights, locations = measure
        N = len(weights)

        if n is not None and n < N:
            n = int(n)
            indices = torch.randperm(N)
            indices = indices[:n]
            weights, locations = weights[indices], locations[indices]

        weights = weights / weights.sum()
        weights, locations = weights.contiguous(), locations.contiguous()

        # Center, normalize the point cloud
        mean = (weights.view(-1, 1) * locations).sum(dim=0)
        locations -= mean
        std = (weights.view(-1) * (locations**2).sum(dim=1).view(-1)).sum().sqrt()
        locations /= std

        return weights, locations


    targets = [normalize(t, n=Npoints) for t in targets]








.. GENERATED FROM PYTHON SOURCE LINES 240-241

Fine tuning:

.. GENERATED FROM PYTHON SOURCE LINES 241-250

.. code-block:: Python


    template = template[0], template[1] / 2 + tensor(
        [0.5, 0.0, 0.0]
    )  # Smaller sphere, towards the back of the dragon
    targets[1] = targets[1][0], targets[1][1] @ tensor(
        [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
    )  # Turn the vertebra
    targets[2] = targets[2][0], -targets[2][1]  # Flip the brain








.. GENERATED FROM PYTHON SOURCE LINES 251-255

Optimal Transport matchings
--------------------------------

Define our solver:

.. GENERATED FROM PYTHON SOURCE LINES 255-288

.. code-block:: Python



    import time

    Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1)


    def OT_registration(source, target, name):
        a, x = source  # weights, locations
        b, y = target  # weights, locations

        x.requires_grad = True
        z = x.clone()  # Moving point cloud

        if use_cuda:
            torch.cuda.synchronize()
        start = time.time()

        nits = 4 if fast_demo else 10

        for it in range(nits):
            wasserstein_zy = Loss(a, z, b, y)
            [grad_z] = torch.autograd.grad(wasserstein_zy, [z])
            z -= grad_z / a[:, None]  # Apply the regularized Brenier map

            # save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors)

        end = time.time()
        print("Registered {} in {:.3f}s.".format(name, end - start))

        return z









.. GENERATED FROM PYTHON SOURCE LINES 289-291

Register the source onto the targets:


.. GENERATED FROM PYTHON SOURCE LINES 291-297

.. code-block:: Python


    matchings = [
        OT_registration(template, target, f"shape{i+1}")
        for (i, target) in enumerate(targets)
    ]





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Registered shape1 in 3.134s.
    Registered shape2 in 3.562s.
    Registered shape3 in 2.661s.




.. GENERATED FROM PYTHON SOURCE LINES 298-299

Display our matchings:

.. GENERATED FROM PYTHON SOURCE LINES 299-316

.. code-block:: Python


    for i, (matching, target) in enumerate(zip(matchings, targets)):
        fig = plt.figure(figsize=(6, 6))
        plt.set_cmap("hsv")

        ax = fig.add_subplot(1, 1, 1, projection="3d")

        display_cloud(ax, (template[0], matching), colors)
        display_cloud(ax, target, "blue")
        ax.set_title(
            "Registered (N={:,}) and target {} (M={:,}) point clouds".format(
                len(matching), i + 1, len(target[0])
            )
        )
        plt.tight_layout()





.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png
         :alt: Registered (N=200,000) and target 1 (M=200,000) point clouds
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png
         :alt: Registered (N=200,000) and target 2 (M=200,000) point clouds
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png
         :alt: Registered (N=200,000) and target 3 (M=200,000) point clouds
         :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png
         :class: sphx-glr-multi-img





.. GENERATED FROM PYTHON SOURCE LINES 317-321

Movie
-------------

Save them as a collection of VTK files:

.. GENERATED FROM PYTHON SOURCE LINES 321-347

.. code-block:: Python


    FPS = 32 if fast_demo else 32

    source = template[1]
    pairs = [
        (source, source),
        (source, matchings[0]),
        (matchings[0], matchings[0]),
        (matchings[0], matchings[1]),
        (matchings[1], matchings[1]),
        (matchings[1], matchings[2]),
        (matchings[2], matchings[2]),
        (matchings[2], source),
    ]

    frame = 0

    print("Save as a VTK movie...", end="", flush=True)
    for A, B in pairs:
        A, B = numpy(A), numpy(B)
        for t in np.linspace(0, 1, FPS):
            save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors)
            frame += 1

    print("Done.")
    plt.show()




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Save as a VTK movie...Done.





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 42.690 seconds)


.. _sphx_glr_download__auto_examples_optimal_transport_plot_interpolation_3D.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_interpolation_3D.ipynb <plot_interpolation_3D.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_interpolation_3D.py <plot_interpolation_3D.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_