Skip to content
Snippets Groups Projects
Commit 5e74ef3e authored by Simon Praetorius's avatar Simon Praetorius
Browse files

Fix for DiagonalMatrix * OtherMatrix

parent 95b4edaa
No related branches found
No related tags found
1 merge request!1479Fix for DiagonalMatrix * OtherMatrix
......@@ -20,10 +20,13 @@
#include <dune/common/boundschecking.hh>
#include <dune/common/densematrix.hh>
#include <dune/common/dynmatrix.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/ftraits.hh>
#include <dune/common/fvector.hh>
#include <dune/common/genericiterator.hh>
#include <dune/common/matrixconcepts.hh>
#include <dune/common/typetraits.hh>
......@@ -473,7 +476,39 @@ namespace Dune {
return result;
}
/**
* \brief Multiply a diagonal matrix with a dense matrix.
*
* The result of this multiplication is either a `FieldMatrix` if the
* `matrixB` is a matrix with static size, or a `DynamicMatrix`. This
* overload is deactivated for `matrixB` being a `FieldMatrix` since this
* is already covered by the corresponding overload of the `operator*` in
* the `FieldMatrix` class.
*/
template <class OtherMatrix,
std::enable_if_t<(Impl::IsDenseMatrix<OtherMatrix>::value), int> = 0,
std::enable_if_t<(not Impl::IsFieldMatrix<OtherMatrix>::value), int> = 0>
friend auto operator* ( const DiagonalMatrix& matrixA,
const OtherMatrix& matrixB)
{
using OtherField = typename FieldTraits<OtherMatrix>::field_type;
using F = typename PromotionTraits<field_type, OtherField>::PromotedType;
auto result = [&]{
if constexpr (Impl::IsStaticSizeMatrix_v<OtherMatrix>) {
static_assert(n == OtherMatrix::rows);
return FieldMatrix<F, n, OtherMatrix::cols>{};
} else {
assert(n == matrixB.N());
return DynamicMatrix<F>{n,matrixB.M()};
}
}();
for (int i = 0; i < result.N(); ++i)
for (int j = 0; j < result.M(); ++j)
result[i][j] = matrixA.diagonal(i) * matrixB[i][j];
return result;
}
//===== sizes
......@@ -640,6 +675,31 @@ namespace Dune {
return DiagonalMatrix<typename PromotionTraits<K,OtherScalar>::PromotedType, 1>{matrixA.diagonal(0)*matrixB.diagonal(0)};
}
template <class OtherMatrix,
std::enable_if_t<(Impl::IsDenseMatrix<OtherMatrix>::value), int> = 0,
std::enable_if_t<(not Impl::IsFieldMatrix<OtherMatrix>::value), int> = 0>
friend auto operator* ( const DiagonalMatrix& matrixA,
const OtherMatrix& matrixB)
{
using OtherField = typename FieldTraits<OtherMatrix>::field_type;
using F = typename PromotionTraits<K, OtherField>::PromotedType;
auto result = [&]{
if constexpr (Impl::IsStaticSizeMatrix_v<OtherMatrix>) {
static_assert(1 == OtherMatrix::rows);
return FieldMatrix<F, 1, OtherMatrix::cols>{};
} else {
assert(1 == matrixB.N());
return DynamicMatrix<F>{1,matrixB.M()};
}
}();
for (int i = 0; i < result.N(); ++i)
for (int j = 0; j < result.M(); ++j)
result[i][j] = matrixA.diagonal(i) * matrixB[i][j];
return result;
}
};
#endif
......
......@@ -9,9 +9,11 @@
#include <iostream>
#include <algorithm>
#include <dune/common/dynmatrix.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/fvector.hh>
#include <dune/common/diagonalmatrix.hh>
#include <dune/common/transpose.hh>
#include "checkmatrixinterface.hh"
......@@ -64,6 +66,19 @@ void test_matrix()
DiagonalMatrix<K,n> AT = A.transposed();
if (AT != A)
DUNE_THROW(FMatrixError, "Return value of DiagoalMatrix::transposed() incorrect!");
// check matrix-matrix multiplication
[[maybe_unused]] auto AA = A * A;
[[maybe_unused]] auto AF = A * AFM;
[[maybe_unused]] auto FA = AFM * A;
[[maybe_unused]] auto AFt = A * transposedView(AFM);
[[maybe_unused]] auto FtA = transposedView(AFM) * A;
Dune::DynamicMatrix<K> ADM(n,n);
[[maybe_unused]] auto AD = A * ADM;
// [[maybe_unused]] auto DA = ADM * A;
[[maybe_unused]] auto ADt = A * transposedView(ADM);
// [[maybe_unused]] auto DtA = transposedView(ADM) * A;
}
template<class K, int n>
......
......@@ -260,7 +260,7 @@ int main()
testFillDense(b);
checkTranspose(suite,a);
checkTranspose(suite,b);
// checkTransposeProduct(suite,a,b);
checkTransposeProduct(suite,a,b);
}
{
......@@ -270,7 +270,7 @@ int main()
testFillDense(b);
checkTranspose(suite,a);
checkTranspose(suite,b);
// checkTransposeProduct(suite,a,b);
checkTransposeProduct(suite,a,b);
}
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment