.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/plot_cross_fit.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_plot_cross_fit.py: Fitting an image and points =========================== In this example, we will fit a cross onto the same cross, but rotated. We will take advantage of this knowledge and use a rotation deformation module. We will also add some noise on the initial center guess to show how to fit the geometrical descriptors. In addition to images, we add points at the extremity of each branch that will also get matched in order to illustrate multi deformables matching. This also helps the fitting process by adding more information to the model. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Import relevant modules. .. GENERATED FROM PYTHON SOURCE LINES 12-26 .. code-block:: default import sys sys.path.append("../") import math import torch import matplotlib.pyplot as plt import scipy.ndimage import imodal imodal.Utilities.set_compute_backend('torch') .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load image data and generate dots. .. GENERATED FROM PYTHON SOURCE LINES 29-53 .. code-block:: default source_image = imodal.Utilities.load_greyscale_image("../data/images/cross_+_30.png", origin='lower') target_image = imodal.Utilities.load_greyscale_image("../data/images/cross_+.png", origin='lower') # Smoothing sig_smooth = 0. source_image = torch.tensor(scipy.ndimage.gaussian_filter(source_image, sig_smooth)) target_image = torch.tensor(scipy.ndimage.gaussian_filter(target_image, sig_smooth)) extent_length = 31. extent = imodal.Utilities.AABB(0., extent_length, 0., extent_length) dots = torch.tensor([[0., 0.5], [0.5, 0.], [0., -0.5], [-0.5, 0.]]) source_dots = 0.6*extent_length*imodal.Utilities.linear_transform(dots, imodal.Utilities.rot2d(math.pi/3)) + extent_length*torch.tensor([0.5, 0.5]) target_dots = 0.6*extent_length*imodal.Utilities.linear_transform(dots, imodal.Utilities.rot2d(math.pi/1)) + extent_length*torch.tensor([0.5, 0.5]) center = extent_length*torch.tensor([[0.3, 0.1]]) .. GENERATED FROM PYTHON SOURCE LINES 54-56 Plot everything. .. GENERATED FROM PYTHON SOURCE LINES 56-70 .. code-block:: default plt.subplot(1, 2, 1) plt.title("Source image") plt.imshow(source_image, origin='lower', extent=extent.totuple()) plt.plot(source_dots.numpy()[:, 0], source_dots.numpy()[:, 1], '.') plt.plot(center.numpy()[:, 0], center.numpy()[:, 1], '.') plt.subplot(1, 2, 2) plt.title("Target image") plt.imshow(target_image, origin='lower', extent=extent.totuple()) plt.plot(target_dots.numpy()[:, 0], target_dots.numpy()[:, 1], '.') plt.show() .. image:: /_auto_examples/images/sphx_glr_plot_cross_fit_001.png :alt: Source image, Target image :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 71-75 We know that the target cross is the result of some rotation at its origin, so we use a local rotation deformation module, with an imprecise center position to simulate data aquisition noise. .. GENERATED FROM PYTHON SOURCE LINES 75-79 .. code-block:: default rotation = imodal.DeformationModules.LocalRotation(2, 2.*extent_length, gd=center) .. GENERATED FROM PYTHON SOURCE LINES 80-83 Create the model by setting `True` for `fit_gd` so that it also optimize the rotation center. .. GENERATED FROM PYTHON SOURCE LINES 83-95 .. code-block:: default source_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent='match', backward=True) target_deformable = imodal.Models.DeformableImage(target_image, output='bitmap', extent='match', backward=True) source_dots_deformable = imodal.Models.DeformablePoints(source_dots) target_dots_deformable = imodal.Models.DeformablePoints(target_dots) attachment = imodal.Attachment.L2NormAttachment(transform=None) model = imodal.Models.RegistrationModel([source_deformable, source_dots_deformable], [rotation], [attachment, imodal.Attachment.EuclideanPointwiseDistanceAttachment()], fit_gd=[True], lam=1000.) .. GENERATED FROM PYTHON SOURCE LINES 96-98 Fit the model. .. GENERATED FROM PYTHON SOURCE LINES 98-109 .. code-block:: default shoot_solver = 'rk4' shoot_it = 10 max_it = 100 costs = {} fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs') fitter.fit([target_deformable, target_dots_deformable], max_it, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'}) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Starting optimization with method torch LBFGS, using solver rk4 with 10 iterations. Initial cost={'deformation': 0.0, 'attach': 262755.8125} 1e-10 Evaluated model with costs=262755.8125 Evaluated model with costs=261072.8893268155 Evaluated model with costs=260262.20402636006 Evaluated model with costs=244300.85577321053 Evaluated model with costs=202114.67821502686 Evaluated model with costs=216731.24096679688 Evaluated model with costs=195680.00700378418 Evaluated model with costs=195489.87872314453 Evaluated model with costs=194501.6148071289 Evaluated model with costs=192203.5576019287 Evaluated model with costs=169738.61282348633 Evaluated model with costs=175044.40286254883 Evaluated model with costs=163547.61013793945 Evaluated model with costs=229734.60982704163 Evaluated model with costs=161746.28846740723 Evaluated model with costs=141094.7487487793 Evaluated model with costs=108456.06268310547 Evaluated model with costs=142422.21313476562 Evaluated model with costs=62099.16650390625 Evaluated model with costs=132232.1014404297 Evaluated model with costs=58461.879150390625 Evaluated model with costs=58920.74267578125 Evaluated model with costs=58320.41455078125 Evaluated model with costs=58507.76745605469 Evaluated model with costs=58284.859375 ================================================================================ Time: 61.855171380000684 Iteration: 0 Costs deformation=1121.9609375 attach=57162.8984375 Total cost=58284.859375 1e-10 Evaluated model with costs=58284.859375 Evaluated model with costs=58284.33435058594 Evaluated model with costs=58284.33410644531 Evaluated model with costs=58284.33312988281 Evaluated model with costs=58284.32946777344 Evaluated model with costs=58284.32556152344 Evaluated model with costs=58284.32971191406 Evaluated model with costs=58284.32556152344 ================================================================================ Time: 80.87274068099941 Iteration: 1 Costs deformation=1123.3411865234375 attach=57160.984375 Total cost=58284.32556152344 1e-10 Evaluated model with costs=58284.32556152344 Evaluated model with costs=58284.32971191406 Evaluated model with costs=58284.32556152344 ================================================================================ Time: 87.91832934500053 Iteration: 2 Costs deformation=1123.3411865234375 attach=57160.984375 Total cost=58284.32556152344 ================================================================================ Optimisation process exited with message: Convergence achieved. Final cost=58284.32556152344 Model evaluation count=36 Time elapsed = 87.91887121699983 .. GENERATED FROM PYTHON SOURCE LINES 110-112 Plot total cost evolution. .. GENERATED FROM PYTHON SOURCE LINES 112-123 .. code-block:: default total_costs = [sum(cost) for cost in list(map(list, zip(*costs.values())))] plt.title("Total cost evolution") plt.xlabel("Iteration") plt.ylabel("Cost") plt.grid(True) plt.plot(range(len(total_costs)), total_costs, color='black', lw=0.7) plt.show() .. image:: /_auto_examples/images/sphx_glr_plot_cross_fit_002.png :alt: Total cost evolution :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 124-126 Compute the final deformed source and plot it. .. GENERATED FROM PYTHON SOURCE LINES 126-160 .. code-block:: default with torch.autograd.no_grad(): model.deformables[0].output = 'bitmap' deformed = model.compute_deformed(shoot_solver, shoot_it) deformed_image = deformed[0][0].view_as(source_image) deformed_dots = deformed[1][0] fitted_center = model.init_manifold[2].gd.detach() print("Fitted rotatation center: {center}".format(center=fitted_center.detach().tolist())) plt.subplot(1, 3, 1) plt.title("Source image") plt.imshow(source_image.numpy(), origin='lower', extent=extent.totuple()) plt.plot(source_dots.numpy()[:, 0], source_dots.numpy()[:, 1], '.') plt.plot(center.numpy()[0, 0], center.numpy()[0, 1], 'X') plt.axis('off') plt.subplot(1, 3, 2) plt.title("Fitted image") plt.imshow(deformed_image.numpy(), origin='lower', extent=extent.totuple()) plt.plot(deformed_dots.numpy()[:, 0], deformed_dots.numpy()[:, 1], '.') plt.plot(fitted_center.numpy()[0, 0], fitted_center.numpy()[0, 1], 'X') plt.axis('off') plt.subplot(1, 3, 3) plt.title("Target image") plt.imshow(target_image.numpy(), origin='lower', extent=extent.totuple()) plt.plot(target_dots.numpy()[:, 0], target_dots.numpy()[:, 1], '.') plt.axis('off') plt.show() .. image:: /_auto_examples/images/sphx_glr_plot_cross_fit_003.png :alt: Source image, Fitted image, Target image :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Fitted rotatation center: [[15.522696495056152, 15.47985553741455]] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 31.530 seconds) .. _sphx_glr_download__auto_examples_plot_cross_fit.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cross_fit.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cross_fit.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_