Source code for Manifolds.Compound

from imodal.StructuredFields.Abstract import SumStructuredField
from imodal.Manifolds.Abstract import BaseManifold
from imodal.Utilities import tensors_device, tensors_dtype


[docs]class CompoundManifold(BaseManifold): def __init__(self, manifolds): self.__manifolds = manifolds device = tensors_device(self.__manifolds) dtype = tensors_dtype(self.__manifolds) super().__init__(device, dtype)
[docs] def to_(self, *argv, **kwargs): """ """ [manifold.to_(*argv, **kwargs) for manifold in self.__manifolds]
def _to_device(self, device): [manifold._to_device(device) for manifold in self.__manifolds] def _to_dtype(self, dtype): [manifold._to_dtype(dtype) for manifold in self.__manifolds] @property def device(self): return self.__manifolds[0].device
[docs] def clone(self, requires_grad=False): return CompoundManifold([m.clone(requires_grad=requires_grad) for m in self.__manifolds])
@property def manifolds(self): return self.__manifolds def __getitem__(self, index): return self.__manifolds[index] @property def dim(self): return self.__manifolds[0].dim @property def nb_pts(self): return sum([m.nb_pts for m in self.__manifolds]) @property def len_gd(self): return sum([m.len_gd for m in self.__manifolds]) @property def numel_gd(self): return tuple(sum((m.numel_gd for m in self.__manifolds), ()))
[docs] def unroll_gd(self): """Returns a flattened list of all gd tensors.""" l = [] for man in self.__manifolds: l.extend(man.unroll_gd()) return l
[docs] def unroll_cotan(self): l = [] for man in self.__manifolds: l.extend(man.unroll_cotan()) return l
[docs] def roll_gd(self, l): """ Unflattens the list into one suitable for fill_gd() or all \*_gd() numerical operations. """ out = [] for man in self.__manifolds: out.append(man.roll_gd(l)) return out
[docs] def roll_cotan(self, l): out = [] for man in self.__manifolds: out.append(man.roll_cotan(l)) return out
def __get_gd(self): return [m.gd for m in self.__manifolds] def __get_tan(self): return [m.tan for m in self.__manifolds] def __get_cotan(self): return [m.cotan for m in self.__manifolds]
[docs] def fill(self, manifold, copy=False, requires_grad=True): self.fill_gd(manifold.gd, copy=copy, requires_grad=requires_grad) self.fill_tan(manifold.tan, copy=copy, requires_grad=requires_grad) self.fill_cotan(manifold.cotan, copy=copy, requires_grad=requires_grad)
[docs] def fill_gd(self, gd, copy=False, requires_grad=None): for manifold, elem in zip(self.__manifolds, gd): manifold.fill_gd(elem, copy=copy, requires_grad=requires_grad) for i in range(len(self.__manifolds)): self.__manifolds[i].fill_gd(gd[i], copy=copy, requires_grad=requires_grad)
[docs] def fill_tan(self, tan, copy=False, requires_grad=None): for manifold, elem in zip(self.__manifolds, tan): manifold.fill_tan(elem, copy=copy, requires_grad=requires_grad)
[docs] def fill_cotan(self, cotan, copy=False, requires_grad=None): for manifold, elem in zip(self.__manifolds, cotan): manifold.fill_cotan(elem, copy=copy, requires_grad=requires_grad)
[docs] def fill_gd_randn(self, requires_grad=True): [manifold.fill_gd_randn(requires_grad=requires_grad) for manifold in self.__manifolds]
[docs] def fill_tan_randn(self, requires_grad=True): [manifold.fill_tan_randn(requires_grad=requires_grad) for manifold in self.__manifolds]
[docs] def fill_cotan_randn(self, requires_grad=True): [manifold.fill_cotan_randn(requires_grad=requires_grad) for manifold in self.__manifolds]
gd = property(__get_gd, fill_gd) tan = property(__get_tan, fill_tan) cotan = property(__get_cotan, fill_cotan)
[docs] def add_gd(self, gd): for i in range(len(self.__manifolds)): self.__manifolds[i].add_gd(gd[i])
[docs] def add_tan(self, tan): for i in range(len(self.__manifolds)): self.__manifolds[i].add_tan(tan[i])
[docs] def add_cotan(self, cotan): for i in range(len(self.__manifolds)): self.__manifolds[i].add_cotan(cotan[i])
[docs] def negate_gd(self): for m in self.__manifolds: m.negate_gd()
[docs] def negate_tan(self): for m in self.__manifolds: m.negate_tan()
[docs] def negate_cotan(self): for m in self.__manifolds: m.negate_cotan()
[docs] def cot_to_vs(self, sigma, backend=None): return SumStructuredField([m.cot_to_vs(sigma, backend=backend) for m in self.__manifolds])
[docs] def inner_prod_field(self, field): return sum([m.inner_prod_field(field) for m in self.__manifolds])
[docs] def infinitesimal_action(self, field): actions = [] for m in self.__manifolds: actions.append(m.infinitesimal_action(field)) return CompoundManifold(actions)