Source code for pybamm.expression_tree.functions

#
# Function classes and methods
#
import autograd
import numbers
import numpy as np
from scipy import special
import pybamm


[docs]class Function(pybamm.Symbol): """A node in the expression tree representing an arbitrary function Parameters ---------- function : method A function can have 0 or many inputs. If no inputs are given, self.evaluate() simply returns func(). Otherwise, self.evaluate(t, y, u) returns func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc). children : :class:`pybamm.Symbol` The children nodes to apply the function to derivative : str, optional Which derivative to use when differentiating ("autograd" or "derivative"). Default is "autograd". differentiated_function : method, optional The function which was differentiated to obtain this one. Default is None. **Extends:** :class:`pybamm.Symbol` """ def __init__( self, function, *children, name=None, derivative="autograd", differentiated_function=None ): # Turn numbers into scalars children = list(children) for idx, child in enumerate(children): if isinstance(child, numbers.Number): children[idx] = pybamm.Scalar(child) if name is not None: self.name = name else: try: name = "function ({})".format(function.__name__) except AttributeError: name = "function ({})".format(function.__class__) domain = self.get_children_domains(children) auxiliary_domains = self.get_children_auxiliary_domains(children) self.function = function self.derivative = derivative self.differentiated_function = differentiated_function super().__init__( name, children=children, domain=domain, auxiliary_domains=auxiliary_domains ) def __str__(self): """ See :meth:`pybamm.Symbol.__str__()`. """ out = "{}(".format(self.name[10:-1]) for child in self.children: out += "{!s}, ".format(child) out = out[:-2] + ")" return out
[docs] def get_children_domains(self, children_list): """Obtains the unique domain of the children. If the children have different domains then raise an error""" domains = [child.domain for child in children_list if child.domain != []] # check that there is one common domain amongst children distinct_domains = set(tuple(dom) for dom in domains) if len(distinct_domains) > 1: raise pybamm.DomainError( "Functions can only be applied to variables on the same domain" ) elif len(distinct_domains) == 0: domain = [] else: domain = domains[0] return domain
[docs] def diff(self, variable): """ See :meth:`pybamm.Symbol.diff()`. """ if variable.id == self.id: return pybamm.Scalar(1) else: children = self.orphans partial_derivatives = [None] * len(children) for i, child in enumerate(self.children): # if variable appears in the function, differentiate # function, and apply chain rule if variable.id in [symbol.id for symbol in child.pre_order()]: partial_derivatives[i] = self._function_diff( children, i ) * child.diff(variable) # remove None entries partial_derivatives = list(filter(None, partial_derivatives)) derivative = sum(partial_derivatives) if derivative == 0: derivative = pybamm.Scalar(0) return derivative
def _function_diff(self, children, idx): """ Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. """ # Store differentiated function, needed in case we want to convert to CasADi if self.derivative == "autograd": return Function( autograd.elementwise_grad(self.function, idx), *children, differentiated_function=self.function ) elif self.derivative == "derivative": if len(children) > 1: raise ValueError( """ differentiation using '.derivative()' not implemented for functions with more than one child """ ) else: # keep using "derivative" as derivative return pybamm.Function( self.function.derivative(), *children, derivative="derivative", differentiated_function=self.function ) def _function_jac(self, children_jacs): """ Calculate the jacobian of a function. """ if all(child.evaluates_to_constant_number() for child in self.children): jacobian = pybamm.Scalar(0) else: # if at least one child contains variable dependence, then # calculate the required partial jacobians and add them jacobian = None children = self.orphans for i, child in enumerate(children): if not child.evaluates_to_constant_number(): jac_fun = self._function_diff(children, i) * children_jacs[i] jac_fun.clear_domains() if jacobian is None: jacobian = jac_fun else: jacobian += jac_fun return jacobian
[docs] def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): """ See :meth:`pybamm.Symbol.evaluate()`. """ if known_evals is not None: if self.id not in known_evals: evaluated_children = [None] * len(self.children) for i, child in enumerate(self.children): evaluated_children[i], known_evals = child.evaluate( t, y, y_dot, inputs, known_evals=known_evals ) known_evals[self.id] = self._function_evaluate(evaluated_children) return known_evals[self.id], known_evals else: evaluated_children = [ child.evaluate(t, y, y_dot, inputs) for child in self.children ] return self._function_evaluate(evaluated_children)
def _evaluates_on_edges(self, dimension): """ See :meth:`pybamm.Symbol._evaluates_on_edges()`. """ return any(child.evaluates_on_edges(dimension) for child in self.children)
[docs] def is_constant(self): """ See :meth:`pybamm.Symbol.is_constant()`. """ return all(child.is_constant() for child in self.children)
def _evaluate_for_shape(self): """ Default behaviour: has same shape as all child See :meth:`pybamm.Symbol.evaluate_for_shape()` """ evaluated_children = [child.evaluate_for_shape() for child in self.children] return self._function_evaluate(evaluated_children) def _function_evaluate(self, evaluated_children): return self.function(*evaluated_children)
[docs] def new_copy(self): """ See :meth:`pybamm.Symbol.new_copy()`. """ children_copy = [child.new_copy() for child in self.children] return self._function_new_copy(children_copy)
def _function_new_copy(self, children): """Returns a new copy of the function. Inputs ------ children : : list A list of the children of the function Returns ------- : :pybamm.Function A new copy of the function """ return pybamm.simplify_if_constant( pybamm.Function( self.function, *children, name=self.name, derivative=self.derivative, differentiated_function=self.differentiated_function ), )
[docs]class SpecificFunction(Function): """ Parent class for the specific functions, which implement their own `diff` operators directly. Parameters ---------- function : method Function to be applied to child child : :class:`pybamm.Symbol` The child to apply the function to """ def __init__(self, function, child): super().__init__(function, child) def _function_new_copy(self, children): """ See :meth:`pybamm.Function._function_new_copy()` """ return pybamm.simplify_if_constant(self.__class__(*children))
class Arcsinh(SpecificFunction): """Arcsinh function.""" def __init__(self, child): super().__init__(np.arcsinh, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Symbol._function_diff()`. """ return 1 / Sqrt(children[0] ** 2 + 1) def arcsinh(child): """Returns arcsinh function of child.""" return pybamm.simplify_if_constant(Arcsinh(child))
[docs]class Cos(SpecificFunction): """Cosine function.""" def __init__(self, child): super().__init__(np.cos, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Symbol._function_diff()`. """ return -Sin(children[0])
[docs]def cos(child): """Returns cosine function of child.""" return pybamm.simplify_if_constant(Cos(child))
[docs]class Cosh(SpecificFunction): """Hyberbolic cosine function.""" def __init__(self, child): super().__init__(np.cosh, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return Sinh(children[0])
[docs]def cosh(child): """Returns hyperbolic cosine function of child.""" return pybamm.simplify_if_constant(Cosh(child))
[docs]class Exponential(SpecificFunction): """Exponential function.""" def __init__(self, child): super().__init__(np.exp, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return Exponential(children[0])
[docs]def exp(child): """Returns exponential function of child.""" return pybamm.simplify_if_constant(Exponential(child))
[docs]class Log(SpecificFunction): """Logarithmic function.""" def __init__(self, child): super().__init__(np.log, child) def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): return np.log(*evaluated_children) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return 1 / children[0]
[docs]def log(child, base="e"): """Returns logarithmic function of child (any base, default 'e').""" if base == "e": return pybamm.simplify_if_constant(Log(child)) else: return Log(child) / np.log(base)
def log10(child): """Returns logarithmic function of child, with base 10.""" return log(child, base=10)
[docs]def max(child): """ Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which returns the larger of two objects. """ return pybamm.simplify_if_constant(Function(np.max, child))
[docs]def min(child): """ Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which returns the smaller of two objects. """ return pybamm.simplify_if_constant(Function(np.min, child))
def sech(child): """Returns hyperbolic sec function of child.""" return pybamm.simplify_if_constant(1 / Cosh(child))
[docs]class Sin(SpecificFunction): """Sine function.""" def __init__(self, child): super().__init__(np.sin, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return Cos(children[0])
[docs]def sin(child): """Returns sine function of child.""" return pybamm.simplify_if_constant(Sin(child))
[docs]class Sinh(SpecificFunction): """Hyperbolic sine function.""" def __init__(self, child): super().__init__(np.sinh, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return Cosh(children[0])
[docs]def sinh(child): """Returns hyperbolic sine function of child.""" return pybamm.simplify_if_constant(Sinh(child))
class Sqrt(SpecificFunction): """Square root function.""" def __init__(self, child): super().__init__(np.sqrt, child) def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): return np.sqrt(*evaluated_children) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return 1 / (2 * Sqrt(children[0])) def sqrt(child): """Returns square root function of child.""" return pybamm.simplify_if_constant(Sqrt(child)) class Tanh(SpecificFunction): """Hyperbolic tan function.""" def __init__(self, child): super().__init__(np.tanh, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return sech(children[0]) ** 2 def tanh(child): """Returns hyperbolic tan function of child.""" return pybamm.simplify_if_constant(Tanh(child)) class Arctan(SpecificFunction): """Arctan function.""" def __init__(self, child): super().__init__(np.arctan, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return 1 / (children[0] ** 2 + 1) def arctan(child): """Returns hyperbolic tan function of child.""" return pybamm.simplify_if_constant(Arctan(child)) class Erf(SpecificFunction): """Error function.""" def __init__(self, child): super().__init__(special.erf, child) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ return 2 / np.sqrt(np.pi) * Exponential(-children[0] ** 2) def erf(child): """Returns error function of child.""" return pybamm.simplify_if_constant(Erf(child)) def erfc(child): """Returns complementary error function of child.""" return pybamm.simplify_if_constant(1 - Erf(child))