Skip to content
Snippets Groups Projects
Verified Commit ad5c398f authored by Santiago Ospina De Los Ríos's avatar Santiago Ospina De Los Ríos
Browse files

Add SFINAE requirements to other specialization & allow custom vectors

parent e7cf80d1
No related branches found
No related tags found
1 merge request!546Improve UMFPack vector chooser
Pipeline #65358 failed
......@@ -172,18 +172,25 @@ namespace Dune {
namespace Impl
{
template<class M, class = void>
struct UMFPackVectorChooser
{};
struct UMFPackVectorChooser;
/** @brief The type of the domain of the solver */
template<class M> using UMFPackDomainType = typename UMFPackVectorChooser<M>::domain_type;
/** @brief The type of the range of the solver */
template<class M> using UMFPackRangeType = typename UMFPackVectorChooser<M>::range_type;
template<class M>
struct UMFPackVectorChooser<M, std::enable_if_t<(std::is_same<M,double>::value) || (std::is_same<M,std::complex<double> >::value)>>
struct UMFPackVectorChooser<M,
std::enable_if_t<(std::is_same<M,double>::value) || (std::is_same<M,std::complex<double> >::value)>>
{
using domain_type = M;
using range_type = M;
};
template<typename T, int n, int m>
struct UMFPackVectorChooser<FieldMatrix<T,n,m>>
struct UMFPackVectorChooser<FieldMatrix<T,n,m>,
std::enable_if_t<(std::is_same<T,double>::value) || (std::is_same<T,std::complex<double> >::value)>>
{
/** @brief The type of the domain of the solver */
using domain_type = FieldVector<T,m>;
......@@ -192,34 +199,35 @@ namespace Dune {
};
template<typename T, typename A>
struct UMFPackVectorChooser<BCRSMatrix<T,A> >
struct UMFPackVectorChooser<BCRSMatrix<T,A>,
std::void_t<UMFPackDomainType<T>, UMFPackRangeType<T>>>
{
using sub_domain_type = typename UMFPackVectorChooser<T>::domain_type;
using sub_range_type = typename UMFPackVectorChooser<T>::range_type;
/** @brief The type of the domain of the solver */
using domain_type = BlockVector<sub_domain_type, typename std::allocator_traits<A>::template rebind_alloc<sub_domain_type>>;
using domain_type = BlockVector<UMFPackDomainType<T>, typename std::allocator_traits<A>::template rebind_alloc<UMFPackDomainType<T>>>;
/** @brief The type of the range of the solver */
using range_type = BlockVector<sub_range_type, typename std::allocator_traits<A>::template rebind_alloc<sub_domain_type>>;
using range_type = BlockVector<UMFPackRangeType<T>, typename std::allocator_traits<A>::template rebind_alloc<UMFPackRangeType<T>>>;
};
// to make the `UMFPackVectorChooser` work with `MultiTypeBlockMatrix`, we need to add an intermediate step for the rows, which are typically `MultiTypeBlockVector`
template<typename FirstBlock, typename... Blocks>
struct UMFPackVectorChooser<MultiTypeBlockVector<FirstBlock, Blocks...> >
struct UMFPackVectorChooser<MultiTypeBlockVector<FirstBlock, Blocks...>,
std::void_t<UMFPackDomainType<FirstBlock>, UMFPackRangeType<FirstBlock>, UMFPackDomainType<Blocks>...>>
{
/** @brief The type of the domain of the solver */
using domain_type = MultiTypeBlockVector<typename UMFPackVectorChooser<FirstBlock>::domain_type,typename UMFPackVectorChooser<Blocks>::domain_type... >;
using domain_type = MultiTypeBlockVector<UMFPackDomainType<FirstBlock>, UMFPackDomainType<Blocks>...>;
/** @brief The type of the range of the solver */
using range_type = typename UMFPackVectorChooser<FirstBlock>::range_type;
using range_type = UMFPackRangeType<FirstBlock>;
};
// specialization for `MultiTypeBlockMatrix` with `MultiTypeBlockVector` rows
template<typename FirstRow, typename... Rows>
struct UMFPackVectorChooser<MultiTypeBlockMatrix<FirstRow, Rows...> >
struct UMFPackVectorChooser<MultiTypeBlockMatrix<FirstRow, Rows...>,
std::void_t<UMFPackDomainType<FirstRow>, UMFPackRangeType<FirstRow>, UMFPackRangeType<Rows>...>>
{
/** @brief The type of the domain of the solver */
using domain_type = typename UMFPackVectorChooser<FirstRow>::domain_type;
using domain_type = UMFPackDomainType<FirstRow>;
/** @brief The type of the range of the solver */
using range_type = MultiTypeBlockVector< typename UMFPackVectorChooser<FirstRow>::range_type, typename UMFPackVectorChooser<Rows>::range_type... >;
using range_type = MultiTypeBlockVector< UMFPackRangeType<FirstRow>, UMFPackRangeType<Rows>... >;
};
// dummy class to represent no BitVector
......@@ -242,15 +250,11 @@ namespace Dune {
*
* \note This will only work if dune-istl has been configured to use UMFPack
*/
template<typename M>
class UMFPack
: public InverseOperator<
typename Impl::UMFPackVectorChooser<M>::domain_type,
typename Impl::UMFPackVectorChooser<M>::range_type >
template<typename M, typename D = Impl::UMFPackDomainType<M>, typename R = Impl::UMFPackRangeType<M>>
class UMFPack : public InverseOperator<D,R>
{
using T = typename M::field_type;
public:
using size_type = SuiteSparse_long;
......@@ -262,9 +266,9 @@ namespace Dune {
/** @brief Type of an associated initializer class. */
using MatrixInitializer = ISTL::Impl::BCCSMatrixInitializer<M, size_type>;
/** @brief The type of the domain of the solver. */
using domain_type = typename Impl::UMFPackVectorChooser<M>::domain_type;
using domain_type = D;
/** @brief The type of the range of the solver. */
using range_type = typename Impl::UMFPackVectorChooser<M>::range_type;
using range_type = R;
//! Category of the solver (see SolverCategory::Category)
virtual SolverCategory::Category category() const
......@@ -789,18 +793,15 @@ namespace Dune {
struct UMFPackCreator {
template<class M> using DomainType = typename Impl::UMFPackVectorChooser<M>::domain_type;
template<class M> using RangeType = typename Impl::UMFPackVectorChooser<M>::range_type;
template<class TL, class M,class=void> struct isValidBlock : std::false_type{};
template<class TL, class M> struct isValidBlock<TL,M,
std::enable_if_t<
std::is_same_v<DomainType<M>, typename Dune::TypeListElement<1,TL>::type>
&& std::is_same_v<RangeType<M>, typename Dune::TypeListElement<2,TL>::type>
std::is_same_v<Impl::UMFPackDomainType<M>, typename Dune::TypeListElement<1,TL>::type>
&& std::is_same_v<Impl::UMFPackRangeType<M>, typename Dune::TypeListElement<2,TL>::type>
>> : std::true_type {};
template<typename TL, typename M>
std::shared_ptr<Dune::InverseOperator<DomainType<M>,RangeType<M>>>
std::shared_ptr<Dune::InverseOperator<Impl::UMFPackDomainType<M>,Impl::UMFPackRangeType<M>>>
operator() (TL /*tl*/, const M& mat, const Dune::ParameterTree& config,
std::enable_if_t<isValidBlock<TL, M>::value,int> = 0) const
{
......@@ -817,8 +818,8 @@ namespace Dune {
{
using D = typename Dune::TypeListElement<1,TL>::type;
using R = typename Dune::TypeListElement<2,TL>::type;
using DU = Std::detected_t<DomainType, M>;
using RU = Std::detected_t<RangeType, M>;
using DU = Std::detected_t< Impl::UMFPackDomainType, M>;
using RU = Std::detected_t< Impl::UMFPackRangeType, M>;
DUNE_THROW(UnsupportedType,
"Unsupported Types in UMFPack:\n"
"Matrix: " << className<M>() << ""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment