Commit 3c267441 authored by Dominic Kempf's avatar Dominic Kempf

Allow coefficient cargo to be a dict

parent 0a56d806
Pipeline #23884 failed with stage
in 1 minute and 11 seconds
......@@ -239,8 +239,10 @@ def driver_block_set_coefficient_function(lop_name, coefficient):
def _cf_ident(coefficient):
ident = str(coefficient.count() - 2)
return ident
name = coefficient.codegen_cargo("name")
if name is None:
name = str(coefficient.count() - 2)
return name
@class_member(classtag="driver_block")
......@@ -251,7 +253,7 @@ def typedef_coefficient_vector(name, coefficient):
# form depends on non ansatz coefficients we need to pass this information
# along. We pass this along through the codegen_cargo() method on the
# coeffiecient.
is_dirichlet = coefficient.codegen_cargo()
is_dirichlet = coefficient.codegen_cargo("is_dirichlet")
# Note: root=False makes sure that no driver_block_get_... method is called
gfs_type = type_gfs(element, is_dirichlet, root=False)
......@@ -285,7 +287,7 @@ def driver_block_typedef_gfs(name, coefficient):
# form depends on non ansatz coefficients we need to pass this information
# along. We pass this along through the codegen_cargo() method on the
# coeffiecient.
is_dirichlet = coefficient.codegen_cargo()
is_dirichlet = coefficient.codegen_cargo("is_dirichlet")
# Note: root=False makes sure that no driver_block_get_... method is called
gfs_type = type_gfs(element, is_dirichlet, root=False)
......
......@@ -19,7 +19,7 @@ class TrialFunction(ufl.Coefficient):
class Coefficient(ufl.Coefficient):
""" A coefficient that honors the reserved index 0. """
def __init__(self, element, count=None, cargo=None):
def __init__(self, element, count=None, cargo={}):
"""The cargo member can be used to transport data through the Coefficient
without UFL knowing about it. Use case: Transport is_dirichlet data
through the coefficient for operator splitting
......@@ -35,9 +35,10 @@ class Coefficient(ufl.Coefficient):
count = 3
ufl.Coefficient.__init__(self, element, count)
def codegen_cargo(self):
def codegen_cargo(self, key):
# Transport data through the Coefficient
return self.cargo
return self.cargo.get(key, None)
def split(obj):
return ufl.split_functions.split(obj)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment