diff --git a/dune/grid/test/test-parallel-ug.cc b/dune/grid/test/test-parallel-ug.cc
index 430c60960f0061aa039f94978d2c8b3881d3c5c2..2fba71e3bc74c9dfa3689e2d22d9fcbdc080bbac 100644
--- a/dune/grid/test/test-parallel-ug.cc
+++ b/dune/grid/test/test-parallel-ug.cc
@@ -364,23 +364,38 @@ public:
   }
 };
 
+template<typename Grid>
 class LoadBalance
 {
-  template<class Grid, class Vector, int commCodim>
+  const static int dimension = Grid::dimension;
+  using ctype = typename Grid::ctype;
+  using Position = Dune::FieldVector<ctype, dimension>;
+  using GridView = typename Grid::LeafGridView;
+  using IdSet = typename Grid::LocalIdSet;
+  using Data = std::map<typename IdSet::IdType, Position>;
+
+  using Codims = std::bitset<dimension+1>;
+
   class LBDataHandle
-    : public Dune::CommDataHandleIF<LBDataHandle<Grid, Vector, commCodim>,
-          typename Vector::value_type>
+    : public Dune::CommDataHandleIF<LBDataHandle,
+          typename Data::mapped_type>
   {
   public:
-    typedef typename Vector::value_type DataType;
-    typedef Dune::CommDataHandleIF<LBDataHandle<Grid, Vector, commCodim>,
-        DataType> ParentType;
+    typedef typename Data::mapped_type DataType;
+
+  public:
 
     bool contains (int dim, int codim) const
-    { return (codim == commCodim); }
+    {
+      assert(dim == dimension);
+      return codims_.test(codim);
+    }
 
     bool fixedSize (int dim, int codim) const
-    { return true; }
+    {
+      assert(dim == dimension);
+      return true;
+    }
 
     template<class Entity>
     size_t size (Entity& entity) const
@@ -391,81 +406,144 @@ class LoadBalance
     template<class MessageBuffer, class Entity>
     void gather (MessageBuffer& buff, const Entity& entity) const
     {
-      int index = grid_.leafGridView().indexSet().index(entity);
-      buff.write(dataVector_[index]);
+      const auto& id = idSet_.id(entity);
+      buff.write(data_.at(id));
     }
 
     template<class MessageBuffer, class Entity>
-    void scatter (MessageBuffer& buff, const Entity& entity, size_t n)
+    void scatter (MessageBuffer& buff, const Entity& entity, size_t)
     {
-      if (dataVector_.size() != grid_.leafGridView().size(commCodim))
-        dataVector_.resize(grid_.leafGridView().size(commCodim));
-
-      int index = grid_.leafGridView().indexSet().index(entity);
-      buff.read(dataVector_[index]);
+      const auto& id = idSet_.id(entity);
+      buff.read(data_[id]);
     }
 
-    LBDataHandle (Grid& grid, Vector& dataVector)
-      : grid_(grid), dataVector_(dataVector)
+    LBDataHandle (const IdSet& idSet, Data& data, const Codims& codims)
+      : idSet_(idSet)
+      , data_(data)
+      , codims_(codims)
     {}
 
   private:
-    Grid& grid_;
-    Vector& dataVector_;
+    const IdSet& idSet_;
+    Data& data_;
+    const Codims codims_;
   };
 
-public:
-  template <class Grid>
-  static void test(Grid& grid)
+  template<typename=void>
+  static Codims toBitset(Codims codims = {})
+    { return codims; }
+
+  template<int codim, int... codimensions, typename=void>
+  static Codims toBitset(Codims codims = {})
+    { return toBitset<codimensions...>(codims.set(codim)); }
+
+  template<typename=void>
+  static void fillVector(const GridView&, Data&) {}
+
+  template<int codim, int... codimensions, typename=void>
+  static void fillVector(const GridView& gv, Data& data)
   {
-    const int dim = Grid::dimension;
-    const int commCodim = dim;
-    typedef typename Grid::ctype ctype;
+    std::cout << "Filling vector for codim " << codim << "\n";
 
-    // define the vector containing the data to be balanced
-    typedef Dune::FieldVector<ctype, dim> Position;
-    std::vector<Position> dataVector(grid.leafGridView().size(commCodim));
+    const auto& idSet = gv.grid().localIdSet();
 
-    // fill the data vector
-    const auto& gv = grid.leafGridView();
-    for (const auto& entity : entities(gv, Dune::Codim<commCodim>(),
+    for (const auto& entity : entities(gv, Dune::Codim<codim>(),
                                        Dune::Partitions::interiorBorder)) {
-      int index = gv.indexSet().index(entity);
+      const auto& id = idSet.id(entity);
 
       // assign the position of the entity to the entry in the vector
-      dataVector[index] = entity.geometry().center();
+      data[id] = entity.geometry().center();
     }
 
-    // balance the grid and the data
-    LBDataHandle<Grid, std::vector<Position>, commCodim> dataHandle(grid, dataVector);
-    grid.loadBalance(dataHandle);
+    fillVector<codimensions...>(gv, data);
+  }
 
-    // check for correctness
-    for (const auto& entity : entities(gv, Dune::Codim<commCodim>(),
-                                       Dune::Partitions::interiorBorder)) {
-      int index = gv.indexSet().index(entity);
+  template<typename=void>
+  static bool checkVector(const GridView&, const Data&)
+    { return true; }
+
+  template<int codim, int... codimensions, typename=void>
+  static bool checkVector(const GridView& gv, const Data& data)
+  {
+    std::cout << "Checking vector for codim " << codim << "\n";
 
-      const auto position = entity.geometry().center();
+    const auto& idSet = gv.grid().localIdSet();
+
+    for (const auto& entity : entities(gv, Dune::Codim<codim>(),
+                                       Dune::Partitions::interiorBorder)) {
+      const auto& id = idSet.id(entity);
+      const auto& commPos = data.at(id);
+      const auto& realPos = entity.geometry().center();
 
       // compare the position with the balanced data
-      for (int k = 0; k < dim; k++)
+      for (int k = 0; k < dimension; k++)
       {
-        if (Dune::FloatCmp::ne(dataVector[index][k], position[k]))
+        if (Dune::FloatCmp::ne(commPos[k], realPos[k]))
         {
           DUNE_THROW(Dune::ParallelError,
-                     gv.comm().rank() << ": position " << position
+                     gv.comm().rank() << ": position " << realPos
                                       << " does not coincide with communicated data "
-                                      << dataVector[index]);
+                                      << commPos);
         }
       }
     }
 
+    return checkVector<codimensions...>(gv, data);
+  }
+
+public:
+  template<int... codimensions>
+  static void test(Grid& grid)
+  {
+    const Codims codims = toBitset<codimensions...>();
+    const auto& gv = grid.leafGridView();
+
+    // define the vector containing the data to be balanced
+    Data data;
+
+    LBDataHandle dataHandle(grid.localIdSet(), data, codims);
+
+    // fill the data vector
+    fillVector<codimensions...>(gv, data);
+
+    // balance the grid and the data
+    grid.loadBalance(dataHandle);
+
+    // check for correctness
+    checkVector<codimensions...>(gv, data);
+
     std::cout << gv.comm().rank()
               << ": load balancing with data was successful." << std::endl;
   }
 };
 
-template <int dim>
+template<typename Grid>
+std::shared_ptr<Grid>
+setupGrid(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart)
+{
+  const static int dim = Grid::dimension;
+  StructuredGridFactory<Grid> structuredGridFactory;
+
+  Dune::FieldVector<double,dim> lowerLeft(0);
+  Dune::FieldVector<double,dim> upperRight(1);
+  std::array<unsigned int, dim> numElements;
+  std::fill(numElements.begin(), numElements.end(), 4);
+  if (simplexGrid)
+    return structuredGridFactory.createSimplexGrid(lowerLeft, upperRight, numElements);
+  else
+    return structuredGridFactory.createCubeGrid(lowerLeft, upperRight, numElements);
+}
+
+template<int dim, int... codimensions>
+void testLoadBalance(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart)
+{
+  using Grid = UGGrid<dim>;
+  auto grid = setupGrid<Grid>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+  LoadBalance<Grid>::template test<codimensions...>(*grid);
+  // LoadBalance<Grid>::template test<codimensions...>(*grid);
+}
+
+template<int dim>
 void testParallelUG(bool simplexGrid, bool localRefinement, int refinementDim, bool refineUpperPart)
 {
   std::cout << "Testing parallel UGGrid for " << dim << "D\n";
@@ -475,23 +553,12 @@ void testParallelUG(bool simplexGrid, bool localRefinement, int refinementDim, b
   ////////////////////////////////////////////////////////////
 
   typedef UGGrid<dim> GridType;
-
-  StructuredGridFactory<GridType> structuredGridFactory;
-
-  Dune::FieldVector<double,dim> lowerLeft(0);
-  Dune::FieldVector<double,dim> upperRight(1);
-  std::array<unsigned int, dim> numElements;
-  std::fill(numElements.begin(), numElements.end(), 4);
-  std::shared_ptr<GridType> grid;
-  if (simplexGrid)
-    grid = structuredGridFactory.createSimplexGrid(lowerLeft, upperRight, numElements);
-  else
-    grid = structuredGridFactory.createCubeGrid(lowerLeft, upperRight, numElements);
+  auto grid = setupGrid<GridType>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
 
   //////////////////////////////////////////////////////
   // Distribute the grid
   //////////////////////////////////////////////////////
-  LoadBalance::test(*grid);
+  grid->loadBalance();
 
   std::cout << "Process " << grid->comm().rank() + 1
             << " has " << grid->size(0)
@@ -646,10 +713,20 @@ int main (int argc , char **argv) try
   for (const bool simplexGrid : {false, true}) {
     for (const bool localRefinement : {false, true}) {
       for (const bool refineUpperPart : {false, true}) {
-        for (const int refinementDim : {0,1})
+        for (const int refinementDim : {0,1}) {
           testParallelUG<2>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
-        for (const int refinementDim : {0,1,2})
+          testLoadBalance<2>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<2, 0>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<2, 2>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<2, 0, 2>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+        }
+        for (const int refinementDim : {0,1,2}) {
           testParallelUG<3>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<3>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<3, 0>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<3, 3>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+          testLoadBalance<3, 0, 3>(simplexGrid, localRefinement, refinementDim, refineUpperPart);
+        }
       }
     }
   }