Skip to content
Snippets Groups Projects
Commit a8132037 authored by Andreas Dedner's avatar Andreas Dedner
Browse files

imporve handling of includes in `importclass.py`

parent aa84555d
No related branches found
No related tags found
1 merge request!790Feature/add python bindings
import io
classExportCode="""
#include <cmath>
from io import StringIO
classACode="""
struct MyClassA {
MyClassA(int a,int b) : a_(a), b_(b) {}
int a_,b_;
};
"""
classBCode="""
#include <cmath>
template <class T> struct MyClassB {
MyClassB(T &t, int p) : a_(std::pow(t.a_,p)), b_(std::pow(t.b_,p)) {}
int a_,b_;
};
"""
runCode="""
template <class T> int run(T &t) {
return t.a_ * t.b_;
}
......@@ -19,16 +23,15 @@ def test_class_export():
from dune.generator.algorithm import run
from dune.generator import path
from dune.typeregistry import generateTypeName
code = io.StringIO(classExportCode)
cls = load("MyClassA",code,10,20)
print( run("run","myclass.hh",cls) )
cls = load("MyClassA",StringIO(classACode),10,20)
assert run("run",StringIO(runCode),cls) == 10*20
clsName,includes = generateTypeName("MyClassB",cls)
cls = load(clsName,[codes]+includes,cls,2)
print( run("run","myclass.hh",cls) )
cls = load(clsName,StringIO(classBCode),cls,2)
assert run("run",StringIO(runCode),cls) == 10**2*20**2
if __name__ == "__main__":
try:
from dune.common.module import get_dune_py_dir
_ = get_dune_py_dir()
test_class_export()
except:
except ImportError:
pass
......@@ -11,27 +11,21 @@ def load(className, includeFiles, *args):
source = '#include <config.h>\n\n'
source += '#define USING_DUNE_PYTHON 1\n\n'
includes = []
if isString(includeFiles):
if not os.path.dirname(includeFiles):
with open(includeFiles, "r") as include:
source += include.read()
source += "\n"
else:
source += "#include <"+includeFiles+">\n"
includes += [includeFiles]
elif hasattr(includeFiles,"readable"): # for IOString
with includeFiles as include:
source += include.read()
source += "\n"
elif isinstance(includeFiles, list):
for includefile in includeFiles:
if isString(includeFiles) or hasattr(includeFiles,"readable"):
includeFiles = [includeFiles]
for includefile in includeFiles:
if isString(includefile):
if not os.path.dirname(includefile):
with open(includefile, "r") as include:
source += include.read()
source += "\n"
else:
source += "#include <"+includefile+">\n"
includes += [includefile]
else:
source += "#include <"+includefile+">\n"
includes += [includefile]
elif hasattr(includefile,"readable"): # for IOString
with includefile as include:
source += include.read()
source += "\n"
argTypes = []
for arg in args:
......@@ -52,6 +46,7 @@ def load(className, includeFiles, *args):
source += "PYBIND11_MODULE( " + moduleName + ", module )\n"
source += "{\n"
includes += ["python/dune/generated/"+moduleName+".cc"]
source += " auto cls = Dune::Python::insertClass< "+className+\
" >( module, \"cls\","+\
"Dune::Python::GenerateTypeName(\""+className+"\"),"+\
......@@ -63,4 +58,9 @@ def load(className, includeFiles, *args):
source += ");\n"
source += "}"
source = "#ifndef def_"+moduleName+\
"\n#define def_"+moduleName+"\n"+\
source+\
"\n#endif\n"
return builder.load(moduleName, source, signature).cls(*args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment