3) Optimal Transport in high dimension
=======================================

Let's use a custom clustering scheme to generalize the **multiscale** Sinkhorn algorithm to high-dimensional settings.

Setup
---------------------

Standard imports: GENERATED FROM PYTHON SOURCE LINES 15-39 .. code-block:: default import numpy as np import matplotlib.pyplot as plt import time import torch from geomloss import SamplesLoss use_cuda = torch.cuda.is_available() dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor def display_4d_samples(ax1, ax2, x, color): x_ = x.detach().cpu().numpy() if not type(color) in [str, list]: color = color.detach().cpu().numpy() ax1.scatter( x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none", cmap="tab10" ) ax2.scatter( x_[:, 2], x_[:, 3], 25 * 500 / len(x_), color, edgecolors="none", cmap="tab10" ) .. GENERATED FROM PYTHON SOURCE LINES 40-49 **Dataset.** Our source and target samples are drawn from (noisy) discrete sub-manifolds in :math:`\mathbb{R}^4`. They allow us to define a pair of discrete probability measures: .. math:: \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~ \beta ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}. .. GENERATED FROM PYTHON SOURCE LINES 49-66 .. code-block:: default N, M = (100, 100) if not use_cuda else (50000, 50000) # Generate some kind of 4d-helix: t = torch.linspace(0, 2 * np.pi, N).type(dtype) X_i = ( torch.stack((t * (2 * t).cos() / 7, t * (2 * t).sin() / 7, t / 7, t ** 2 / 50)) .t() .contiguous() ) X_i = X_i + 0.05 * torch.randn(N, 4).type(dtype) # + some noise # The y_j's are sampled non-uniformly on the unit sphere of R^4: Y_j = torch.randn(M, 4).type(dtype) Y_j[:, 0] += 2 Y_j = Y_j / (1e-4 + Y_j.norm(dim=1, keepdim=True)) .. GENERATED FROM PYTHON SOURCE LINES 67-68 We display our 4d-samples using two 2d-views: .. GENERATED FROM PYTHON SOURCE LINES 68-81 .. code-block:: default plt.figure(figsize=(12, 6)) ax1 = plt.subplot(1, 2, 1) plt.title("Dimensions 0, 1") ax2 = plt.subplot(1, 2, 2) plt.title("Dimensions 2, 3") display_4d_samples(ax1, ax2, X_i, [(0.95, 0.55, 0.55)]) display_4d_samples(ax1, ax2, Y_j, [(0.55, 0.55, 0.95)]) plt.tight_layout() .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_optimal_transport_cluster_001.png :alt: Dimensions 0, 1, Dimensions 2, 3 :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_optimal_transport_cluster_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/code/geomloss/geomloss/examples/sinkhorn_multiscale/plot_optimal_transport_cluster.py:31: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored ax1.scatter( /home/code/geomloss/geomloss/examples/sinkhorn_multiscale/plot_optimal_transport_cluster.py:34: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored ax2.scatter( .. GENERATED FROM PYTHON SOURCE LINES 82-90 Online Sinkhorn algorithm ------------------------------- When working with large point clouds in dimension > 3, the :mod:`SamplesLoss("sinkhorn") ` layer relies on an **online** implementation of the Sinkhorn algorithm (in the log-domain, with :math:`\varepsilon`-scaling) which computes softmin reductions **on-the-fly**, with a **linear memory footprint**: .. GENERATED FROM PYTHON SOURCE LINES 90-111 .. code-block:: default from geomloss import SamplesLoss # Compute the Wasserstein-2 distance between our samples, # with a small blur radius and a conservative value of the # scaling "decay" coefficient (.8 is pretty close to 1): Loss = SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.8) start = time.time() Wass_xy = Loss(X_i, Y_j) if use_cuda: torch.cuda.synchronize() end = time.time() print( "Wasserstein distance: {:.3f}, computed in {:.3f}s.".format( Wass_xy.item(), end - start ) ) .. rst-class:: sphx-glr-script-out .. code-block:: none Wasserstein distance: 0.509, computed in 0.708s. .. GENERATED FROM PYTHON SOURCE LINES 112-151 Multiscale Sinkhorn algorithm ------------------------------- Thanks to the :math:`\varepsilon`-scaling heuristic, this **online** backend already outperforms a naive implementation of the Sinkhorn/Auction algorithm by a factor ~10, for comparable values of the **blur** parameter. But we can go further. A key insight from recent works on computational Optimal Transport is that the dual optimization problem on the potentials (or *prices*) :math:`f` and :math:`g` can often be solved efficiently in a **coarse-to-fine** fashion, using a clever subsampling of the input measures in the first iterations of the :math:`\varepsilon`-scaling descent. For regularized Optimal Transport, the main reference on the subject is `(Schmitzer, 2016) `_ which combines an octree-like encoding with a kernel truncation (*pruning*) scheme to achieve log-linear complexity. Going further, `(Gerber and Maggioni, 2017) `_ generalize these ideas to high-dimensional scenarios, using a clever multiscale decomposition that relies on the **manifold-like structure** of the data - if any. Leveraging the block-sparse routines of the `KeOps library `_, the **multiscale** backend of the :mod:`SamplesLoss("sinkhorn") ` layer provides the **first GPU implementation** of these strategies. In dimensions 1, 2 and 3, clustering is automatically performed using a straightforward cubic grid. But in the general case, clustering information can simply be provided through a **vector of labels**, alongside the weights and samples' locations. **Clustering in high-dimension.** In this tutorial, we rely on an off-the-shelf `K-means clustering `_, copy-pasted from the examples gallery of the `KeOps library `_: feel free to replace it with a more clever scheme if needed! .. GENERATED FROM PYTHON SOURCE LINES 151-195 .. code-block:: default from pykeops.torch import generic_argmin def KMeans(x, K=10, Niter=10, verbose=True): N, D = x.shape # Number of samples, dimension of the ambient space # Define our KeOps CUDA kernel: nn_search = generic_argmin( # Argmin reduction for generic formulas: "SqDist(x,y)", # A simple squared L2 distance "ind = Vi(1)", # Output one index per "line" (reduction over "j") "x = Vi({})".format(D), # 1st arg: one point per "line" "y = Vj({})".format(D), ) # 2nd arg: one point per "column" # K-means loop: # - x is the point cloud, # - cl is the vector of class labels # - c is the cloud of cluster centroids start = time.time() # Simplistic random initialization for the cluster centroids: perm = torch.randperm(N) idx = perm[:K] c = x[idx, :].clone() for i in range(Niter): cl = nn_search(x, c).view(-1) # Points -> Nearest cluster Ncl = torch.bincount(cl).type(dtype) # Class weights for d in range(D): # Compute the cluster centroids with torch.bincount: c[:, d] = torch.bincount(cl, weights=x[:, d]) / Ncl if use_cuda: torch.cuda.synchronize() end = time.time() if verbose: print("KMeans performed in {:.3f}s.".format(end - start)) return cl, c lab_i, c_i = KMeans(X_i, K=100 if use_cuda else 10) lab_j, c_j = KMeans(Y_j, K=400 if use_cuda else 10) .. rst-class:: sphx-glr-script-out .. code-block:: none KMeans performed in 0.019s. KMeans performed in 0.009s. .. GENERATED FROM PYTHON SOURCE LINES 196-197 The average cluster size can be computed with one line of code: .. GENERATED FROM PYTHON SOURCE LINES 197-205 .. code-block:: default std_i = ((X_i - c_i[lab_i, :]) ** 2).sum(1).mean().sqrt() std_j = ((Y_j - c_j[lab_j, :]) ** 2).sum(1).mean().sqrt() print( "Our clusters have standard deviations of {:.3f} and {:.3f}.".format(std_i, std_j) ) .. rst-class:: sphx-glr-script-out .. code-block:: none Our clusters have standard deviations of 0.082 and 0.133. .. GENERATED FROM PYTHON SOURCE LINES 206-208 As expected, our samples are now distributed in small, convex clusters that partition the input data: .. GENERATED FROM PYTHON SOURCE LINES 208-226 .. code-block:: default # sphinx_gallery_thumbnail_number = 2 plt.figure(figsize=(12, 12)) ax1 = plt.subplot(2, 2, 1) plt.title("Dimensions 0, 1") ax2 = plt.subplot(2, 2, 2) plt.title("Dimensions 2, 3") ax3 = plt.subplot(2, 2, 3) plt.title("Dimensions 0, 1") ax4 = plt.subplot(2, 2, 4) plt.title("Dimensions 2, 3") display_4d_samples(ax1, ax2, X_i, lab_i) display_4d_samples(ax3, ax4, Y_j, lab_j) plt.tight_layout() .. image-sg:: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_optimal_transport_cluster_002.png :alt: Dimensions 0, 1, Dimensions 2, 3, Dimensions 0, 1, Dimensions 2, 3 :srcset: /_auto_examples/sinkhorn_multiscale/images/sphx_glr_plot_optimal_transport_cluster_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 227-235 To use this information in the **multiscale** Sinkhorn algorithm, we should simply provide: - explicit **labels** and **weights** for both input measures, - a typical **cluster_scale** which specifies the iteration at which the Sinkhorn loop jumps from a **coarse** to a **fine** representation of the data. .. GENERATED FROM PYTHON SOURCE LINES 235-257 .. code-block:: default Loss = SamplesLoss( "sinkhorn", p=2, blur=0.05, scaling=0.8, cluster_scale=max(std_i, std_j), verbose=True, ) # To specify explicit cluster labels, SamplesLoss also requires # explicit weights. Let's go with the default option - a uniform distribution: a_i = torch.ones(N).type(dtype) / N b_j = torch.ones(M).type(dtype) / M start = time.time() # 6 args -> labels_i, weights_i, locations_i, labels_j, weights_j, locations_j Wass_xy = Loss(lab_i, a_i, X_i, lab_j, b_j, Y_j) if use_cuda: torch.cuda.synchronize() end = time.time() .. rst-class:: sphx-glr-script-out .. code-block:: none 100x400 clusters, computed at scale = 0.133 Successive scales : 4.012, 4.012, 3.210, 2.568, 2.054, 1.643, 1.315, 1.052, 0.841, 0.673, 0.538, 0.431, 0.345, 0.276, 0.221, 0.176, 0.141, 0.113, 0.090, 0.072, 0.058, 0.050 Jump from coarse to fine between indices 16 (σ=0.141) and 17 (σ=0.113). Keep 13111/40000 = 32.8% of the coarse cost matrix. Keep 2710/10000 = 27.1% of the coarse cost matrix. Keep 26776/160000 = 16.7% of the coarse cost matrix. .. GENERATED FROM PYTHON SOURCE LINES 258-261 That's it! As expected, leveraging the structure of the data has allowed us to gain another ~10 speedup on large-scale transportation problems: .. print(
    "Wasserstein distance: {:.3f}, computed in {:.3f}s.".format(
        Wass_xy.item(), end - start
    )
)

plt.show()