#
# Concatenation classes
#
import copy
import numpy as np
import pybamm
from scipy.sparse import vstack, issparse
from collections import defaultdict
[docs]class Concatenation(pybamm.Symbol):
"""A node in the expression tree representing a concatenation of symbols
**Extends**: :class:`pybamm.Symbol`
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The symbols to concatenate
"""
def __init__(self, *children, name=None, check_domain=True, concat_fun=None):
if name is None:
name = "concatenation"
if check_domain:
domain = self.get_children_domains(children)
auxiliary_domains = self.get_children_auxiliary_domains(children)
else:
domain = []
auxiliary_domains = {}
self.concatenation_function = concat_fun
super().__init__(
name, children, domain=domain, auxiliary_domains=auxiliary_domains
)
def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
out = self.name + "("
for child in self.children:
out += "{!s}, ".format(child)
out = out[:-2] + ")"
return out
def get_children_domains(self, children):
# combine domains from children
domain = []
for child in children:
if not isinstance(child, pybamm.Symbol):
raise TypeError("{} is not a pybamm symbol".format(child))
child_domain = child.domain
if child_domain == []:
raise pybamm.DomainError(
"Cannot concatenate child '{}' with empty domain".format(child)
)
if set(domain).isdisjoint(child_domain):
domain += child_domain
else:
raise pybamm.DomainError("domain of children must be disjoint")
return domain
def _concatenation_evaluate(self, children_eval):
""" See :meth:`Concatenation._concatenation_evaluate()`. """
if len(children_eval) == 0:
return np.array([])
else:
return self.concatenation_function(children_eval)
[docs] def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
""" See :meth:`pybamm.Symbol.evaluate()`. """
children = self.cached_children
if known_evals is not None:
if self.id not in known_evals:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx], known_evals = child.evaluate(
t, y, y_dot, inputs, known_evals
)
known_evals[self.id] = self._concatenation_evaluate(children_eval)
return known_evals[self.id], known_evals
else:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx] = child.evaluate(t, y, y_dot, inputs)
return self._concatenation_evaluate(children_eval)
[docs] def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_children = [child.new_copy() for child in self.children]
return self._concatenation_new_copy(new_children)
def _concatenation_new_copy(self, children):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_symbol = self.__class__(*children)
return new_symbol
def _concatenation_jac(self, children_jacs):
""" Calculate the jacobian of a concatenation """
return NotImplementedError
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape` """
if len(self.children) == 0:
return np.array([])
else:
# Default: use np.concatenate
concatenation_function = self.concatenation_function or np.concatenate
return concatenation_function(
[child.evaluate_for_shape() for child in self.children]
)
[docs] def is_constant(self):
""" See :meth:`pybamm.Symbol.is_constant()`. """
return all(child.is_constant() for child in self.children)
[docs]class NumpyConcatenation(Concatenation):
"""A node in the expression tree representing a concatenation of equations, when we
*don't* care about domains. The class :class:`pybamm.DomainConcatenation`, which
*is* careful about domains and uses broadcasting where appropriate, should be used
whenever possible instead.
Upon evaluation, equations are concatenated using numpy concatenation.
**Extends**: :class:`Concatenation`
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The equations to concatenate
"""
def __init__(self, *children):
children = list(children)
# Turn objects that evaluate to scalars to objects that evaluate to vectors,
# so that we can concatenate them
for i, child in enumerate(children):
if child.evaluates_to_number():
children[i] = child * pybamm.Vector([1])
super().__init__(
*children,
name="numpy_concatenation",
check_domain=False,
concat_fun=np.concatenate
)
def _concatenation_jac(self, children_jacs):
""" See :meth:`pybamm.Concatenation.concatenation_jac()`. """
children = self.cached_children
if len(children) == 0:
return pybamm.Scalar(0)
else:
return SparseStack(*children_jacs)
[docs]class DomainConcatenation(Concatenation):
"""A node in the expression tree representing a concatenation of symbols, being
careful about domains.
It is assumed that each child has a domain, and the final concatenated vector will
respect the sizes and ordering of domains established in mesh keys
**Extends**: :class:`pybamm.Concatenation`
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The symbols to concatenate
full_mesh : :class:`pybamm.BaseMesh`
The underlying mesh for discretisation, used to obtain the number of mesh points
in each domain.
copy_this : :class:`pybamm.DomainConcatenation` (optional)
if provided, this class is initialised by copying everything except the children
from `copy_this`. `mesh` is not used in this case
"""
def __init__(self, children, full_mesh, copy_this=None):
# Convert any constant symbols in children to a Vector of the right size for
# concatenation
children = list(children)
# Allow the base class to sort the domains into the correct order
super().__init__(*children, name="domain_concatenation")
# ensure domain is sorted according to mesh keys
domain_dict = {d: full_mesh.domain_order.index(d) for d in self.domain}
self.domain = sorted(domain_dict, key=domain_dict.__getitem__)
if copy_this is None:
# store mesh
self._full_mesh = full_mesh
# create dict of domain => slice of final vector
self.secondary_dimensions_npts = self._get_auxiliary_domain_repeats(
self.domains
)
self._slices = self.create_slices(self)
# store size of final vector
self._size = self._slices[self.domain[-1]][-1].stop
# create disc of domain => slice for each child
self._children_slices = [
self.create_slices(child) for child in self.cached_children
]
else:
self._full_mesh = copy.copy(copy_this._full_mesh)
self._slices = copy.copy(copy_this._slices)
self._size = copy.copy(copy_this._size)
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts
def _get_auxiliary_domain_repeats(self, auxiliary_domains):
"""
Helper method to read the 'auxiliary_domain' meshes
"""
if "secondary" in auxiliary_domains:
sec_mesh_npts = self.full_mesh.combine_submeshes(
*auxiliary_domains["secondary"]
).npts
else:
sec_mesh_npts = 1
if "tertiary" in auxiliary_domains:
tert_mesh_npts = self.full_mesh.combine_submeshes(
*auxiliary_domains["tertiary"]
).npts
else:
tert_mesh_npts = 1
return sec_mesh_npts * tert_mesh_npts
@property
def full_mesh(self):
return self._full_mesh
def create_slices(self, node):
slices = defaultdict(list)
start = 0
end = 0
second_pts = self._get_auxiliary_domain_repeats(self.domains)
if second_pts != self.secondary_dimensions_npts:
raise ValueError(
"""Concatenation and children must have the same number of
points in secondary dimensions"""
)
for i in range(second_pts):
for dom in node.domain:
end += self.full_mesh[dom].npts
slices[dom].append(slice(start, end))
start = end
return slices
def _concatenation_evaluate(self, children_eval):
""" See :meth:`Concatenation._concatenation_evaluate()`. """
# preallocate vector
vector = np.empty((self._size, 1))
# loop through domains of children writing subvectors to final vector
for child_vector, slices in zip(children_eval, self._children_slices):
for child_dom, child_slice in slices.items():
for i, _slice in enumerate(child_slice):
vector[self._slices[child_dom][i]] = child_vector[_slice]
return vector
def _concatenation_jac(self, children_jacs):
""" See :meth:`pybamm.Concatenation.concatenation_jac()`. """
# note that this assumes that the children are in the right order and only have
# one domain each
jacs = []
for i in range(self.secondary_dimensions_npts):
for child_jac, slices in zip(children_jacs, self._children_slices):
if len(slices) > 1:
raise NotImplementedError(
"""jacobian only implemented for when each child has
a single domain"""
)
child_slice = next(iter(slices.values()))
jacs.append(pybamm.Index(child_jac, child_slice[i]))
return SparseStack(*jacs)
def _concatenation_new_copy(self, children):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_symbol = simplified_domain_concatenation(
children, self.full_mesh, copy_this=self
)
return new_symbol
[docs]class SparseStack(Concatenation):
"""A node in the expression tree representing a concatenation of sparse
matrices. As with NumpyConcatenation, we *don't* care about domains.
The class :class:`pybamm.DomainConcatenation`, which *is* careful about
domains and uses broadcasting where appropriate, should be used whenever
possible instead.
**Extends**: :class:`Concatenation`
Parameters
----------
children : iterable of :class:`Concatenation`
The equations to concatenate
"""
def __init__(self, *children):
children = list(children)
if not any(issparse(child.evaluate_for_shape()) for child in children):
concatenation_function = np.vstack
else:
concatenation_function = vstack
super().__init__(
*children,
name="sparse_stack",
check_domain=False,
concat_fun=concatenation_function
)
def simplified_numpy_concatenation(*children):
""" Perform simplifications on a numpy concatenation """
# Turn a concatenation of concatenations into a single concatenation
new_children = []
for child in children:
# extract any children from numpy concatenation
if isinstance(child, NumpyConcatenation):
new_children.extend(child.orphans)
else:
new_children.append(child)
return pybamm.simplify_if_constant(NumpyConcatenation(*new_children))
[docs]def numpy_concatenation(*children):
""" Helper function to create numpy concatenations """
# TODO: add option to turn off simplifications
return simplified_numpy_concatenation(*children)
def simplified_domain_concatenation(children, mesh, copy_this=None):
""" Perform simplifications on a domain concatenation """
# Create the DomainConcatenation to read domain and child domain
concat = DomainConcatenation(children, mesh, copy_this=copy_this)
# Simplify Concatenation of StateVectors to a single StateVector
# The sum of the evalation arrays of the StateVectors must be exactly 1
if all([isinstance(child, pybamm.StateVector) for child in children]):
longest_eval_array = len(children[-1]._evaluation_array)
eval_arrays = {}
for child in children:
eval_arrays[child] = np.concatenate(
[
child.evaluation_array,
np.zeros(longest_eval_array - len(child.evaluation_array)),
]
)
first_start = children[0].y_slices[0].start
last_stop = children[-1].y_slices[-1].stop
if all(
sum(array for array in eval_arrays.values())[first_start:last_stop] == 1
):
return pybamm.StateVector(
slice(first_start, last_stop),
domain=concat.domain,
auxiliary_domains=concat.auxiliary_domains,
)
return pybamm.simplify_if_constant(concat)
[docs]def domain_concatenation(children, mesh):
""" Helper function to create domain concatenations """
# TODO: add option to turn off simplifications
return simplified_domain_concatenation(children, mesh)