Source code for DeformationModules.GlobalTranslation

import torch

from imodal.StructuredFields import ConstantField
from imodal.Manifolds import EmptyManifold
from imodal.DeformationModules.Abstract import DeformationModule


[docs]class GlobalTranslation(DeformationModule): """ Global translation deformation module. """ def __init__(self, dim, coeff=1., label=None): super().__init__(label) self.__controls = torch.zeros(dim) self.__coeff = coeff self.__manifold = EmptyManifold(dim) def __str__(self): outstr = "Global translation\n" if self.label: outstr += " Label=" + self.label + "\n" outstr += " Coeff=" + str(self.__coeff) return outstr
[docs] @classmethod def build(cls, dim, coeff=1., label=None): return cls(dim, coeff, label)
[docs] def to_(self, *args, **kwargs): self.__manifold.to_(*args, **kwargs) self.__controls = self.__controls.to(*args, **kwargs)
@property def coeff(self): return self.__coeff @property def manifold(self): return self.__manifold @property def device(self): return self.__manifold.device def __get_controls(self): return self.__controls
[docs] def fill_controls(self, controls): self.__controls = controls.clone()
def __get_coeff(self): return self.__coeff def __set_coeff(self, coeff): self.__coeff = coeff controls = property(__get_controls, fill_controls) coeff = property(__get_coeff, __set_coeff)
[docs] def fill_controls_zero(self): self.fill_controls(torch.zeros_like(self.__controls))
[docs] def __call__(self, points): return self.field_generator()(points)
[docs] def cost(self): return 0.5 * self.__coeff * torch.dot(self.__controls, self.__controls)
[docs] def compute_geodesic_control(self, man): """Computes geodesic control from StructuredField vs.""" geodesic_controls = torch.zeros_like(self.__controls) for i in range(self.__controls.shape[0]): cont_i = torch.zeros_like(self.__controls) cont_i[i] = 1. v_i = ConstantField(cont_i) geodesic_controls[i] = man.inner_prod_field(v_i) / self.__coeff self.__controls = geodesic_controls
[docs] def field_generator(self): return ConstantField(self.__controls)