Source code for Utilities.plotting

import numpy as np
import torch
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, FancyArrowPatch, Rectangle, PathPatch
from matplotlib.path import Path

from imodal.Utilities import close_shape

[docs]def plot_closed_shape(shape, **kwargs): closed_shape = close_shape(shape) plt.plot(closed_shape[:, 0], closed_shape[:, 1], **kwargs)
[docs]def plot_grid(ax, gridx, gridy, **kwargs): """ Plot grid. Parameters ---------- ax : matplotlib.axes.Axes Axes on which the grid will be drawn. gridx : Abscisse component of the grid that will be drawn. gridy : Ordinate component of the grid that will be drawn. kwargs : dict Keyword arguments that gets passed to the plot() functions. """ for i in range(gridx.shape[0]): ax.plot(gridx[i, :], gridy[i, :], **kwargs) for i in range(gridx.shape[1]): ax.plot(gridx[:, i], gridy[:, i], **kwargs)
[docs]def plot_C_arrows(ax, pos, C, R=None, c_index=0, scale=1., **kwargs): """ Plot growth constants as arrows. Parameters ---------- ax : matplotlib.axes.Axes Axes on which the arrows will be drawn. pos : torch.Tensor Positions of the growth constants arrows. C : torch.Tensor Growth constants. R : torch.Tensor, default=None Local frame of each positions. If none, will assume idendity. c_index : int, default=0 The dimension of the growth constants that will get drawn. scale : float, default=1. Scale applied to the arrow lengths. kwargs : dict Keyword arguments that gets passed to the underlying matplotlib plot functions. """ for i in range(pos.shape[0]): C_i = scale*C[i, :, c_index] arrowstyle_x = "<->" arrowstyle_y = "<->" if C_i[0] <= 0.: arrowstyle_x += ",head_length=-0.4" if C_i[1] <= 0.: arrowstyle_y += ",head_length=-0.4" if R is not None: rotmat = R[i].numpy() else: rotmat = np.eye(2) top_pos =, np.array([0., C_i[1]/2.])) + pos[i].numpy() bot_pos =, -np.array([0., C_i[1]/2])) + pos[i].numpy() left_pos =, -np.array([C_i[0]/2, 0.])) + pos[i].numpy() right_pos =, np.array([C_i[0]/2, 0.])) + pos[i].numpy() ax.add_patch(FancyArrowPatch(left_pos, right_pos, arrowstyle=arrowstyle_x, **kwargs)) ax.add_patch(FancyArrowPatch(bot_pos, top_pos, arrowstyle=arrowstyle_y, **kwargs))
[docs]def plot_C_ellipses(ax, pos, C, R=None, c_index=0, scale=1., **kwargs): """ Plot growth constants as ellipses. Parameters ---------- ax : matplotlib.axes.Axes Axes on which the ellipses will be drawn. pos : torch.Tensor Positions of the growth constants ellipses. C : torch.Tensor Growth constants. R : torch.Tensor, default=None Local frame of each positions. If none, will assume idendity. c_index : int, default=0 The dimension of the growth constants that will get drawn. scale : float, default=1. Scale applied to the ellipses. kwargs : dict Keyword arguments that gets passed to the underlying matplotlib plot functions. """ if R is not None: angle = torch.atan2(R[:, 1, 0], R[:, 0, 0])/math.pi*180. else: angle = torch.zeros(C.shape[0]) for i in range(pos.shape[0]): C_i = scale*C[i, :, c_index] e = Ellipse(xy=pos[i], width=abs(C_i[0].item()), height=abs(C_i[1].item()), angle=angle[i].item(), **kwargs) a = 0.5*(1+torch.sign(C_i[0])).item() b = 0.5*(1+torch.sign(C_i[1])).item() e.set_facecolor((0.5-0.25*(a+b), 0, 0.5+0.25*(a+b))) ax.add_artist(e)
# ax.add_artist(Ellipse(xy=pos[i], width=abs(C_i[0].item()), height=abs(C_i[1].item()), angle=angle[i].item(), **kwargs))
[docs]def set_aspect_equal_3d(ax): """Fix equal aspect bug for 3D plots. Equivalent to `plt.axis('equal')` in 3D.""" xlim = ax.get_xlim3d() ylim = ax.get_ylim3d() zlim = ax.get_zlim3d() from numpy import mean xmean = mean(xlim) ymean = mean(ylim) zmean = mean(zlim) plot_radius = max([abs(lim - mean_) for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean)) for lim in lims]) ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius]) ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius]) ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])
[docs]def plot_aabb(ax, aabb, **kwargs): ax.add_artist(Rectangle((aabb.ymin, aabb.xmin), aabb.width, aabb.height, fill=False, **kwargs))
[docs]def plot_polyline(ax, polyline, close=False, **kwargs): codes = [Path.MOVETO] codes.extend([Path.LINETO]*(len(polyline)-1)) polyline = polyline.tolist() if close: codes.append(Path.CLOSEPOLY) polyline.append((0., 0.)) path = Path(polyline, codes) ax.add_artist(PathPatch(path, **kwargs))
[docs]def plot_basis3d(points, basis, length=0.1, **kwargs): plt.quiver(points[:, 0].numpy(), points[:, 1].numpy(), points[:, 2].numpy(), basis[:, 0, 0].numpy(), basis[:, 1, 0].numpy(), basis[:, 2, 0].numpy(), length=length, color='red', **kwargs) plt.quiver(points[:, 0].numpy(), points[:, 1].numpy(), points[:, 2].numpy(), basis[:, 0, 1].numpy(), basis[:, 1, 1].numpy(), basis[:, 2, 1].numpy(), length=length, color='green', **kwargs) plt.quiver(points[:, 0].numpy(), points[:, 1].numpy(), points[:, 2].numpy(), basis[:, 0, 2].numpy(), basis[:, 1, 2].numpy(), basis[:, 2, 2].numpy(), length=length, color='blue', **kwargs)