Source code for imodal.Attachment.attachment

from collections import Iterable

import geomloss
import torch

from imodal.Kernels.kernels import distances, scal


class Attachment:
    def __init__(self, weight=1.):
        self.__weight = weight

    def __str__(self):
        return "{classname} (weight={weight})".format(classname=self.__class__.__name__, weight=self.weight)

    """
    Weight given to this attachment term. Useful when dealing with compound attachment.
    """
    @property
    def weight(self):
        return self.__weight

    """
    Computes the attachment term between two objects.

    Parameters
    ----------
    source :
        The source term.
    target :
        The target term.

    Returns
    -------
    torch.Tensor
        Value quantifying the attachment between the source and the target.
    """
    def __call__(self, source, target):
        return self.__weight*self.loss(source, target)

    """

    """
    def loss(self, source, target):
        raise NotImplementedError


[docs]class CompoundAttachment(Attachment): """Compound attachment measure. Can be used to combine different measures together""" def __init__(self, attachments, weight=1.): assert isinstance(attachments, Iterable) self.__attachments = attachments super().__init__(weight)
[docs] def loss(self, source, target): return sum([attachment.loss(source, target) for attachment in self.__attachments])
[docs]class EnergyAttachment(Attachment): """Energy Distance between two sampled probability measures.""" def __init__(self, weight=1.): super().__init__(weight)
[docs] def loss(self, source, target): if len(source) == 1: x_i = source[0] a_i = torch.ones(x_i.shape[0]) y_j = target[0] b_j = torch.ones(y_j.shape[0]) else: x_i, a_i = source y_j, b_j = target K_xx = -distances(x_i, x_i) K_xy = -distances(x_i, y_j) K_yy = -distances(y_j, y_j) return .5*scal(a_i, torch.mm(K_xx, a_i.view(-1, 1))) - scal(a_i, torch.mm(K_xy, b_j.view(-1, 1))) + .5*scal(b_j, torch.mm(K_yy, b_j.view(-1, 1)))
class GeomlossAttachment(Attachment): def __init__(self, weight=1., **kwargs): super().__init__(weight) self.__geomloss = geomloss.SamplesLoss(**kwargs) def loss(self, source, target): if isinstance(source, Iterable) and not isinstance(source, torch.Tensor): return self.__geomloss(source[1], source[0], target[1], target[0]) else: return self.__geomloss(source, target)
[docs]class EuclideanPointwiseDistanceAttachment(Attachment): """Euclidean pointwise distance between two measures.""" def __init__(self, weight=1.): super().__init__(weight)
[docs] def loss(self, source, target): x = source[0] y = target[0] return torch.sum(torch.norm(x-y, dim=1))
class NullLoss(Attachment): def __init__(self): super().__init__(0.) def loss(self, source, target): return torch.tensor(0.)