Source code for imodal.Models.fitter

import time

import torch

from imodal.Models import BaseOptimizer, create_optimizer, get_default_optimizer
from imodal.Models import BaseModel
from imodal.Utilities import append_in_dict_of_list


[docs]class Fitter: def __init__(self, model, optimizer=None, post_iteration_callback=None): assert isinstance(model, BaseModel) assert isinstance(optimizer, BaseOptimizer) or isinstance(optimizer, str) or optimizer is None if optimizer is None: optimizer = get_default_optimizer() if isinstance(optimizer, str): optimizer = create_optimizer(optimizer, model) self.__optimizer = optimizer self.__post_iteration_callback = post_iteration_callback self.__model = model self.reset()
[docs] def reset(self): self.__it = 0 self.__optimizer.reset()
[docs] def fit(self, target, max_iter, options={}, costs=None, disp=True): assert isinstance(costs, dict) or costs is None shoot_solver = 'euler' shoot_it = 10 tol = None if 'shoot_solver' in options: shoot_solver = options['shoot_solver'] del options['shoot_solver'] if 'shoot_it' in options: shoot_it = options['shoot_it'] del options['shoot_it'] if 'tol' in options: tol = options['tol'] del options['tol'] # Initial cost if costs is not None or disp: #with torch.autograd.no_grad(): cost_0 = self.__model.evaluate(target, shoot_solver, shoot_it) if costs is not None: append_in_dict_of_list(costs, cost_0) if disp: print("Starting optimization with method {}, using solver {} with {} iterations.".format(self.__optimizer.method_name, shoot_solver, shoot_it)) print("Initial cost={cost}".format(cost=cost_0)) def _post_iteration_callback(model, last_costs): # Display progression if disp: print("="*80) print("Time: {time}".format(time=time.perf_counter() - start_time)) print("Iteration: {it}".format(it=self.__it)) self.__print_costs(last_costs) self.__it = self.__it + 1 if costs: append_in_dict_of_list(costs, last_costs) start_time = time.perf_counter() res = self.__optimizer.optimize(target, max_iter, _post_iteration_callback, costs, shoot_solver, shoot_it, tol, options=options) if disp: print("="*80) print("Optimisation process exited with message: {message}".format(message=res['message'])) print("Final cost={cost}".format(cost=res['final'])) print("Model evaluation count={neval}".format(neval=res['neval'])) if 'neval_grad' in res: print("Model gradient evaluation count={neval_grad}".format(neval_grad=res['neval_grad'])) print("Time elapsed =", time.perf_counter() - start_time)
def __print_costs(self, costs): print("Costs") for key in costs.keys(): print("{cost}={value}".format(cost=key, value=costs[key])) print("Total cost={total}".format(total=sum(costs.values())))