Fitting an image and pointsΒΆ

In this example, we will fit a cross onto the same cross, but rotated. We will take advantage of this knowledge and use a rotation deformation module. We will also add some noise on the initial center guess to show how to fit the geometrical descriptors. In addition to images, we add points at the extremity of each branch that will also get matched in order to illustrate multi deformables matching. This also helps the fitting process by adding more information to the model.

Import relevant modules.

import sys
sys.path.append("../")

import math

import torch
import matplotlib.pyplot as plt
import scipy.ndimage

import imodal

imodal.Utilities.set_compute_backend('torch')

Load image data and generate dots.

source_image = imodal.Utilities.load_greyscale_image("../data/images/cross_+_30.png", origin='lower')
target_image = imodal.Utilities.load_greyscale_image("../data/images/cross_+.png", origin='lower')

# Smoothing
sig_smooth = 0.
source_image = torch.tensor(scipy.ndimage.gaussian_filter(source_image, sig_smooth))
target_image = torch.tensor(scipy.ndimage.gaussian_filter(target_image, sig_smooth))

extent_length = 31.
extent = imodal.Utilities.AABB(0., extent_length, 0., extent_length)

dots = torch.tensor([[0., 0.5],
                     [0.5, 0.],
                     [0., -0.5],
                     [-0.5, 0.]])

source_dots = 0.6*extent_length*imodal.Utilities.linear_transform(dots, imodal.Utilities.rot2d(math.pi/3)) + extent_length*torch.tensor([0.5, 0.5])

target_dots = 0.6*extent_length*imodal.Utilities.linear_transform(dots, imodal.Utilities.rot2d(math.pi/1)) + extent_length*torch.tensor([0.5, 0.5])

center = extent_length*torch.tensor([[0.3, 0.1]])

Plot everything.

plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower', extent=extent.totuple())
plt.plot(source_dots.numpy()[:, 0], source_dots.numpy()[:, 1], '.')
plt.plot(center.numpy()[:, 0], center.numpy()[:, 1], '.')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(target_image, origin='lower', extent=extent.totuple())
plt.plot(target_dots.numpy()[:, 0], target_dots.numpy()[:, 1], '.')

plt.show()
Source image, Target image

We know that the target cross is the result of some rotation at its origin, so we use a local rotation deformation module, with an imprecise center position to simulate data aquisition noise.

rotation = imodal.DeformationModules.LocalRotation(2, 2.*extent_length, gd=center)

Create the model by setting True for fit_gd so that it also optimize the rotation center.

source_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent='match', backward=True)
target_deformable = imodal.Models.DeformableImage(target_image, output='bitmap', extent='match', backward=True)

source_dots_deformable = imodal.Models.DeformablePoints(source_dots)
target_dots_deformable = imodal.Models.DeformablePoints(target_dots)

attachment = imodal.Attachment.L2NormAttachment(transform=None)

model = imodal.Models.RegistrationModel([source_deformable, source_dots_deformable], [rotation], [attachment, imodal.Attachment.EuclideanPointwiseDistanceAttachment()], fit_gd=[True], lam=1000.)

Fit the model.

shoot_solver = 'rk4'
shoot_it = 10
max_it = 100

costs = {}
fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit([target_deformable, target_dots_deformable], max_it, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Out:

Starting optimization with method torch LBFGS, using solver rk4 with 10 iterations.
Initial cost={'deformation': 0.0, 'attach': 262755.8125}
1e-10
Evaluated model with costs=262755.8125
Evaluated model with costs=261072.8893268155
Evaluated model with costs=260262.20402636006
Evaluated model with costs=244300.85577321053
Evaluated model with costs=202114.67821502686
Evaluated model with costs=216731.24096679688
Evaluated model with costs=195680.00700378418
Evaluated model with costs=195489.87872314453
Evaluated model with costs=194501.6148071289
Evaluated model with costs=192203.5576019287
Evaluated model with costs=169738.61282348633
Evaluated model with costs=175044.40286254883
Evaluated model with costs=163547.61013793945
Evaluated model with costs=229734.60982704163
Evaluated model with costs=161746.28846740723
Evaluated model with costs=141094.7487487793
Evaluated model with costs=108456.06268310547
Evaluated model with costs=142422.21313476562
Evaluated model with costs=62099.16650390625
Evaluated model with costs=132232.1014404297
Evaluated model with costs=58461.879150390625
Evaluated model with costs=58920.74267578125
Evaluated model with costs=58320.41455078125
Evaluated model with costs=58507.76745605469
Evaluated model with costs=58284.859375
================================================================================
Time: 61.855171380000684
Iteration: 0
Costs
deformation=1121.9609375
attach=57162.8984375
Total cost=58284.859375
1e-10
Evaluated model with costs=58284.859375
Evaluated model with costs=58284.33435058594
Evaluated model with costs=58284.33410644531
Evaluated model with costs=58284.33312988281
Evaluated model with costs=58284.32946777344
Evaluated model with costs=58284.32556152344
Evaluated model with costs=58284.32971191406
Evaluated model with costs=58284.32556152344
================================================================================
Time: 80.87274068099941
Iteration: 1
Costs
deformation=1123.3411865234375
attach=57160.984375
Total cost=58284.32556152344
1e-10
Evaluated model with costs=58284.32556152344
Evaluated model with costs=58284.32971191406
Evaluated model with costs=58284.32556152344
================================================================================
Time: 87.91832934500053
Iteration: 2
Costs
deformation=1123.3411865234375
attach=57160.984375
Total cost=58284.32556152344
================================================================================
Optimisation process exited with message: Convergence achieved.
Final cost=58284.32556152344
Model evaluation count=36
Time elapsed = 87.91887121699983

Plot total cost evolution.

total_costs = [sum(cost) for cost in list(map(list, zip(*costs.values())))]

plt.title("Total cost evolution")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.grid(True)
plt.plot(range(len(total_costs)), total_costs, color='black', lw=0.7)
plt.show()
Total cost evolution

Compute the final deformed source and plot it.

with torch.autograd.no_grad():
    model.deformables[0].output = 'bitmap'
    deformed = model.compute_deformed(shoot_solver, shoot_it)

    deformed_image = deformed[0][0].view_as(source_image)
    deformed_dots = deformed[1][0]

fitted_center = model.init_manifold[2].gd.detach()

print("Fitted rotatation center: {center}".format(center=fitted_center.detach().tolist()))

plt.subplot(1, 3, 1)
plt.title("Source image")
plt.imshow(source_image.numpy(), origin='lower', extent=extent.totuple())
plt.plot(source_dots.numpy()[:, 0], source_dots.numpy()[:, 1], '.')
plt.plot(center.numpy()[0, 0], center.numpy()[0, 1], 'X')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Fitted image")
plt.imshow(deformed_image.numpy(), origin='lower', extent=extent.totuple())
plt.plot(deformed_dots.numpy()[:, 0], deformed_dots.numpy()[:, 1], '.')
plt.plot(fitted_center.numpy()[0, 0], fitted_center.numpy()[0, 1], 'X')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("Target image")
plt.imshow(target_image.numpy(), origin='lower', extent=extent.totuple())
plt.plot(target_dots.numpy()[:, 0], target_dots.numpy()[:, 1], '.')
plt.axis('off')

plt.show()
Source image, Fitted image, Target image

Out:

Fitted rotatation center: [[15.522696495056152, 15.47985553741455]]

Total running time of the script: ( 1 minutes 31.530 seconds)

Gallery generated by Sphinx-Gallery