Analyzing Differences Between Tree Images

Image registration with an implicit module of order 1. Segmentations given by the data are used to initialize its points.

Initialization

Import relevant Python modules.

import time
import pickle
import sys
sys.path.append("../")

import matplotlib.pyplot as plt
import torch

import imodal


device = 'cuda:1'
torch.set_default_dtype(torch.float64)
imodal.Utilities.set_compute_backend('keops')

Load source and target images, along with the source curve.

with open("../data/tree_growth.pickle", 'rb') as f:
    data = pickle.load(f)

source_shape = data['source_shape'].to(torch.get_default_dtype())
source_image = data['source_image'].to(torch.get_default_dtype())
target_image = data['target_image'].to(torch.get_default_dtype())

# Segmentations as Axis Aligned Bounding Boxes (AABB)
aabb_trunk = data['aabb_trunk']
aabb_crown = data['aabb_leaves']
extent = data['extent']

Display source and target images, along with the segmented source curve (orange for the trunk, green for the crown).

shape_is_trunk = aabb_trunk.is_inside(source_shape)
shape_is_crown = aabb_crown.is_inside(source_shape)

plt.subplot(1, 2, 1)
plt.title("Source")
plt.imshow(source_image, cmap='gray', origin='lower', extent=extent.totuple())
plt.plot(source_shape[shape_is_trunk, 0].numpy(), source_shape[shape_is_trunk, 1].numpy(), lw=2., color='orange')
plt.plot(source_shape[shape_is_crown, 0].numpy(), source_shape[shape_is_crown, 1].numpy(), lw=2., color='green')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Target")
plt.imshow(target_image, cmap='gray', origin='lower', extent=extent.totuple())
plt.axis('off')
plt.show()
Source, Target

Generating implicit modules of order 1 points and growth model tensor.

implicit1_density = 500.

# Lambda function defining the area in and around the tree shape
area = lambda x, **kwargs: imodal.Utilities.area_shape(x, **kwargs) | imodal.Utilities.area_polyline_outline(x, **kwargs)
polyline_width = 0.07

# Generation of the points of the initial geometrical descriptor
implicit1_points = imodal.Utilities.fill_area_uniform_density(area, imodal.Utilities.AABB(xmin=0., xmax=1., ymin=0., ymax=1.), implicit1_density, shape=source_shape, polyline=source_shape, width=polyline_width)

# Masks that flag points into either the trunk or the crown
implicit1_trunk_points = aabb_trunk.is_inside(implicit1_points)
implicit1_crown_points = aabb_crown.is_inside(implicit1_points)

implicit1_points = implicit1_points[implicit1_trunk_points | implicit1_crown_points]
implicit1_trunk_points = aabb_trunk.is_inside(implicit1_points)
implicit1_crown_points = aabb_crown.is_inside(implicit1_points)

assert implicit1_points[implicit1_trunk_points].shape[0] + implicit1_points[implicit1_crown_points].shape[0] == implicit1_points.shape[0]

# Initial normal frames
implicit1_r = torch.eye(2).repeat(implicit1_points.shape[0], 1, 1)

# Growth model tensor
implicit1_c = torch.zeros(implicit1_points.shape[0], 2, 4)

# Horizontal stretching for the trunk
implicit1_c[implicit1_trunk_points, 0, 0] = 1.
# Vertical stretching for the trunk
implicit1_c[implicit1_trunk_points, 1, 1] = 1.
# Horizontal stretching for the crown
implicit1_c[implicit1_crown_points, 0, 2] = 1.
# Vertical stretching for the crown
implicit1_c[implicit1_crown_points, 1, 3] = 1.

Plot the 4 dimensional growth model tensor.

plt.figure(figsize=[20., 5.])
for i in range(4):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow(source_image, origin='lower', extent=extent, cmap='gray')
    imodal.Utilities.plot_C_ellipses(ax, implicit1_points, implicit1_c, c_index=i, color='blue', scale=0.03)
    plt.xlim(0., 1.)
    plt.ylim(0., 1.)
    plt.axis('off')

plt.show()
plot tree growth

Create the deformation model with a combination of 2 modules : a global translation and the implicit module of order 1.

Create and initialize the global translation module global_translation.

global_translation_coeff = 1.
global_translation = imodal.DeformationModules.GlobalTranslation(2, coeff=global_translation_coeff)

Create and initialize the implicit module of order 1 implicit1.

sigma1 = 2./implicit1_density**(1/2)
implicit1_coeff = 0.1
implicit1_nu = 100.
implicit1 = imodal.DeformationModules.ImplicitModule1(2, implicit1_points.shape[0], sigma1, implicit1_c, nu=implicit1_nu, gd=(implicit1_points, implicit1_r), coeff=implicit1_coeff)
implicit1.eps = 1e-2

Define deformables used by the registration model.

source_image_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent=extent)
target_image_deformable = imodal.Models.DeformableImage(target_image, output='bitmap', extent=extent)

source_image_deformable.to_device(device)
target_image_deformable.to_device(device)

Registration

Define the registration model.

attachment_image = imodal.Attachment.L2NormAttachment(weight=1e0)

model = imodal.Models.RegistrationModel([source_image_deformable], [implicit1, global_translation], [attachment_image], lam=1.)
model.to_device(device)

Fitting using Torch LBFGS optimizer.

shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs')
fitter.fit([target_image_deformable], 500, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe', 'history_size': 500})

Out:

Starting optimization with method torch LBFGS, using solver euler with 10 iterations.
Initial cost={'deformation': 0.0, 'attach': 76225141.99999996}
1e-10
Evaluated model with costs=76225141.99999996
Evaluated model with costs=171572063.7711776
Evaluated model with costs=139392999.86042202
Evaluated model with costs=78050686.03371653
Evaluated model with costs=73566406.0385577
Evaluated model with costs=71212065.45047338
Evaluated model with costs=73142005.10259366
Evaluated model with costs=68493909.30452502
Evaluated model with costs=63894687.42130799
Evaluated model with costs=54071234.86216366
Evaluated model with costs=57372934.35765917
Evaluated model with costs=51772472.06776254
Evaluated model with costs=51614666.44162917
Evaluated model with costs=51631760.30806436
Evaluated model with costs=51596060.578789294
Evaluated model with costs=51541810.05689387
Evaluated model with costs=51636113.499598734
Evaluated model with costs=51508983.80014829
Evaluated model with costs=51528232.72499685
Evaluated model with costs=51484278.55399751
Evaluated model with costs=51482157.516750224
Evaluated model with costs=51470192.29707427
Evaluated model with costs=51463974.3567045
Evaluated model with costs=51461508.19759631
Evaluated model with costs=51643472.872851625
Evaluated model with costs=51431312.27498791
Evaluated model with costs=51400410.929490514
================================================================================
Time: 285.3828136604279
Iteration: 0
Costs
deformation=0.0014938654113724153
attach=51400410.92799665
Total cost=51400410.929490514
1e-10
Evaluated model with costs=51400410.929490514
Evaluated model with costs=51372153.25674013
Evaluated model with costs=51243307.67112245
Evaluated model with costs=50857145.378125444
Evaluated model with costs=49862989.73707366
Evaluated model with costs=47105362.782228194
Evaluated model with costs=44072906.09299316
Evaluated model with costs=42682094.04029929
Evaluated model with costs=41343749.08348512
Evaluated model with costs=40948648.94783037
Evaluated model with costs=42963919.34603904
Evaluated model with costs=40818709.774088524
Evaluated model with costs=40650731.54944925
Evaluated model with costs=40640567.94245269
Evaluated model with costs=40635001.743004285
Evaluated model with costs=40594452.644408636
Evaluated model with costs=40495298.68054157
Evaluated model with costs=40476382.02194792
Evaluated model with costs=40468444.216567166
Evaluated model with costs=40483691.63419058
Evaluated model with costs=40466532.498703115
Evaluated model with costs=40462907.40166026
Evaluated model with costs=40460895.28739086
Evaluated model with costs=40471881.66002083
Evaluated model with costs=40459384.66148927
================================================================================
Time: 554.9204118531197
Iteration: 1
Costs
deformation=0.0031694979136689615
attach=40459384.65831977
Total cost=40459384.66148927
1e-10
Evaluated model with costs=40459384.66148927
Evaluated model with costs=40457936.19874938
Evaluated model with costs=40457094.23267533
Evaluated model with costs=40453377.51434859
Evaluated model with costs=40446074.422715515
Evaluated model with costs=40439276.149429545
Evaluated model with costs=40421923.642156616
Evaluated model with costs=40438744.49859901
Evaluated model with costs=40412731.219843626
Evaluated model with costs=40381014.7191554
Evaluated model with costs=40334598.869233325
Evaluated model with costs=40306579.21563403
Evaluated model with costs=40289297.91425672
Evaluated model with costs=40356189.12713901
Evaluated model with costs=40287311.02751078
Evaluated model with costs=40255513.244219854
Evaluated model with costs=40280673.59499148
Evaluated model with costs=40251691.412622325
Evaluated model with costs=40246971.27636479
Evaluated model with costs=40254034.73996697
Evaluated model with costs=40246810.98547113
Evaluated model with costs=40244712.03575505
Evaluated model with costs=40239748.486258365
Evaluated model with costs=40239143.19981024
Evaluated model with costs=40239609.1093603
Evaluated model with costs=40231177.76497096
================================================================================
Time: 834.8457618234679
Iteration: 2
Costs
deformation=0.0029585594626364542
attach=40231177.7620124
Total cost=40231177.76497096
1e-10
Evaluated model with costs=40231177.76497096
Evaluated model with costs=41292855.98371006
Evaluated model with costs=40234181.71671496
Evaluated model with costs=40232155.75473622
Evaluated model with costs=40231102.49086602
Evaluated model with costs=40231028.07684453
Evaluated model with costs=40230983.97293015
Evaluated model with costs=40230920.237769105
Evaluated model with costs=40230886.98778354
Evaluated model with costs=40230836.7508002
Evaluated model with costs=40230813.458332665
Evaluated model with costs=40230769.37910962
Evaluated model with costs=40230754.77452757
Evaluated model with costs=40230719.95041699
Evaluated model with costs=40241850.594302274
Evaluated model with costs=40231602.50590897
Evaluated model with costs=40231788.30883913
Evaluated model with costs=40232266.59302029
Evaluated model with costs=40232316.556837305
Evaluated model with costs=40230721.28523559
================================================================================
Time: 1052.240169564262
Iteration: 3
Costs
deformation=0.0029572417620960805
attach=40230721.28227835
Total cost=40230721.28523559
1e-10
Evaluated model with costs=40230719.95041699
Evaluated model with costs=40241850.594302274
Evaluated model with costs=40231602.50590897
Evaluated model with costs=40231788.30883913
Evaluated model with costs=40232266.59302029
Evaluated model with costs=40232316.556837305
Evaluated model with costs=40230721.28523559
================================================================================
Time: 1128.331251487136
Iteration: 4
Costs
deformation=0.0029572417620960805
attach=40230721.28227835
Total cost=40230721.28523559
================================================================================
Optimisation process exited with message: Convergence achieved.
Final cost=40230721.28523559
Model evaluation count=105
Time elapsed = 1128.3314475277439

Visualization

Compute optimized deformation trajectory.

deformed_intermediates = {}
start = time.perf_counter()
with torch.autograd.no_grad():
    deformed_image = model.compute_deformed(shoot_solver, shoot_it, intermediates=deformed_intermediates)[0][0].detach().cpu()
print("Elapsed={elapsed}".format(elapsed=time.perf_counter()-start))

Out:

Elapsed=2.6374162305146456

Display deformed source image and target.

plt.figure(figsize=[15., 5.])
plt.subplot(1, 3, 1)
plt.title("Source")
plt.imshow(source_image, extent=extent.totuple(), origin='lower')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Deformed")
plt.imshow(deformed_image, extent=extent.totuple(), origin='lower')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Target")
plt.imshow(target_image, extent=extent.totuple(), origin='lower')
plt.axis('off')

plt.show()
Source, Deformed, Target

We can follow the action of each part of the total deformation by setting all the controls components to zero but one.

Functions generating controls to follow one part of the deformation.

def generate_implicit1_controls(table):
    outcontrols = []
    for control in deformed_intermediates['controls']:
        outcontrols.append(control[1]*torch.tensor(table, dtype=torch.get_default_dtype(), device=device))

    return outcontrols


def generate_controls(implicit1_table, trans):
    outcontrols = []
    implicit1_controls = generate_implicit1_controls(implicit1_table)
    for control, implicit1_control in zip(deformed_intermediates['controls'], implicit1_controls):
        outcontrols.append([implicit1_control, control[2]*torch.tensor(trans, dtype=torch.get_default_dtype(), device=device)])

    return outcontrols

Function to compute a deformation given a set of controls up to some time point.

grid_resolution = [16, 16]


def compute_intermediate_deformed(it, controls, t1, intermediates=None):
    implicit1_points = deformed_intermediates['states'][0][1].gd[0]
    implicit1_r = deformed_intermediates['states'][0][1].gd[1]
    implicit1_cotan_points = deformed_intermediates['states'][0][1].cotan[0]
    implicit1_cotan_r = deformed_intermediates['states'][0][1].cotan[1]
    silent_cotan = deformed_intermediates['states'][0][0].cotan

    implicit1 = imodal.DeformationModules.ImplicitModule1(2, implicit1_points.shape[0], sigma1, implicit1_c.clone(), nu=implicit1_nu, gd=(implicit1_points.clone(), implicit1_r.clone()), cotan=(implicit1_cotan_points, implicit1_cotan_r), coeff=implicit1_coeff)
    global_translation = imodal.DeformationModules.GlobalTranslation(2, coeff=global_translation_coeff)

    implicit1.to_(device=device)
    global_translation.to_(device=device)

    source_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent=extent)
    source_deformable.silent_module.manifold.cotan = silent_cotan

    grid_deformable = imodal.Models.DeformableGrid(extent, grid_resolution)

    source_deformable.to_device(device)
    grid_deformable.to_device(device)

    costs = {}
    with torch.autograd.no_grad():
        deformed = imodal.Models.deformables_compute_deformed([source_deformable, grid_deformable], [implicit1, global_translation], shoot_solver, it, controls=controls, t1=t1, intermediates=intermediates, costs=costs)

    return deformed[0][0]

Functions to generate the deformation trajectory given a set of controls.

def generate_images(table, trans, outputfilename):
    incontrols = generate_controls(table, trans)
    intermediates_shape = {}
    deformed = compute_intermediate_deformed(10, incontrols, 1., intermediates=intermediates_shape)

    trajectory_grid = [imodal.Utilities.vec2grid(state[1].gd, grid_resolution[0], grid_resolution[1]) for state in intermediates_shape['states']]

    trajectory = [source_image]
    t = torch.linspace(0., 1., 11)
    indices = [0, 3, 7, 10]
    print("Computing trajectories...")
    for index in indices[1:]:
        print("{}, t={}".format(index, t[index]))
        deformed = compute_intermediate_deformed(index, incontrols[:4*index], t[index])

        trajectory.append(deformed)

    print("Generating images...")
    plt.figure(figsize=[5.*len(indices), 5.])
    for deformed, i in zip(trajectory, range(len(indices))):
        ax = plt.subplot(1, len(indices), i + 1)

        grid = trajectory_grid[indices[i]]
        plt.imshow(deformed.cpu(), origin='lower', extent=extent, cmap='gray')
        imodal.Utilities.plot_grid(ax, grid[0].cpu(), grid[1].cpu(), color='xkcd:light blue', lw=1)
        plt.xlim(0., 1.)
        plt.ylim(0., 1.)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

Generate trajectory of the total optimized deformation.

generate_images([True, True, True, True], True, "deformed_all")
plot tree growth

Out:

Computing trajectories...
3, t=0.30000000000000004
7, t=0.7000000000000001
10, t=1.0
Generating images...

Generate trajectory following vertical elongation of the trunk.

generate_images([False, True, False, False], False, "deformed_trunk_vertical")
plot tree growth

Out:

Computing trajectories...
3, t=0.30000000000000004
7, t=0.7000000000000001
10, t=1.0
Generating images...

Generate trajectory following horizontal elongation of the crown.

generate_images([False, False, True, False], False, "deformed_crown_horizontal")
plot tree growth

Out:

Computing trajectories...
3, t=0.30000000000000004
7, t=0.7000000000000001
10, t=1.0
Generating images...

Total running time of the script: ( 19 minutes 16.473 seconds)

Gallery generated by Sphinx-Gallery