From affef6f2d0a0feb6b7982b02c67b406ea5967122 Mon Sep 17 00:00:00 2001 From: Christian Engwer <christi@dune-project.org> Date: Fri, 20 Nov 2015 14:35:56 +0100 Subject: [PATCH] [test] test fieldmatrix with mixed precision --- dune/common/test/fmatrixtest.cc | 79 +++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/dune/common/test/fmatrixtest.cc b/dune/common/test/fmatrixtest.cc index a310195f4..1f66e71e2 100644 --- a/dune/common/test/fmatrixtest.cc +++ b/dune/common/test/fmatrixtest.cc @@ -148,34 +148,36 @@ int test_invert_solve() return ret + test_invert_solve<double,3>(A_data2, inv_data2, x2, b2); } -template<class K, int n, int m, class X, class Y> +template<class K, int n, int m, class X, class Y, class XT, class YT> void test_mult(FieldMatrix<K, n, m>& A, - X& v, Y& f) + X& v, Y& f, XT& vT, YT& fT) { // test the various matrix-vector products A.mv(v,f); - A.mtv(f,v); + A.mtv(fT,vT); A.umv(v,f); - A.umtv(f,v); - A.umhv(f,v); + A.umtv(fT,vT); + A.umhv(fT,vT); A.mmv(v,f); - A.mmtv(f,v); - A.mmhv(f,v); - K scalar = (K)(0.5); + A.mmtv(fT,vT); + A.mmhv(fT,vT); + using S = typename FieldTraits<Y>::field_type; + using S2 = typename FieldTraits<XT>::field_type; + S scalar = (S)(0.5); + S2 scalar2 = (S2)(0.5); A.usmv(scalar,v,f); - A.usmtv(scalar,f,v); - A.usmhv(scalar,f,v); + A.usmtv(scalar2,fT,vT); + A.usmhv(scalar2,fT,vT); } - -template<class K, int n, int m> +template<class K, class K2, class K3, int n, int m> void test_matrix() { typedef typename FieldMatrix<K,n,m>::size_type size_type; FieldMatrix<K,n,m> A; - FieldVector<K,n> f; - FieldVector<K,m> v; + FieldVector<K2,m> v; + FieldVector<K3,n> f; // assign matrix A=K(); @@ -184,11 +186,11 @@ void test_matrix() for (size_type j=0; j<m; j++) A[i][j] = i*j; // iterator matrix - typename FieldMatrix<K,n,m>::RowIterator rit = A.begin(); + auto rit = A.begin(); for (; rit!=A.end(); ++rit) { rit.index(); - typename FieldMatrix<K,n,m>::ColIterator cit = rit->begin(); + auto cit = rit->begin(); for (; cit!=rit->end(); ++cit) { cit.index(); @@ -203,8 +205,8 @@ void test_matrix() for (size_type i=0; i<v.dim(); i++) v[i] = i; // iterator vector - typename FieldVector<K,m>::iterator it = v.begin(); - typename FieldVector<K,m>::ConstIterator end = v.end(); + auto it = v.begin(); + auto end = v.end(); for (; it!=end; ++it) { it.index(); @@ -226,10 +228,10 @@ void test_matrix() A.umv(v,f); // check that mv and umv are doing the same thing { - FieldVector<K,n> res2(0); - FieldVector<K,n> res1; + FieldVector<K3,n> res2(0); + FieldVector<K3,n> res1; - FieldVector<K,m> b(1); + FieldVector<K2,m> b(1); A.mv(b, res1); A.umv(b, res2); @@ -241,8 +243,11 @@ void test_matrix() } { - FieldVector<K,m> v0 ( v ); - test_mult(A, v0, f ); + FieldVector<K2,m> v0 (v); + FieldVector<K3,n> f0 (f); + FieldVector<K3,m> vT (0); + FieldVector<K2,n> fT (0); + test_mult(A, v0, f0, vT, fT); } // { @@ -541,10 +546,10 @@ test_infinity_norms() } -template< class K, int rows, int cols > +template< class K, class K2, int rows, int cols > void test_interface() { - typedef CheckMatrixInterface::UseFieldVector< K, rows, cols > Traits; + typedef CheckMatrixInterface::UseFieldVector< K2, rows, cols > Traits; typedef Dune::FieldMatrix< K, rows, cols > FMatrix; FMatrix m( 1 ); @@ -573,19 +578,25 @@ int main() test_initialisation(); // test 1 x 1 matrices - test_interface<float, 1, 1>(); - test_matrix<float, 1, 1>(); + test_interface<float, float, 1, 1>(); + test_matrix<float, float, float, 1, 1>(); ScalarOperatorTest<float>(); - test_matrix<double, 1, 1>(); + test_matrix<double, double, double, 1, 1>(); ScalarOperatorTest<double>(); // test n x m matrices - test_interface<int, 10, 5>(); - test_matrix<int, 10, 5>(); - test_matrix<double, 5, 10>(); - test_interface<double, 5, 10>(); + test_interface<int, int, 10, 5>(); + test_matrix<int, int, int, 10, 5>(); + test_matrix<double, double, double, 5, 10>(); + test_interface<double, double, 5, 10>(); + // mixed precision + test_interface<float, float, 5, 10>(); + test_matrix<float, double, float, 5, 10>(); // test complex matrices - test_matrix<std::complex<float>, 1, 1>(); - test_matrix<std::complex<double>, 5, 10>(); + test_matrix<std::complex<float>, std::complex<float>, std::complex<float>, 1, 1>(); + test_matrix<std::complex<double>, std::complex<double>, std::complex<double>, 5, 10>(); + // test complex/real matrices mixed case + test_matrix<float, std::complex<float>, std::complex<float>, 1, 1>(); + test_matrix<std::complex<float>, float, std::complex<float>, 1, 1>(); #if HAVE_LAPACK // test eigemvalue computation test_ev<double>(); -- GitLab