Skip to content
Snippets Groups Projects
Commit bdf74dca authored by Christian Engwer's avatar Christian Engwer
Browse files

[densematrix] use the correct scaling parameter for usmv and friends

In the current state the scaling parameter for matrix-vector products
(e.g. the alpha parameter for DenseMatrix<...>::usmv) is the
field_type of the matrix. I you consider mixed precision, like in the
xblas library, you would actually want the scaling to be of the same
accuracy as the output vector; an other reason is vectorization for
one matrix and multiple vectors.

The patch contains two changes:

a) the FieldTraits are extended, so that the contain specializations
   for C and C++ vectors
b) the thre matrix implementations are changed, such that they deduce
   the type as FieldTraits<Y>::field_type, where Y is the type of the
   result vector.

In the case of identity matrix the old implementation was laso completely wrong, as it used
Y::axpy to implement usmv and Y::axpy expects alpha to be Y::value_type, so that we already
had a type clash, which just went unnoticed, as we never used mixed-types.
parent d33190bc
No related branches found
No related tags found
No related merge requests found
......@@ -511,7 +511,8 @@ namespace Dune
//! y += alpha A x
template<class X, class Y>
void usmv (const field_type& alpha, const X& x, Y& y) const
void usmv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=M()) DUNE_THROW(FMatrixError,"index out of range");
......@@ -524,7 +525,8 @@ namespace Dune
//! y += alpha A^T x
template<class X, class Y>
void usmtv (const field_type& alpha, const X& x, Y& y) const
void usmtv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=N()) DUNE_THROW(FMatrixError,"index out of range");
......@@ -538,7 +540,8 @@ namespace Dune
//! y += alpha A^H x
template<class X, class Y>
void usmhv (const field_type& alpha, const X& x, Y& y) const
void usmhv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=N()) DUNE_THROW(FMatrixError,"index out of range");
......
......@@ -364,7 +364,8 @@ namespace Dune {
//! y += alpha A x
template<class X, class Y>
void usmv (const K& alpha, const X& x, Y& y) const
void usmv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=M()) DUNE_THROW(FMatrixError,"index out of range");
......@@ -376,7 +377,8 @@ namespace Dune {
//! y += alpha A^T x
template<class X, class Y>
void usmtv (const K& alpha, const X& x, Y& y) const
void usmtv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=N()) DUNE_THROW(FMatrixError,"index out of range");
......@@ -388,7 +390,8 @@ namespace Dune {
//! y += alpha A^H x
template<class X, class Y>
void usmhv (const K& alpha, const X& x, Y& y) const
void usmhv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
#ifdef DUNE_FMatrix_WITH_CHECKING
if (x.N()!=N()) DUNE_THROW(FMatrixError,"index out of range");
......
......@@ -8,6 +8,7 @@
*/
#include <complex>
#include <vector>
namespace Dune {
......@@ -41,6 +42,20 @@ namespace Dune {
typedef T real_type;
};
template<class T, unsigned int N>
struct FieldTraits< T[N] >
{
typedef typename FieldTraits<T>::field_type field_type;
typedef typename FieldTraits<T>::real_type real_type;
};
template<class T>
struct FieldTraits< std::vector<T> >
{
typedef typename FieldTraits<T>::field_type field_type;
typedef typename FieldTraits<T>::real_type real_type;
};
} // end namespace Dune
#endif // DUNE_FTRAITS_HH
......@@ -102,21 +102,24 @@ namespace Dune
/** \copydoc Dune::DenseMatrix::usmv */
template< class X, class Y >
void usmv ( const field_type &alpha, const X &x, Y &y ) const
void usmv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
y.axpy( alpha, x );
}
/** \copydoc Dune::DenseMatrix::usmtv */
template< class X, class Y >
void usmtv ( const field_type &alpha, const X &x, Y &y ) const
void usmtv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
y.axpy( alpha, x );
}
/** \copydoc Dune::DenseMatrix::usmhv */
template< class X, class Y >
void usmhv ( const field_type &alpha, const X &x, Y &y ) const
void usmhv (const typename FieldTraits<Y>::field_type & alpha,
const X& x, Y& y) const
{
y.axpy( alpha, x );
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment