Skip to content
Snippets Groups Projects
Commit 10175300 authored by Christian Engwer's avatar Christian Engwer
Browse files

Merge branch '70-how-to-pass-a-python-grid-function-back-to-c' into 'master'

Resolve "How to pass a python grid function back to C++?"

Closes #70

See merge request !358
parents b48b1659 c16aa0cb
Branches master
No related tags found
No related merge requests found
Pipeline #45145 failed
......@@ -7,6 +7,8 @@
#include <type_traits>
#include <utility>
#include <dune/common/classname.hh>
#include <dune/functions/functionspacebases/defaultglobalbasis.hh>
#include <dune/python/common/dimrange.hh>
......@@ -18,7 +20,7 @@
#include <dune/python/pybind11/complex.h>
#include <dune/python/pybind11/pybind11.h>
#include <dune/python/pybind11/stl.h>
namespace Dune
{
......@@ -56,27 +58,46 @@ namespace Dune
pybind11::object obj_;
};
template<typename K, unsigned int n>
struct RangeType
{
using type = Dune::FieldVector< K, n >;
static void registerRange(pybind11::module scope)
{
registerFieldVector<K,n>(scope);
}
};
template<typename K>
struct RangeType<K,1>
{
using type = K;
static void registerRange(pybind11::module scope) {} // nothing to register, as K is a basic type
};
template< class GlobalBasis, class... options >
DUNE_EXPORT void registerGlobalBasis ( pybind11::module module, pybind11::class_< GlobalBasis, options... > &cls )
{
using pybind11::operator""_a;
typedef Dune::TypeTree::HybridTreePath<> DefaultTreePath;
using GridView = typename GlobalBasis::GridView;
using DefaultTreePath = Dune::TypeTree::HybridTreePath<>;
const std::size_t dimRange = DimRange< typename GlobalBasis::PreBasis::Node >::value;
const std::size_t dimWorld = GridView::dimensionworld;
cls.def( pybind11::init( [] ( const typename GlobalBasis::GridView &gridView ) { return new GlobalBasis( gridView ); } ), pybind11::keep_alive< 1, 2 >() );
cls.def( pybind11::init( [] ( const GridView &gridView ) { return new GlobalBasis( gridView ); } ), pybind11::keep_alive< 1, 2 >() );
cls.def( "__len__", [](const GlobalBasis& self) { return self.dimension(); } );
cls.def_property_readonly( "dimRange", [] ( pybind11::handle self ) { return pybind11::int_( dimRange ); } );
cls.def_property( "gridView",
[](const GlobalBasis& basis) { return basis.gridView(); },
[](GlobalBasis& basis, const typename GlobalBasis::GridView& gridView) { basis.update(gridView); });
[](GlobalBasis& basis, const GridView& gridView) { basis.update(gridView); });
typedef LocalViewWrapper< GlobalBasis > LocalView;
auto includes = IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto lv = insertClass< LocalView >( module, "LocalView",
GenerateTypeName("Dune::Python::LocalViewWrapper", MetaType<GlobalBasis>()),
IncludeFiles{"dune/python/functions/globalbasis.hh"}).first;
includes).first;
lv.def( "bind", &LocalView::bind );
lv.def( "unbind", &LocalView::unbind );
lv.def( "index", [] ( const LocalView &localView, int index ) { return localView.index( index ); });
......@@ -92,10 +113,36 @@ namespace Dune
cls.def( "interpolate", &Dune::Python::Functions::interpolate<GlobalBasis, bool> );
cls.def( "interpolate", &Dune::Python::Functions::interpolate<GlobalBasis, int> );
typedef Dune::FieldVector< double, dimRange > Range;
typedef Dune::Functions::DiscreteGlobalBasisFunction< GlobalBasis, HierarchicPythonVector< double >, DefaultNodeToRangeMap< GlobalBasis, DefaultTreePath >, Range > DiscreteFunction;
auto clsDiscreteFunction = insertClass< DiscreteFunction >( module, "DiscreteFunction", GenerateTypeName( cls, "DiscreteFunction" ) );
registerDiscreteFunction( module, clsDiscreteFunction.first );
using Range = typename RangeType< double, dimRange >::type;
RangeType< double, dimRange >::registerRange(module);
using Domain = Dune::FieldVector< double, dimWorld >;
registerFieldVector<double,dimWorld>(module);
using DiscreteFunction = Dune::Functions::DiscreteGlobalBasisFunction< GlobalBasis, HierarchicPythonVector< double >, DefaultNodeToRangeMap< GlobalBasis, DefaultTreePath >, Range >;
// register the HierarchicPythonVector
Dune::Python::addToTypeRegistry<HierarchicPythonVector<double>>(
GenerateTypeName("Dune::Python::HierarchicPythonVector", MetaType<double>()),
{"dune/python/functions/discretefunction.hh"}
);
// and add the DiscreteFunction to our module
auto clsDiscreteFunction = insertClass< DiscreteFunction >( module, "DiscreteFunction",
GenerateTypeName( "Dune::Functions::DiscreteGlobalBasisFunction",
MetaType<GlobalBasis>(),
MetaType<HierarchicPythonVector< double >>(),
"Dune::Python::DefaultNodeToRangeMap< " + Dune::Python::findInTypeRegistry<GlobalBasis>().first->second.name + ", Dune::TypeTree::HybridTreePath<> >",
MetaType<Range>()
), includes);
// register the GridViewFunction and register the implicit conversion
Dune::Python::addToTypeRegistry<Range(Domain)>(GenerateTypeName(className<Range(Domain)>()));
using GridViewFunction = Dune::Functions::GridViewFunction<Range(Domain), GridView>;
auto clsGridViewFunction = insertClass< GridViewFunction >( module, "GridViewFunction",
GenerateTypeName( "Dune::Functions::GridViewFunction",
MetaType<Range(Domain)>(),
MetaType<GridView>()
), includes);
clsGridViewFunction.first.def(pybind11::init<DiscreteFunction>());
pybind11::implicitly_convertible<DiscreteFunction, GridViewFunction>();
registerDiscreteFunction<GlobalBasis>( module, clsDiscreteFunction.first );
cls.def("asFunction", [] ( GlobalBasis &self, pybind11::buffer dofVector ) {
auto nodeToRangeMapPtr =
......@@ -109,6 +156,7 @@ namespace Dune
vectorPtr,
nodeToRangeMapPtr);
}, pybind11::keep_alive< 0, 1 >(), pybind11::keep_alive< 0, 2 >(), "dofVector"_a );
}
} // namespace Python
......
......@@ -10,3 +10,7 @@ dune_python_add_test(NAME pypoisson
SCRIPT poisson.py
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
LABELS quick)
dune_python_add_test(NAME pyfunction
SCRIPT function.py
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
LABELS quick)
import numpy as np
import dune.grid
from dune.grid import cartesianDomain
from dune.grid import yaspGrid
import dune.functions
from dune.generator import algorithm
import io
# we try to pass different dune-functions functions from the python side to C++
code="""
#include <dune/common/classname.hh>
#include <dune/grid/yaspgrid.hh>
#include <dune/functions/gridfunctions/gridfunction.hh>
#include <dune/functions/gridfunctions/discreteglobalbasisfunction.hh>
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <dune/python/functions/hierarchicvectorwrapper.hh>
#include <dune/python/functions/globalbasis.hh>
#include <dune/python/pybind11/pybind11.h>
#include <dune/python/pybind11/stl.h>
static const int dim = 2;
using Grid = Dune::YaspGrid<dim, Dune::EquidistantOffsetCoordinates<double, dim>>;
using GV = typename Grid::LeafGridView;
using Signature = double(Dune::FieldVector<double,2>);
using GridViewFunction = Dune::Functions::GridViewFunction<Signature,GV>;
template<typename B, typename V, typename NTRE, typename R>
std::string callScalar(const Dune::Functions::DiscreteGlobalBasisFunction<B,V,NTRE,R> & f)
{
return "DiscreteGlobalBasisFunction<...>";
}
std::string callScalar(const GridViewFunction & f)
{
return "GridViewFunction<GV>";
}
std::string callVector(const std::vector<GridViewFunction> & f)
{
return "vector<GridViewFunction<GV>>";
}
"""
class CppTypeInfo:
def __init__(self,name,includes):
self.cppTypeName = name
self.cppIncludes = includes
scalarTypeName = "GridViewFunction"
vectorTypeName = "std::vector<GridViewFunction>"
callScalar = algorithm.load('callScalar', io.StringIO(code), CppTypeInfo(scalarTypeName,[]));
callVector = algorithm.load('callVector', io.StringIO(code), CppTypeInfo(vectorTypeName,[]));
# number of grid elements (in one direction)
gridSize = 4
# create a YaspGrid of the unit square
domain = cartesianDomain([0,0],[1,1],[gridSize,gridSize])
grid = yaspGrid(domain, dimgrid=2)
# create a nodal Lagrange FE basis of order 1 and order 2
basis1 = dune.functions.defaultGlobalBasis(grid, dune.functions.Lagrange(order=1))
basis2 = dune.functions.defaultGlobalBasis(grid, dune.functions.Lagrange(order=2))
# create a DOF vectors
N1 = len(basis1)
x1 = np.ndarray(N1)
N2 = len(basis2)
x2 = np.ndarray(N2)
# interpolate data
basis1.interpolate(x1, lambda x : np.linalg.norm(x-np.array([0.5,0.5]))-0.3)
basis2.interpolate(x2, lambda x : np.linalg.norm(x-np.array([0.5,0.5]))-0.3)
# create a grid function
f1 = basis1.asFunction(x1)
f2 = basis2.asFunction(x2)
# calling the concrete interface requires the lagrange basis header
incLagrange = "#include <dune/functions/functionspacebases/lagrangebasis.hh>\n"
callConcrete1 = algorithm.load('callScalar', io.StringIO(incLagrange + code), f1);
callConcrete2 = algorithm.load('callScalar', io.StringIO(incLagrange + code), f2);
print ("Calling into C++")
print ("Called as " + callConcrete1(f1))
print ("Called as " + callConcrete2(f2))
# try to call as GridViewFunction
print ("Called as " + callScalar(f1))
print ("Called as " + callScalar(f2))
print ("Called as " + callVector([f1,f1]))
print ("Called as " + callVector([f1,f2]))
print ("Called as " + callVector([f2,f2]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment