Source code for pybamm.solvers.scikits_ode_solver

#
# Solver class using Scipy's adaptive time stepper
#
import casadi
import pybamm

import numpy as np
import importlib
import scipy.sparse as sparse

scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
    scikits_odes_spec = importlib.util.find_spec("scikits.odes")
    if scikits_odes_spec is not None:
        scikits_odes = importlib.util.module_from_spec(scikits_odes_spec)
        scikits_odes_spec.loader.exec_module(scikits_odes)


def have_scikits_odes():
    return scikits_odes_spec is not None


[docs]class ScikitsOdeSolver(pybamm.BaseSolver): """Solve a discretised model, using scikits.odes. Parameters ---------- method : str, optional The method to use in solve_ivp (default is "BDF") rtol : float, optional The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation <https://bmcage.github.io/odes/dev/index.html>`_ for details. Some common keys: - 'linsolver': can be 'dense' (= default), 'lapackdense', 'spgmr', 'spbcgs', \ 'sptfqmr' """ def __init__( self, method="cvode", rtol=1e-6, atol=1e-6, extrap_tol=0, linsolver="deprecated", extra_options=None, ): if scikits_odes_spec is None: raise ImportError("scikits.odes is not installed") super().__init__(method, rtol, atol, extrap_tol=extrap_tol) self.extra_options = extra_options or {} if linsolver != "deprecated": raise ValueError( "linsolver has been deprecated. Pass 'linsolver' to extra_options " "dictionary instead" ) self.ode_solver = True self.name = "Scikits ODE solver ({})".format(method) pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") def _integrate(self, model, t_eval, inputs_dict=None): """ Solve a model defined by dydt with initial conditions y0. Parameters ---------- model : :class:`pybamm.BaseModel` The model whose solution to calculate. t_eval : numeric type The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving """ if model.rhs_eval.form == "casadi": inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) else: inputs = inputs_dict y0 = model.y0 if isinstance(y0, casadi.DM): y0 = y0.full().flatten() derivs = model.rhs_eval events = model.terminate_events_eval jacobian = model.jacobian_eval def eqsydot(t, y, return_ydot): return_ydot[:] = derivs(t, y, inputs) def rootfn(t, y, return_root): return_root[:] = [event(t, y, inputs) for event in events] if jacobian: jac_y0_t0 = jacobian(t_eval[0], y0, inputs) if sparse.issparse(jac_y0_t0): def jacfn(t, y, fy, J): J[:][:] = jacobian(t, y, inputs).toarray() def jac_times_vecfn(v, Jv, t, y, userdata): Jv[:] = userdata._jac_eval * v return 0 else: def jacfn(t, y, fy, J): J[:][:] = jacobian(t, y, inputs) def jac_times_vecfn(v, Jv, t, y, userdata): Jv[:] = np.matmul(userdata._jac_eval, v) return 0 def jac_times_setupfn(t, y, fy, userdata): userdata._jac_eval = jacobian(t, y, inputs) return 0 extra_options = { **self.extra_options, "old_api": False, "rtol": self.rtol, "atol": self.atol, } # Read linsolver (defaults to dense) linsolver = extra_options.get("linsolver", "dense") if jacobian: if linsolver in ("dense", "lapackdense"): extra_options.update({"jacfn": jacfn}) elif linsolver in ("spgmr", "spbcgs", "sptfqmr"): extra_options.update( { "jac_times_setupfn": jac_times_setupfn, "jac_times_vecfn": jac_times_vecfn, "user_data": self, } ) if events: extra_options.update({"rootfn": rootfn, "nr_rootfns": len(events)}) ode_solver = scikits_odes.ode(self.method, eqsydot, **extra_options) timer = pybamm.Timer() sol = ode_solver.solve(t_eval, y0) integration_time = timer.time() # return solution, we need to tranpose y to match scipy's ivp interface if sol.flag in [0, 2]: # 0 = solved for all t_eval if sol.flag == 0: termination = "final time" # 2 = found root(s) elif sol.flag == 2: termination = "event" if sol.roots.t is None: t_root = None else: t_root = sol.roots.t sol = pybamm.Solution( sol.values.t, np.transpose(sol.values.y), model, inputs_dict, t_root, np.transpose(sol.roots.y), termination, ) sol.integration_time = integration_time return sol else: raise pybamm.SolverError(sol.message)