Source code for geomloss.samples_loss

import torch
from torch.nn import Module
from functools import partial
import warnings

from .kernel_samples import kernel_tensorized, kernel_online, kernel_multiscale

from .sinkhorn_samples import sinkhorn_tensorized
from .sinkhorn_samples import sinkhorn_online
from .sinkhorn_samples import sinkhorn_multiscale

from .kernel_samples import kernel_tensorized as hausdorff_tensorized
from .kernel_samples import kernel_online as hausdorff_online
from .kernel_samples import kernel_multiscale as hausdorff_multiscale


routines = {
    "sinkhorn": {
        "tensorized": sinkhorn_tensorized,
        "online": sinkhorn_online,
        "multiscale": sinkhorn_multiscale,
    },
    "hausdorff": {
        "tensorized": hausdorff_tensorized,
        "online": hausdorff_online,
        "multiscale": hausdorff_multiscale,
    },
    "energy": {
        "tensorized": partial(kernel_tensorized, name="energy"),
        "online": partial(kernel_online, name="energy"),
        "multiscale": partial(kernel_multiscale, name="energy"),
    },
    "gaussian": {
        "tensorized": partial(kernel_tensorized, name="gaussian"),
        "online": partial(kernel_online, name="gaussian"),
        "multiscale": partial(kernel_multiscale, name="gaussian"),
    },
    "laplacian": {
        "tensorized": partial(kernel_tensorized, name="laplacian"),
        "online": partial(kernel_online, name="laplacian"),
        "multiscale": partial(kernel_multiscale, name="laplacian"),
    },
}


[docs]class SamplesLoss(Module): """Creates a criterion that computes distances between sampled measures on a vector space. Warning: If **loss** is ``"sinkhorn"`` and **reach** is **None** (balanced Optimal Transport), the resulting routine will expect measures whose total masses are equal with each other. Parameters: loss (string, default = ``"sinkhorn"``): The loss function to compute. The supported values are: - ``"sinkhorn"``: (Un-biased) Sinkhorn divergence, which interpolates between Wasserstein (blur=0) and kernel (blur= :math:`+\infty` ) distances. - ``"hausdorff"``: Weighted Hausdorff distance, which interpolates between the ICP loss (blur=0) and a kernel distance (blur= :math:`+\infty` ). - ``"energy"``: Energy Distance MMD, computed using the kernel :math:`k(x,y) = -\|x-y\|_2`. - ``"gaussian"``: Gaussian MMD, computed using the kernel :math:`k(x,y) = \exp \\big( -\|x-y\|_2^2 \,/\, 2\sigma^2)` of standard deviation :math:`\sigma` = **blur**. - ``"laplacian"``: Laplacian MMD, computed using the kernel :math:`k(x,y) = \exp \\big( -\|x-y\|_2 \,/\, \sigma)` of standard deviation :math:`\sigma` = **blur**. p (int, default=2): If **loss** is ``"sinkhorn"`` or ``"hausdorff"``, specifies the ground cost function between points. The supported values are: - **p** = 1: :math:`~~C(x,y) ~=~ \|x-y\|_2`. - **p** = 2: :math:`~~C(x,y) ~=~ \\tfrac{1}{2}\|x-y\|_2^2`. blur (float, default=.05): The finest level of detail that should be handled by the loss function - in order to prevent overfitting on the samples' locations. - If **loss** is ``"gaussian"`` or ``"laplacian"``, it is the standard deviation :math:`\sigma` of the convolution kernel. - If **loss** is ``"sinkhorn"`` or ``"hausdorff"``, it is the typical scale :math:`\sigma` associated to the temperature :math:`\\varepsilon = \sigma^p`. The default value of .05 is sensible for input measures that lie in the unit square/cube. Note that the *Energy Distance* is scale-equivariant, and won't be affected by this parameter. reach (float, default=None= :math:`+\infty` ): If **loss** is ``"sinkhorn"`` or ``"hausdorff"``, specifies the typical scale :math:`\\tau` associated to the constraint strength :math:`\\rho = \\tau^p`. diameter (float, default=None): A rough indication of the maximum distance between points, which is used to tune the :math:`\\varepsilon`-scaling descent and provide a default heuristic for clustering **multiscale** schemes. If **None**, a conservative estimate will be computed on-the-fly. scaling (float, default=.5): If **loss** is ``"sinkhorn"``, specifies the ratio between successive values of :math:`\sigma=\\varepsilon^{1/p}` in the :math:`\\varepsilon`-scaling descent. This parameter allows you to specify the trade-off between speed (**scaling** < .4) and accuracy (**scaling** > .9). truncate (float, default=None= :math:`+\infty`): If **backend** is ``"multiscale"``, specifies the effective support of a Gaussian/Laplacian kernel as a multiple of its standard deviation. If **truncate** is not **None**, kernel truncation steps will assume that :math:`\\exp(-x/\sigma)` or :math:`\\exp(-x^2/2\sigma^2) are zero when :math:`\|x\| \,>\, \\text{truncate}\cdot \sigma`. cost (function or string, default=None): if **loss** is ``"sinkhorn"`` or ``"hausdorff"``, specifies the cost function that should be used instead of :math:`\\tfrac{1}{p}\|x-y\|^p`: - If **backend** is ``"tensorized"``, **cost** should be a python function that takes as input a (B,N,D) torch Tensor **x**, a (B,M,D) torch Tensor **y** and returns a batched Cost matrix as a (B,N,M) Tensor. - Otherwise, if **backend** is ``"online"`` or ``"multiscale"``, **cost** should be a `KeOps formula <http://www.kernel-operations.io/api/math-operations.html>`_, given as a string, with variables ``X`` and ``Y``. The default values are ``"Norm2(X-Y)"`` (for **p** = 1) and ``"(SqDist(X,Y) / IntCst(2))"`` (for **p** = 2). cluster_scale (float, default=None): If **backend** is ``"multiscale"``, specifies the coarse scale at which cluster centroids will be computed. If **None**, a conservative estimate will be computed from **diameter** and the ambient space's dimension, making sure that memory overflows won't take place. debias (bool, default=True): If **loss** is ``"sinkhorn"``, specifies if we should compute the **unbiased** Sinkhorn divergence instead of the classic, entropy-regularized "SoftAssign" loss. potentials (bool, default=False): When this parameter is set to True, the :mod:`SamplesLoss` layer returns a pair of optimal dual potentials :math:`F` and :math:`G`, sampled on the input measures, instead of differentiable scalar value. These dual vectors :math:`(F(x_i))` and :math:`(G(y_j))` are encoded as Torch tensors, with the same shape as the input weights :math:`(\\alpha_i)` and :math:`(\\beta_j)`. verbose (bool, default=False): If **backend** is ``"multiscale"``, specifies whether information on the clustering and :math:`\\varepsilon`-scaling descent should be displayed in the standard output. backend (string, default = ``"auto"``): The implementation that will be used in the background; this choice has a major impact on performance. The supported values are: - ``"auto"``: Choose automatically, using a simple heuristic based on the inputs' shapes. - ``"tensorized"``: Relies on a full cost/kernel matrix, computed once and for all and stored on the device memory. This method is fast, but has a quadratic memory footprint and does not scale beyond ~5,000 samples per measure. - ``"online"``: Computes cost/kernel values on-the-fly, leveraging online map-reduce CUDA routines provided by the `pykeops <https://www.kernel-operations.io>`_ library. - ``"multiscale"``: Fast implementation that scales to millions of samples in dimension 1-2-3, relying on the block-sparse reductions provided by the `pykeops <https://www.kernel-operations.io>`_ library. """ def __init__( self, loss="sinkhorn", p=2, blur=0.05, reach=None, diameter=None, scaling=0.5, truncate=5, cost=None, kernel=None, cluster_scale=None, debias=True, potentials=False, verbose=False, backend="auto", ): super(SamplesLoss, self).__init__() self.loss = loss self.backend = backend self.p = p self.blur = blur self.reach = reach self.truncate = truncate self.diameter = diameter self.scaling = scaling self.cost = cost self.kernel = kernel self.cluster_scale = cluster_scale self.debias = debias self.potentials = potentials self.verbose = verbose def forward(self, *args): """Computes the loss between sampled measures. Documentation and examples: Soon! Until then, please check the tutorials :-)""" l_x, α, x, l_y, β, y = self.process_args(*args) B, N, M, D, l_x, α, l_y, β = self.check_shapes(l_x, α, x, l_y, β, y) backend = ( self.backend ) # Choose the backend ----------------------------------------- if l_x is not None or l_y is not None: if backend in ["auto", "multiscale"]: backend = "multiscale" else: raise ValueError( 'Explicit cluster labels are only supported with the "auto" and "multiscale" backends.' ) elif backend == "auto": if M * N <= 5000 ** 2: backend = ( "tensorized" # Fast backend, with a quadratic memory footprint ) else: if ( D <= 3 and self.loss == "sinkhorn" and M * N > 10000 ** 2 and self.p == 2 ): backend = "multiscale" # Super scalable algorithm in low dimension else: backend = "online" # Play it safe, without kernel truncation # Check compatibility between the batchsize and the backend -------------------------- if backend in ["multiscale"]: # multiscale routines work on single measures if B == 1: α, x, β, y = α.squeeze(0), x.squeeze(0), β.squeeze(0), y.squeeze(0) elif B > 1: warnings.warn( "The 'multiscale' backend do not support batchsize > 1. " + "Using 'tensorized' instead: beware of memory overflows!" ) backend = "tensorized" if B == 0 and backend in [ "tensorized", "online", ]: # tensorized and online routines work on batched tensors α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0) # Run -------------------------------------------------------------------------------- values = routines[self.loss][backend]( α, x, β, y, p=self.p, blur=self.blur, reach=self.reach, diameter=self.diameter, scaling=self.scaling, truncate=self.truncate, cost=self.cost, kernel=self.kernel, cluster_scale=self.cluster_scale, debias=self.debias, potentials=self.potentials, labels_x=l_x, labels_y=l_y, verbose=self.verbose, ) # Make sure that the output has the correct shape ------------------------------------ if ( self.potentials ): # Return some dual potentials (= test functions) sampled on the input measures F, G = values return F.view_as(α), G.view_as(β) else: # Return a scalar cost value if backend in ["multiscale"]: # KeOps backends return a single scalar value if B == 0: return values # The user expects a scalar value else: return values.view( -1 ) # The user expects a "batch list" of distances else: # "tensorized" backend returns a "batch vector" of values if B == 0: return values[0] # The user expects a scalar value else: return values # The user expects a "batch vector" of distances def process_args(self, *args): if len(args) == 6: return args if len(args) == 4: α, x, β, y = args return None, α, x, None, β, y elif len(args) == 2: x, y = args α = self.generate_weights(x) β = self.generate_weights(y) return None, α, x, None, β, y else: raise ValueError( "A SamplesLoss accepts two (x, y), four (α, x, β, y) or six (l_x, α, x, l_y, β, y) arguments." ) def generate_weights(self, x): if x.dim() == 2: # N = x.shape[0] return torch.ones(N).type_as(x) / N elif x.dim() == 3: B, N, _ = x.shape return torch.ones(B, N).type_as(x) / N else: raise ValueError( "Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors." ) def check_shapes(self, l_x, α, x, l_y, β, y): if α.dim() != β.dim(): raise ValueError( "Input weights 'α' and 'β' should have the same number of dimensions." ) if x.dim() != y.dim(): raise ValueError( "Input samples 'x' and 'y' should have the same number of dimensions." ) if x.shape[-1] != y.shape[-1]: raise ValueError( "Input samples 'x' and 'y' should have the same last dimension." ) if ( x.dim() == 2 ): # No batch -------------------------------------------------------------------- B = 0 # Batchsize N, D = x.shape # Number of "i" samples, dimension of the feature space M, _ = y.shape # Number of "j" samples, dimension of the feature space if α.dim() not in [1, 2]: raise ValueError( "Without batches, input weights 'α' and 'β' should be encoded as (N,) or (N,1) tensors." ) elif α.dim() == 2: if α.shape[1] > 1: raise ValueError( "Without batches, input weights 'α' should be encoded as (N,) or (N,1) tensors." ) if β.shape[1] > 1: raise ValueError( "Without batches, input weights 'β' should be encoded as (M,) or (M,1) tensors." ) α, β = α.view(-1), β.view(-1) if l_x is not None: if l_x.dim() not in [1, 2]: raise ValueError( "Without batches, the vector of labels 'l_x' should be encoded as an (N,) or (N,1) tensor." ) elif l_x.dim() == 2: if l_x.shape[1] > 1: raise ValueError( "Without batches, the vector of labels 'l_x' should be encoded as (N,) or (N,1) tensors." ) l_x = l_x.view(-1) if len(l_x) != N: raise ValueError( "The vector of labels 'l_x' should have the same length as the point cloud 'x'." ) if l_y is not None: if l_y.dim() not in [1, 2]: raise ValueError( "Without batches, the vector of labels 'l_y' should be encoded as an (M,) or (M,1) tensor." ) elif l_y.dim() == 2: if l_y.shape[1] > 1: raise ValueError( "Without batches, the vector of labels 'l_y' should be encoded as (M,) or (M,1) tensors." ) l_y = l_y.view(-1) if len(l_y) != M: raise ValueError( "The vector of labels 'l_y' should have the same length as the point cloud 'y'." ) N2, M2 = α.shape[0], β.shape[0] elif ( x.dim() == 3 ): # batch computation --------------------------------------------------------- ( B, N, D, ) = x.shape # Batchsize, number of "i" samples, dimension of the feature space ( B2, M, _, ) = y.shape # Batchsize, number of "j" samples, dimension of the feature space if B != B2: raise ValueError("Samples 'x' and 'y' should have the same batchsize.") if α.dim() not in [2, 3]: raise ValueError( "With batches, input weights 'α' and 'β' should be encoded as (B,N) or (B,N,1) tensors." ) elif α.dim() == 3: if α.shape[2] > 1: raise ValueError( "With batches, input weights 'α' should be encoded as (B,N) or (B,N,1) tensors." ) if β.shape[2] > 1: raise ValueError( "With batches, input weights 'β' should be encoded as (B,M) or (B,M,1) tensors." ) α, β = α.squeeze(-1), β.squeeze(-1) if l_x is not None: raise NotImplementedError( 'The "multiscale" backend has not been implemented with batches.' ) if l_y is not None: raise NotImplementedError( 'The "multiscale" backend has not been implemented with batches.' ) B2, N2 = α.shape B3, M2 = β.shape if B != B2: raise ValueError( "Samples 'x' and weights 'α' should have the same batchsize." ) if B != B3: raise ValueError( "Samples 'y' and weights 'β' should have the same batchsize." ) else: raise ValueError( "Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors." ) if N != N2: raise ValueError( "Weights 'α' and samples 'x' should have compatible shapes." ) if M != M2: raise ValueError( "Weights 'β' and samples 'y' should have compatible shapes." ) return B, N, M, D, l_x, α, l_y, β