bvector.hh 9.48 KiB
// SPDX-FileCopyrightText: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#ifndef DUNE_PYTHON_ISTL_BVECTOR_HH
#define DUNE_PYTHON_ISTL_BVECTOR_HH
#include <cstddef>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>
#include <dune/common/typeutilities.hh>
//this added otherwise insert class wasn't possible on line ~190
#include <dune/python/common/typeregistry.hh>
#include <dune/python/common/fvecmatregistry.hh>
#include <dune/python/common/string.hh>
#include <dune/python/common/vector.hh>
#include <dune/python/istl/iterator.hh>
#include <dune/python/pybind11/operators.h>
#include <dune/python/pybind11/pybind11.h>
#include <dune/istl/bvector.hh>
#include <dune/istl/blocklevel.hh>
namespace Dune
{
namespace Python
{
namespace detail
{
template< class K, int n >
inline static void copy ( const char *ptr, const ssize_t *shape, const ssize_t *strides, Dune::FieldVector< K, n > &v )
{
if( *shape != static_cast< ssize_t >( n ) )
throw pybind11::value_error( "Invalid buffer size: " + std::to_string( *shape ) + " (should be: " + std::to_string( n ) + ")." );
for( ssize_t i = 0; i < static_cast< ssize_t >( n ); ++i )
v[ i ] = *reinterpret_cast< const K * >( ptr + i*(*strides) );
}
template< class B, class A >
inline static void copy ( const char *ptr, const ssize_t *shape, const ssize_t *strides, Dune::BlockVector< B, A > &v )
{
v.resize( *shape );
for( ssize_t i = 0; i < *shape; ++i )
copy( ptr + i*(*strides), shape+1, strides+1, v[ i ] );
}
template< class BlockVector >
inline static void copy ( pybind11::buffer buffer, BlockVector &v )
{
typedef typename BlockVector::field_type field_type;
typedef typename BlockVector::size_type size_type;
pybind11::buffer_info info = buffer.request();
if( info.format != pybind11::format_descriptor< field_type >::format() )
throw pybind11::value_error( "Incompatible buffer format." );
if( size_type(info.ndim) != blockLevel<BlockVector>() )
throw pybind11::value_error( "Block vectors can only be initialized from one-dimensional buffers." );
copy( static_cast< const char * >( info.ptr ), info.shape.data(), info.strides.data(), v );
}
// blockVectorGetItem
// ------------------
template< class BlockVector >
inline static pybind11::object blockVectorGetItem ( const pybind11::object &vObj, BlockVector &v, typename BlockVector::size_type index )
{
auto pos = v.find( index );
if( pos == v.end() )
throw pybind11::index_error( "Index " + std::to_string( index ) + " does not exist in block vector." );
pybind11::object result = pybind11::cast( *pos, pybind11::return_value_policy::reference );
pybind11::detail::keep_alive_impl( result, vObj );
return result;
}
} // namespace detail
// to_string
// ---------
template< class X >
inline static auto to_string ( const X &x )
-> std::enable_if_t< std::is_base_of< Imp::block_vector_unmanaged< typename X::block_type, typename X::size_type >, X >::value, std::string >
{
return "(" + join( ", ", [] ( auto &&x ) { return to_string( x ); }, x.begin(), x.end() ) + ")";
}
// registserBlockVector
// --------------------
template< class BlockVector, class... options >
inline void registerBlockVector ( pybind11::class_< BlockVector, options... > cls )
{
typedef typename BlockVector::field_type field_type;
typedef typename BlockVector::block_type block_type;
typedef typename BlockVector::size_type size_type;
using pybind11::operator""_a;
cls.def( "assign", [] ( BlockVector &self, const BlockVector &x ) { self = x; }, "x"_a );
cls.def( "copy", [] ( const BlockVector &self ) { return new BlockVector( self ); } );
cls.def( "__getitem__", [] ( const pybind11::object &self, size_type index ) {
return detail::blockVectorGetItem( self, pybind11::cast< BlockVector & >( self ), index );
} );
cls.def( "__getitem__", [] ( const pybind11::object &self, pybind11::iterable index ) {
BlockVector &v = pybind11::cast< BlockVector & >( self );
pybind11::tuple refs( pybind11::len( index ) );
std::size_t j = 0;
for( pybind11::handle i : index )
refs[ j++ ] = detail::blockVectorGetItem( self, v, pybind11::cast< size_type >( i ) );
return refs;
} );
cls.def( "__setitem__", [] ( BlockVector &self, size_type index, block_type value ) {
auto pos = self.find( index );
if( pos != self.end() )
*pos = value;
else
throw pybind11::index_error();
} );
cls.def( "__setitem__", [] ( BlockVector &self, pybind11::slice index, pybind11::iterable value ) {
std::size_t start, stop, step, length;
index.compute( self.N(), &start, &stop, &step, &length );
for( auto v : value )
{
if( start >= stop )
throw pybind11::value_error( "too many values passed" );
auto pos = self.find( start );
if( pos != self.end() )
*pos = pybind11::cast< block_type >( v );
else
throw pybind11::index_error();
start += step;
}
if( start < stop )
throw pybind11::value_error( "too few values passed" );
} );
cls.def( "__len__", [] ( const BlockVector &self ) { return self.N(); } );
detail::registerOneTensorInterface( cls );
detail::registerISTLIterators( cls );
cls.def( "__iadd__", [] ( BlockVector &self, const BlockVector& x ) -> BlockVector & { self += x; return self; } );
cls.def( "__isub__", [] ( BlockVector &self, const BlockVector& x ) -> BlockVector & { self -= x; return self; } );
cls.def( "__imul__", [] ( BlockVector &self, field_type x ) -> BlockVector & { self *= x; return self; } );
cls.def( "__idiv__", [] ( BlockVector &self, field_type x ) -> BlockVector & { self /= x; return self; } );
cls.def( "__itruediv__", [] ( BlockVector &self, field_type x ) -> BlockVector & { self /= x; return self; } );
cls.def( "__add__", [] ( const BlockVector &self, const BlockVector &x ) { BlockVector *copy = new BlockVector( self ); *copy += x; return copy; } );
cls.def( "__sub__", [] ( const BlockVector &self, const BlockVector &x ) { BlockVector *copy = new BlockVector( self ); *copy -= x; return copy; } );
cls.def( "__div__", [] ( const BlockVector &self, field_type x ) { BlockVector *copy = new BlockVector( self ); *copy /= x; return copy; } );
cls.def( "__truediv__", [] ( const BlockVector &self, field_type x ) { BlockVector *copy = new BlockVector( self ); *copy /= x; return copy; } );
cls.def( "__mul__", [] ( const BlockVector &self, field_type x ) { BlockVector *copy = new BlockVector( self ); *copy *= x; return copy; } );
cls.def( "__rmul__", [] ( const BlockVector &self, field_type x ) { BlockVector *copy = new BlockVector( self ); *copy *= x; return copy; } );
}
// registerBlockVector
// --------------------
//for the new bindings and arbitrary block size haven't
//the generator actually takes the scope into account which is why we do nothing with it here
//so when doing a dune.istl blockvector it doesn't actually define any of the rest of the bindings
template< class BlockVector, class ... options >
void registerBlockVector ( pybind11::handle /*scope*/, pybind11::class_<BlockVector, options ... > cls )
{
typedef typename BlockVector::size_type size_type;
using pybind11::operator""_a;
registerBlockVector( cls );
cls.def( pybind11::init( [] () { return new BlockVector(); } ) );
cls.def( pybind11::init( [] ( size_type size ) { return new BlockVector( size ); } ), "size"_a );
cls.def( pybind11::init( [] ( pybind11::buffer buffer ) {
BlockVector *self = new BlockVector();
detail::copy( buffer, *self );
return self;
} ) );
// cls.def( "__str__", [] ( const BlockVector &self ) { return to_string( self ); } );
cls.def( "assign", [] ( BlockVector &self, pybind11::buffer buffer ) { detail::copy( buffer, self ); }, "buffer"_a );
cls.def_property_readonly( "capacity", [] ( const BlockVector &self ) { return self.capacity(); } );
cls.def( "resize", [] ( BlockVector &self, size_type size ) { self.resize( size ); }, "size"_a );
}
//the auto class is needed so that run.algorithm can properly work
template< class BlockVector >
inline pybind11::class_< BlockVector > registerBlockVector ( pybind11::handle scope, const char *clsName = "BlockVector" )
{
//typedef typename BlockVector::size_type size_type;
using pybind11::operator""_a;
int rows = BlockVector::block_type::dimension;
std::string vectorTypename = "Dune::BlockVector< Dune::FieldVector< double, "+ std::to_string(rows) + " > >";
auto cls = Dune::Python::insertClass< BlockVector >( scope, clsName, Dune::Python::GenerateTypeName(vectorTypename), Dune::Python::IncludeFiles{"dune/istl/bvector.hh","dune/python/istl/bvector.hh"});
if (cls.second)
registerBlockVector( scope, cls.first );
return cls.first;
}
} // namespace Python
} // namespace Dune
#endif // #ifndef DUNE_PYTHON_ISTL_BVECTOR_HH