unfitted_transfer_matrix_solver.hh 6.47 KB
Newer Older
1 2
#ifndef DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH
#define DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH
3 4 5 6 7

#include <dune/common/parametertree.hh>

#include <duneuro/common/make_dof_vector.hh>
#include <duneuro/eeg/projection_utilities.hh>
8
#include <duneuro/eeg/unfitted_transfer_matrix_rhs.hh>
9 10 11 12
#include <duneuro/io/data_tree.hh>

namespace duneuro
{
13
  template <class S>
14
  struct UnfittedTransferMatrixSolverTraits {
15 16 17
    using Solver = S;
    using SubTriangulation = typename Solver::Traits::SubTriangulation;
    static const unsigned int dimension = SubTriangulation::dim;
18 19 20
    using FunctionSpace = typename Solver::Traits::FunctionSpace;
    using DomainDOFVector = typename Solver::Traits::DomainDOFVector;
    using RangeDOFVector = typename Solver::Traits::RangeDOFVector;
21
    using CoordinateFieldType = typename SubTriangulation::ctype;
22 23 24 25 26
    using Coordinate = Dune::FieldVector<CoordinateFieldType, dimension>;
    using Element = typename Solver::Traits::FundamentalGridView::template Codim<0>::Entity;
    using ProjectedPosition = duneuro::ProjectedPosition<Element, Coordinate>;
  };

27
  template <class S, class RHSFactory>
28
  class UnfittedTransferMatrixSolver
29 30
  {
  public:
31
    using Traits = UnfittedTransferMatrixSolverTraits<S>;
32

33 34
    UnfittedTransferMatrixSolver(
        std::shared_ptr<typename Traits::SubTriangulation> subTriangulation,
35
        std::shared_ptr<typename Traits::Solver> solver, const Dune::ParameterTree& config)
36
        : subTriangulation_(subTriangulation)
37
        , solver_(solver)
38
        , rightHandSideVector_(solver_->functionSpace().getGFS(), 0.0)
39 40 41 42
        , config_(config)
    {
    }

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    template <class SolverBackend>
    std::unique_ptr<DenseMatrix<double>> solve(
        SolverBackend& solverBackend,
        const ProjectedElectrodes<typename Traits::SubTriangulation::GridView>& projectedElectrodes,
        const Dune::ParameterTree& config, DataTree dataTree = DataTree())
    {
      auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
          projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
      auto solver_config = config.sub("solver");
      typename Traits::DomainDOFVector solution(solver_->functionSpace().getGFS(), 0.0);
      for (std::size_t index = 1; index < projectedElectrodes.size(); ++index) {
        solve(solverBackend.get(), projectedElectrodes.projectedPosition(0),
              projectedElectrodes.projectedPosition(index), solution, rightHandSideVector_,
              solver_config, dataTree.sub("solver.electrode_" + std::to_string(index)));
        set_matrix_row(*transferMatrix, index, Dune::PDELab::Backend::native(solution));
      }
      return transferMatrix;
    }

#if HAVE_TBB
    template <class SolverBackend>
    std::unique_ptr<DenseMatrix<double>> solve(
        tbb::enumerable_thread_specific<SolverBackend>& solverBackend,
        const ProjectedElectrodes<typename Traits::SubTriangulation::GridView>& projectedElectrodes,
        const Dune::ParameterTree& config, DataTree dataTree = DataTree())
    {
      auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
          projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
      auto solver_config = config.sub("solver");
      tbb::task_scheduler_init init(solver_config.hasKey("numberOfThreads") ?
                                        solver_config.get<std::size_t>("numberOfThreads") :
                                        tbb::task_scheduler_init::automatic);
      auto grainSize = solver_config.get<int>("grainSize", 16);
      tbb::enumerable_thread_specific<typename Traits::DomainDOFVector> solution(
          solver_->functionSpace().getGFS(), 0.0);
      tbb::parallel_for(
          tbb::blocked_range<std::size_t>(1, projectedElectrodes.size(), grainSize),
          [&](const tbb::blocked_range<std::size_t>& range) {
            auto& mySolution = solution.local();
            for (std::size_t index = range.begin(); index != range.end(); ++index) {
              solve(solverBackend.local().get(), projectedElectrodes.projectedPosition(0),
                    projectedElectrodes.projectedPosition(index), mySolution,
                    rightHandSideVector_.local(), solver_config,
                    dataTree.sub("solver.electrode_" + std::to_string(index)));
              set_matrix_row(*transferMatrix, index, Dune::PDELab::Backend::native(mySolution));
            }
          });
      return transferMatrix;
    }
#endif

    const typename Traits::FunctionSpace& functionSpace() const
    {
      return solver_->functionSpace();
    }

  private:
    std::shared_ptr<typename Traits::SubTriangulation> subTriangulation_;
    std::shared_ptr<typename Traits::Solver> solver_;
#if HAVE_TBB
    tbb::enumerable_thread_specific<typename Traits::RangeDOFVector> rightHandSideVector_;
#else
    typename Traits::RangeDOFVector rightHandSideVector_;
#endif
    Dune::ParameterTree config_;

109 110
    template <class SolverBackend>
    void solve(SolverBackend& solverBackend, const typename Traits::ProjectedPosition& reference,
111
               const typename Traits::ProjectedPosition& electrode,
112 113 114
               typename Traits::DomainDOFVector& solution,
               typename Traits::RangeDOFVector& rightHandSideVector,
               const Dune::ParameterTree& config, DataTree dataTree = DataTree()) const
115 116
    {
      Dune::Timer timer;
117
      rightHandSideVector = 0.0;
118
      // assemble right hand side
119 120
      auto rhsAssembler =
          RHSFactory::template create<typename Traits::RangeDOFVector>(*solver_, config);
121
#if HAVE_TBB
122
      { // mutex, as the finite element map is not thread safe
123
        tbb::mutex::scoped_lock lock(solver_->functionSpaceMutex());
124 125
        rhsAssembler->bind(reference.element, reference.localPosition, electrode.element,
                           electrode.localPosition);
126
      }
127
      rhsAssembler->assembleRightHandSide(rightHandSideVector);
128
#else
129 130 131
      rhsAssembler->bind(reference.element, reference.localPosition, electrode.element,
                         electrode.localPosition);
      rhsAssembler->assembleRightHandSide(rightHandSideVector);
132
#endif
133 134 135 136
      timer.stop();
      dataTree.set("time_rhs_assembly", timer.lastElapsed());
      timer.start();
      // solve system
137
      solver_->solve(solverBackend, rightHandSideVector, solution, config,
138
                     dataTree.sub("linear_system_solver"));
139 140 141 142 143 144 145
      timer.stop();
      dataTree.set("time_solution", timer.lastElapsed());
      dataTree.set("time", timer.elapsed());
    }
  };
}

146
#endif // DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH