From 31f75b7b3c28c2896a55a42cef8d97073d88cf0e Mon Sep 17 00:00:00 2001
From: Markus Blatt <mblatt@dune-project.org>
Date: Fri, 26 Nov 2010 14:44:48 +0000
Subject: [PATCH] Made superlu work for float, too.

[[Imported from SVN: r1403]]
---
 dune/istl/superlu.hh          | 100 +++++++++++++++++++++++++++-------
 dune/istl/supermatrix.hh      |  34 +++++++++++-
 dune/istl/test/superlutest.cc |  34 ++++++++----
 3 files changed, 133 insertions(+), 35 deletions(-)

diff --git a/dune/istl/superlu.hh b/dune/istl/superlu.hh
index 3bd934f9e..6ddc1714d 100644
--- a/dune/istl/superlu.hh
+++ b/dune/istl/superlu.hh
@@ -12,8 +12,10 @@
 #endif
 #ifdef SUPERLU_POST_2005_VERSION
 #include "slu_ddefs.h"
+//#include "slu_sdefs.h"
 #else
 #include "dsp_defs.h"
+//#include "fsp_defs.h"
 #endif
 #include "solvers.hh"
 #include "supermatrix.hh"
@@ -104,7 +106,8 @@ namespace Dune
     /**
      *  \copydoc InverseOperator::apply(X&,Y&,double,InverseOperatorResult&)
      */
-    void apply (domain_type& x, range_type& b, double reduction, InverseOperatorResult& res)
+    void apply (domain_type& x, range_type& b, typename Dune::FieldTraits<T>::real_type reduction,
+                InverseOperatorResult& res)
     {
       apply(x,b,res);
     }
@@ -235,33 +238,33 @@ namespace Dune
     perm_c = new int[mat.M()];
     perm_r = new int[mat.N()];
     etree  = new int[mat.M()];
-    R = new double[mat.N()];
-    C = new double[mat.M()];
+    R = new T[mat.N()];
+    C = new T[mat.M()];
 
     set_default_options(&options);
     // Do the factorization
     B.ncol=0;
     B.Stype=SLU_DN;
-    B.Dtype=SLU_D;
+    B.Dtype= static_cast<Dtype_t>(GetSuperLUType<T>::type);
     B.Mtype= SLU_GE;
     DNformat fakeFormat;
     fakeFormat.lda=mat.N();
     B.Store=&fakeFormat;
     X.Stype=SLU_DN;
-    X.Dtype=SLU_D;
+    X.Dtype=static_cast<Dtype_t>(GetSuperLUType<T>::type);
     X.Mtype= SLU_GE;
     X.ncol=0;
     X.Store=&fakeFormat;
 
-    double rpg, rcond, ferr, berr;
+    T rpg, rcond, ferr, berr;
     int info;
     mem_usage_t memusage;
     SuperLUStat_t stat;
 
     StatInit(&stat);
-    dgssvx(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
-           &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr,
-           &berr, &memusage, &stat, &info);
+    applySuperLU(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
+                 &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr,
+                 &berr, &memusage, &stat, &info);
 
     if(verbose) {
       dinfo<<"LU factorization: dgssvx() returns info "<< info<<std::endl;
@@ -320,6 +323,58 @@ namespace Dune
     options.Fact = FACTORED;
   }
 
+  void createDenseSuperLUMatrix(SuperMatrix* B, int rows, int cols, double* b, int size,
+                                Stype_t stype, Mtype_t mtype)
+  {
+    dCreate_Dense_Matrix(B, rows, cols, b, size, stype, SLU_D, mtype);
+  }
+
+  // Unfortunately SuperLU uses a lot of copy and paste in its headers.
+  // This results in some structs being declares in the headers of the float
+  // AND double version. To get around this we only include the double version
+  // and define the functions of the other versions as extern.
+  extern "C"
+  {
+    // single precision versions of SuperLU
+    void sCreate_Dense_Matrix(SuperMatrix* B, int rows, int cols, float* b, int size,
+                              Stype_t stype, Dtype_t dtype, Mtype_t mtype);
+
+
+    void sgssvx(superlu_options_t *options, SuperMatrix *mat, int *permc, int *permr, int *etree,
+                char *equed, float *R, float *C, SuperMatrix *L, SuperMatrix *U,
+                void *work, int lwork, SuperMatrix *B, SuperMatrix *X,
+                float *rpg, float *rcond, float *ferr, float *berr,
+                mem_usage_t *memusage, SuperLUStat_t *stat, int *info);
+  }
+
+  void createDenseSuperLUMatrix(SuperMatrix* B, int rows, int cols, float* b, int size,
+                                Stype_t stype, Mtype_t mtype)
+  {
+    sCreate_Dense_Matrix(B, rows, cols, b, size, stype, SLU_S, mtype);
+  }
+
+  void applySuperLU(superlu_options_t *options, SuperMatrix *mat, int *permc, int *permr, int *etree,
+                    char *equed, double *R, double *C, SuperMatrix *L, SuperMatrix *U,
+                    void *work, int lwork, SuperMatrix *B, SuperMatrix *X,
+                    double *rpg, double *rcond, double *ferr, double *berr,
+                    mem_usage_t *memusage, SuperLUStat_t *stat, int *info)
+  {
+    dgssvx(options, mat, permc, permr, etree, equed, R, C,
+           L, U, work, lwork, B, X, rpg, rcond, ferr, berr,
+           memusage, stat, info);
+  }
+
+
+  void applySuperLU(superlu_options_t *options, SuperMatrix *mat, int *permc, int *permr, int *etree,
+                    char *equed, float *R, float *C, SuperMatrix *L, SuperMatrix *U,
+                    void *work, int lwork, SuperMatrix *B, SuperMatrix *X,
+                    float *rpg, float *rcond, float *ferr, float *berr,
+                    mem_usage_t *memusage, SuperLUStat_t *stat, int *info)
+  {
+    sgssvx(options, mat, permc, permr, etree, equed, R, C,
+           L, U, work, lwork, B, X, rpg, rcond, ferr, berr,
+           memusage, stat, info);
+  }
   template<typename T, typename A, int n, int m>
   void SuperLU<BCRSMatrix<FieldMatrix<T,n,m>,A> >
   ::apply(domain_type& x, range_type& b, InverseOperatorResult& res)
@@ -328,15 +383,18 @@ namespace Dune
       DUNE_THROW(ISTLError, "Matrix of SuperLU is null!");
 
     if(first) {
-      dCreate_Dense_Matrix(&B, mat.N(), 1,  reinterpret_cast<T*>(&b[0]), mat.N(), SLU_DN, SLU_D, SLU_GE);
-      dCreate_Dense_Matrix(&X, mat.N(), 1,  reinterpret_cast<T*>(&x[0]), mat.N(), SLU_DN, SLU_D, SLU_GE);
+      assert(mat.N()<=static_cast<std::size_t>(std::numeric_limits<int>::max()));
+      createDenseSuperLUMatrix(&B, mat.N(), 1,  reinterpret_cast<T*>(&b[0]),
+                               mat.N(), SLU_DN, SLU_GE);
+      createDenseSuperLUMatrix(&X, mat.N(), 1,  reinterpret_cast<T*>(&x[0]),
+                               mat.N(), SLU_DN, SLU_GE);
       first=false;
     }else{
       ((DNformat*) B.Store)->nzval=&b[0];
       ((DNformat*)X.Store)->nzval=&x[0];
     }
 
-    double rpg, rcond, ferr, berr;
+    T rpg, rcond, ferr, berr;
     int info;
     mem_usage_t memusage;
     SuperLUStat_t stat;
@@ -350,9 +408,9 @@ namespace Dune
      */
     options.IterRefine=DOUBLE;
 
-    dgssvx(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
-           &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr, &berr,
-           &memusage, &stat, &info);
+    applySuperLU(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
+                 &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr, &berr,
+                 &memusage, &stat, &info);
 
     res.iterations=1;
 
@@ -398,15 +456,15 @@ namespace Dune
       DUNE_THROW(ISTLError, "Matrix of SuperLU is null!");
 
     if(first) {
-      dCreate_Dense_Matrix(&B, mat.N(), 1,  b, mat.N(), SLU_DN, SLU_D, SLU_GE);
-      dCreate_Dense_Matrix(&X, mat.N(), 1,  x, mat.N(), SLU_DN, SLU_D, SLU_GE);
+      createDenseSuperLUMatrix(&B, mat.N(), 1,  b, mat.N(), SLU_DN, SLU_GE);
+      createDenseSuperLUMatrix(&X, mat.N(), 1,  x, mat.N(), SLU_DN, SLU_GE);
       first=false;
     }else{
       ((DNformat*) B.Store)->nzval=b;
       ((DNformat*)X.Store)->nzval=x;
     }
 
-    double rpg, rcond, ferr, berr;
+    T rpg, rcond, ferr, berr;
     int info;
     mem_usage_t memusage;
     SuperLUStat_t stat;
@@ -415,9 +473,9 @@ namespace Dune
 
     options.IterRefine=DOUBLE;
 
-    dgssvx(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
-           &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr, &berr,
-           &memusage, &stat, &info);
+    applySuperLU(&options, &static_cast<SuperMatrix&>(mat), perm_c, perm_r, etree, &equed, R, C,
+                 &L, &U, work, lwork, &B, &X, &rpg, &rcond, &ferr, &berr,
+                 &memusage, &stat, &info);
 
     if(verbose) {
       dinfo<<"Triangular solve: dgssvx() returns info "<< info<<std::endl;
diff --git a/dune/istl/supermatrix.hh b/dune/istl/supermatrix.hh
index 2b3851a30..cc91743ed 100644
--- a/dune/istl/supermatrix.hh
+++ b/dune/istl/supermatrix.hh
@@ -498,13 +498,43 @@ namespace Dune
     }
   }
 
+  // Unfortunately SuperLU uses a lot of copy and paste in its headers.
+  // This results in some structs being declares in the headers of the float
+  // AND double version. To get around this we only include the double version
+  // and define the functions of the other versions as extern.
+  extern "C"
+  {
+    // single precision versions of SuperMatrix creation
+    void
+    sCreate_CompCol_Matrix(SuperMatrix *A, int m, int n, int nnz,
+                           float *nzval, int *rowind, int *colptr,
+                           Stype_t stype, Dtype_t dtype, Mtype_t mtype);
+  }
+
+
+  void createCompColSuperMatrix(SuperMatrix *A, int m, int n, int nnz,
+                                double *nzval, int *rowind, int *colptr,
+                                Stype_t stype, Mtype_t mtype)
+  {
+    dCreate_CompCol_Matrix(A, m, n, nnz, nzval, rowind, colptr, stype,
+                           SLU_D, mtype);
+  }
+
+  void createCompColSuperMatrix(SuperMatrix *A, int m, int n, int nnz,
+                                float *nzval, int *rowind, int *colptr,
+                                Stype_t stype, Mtype_t mtype)
+  {
+    sCreate_CompCol_Matrix(A, m, n, nnz, nzval, rowind, colptr, stype,
+                           SLU_S, mtype);
+  }
+
   template<class T, class A, int n, int m>
   void SuperMatrixInitializer<BCRSMatrix<FieldMatrix<T,n,m>,A> >::createMatrix() const
   {
     delete[] marker;
     marker=0;
-    dCreate_CompCol_Matrix(&mat->A, mat->N_, mat->M_, mat->colstart[cols],
-                           mat->values, mat->rowindex, mat->colstart, SLU_NC, static_cast<Dtype_t>(GetSuperLUType<T>::type), SLU_GE);
+    createCompColSuperMatrix(&mat->A, mat->N_, mat->M_, mat->colstart[cols],
+                             mat->values, mat->rowindex, mat->colstart, SLU_NC, SLU_GE);
   }
 
   template<class F, class MRS>
diff --git a/dune/istl/test/superlutest.cc b/dune/istl/test/superlutest.cc
index 1b797e434..9ab152070 100644
--- a/dune/istl/test/superlutest.cc
+++ b/dune/istl/test/superlutest.cc
@@ -8,20 +8,12 @@
 #include <laplacian.hh>
 #include <dune/common/timer.hh>
 #include <dune/istl/superlu.hh>
-int main(int argc, char** argv)
-{
 
-  const int BS=1;
-  std::size_t N=100;
-
-  if(argc>1)
-    N = atoi(argv[1]);
-  std::cout<<"testing for N="<<N<<" BS="<<1<<std::endl;
-
-
-  typedef Dune::FieldMatrix<double,BS,BS> MatrixBlock;
+template<class T, int BS>
+void testSuperLU(std::size_t N){
+  typedef Dune::FieldMatrix<T,BS,BS> MatrixBlock;
   typedef Dune::BCRSMatrix<MatrixBlock> BCRSMat;
-  typedef Dune::FieldVector<double,BS> VectorBlock;
+  typedef Dune::FieldVector<T,BS> VectorBlock;
   typedef Dune::BlockVector<VectorBlock> Vector;
   typedef Dune::MatrixAdapter<BCRSMat,Vector,Vector> Operator;
 
@@ -55,3 +47,21 @@ int main(int argc, char** argv)
   solver1.apply(x,b, res);
 
 }
+
+int main(int argc, char** argv)
+{
+
+  const int BS=1;
+  std::size_t N=100;
+
+  if(argc>1)
+    N = atoi(argv[1]);
+  std::cout<<"testing for N="<<N<<" BS="<<1<<std::endl;
+
+  testSuperLU<double,BS>(N);
+
+  testSuperLU<float,BS>(N);
+
+
+  //testSuperLU<std::complex<double>,1>(N);
+}
-- 
GitLab