Skip to content
Snippets Groups Projects
Commit ef35ce99 authored by Oliver Sander's avatar Oliver Sander
Browse files

Introduce methods N() and M() to return the number of matrix rows and columns

These methods should be there in analogy to all other ISTL matrices.
Besides, this patch makes the code use them internally, which makes
for a few nice simplifications.
parent 1462583d
No related branches found
No related tags found
No related merge requests found
......@@ -168,22 +168,19 @@ namespace Dune {
* do y += A x in next row
*/
static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
static const int rowlen = mpl::at_c<TMatrix,crow>::type::size();
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x);
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::umv(y, A, x);
}
/**
* do y -= A x in next row
*/
static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
static const int rowlen = mpl::at_c<TMatrix,crow>::type::size();
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x);
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::mmv(y, A, x);
}
template <typename AlphaType>
static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
static const int rowlen = mpl::at_c<TMatrix,crow>::type::size();
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,TMatrix::M(),TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
}
};
......@@ -226,51 +223,63 @@ namespace Dune {
typedef typename T1::field_type field_type;
/** \brief Return the number of matrix rows */
static DUNE_CONSTEXPR std::size_t N()
{
return mpl::size<type>::value;
}
/** \brief Return the number of matrix columns */
static DUNE_CONSTEXPR std::size_t M()
{
return T1::size();
}
/**
* assignment operator
*/
template<typename T>
void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); }
void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<N(),type,T>::equalize(*this, newval); }
/** \brief y = A x
*/
template<typename X, typename Y>
void mv (const X& x, Y& y) const {
static_assert(x.size() == T1::size(), "length of x does not match row length");
static_assert(y.size() == mpl::size<type>::value, "length of y does not match row count");
static_assert(x.size() == M(), "length of x does not match row length");
static_assert(y.size() == N(), "length of y does not match row count");
y = 0; //reset y (for mv uses umv)
MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,T1::size(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
}
/** \brief y += A x
*/
template<typename X, typename Y>
void umv (const X& x, Y& y) const {
static_assert(x.size() == T1::size(), "length of x does not match row length");
static_assert(y.size() == mpl::size<type>::value, "length of y does not match row count");
static_assert(x.size() == M(), "length of x does not match row length");
static_assert(y.size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,T1::size(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
}
/** \brief y -= A x
*/
template<typename X, typename Y>
void mmv (const X& x, Y& y) const {
static_assert(x.size() == T1::size(), "length of x does not match row length");
static_assert(y.size() == mpl::size<type>::value, "length of y does not match row count");
static_assert(x.size() == M(), "length of x does not match row length");
static_assert(y.size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,T1::size(),Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements
}
/** \brief y += alpha A x
*/
template<typename AlphaType, typename X, typename Y>
void usmv (const AlphaType& alpha, const X& x, Y& y) const {
static_assert(x.size() == T1::size(), "length of x does not match row length");
static_assert(y.size() == mpl::size<type>::value, "length of y does not match row count");
static_assert(x.size() == M(), "length of x does not match row length");
static_assert(y.size() == N(), "length of y does not match row count");
MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,T1::size(),Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements
MultiTypeBlockMatrix_VectMul<0,N(),0,M(),Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements
}
......@@ -288,9 +297,7 @@ namespace Dune {
template<typename T1, typename T2, typename T3, typename T4, typename T5,
typename T6, typename T7, typename T8, typename T9>
std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) {
static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value; //row count
static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value; //col count of first row
MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m);
MultiTypeBlockMatrix_Print<0,m.N(),0,m.M(),MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m);
return s;
}
......@@ -382,7 +389,7 @@ namespace Dune {
static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
auto rhs = std::get<crow> (b);
MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::at_c<TMatrix,crow>::type::size()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
MultiTypeBlockMatrix_Solver_Col<I,crow,0,TMatrix::M()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
//solve on blocklevel I-1
algmeta_itsteps<I-1>::bsorf(std::get<crow>( fusion::at_c<crow>(A)), std::get<crow>(v),rhs,w);
std::get<crow>(x).axpy(w,std::get<crow>(v));
......@@ -403,7 +410,7 @@ namespace Dune {
static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
auto rhs = std::get<crow> (b);
MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::at_c<TMatrix,crow>::type::size()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
MultiTypeBlockMatrix_Solver_Col<I,crow,0, TMatrix::M()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
//solve on blocklevel I-1
algmeta_itsteps<I-1>::bsorb(std::get<crow>( fusion::at_c<crow>(A)), std::get<crow>(v),rhs,w);
std::get<crow>(x).axpy(w,std::get<crow>(v));
......@@ -425,7 +432,7 @@ namespace Dune {
static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
auto rhs = std::get<crow> (b);
MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::at_c<TMatrix,crow>::type::size()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
MultiTypeBlockMatrix_Solver_Col<I,crow,0, TMatrix::M()>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
//solve on blocklevel I-1
algmeta_itsteps<I-1>::dbjac(std::get<crow>( fusion::at_c<crow>(A)), std::get<crow>(v),rhs,w);
MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbjac(A,x,v,b,w); //next row
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment