From e5c71cd75c8a44f9c4adda2f0e47764b2b238921 Mon Sep 17 00:00:00 2001
From: Markus Blatt <mblatt@dune-project.org>
Date: Thu, 9 Dec 2010 09:30:42 +0000
Subject: [PATCH] MPI_Gatherv -> MPI_Allgatherv (the latter is optimized on
 Blue Gene P) Use non-blocking communication to build the index sets for a
 considerable speedup.

[[Imported from SVN: r1425]]
---
 dune/istl/repartition.hh | 345 +++++++++++++++++++++++----------------
 1 file changed, 207 insertions(+), 138 deletions(-)

diff --git a/dune/istl/repartition.hh b/dune/istl/repartition.hh
index 5ada1cc5c..4ba773274 100644
--- a/dune/istl/repartition.hh
+++ b/dune/istl/repartition.hh
@@ -850,17 +850,15 @@ namespace Dune
         // The diagonal entries are the number of nodes on the process.
         // The offdiagonal entries are the number of edges leading to other processes.
 
-        idxtype *xadj=new idxtype[2], *vwgt = new idxtype[1];
+        idxtype *xadj=new idxtype[2], *vwgt = 0;
         idxtype *vtxdist=new idxtype[oocomm.communicator().size()+1];
-        idxtype * adjncy=new idxtype[noNeighbours], *adjwgt = new idxtype[noNeighbours];
+        idxtype * adjncy=new idxtype[noNeighbours], *adjwgt = 0;
 
         // each process has exactly one vertex!
         for(int i=0; i<oocomm.communicator().size(); ++i)
           vtxdist[i]=i;
         vtxdist[oocomm.communicator().size()]=oocomm.communicator().size();
 
-
-        vwgt[0]=mat.N(); // weight is numer of rows TODO: Should actually be the nonzeros.
         xadj[0]=0;
         xadj[1]=noNeighbours;
 
@@ -897,26 +895,37 @@ namespace Dune
         typedef typename  IndexSet::LocalIndex LocalIndex;
 
         idxtype* adjp=adjncy;
+
+#ifdef USE_WEIGHTS
+        vwgt   = new idxtype[1];
+        vwgt[0]= mat.N(); // weight is numer of rows TODO: Should actually be the nonzeros.
+
+        adjwgt = new idxtype[noNeighbours];
         idxtype* adjwp=adjwgt;
+#endif
 
         for(NeighbourIterator n= oocomm.remoteIndices().begin(); n !=  oocomm.remoteIndices().end();
             ++n)
           if(n->first != rank) {
             *adjp=n->first;
-            *adjwp=1; //edgecount[n->first];
             ++adjp;
+#ifdef USE_WEIGHTS
+            *adjwp=1; //edgecount[n->first];
             ++adjwp;
+#endif
           }
-
         assert(isValidGraph(vtxdist[rank+1]-vtxdist[rank],
                             vtxdist[oocomm.communicator().size()],
                             noNeighbours, xadj, adjncy, false));
 
-        int wgtflag=3, numflag=0, edgecut;
+        int wgtflag=0, numflag=0, edgecut;
+#ifdef USE_WEIGHTS
+        wgtflag=3;
+#endif
         float *tpwgts = new float[nparts];
         for(int i=0; i<nparts; ++i)
           tpwgts[i]=1.0/nparts;
-        int options[4] ={ 0,0,0,0};
+        int options[5] ={ 1,3,15,0,0};
         MPI_Comm comm=oocomm.communicator();
 
         Dune::dinfo<<rank<<" vtxdist: ";
@@ -925,8 +934,13 @@ namespace Dune
         print_carray(Dune::dinfo, xadj, 2);
         Dune::dinfo<<std::endl<<rank<<" adjncy: ";
         print_carray(Dune::dinfo, adjncy, noNeighbours);
+
+#ifdef USE_WEIGHTS
+        Dune::dinfo<<std::endl<<rank<<" vwgt: ";
+        print_carray(Dune::dinfo, vwgt, 1);
         Dune::dinfo<<std::endl<<rank<<" adwgt: ";
         print_carray(Dune::dinfo, adjwgt, noNeighbours);
+#endif
         Dune::dinfo<<std::endl;
         oocomm.communicator().barrier();
         if(verbose && oocomm.communicator().rank()==0)
@@ -951,11 +965,10 @@ namespace Dune
         Timer time1;
         std::size_t gnoedges=0;
         int* noedges = 0;
-        if(rank==0)
-          noedges = new int[oocomm.communicator().size()];
+        noedges = new int[oocomm.communicator().size()];
         Dune::dverb<<"noNeighbours: "<<noNeighbours<<std::endl;
         // gather number of edges for each vertex.
-        oocomm.communicator().gather(&noNeighbours, noedges, 1, 0);
+        MPI_Allgather(&noNeighbours,1,MPI_INT,noedges,1, MPI_INT,oocomm.communicator());
 
         if(verbose && oocomm.communicator().rank()==0)
           std::cout<<"Gathering noedges took "<<time1.elapsed()<<std::endl;
@@ -975,7 +988,7 @@ namespace Dune
         std::size_t localNoVtx=vtxdist[rank+1]-vtxdist[rank];
         std::size_t gxadjlen = vtxdist[oocomm.communicator().size()]-vtxdist[0]+oocomm.communicator().size();
 
-        if(rank==0) {
+        {
           Dune::dinfo<<"noedges: ";
           print_carray(Dune::dinfo, noedges, oocomm.communicator().size());
           Dune::dinfo<<std::endl;
@@ -1022,9 +1035,11 @@ namespace Dune
                      <<" gnoedges: "<<gnoedges<<std::endl;
           gxadj = new idxtype[gxadjlen];
           gpart = new idxtype[noVertices];
+#ifdef USE_WEIGHTS
           gvwgt = new idxtype[noVertices];
-          gadjncy = new idxtype[gnoedges];
           gadjwgt = new idxtype[gnoedges];
+#endif
+          gadjncy = new idxtype[gnoedges];
         }
 
         if(verbose && oocomm.communicator().rank()==0)
@@ -1032,23 +1047,25 @@ namespace Dune
         time1.reset();
         // Communicate data
 
-        MPI_Gatherv(xadj,2,MPITraits<idxtype>::getType(),
-                    gxadj,noxs,xdispl,MPITraits<idxtype>::getType(),
-                    0,comm);
-        MPI_Gatherv(vwgt,localNoVtx,MPITraits<idxtype>::getType(),
-                    gvwgt,novs,vdispl,MPITraits<idxtype>::getType(),
-                    0,comm);
-        MPI_Gatherv(adjncy,noNeighbours,MPITraits<idxtype>::getType(),
-                    gadjncy,noedges,displ,MPITraits<idxtype>::getType(),
-                    0,comm);
-        MPI_Gatherv(adjwgt,noNeighbours,MPITraits<idxtype>::getType(),
-                    gadjwgt,noedges,displ,MPITraits<idxtype>::getType(),
-                    0,comm);
+        MPI_Allgatherv(xadj,2,MPITraits<idxtype>::getType(),
+                       gxadj,noxs,xdispl,MPITraits<idxtype>::getType(),
+                       comm);
+        MPI_Allgatherv(adjncy,noNeighbours,MPITraits<idxtype>::getType(),
+                       gadjncy,noedges,displ,MPITraits<idxtype>::getType(),
+                       comm);
+#ifdef USE_WEIGHTS
+        MPI_Allgatherv(adjwgt,noNeighbours,MPITraits<idxtype>::getType(),
+                       gadjwgt,noedges,displ,MPITraits<idxtype>::getType(),
+                       comm);
+        MPI_Allgatherv(vwgt,localNoVtx,MPITraits<idxtype>::getType(),
+                       gvwgt,novs,vdispl,MPITraits<idxtype>::getType(),
+                       comm);
+#endif
         if(verbose && oocomm.communicator().rank()==0)
           std::cout<<"Gathering global graph data took "<<time1.elapsed()<<std::endl;
         time1.reset();
 
-        if(rank==0) {
+        {
           // create the real gxadj array
           // i.e. shift entries and add displacements.
 
@@ -1068,24 +1085,28 @@ namespace Dune
           }
           Dune::dinfo<<std::endl<<"shifted xadj:";
           print_carray(Dune::dinfo, gxadj, noVertices+1);
-          Dune::dinfo<<std::endl<<" gvwgt: ";
-          print_carray(Dune::dinfo, gvwgt, noVertices);
           Dune::dinfo<<std::endl<<" gadjncy: ";
           print_carray(Dune::dinfo, gadjncy, gnoedges);
+#ifdef USE_WEIGHTS
+          Dune::dinfo<<std::endl<<" gvwgt: ";
+          print_carray(Dune::dinfo, gvwgt, noVertices);
           Dune::dinfo<<std::endl<<"adjwgt: ";
           print_carray(Dune::dinfo, gadjwgt, gnoedges);
           Dune::dinfo<<std::endl;
+#endif
           // everything should be fine now!!!
           if(verbose && oocomm.communicator().rank()==0)
             std::cout<<"Postprocesing global graph data took "<<time1.elapsed()<<std::endl;
           time1.reset();
+#ifndef NDEBUG
           assert(isValidGraph(noVertices, noVertices, gnoedges,
                               gxadj, gadjncy, true));
+#endif
 
           if(verbose && oocomm.communicator().rank()==0)
             std::cout<<"Creating grah one 1 process took "<<time.elapsed()<<std::endl;
           time.reset();
-
+          options[0]=1; options[1]=3; options[2]=1; options[3]=3; options[4]=3;
           // Call metis
           METIS_PartGraphKway(&noVertices, gxadj, gadjncy, gvwgt, gadjwgt, &wgtflag,
                               &numflag, &nparts, options, &edgecut, gpart);
@@ -1098,15 +1119,17 @@ namespace Dune
           print_carray(Dune::dinfo, gpart, noVertices);
 
           delete[] gxadj;
-          delete[] gvwgt;
           delete[] gadjncy;
+#ifdef USE_WEIGHTS
+          delete[] gvwgt;
           delete[] gadjwgt;
+#endif
         }
         // Scatter result
         MPI_Scatter(gpart, 1, MPITraits<idxtype>::getType(), part, 1,
                     MPITraits<idxtype>::getType(), 0, comm);
 
-        if(rank==0) {
+        {
           // release remaining memory
           delete[] gpart;
           delete[] noedges;
@@ -1116,10 +1139,12 @@ namespace Dune
 
 #endif
         delete[] xadj;
-        delete[] vwgt;
         delete[] vtxdist;
         delete[] adjncy;
+#ifdef USE_WEIGHTS
+        delete[] vwgt;
         delete[] adjwgt;
+#endif
         delete[] tpwgts;
       }
     }else{
@@ -1426,69 +1451,79 @@ namespace Dune
     // 4.1) Let's start...
     //
     int npes = oocomm.communicator().size();
-    int *sendTo = new int[npes];
-    int *recvFrom = new int[npes];
-    int *buf = new int[npes];
-    // init the buffers
-    for(int j=0; j<npes; j++) {
-      sendTo[j] = 0;
-      recvFrom[j] = 0;
-      buf[j] = 0;
-    }
+    int *sendTo = 0;
+    int noSendTo = 0;
+    std::set<int> recvFrom;
 
     // the max number of vertices is stored in the sendTo buffer,
     // not the number of vertices to send! Because the max number of Vtx
     // is used as the fixed buffer size by the MPI send/receive calls
 
-    // TODO: optimize buffer size
-    bool existentOnNextLevel=false;
-
     typedef typename std::vector<int>::const_iterator VIter;
-    int numOfVtx = oocomm.indexSet().size();
     int mype = oocomm.communicator().rank();
 
-    for(VIter i=setPartition.begin(), iend = setPartition.end(); i!=iend; ++i) {
-      if (*i!=mype) {
-        if (sendTo[*i]==0) {
-          sendTo[*i] = numOfVtx;
-          buf[*i] = numOfVtx;
-        }
-      }
-      else
-        existentOnNextLevel=true;
+    {
+      std::set<int> tsendTo;
+      for(VIter i=setPartition.begin(), iend = setPartition.end(); i!=iend; ++i)
+        tsendTo.insert(*i);
+
+      noSendTo = tsendTo.size();
+      sendTo = new int[noSendTo];
+      typedef std::set<int>::const_iterator iterator;
+      int idx=0;
+      for(iterator i=tsendTo.begin(); i != tsendTo.end(); ++i, ++idx)
+        sendTo[idx]=*i;
     }
 
-    // The own "send to" array is sent to the next process and so on.
-    // Each process receive such a array and pick up the
-    // corresponding "receive from" value. This value define the size
-    // of the buffer containing the vertices to receive by the next step.
-    // TODO: not really a ring communication
-    int pe=0;
-    int src = (mype-1+npes)%npes;
-    int dest = (mype+1)%npes;
-
-    MPI_Comm comm = oocomm.communicator();
-    MPI_Status status;
-
-    // ring communication, we need n-1 communication for n processors
-    for (int i=0; i<npes-1; i++) {
-      MPI_Sendrecv_replace(buf, npes, MPI_INT, dest, 0, src, 0, comm, &status);
-      // pe is the process of the actual received buffer
-      pe = ((mype-1-i)+npes)%npes;
-      recvFrom[pe] = buf[mype]; // pick up the "recv from" value for myself
-      if(recvFrom[pe]>0)
-        existentOnNextLevel=true;
-    }
-    delete[] buf;
+    //
+    int* gnoSend= new int[oocomm.communicator().size()];
+    int* gsendToDispl =  new int[oocomm.communicator().size()+1];
+
+    MPI_Allgather(&noSendTo, 1, MPI_INT, gnoSend, 1,
+                  MPI_INT, oocomm.communicator());
+
+    // calculate total receive message size
+    int totalNoRecv = 0;
+    for(int i=0; i<npes; ++i)
+      totalNoRecv += gnoSend[i];
+
+    int *gsendTo = new int[totalNoRecv];
+
+    // calculate displacement for allgatherv
+    gsendToDispl[0]=0;
+    for(int i=0; i<npes; ++i)
+      gsendToDispl[i+1]=gsendToDispl[i]+gnoSend[i];
+
+    // gather the data
+    MPI_Allgatherv(sendTo, noSendTo, MPI_INT, gsendTo, gnoSend, gsendToDispl,
+                   MPI_INT, oocomm.communicator());
+
+    // Extract from which processes we will receive data
+    for(int proc=0; proc < npes; ++proc)
+      for(int i=gsendToDispl[proc]; i < gsendToDispl[proc+1]; ++i)
+        if(gsendTo[i]==mype)
+          recvFrom.insert(proc);
+
+    bool existentOnNextLevel = recvFrom.size()>0;
+
+    // Delete memory
+    delete[] gnoSend;
+    delete[] gsendToDispl;
+    delete[] gsendTo;
+
 
 #ifdef DEBUG_REPART
-    std::cout<<mype<<": recvFrom: ";
-    for(int i=0; i<npes; i++) {
-      std::cout<<recvFrom[i]<<" ";
+    if(recvFrom.size()) {
+      std::cout<<mype<<": recvFrom: ";
+      typedef typename std::set<int>::const_iterator siter;
+      for(siter i=recvFrom.begin(); i!= recvFrom.end(); ++i) {
+        std::cout<<*i<<" ";
+      }
     }
+
     std::cout<<std::endl<<std::endl;
     std::cout<<mype<<": sendTo: ";
-    for(int i=0; i<npes; i++) {
+    for(int i=0; i<noSendTo; i++) {
       std::cout<<sendTo[i]<<" ";
     }
     std::cout<<std::endl<<std::endl;
@@ -1513,68 +1548,103 @@ namespace Dune
     std::set<GI> sendOverlapSet;
     std::set<int> myNeighbors;
 
-    getOwnerOverlapVec<OwnerSet>(graph, setPartition, oocomm.globalLookup(),
-                                 mype, mype, myOwnerVec, myOverlapSet, redistInf, myNeighbors);
-
-    for(int i=0; i < npes; ++i) {
-      // the rank of the process defines the sending order,
-      // so it starts naturally by 0
-
-      if (i==mype) {
-        for(int j=0; j < npes; ++j) {
-          if (sendTo[j]>0) {
-            // clear the vector for sending
-            sendOwnerVec.clear();
-            sendOverlapSet.clear();
-            // get all owner and overlap vertices for process j and save these
-            // in the vectors sendOwnerVec and sendOverlapSet
-            std::set<int> neighbors;
-            getOwnerOverlapVec<OwnerSet>(graph, setPartition, oocomm.globalLookup(),
-                                         mype, j, sendOwnerVec, sendOverlapSet, redistInf,
-                                         neighbors);
-            // +2, we need 2 integer more for the length of each part
-            // (owner/overlap) of the array
-            int buffersize=0;
-            int tsize;
-            MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &buffersize);
-            MPI_Pack_size(sendOwnerVec.size(), MPITraits<GI>::getType(), oocomm.communicator(), &tsize);
-            buffersize +=tsize;
-            MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &tsize);
-            buffersize +=tsize;
-            MPI_Pack_size(sendOverlapSet.size(), MPITraits<GI>::getType(), oocomm.communicator(), &tsize);
-            buffersize += tsize;
-            MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &tsize);
-            buffersize += tsize;
-            MPI_Pack_size(neighbors.size(), MPI_INT, oocomm.communicator(), &tsize);
-            buffersize += tsize;
-
-            char* sendBuf = new char[buffersize];
-#ifdef DEBUG_REPART
-            std::cout<<mype<<" sending "<<sendOwnerVec.size()<<" owner and "<<
-            sendOverlapSet.size()<<" overlap to "<<j<<" buffersize="<<buffersize<<std::endl;
-#endif
-            createSendBuf(sendOwnerVec, sendOverlapSet, neighbors, sendBuf, buffersize, oocomm.communicator());
-            MPI_Send(sendBuf, buffersize, MPI_PACKED, j, 0, oocomm.communicator());
-            delete[] sendBuf;
-          }
-        }
-      } else { // All the other processes have to wait for receive...
-        if (recvFrom[i]>0) {
-          // Get buffer size
-          MPI_Probe(i, 0,oocomm.communicator(), &status);
-          int buffersize=0;
-          MPI_Get_count(&status, MPI_PACKED, &buffersize);
-          char* recvBuf = new char[buffersize];
+    //    getOwnerOverlapVec<OwnerSet>(graph, setPartition, oocomm.globalLookup(),
+    //				 mype, mype, myOwnerVec, myOverlapSet, redistInf, myNeighbors);
+
+    char **sendBuffers=new char*[noSendTo];
+    MPI_Request *requests = new MPI_Request[noSendTo];
+
+    // Create all messages to be sent
+    for(int i=0; i < noSendTo; ++i) {
+      // clear the vector for sending
+      sendOwnerVec.clear();
+      sendOverlapSet.clear();
+      // get all owner and overlap vertices for process j and save these
+      // in the vectors sendOwnerVec and sendOverlapSet
+      std::set<int> neighbors;
+      getOwnerOverlapVec<OwnerSet>(graph, setPartition, oocomm.globalLookup(),
+                                   mype, sendTo[i], sendOwnerVec, sendOverlapSet, redistInf,
+                                   neighbors);
+      // +2, we need 2 integer more for the length of each part
+      // (owner/overlap) of the array
+      int buffersize=0;
+      int tsize;
+      MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &buffersize);
+      MPI_Pack_size(sendOwnerVec.size(), MPITraits<GI>::getType(), oocomm.communicator(), &tsize);
+      buffersize +=tsize;
+      MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &tsize);
+      buffersize +=tsize;
+      MPI_Pack_size(sendOverlapSet.size(), MPITraits<GI>::getType(), oocomm.communicator(), &tsize);
+      buffersize += tsize;
+      MPI_Pack_size(1, MPITraits<std::size_t>::getType(), oocomm.communicator(), &tsize);
+      buffersize += tsize;
+      MPI_Pack_size(neighbors.size(), MPI_INT, oocomm.communicator(), &tsize);
+      buffersize += tsize;
+
+      sendBuffers[i] = new char[buffersize];
+
 #ifdef DEBUG_REPART
-          std::cout<<mype<<" receiving "<<recvFrom[i]<<" from "<<i<<" buffersize="<<buffersize<<std::endl;
+      std::cout<<mype<<" sending "<<sendOwnerVec.size()<<" owner and "<<
+      sendOverlapSet.size()<<" overlap to "<<sendTo[i]<<" buffersize="<<buffersize<<std::endl;
 #endif
-          MPI_Recv(recvBuf, buffersize, MPI_PACKED, i, 0, oocomm.communicator(), &status);
-          saveRecvBuf(recvBuf, buffersize, myOwnerVec, myOverlapSet, myNeighbors, redistInf, i, oocomm.communicator());
-          delete[] recvBuf;
-        }
+      createSendBuf(sendOwnerVec, sendOverlapSet, neighbors, sendBuffers[i], buffersize, oocomm.communicator());
+      MPI_Issend(sendBuffers[i], buffersize, MPI_PACKED, sendTo[i], 99, oocomm.communicator(), requests+i);
+    }
+
+
+    // Receive Messages
+    int noRecv = recvFrom.size();
+    int oldbuffersize=0;
+    char* recvBuf = 0;
+    while(noRecv>0) {
+      // probe for an incoming message
+      MPI_Status stat;
+      MPI_Probe(MPI_ANY_SOURCE, 99,  oocomm.communicator(), &stat);
+      int buffersize;
+      MPI_Get_count(&stat, MPI_PACKED, &buffersize);
+
+      if(oldbuffersize<buffersize) {
+        // buffer too small, reallocate
+        delete[] recvBuf;
+        recvBuf = new char[buffersize];
+        oldbuffersize = buffersize;
       }
+      MPI_Recv(recvBuf, buffersize, MPI_PACKED, stat.MPI_SOURCE, 99, oocomm.communicator(), &stat);
+      saveRecvBuf(recvBuf, buffersize, myOwnerVec, myOverlapSet, myNeighbors, redistInf,
+                  stat.MPI_SOURCE, oocomm.communicator());
+      --noRecv;
+    }
+
+    if(recvBuf)
+      delete[] recvBuf;
+
+    // Wait for sending messages to complete
+    MPI_Status *statuses = new MPI_Status[noSendTo];
+    int send = MPI_Waitall(noSendTo, requests, statuses);
+
+    // check for errors
+    if(send==MPI_ERR_IN_STATUS) {
+      std::cerr<<mype<<": Error in sending :"<<std::endl;
+      // Search for the error
+      for(int i=0; i< noSendTo; i++)
+        if(statuses[i].MPI_ERROR!=MPI_SUCCESS) {
+          char message[300];
+          int messageLength;
+          MPI_Error_string(statuses[i].MPI_ERROR, message, &messageLength);
+          std::cerr<<" source="<<statuses[i].MPI_SOURCE<<" message: ";
+          for(int i=0; i< messageLength; i++)
+            std::cout<<message[i];
+        }
+      std::cerr<<std::endl;
     }
 
+    for(int i=0; i < noSendTo; ++i)
+      delete[] sendBuffers[i];
+
+    delete[] sendBuffers;
+    delete[] statuses;
+    delete[] requests;
+
     redistInf.setCommunicator(oocomm.communicator());
 
     //
@@ -1689,7 +1759,6 @@ namespace Dune
 
     // release the memory
     delete[] sendTo;
-    delete[] recvFrom;
 
 
 #ifdef PERF_REPART
-- 
GitLab