diff --git a/dune/grid/test/test-parallel-ug.cc b/dune/grid/test/test-parallel-ug.cc index 430c60960f0061aa039f94978d2c8b3881d3c5c2..2fba71e3bc74c9dfa3689e2d22d9fcbdc080bbac 100644 --- a/dune/grid/test/test-parallel-ug.cc +++ b/dune/grid/test/test-parallel-ug.cc @@ -364,23 +364,38 @@ public: } }; +template<typename Grid> class LoadBalance { - template<class Grid, class Vector, int commCodim> + const static int dimension = Grid::dimension; + using ctype = typename Grid::ctype; + using Position = Dune::FieldVector<ctype, dimension>; + using GridView = typename Grid::LeafGridView; + using IdSet = typename Grid::LocalIdSet; + using Data = std::map<typename IdSet::IdType, Position>; + + using Codims = std::bitset<dimension+1>; + class LBDataHandle - : public Dune::CommDataHandleIF<LBDataHandle<Grid, Vector, commCodim>, - typename Vector::value_type> + : public Dune::CommDataHandleIF<LBDataHandle, + typename Data::mapped_type> { public: - typedef typename Vector::value_type DataType; - typedef Dune::CommDataHandleIF<LBDataHandle<Grid, Vector, commCodim>, - DataType> ParentType; + typedef typename Data::mapped_type DataType; + + public: bool contains (int dim, int codim) const - { return (codim == commCodim); } + { + assert(dim == dimension); + return codims_.test(codim); + } bool fixedSize (int dim, int codim) const - { return true; } + { + assert(dim == dimension); + return true; + } template<class Entity> size_t size (Entity& entity) const @@ -391,81 +406,144 @@ class LoadBalance template<class MessageBuffer, class Entity> void gather (MessageBuffer& buff, const Entity& entity) const { - int index = grid_.leafGridView().indexSet().index(entity); - buff.write(dataVector_[index]); + const auto& id = idSet_.id(entity); + buff.write(data_.at(id)); } template<class MessageBuffer, class Entity> - void scatter (MessageBuffer& buff, const Entity& entity, size_t n) + void scatter (MessageBuffer& buff, const Entity& entity, size_t) { - if (dataVector_.size() != grid_.leafGridView().size(commCodim)) - dataVector_.resize(grid_.leafGridView().size(commCodim)); - - int index = grid_.leafGridView().indexSet().index(entity); - buff.read(dataVector_[index]); + const auto& id = idSet_.id(entity); + buff.read(data_[id]); } - LBDataHandle (Grid& grid, Vector& dataVector) - : grid_(grid), dataVector_(dataVector) + LBDataHandle (const IdSet& idSet, Data& data, const Codims& codims) + : idSet_(idSet) + , data_(data) + , codims_(codims) {} private: - Grid& grid_; - Vector& dataVector_; + const IdSet& idSet_; + Data& data_; + const Codims codims_; }; -public: - template <class Grid> - static void test(Grid& grid) + template<typename=void> + static Codims toBitset(Codims codims = {}) + { return codims; } + + template<int codim, int... codimensions, typename=void> + static Codims toBitset(Codims codims = {}) + { return toBitset<codimensions...>(codims.set(codim)); } + + template<typename=void> + static void fillVector(const GridView&, Data&) {} + + template<int codim, int... codimensions, typename=void> + static void fillVector(const GridView& gv, Data& data) { - const int dim = Grid::dimension; - const int commCodim = dim; - typedef typename Grid::ctype ctype; + std::cout << "Filling vector for codim " << codim << "\n"; - // define the vector containing the data to be balanced - typedef Dune::FieldVector<ctype, dim> Position; - std::vector<Position> dataVector(grid.leafGridView().size(commCodim)); + const auto& idSet = gv.grid().localIdSet(); - // fill the data vector - const auto& gv = grid.leafGridView(); - for (const auto& entity : entities(gv, Dune::Codim<commCodim>(), + for (const auto& entity : entities(gv, Dune::Codim<codim>(), Dune::Partitions::interiorBorder)) { - int index = gv.indexSet().index(entity); + const auto& id = idSet.id(entity); // assign the position of the entity to the entry in the vector - dataVector[index] = entity.geometry().center(); + data[id] = entity.geometry().center(); } - // balance the grid and the data - LBDataHandle<Grid, std::vector<Position>, commCodim> dataHandle(grid, dataVector); - grid.loadBalance(dataHandle); + fillVector<codimensions...>(gv, data); + } - // check for correctness - for (const auto& entity : entities(gv, Dune::Codim<commCodim>(), - Dune::Partitions::interiorBorder)) { - int index = gv.indexSet().index(entity); + template<typename=void> + static bool checkVector(const GridView&, const Data&) + { return true; } + + template<int codim, int... codimensions, typename=void> + static bool checkVector(const GridView& gv, const Data& data) + { + std::cout << "Checking vector for codim " << codim << "\n"; - const auto position = entity.geometry().center(); + const auto& idSet = gv.grid().localIdSet(); + + for (const auto& entity : entities(gv, Dune::Codim<codim>(), + Dune::Partitions::interiorBorder)) { + const auto& id = idSet.id(entity); + const auto& commPos = data.at(id); + const auto& realPos = entity.geometry().center(); // compare the position with the balanced data - for (int k = 0; k < dim; k++) + for (int k = 0; k < dimension; k++) { - if (Dune::FloatCmp::ne(dataVector[index][k], position[k])) + if (Dune::FloatCmp::ne(commPos[k], realPos[k])) { DUNE_THROW(Dune::ParallelError, - gv.comm().rank() << ": position " << position + gv.comm().rank() << ": position " << realPos << " does not coincide with communicated data " - << dataVector[index]); + << commPos); } } } + return checkVector<codimensions...>(gv, data); + } + +public: + template<int... codimensions> + static void test(Grid& grid) + { + const Codims codims = toBitset<codimensions...>(); + const auto& gv = grid.leafGridView(); + + // define the vector containing the data to be balanced + Data data; + + LBDataHandle dataHandle(grid.localIdSet(), data, codims); + + // fill the data vector + fillVector<codimensions...>(gv, data); + + // balance the grid and the data + grid.loadBalance(dataHandle); + + // check for correctness + checkVector<codimensions...>(gv, data); + std::cout << gv.comm().rank() << ": load balancing with data was successful." << std::endl; } }; -template <int dim> +template<typename Grid> +std::shared_ptr<Grid> +setupGrid(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart) +{ + const static int dim = Grid::dimension; + StructuredGridFactory<Grid> structuredGridFactory; + + Dune::FieldVector<double,dim> lowerLeft(0); + Dune::FieldVector<double,dim> upperRight(1); + std::array<unsigned int, dim> numElements; + std::fill(numElements.begin(), numElements.end(), 4); + if (simplexGrid) + return structuredGridFactory.createSimplexGrid(lowerLeft, upperRight, numElements); + else + return structuredGridFactory.createCubeGrid(lowerLeft, upperRight, numElements); +} + +template<int dim, int... codimensions> +void testLoadBalance(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart) +{ + using Grid = UGGrid<dim>; + auto grid = setupGrid<Grid>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + LoadBalance<Grid>::template test<codimensions...>(*grid); + // LoadBalance<Grid>::template test<codimensions...>(*grid); +} + +template<int dim> void testParallelUG(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart) { std::cout << "Testing parallel UGGrid for " << dim << "D\n"; @@ -475,23 +553,12 @@ void testParallelUG(bool simplexGrid, bool localRefinement, int refinementDim, b //////////////////////////////////////////////////////////// typedef UGGrid<dim> GridType; - - StructuredGridFactory<GridType> structuredGridFactory; - - Dune::FieldVector<double,dim> lowerLeft(0); - Dune::FieldVector<double,dim> upperRight(1); - std::array<unsigned int, dim> numElements; - std::fill(numElements.begin(), numElements.end(), 4); - std::shared_ptr<GridType> grid; - if (simplexGrid) - grid = structuredGridFactory.createSimplexGrid(lowerLeft, upperRight, numElements); - else - grid = structuredGridFactory.createCubeGrid(lowerLeft, upperRight, numElements); + auto grid = setupGrid<GridType>(simplexGrid, localRefinement, refinementDim, refineUpperPart); ////////////////////////////////////////////////////// // Distribute the grid ////////////////////////////////////////////////////// - LoadBalance::test(*grid); + grid->loadBalance(); std::cout << "Process " << grid->comm().rank() + 1 << " has " << grid->size(0) @@ -646,10 +713,20 @@ int main (int argc , char **argv) try for (const bool simplexGrid : {false, true}) { for (const bool localRefinement : {false, true}) { for (const bool refineUpperPart : {false, true}) { - for (const int refinementDim : {0,1}) + for (const int refinementDim : {0,1}) { testParallelUG<2>(simplexGrid, localRefinement, refinementDim, refineUpperPart); - for (const int refinementDim : {0,1,2}) + testLoadBalance<2>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<2, 0>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<2, 2>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<2, 0, 2>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + } + for (const int refinementDim : {0,1,2}) { testParallelUG<3>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<3>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<3, 0>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<3, 3>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + testLoadBalance<3, 0, 3>(simplexGrid, localRefinement, refinementDim, refineUpperPart); + } } } }