From 427126ed4e96644511eab7aa268d29a6a9bdd94d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Carsten=20Gr=C3=A4ser?= <graeser@math.fau.de>
Date: Sat, 15 Feb 2025 16:32:04 +0100
Subject: [PATCH] [example] Handle hanging-node and Dirichlet constraints
 uniformly

---
 examples/poisson-adaptive.cc | 191 +++++++++++++++--------------------
 examples/utilities.hh        |  34 +++++++
 2 files changed, 116 insertions(+), 109 deletions(-)

diff --git a/examples/poisson-adaptive.cc b/examples/poisson-adaptive.cc
index 2df812d0..fd034261 100644
--- a/examples/poisson-adaptive.cc
+++ b/examples/poisson-adaptive.cc
@@ -42,14 +42,29 @@
 #include <dune/fufem/functiontools/boundarydofs.hh>
 #include <dune/fufem/forms/forms.hh>
 #include <dune/fufem/constraints/continuityconstraints.hh>
+#include <dune/fufem/constraints/boundaryconstraints.hh>
 
 #include "utilities.hh"
 
 using namespace Dune;
 
+template<std::size_t codim, class BV, class V, class MI, class C, class Basis>
+auto computeCodimConstraints(Dune::Fufem::AffineConstraints<BV, V, MI, C>& constraints, const Basis& basis)
+{
+  auto& isConstrained = constraints.isConstrained();
+  auto localView = basis.localView();
+  for(const auto& element : elements(basis.gridView()))
+  {
+    localView.bind(element);
+    const auto& localCoefficients = localView.tree().finiteElement().localCoefficients();
+    for(auto i : Dune::range(localCoefficients.size()))
+      if (localCoefficients.localKey(i).codim() == codim)
+        isConstrained[localView.index(i)] = true;
+  }
+}
 
-template<class BilinearForm, class Residual, class ExtensionBasis, class Constraints, class IsConstraints, class Vector, class... Args>
-auto hierarchicalErrorEstimator(BilinearForm b, Residual residual, const ExtensionBasis& extension, const Constraints& constraints, const IsConstraints& isConstrained, Vector& defect, Args... args)
+template<class BilinearForm, class Residual, class ExtensionBasis, class Constraints, class Vector, class... Args>
+auto hierarchicalErrorEstimator(BilinearForm b, Residual residual, const ExtensionBasis& extension, const Constraints& constraints, Vector& defect, Args... args)
 {
   using namespace Dune::Fufem::Forms;
 
@@ -60,9 +75,9 @@ auto hierarchicalErrorEstimator(BilinearForm b, Residual residual, const Extensi
   auto B = Dune::BCRSMatrix<double>();
 
   auto globalAssembler = GlobalAssembler(extension, extension, 0, std::forward<Args>(args)...);
-  globalAssembler.assembleFunctional(r, residual(v), constraints);
-  globalAssembler.assembleOperator(B, b(d,v), constraints);
+  globalAssembler.assembleSystem(B, r, b(d,v), residual(v), constraints);
 
+  const auto& isConstrained = constraints.isConstrained();
   auto eta = std::vector<double>(extension.size(), 0.0);
   double etaSum = 0;
   for(auto k : Dune::range(extension.size()))
@@ -192,9 +207,6 @@ int main (int argc, char *argv[]) try
 
   log("MPI initialized");
 
-  constexpr auto order = 1;
-  constexpr auto extensionOrder = 2;
-
   ///////////////////////////////////
   //   Generate the grid
   ///////////////////////////////////
@@ -213,31 +225,46 @@ int main (int argc, char *argv[]) try
 
   log(Dune::formatString("Grid uniformly refined %d times", refinements));
 
+  // *********************************************
+  // Boundary and right hand side data
+  // *********************************************
 
-  // Problem data
+  // We use an example combining two features that have to be
+  // detected by the error estimator:
+  // * An L-shape domain with a reentrent corner.
+  // * A boundary condition not visible on the coarsest mesh.
 
   auto dirichletIndicatorFunction = [](auto x) { return true; };
+
   auto rightHandSide = [] (const auto& x) { return 10;};
-  auto dirichletValues = [](const auto& x){ return 0; };
-//  auto rightHandSide = [] (const auto& x) { return 0;};
-//  auto dirichletValues = [](const auto& x){
-//    return (std::fabs(x[0])<1e-5)*(x[1]<.25)*x[1]*(.25-x[1])*20.;
-//  };
 
-  std::size_t refStep = 0;
+  auto dirichletValues = [](const auto& x){
+    auto z = x[0]+x[1];
+    return std::max(10.0*z*(.25-z), 0.0);
+  };
+
+  // *********************************************
+  // Standard adaptive loop
+  // *********************************************
+
+  // Discretize -> solve -> estimate -> mark -> refine
+
+  std::size_t refinementStep = 0;
   while(true)
   {
 
     auto gridView = grid.leafGridView();
     using GridView = decltype(gridView);
 
-    /////////////////////////////////////////////////////////
-    //   Choose a finite element space
-    /////////////////////////////////////////////////////////
+    // *********************************************
+    // Choose a finite element space
+    // *********************************************
+
+    constexpr auto order = 1;
+    constexpr auto extensionOrder = 2;
 
     auto basis = Dune::Functions::LagrangeBasis<GridView,order>(gridView);
 
-    
     // Define suitable matrix and vector types
     using Matrix = Dune::BCRSMatrix<double>;
     using Vector = Dune::BlockVector<double>;
@@ -245,31 +272,37 @@ int main (int argc, char *argv[]) try
 
     log(Dune::formatString("Basis created with dimension %d", basis.dimension()));
 
-    auto constraints = Dune::Fufem::makeContinuityConstraints<BitVector>(basis);
+    // *********************************************
+    // Compute hanging node and Dirichlet constraints
+    // *********************************************
+
+    auto dirichletPatch = BoundaryPatch(basis.gridView());
+    dirichletPatch.insertFacesByProperty([&](auto&& intersection) { return dirichletIndicatorFunction(intersection.geometry().center()); });
+    log("Boundary patch setup");
 
-    log("Constraints computed");
+    auto constraints = Dune::Fufem::makeAffineConstraints<BitVector, Vector>(basis);
 
-    constraints.check();
+    computeContinuityConstraints(constraints, basis);
+    log("Hanging node constraints computed");
 
+    computeBoundaryConstraints(constraints, basis, dirichletValues, dirichletPatch);
+    log("Boundary constraints computed");
+
+    constraints.check();
     log("Constraints checked");
 
-    /////////////////////////////////////////////////////////
-    //   Stiffness matrix and right hand side vector
-    /////////////////////////////////////////////////////////
+    // *********************************************
+    // Stiffness matrix and right hand side vector
+    // *********************************************
 
     Vector rhs;
     Vector sol;
-    BitVector isConstrained;
     Matrix stiffnessMatrix;
 
     // *********************************************
     // Assemble the system
     // *********************************************
 
-
-    auto dirichletPatch = BoundaryPatch(basis.gridView());
-    dirichletPatch.insertFacesByProperty([&](auto&& intersection) { return dirichletIndicatorFunction(intersection.geometry().center()); });
-
     // Disable parallel assembly for UGGrid before 2.10
 #if DUNE_VERSION_GT(DUNE_FUNCTIONS, 2, 9)
     auto gridViewPartition = Dune::Fufem::coloredGridViewPartition(basis.gridView());
@@ -282,7 +315,6 @@ int main (int argc, char *argv[]) try
 
     {
       using namespace ::Dune::Fufem::Forms;
-      namespace F = ::Dune::Fufem::Forms;
 
       auto v = testFunction(basis);
       auto u = trialFunction(basis);
@@ -291,34 +323,8 @@ int main (int argc, char *argv[]) try
       auto a = integrate(dot(grad(u), grad(v)));
       auto b = integrate(f*v);
 
-      auto operatorAssembler = Dune::Fufem::DuneFunctionsOperatorAssembler{basis, basis};
-      log("Assembler set up");
-
-      //    globalAssembler.assembleOperator(stiffnessMatrix, a);
-
-      auto&& matrixBackend = Dune::Fufem::istlMatrixBackend(stiffnessMatrix);
-      auto patternBackend = matrixBackend.patternBuilder();
-      operatorAssembler.assembleBulkPattern(patternBackend);
-      log("Matrix pattern assembled");
-
-      constraints.constrainMatrixPattern(patternBackend, basis);
-      log("Matrix pattern constrained");
-
-      patternBackend.setupMatrix();
-      log("Matrix set up");
-
-      matrixBackend.assign(0);
-      operatorAssembler.assembleBulkEntries(matrixBackend, a);
-      log("Matrix assembled");
-
-      constraints.constrainMatrix(stiffnessMatrix);
-      log("Matrix constrained");
-
-      globalAssembler.assembleFunctional(rhs, b);
-      log("RHS assembled");
-
-      constraints.constrainVector(rhs);
-      log("RHS constrained");
+      globalAssembler.assembleSystem(stiffnessMatrix, rhs, a, b, constraints);
+      log("Liner problem assembled");
     }
 
     // *********************************************
@@ -328,21 +334,6 @@ int main (int argc, char *argv[]) try
     Dune::Functions::istlVectorBackend(sol).resize(basis);
     Dune::Functions::istlVectorBackend(sol) = 0;
 
-    // *********************************************
-    // Incorporate Dirichlet boundary conditions
-    // *********************************************
-
-    Dune::Functions::istlVectorBackend(isConstrained).resize(basis);
-    Dune::Functions::istlVectorBackend(isConstrained) = 0;
-    Dune::Fufem::markBoundaryPatchDofs(dirichletPatch, basis, isConstrained);
-    log("Boundary DOFs marked");
-
-    interpolate(basis, sol, dirichletValues, isConstrained);
-    log("Boundary DOFs interpolated");
-
-    incorporateEssentialConstraints(stiffnessMatrix, rhs, isConstrained, sol);
-    log("Boundary condition incorporated into system");
-
     // *********************************************
     // Solve linear system
     // *********************************************
@@ -393,40 +384,22 @@ int main (int argc, char *argv[]) try
       auto u = bindToCoefficients(trialFunction(basis), sol);
 
       auto extensionBasis = Dune::Functions::HierarchicalLagrangeBasis<GridView,extensionOrder>(gridView);
-      auto extensionConstraints = Dune::Fufem::makeContinuityConstraints<BitVector>(extensionBasis);
-      log("Constraints for extension space computed");
+      log("Extension basis created");
 
-      auto defect = Vector();
-      Dune::Functions::istlVectorBackend(defect).resize(extensionBasis);
-      Dune::Functions::istlVectorBackend(defect) = 0;
+      auto extensionConstraints = Dune::Fufem::makeAffineConstraints<BitVector, Vector>(extensionBasis);
 
-      // Mark all constrained and lower order DOFs to be ignored
-      auto ignore = extensionConstraints.isConstrained();
-      {
-        auto localView = extensionBasis.localView();
-        for(const auto& element : elements(extensionBasis.gridView()))
-        {
-          localView.bind(element);
-          const auto& localCoefficients = localView.tree().finiteElement().localCoefficients();
-          for(auto i : Dune::range(localCoefficients.size()))
-            if (localCoefficients.localKey(i).codim() == Grid::dimension)
-              ignore[localView.index(i)] = true;
-        }
-      }
+      computeContinuityConstraints(extensionConstraints, extensionBasis);
+      log("Extension hanging node constraints computed");
+
+      computeCodimConstraints<Grid::dimension>(extensionConstraints, extensionBasis);
+      log("Extension DOFs identified");
 
-      // Interpolate all non-ignored boundary DOFs
-      auto extensionIsBoundary = std::vector<bool>();
-      Dune::Fufem::markBoundaryPatchDofs(dirichletPatch, extensionBasis, extensionIsBoundary);
-      for(auto i: Dune::range(ignore.size()))
-        if (ignore[i])
-          extensionIsBoundary[i] = false;
-      interpolate(extensionBasis, defect, dirichletValues, extensionIsBoundary);
-      log("Boundary DOFs for extension space marked and interpolated");
+      computeBoundaryConstraints(extensionConstraints, extensionBasis, dirichletValues, dirichletPatch);
+      log("Extension boundary constraints computed");
 
-      // Mark all boundary DOFs for being ignored, too.
-      for(auto i: Dune::range(ignore.size()))
-        if (extensionIsBoundary[i])
-          ignore[i] = true;
+      auto defect = Vector();
+      Dune::Functions::istlVectorBackend(defect).resize(extensionBasis);
+      Dune::Functions::istlVectorBackend(defect) = 0;
 
       // Define defect problem assemblers
       auto f = Coefficient(Dune::Functions::makeGridViewFunction(rightHandSide, basis.gridView()));
@@ -439,15 +412,15 @@ int main (int argc, char *argv[]) try
 
       // Compute local error estimates
 #if DUNE_VERSION_GT(DUNE_FUNCTIONS, 2, 9)
-      auto [error, eta] = hierarchicalErrorEstimator(a, residual, extensionBasis, extensionConstraints, ignore, defect, std::cref(gridViewPartition), threadCount);
+      auto [error, eta] = hierarchicalErrorEstimator(a, residual, extensionBasis, extensionConstraints, defect, std::cref(gridViewPartition), threadCount);
 #else
-      auto [error, eta] = hierarchicalErrorEstimator(a, residual, extensionBasis, extensionConstraints, ignore, defect);
+      auto [error, eta] = hierarchicalErrorEstimator(a, residual, extensionBasis, extensionConstraints, defect);
 #endif
 
       log(Dune::formatString("Estimated total error :    % 12.5e", error));
 
       // *********************************************
-      // Write solution
+      // Write solution and solution of localized defect problem
       // *********************************************
 
       {
@@ -459,12 +432,12 @@ int main (int argc, char *argv[]) try
         auto vtkWriter = UnstructuredGridWriter(LagrangeDataCollector(basis.gridView(), extensionOrder));
         vtkWriter.addPointData(u, FieldInfo("sol", FieldInfo::Type::scalar, 1));
         vtkWriter.addPointData(d, FieldInfo("defect", FieldInfo::Type::scalar, 1));
-        vtkWriter.write(Dune::formatString("poisson-adaptive-%03d", refStep));
+        vtkWriter.write(Dune::formatString("poisson-adaptive-%03d", refinementStep));
 #else
         Dune::SubsamplingVTKWriter<GridView> vtkWriter(gridView, Dune::refinementLevels(1));
         vtkWriter.addVertexData(u, VTK::FieldInfo("sol", VTK::FieldInfo::Type::scalar, 1));
         vtkWriter.addVertexData(d, VTK::FieldInfo("defect", VTK::FieldInfo::Type::scalar, 1));
-        vtkWriter.write(Dune::formatString("poisson-adaptive-%03d", refStep));
+        vtkWriter.write(Dune::formatString("poisson-adaptive-%03d", refinementStep));
 #endif
         log("Solution written to vtk file");
       }
@@ -484,7 +457,7 @@ int main (int argc, char *argv[]) try
       grid.postAdapt();
       log(Dune::formatString("Grid adapted. New grid has %d levels.", grid.maxLevel()));
 
-      refStep++;
+      ++refinementStep;
     }
   }
 }
diff --git a/examples/utilities.hh b/examples/utilities.hh
index 64ca8d64..e4340143 100644
--- a/examples/utilities.hh
+++ b/examples/utilities.hh
@@ -114,6 +114,40 @@ public:
       std::cout << "Assembling functional took " << timer.elapsed() << "s" << std::endl;
   }
 
+  template<class Matrix, class Vector, class LocalOperatorAssembler, class LocalFunctionalAssembler, class Constraints>
+  void assembleSystem(Matrix& matrix, Vector& vector, LocalOperatorAssembler&& localOperatorAssembler, LocalFunctionalAssembler&& localFunctionalAssembler, const Constraints& constraints) const
+  {
+    Dune::Timer timer;
+
+    auto operatorAssembler = Dune::Fufem::DuneFunctionsOperatorAssembler{testBasis_, trialBasis_};
+    std::apply([&](auto&&... args) {
+      auto&& matrixBackend = Dune::Fufem::istlMatrixBackend(matrix);
+      auto patternBackend = matrixBackend.patternBuilder();
+
+      operatorAssembler.assembleBulkPattern(patternBackend, Dune::resolveRef(args)...);
+
+      constraints.constrainMatrixPattern(patternBackend, testBasis_);
+
+      patternBackend.setupMatrix();
+
+      matrixBackend.assign(0);
+      operatorAssembler.assembleBulkEntries(matrixBackend, localOperatorAssembler, Dune::resolveRef(args)...);
+    }, args_);
+    if (verbosity_>0)
+      std::cout << "Assembling operator took " << timer.elapsed() << "s" << std::endl;
+
+    auto functionalAssembler = Dune::Fufem::DuneFunctionsFunctionalAssembler{testBasis_};
+    std::apply([&](auto&&... args) {
+      functionalAssembler.assembleBulk(vector, localFunctionalAssembler, Dune::resolveRef(args)...);
+    }, args_);
+    if (verbosity_>0)
+      std::cout << "Assembling functional took " << timer.elapsed() << "s" << std::endl;
+
+    constraints.constrainLinearSystem(matrix, vector);
+    if (verbosity_>0)
+      std::cout << "Constraining linear system took " << timer.elapsed() << "s" << std::endl;
+  }
+
   template<class Vector>
   void initializeVector(Vector& vector) const
   {
-- 
GitLab