diff --git a/dune/common/diagonalmatrix.hh b/dune/common/diagonalmatrix.hh index 0f83bd6c5a9546c00cf3ce2931d6eba2aa8f7553..ae41c1f5f2c83aaac549fc529401ab96827ac5fb 100644 --- a/dune/common/diagonalmatrix.hh +++ b/dune/common/diagonalmatrix.hh @@ -20,10 +20,13 @@ #include <dune/common/boundschecking.hh> #include <dune/common/densematrix.hh> +#include <dune/common/dynmatrix.hh> #include <dune/common/exceptions.hh> #include <dune/common/fmatrix.hh> +#include <dune/common/ftraits.hh> #include <dune/common/fvector.hh> #include <dune/common/genericiterator.hh> +#include <dune/common/matrixconcepts.hh> #include <dune/common/typetraits.hh> @@ -473,7 +476,39 @@ namespace Dune { return result; } - + /** + * \brief Multiply a diagonal matrix with a dense matrix. + * + * The result of this multiplication is either a `FieldMatrix` if the + * `matrixB` is a matrix with static size, or a `DynamicMatrix`. This + * overload is deactivated for `matrixB` being a `FieldMatrix` since this + * is already covered by the corresponding overload of the `operator*` in + * the `FieldMatrix` class. + */ + template <class OtherMatrix, + std::enable_if_t<(Impl::IsDenseMatrix<OtherMatrix>::value), int> = 0, + std::enable_if_t<(not Impl::IsFieldMatrix<OtherMatrix>::value), int> = 0> + friend auto operator* ( const DiagonalMatrix& matrixA, + const OtherMatrix& matrixB) + { + using OtherField = typename FieldTraits<OtherMatrix>::field_type; + using F = typename PromotionTraits<field_type, OtherField>::PromotedType; + + auto result = [&]{ + if constexpr (Impl::IsStaticSizeMatrix_v<OtherMatrix>) { + static_assert(n == OtherMatrix::rows); + return FieldMatrix<F, n, OtherMatrix::cols>{}; + } else { + assert(n == matrixB.N()); + return DynamicMatrix<F>{n,matrixB.M()}; + } + }(); + + for (int i = 0; i < result.N(); ++i) + for (int j = 0; j < result.M(); ++j) + result[i][j] = matrixA.diagonal(i) * matrixB[i][j]; + return result; + } //===== sizes @@ -640,6 +675,31 @@ namespace Dune { return DiagonalMatrix<typename PromotionTraits<K,OtherScalar>::PromotedType, 1>{matrixA.diagonal(0)*matrixB.diagonal(0)}; } + template <class OtherMatrix, + std::enable_if_t<(Impl::IsDenseMatrix<OtherMatrix>::value), int> = 0, + std::enable_if_t<(not Impl::IsFieldMatrix<OtherMatrix>::value), int> = 0> + friend auto operator* ( const DiagonalMatrix& matrixA, + const OtherMatrix& matrixB) + { + using OtherField = typename FieldTraits<OtherMatrix>::field_type; + using F = typename PromotionTraits<K, OtherField>::PromotedType; + + auto result = [&]{ + if constexpr (Impl::IsStaticSizeMatrix_v<OtherMatrix>) { + static_assert(1 == OtherMatrix::rows); + return FieldMatrix<F, 1, OtherMatrix::cols>{}; + } else { + assert(1 == matrixB.N()); + return DynamicMatrix<F>{1,matrixB.M()}; + } + }(); + + for (int i = 0; i < result.N(); ++i) + for (int j = 0; j < result.M(); ++j) + result[i][j] = matrixA.diagonal(i) * matrixB[i][j]; + return result; + } + }; #endif diff --git a/dune/common/test/diagonalmatrixtest.cc b/dune/common/test/diagonalmatrixtest.cc index 8d244b690b46f78ab5969f0b6cbe9b6d75f8d1ae..e25711efcaf6ce160ff6aa16cb3026b4139b5977 100644 --- a/dune/common/test/diagonalmatrixtest.cc +++ b/dune/common/test/diagonalmatrixtest.cc @@ -9,9 +9,11 @@ #include <iostream> #include <algorithm> +#include <dune/common/dynmatrix.hh> #include <dune/common/exceptions.hh> #include <dune/common/fvector.hh> #include <dune/common/diagonalmatrix.hh> +#include <dune/common/transpose.hh> #include "checkmatrixinterface.hh" @@ -64,6 +66,19 @@ void test_matrix() DiagonalMatrix<K,n> AT = A.transposed(); if (AT != A) DUNE_THROW(FMatrixError, "Return value of DiagoalMatrix::transposed() incorrect!"); + + // check matrix-matrix multiplication + [[maybe_unused]] auto AA = A * A; + [[maybe_unused]] auto AF = A * AFM; + [[maybe_unused]] auto FA = AFM * A; + [[maybe_unused]] auto AFt = A * transposedView(AFM); + [[maybe_unused]] auto FtA = transposedView(AFM) * A; + + Dune::DynamicMatrix<K> ADM(n,n); + [[maybe_unused]] auto AD = A * ADM; + // [[maybe_unused]] auto DA = ADM * A; + [[maybe_unused]] auto ADt = A * transposedView(ADM); + // [[maybe_unused]] auto DtA = transposedView(ADM) * A; } template<class K, int n> diff --git a/dune/common/test/transposetest.cc b/dune/common/test/transposetest.cc index 910e86c317220b803f457f18ed2674304205ed13..63d1090788b16f02a2fe9fe574fe0643235fe2f8 100644 --- a/dune/common/test/transposetest.cc +++ b/dune/common/test/transposetest.cc @@ -260,7 +260,7 @@ int main() testFillDense(b); checkTranspose(suite,a); checkTranspose(suite,b); - // checkTransposeProduct(suite,a,b); + checkTransposeProduct(suite,a,b); } { @@ -270,7 +270,7 @@ int main() testFillDense(b); checkTranspose(suite,a); checkTranspose(suite,b); - // checkTransposeProduct(suite,a,b); + checkTransposeProduct(suite,a,b); } {