Skip to content
Snippets Groups Projects
Commit 985300be authored by Carsten Gräser's avatar Carsten Gräser
Browse files

[cleanup] Simplify MultiTypeBlockMatrix::mv, mmv, umv, usmv

The new implementation is based on Hybrid::forEach
parent 72b1d692
No related branches found
No related tags found
1 merge request!45Use hybridutilities
......@@ -71,86 +71,6 @@ namespace Dune {
/**
@brief Matrix-vector multiplication
This class implements matrix vector multiplication for MultiTypeBlockMatrix/MultiTypeBlockVector types
*/
template<int crow, int remain_rows, int ccol, int remain_cols,
typename TVecY, typename TMatrix, typename TVecX>
class MultiTypeBlockMatrix_VectMul {
public:
/** \brief y += A x
*/
static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
std::get<ccol>( std::get<crow>(A) ).umv( std::get<ccol>(x), std::get<crow>(y) );
MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x);
}
/** \brief y -= A x
*/
static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
std::get<ccol>( std::get<crow>(A) ).mmv( std::get<ccol>(x), std::get<crow>(y) );
MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x);
}
/** \brief y += alpha A x
* \tparam AlphaType Type used for the scalar factor 'alpha'
*/
template<typename AlphaType>
static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
std::get<ccol>( std::get<crow>(A) ).usmv(alpha, std::get<ccol>(x), std::get<crow>(y) );
MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
}
};
//specialization for remain_cols = 0
template<int crow, int remain_rows,int ccol, typename TVecY,
typename TMatrix, typename TVecX>
class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> { //start iteration over next row
public:
/**
* do y += A x in next row
*/
static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::umv(y, A, x);
}
/**
* do y -= A x in next row
*/
static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::mmv(y, A, x);
}
template <typename AlphaType>
static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
}
};
//specialization for remain_rows = 0
template<int crow, int ccol, int remain_cols, typename TVecY,
typename TMatrix, typename TVecX>
class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> {
//end recursion
public:
static void umv(TVecY&, const TMatrix&, const TVecX&) {}
static void mmv(TVecY&, const TMatrix&, const TVecX&) {}
template<typename AlphaType>
static void usmv(const AlphaType&, TVecY&, const TMatrix&, const TVecX&) {}
};
/**
@brief A Matrix class to support different block types
......@@ -240,7 +160,7 @@ namespace Dune {
static_assert(X::size() == M(), "length of x does not match row length");
static_assert(Y::size() == N(), "length of y does not match row count");
y = 0; //reset y (for mv uses umv)
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
umv(x,y);
}
/** \brief y += A x
......@@ -249,7 +169,12 @@ namespace Dune {
void umv (const X& x, Y& y) const {
static_assert(X::size() == M(), "length of x does not match row length");
static_assert(Y::size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
using namespace Dune::Hybrid;
forEach(integralRange(Hybrid::size(y)), [&](auto&& i) {
forEach(integralRange(Hybrid::size(x)), [&](auto&& j) {
(*this)[i][j].umv(x[j], y[i]);
});
});
}
/** \brief y -= A x
......@@ -258,7 +183,12 @@ namespace Dune {
void mmv (const X& x, Y& y) const {
static_assert(X::size() == M(), "length of x does not match row length");
static_assert(Y::size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements
using namespace Dune::Hybrid;
forEach(integralRange(Hybrid::size(y)), [&](auto&& i) {
forEach(integralRange(Hybrid::size(x)), [&](auto&& j) {
(*this)[i][j].mmv(x[j], y[i]);
});
});
}
/** \brief y += alpha A x
......@@ -267,7 +197,12 @@ namespace Dune {
void usmv (const AlphaType& alpha, const X& x, Y& y) const {
static_assert(X::size() == M(), "length of x does not match row length");
static_assert(Y::size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements
using namespace Dune::Hybrid;
forEach(integralRange(Hybrid::size(y)), [&](auto&& i) {
forEach(integralRange(Hybrid::size(x)), [&](auto&& j) {
(*this)[i][j].usmv(alpha, x[j], y[i]);
});
});
}
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment