import torch
from torch.nn import Module
from functools import partial
import warnings
from .kernel_samples import kernel_tensorized, kernel_online, kernel_multiscale
from .sinkhorn_samples import sinkhorn_tensorized
from .sinkhorn_samples import sinkhorn_online
from .sinkhorn_samples import sinkhorn_multiscale
from .kernel_samples import kernel_tensorized as hausdorff_tensorized
from .kernel_samples import kernel_online as hausdorff_online
from .kernel_samples import kernel_multiscale as hausdorff_multiscale
routines = {
"sinkhorn": {
"tensorized": sinkhorn_tensorized,
"online": sinkhorn_online,
"multiscale": sinkhorn_multiscale,
},
"hausdorff": {
"tensorized": hausdorff_tensorized,
"online": hausdorff_online,
"multiscale": hausdorff_multiscale,
},
"energy": {
"tensorized": partial(kernel_tensorized, name="energy"),
"online": partial(kernel_online, name="energy"),
"multiscale": partial(kernel_multiscale, name="energy"),
},
"gaussian": {
"tensorized": partial(kernel_tensorized, name="gaussian"),
"online": partial(kernel_online, name="gaussian"),
"multiscale": partial(kernel_multiscale, name="gaussian"),
},
"laplacian": {
"tensorized": partial(kernel_tensorized, name="laplacian"),
"online": partial(kernel_online, name="laplacian"),
"multiscale": partial(kernel_multiscale, name="laplacian"),
},
}
[docs]class SamplesLoss(Module):
"""Creates a criterion that computes distances between sampled measures on a vector space.
Warning:
If **loss** is ``"sinkhorn"`` and **reach** is **None** (balanced Optimal Transport),
the resulting routine will expect measures whose total masses are equal with each other.
Parameters:
loss (string, default = ``"sinkhorn"``): The loss function to compute.
The supported values are:
- ``"sinkhorn"``: (Un-biased) Sinkhorn divergence, which interpolates
between Wasserstein (blur=0) and kernel (blur= :math:`+\infty` ) distances.
- ``"hausdorff"``: Weighted Hausdorff distance, which interpolates
between the ICP loss (blur=0) and a kernel distance (blur= :math:`+\infty` ).
- ``"energy"``: Energy Distance MMD, computed using the kernel
:math:`k(x,y) = -\|x-y\|_2`.
- ``"gaussian"``: Gaussian MMD, computed using the kernel
:math:`k(x,y) = \exp \\big( -\|x-y\|_2^2 \,/\, 2\sigma^2)`
of standard deviation :math:`\sigma` = **blur**.
- ``"laplacian"``: Laplacian MMD, computed using the kernel
:math:`k(x,y) = \exp \\big( -\|x-y\|_2 \,/\, \sigma)`
of standard deviation :math:`\sigma` = **blur**.
p (int, default=2): If **loss** is ``"sinkhorn"`` or ``"hausdorff"``,
specifies the ground cost function between points.
The supported values are:
- **p** = 1: :math:`~~C(x,y) ~=~ \|x-y\|_2`.
- **p** = 2: :math:`~~C(x,y) ~=~ \\tfrac{1}{2}\|x-y\|_2^2`.
blur (float, default=.05): The finest level of detail that
should be handled by the loss function - in
order to prevent overfitting on the samples' locations.
- If **loss** is ``"gaussian"`` or ``"laplacian"``,
it is the standard deviation :math:`\sigma` of the convolution kernel.
- If **loss** is ``"sinkhorn"`` or ``"hausdorff"``,
it is the typical scale :math:`\sigma` associated
to the temperature :math:`\\varepsilon = \sigma^p`.
The default value of .05 is sensible for input
measures that lie in the unit square/cube.
Note that the *Energy Distance* is scale-equivariant, and won't
be affected by this parameter.
reach (float, default=None= :math:`+\infty` ): If **loss** is ``"sinkhorn"``
or ``"hausdorff"``,
specifies the typical scale :math:`\\tau` associated
to the constraint strength :math:`\\rho = \\tau^p`.
diameter (float, default=None): A rough indication of the maximum
distance between points, which is used to tune the :math:`\\varepsilon`-scaling
descent and provide a default heuristic for clustering **multiscale** schemes.
If **None**, a conservative estimate will be computed on-the-fly.
scaling (float, default=.5): If **loss** is ``"sinkhorn"``,
specifies the ratio between successive values
of :math:`\sigma=\\varepsilon^{1/p}` in the
:math:`\\varepsilon`-scaling descent.
This parameter allows you to specify the trade-off between
speed (**scaling** < .4) and accuracy (**scaling** > .9).
truncate (float, default=None= :math:`+\infty`): If **backend**
is ``"multiscale"``, specifies the effective support of
a Gaussian/Laplacian kernel as a multiple of its standard deviation.
If **truncate** is not **None**, kernel truncation
steps will assume that
:math:`\\exp(-x/\sigma)` or
:math:`\\exp(-x^2/2\sigma^2) are zero when
:math:`\|x\| \,>\, \\text{truncate}\cdot \sigma`.
cost (function or string, default=None): if **loss** is ``"sinkhorn"``
or ``"hausdorff"``, specifies the cost function that should
be used instead of :math:`\\tfrac{1}{p}\|x-y\|^p`:
- If **backend** is ``"tensorized"``, **cost** should be a
python function that takes as input a
(B,N,D) torch Tensor **x**, a (B,M,D) torch Tensor **y**
and returns a batched Cost matrix as a (B,N,M) Tensor.
- Otherwise, if **backend** is ``"online"`` or ``"multiscale"``,
**cost** should be a `KeOps formula <http://www.kernel-operations.io/api/math-operations.html>`_,
given as a string, with variables ``X`` and ``Y``.
The default values are ``"Norm2(X-Y)"`` (for **p** = 1) and
``"(SqDist(X,Y) / IntCst(2))"`` (for **p** = 2).
cluster_scale (float, default=None): If **backend** is ``"multiscale"``,
specifies the coarse scale at which cluster centroids will be computed.
If **None**, a conservative estimate will be computed from
**diameter** and the ambient space's dimension,
making sure that memory overflows won't take place.
debias (bool, default=True): If **loss** is ``"sinkhorn"``,
specifies if we should compute the **unbiased**
Sinkhorn divergence instead of the classic,
entropy-regularized "SoftAssign" loss.
potentials (bool, default=False): When this parameter is set to True,
the :mod:`SamplesLoss` layer returns a pair of optimal dual potentials
:math:`F` and :math:`G`, sampled on the input measures,
instead of differentiable scalar value.
These dual vectors :math:`(F(x_i))` and :math:`(G(y_j))`
are encoded as Torch tensors, with the same shape
as the input weights :math:`(\\alpha_i)` and :math:`(\\beta_j)`.
verbose (bool, default=False): If **backend** is ``"multiscale"``,
specifies whether information on the clustering and
:math:`\\varepsilon`-scaling descent should be displayed
in the standard output.
backend (string, default = ``"auto"``): The implementation that
will be used in the background; this choice has a major impact
on performance. The supported values are:
- ``"auto"``: Choose automatically, using a simple
heuristic based on the inputs' shapes.
- ``"tensorized"``: Relies on a full cost/kernel matrix, computed
once and for all and stored on the device memory.
This method is fast, but has a quadratic
memory footprint and does not scale beyond ~5,000 samples per measure.
- ``"online"``: Computes cost/kernel values on-the-fly, leveraging
online map-reduce CUDA routines provided by
the `pykeops <https://www.kernel-operations.io>`_ library.
- ``"multiscale"``: Fast implementation that scales to millions
of samples in dimension 1-2-3, relying on the block-sparse
reductions provided by the `pykeops <https://www.kernel-operations.io>`_ library.
"""
def __init__(
self,
loss="sinkhorn",
p=2,
blur=0.05,
reach=None,
diameter=None,
scaling=0.5,
truncate=5,
cost=None,
kernel=None,
cluster_scale=None,
debias=True,
potentials=False,
verbose=False,
backend="auto",
):
super(SamplesLoss, self).__init__()
self.loss = loss
self.backend = backend
self.p = p
self.blur = blur
self.reach = reach
self.truncate = truncate
self.diameter = diameter
self.scaling = scaling
self.cost = cost
self.kernel = kernel
self.cluster_scale = cluster_scale
self.debias = debias
self.potentials = potentials
self.verbose = verbose
def forward(self, *args):
"""Computes the loss between sampled measures.
Documentation and examples: Soon!
Until then, please check the tutorials :-)"""
l_x, α, x, l_y, β, y = self.process_args(*args)
B, N, M, D, l_x, α, l_y, β = self.check_shapes(l_x, α, x, l_y, β, y)
backend = (
self.backend
) # Choose the backend -----------------------------------------
if l_x is not None or l_y is not None:
if backend in ["auto", "multiscale"]:
backend = "multiscale"
else:
raise ValueError(
'Explicit cluster labels are only supported with the "auto" and "multiscale" backends.'
)
elif backend == "auto":
if M * N <= 5000 ** 2:
backend = (
"tensorized" # Fast backend, with a quadratic memory footprint
)
else:
if (
D <= 3
and self.loss == "sinkhorn"
and M * N > 10000 ** 2
and self.p == 2
):
backend = "multiscale" # Super scalable algorithm in low dimension
else:
backend = "online" # Play it safe, without kernel truncation
# Check compatibility between the batchsize and the backend --------------------------
if backend in ["multiscale"]: # multiscale routines work on single measures
if B == 1:
α, x, β, y = α.squeeze(0), x.squeeze(0), β.squeeze(0), y.squeeze(0)
elif B > 1:
warnings.warn(
"The 'multiscale' backend do not support batchsize > 1. "
+ "Using 'tensorized' instead: beware of memory overflows!"
)
backend = "tensorized"
if B == 0 and backend in [
"tensorized",
"online",
]: # tensorized and online routines work on batched tensors
α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)
# Run --------------------------------------------------------------------------------
values = routines[self.loss][backend](
α,
x,
β,
y,
p=self.p,
blur=self.blur,
reach=self.reach,
diameter=self.diameter,
scaling=self.scaling,
truncate=self.truncate,
cost=self.cost,
kernel=self.kernel,
cluster_scale=self.cluster_scale,
debias=self.debias,
potentials=self.potentials,
labels_x=l_x,
labels_y=l_y,
verbose=self.verbose,
)
# Make sure that the output has the correct shape ------------------------------------
if (
self.potentials
): # Return some dual potentials (= test functions) sampled on the input measures
F, G = values
return F.view_as(α), G.view_as(β)
else: # Return a scalar cost value
if backend in ["multiscale"]: # KeOps backends return a single scalar value
if B == 0:
return values # The user expects a scalar value
else:
return values.view(
-1
) # The user expects a "batch list" of distances
else: # "tensorized" backend returns a "batch vector" of values
if B == 0:
return values[0] # The user expects a scalar value
else:
return values # The user expects a "batch vector" of distances
def process_args(self, *args):
if len(args) == 6:
return args
if len(args) == 4:
α, x, β, y = args
return None, α, x, None, β, y
elif len(args) == 2:
x, y = args
α = self.generate_weights(x)
β = self.generate_weights(y)
return None, α, x, None, β, y
else:
raise ValueError(
"A SamplesLoss accepts two (x, y), four (α, x, β, y) or six (l_x, α, x, l_y, β, y) arguments."
)
def generate_weights(self, x):
if x.dim() == 2: #
N = x.shape[0]
return torch.ones(N).type_as(x) / N
elif x.dim() == 3:
B, N, _ = x.shape
return torch.ones(B, N).type_as(x) / N
else:
raise ValueError(
"Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors."
)
def check_shapes(self, l_x, α, x, l_y, β, y):
if α.dim() != β.dim():
raise ValueError(
"Input weights 'α' and 'β' should have the same number of dimensions."
)
if x.dim() != y.dim():
raise ValueError(
"Input samples 'x' and 'y' should have the same number of dimensions."
)
if x.shape[-1] != y.shape[-1]:
raise ValueError(
"Input samples 'x' and 'y' should have the same last dimension."
)
if (
x.dim() == 2
): # No batch --------------------------------------------------------------------
B = 0 # Batchsize
N, D = x.shape # Number of "i" samples, dimension of the feature space
M, _ = y.shape # Number of "j" samples, dimension of the feature space
if α.dim() not in [1, 2]:
raise ValueError(
"Without batches, input weights 'α' and 'β' should be encoded as (N,) or (N,1) tensors."
)
elif α.dim() == 2:
if α.shape[1] > 1:
raise ValueError(
"Without batches, input weights 'α' should be encoded as (N,) or (N,1) tensors."
)
if β.shape[1] > 1:
raise ValueError(
"Without batches, input weights 'β' should be encoded as (M,) or (M,1) tensors."
)
α, β = α.view(-1), β.view(-1)
if l_x is not None:
if l_x.dim() not in [1, 2]:
raise ValueError(
"Without batches, the vector of labels 'l_x' should be encoded as an (N,) or (N,1) tensor."
)
elif l_x.dim() == 2:
if l_x.shape[1] > 1:
raise ValueError(
"Without batches, the vector of labels 'l_x' should be encoded as (N,) or (N,1) tensors."
)
l_x = l_x.view(-1)
if len(l_x) != N:
raise ValueError(
"The vector of labels 'l_x' should have the same length as the point cloud 'x'."
)
if l_y is not None:
if l_y.dim() not in [1, 2]:
raise ValueError(
"Without batches, the vector of labels 'l_y' should be encoded as an (M,) or (M,1) tensor."
)
elif l_y.dim() == 2:
if l_y.shape[1] > 1:
raise ValueError(
"Without batches, the vector of labels 'l_y' should be encoded as (M,) or (M,1) tensors."
)
l_y = l_y.view(-1)
if len(l_y) != M:
raise ValueError(
"The vector of labels 'l_y' should have the same length as the point cloud 'y'."
)
N2, M2 = α.shape[0], β.shape[0]
elif (
x.dim() == 3
): # batch computation ---------------------------------------------------------
(
B,
N,
D,
) = x.shape
# Batchsize, number of "i" samples, dimension of the feature space
(
B2,
M,
_,
) = y.shape
# Batchsize, number of "j" samples, dimension of the feature space
if B != B2:
raise ValueError("Samples 'x' and 'y' should have the same batchsize.")
if α.dim() not in [2, 3]:
raise ValueError(
"With batches, input weights 'α' and 'β' should be encoded as (B,N) or (B,N,1) tensors."
)
elif α.dim() == 3:
if α.shape[2] > 1:
raise ValueError(
"With batches, input weights 'α' should be encoded as (B,N) or (B,N,1) tensors."
)
if β.shape[2] > 1:
raise ValueError(
"With batches, input weights 'β' should be encoded as (B,M) or (B,M,1) tensors."
)
α, β = α.squeeze(-1), β.squeeze(-1)
if l_x is not None:
raise NotImplementedError(
'The "multiscale" backend has not been implemented with batches.'
)
if l_y is not None:
raise NotImplementedError(
'The "multiscale" backend has not been implemented with batches.'
)
B2, N2 = α.shape
B3, M2 = β.shape
if B != B2:
raise ValueError(
"Samples 'x' and weights 'α' should have the same batchsize."
)
if B != B3:
raise ValueError(
"Samples 'y' and weights 'β' should have the same batchsize."
)
else:
raise ValueError(
"Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors."
)
if N != N2:
raise ValueError(
"Weights 'α' and samples 'x' should have compatible shapes."
)
if M != M2:
raise ValueError(
"Weights 'β' and samples 'y' should have compatible shapes."
)
return B, N, M, D, l_x, α, l_y, β