Commit bfd5f465 authored by Andreas Nüßing's avatar Andreas Nüßing

unify transfer matrix solver

Instead of using two basically identical transfer matrix solver classes,
we merge them into a single class that is parameterized with the solver
and the transfer matrix rhs factory.
parent 85bfe42b
......@@ -39,8 +39,9 @@ namespace duneuro
struct CGSolverTraits {
static const int dimension = VC::dim;
using VolumeConductor = VC;
using GridView = typename VC::GridView;
using CoordinateFieldType = typename VC::ctype;
using ElementSearch = KDTreeElementSearch<typename VC::GridView>;
using ElementSearch = KDTreeElementSearch<GridView>;
using Problem = ConvectionDiffusionCGDefaultParameter<VC>;
using DirichletExtension = Dune::PDELab::ConvectionDiffusionDirichletExtensionAdapter<Problem>;
using BoundaryCondition = Dune::PDELab::ConvectionDiffusionBoundaryConditionAdapter<Problem>;
......
......@@ -25,13 +25,13 @@ namespace duneuro
template <class ST, int comps, int degree, class P, class DF, class RF, class JF>
struct CutFEMSolverTraits {
using SubTriangulation = ST;
using FundamentalGridView = typename ST::BaseT::GridView;
using CoordinateFieldType = typename FundamentalGridView::ctype;
using ElementSearch = KDTreeElementSearch<FundamentalGridView>;
static const int dimension = FundamentalGridView::dimension;
using GridView = typename ST::BaseT::GridView;
using CoordinateFieldType = typename GridView::ctype;
using ElementSearch = KDTreeElementSearch<GridView>;
static const int dimension = GridView::dimension;
static const int compartments = comps;
using Problem = P;
using FunctionSpace = CutFEMMultiPhaseSpace<FundamentalGridView, RF, degree, compartments>;
using FunctionSpace = CutFEMMultiPhaseSpace<GridView, RF, degree, compartments>;
using DomainField = DF;
using RangeField = RF;
using DomainDOFVector = Dune::PDELab::Backend::Vector<typename FunctionSpace::GFS, DF>;
......@@ -42,7 +42,7 @@ namespace duneuro
ConvectionDiffusion_DG_LocalOperator<Problem, EdgeNormProvider, PenaltyFluxWeighting>;
using WrappedLocalOperator = Dune::UDG::CutFEMMultiPhaseLocalOperatorWrapper<LocalOperator>;
// using WrappedLocalOperator = Dune::UDG::MultiPhaseLocalOperatorWrapper<LocalOperator>;
using UnfittedSubTriangulation = Dune::PDELab::UnfittedSubTriangulation<FundamentalGridView>;
using UnfittedSubTriangulation = Dune::PDELab::UnfittedSubTriangulation<GridView>;
using MatrixBackend = Dune::PDELab::istl::BCRSMatrixBackend<>;
using RawGridOperator =
Dune::UDG::UDGGridOperator<typename FunctionSpace::GFS, typename FunctionSpace::GFS,
......@@ -143,13 +143,6 @@ namespace duneuro
return false;
}
#if HAVE_TBB
tbb::mutex& functionSpaceMutex()
{
return fsMutex_;
}
#endif
private:
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<const typename Traits::ElementSearch> search_;
......@@ -163,10 +156,6 @@ namespace duneuro
typename Traits::RawGridOperator rawGridOperator_;
typename Traits::GridOperator gridOperator_;
typename Traits::LinearSolver linearSolver_;
#if HAVE_TBB
tbb::mutex fsMutex_;
#endif
};
}
......
......@@ -41,8 +41,9 @@ namespace duneuro
struct DGSolverTraits {
static const int dimension = VC::dim;
using VolumeConductor = VC;
using GridView = typename VC::GridView;
using CoordinateFieldType = typename VC::ctype;
using ElementSearch = KDTreeElementSearch<typename VC::GridView>;
using ElementSearch = KDTreeElementSearch<GridView>;
using Problem = P;
using FunctionSpace = typename DGFunctionSpaceTraits<VC, degree, elementType>::Type;
using DomainDOFVector = Dune::PDELab::Backend::Vector<typename FunctionSpace::GFS, DF>;
......
......@@ -26,13 +26,13 @@ namespace duneuro
template <class ST, int comps, int degree, class P, class DF, class RF, class JF>
struct UDGSolverTraits {
using SubTriangulation = ST;
using FundamentalGridView = typename ST::BaseT::GridView;
using CoordinateFieldType = typename FundamentalGridView::ctype;
using ElementSearch = KDTreeElementSearch<FundamentalGridView>;
static const int dimension = FundamentalGridView::dimension;
using GridView = typename ST::BaseT::GridView;
using CoordinateFieldType = typename GridView::ctype;
using ElementSearch = KDTreeElementSearch<GridView>;
static const int dimension = GridView::dimension;
static const int compartments = comps;
using Problem = P;
using FunctionSpace = UDGQkMultiPhaseSpace<FundamentalGridView, RF, degree, compartments>;
using FunctionSpace = UDGQkMultiPhaseSpace<GridView, RF, degree, compartments>;
using DomainField = DF;
using RangeField = RF;
using DomainDOFVector = Dune::PDELab::Backend::Vector<typename FunctionSpace::GFS, DF>;
......@@ -42,7 +42,7 @@ namespace duneuro
using LocalOperator =
ConvectionDiffusion_DG_LocalOperator<Problem, EdgeNormProvider, PenaltyFluxWeighting>;
using WrappedLocalOperator = Dune::UDG::MultiPhaseLocalOperatorWrapper<LocalOperator>;
using UnfittedSubTriangulation = Dune::PDELab::UnfittedSubTriangulation<FundamentalGridView>;
using UnfittedSubTriangulation = Dune::PDELab::UnfittedSubTriangulation<GridView>;
using MatrixBackend = Dune::PDELab::istl::BCRSMatrixBackend<>;
using GridOperator =
Dune::UDG::UDGGridOperator<typename FunctionSpace::GFS, typename FunctionSpace::GFS,
......@@ -139,13 +139,6 @@ namespace duneuro
return true;
}
#if HAVE_TBB
tbb::mutex& functionSpaceMutex()
{
return fsMutex_;
}
#endif
private:
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<const typename Traits::ElementSearch> search_;
......@@ -158,10 +151,6 @@ namespace duneuro
typename Traits::UnfittedSubTriangulation unfittedSubTriangulation_;
typename Traits::GridOperator gridOperator_;
typename Traits::LinearSolver linearSolver_;
#if HAVE_TBB
tbb::mutex fsMutex_;
#endif
};
}
......
......@@ -19,13 +19,13 @@
#include <duneuro/common/kdtree.hh>
#endif
//#include <dune/biomag/localoperator/boundaryprojection.hh>
#include <duneuro/eeg/projection_utilities.hh>
#include <duneuro/eeg/electrode_projection_interface.hh>
#include <duneuro/io/data_tree.hh>
namespace duneuro
{
template <class GV>
class ProjectedElectrodes
class ProjectedElectrodes : public ElectrodeProjectionInterface<GV>
{
typedef typename GV::template Codim<0>::Entity Element;
......@@ -33,7 +33,7 @@ namespace duneuro
typedef typename GV::ctype ctype;
enum { dim = GV::dimension };
using Projection = ProjectedPosition<Element, Dune::FieldVector<ctype, dim>>;
using Projection = ProjectedElectrode<GV>;
ProjectedElectrodes(const std::vector<Dune::FieldVector<ctype, dim>>& electrodes, const GV& gv,
DataTree dataTree = DataTree())
......@@ -43,7 +43,7 @@ namespace duneuro
std::vector<std::pair<Projection, ctype>> minDistance;
for (std::size_t i = 0; i < electrodes.size(); ++i) {
minDistance.push_back(std::make_pair(
Projection(gridView_.template begin<0>(), Dune::FieldVector<ctype, dim>(0.0)),
Projection{gridView_.template begin<0>(), Dune::FieldVector<ctype, dim>(0.0)},
std::numeric_limits<ctype>::max()));
}
std::cout << "\n";
......@@ -62,34 +62,10 @@ namespace duneuro
diff -= eg.global(local);
auto diff2n = diff.two_norm();
if (diff2n < minDistance[i].second) {
minDistance[i].first = Projection(element, local);
minDistance[i].first = Projection{element, local};
minDistance[i].second = diff2n;
}
//}
/* for (const auto& intersection : intersections(gridView_, element)) {
if (intersection.neighbor())
continue;
const auto& ig = intersection.geometry();
const auto& reference = ReferenceElements<ctype, dim - 1>::general(ig.type());
for (unsigned int i = 0; i < electrodes.size(); ++i) {
auto local = ig.local(electrodes[i]);
if (reference.checkInside(local)) {
auto projectedGlobal = ig.global(local);
auto diff = electrodes[i];
diff -= projectedGlobal;
// if (Dune::FloatCmp::ge(intersection.centerUnitOuterNormal() * diff,
0.0)) {
auto diff2n = diff.two_norm();
if (diff2n < minDistance[i].second) {
minDistance[i].first = Projection(
intersection.inside(),
intersection.geometryInInside().global(local));
minDistance[i].second = diff2n;
}
//}
}
}
}*/
}
}
ctype maxdiff = 0.0;
......@@ -119,12 +95,18 @@ namespace duneuro
DUNE_THROW(Dune::Exception, "element of electrode at "
<< electrode << " is not a host cell for any domain");
}
projections_.push_back(Projection(element, element.geometry().local(electrode)));
projections_.push_back(Projection{element, element.geometry().local(electrode)});
}
dataTree.set("time", timer.elapsed());
}
#endif
virtual void
setElectrodes(const std::vector<Dune::FieldVector<ctype, dim>>& electrodes) override
{
DUNE_THROW(Dune::Exception, "should not be called");
}
template <class DGF, class OutputIterator>
void evaluateAtProjections(const DGF& dgf, OutputIterator out) const
{
......@@ -163,12 +145,12 @@ namespace duneuro
return projections_[i].element;
}
const Projection& projectedPosition(std::size_t i) const
virtual const Projection& getProjection(std::size_t i) const override
{
return projections_[i];
}
std::size_t size() const
virtual std::size_t size() const override
{
return projections_.size();
}
......
#ifndef DUNEURO_FITTED_TRANSFER_MATRIX_SOLVER_HH
#define DUNEURO_FITTED_TRANSFER_MATRIX_SOLVER_HH
#ifndef DUNEURO_TRANSFER_MATRIX_SOLVER_HH
#define DUNEURO_TRANSFER_MATRIX_SOLVER_HH
#include <dune/common/parametertree.hh>
#include <dune/common/timer.hh>
#include <duneuro/common/flags.hh>
#include <duneuro/common/make_dof_vector.hh>
#include <duneuro/eeg/electrode_projection_interface.hh>
#include <duneuro/io/data_tree.hh>
......@@ -12,52 +10,48 @@
namespace duneuro
{
template <class S>
struct FittedTransferMatrixSolverTraits {
struct TransferMatrixSolverTraits {
using Solver = S;
static const unsigned int dimension = S::Traits::dimension;
using VolumeConductor = typename S::Traits::VolumeConductor;
using FunctionSpace = typename S::Traits::FunctionSpace;
using DomainDOFVector = typename S::Traits::DomainDOFVector;
using RangeDOFVector = typename S::Traits::RangeDOFVector;
using CoordinateFieldType = typename VolumeConductor::ctype;
static const unsigned int dimension = Solver::Traits::dimension;
using FunctionSpace = typename Solver::Traits::FunctionSpace;
using DomainDOFVector = typename Solver::Traits::DomainDOFVector;
using RangeDOFVector = typename Solver::Traits::RangeDOFVector;
using CoordinateFieldType = typename Solver::Traits::CoordinateFieldType;
using Coordinate = Dune::FieldVector<CoordinateFieldType, dimension>;
using Element = typename VolumeConductor::GridView::template Codim<0>::Entity;
using ProjectedPosition = ProjectedElectrode<typename VolumeConductor::GridView>;
using ProjectedPosition = duneuro::ProjectedElectrode<typename Solver::Traits::GridView>;
};
template <class S, class RHSFactory>
class FittedTransferMatrixSolver
class TransferMatrixSolver
{
public:
using Traits = FittedTransferMatrixSolverTraits<S>;
using Traits = TransferMatrixSolverTraits<S>;
FittedTransferMatrixSolver(std::shared_ptr<typename Traits::VolumeConductor> volumeConductor,
std::shared_ptr<typename Traits::Solver> solver)
: volumeConductor_(volumeConductor)
, solver_(solver)
TransferMatrixSolver(std::shared_ptr<typename Traits::Solver> solver,
const Dune::ParameterTree& config)
: solver_(solver)
, rightHandSideVector_(solver_->functionSpace().getGFS(), 0.0)
, config_(config)
{
}
template <class SolverBackend>
std::unique_ptr<DenseMatrix<double>>
solve(SolverBackend& solverBackend,
const ElectrodeProjectionInterface<typename Traits::VolumeConductor::GridView>&
electrodeProjection,
const ElectrodeProjectionInterface<typename Traits::Solver::Traits::GridView>&
projectedElectrodes,
const Dune::ParameterTree& config, DataTree dataTree = DataTree())
{
auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
electrodeProjection.size(), solver_->functionSpace().getGFS().ordering().size());
projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
auto solver_config = config.sub("solver");
typename Traits::DomainDOFVector solution(solver_->functionSpace().getGFS(), 0.0);
for (std::size_t index = 1; index < electrodeProjection.size(); ++index) {
solve(solverBackend.get(), electrodeProjection.getProjection(0),
electrodeProjection.getProjection(index), solution, rightHandSideVector_,
for (std::size_t index = 1; index < projectedElectrodes.size(); ++index) {
solve(solverBackend.get(), projectedElectrodes.getProjection(0),
projectedElectrodes.getProjection(index), solution, rightHandSideVector_,
solver_config, dataTree.sub("solver.electrode_" + std::to_string(index)));
set_matrix_row(*transferMatrix, index, Dune::PDELab::Backend::native(solution));
}
return transferMatrix;
}
......@@ -65,13 +59,12 @@ namespace duneuro
template <class SolverBackend>
std::unique_ptr<DenseMatrix<double>>
solve(tbb::enumerable_thread_specific<SolverBackend>& solverBackend,
const ElectrodeProjectionInterface<typename Traits::VolumeConductor::GridView>&
electrodeProjection,
const ElectrodeProjectionInterface<typename Traits::Solver::Traits::GridView>&
projectedElectrodes,
const Dune::ParameterTree& config, DataTree dataTree = DataTree())
{
auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
electrodeProjection.size(), solver_->functionSpace().getGFS().ordering().size());
projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
auto solver_config = config.sub("solver");
tbb::task_scheduler_init init(solver_config.hasKey("numberOfThreads") ?
solver_config.get<std::size_t>("numberOfThreads") :
......@@ -79,42 +72,30 @@ namespace duneuro
auto grainSize = solver_config.get<int>("grainSize", 16);
tbb::enumerable_thread_specific<typename Traits::DomainDOFVector> solution(
solver_->functionSpace().getGFS(), 0.0);
// split the electrodes into blocks of at most grainSize number of elements. solve these
// blocks in parallel
tbb::parallel_for(tbb::blocked_range<std::size_t>(1, electrodeProjection.size(), grainSize),
tbb::parallel_for(tbb::blocked_range<std::size_t>(1, projectedElectrodes.size(), grainSize),
[&](const tbb::blocked_range<std::size_t>& range) {
auto& mySolution = solution.local();
for (std::size_t index = range.begin(); index != range.end(); ++index) {
solve(solverBackend.local().get(), electrodeProjection.getProjection(0),
electrodeProjection.getProjection(index), mySolution,
solve(solverBackend.local().get(), projectedElectrodes.getProjection(0),
projectedElectrodes.getProjection(index), mySolution,
rightHandSideVector_.local(), solver_config,
dataTree.sub("solver.electrode_" + std::to_string(index)));
set_matrix_row(*transferMatrix, index,
Dune::PDELab::Backend::native(mySolution));
}
});
return transferMatrix;
}
#endif
const typename Traits::FunctionSpace& functionSpace() const
{
return solver_->functionSpace();
}
private:
std::shared_ptr<typename Traits::VolumeConductor> volumeConductor_;
std::shared_ptr<typename Traits::Solver> solver_;
#if HAVE_TBB
tbb::enumerable_thread_specific<typename Traits::RangeDOFVector> rightHandSideVector_;
#else
typename Traits::RangeDOFVector rightHandSideVector_;
#endif
template <class V>
friend struct MakeDOFVectorHelper;
Dune::ParameterTree config_;
template <class SolverBackend>
void solve(SolverBackend& solverBackend, const typename Traits::ProjectedPosition& reference,
......@@ -124,8 +105,8 @@ namespace duneuro
const Dune::ParameterTree& config, DataTree dataTree = DataTree()) const
{
Dune::Timer timer;
// assemble right hand side
rightHandSideVector = 0.0;
// assemble right hand side
auto rhsAssembler =
RHSFactory::template create<typename Traits::RangeDOFVector>(*solver_, config);
rhsAssembler->bind(reference.element, reference.localPosition, electrode.element,
......@@ -143,4 +124,5 @@ namespace duneuro
}
};
}
#endif // DUNEURO_FITTED_TRANSFER_MATRIX_SOLVER_HH
#endif // DUNEURO_TRANSFER_MATRIX_SOLVER_HH
......@@ -10,8 +10,7 @@ namespace duneuro
{
struct UnfittedTransferMatrixRHSFactory {
template <class Vector, class Solver>
static std::unique_ptr<TransferMatrixRHSInterface<typename Solver::Traits::FundamentalGridView,
Vector>>
static std::unique_ptr<TransferMatrixRHSInterface<typename Solver::Traits::GridView, Vector>>
create(const Solver& solver, const Dune::ParameterTree& config)
{
auto type = config.get<std::string>("type", "point");
......
#ifndef DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH
#define DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH
#include <dune/common/parametertree.hh>
#include <duneuro/common/make_dof_vector.hh>
#include <duneuro/eeg/projection_utilities.hh>
#include <duneuro/io/data_tree.hh>
namespace duneuro
{
template <class S>
struct UnfittedTransferMatrixSolverTraits {
using Solver = S;
using SubTriangulation = typename Solver::Traits::SubTriangulation;
static const unsigned int dimension = SubTriangulation::dim;
using FunctionSpace = typename Solver::Traits::FunctionSpace;
using DomainDOFVector = typename Solver::Traits::DomainDOFVector;
using RangeDOFVector = typename Solver::Traits::RangeDOFVector;
using CoordinateFieldType = typename SubTriangulation::ctype;
using Coordinate = Dune::FieldVector<CoordinateFieldType, dimension>;
using Element = typename Solver::Traits::FundamentalGridView::template Codim<0>::Entity;
using ProjectedPosition = duneuro::ProjectedPosition<Element, Coordinate>;
};
template <class S, class RHSFactory>
class UnfittedTransferMatrixSolver
{
public:
using Traits = UnfittedTransferMatrixSolverTraits<S>;
UnfittedTransferMatrixSolver(
std::shared_ptr<typename Traits::SubTriangulation> subTriangulation,
std::shared_ptr<typename Traits::Solver> solver, const Dune::ParameterTree& config)
: subTriangulation_(subTriangulation)
, solver_(solver)
, rightHandSideVector_(solver_->functionSpace().getGFS(), 0.0)
, config_(config)
{
}
template <class SolverBackend>
std::unique_ptr<DenseMatrix<double>> solve(
SolverBackend& solverBackend,
const ProjectedElectrodes<typename Traits::SubTriangulation::GridView>& projectedElectrodes,
const Dune::ParameterTree& config, DataTree dataTree = DataTree())
{
auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
auto solver_config = config.sub("solver");
typename Traits::DomainDOFVector solution(solver_->functionSpace().getGFS(), 0.0);
for (std::size_t index = 1; index < projectedElectrodes.size(); ++index) {
solve(solverBackend.get(), projectedElectrodes.projectedPosition(0),
projectedElectrodes.projectedPosition(index), solution, rightHandSideVector_,
solver_config, dataTree.sub("solver.electrode_" + std::to_string(index)));
set_matrix_row(*transferMatrix, index, Dune::PDELab::Backend::native(solution));
}
return transferMatrix;
}
#if HAVE_TBB
template <class SolverBackend>
std::unique_ptr<DenseMatrix<double>> solve(
tbb::enumerable_thread_specific<SolverBackend>& solverBackend,
const ProjectedElectrodes<typename Traits::SubTriangulation::GridView>& projectedElectrodes,
const Dune::ParameterTree& config, DataTree dataTree = DataTree())
{
auto transferMatrix = Dune::Std::make_unique<DenseMatrix<double>>(
projectedElectrodes.size(), solver_->functionSpace().getGFS().ordering().size());
auto solver_config = config.sub("solver");
tbb::task_scheduler_init init(solver_config.hasKey("numberOfThreads") ?
solver_config.get<std::size_t>("numberOfThreads") :
tbb::task_scheduler_init::automatic);
auto grainSize = solver_config.get<int>("grainSize", 16);
tbb::enumerable_thread_specific<typename Traits::DomainDOFVector> solution(
solver_->functionSpace().getGFS(), 0.0);
tbb::parallel_for(
tbb::blocked_range<std::size_t>(1, projectedElectrodes.size(), grainSize),
[&](const tbb::blocked_range<std::size_t>& range) {
auto& mySolution = solution.local();
for (std::size_t index = range.begin(); index != range.end(); ++index) {
solve(solverBackend.local().get(), projectedElectrodes.projectedPosition(0),
projectedElectrodes.projectedPosition(index), mySolution,
rightHandSideVector_.local(), solver_config,
dataTree.sub("solver.electrode_" + std::to_string(index)));
set_matrix_row(*transferMatrix, index, Dune::PDELab::Backend::native(mySolution));
}
});
return transferMatrix;
}
#endif
const typename Traits::FunctionSpace& functionSpace() const
{
return solver_->functionSpace();
}
private:
std::shared_ptr<typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<typename Traits::Solver> solver_;
#if HAVE_TBB
tbb::enumerable_thread_specific<typename Traits::RangeDOFVector> rightHandSideVector_;
#else
typename Traits::RangeDOFVector rightHandSideVector_;
#endif
Dune::ParameterTree config_;
template <class SolverBackend>
void solve(SolverBackend& solverBackend, const typename Traits::ProjectedPosition& reference,
const typename Traits::ProjectedPosition& electrode,
typename Traits::DomainDOFVector& solution,
typename Traits::RangeDOFVector& rightHandSideVector,
const Dune::ParameterTree& config, DataTree dataTree = DataTree()) const
{
Dune::Timer timer;
rightHandSideVector = 0.0;
// assemble right hand side
auto rhsAssembler =
RHSFactory::template create<typename Traits::RangeDOFVector>(*solver_, config);
#if HAVE_TBB
{ // mutex, as the finite element map is not thread safe
tbb::mutex::scoped_lock lock(solver_->functionSpaceMutex());
rhsAssembler->bind(reference.element, reference.localPosition, electrode.element,
electrode.localPosition);
}
rhsAssembler->assembleRightHandSide(rightHandSideVector);
#else
rhsAssembler->bind(reference.element, reference.localPosition, electrode.element,
electrode.localPosition);
rhsAssembler->assembleRightHandSide(rightHandSideVector);
#endif
timer.stop();
dataTree.set("time_rhs_assembly", timer.lastElapsed());
timer.start();
// solve system
solver_->solve(solverBackend, rightHandSideVector, solution, config,
dataTree.sub("linear_system_solver"));
timer.stop();
dataTree.set("time_solution", timer.lastElapsed());
dataTree.set("time", timer.elapsed());
}
};
}
#endif // DUNEURO_UNFITTED_TRANSFER_MATRIX_SOLVER_HH
......@@ -35,10 +35,11 @@ namespace duneuro
public:
using Traits = UnfittedTransferMatrixUserTraits<S, SMF>;
UnfittedTransferMatrixUser(std::shared_ptr<typename Traits::SubTriangulation> subTriangulation,
std::shared_ptr<typename Traits::Solver> solver,
std::shared_ptr<typename Traits::ElementSearch> search,
const Dune::ParameterTree& config)
UnfittedTransferMatrixUser(
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation,
std::shared_ptr<const typename Traits::Solver> solver,
std::shared_ptr<const typename Traits::ElementSearch> search,
const Dune::ParameterTree& config)
: subTriangulation_(subTriangulation), solver_(solver), search_(search)
{
}
......@@ -105,14 +106,7 @@ namespace duneuro
{
using SVC = typename Traits::SparseRHSVector;
SVC rhs;
#if HAVE_TBB
{
tbb::mutex::scoped_lock lock(solver_->functionSpaceMutex());
sparseSourceModel_->assembleRightHandSide(rhs);
}
#else
sparseSourceModel_->assembleRightHandSide(rhs);
#endif
const auto blockSize = Traits::Solver::Traits::FunctionSpace::blockSize;
......@@ -135,23 +129,16 @@ namespace duneuro
} else {
*denseRHSVector_ = 0.0;
}
#if HAVE_TBB
{
tbb::mutex::scoped_lock lock(solver_->functionSpaceMutex());
denseSourceModel_->assembleRightHandSide(*denseRHSVector_);
}
#else
denseSourceModel_->assembleRightHandSide(*denseRHSVector_);
#endif
return matrix_dense_vector_product(transferMatrix,
Dune::PDELab::Backend::native(*denseRHSVector_));
}
private:
std::shared_ptr<typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<typename Traits::Solver> solver_;
std::shared_ptr<typename Traits::ElementSearch> search_;
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<const typename Traits::Solver> solver_;
std::shared_ptr<const typename Traits::ElementSearch> search_;
VectorDensity density_;
std::shared_ptr<SourceModelInterface<typename Traits::DomainField, Traits::dimension,
typename Traits::SparseRHSVector>>
......
......@@ -28,8 +28,8 @@
#include <duneuro/eeg/eeg_forward_solver.hh>
#include <duneuro/eeg/electrode_projection_factory.hh>
#include <duneuro/eeg/fitted_transfer_matrix_rhs_factory.hh>
#include <duneuro/eeg/fitted_transfer_matrix_solver.hh>
#include <duneuro/eeg/fitted_transfer_matrix_user.hh>
#include <duneuro/eeg/transfer_matrix_solver.hh>
#include <duneuro/io/fitted_tensor_vtk_functor.hh>
#include <duneuro/io/volume_conductor_reader.hh>
#include <duneuro/io/vtk_writer.hh>
......@@ -104,7 +104,8 @@ namespace duneuro
nullptr)
, solverBackend_(solver_,
config.hasSub("solver") ? config.sub("solver") : Dune::ParameterTree())
, eegTransferMatrixSolver_(volumeConductorStorage_.get(), solver_)
, eegTransferMatrixSolver_(solver_, config.hasSub("solver") ? config.sub("solver") :
Dune::ParameterTree())
, megTransferMatrixSolver_(solver_, megSolver_)
, eegForwardSolver_(solver_)
{
......@@ -432,7 +433,7 @@ namespace duneuro
#else
typename Traits::SolverBackend solverBackend_;
#endif
FittedTransferMatrixSolver<typename Traits::Solver, typename Traits::TransferMatrixRHSFactory>
TransferMatrixSolver<typename Traits::Solver, typename Traits::TransferMatrixRHSFactory>
eegTransferMatrixSolver_;
FittedMEGTransferMatrixSolver<typename Traits::Solver> megTransferMatrixSolver_;
EEGForwardSolver<typename Traits::Solver, typename Traits::SourceModelFactory>
......
......@@ -20,9 +20,9 @@
#include <duneuro/eeg/cutfem_source_model_factory.hh>
#include <duneuro/eeg/eeg_forward_solver.hh>
#include <duneuro/eeg/projected_electrodes.hh>
#include <duneuro/eeg/transfer_matrix_solver.hh>
#include <duneuro/eeg/udg_source_model_factory.hh>
#include <duneuro/eeg/unfitted_transfer_matrix_rhs_factory.hh>
#include <duneuro/eeg/unfitted_transfer_matrix_solver.hh>
#include <duneuro/eeg/unfitted_transfer_matrix_user.hh>
#include <duneuro/io/refined_vtk_writer.hh>
#include <duneuro/io/vtk_functors.hh>
......@@ -76,7 +76,7 @@ namespace duneuro
using SourceModelFactory = typename SelectUnfittedSolver<solverType, dim, degree,
compartments>::SourceModelFactoryType;
using TransferMatrixRHSFactory = UnfittedTransferMatrixRHSFactory;
using EEGTransferMatrixSolver = UnfittedTransferMatrixSolver<Solver, TransferMatrixRHSFactory>;
using EEGTransferMatrixSolver = TransferMatrixSolver<Solver, TransferMatrixRHSFactory>;
using TransferMatrixUser = UnfittedTransferMatrixUser<Solver, SourceModelFactory>;
using SolverBackend =
typename SelectUnfittedSolver<solverType, dim, degree, compartments>::SolverBackendType;
......@@ -114,7 +114,7 @@ namespace duneuro
config.sub("solver")))
, solverBackend_(solver_,
config.hasSub("solver") ? config.sub("solver") : Dune::ParameterTree())
, eegTransferMatrixSolver_(subTriangulation_, solver_, config.sub("solver"))
, eegTransferMatrixSolver_(solver_, config.sub("solver"))
, eegForwardSolver_(solver_)
, conductivities_(config.get<std::vector<double>>("solver.conductivities"))
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment