diff --git a/dune/curvedgeometry/curvedgeometry.hh b/dune/curvedgeometry/curvedgeometry.hh
index 47ef316776a2581902849a753c22eca909ab3595..1aa7c555f9ef88a834a051b2d7b9a05de6b97791 100644
--- a/dune/curvedgeometry/curvedgeometry.hh
+++ b/dune/curvedgeometry/curvedgeometry.hh
@@ -10,6 +10,7 @@
 #include <dune/common/fmatrix.hh>
 #include <dune/common/fvector.hh>
 #include <dune/common/typetraits.hh>
+#include <dune/common/std/type_traits.hh>
 
 #include <dune/geometry/affinegeometry.hh>
 #include <dune/geometry/quadraturerules.hh>
@@ -22,33 +23,6 @@ namespace Dune
 {
   namespace Impl
   {
-    template <class T, int n, int m>
-    FieldMatrix<T,n,m> outerProduct(const FieldVector<T,n>& a, const FieldVector<T,m>& b)
-    {
-      FieldMatrix<T,n,m> res;
-      for (int i = 0; i < n; ++i)
-        for (int j = 0; j < m; ++j)
-          res[i][j] = a[i] * b[j];
-      return res;
-    }
-
-    template <class T, int n, int m>
-    FieldMatrix<T,n,m> outerProduct(const FieldMatrix<T,n,1>& a, const FieldVector<T,m>& b)
-    {
-      FieldMatrix<T,n,m> res;
-      for (int i = 0; i < n; ++i)
-        for (int j = 0; j < m; ++j)
-          res[i][j] = a[i][0] * b[j];
-      return res;
-    }
-
-    template <class T, int n, int m,
-      std::enable_if_t<(n > 1), int> = 0>
-    FieldMatrix<T,n,m> outerProduct(const FieldMatrix<T,1,n>& a, const FieldVector<T,m>& b)
-    {
-      return outerProduct(a[0], b);
-    }
-
     template <class T, int n, int m>
     void outerProductAccumulate(const FieldVector<T,n>& a, const FieldVector<T,m>& b, FieldMatrix<T,n,m>& res)
     {
@@ -85,8 +59,9 @@ namespace Dune
    *  This structure provides the default values.
    *
    *  \tparam  ct  coordinate type
+   *  \tparam  LFECache  A LocalFiniteElementVariantCache implementation, e.g. LagrangeLocalFiniteElementCache
    */
-  template <class ct>
+  template <class ct, class LFECache>
   struct CurvedGeometryTraits
   {
     /// \brief helper structure containing some matrix routines. See affinegeometry.hh
@@ -98,26 +73,7 @@ namespace Dune
     /// \brief maximal number of Newton iteration in `geometry.local(global)`
     static int maxIteration () { return 100; }
 
-    /// \brief template specifying the storage for the vertices
-    /**
-     *  Internally, the CurvedGeometry needs to store the lagrange vertices of the
-     *  geometry.
-     *
-     *  \tparam  coorddim   coordinate dimension
-     */
-    template <int coorddim>
-    struct VertexStorage
-    {
-      using type = std::vector<FieldVector<ct, coorddim>>;
-    };
-
-    /// \brief will there be only one geometry type for a dimension?
-    template <int dim>
-    struct hasSingleGeometryType
-    {
-      static const bool v = false;
-      static const unsigned int topologyId = ~0u;
-    };
+    using LocalFiniteElementCache = LFECache;
   };
 
 
@@ -131,28 +87,23 @@ namespace Dune
    *  \tparam  ct      coordinate type
    *  \tparam  mydim   geometry dimension
    *  \tparam  cdim    coordinate dimension
-   *  \tparam  polynomial_order
+   *  \tparam  Traits  Parametrization of the geometry, see \ref CurvedGeometryTraits
    *
    *  The requirements on the traits are documented along with their default,
    *  CurvedGeometryTraits.
    */
-  template <class ct, int mydim, int cdim, int polynomial_order>
+  template <class ct, int mydim, int cdim, class Traits>
   class CurvedGeometry
   {
   public:
     /// coordinate type
     using ctype = ct;
 
-    using Traits = CurvedGeometryTraits<ctype>;
-
     /// geometry dimension
     static const int mydimension = mydim;
     /// coordinate dimension
     static const int coorddimension = cdim;
 
-    /// Polynomial order of geometry
-    static const int order = polynomial_order;
-
     /// type of local coordinates
     using LocalCoordinate = FieldVector<ctype, mydimension>;
     /// type of global coordinates
@@ -176,43 +127,91 @@ namespace Dune
   protected:
     using MatrixHelper = typename Traits::MatrixHelper;
 
-    using LocalFECache = LagrangeLocalFiniteElementCache<ctype, ctype, mydimension, order>;
+    using LocalFECache = typename Traits::LocalFiniteElementCache;
     using LocalFiniteElement = typename LocalFECache::FiniteElementType;
     using LocalBasis = typename LocalFiniteElement::Traits::LocalBasisType;
     using LocalBasisTraits = typename LocalBasis::Traits;
 
+  protected:
+    CurvedGeometry (const ReferenceElement& refElement)
+      : refElement_(refElement)
+      , localFECache_()
+      , localFE_(localFECache_.get(refElement.type()))
+    {}
+
+
   public:
     /// \brief constructor
     /**
      *  \param[in]  refElement  reference element for the geometry
-     *  \param[in]  corners     corners to store internally
+     *  \param[in]  vertices    vertices to store internally
      *
      *  \note The type of vertices is actually a template argument.
      *        It is only required that the internal vertex storage can be
      *        constructed from this object.
      */
-    template <class Vertices>
-    CurvedGeometry (const ReferenceElement& refElement, const Vertices& vertices)
-      : refElement_(refElement)
-      , vertices_(vertices)
-      , localFECache_()
-      , localBasis_(localFECache_.get(refElement.type()).localBasis())
-    {}
+    CurvedGeometry (const ReferenceElement& refElement, std::vector<GlobalCoordinate> vertices)
+      : CurvedGeometry(refElement)
+    {
+      vertices_ = std::move(vertices);
+      assert(localFE_.size() == vertices_.size());
+    }
+
+    template <class Parametrization,
+      std::enable_if_t<Std::is_callable<Parametrization(LocalCoordinate), GlobalCoordinate>::value, bool> = true>
+    CurvedGeometry (const ReferenceElement& refElement, Parametrization&& param)
+      : CurvedGeometry(refElement)
+    {
+      auto const& localInterpolation = localFE_.localInterpolation();
+      localInterpolation.interpolate(param, vertices_);
+    }
 
     /// \brief constructor
     /**
-     *  \param[in]  gt          geometry type
-     *  \param[in]  corners     corners to store internally
+     *  \param[in]  gt     geometry type
+     *  \param[in]  param  either a vector of vertices, or a functor that can be used to construct the vertices
      */
-    template <class Vertices>
-    CurvedGeometry (Dune::GeometryType gt, const Vertices& vertices)
-      : CurvedGeometry(ReferenceElements::general(gt), vertices)
+    template <class Parametrization>
+    CurvedGeometry (Dune::GeometryType gt, Parametrization&& param)
+      : CurvedGeometry(ReferenceElements::general(gt), std::forward<Parametrization>(param))
+    {}
+
+    /// \brief Copy constructor
+    CurvedGeometry (const CurvedGeometry& that)
+      : CurvedGeometry(that.refElement_, that.vertices_)
+    {}
+
+    /// \brief Move constructor
+    CurvedGeometry (CurvedGeometry&& that)
+      : CurvedGeometry(that.refElement_, std::move(that.vertices_))
     {}
 
+    /// \brief Copy assignment operator
+    CurvedGeometry& operator=(const CurvedGeometry& that)
+    {
+      assert(refElement_ == that.refElement_);
+      vertices_ = that.vertices_;
+      return *this;
+    }
+
+    /// \brief Move assignment operator
+    CurvedGeometry& operator=(CurvedGeometry&& that)
+    {
+      assert(refElement_ == that.refElement_);
+      vertices_ = std::move(that.vertices_);
+      return *this;
+    }
+
+    /// \brief obtain the polynomial order of the parametrization
+    int order () const
+    {
+      return localBasis().order();
+    }
+
     /// \brief is this mapping affine?
     bool affine () const
     {
-      return refElement_.type().isSimplex() && int(vertices_.size()) == corners();
+      return refElement_.type().isSimplex() && order() == 1;
     }
 
     /// \brief obtain the name of the reference element
@@ -221,7 +220,7 @@ namespace Dune
       return refElement_.type();
     }
 
-    /// \brief obtain number of corners of the corresponding reference element */
+    /// \brief obtain number of corners of the corresponding reference element
     int corners () const
     {
       return refElement_.size(mydimension);
@@ -257,11 +256,8 @@ namespace Dune
      */
     GlobalCoordinate global (const LocalCoordinate& local) const
     {
-      if (mydimension == 0)
-        return vertices_[0];
-
       thread_local std::vector<typename LocalBasisTraits::RangeType> shapeValues;
-      localBasis_.evaluateFunction(local, shapeValues);
+      localBasis().evaluateFunction(local, shapeValues);
       assert(shapeValues.size() == vertices_.size());
 
       GlobalCoordinate out(0);
@@ -293,12 +289,11 @@ namespace Dune
       for (int i = 0; i < Traits::maxIteration(); ++i)
       {
         // Newton's method: DF^n dx^n = F^n, x^{n+1} -= dx^n
-        const GlobalCoordinate dglobal = (*this).global( x ) - globalCoord;
-        const bool invertible =
-          MatrixHelper::template xTRightInvA<mydimension, coorddimension>(jacobianTransposed(x), dglobal, dx);
+        const GlobalCoordinate dglobal = global(x) - globalCoord;
+        const bool invertible = MatrixHelper::xTRightInvA(jacobianTransposed(x), dglobal, dx);
 
         if (!invertible)
-          return LocalCoordinate(std::numeric_limits<ctype>::max());
+          break;
 
         // update x with correction
         x -= dx;
@@ -306,6 +301,7 @@ namespace Dune
         // for affine mappings only one iteration is needed
         if (affineMapping)
           break;
+
         if (dx.two_norm2() < tolerance)
           break;
       }
@@ -343,7 +339,7 @@ namespace Dune
      */
     ctype integrationElement (const LocalCoordinate& local) const
     {
-      return MatrixHelper::template sqrtDetAAT<mydimension, coorddimension>(jacobianTransposed(local));
+      return MatrixHelper::sqrtDetAAT(jacobianTransposed(local));
     }
 
     /// \brief Obtain the volume of the mapping's image
@@ -352,8 +348,8 @@ namespace Dune
      */
     Volume volume () const
     {
-      int p = 2*order; // TODO: needs to be checked
-      auto const& quadRule = QuadratureRules<ctype, mydimension>::rule(type(), p);
+      const int p = 2*localBasis().order(); // TODO: needs to be checked
+      const auto& quadRule = QuadratureRules<ctype, mydimension>::rule(type(), p);
       Volume vol(0);
       for (auto const& qp : quadRule)
         vol += integrationElement(qp.position()) * qp.weight();
@@ -373,7 +369,7 @@ namespace Dune
     JacobianTransposed jacobianTransposed (const LocalCoordinate& local) const
     {
       std::vector<typename LocalBasisTraits::JacobianType> shapeJacobians;
-      localBasis_.evaluateJacobian(local, shapeJacobians);
+      localBasis().evaluateJacobian(local, shapeJacobians);
       assert(shapeJacobians.size() == vertices_.size());
 
       JacobianTransposed out(0);
@@ -401,7 +397,7 @@ namespace Dune
       return geometry.refElement();
     }
 
-    auto const& vertices () const
+    const std::vector<GlobalCoordinate>& vertices () const
     {
       return vertices_;
     }
@@ -412,6 +408,11 @@ namespace Dune
       return refElement_;
     }
 
+    const LocalBasis& localBasis() const
+    {
+      return localFE_.localBasis();
+    }
+
     GlobalCoordinate normalEdge (const LocalCoordinate& local, const JacobianTransposed& J) const
     {
       GlobalCoordinate res{
@@ -433,13 +434,16 @@ namespace Dune
 
   private:
     ReferenceElement refElement_;
-    std::vector<GlobalCoordinate> vertices_;
-    //typename Traits::template VertexStorage<coorddimension>::type vertices_;
-
     LocalFECache localFECache_;
-    LocalBasis const& localBasis_;
+    LocalFiniteElement const& localFE_;
+
+    std::vector<GlobalCoordinate> vertices_;
   };
 
+  template <class ctype, int mydim, int cdim, int order>
+  using LagrangeCurvedGeometry = CurvedGeometry<ctype,mydim,cdim,
+    CurvedGeometryTraits<ctype, LagrangeLocalFiniteElementCache<ctype, ctype, mydim, order>> >;
+
 } // namespace Dune
 
 
diff --git a/src/dune-curvedgeometry.cc b/src/dune-curvedgeometry.cc
index 920d4a9df584f09076ce97adf16d7448b041a804..8a7e697ded51113e460da9cd591993e2e1ce75ab 100644
--- a/src/dune-curvedgeometry.cc
+++ b/src/dune-curvedgeometry.cc
@@ -9,6 +9,7 @@
 #include <iostream>
 
 #include <dune/common/parallel/mpihelper.hh>
+#include <dune/common/test/testsuite.hh>
 #include <dune/curvedgeometry/curvedgeometry.hh>
 #include <dune/geometry/quadraturerules.hh>
 #include <dune/grid/yaspgrid.hh>
@@ -21,13 +22,14 @@ int main(int argc, char** argv)
   using namespace Dune;
   MPIHelper& helper = MPIHelper::instance(argc, argv);
 
-  YaspGrid<2> grid({1.0,1.0}, {4,4});
+  YaspGrid<2> grid({8.0,8.0}, {32,32});
   using LocalCoordinate  = FieldVector<double,2>;
   using GlobalCoordinate = FieldVector<double,2>;
   using WorldCoordinate  = FieldVector<double,3>;
 
-  using LocalFECache = LagrangeLocalFiniteElementCache<double, double, 2, order>;
-  LocalFECache localFeCache;
+  using Geometry = LagrangeCurvedGeometry<double, 2, 3, order>;
+
+  FieldMatrix<double,2,2> I{{1,0},{0,1}};
 
   // coordinate projection
   auto project = [](GlobalCoordinate const& global) -> WorldCoordinate
@@ -37,22 +39,27 @@ int main(int argc, char** argv)
              std::sin(global[0])*std::cos(global[1]) };
   };
 
+  TestSuite test("curved geometry");
+
   for (auto const& e : elements(grid.leafGridView()))
   {
     // projection from local coordinates
     auto X = [project,geo=e.geometry()](LocalCoordinate const& local) -> WorldCoordinate { return project(geo.global(local)); };
 
-    // construct lagrange vertices
-    std::vector<WorldCoordinate> vertices;
-    localFeCache.get(e.type()).localInterpolation().interpolate(X, vertices);
-
     // create a curved geometry
-    CurvedGeometry<double,2,3,order> geometry(e.type(), vertices);
+    Geometry geometry(e.type(), X);
 
     auto const& quadRule = QuadratureRules<double,2>::rule(e.type(), 4);
     for (auto const& qp : quadRule) {
       auto Jt = geometry.jacobianTransposed(qp.position());
       auto Jtinv = geometry.jacobianInverseTransposed(qp.position());
+
+      FieldMatrix<double, 2, 2> res;
+      FMatrixHelp::multMatrix(Jt, Jtinv, res);
+      res -= I;
+      test.check(res.frobenius_norm() < std::sqrt(std::numeric_limits<double>::epsilon()), "J^-1 * J == I");
     }
   }
+
+  return test.report() ? 0 : 1;
 }