From affef6f2d0a0feb6b7982b02c67b406ea5967122 Mon Sep 17 00:00:00 2001
From: Christian Engwer <christi@dune-project.org>
Date: Fri, 20 Nov 2015 14:35:56 +0100
Subject: [PATCH] [test] test fieldmatrix with mixed precision

---
 dune/common/test/fmatrixtest.cc | 79 +++++++++++++++++++--------------
 1 file changed, 45 insertions(+), 34 deletions(-)

diff --git a/dune/common/test/fmatrixtest.cc b/dune/common/test/fmatrixtest.cc
index a310195f4..1f66e71e2 100644
--- a/dune/common/test/fmatrixtest.cc
+++ b/dune/common/test/fmatrixtest.cc
@@ -148,34 +148,36 @@ int test_invert_solve()
   return ret + test_invert_solve<double,3>(A_data2, inv_data2, x2, b2);
 }
 
-template<class K, int n, int m, class X, class Y>
+template<class K, int n, int m, class X, class Y, class XT, class YT>
 void test_mult(FieldMatrix<K, n, m>& A,
-               X& v, Y& f)
+               X& v, Y& f, XT& vT, YT& fT)
 {
   // test the various matrix-vector products
   A.mv(v,f);
-  A.mtv(f,v);
+  A.mtv(fT,vT);
   A.umv(v,f);
-  A.umtv(f,v);
-  A.umhv(f,v);
+  A.umtv(fT,vT);
+  A.umhv(fT,vT);
   A.mmv(v,f);
-  A.mmtv(f,v);
-  A.mmhv(f,v);
-  K scalar = (K)(0.5);
+  A.mmtv(fT,vT);
+  A.mmhv(fT,vT);
+  using S = typename FieldTraits<Y>::field_type;
+  using S2 = typename FieldTraits<XT>::field_type;
+  S scalar = (S)(0.5);
+  S2 scalar2 = (S2)(0.5);
   A.usmv(scalar,v,f);
-  A.usmtv(scalar,f,v);
-  A.usmhv(scalar,f,v);
+  A.usmtv(scalar2,fT,vT);
+  A.usmhv(scalar2,fT,vT);
 }
 
-
-template<class K, int n, int m>
+template<class K, class K2, class K3, int n, int m>
 void test_matrix()
 {
   typedef typename FieldMatrix<K,n,m>::size_type size_type;
 
   FieldMatrix<K,n,m> A;
-  FieldVector<K,n> f;
-  FieldVector<K,m> v;
+  FieldVector<K2,m> v;
+  FieldVector<K3,n> f;
 
   // assign matrix
   A=K();
@@ -184,11 +186,11 @@ void test_matrix()
     for (size_type j=0; j<m; j++)
       A[i][j] = i*j;
   // iterator matrix
-  typename FieldMatrix<K,n,m>::RowIterator rit = A.begin();
+  auto rit = A.begin();
   for (; rit!=A.end(); ++rit)
   {
     rit.index();
-    typename FieldMatrix<K,n,m>::ColIterator cit = rit->begin();
+    auto cit = rit->begin();
     for (; cit!=rit->end(); ++cit)
     {
       cit.index();
@@ -203,8 +205,8 @@ void test_matrix()
   for (size_type i=0; i<v.dim(); i++)
     v[i] = i;
   // iterator vector
-  typename FieldVector<K,m>::iterator it = v.begin();
-  typename FieldVector<K,m>::ConstIterator end = v.end();
+  auto it = v.begin();
+  auto end = v.end();
   for (; it!=end; ++it)
   {
     it.index();
@@ -226,10 +228,10 @@ void test_matrix()
   A.umv(v,f);
   // check that mv and umv are doing the same thing
   {
-    FieldVector<K,n> res2(0);
-    FieldVector<K,n> res1;
+    FieldVector<K3,n> res2(0);
+    FieldVector<K3,n> res1;
 
-    FieldVector<K,m> b(1);
+    FieldVector<K2,m> b(1);
 
     A.mv(b, res1);
     A.umv(b, res2);
@@ -241,8 +243,11 @@ void test_matrix()
   }
 
   {
-    FieldVector<K,m> v0 ( v );
-    test_mult(A, v0, f );
+    FieldVector<K2,m> v0 (v);
+    FieldVector<K3,n> f0 (f);
+    FieldVector<K3,m> vT (0);
+    FieldVector<K2,n> fT (0);
+    test_mult(A, v0, f0, vT, fT);
   }
 
   // {
@@ -541,10 +546,10 @@ test_infinity_norms()
 }
 
 
-template< class K, int rows, int cols >
+template< class K, class K2, int rows, int cols >
 void test_interface()
 {
-  typedef CheckMatrixInterface::UseFieldVector< K, rows, cols > Traits;
+  typedef CheckMatrixInterface::UseFieldVector< K2, rows, cols > Traits;
   typedef Dune::FieldMatrix< K, rows, cols > FMatrix;
 
   FMatrix m( 1 );
@@ -573,19 +578,25 @@ int main()
     test_initialisation();
 
     // test 1 x 1 matrices
-    test_interface<float, 1, 1>();
-    test_matrix<float, 1, 1>();
+    test_interface<float, float, 1, 1>();
+    test_matrix<float, float, float, 1, 1>();
     ScalarOperatorTest<float>();
-    test_matrix<double, 1, 1>();
+    test_matrix<double, double, double, 1, 1>();
     ScalarOperatorTest<double>();
     // test n x m matrices
-    test_interface<int, 10, 5>();
-    test_matrix<int, 10, 5>();
-    test_matrix<double, 5, 10>();
-    test_interface<double, 5, 10>();
+    test_interface<int, int, 10, 5>();
+    test_matrix<int, int, int, 10, 5>();
+    test_matrix<double, double, double, 5, 10>();
+    test_interface<double, double, 5, 10>();
+    // mixed precision
+    test_interface<float, float, 5, 10>();
+    test_matrix<float, double, float, 5, 10>();
     // test complex matrices
-    test_matrix<std::complex<float>, 1, 1>();
-    test_matrix<std::complex<double>, 5, 10>();
+    test_matrix<std::complex<float>, std::complex<float>, std::complex<float>, 1, 1>();
+    test_matrix<std::complex<double>, std::complex<double>, std::complex<double>, 5, 10>();
+    // test complex/real matrices mixed case
+    test_matrix<float, std::complex<float>, std::complex<float>, 1, 1>();
+    test_matrix<std::complex<float>, float, std::complex<float>, 1, 1>();
 #if HAVE_LAPACK
     // test eigemvalue computation
     test_ev<double>();
-- 
GitLab