Commit 06fd1478 authored by Andreas Dedner's avatar Andreas Dedner

[!112] add a plot function taking an 'on' parameter with values 'cells' and 'points'

Merge branch 'feature/addGeneralCellPlotting' into 'master'

See merge request [staging/dune-python!112]

  [staging/dune-python!112]: Nonestaging/dune-python/merge_requests/112
parents 1d040c82 f3b205ca
...@@ -194,6 +194,23 @@ ...@@ -194,6 +194,23 @@
} }
], ],
"metadata": { "metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"livereveal": { "livereveal": {
"center": false, "center": false,
"controls": false, "controls": false,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# # Integrating dune-functions # # Integrating dune-functions
# In[1]: # In[ ]:
import time import time
...@@ -95,6 +95,7 @@ rhs_dofs = numpy.zeros(taylorHoodBasis.dimension) ...@@ -95,6 +95,7 @@ rhs_dofs = numpy.zeros(taylorHoodBasis.dimension)
# In[ ]: # In[ ]:
from dune.generator import algorithm from dune.generator import algorithm
def onBoundary(x): def onBoundary(x):
return any(x[i] < 1e-10 or x[i] > (1-1e-10) for i in range(len(x))) return any(x[i] < 1e-10 or x[i] > (1-1e-10) for i in range(len(x)))
......
...@@ -261,8 +261,7 @@ ...@@ -261,8 +261,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from dune.plotting import plotPointData\n", "lgf.plot(figsize=(9,9), gridLines=None)"
"plotPointData(lgf, figsize=(9,9), gridLines=None)"
] ]
}, },
{ {
...@@ -274,6 +273,23 @@ ...@@ -274,6 +273,23 @@
} }
], ],
"metadata": { "metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"livereveal": { "livereveal": {
"center": false, "center": false,
"controls": false, "controls": false,
......
...@@ -25,7 +25,7 @@ import math ...@@ -25,7 +25,7 @@ import math
# In[ ]: # In[ ]:
from dune.grid import cartesianDomain from dune.grid import cartesianDomain, gridFunction
from dune.alugrid import aluConformGrid from dune.alugrid import aluConformGrid
vertices = numpy.array([(0,0), (1,0), (1,1), (0,1), (-1,1), (-1,0), (-1,-1), (0,-1)]) vertices = numpy.array([(0,0), (1,0), (1,1), (0,1), (-1,1), (-1,0), (-1,-1), (0,-1)])
triangles = numpy.array([(2,0,1), (0,2,3), (4,0,3), (0,4,5), (6,0,5), (0,6,7)]) triangles = numpy.array([(2,0,1), (0,2,3), (4,0,3), (0,4,5), (6,0,5), (0,6,7)])
...@@ -41,7 +41,7 @@ class LinearShapeFunction: ...@@ -41,7 +41,7 @@ class LinearShapeFunction:
self.ofs = ofs self.ofs = ofs
self.grad = grad self.grad = grad
def evaluate(self, local): def evaluate(self, local):
return self.ofs + sum([x*y for x, y in zip(self.grad, local)]) return [self.ofs + sum([x*y for x, y in zip(self.grad, local)])]
def gradient(self, local): def gradient(self, local):
return self.grad return self.grad
...@@ -62,7 +62,7 @@ for i in range(dim): ...@@ -62,7 +62,7 @@ for i in range(dim):
from dune.geometry import quadratureRules from dune.geometry import quadratureRules
from dune.istl import blockVector from dune.istl import blockVector
f = lambda v: sum(2.0 * x * (1 - x) for x in v) f = lambda v: [sum(2.0 * x * (1 - x) for x in v)]
dim, indexSet = aluView.dimension, aluView.indexSet dim, indexSet = aluView.dimension, aluView.indexSet
rhs = blockVector(indexSet.size(dim)) rhs = blockVector(indexSet.size(dim))
...@@ -83,10 +83,10 @@ for e in aluView.elements: ...@@ -83,10 +83,10 @@ for e in aluView.elements:
# In[ ]: # In[ ]:
from dune.istl import bcrsMatrix, BCRSMatrix11 from dune.istl import BuildMode, bcrsMatrix, BCRSMatrix
dim, indexSet = aluView.dimension, aluView.indexSet dim, indexSet = aluView.dimension, aluView.indexSet
matrix = bcrsMatrix((indexSet.size(dim), indexSet.size(dim)), 8, 0.1, BCRSMatrix11.implicit) matrix = bcrsMatrix((indexSet.size(dim), indexSet.size(dim)), 8, 0.1, BuildMode.implicit)
quadrature = quadratureRules(1) quadrature = quadratureRules(1)
for e in aluView.elements: for e in aluView.elements:
...@@ -145,13 +145,11 @@ _ = solver(u, rhs) ...@@ -145,13 +145,11 @@ _ = solver(u, rhs)
@gridFunction(aluView) @gridFunction(aluView)
def lgf( element, local ): def lgf( element, local ):
return [ sum( phi.evaluate(local) * u[indexSet.subIndex(element, i, dim)]\ return [ sum( phi.evaluate(local) * u[indexSet.subIndex(element, i, dim)] for i, phi in enumerate(p1ShapeFunctionSet)) ]
for i, phi in enumerate(p1ShapeFunctionSet)) ]
_ = aluView.writeVTK("fem2d", pointdata={"u": lgf}) _ = aluView.writeVTK("fem2d", pointdata={"u": lgf})
# In[ ]: # In[ ]:
from dune.plotting import plotPointData lgf.plot(figsize=(9,9), gridLines=None)
plotPointData(lgf, figsize=(9,9), gridLines=None)
...@@ -189,9 +189,8 @@ ...@@ -189,9 +189,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from dune.plotting import plotPointData\n",
"cgrid = yaspView.function(lambda e,p: [c[mapper.index(e)]])\n", "cgrid = yaspView.function(lambda e,p: [c[mapper.index(e)]])\n",
"plotPointData(cgrid, figsize=(9,9))" "cgrid.plot(figsize=(9,9),on=\"cells\")"
] ]
}, },
{ {
...@@ -234,8 +233,7 @@ ...@@ -234,8 +233,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from dune.plotting import plotPointData\n", "cgrid.plot( figsize=(9,9), on=\"cells\")"
"plotPointData(cgrid, figsize=(9,9))"
] ]
}, },
{ {
...@@ -259,7 +257,7 @@ ...@@ -259,7 +257,7 @@
"while t < 0.5:\n", "while t < 0.5:\n",
" t += evolve(yaspView, mapper, c, b, t)\n", " t += evolve(yaspView, mapper, c, b, t)\n",
"print(\"time used:\", time.time()-start)\n", "print(\"time used:\", time.time()-start)\n",
"plotPointData(cgrid, figsize=(9,9),gridLines=\"\")" "cgrid.plot(figsize=(9,9),gridLines=\"\",on=\"cells\")"
] ]
}, },
{ {
...@@ -271,6 +269,23 @@ ...@@ -271,6 +269,23 @@
} }
], ],
"metadata": { "metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"livereveal": { "livereveal": {
"center": false, "center": false,
"controls": false, "controls": false,
......
...@@ -110,9 +110,8 @@ print("time used:", time.time()-start) ...@@ -110,9 +110,8 @@ print("time used:", time.time()-start)
# In[ ]: # In[ ]:
from dune.plotting import plotPointData
cgrid = yaspView.function(lambda e,p: [c[mapper.index(e)]]) cgrid = yaspView.function(lambda e,p: [c[mapper.index(e)]])
plotPointData(cgrid, figsize=(9,9)) cgrid.plot(figsize=(9,9),on="cells")
# Let's do this on the C++ side instead: # Let's do this on the C++ side instead:
...@@ -136,8 +135,7 @@ print("time used:", time.time()-start) ...@@ -136,8 +135,7 @@ print("time used:", time.time()-start)
# In[ ]: # In[ ]:
from dune.plotting import plotPointData cgrid.plot( figsize=(9,9), on="cells")
plotPointData(cgrid, figsize=(9,9))
# Now that its a bit faster let us recompute the solution on a finer grid: # Now that its a bit faster let us recompute the solution on a finer grid:
...@@ -153,4 +151,4 @@ t = 0.0 ...@@ -153,4 +151,4 @@ t = 0.0
while t < 0.5: while t < 0.5:
t += evolve(yaspView, mapper, c, b, t) t += evolve(yaspView, mapper, c, b, t)
print("time used:", time.time()-start) print("time used:", time.time()-start)
plotPointData(cgrid, figsize=(9,9),gridLines="") cgrid.plot(figsize=(9,9),gridLines="",on="cells")
...@@ -97,9 +97,9 @@ def plot(self, function=None, *args, **kwargs): ...@@ -97,9 +97,9 @@ def plot(self, function=None, *args, **kwargs):
else: else:
try: try:
grid = function.grid grid = function.grid
dune.plotting.plotPointData(solution=function,*args,**kwargs) dune.plotting.plot(solution=function,*args,**kwargs)
except AttributeError: except AttributeError:
dune.plotting.plotPointData(solution=self.function(function),*args,**kwargs) dune.plotting.plot(solution=self.function(function),*args,**kwargs)
@deprecated("use the `gridFunction` decorator") @deprecated("use the `gridFunction` decorator")
def globalGridFunction(gv, evaluator): def globalGridFunction(gv, evaluator):
......
...@@ -34,15 +34,17 @@ def plotGrid(grid, gridLines="black", figure=None, ...@@ -34,15 +34,17 @@ def plotGrid(grid, gridLines="black", figure=None,
pyplot.show(block=block) pyplot.show(block=block)
def _plotPointData(fig, grid, solution, level=0, gridLines="black", def _plotData(fig, grid, solution, level=0, gridLines="black",
component=None, vectors=None, nofVectors=None, component=None, vectors=None, nofVectors=None,
xlim=None, ylim=None, clim=None, cmap=None, colorbar=True): xlim=None, ylim=None, clim=None, cmap=None, colorbar=True,
on="cell"):
if (gridLines is not None) and (gridLines != ""): if (gridLines is not None) and (gridLines != ""):
_plotGrid(fig, grid, gridLines=gridLines) _plotGrid(fig, grid, gridLines=gridLines)
if solution is not None: if solution is not None:
if not any(g.isNone for g in grid.indexSet.types(0)): if on == "points":
assert not any(gt.isNone for gt in grid.indexSet.types(0)), "Can't plot point data with polygonal grids, use `on=\"cells\" in plotting command"
triangulation = grid.triangulation(level) triangulation = grid.triangulation(level)
data = solution.pointData(level) data = solution.pointData(level)
try: try:
...@@ -152,7 +154,7 @@ def plotPointData(solution, level=0, gridLines="black", ...@@ -152,7 +154,7 @@ def plotPointData(solution, level=0, gridLines="black",
grid = solution grid = solution
solution = None solution = None
if not grid.dimension == 2: if not grid.dimension == 2:
print("inline plotting so far only available for 2d grids") raise ValueError("inline plotting so far only available for 2d grids")
return return
if figure is None: if figure is None:
...@@ -160,7 +162,33 @@ def plotPointData(solution, level=0, gridLines="black", ...@@ -160,7 +162,33 @@ def plotPointData(solution, level=0, gridLines="black",
show = True show = True
else: else:
show = False show = False
_plotPointData(figure,grid,solution,level,gridLines,None,vectors,nofVectors,xlim,ylim,clim,cmap,colorbar=colorbar) _plotData(figure,grid,solution,level,gridLines,None,
vectors,nofVectors,xlim,ylim,clim,cmap,
colorbar=colorbar,on="points")
if show:
pyplot.show(block=block)
def plotCellData(solution, level=0, gridLines="black",
vectors=None, nofVectors=None, figure=None,
xlim=None, ylim=None, clim=None, figsize=None, cmap=None,
colorbar=True):
try:
grid = solution.grid
except:
grid = solution
solution = None
if not grid.dimension == 2:
raise ValueError("inline plotting so far only available for 2d grids")
return
if figure is None:
figure = pyplot.figure(figsize=figsize)
show = True
else:
show = False
_plotData(figure,grid,solution,level,gridLines,None,vectors,nofVectors,xlim,ylim,clim,cmap,
colorbar=colorbar,on="cells")
if show: if show:
pyplot.show(block=block) pyplot.show(block=block)
...@@ -173,7 +201,7 @@ def plotComponents(solution, level=0, show=None, gridLines="black", figure=None, ...@@ -173,7 +201,7 @@ def plotComponents(solution, level=0, show=None, gridLines="black", figure=None,
grid = solution grid = solution
solution = None solution = None
if not grid.dimension == 2: if not grid.dimension == 2:
print("inline plotting so far only available for 2d grids") raise ValueError("inline plotting so far only available for 2d grids")
return return
if not show: if not show:
...@@ -187,15 +215,33 @@ def plotComponents(solution, level=0, show=None, gridLines="black", figure=None, ...@@ -187,15 +215,33 @@ def plotComponents(solution, level=0, show=None, gridLines="black", figure=None,
# first the grid if required # first the grid if required
if (gridLines is not None) and (gridLines != ""): if (gridLines is not None) and (gridLines != ""):
pyplot.subplot(subfig) pyplot.subplot(subfig)
_plotPointData(figure,grid,None,level,gridLines,None,False,None,xlim,ylim,clim,cmap) _plotData(figure,grid,None,level,gridLines,None,False,None,xlim,ylim,clim,cmap,
on="points")
# add the data # add the data
for p in show: for p in show:
pyplot.subplot(subfig+offset+p) pyplot.subplot(subfig+offset+p)
_plotPointData(figure,grid,solution,level,"",p,False,None,xlim,ylim,clim,cmap,False) _plotData(figure,grid,solution,level,"",p,False,None,xlim,ylim,clim,cmap,False,
on="points")
pyplot.show(block=block) pyplot.show(block=block)
def plot(solution,*args,**kwargs):
try:
grid = solution.grid
except:
grid = solution
defaultOn = "cells" if any(gt.isNone for gt in grid.indexSet.types(0)) else "points"
use = kwargs.pop("on",defaultOn)
if use == "points":
plotPointData(solution,*args,**kwargs)
elif use == "components-points":
plotComponents(solution,*args,**kwargs)
elif use == "cells":
plotCellData(solution,*args,**kwargs)
else:
raise ValueError("wrong value for 'on' parameter should be one of 'points','cells','components-points'")
def mayaviPointData(solution, level=0, component=0): def mayaviPointData(solution, level=0, component=0):
grid = solution.grid grid = solution.grid
from mayavi import mlab from mayavi import mlab
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment