.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/optimal_transport/plot_interpolation_3D.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_optimal_transport_plot_interpolation_3D.py: Creating a fancy interpolation video between 3D meshes. ============================================================== N.B.: I am currently very busy writing my PhD thesis. Comments will come soon! .. GENERATED FROM PYTHON SOURCE LINES 9-14 Setup ---------------------- Standard imports. .. GENERATED FROM PYTHON SOURCE LINES 14-31 .. code-block:: default import numpy as np import torch import os use_cuda = torch.cuda.is_available() tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor numpy = lambda x: x.detach().cpu().numpy() from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from geomloss import SamplesLoss from pykeops.torch import LazyTensor .. GENERATED FROM PYTHON SOURCE LINES 32-33 Utility: turn a triangle mesh into a weighted point cloud. .. GENERATED FROM PYTHON SOURCE LINES 33-55 .. code-block:: default def to_measure(points, triangles): """Turns a triangle into a weighted point cloud.""" # Our mesh is given as a collection of ABC triangles: A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]] # Locations and weights of our Dirac atoms: X = (A + B + C) / 3 # centers of the faces S = np.sqrt(np.sum(np.cross(B - A, C - A) ** 2, 1)) / 2 # areas of the faces print( "File loaded, and encoded as the weighted sum of {:,} atoms in 3D.".format( len(X) ) ) # We return a (normalized) vector of weights + a "list" of points return tensor(S / np.sum(S)), tensor(X) .. GENERATED FROM PYTHON SOURCE LINES 56-58 Utility: load ".ply" mesh file. .. GENERATED FROM PYTHON SOURCE LINES 58-75 .. code-block:: default from plyfile import PlyData, PlyElement def load_ply_file(fname): """Loads a .ply mesh to return a collection of weighted Dirac atoms: one per triangle face.""" # Load the data, and read the connectivity information: plydata = PlyData.read(fname) triangles = np.vstack(plydata["face"].data["vertex_indices"]) # Normalize the point cloud, as specified by the user: points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]]) return to_measure(points, triangles) .. GENERATED FROM PYTHON SOURCE LINES 76-78 Utility: load ".nii" volume file. .. GENERATED FROM PYTHON SOURCE LINES 78-93 .. code-block:: default import SimpleITK as sitk from skimage.measure import marching_cubes def load_nii_file(fname, threshold=0.5): """Uses the marching cube algorithm to turn a .nii binary mask into a surface weighted point cloud.""" mask = sitk.GetArrayFromImage(sitk.ReadImage(fname)) # mask = skimage.transform.downscale_local_mean(mask, (4,4,4)) verts, faces, normals, values = marching_cubes(mask, threshold) return to_measure(verts, faces) .. GENERATED FROM PYTHON SOURCE LINES 94-96 Synthetic sphere - a typical source measure: .. GENERATED FROM PYTHON SOURCE LINES 96-113 .. code-block:: default def create_sphere(n_samples=1000): """Creates a uniform sample on the unit sphere.""" n_samples = int(n_samples) indices = np.arange(0, n_samples, dtype=float) + 0.5 phi = np.arccos(1 - 2 * indices / n_samples) theta = np.pi * (1 + 5 ** 0.5) * indices x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi) points = np.vstack((x, y, z)).T weights = np.ones(n_samples) / n_samples return tensor(weights), tensor(points) .. GENERATED FROM PYTHON SOURCE LINES 114-116 Simple (slow) display routine: .. GENERATED FROM PYTHON SOURCE LINES 116-133 .. code-block:: default def display_cloud(ax, measure, color): w_i, x_i = numpy(measure[0]), numpy(measure[1]) ax.view_init(elev=110, azim=-90) # ax.set_aspect('equal') weights = w_i / w_i.sum() ax.scatter(x_i[:, 0], x_i[:, 1], x_i[:, 2], s=25 * 500 * weights, c=color) ax.axes.set_xlim3d(left=-1.4, right=1.4) ax.axes.set_ylim3d(bottom=-1.4, top=1.4) ax.axes.set_zlim3d(bottom=-1.4, top=1.4) .. GENERATED FROM PYTHON SOURCE LINES 134-135 Save the output as a VTK folder, to be rendered with Paraview: .. GENERATED FROM PYTHON SOURCE LINES 135-152 .. code-block:: default folder = "output/wasserstein_3D/" os.makedirs(os.path.dirname("output/wasserstein_3D/"), exist_ok=True) from pyvtk import PolyData, PointData, CellData, Scalars, VtkData, PointData def save_vtk(fname, points, colors): """N.B.: Paraview is a good VTK viewer, which supports ray-tracing.""" structure = PolyData(points=points, vertices=np.arange(len(points))) values = PointData(Scalars(colors, name="colors")) vtk = VtkData(structure, values) vtk.tofile(folder + fname, "binary") .. GENERATED FROM PYTHON SOURCE LINES 153-157 Data ---------------- Shall we work on subsampled data or at full resolution? .. GENERATED FROM PYTHON SOURCE LINES 157-165 .. code-block:: default fast_demo = False if use_cuda else True if use_cuda: Npoints = 1e4 if fast_demo else 2e5 else: Npoints = 1e3 .. GENERATED FROM PYTHON SOURCE LINES 166-167 Create a reference template: .. GENERATED FROM PYTHON SOURCE LINES 167-170 .. code-block:: default template = create_sphere(Npoints) .. GENERATED FROM PYTHON SOURCE LINES 171-173 Use color labels to track the particles: .. GENERATED FROM PYTHON SOURCE LINES 173-179 .. code-block:: default K = 12 colors = (K * template[1][:, 0]).cos() colors = colors.view(-1).detach().cpu().numpy() .. GENERATED FROM PYTHON SOURCE LINES 180-182 Fetch the data: .. GENERATED FROM PYTHON SOURCE LINES 182-200 .. code-block:: default os.makedirs(os.path.dirname("data/"), exist_ok=True) if not os.path.exists("data/wasserstein_3D_models/Stanford_dragon_200k.ply"): print("Fetching the data... ", end="", flush=True) import urllib.request urllib.request.urlretrieve( "http://www.kernel-operations.io/data/wasserstein_3D_models.zip", "data/wasserstein_3D_models.zip", ) import shutil shutil.unpack_archive("data/wasserstein_3D_models.zip", "data") print("Done.") .. GENERATED FROM PYTHON SOURCE LINES 201-202 Load the data on the GPU: .. GENERATED FROM PYTHON SOURCE LINES 202-212 .. code-block:: default print("Loading the data:") # N.B.: Since Plyfile is far from being optimized, this may take some time! targets = [ load_ply_file("data/wasserstein_3D_models/Stanford_dragon_200k.ply"), load_ply_file("data/wasserstein_3D_models/vertebrae_400k_biol260_sketchfab_CC.ply"), load_nii_file("data/wasserstein_3D_models/brain.nii.gz"), ] .. rst-class:: sphx-glr-script-out .. code-block:: none Loading the data: File loaded, and encoded as the weighted sum of 202,520 atoms in 3D. File loaded, and encoded as the weighted sum of 394,237 atoms in 3D. File loaded, and encoded as the weighted sum of 382,074 atoms in 3D. .. GENERATED FROM PYTHON SOURCE LINES 213-214 Normalize and subsample everyone, if required: .. GENERATED FROM PYTHON SOURCE LINES 214-241 .. code-block:: default def normalize(measure, n=None): """Reduce a point cloud to at most n points and normalize the weights and point cloud.""" weights, locations = measure N = len(weights) if n is not None and n < N: n = int(n) indices = torch.randperm(N) indices = indices[:n] weights, locations = weights[indices], locations[indices] weights = weights / weights.sum() weights, locations = weights.contiguous(), locations.contiguous() # Center, normalize the point cloud mean = (weights.view(-1, 1) * locations).sum(dim=0) locations -= mean std = (weights.view(-1) * (locations ** 2).sum(dim=1).view(-1)).sum().sqrt() locations /= std return weights, locations targets = [normalize(t, n=Npoints) for t in targets] .. GENERATED FROM PYTHON SOURCE LINES 242-243 Fine tuning: .. GENERATED FROM PYTHON SOURCE LINES 243-252 .. code-block:: default template = template[0], template[1] / 2 + tensor( [0.5, 0.0, 0.0] ) # Smaller sphere, towards the back of the dragon targets[1] = targets[1][0], targets[1][1] @ tensor( [[0, 0, 1], [0, 1, 0], [1, 0, 0]] ) # Turn the vertebra targets[2] = targets[2][0], -targets[2][1] # Flip the brain .. GENERATED FROM PYTHON SOURCE LINES 253-257 Optimal Transport matchings -------------------------------- Define our solver: .. GENERATED FROM PYTHON SOURCE LINES 257-290 .. code-block:: default import time Loss = SamplesLoss("sinkhorn", p=2, blur=0.01, scaling=0.5, truncate=1) def OT_registration(source, target, name): a, x = source # weights, locations b, y = target # weights, locations x.requires_grad = True z = x.clone() # Moving point cloud if use_cuda: torch.cuda.synchronize() start = time.time() nits = 4 if fast_demo else 10 for it in range(nits): wasserstein_zy = Loss(a, z, b, y) [grad_z] = torch.autograd.grad(wasserstein_zy, [z]) z -= grad_z / a[:, None] # Apply the regularized Brenier map # save_vtk(f"matching_{name}_it_{it}.vtk", numpy(z), colors) end = time.time() print("Registered {} in {:.3f}s.".format(name, end - start)) return z .. GENERATED FROM PYTHON SOURCE LINES 291-293 Register the source onto the targets: .. GENERATED FROM PYTHON SOURCE LINES 293-299 .. code-block:: default matchings = [ OT_registration(template, target, f"shape{i+1}") for (i, target) in enumerate(targets) ] .. rst-class:: sphx-glr-script-out .. code-block:: none Registered shape1 in 8.594s. Registered shape2 in 17.520s. Registered shape3 in 4.895s. .. GENERATED FROM PYTHON SOURCE LINES 300-301 Display our matchings: .. GENERATED FROM PYTHON SOURCE LINES 301-319 .. code-block:: default for (i, (matching, target)) in enumerate(zip(matchings, targets)): fig = plt.figure(figsize=(6, 6)) plt.set_cmap("hsv") ax = fig.add_subplot(1, 1, 1, projection="3d") display_cloud(ax, (template[0], matching), colors) display_cloud(ax, target, "blue") ax.set_title( "Registered (N={:,}) and target {} (M={:,}) point clouds".format( len(matching), i + 1, len(target[0]) ) ) plt.tight_layout() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png :alt: Registered (N=200,000) and target 1 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_001.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png :alt: Registered (N=200,000) and target 2 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_002.png :class: sphx-glr-multi-img * .. image-sg:: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png :alt: Registered (N=200,000) and target 3 (M=200,000) point clouds :srcset: /_auto_examples/optimal_transport/images/sphx_glr_plot_interpolation_3D_003.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 320-324 Movie ------------- Save them as a collection of VTK files: .. GENERATED FROM PYTHON SOURCE LINES 324-350 .. code-block:: default FPS = 32 if fast_demo else 32 source = template[1] pairs = [ (source, source), (source, matchings[0]), (matchings[0], matchings[0]), (matchings[0], matchings[1]), (matchings[1], matchings[1]), (matchings[1], matchings[2]), (matchings[2], matchings[2]), (matchings[2], source), ] frame = 0 print("Save as a VTK movie...", end="", flush=True) for (A, B) in pairs: A, B = numpy(A), numpy(B) for t in np.linspace(0, 1, FPS): save_vtk(f"frame_{frame}.vtk", (1 - t) * A + t * B, colors) frame += 1 print("Done.") plt.show() .. rst-class:: sphx-glr-script-out .. code-block:: none Save as a VTK movie...Done. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 4.567 seconds) .. _sphx_glr_download__auto_examples_optimal_transport_plot_interpolation_3D.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_interpolation_3D.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_interpolation_3D.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_