Commit 4274b4ae authored by Andreas Nüßing's avatar Andreas Nüßing

[TransferMatrixUser] unify fitted and unfitted

Again, there were two basically identical versions of the transfer
matrix user, one for fitted and the other one for unfitted solvers.
We merge both classes into a single one which is then parameterized by
the solver and the source model factory.
parent bfd5f465
#ifndef DUNEURO_FITTED_TRANSFER_MATRIX_USER_HH
#define DUNEURO_FITTED_TRANSFER_MATRIX_USER_HH
#include <dune/common/parametertree.hh>
#include <dune/common/timer.hh>
#include <duneuro/common/dg_solver.hh>
#include <duneuro/common/dipole.hh>
#include <duneuro/common/flags.hh>
#include <duneuro/common/make_dof_vector.hh>
#include <duneuro/common/matrix_utilities.hh>
#include <duneuro/common/sparse_vector_container.hh>
#include <duneuro/common/vector_density.hh>
#include <duneuro/eeg/dg_source_model_factory.hh>
#include <duneuro/io/data_tree.hh>
namespace duneuro
{
template <class S>
struct FittedTransferMatrixUserTraits {
static const unsigned int dimension = S::Traits::dimension;
using VolumeConductor = typename S::Traits::VolumeConductor;
using EEGForwardSolver = S;
using DenseRHSVector = typename EEGForwardSolver::Traits::RangeDOFVector;
using SparseRHSVector = SparseVectorContainer<typename DenseRHSVector::ContainerIndex,
typename DenseRHSVector::ElementType>;
using CoordinateField = typename VolumeConductor::ctype;
using Coordinate = Dune::FieldVector<CoordinateField, dimension>;
using DipoleType = Dipole<CoordinateField, dimension>;
using DomainField = typename EEGForwardSolver::Traits::DomainDOFVector::field_type;
using ElementSearch = KDTreeElementSearch<typename VolumeConductor::GridView>;
};
template <class S, class SMF>
class FittedTransferMatrixUser
{
public:
using Traits = FittedTransferMatrixUserTraits<S>;
FittedTransferMatrixUser(std::shared_ptr<typename Traits::VolumeConductor> volumeConductor,
std::shared_ptr<typename Traits::ElementSearch> search,
std::shared_ptr<typename Traits::EEGForwardSolver> solver)
: volumeConductor_(volumeConductor), search_(search), solver_(solver)
{
}
void setSourceModel(const Dune::ParameterTree& config, const Dune::ParameterTree& solverConfig,
DataTree dataTree = DataTree())
{
sparseSourceModel_.reset();
denseSourceModel_.reset();
density_ = source_model_default_density(config);
if (density_ == VectorDensity::sparse) {
sparseSourceModel_ = SMF::template createSparse<typename Traits::SparseRHSVector>(
*solver_, config, solverConfig);
} else {
denseSourceModel_ = SMF::template createDense<typename Traits::DenseRHSVector>(
*solver_, config, solverConfig);
}
}
void bind(const typename Traits::DipoleType& dipole, DataTree dataTree = DataTree())
{
if (density_ == VectorDensity::sparse) {
if (!sparseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
sparseSourceModel_->bind(dipole, dataTree);
} else {
if (!denseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
denseSourceModel_->bind(dipole, dataTree);
}
}
void postProcessPotential(const std::vector<typename Traits::Coordinate>& projectedElectrodes,
std::vector<typename Traits::DomainField>& potential)
{
if (projectedElectrodes.size() != potential.size()) {
DUNE_THROW(duneuro::IllegalArgumentException,
"number of electrodes ("
<< projectedElectrodes.size()
<< ") does not match number of entries in the potential vector ("
<< potential.size() << ")");
}
if (density_ == VectorDensity::sparse) {
if (!sparseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
sparseSourceModel_->postProcessSolution(projectedElectrodes, potential);
} else {
if (!denseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
denseSourceModel_->postProcessSolution(projectedElectrodes, potential);
}
}
template <class M>
std::vector<typename Traits::DomainField> solve(const M& transferMatrix,
DataTree dataTree = DataTree()) const
{
Dune::Timer timer;
std::vector<typename Traits::DomainField> result;
if (density_ == VectorDensity::sparse) {
dataTree.set("density", "sparse");
result = solveSparse(transferMatrix);
} else {
dataTree.set("density", "dense");
result = solveDense(transferMatrix);
}
dataTree.set("time", timer.elapsed());
return result;
}
template <class M>
std::vector<typename Traits::DomainField> solveSparse(const M& transferMatrix) const
{
using SVC = typename Traits::SparseRHSVector;
SVC rhs;
if (!sparseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
sparseSourceModel_->assembleRightHandSide(rhs);
const auto blockSize =
Traits::EEGForwardSolver::Traits::FunctionSpace::GFS::Traits::Backend::blockSize;
if (blockSize == 1) {
return matrix_sparse_vector_product(transferMatrix, rhs,
[](const typename SVC::Index& c) { return c[0]; });
} else {
return matrix_sparse_vector_product(
transferMatrix, rhs,
[blockSize](const typename SVC::Index& c) { return c[1] * blockSize + c[0]; });
}
}
template <class M>
std::vector<typename Traits::DomainField> solveDense(const M& transferMatrix) const
{
if (!denseRHSVector_) {
denseRHSVector_ = make_range_dof_vector(*solver_, 0.0);
} else {
*denseRHSVector_ = 0.0;
}
if (!denseSourceModel_) {
DUNE_THROW(Dune::Exception, "source model not set");
}
denseSourceModel_->assembleRightHandSide(*denseRHSVector_);
return matrix_dense_vector_product(transferMatrix,
Dune::PDELab::Backend::native(*denseRHSVector_));
}
private:
std::shared_ptr<typename Traits::VolumeConductor> volumeConductor_;
std::shared_ptr<typename Traits::ElementSearch> search_;
std::shared_ptr<typename Traits::EEGForwardSolver> solver_;
VectorDensity density_;
std::shared_ptr<SourceModelInterface<typename Traits::DomainField, Traits::dimension,
typename Traits::SparseRHSVector>>
sparseSourceModel_;
std::shared_ptr<SourceModelInterface<typename Traits::DomainField, Traits::dimension,
typename Traits::DenseRHSVector>>
denseSourceModel_;
mutable std::shared_ptr<typename Traits::DenseRHSVector> denseRHSVector_;
};
}
#endif // DUNEURO_FITTED_TRANSFER_MATRIX_USER_HH
#ifndef DUNEURO_UNFITTED_TRANSFER_MATRIX_USER_HH
#define DUNEURO_UNFITTED_TRANSFER_MATRIX_USER_HH
#ifndef DUNEURO_TRANSFER_MATRIX_USER_HH
#define DUNEURO_TRANSFER_MATRIX_USER_HH
#include <dune/common/parametertree.hh>
#include <dune/common/timer.hh>
......@@ -15,32 +15,26 @@
namespace duneuro
{
template <class S, class SMF>
struct UnfittedTransferMatrixUserTraits {
struct TransferMatrixUserTraits {
using Solver = S;
using SubTriangulation = typename Solver::Traits::SubTriangulation;
static const unsigned int dimension = SubTriangulation::dim;
static const unsigned int dimension = S::Traits::dimension;
using DenseRHSVector = typename Solver::Traits::RangeDOFVector;
using SparseRHSVector = SparseVectorContainer<typename DenseRHSVector::ContainerIndex,
typename DenseRHSVector::ElementType>;
using CoordinateFieldType = typename SubTriangulation::ctype;
using CoordinateFieldType = typename S::Traits::CoordinateFieldType;
using Coordinate = Dune::FieldVector<CoordinateFieldType, dimension>;
using DipoleType = Dipole<CoordinateFieldType, dimension>;
using DomainField = typename Solver::Traits::DomainDOFVector::field_type;
using ElementSearch = KDTreeElementSearch<typename SubTriangulation::BaseT::GridView>;
};
template <class S, class SMF>
class UnfittedTransferMatrixUser
class TransferMatrixUser
{
public:
using Traits = UnfittedTransferMatrixUserTraits<S, SMF>;
UnfittedTransferMatrixUser(
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation,
std::shared_ptr<const typename Traits::Solver> solver,
std::shared_ptr<const typename Traits::ElementSearch> search,
const Dune::ParameterTree& config)
: subTriangulation_(subTriangulation), solver_(solver), search_(search)
using Traits = TransferMatrixUserTraits<S, SMF>;
explicit TransferMatrixUser(std::shared_ptr<const typename Traits::Solver> solver)
: solver_(solver)
{
}
......@@ -108,7 +102,7 @@ namespace duneuro
SVC rhs;
sparseSourceModel_->assembleRightHandSide(rhs);
const auto blockSize = Traits::Solver::Traits::FunctionSpace::blockSize;
const auto blockSize = S::Traits::FunctionSpace::GFS::Traits::Backend::blockSize;
std::vector<typename Traits::DomainField> output;
if (blockSize == 1) {
......@@ -136,9 +130,7 @@ namespace duneuro
}
private:
std::shared_ptr<const typename Traits::SubTriangulation> subTriangulation_;
std::shared_ptr<const typename Traits::Solver> solver_;
std::shared_ptr<const typename Traits::ElementSearch> search_;
VectorDensity density_;
std::shared_ptr<SourceModelInterface<typename Traits::DomainField, Traits::dimension,
typename Traits::SparseRHSVector>>
......@@ -150,4 +142,4 @@ namespace duneuro
};
}
#endif // DUNEURO_UNFITTED_TRANSFER_MATRIX_USER_HH
#endif // DUNEURO_TRANSFER_MATRIX_USER_HH
......@@ -28,8 +28,8 @@
#include <duneuro/eeg/eeg_forward_solver.hh>
#include <duneuro/eeg/electrode_projection_factory.hh>
#include <duneuro/eeg/fitted_transfer_matrix_rhs_factory.hh>
#include <duneuro/eeg/fitted_transfer_matrix_user.hh>
#include <duneuro/eeg/transfer_matrix_solver.hh>
#include <duneuro/eeg/transfer_matrix_user.hh>
#include <duneuro/io/fitted_tensor_vtk_functor.hh>
#include <duneuro/io/volume_conductor_reader.hh>
#include <duneuro/io/vtk_writer.hh>
......@@ -321,8 +321,7 @@ namespace duneuro
{
std::vector<std::vector<double>> result(dipoles.size());
using User =
FittedTransferMatrixUser<typename Traits::Solver, typename Traits::SourceModelFactory>;
using User = TransferMatrixUser<typename Traits::Solver, typename Traits::SourceModelFactory>;
#if HAVE_TBB
tbb::task_scheduler_init init(config.hasKey("numberOfThreads") ?
......@@ -330,7 +329,7 @@ namespace duneuro
tbb::task_scheduler_init::automatic);
tbb::parallel_for(tbb::blocked_range<std::size_t>(0, dipoles.size()),
[&](const tbb::blocked_range<std::size_t>& range) {
User myUser(volumeConductorStorage_.get(), elementSearch_, solver_);
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = range.begin(); index != range.end(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -346,7 +345,7 @@ namespace duneuro
}
});
#else
User myUser(volumeConductorStorage_.get(), elementSearch_, solver_);
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = 0; index < dipoles.size(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -371,8 +370,7 @@ namespace duneuro
{
std::vector<std::vector<double>> result(dipoles.size());
using User =
FittedTransferMatrixUser<typename Traits::Solver, typename Traits::SourceModelFactory>;
using User = TransferMatrixUser<typename Traits::Solver, typename Traits::SourceModelFactory>;
#if HAVE_TBB
tbb::task_scheduler_init init(config.hasKey("numberOfThreads") ?
......@@ -380,7 +378,7 @@ namespace duneuro
tbb::task_scheduler_init::automatic);
tbb::parallel_for(tbb::blocked_range<std::size_t>(0, dipoles.size()),
[&](const tbb::blocked_range<std::size_t>& range) {
User myUser(volumeConductorStorage_.get(), elementSearch_, solver_);
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = range.begin(); index != range.end(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -389,7 +387,7 @@ namespace duneuro
}
});
#else
User myUser(volumeConductorStorage_.get(), elementSearch_, solver_);
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = 0; index < dipoles.size(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......
......@@ -21,9 +21,9 @@
#include <duneuro/eeg/eeg_forward_solver.hh>
#include <duneuro/eeg/projected_electrodes.hh>
#include <duneuro/eeg/transfer_matrix_solver.hh>
#include <duneuro/eeg/transfer_matrix_user.hh>
#include <duneuro/eeg/udg_source_model_factory.hh>
#include <duneuro/eeg/unfitted_transfer_matrix_rhs_factory.hh>
#include <duneuro/eeg/unfitted_transfer_matrix_user.hh>
#include <duneuro/io/refined_vtk_writer.hh>
#include <duneuro/io/vtk_functors.hh>
#include <duneuro/meeg/meeg_driver_interface.hh>
......@@ -77,7 +77,7 @@ namespace duneuro
compartments>::SourceModelFactoryType;
using TransferMatrixRHSFactory = UnfittedTransferMatrixRHSFactory;
using EEGTransferMatrixSolver = TransferMatrixSolver<Solver, TransferMatrixRHSFactory>;
using TransferMatrixUser = UnfittedTransferMatrixUser<Solver, SourceModelFactory>;
using TransferMatrixUser = duneuro::TransferMatrixUser<Solver, SourceModelFactory>;
using SolverBackend =
typename SelectUnfittedSolver<solverType, dim, degree, compartments>::SolverBackendType;
......@@ -108,7 +108,8 @@ namespace duneuro
, domain_(levelSetGridView_, data_.levelSetData, config.sub("domain"))
, subTriangulation_(std::make_shared<typename Traits::SubTriangulation>(
fundamentalGridView_, levelSetGridView_, domain_.getDomainConfiguration(),
config.get<bool>("udg.force_refinement", false)))
config.get<bool>("udg.force_refinement", false),
config.get<double>("udg.value_tolerance", 1e-8)))
, elementSearch_(std::make_shared<typename Traits::ElementSearch>(fundamentalGridView_))
, solver_(std::make_shared<typename Traits::Solver>(subTriangulation_, elementSearch_,
config.sub("solver")))
......@@ -268,8 +269,7 @@ namespace duneuro
tbb::task_scheduler_init::automatic);
tbb::parallel_for(tbb::blocked_range<std::size_t>(0, dipoles.size(), grainSize),
[&](const tbb::blocked_range<std::size_t>& range) {
User myUser(subTriangulation_, solver_, elementSearch_,
config.sub("solver"));
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = range.begin(); index != range.end(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -285,7 +285,7 @@ namespace duneuro
}
});
#else
User myUser(subTriangulation_, solver_, elementSearch_, config.sub("solver"));
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = 0; index < dipoles.size(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -319,8 +319,7 @@ namespace duneuro
tbb::task_scheduler_init::automatic);
tbb::parallel_for(tbb::blocked_range<std::size_t>(0, dipoles.size(), grainSize),
[&](const tbb::blocked_range<std::size_t>& range) {
User myUser(subTriangulation_, solver_, elementSearch_,
config.sub("solver"));
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = range.begin(); index != range.end(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......@@ -329,7 +328,7 @@ namespace duneuro
}
});
#else
User myUser(subTriangulation_, solver_, elementSearch_, config.sub("solver"));
User myUser(solver_);
myUser.setSourceModel(config.sub("source_model"), config_.sub("solver"));
for (std::size_t index = 0; index < dipoles.size(); ++index) {
auto dt = dataTree.sub("dipole_" + std::to_string(index));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment