.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/optimal_transport/plot_interpolation_3D.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download__auto_examples_optimal_transport_plot_interpolation_3D.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_optimal_transport_plot_interpolation_3D.py: Creating a fancy interpolation video between 3D meshes. ============================================================== N.B.: I am currently very busy writing my PhD thesis. Comments will come soon! .. GENERATED FROM PYTHON SOURCE LINES 9-14 Setup ---------------------- Standard imports. .. GENERATED FROM PYTHON SOURCE LINES 14-31 .. code-block:: Python import numpy as np import torch import os use_cuda = torch.cuda.is_available() tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor numpy = lambda x: x.detach().cpu().numpy() from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from geomloss import SamplesLoss from pykeops.torch import LazyTensor .. GENERATED FROM PYTHON SOURCE LINES 32-33 Utility: turn a triangle mesh into a weighted point cloud. .. GENERATED FROM PYTHON SOURCE LINES 33-55 .. code-block:: Python def to_measure(points, triangles): """Turns a triangle into a weighted point cloud.""" # Our mesh is given as a collection of ABC triangles: A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]] # Locations and weights of our Dirac atoms: X = (A + B + C) / 3 # centers of the faces S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2 # areas of the faces print( "File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format( len(X) ) ) # We return a (normalized) vector of weights + a "list" of points return tensor(S / np.sum(S)), tensor(X) .. GENERATED FROM PYTHON SOURCE LINES 56-58 Utility: load ".ply" mesh file. .. GENERATED FROM PYTHON SOURCE LINES 58-75 .. code-block:: Python from plyfile import PlyData, PlyElement def load_ply_file(fname): """Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face.""" # Load the data, and read the connectivity information: plydata = PlyData.read(fname) triangles = np.vstack(plydata["face"].data["vertex_indices"]) # Normalize the point cloud, as specified by the user: points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]]) return to_measure(points, triangles) .. GENERATED FROM PYTHON SOURCE LINES 76-78 Utility: load ".nii" volume file. .. GENERATED FROM PYTHON SOURCE LINES 78-93 .. code-block:: Python import SimpleITK as sitk from skimage.measure import marching_cubes def load_nii_file(fname, threshold=0.5): """Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud.""" mask = sitk.GetArrayFromImage(sitk.ReadImage(fname)) # mask = skimage.transform.downscale_local_mean(mask, (4,4,4)) verts, faces, normals, values = marching_cubes(mask, threshold) return to_measure(verts, faces) .. GENERATED FROM PYTHON SOURCE LINES 94-96 Synthetic sphere - a typical source measure: .. GENERATED FROM PYTHON SOURCE LINES 96-113 .. code-block:: Python def create_sphere(n_samples=1000): """Creates a uniform sample on the unit sphere.""" n_samples = int(n_samples) indices = np.arange(0, n_samples, dtype=float) + 0.5 phi = np.arccos(1 - 2 * indices / n_samples) theta = np.pi * (1 + 5**0.5) * indices x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi) points = np.vstack((x, y, z)).T weights = np.ones(n_samples) / n_samples return tensor(weights), tensor(points) .. GENERATED FROM PYTHON SOURCE LINES 114-116 Simple (slow) display routine: .. GENERATED FROM PYTHON SOURCE LINES 116-132 .. code-block:: Python def display_cloud(ax, measure, color): w_i, x_i = numpy(measure[0]), numpy(measure[1]) ax.view_init(elev=110, azim=-90) # ax.set_aspect('equal') weights = w_i / w_i.sum() ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color) ax.axes.set_xlim3d(left=-1.4, right=1.4) ax.axes.set_ylim3d(bottom=-1.4, top=1.4) ax.axes.set_zlim3d(bottom=-1.4, top=1.4) .. GENERATED FROM PYTHON SOURCE LINES 133-134 Save the output as a VTK folder, to be rendered with Paraview: .. GENERATED FROM PYTHON SOURCE LINES 134-150 .. code-block:: Python folder = "output/wasserstein_3D/" os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True) import pyvista as pv def save_vtk(fname, points, colors): """N.B.: Paraview is a good VTK viewer, which supports ray-tracing.""" # Use PyVista to save the point cloud as a VTK file: points = pv.PolyData(points) points["colors"] = colors points.save(folder + fname) .. GENERATED FROM PYTHON SOURCE LINES 151-155 Data ---------------- Shall we work on subsampled data or at full resolution? .. GENERATED FROM PYTHON SOURCE LINES 155-163 .. code-block:: Python fast_demo = False if use_cuda else True if use_cuda: Npoints = 1e4 if fast_demo else 2e5 else: Npoints = 1e3 .. GENERATED FROM PYTHON SOURCE LINES 164-165 Create a reference template: .. GENERATED FROM PYTHON SOURCE LINES 165-168 .. code-block:: Python template = create_sphere(Npoints) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/optimal_transport/plot_interpolation_3D.py:110: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /pytorch/torch/csrc/tensor/python_tensor.cpp:78.) return tensor(weights), tensor(points) .. GENERATED FROM PYTHON SOURCE LINES 169-171 Use color labels to track the particles: .. GENERATED FROM PYTHON SOURCE LINES 171-177 .. code-block:: Python K = 12 colors = (K * template[1][:, 0]).cos() colors = colors.view(-1).detach().cpu().numpy() .. GENERATED FROM PYTHON SOURCE LINES 178-180 Fetch the data: .. GENERATED FROM PYTHON SOURCE LINES 180-198 .. code-block:: Python os.makedirs(os.path.dirname("data/"), exist_ok=True) if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"): print("Fetching the data... ", end="", flush=True) import urllib.request urllib.request.urlretrieve( "http://www.kernel-operations.io/data/wasserstein_3D_models.zip", "data/wasserstein_3D_models.zip", ) import shutil shutil.unpack_archive("data/wasserstein_3D_models.zip", "data") print("Done.") .. GENERATED FROM PYTHON SOURCE LINES 199-200 Load the data on the GPU: .. GENERATED FROM PYTHON SOURCE LINES 200-210 .. code-block:: Python print("Loading the data:") # N.B.: Since Plyfile is far from being optimized, this may take some time! targets = [ load_ply_file("data/wasserstein_3D_models/Stanford_dragon_200k.ply"), load_ply_file("data/wasserstein_3D_models/vertebrae_400k_biol260_sketchfab_CC.ply"), load_nii_file("data/wasserstein_3D_models/brain.nii.gz"), ] .. rst-class:: sphx-glr-script-out .. code-block:: none Loading the data: File loaded, and encoded as the weighted sum of 202,520 atoms in 3D. File loaded, and encoded as the weighted sum of 394,237 atoms in 3D. File loaded, and encoded as the weighted sum of 382,074 atoms in 3D. .. GENERATED FROM PYTHON SOURCE LINES 211-212 Normalize and subsample everyone, if required: .. GENERATED FROM PYTHON SOURCE LINES 212-239 .. code-block:: Python def normalize(measure, n=None): """Reduce a point cloud to at most n points and normalize the weights and point cloud.""" weights, locations = measure N = len(weights) if n is not None and n < N: n = int(n) indices = torch.randperm(N) indices = indices[:n] weights, locations = weights[indices], locations[indices] weights = weights / weights.sum() weights, locations = weights.contiguous(), locations.contiguous() # Center, normalize the point cloud mean = (weights.view(-1, 1) * locations).sum(dim=0) locations -= mean std = (weights.view(-1) * (locations**2).sum(dim=1).view(-1)).sum().sqrt() locations /= std return weights, locations targets = [normalize(t, n=Npoints) for t in targets] .. GENERATED FROM PYTHON SOURCE LINES 240-241 Fine tuning: .. GENERATED FROM PYTHON SOURCE LINES 241-250 .. code-block:: Python template = template[0], template[1] / 2 + tensor( [0.5, 0.0, 0.0] ) # Smaller sphere, towards the back of the dragon targets[1] = targets[1][0], targets[1][1] @ tensor( [[0, 0, 1], [0, 1, 0], [1, 0, 0]] ) # Turn the vertebra targets[2] = targets[2][0], -targets[2][1] # Flip the brain .. GENERATED FROM PYTHON SOURCE LINES 251-255 Optimal Transport matchings -------------------------------- Define our solver: .. GENERATED FROM PYTHON SOURCE LINES 255-288 .. code-block:: Python import time Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1) def OT_registration(source, target, name): a, x = source # weights, locations b, y = target # weights, locations x.requires_grad = True z = x.clone() # Moving point cloud if use_cuda: torch.cuda.synchronize() start = time.time() nits = 4 if fast_demo else 10 for it in range(nits): wasserstein_zy = Loss(a, z, b, y) [grad_z] = torch.autograd.grad(wasserstein_zy, [z]) z -= grad_z / a[:, None] # Apply the regularized Brenier map # save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors) end = time.time() print("Registered {} in {:.3f}s.".format(name, end - start)) return z .. GENERATED FROM PYTHON SOURCE LINES 289-291 Register the source onto the targets: .. GENERATED FROM PYTHON SOURCE LINES 291-297 .. code-block:: Python matchings = [ OT_registration(template, target, f"shape{i+1}") for (i, target) in enumerate(targets) ] .. rst-class:: sphx-glr-script-out .. code-block:: none Registered shape1 in 3.134s. Registered shape2 in 3.562s. Registered shape3 in 2.661s. .. GENERATED FROM PYTHON SOURCE LINES 298-299 Display our matchings: .. GENERATED FROM PYTHON SOURCE LINES 299-316 .. code-block:: Python for i, (matching, target) in enumerate(zip(matchings, targets)): fig = plt.figure(figsize=(6, 6)) plt.set_cmap("hsv") ax = fig.add_subplot(1, 1, 1, projection="3d") display_cloud(ax, (template[0], matching), colors) display_cloud(ax, target, "blue") ax.set_title( "Registered (N={:,}) and target {} (M={:,}) point clouds".format( len(matching), i + 1, len(target[0]) ) ) plt.tight_layout() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png :alt: Registered (N=200,000) and target 1 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png :alt: Registered (N=200,000) and target 2 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png :alt: Registered (N=200,000) and target 3 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 317-321 Movie ------------- Save them as a collection of VTK files: .. GENERATED FROM PYTHON SOURCE LINES 321-347 .. code-block:: Python FPS = 32 if fast_demo else 32 source = template[1] pairs = [ (source, source), (source, matchings[0]), (matchings[0], matchings[0]), (matchings[0], matchings[1]), (matchings[1], matchings[1]), (matchings[1], matchings[2]), (matchings[2], matchings[2]), (matchings[2], source), ] frame = 0 print("Save as a VTK movie...", end="", flush=True) for A, B in pairs: A, B = numpy(A), numpy(B) for t in np.linspace(0, 1, FPS): save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors) frame += 1 print("Done.") plt.show() .. rst-class:: sphx-glr-script-out .. code-block:: none Save as a VTK movie...Done. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 42.690 seconds) .. _sphx_glr_download__auto_examples_optimal_transport_plot_interpolation_3D.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_interpolation_3D.ipynb <plot_interpolation_3D.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_interpolation_3D.py <plot_interpolation_3D.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_