Skip to content
Snippets Groups Projects
Commit fd835749 authored by Simon Praetorius's avatar Simon Praetorius
Browse files

Add a tensortraits type to combine old and new matrix/vector/tensor implementations

parent 94d3f18b
No related branches found
No related tags found
No related merge requests found
......@@ -104,6 +104,7 @@ install(FILES
tensordot.hh
tensormixin.hh
tensorspan.hh
tensortraits.hh
timer.hh
transpose.hh
tupleutility.hh
......
......@@ -12,6 +12,7 @@ namespace Dune::Concept::Archetypes {
template <std::size_t r>
struct Extents
{
using rank_type = std::size_t;
using index_type = std::size_t;
static constexpr std::size_t rank () { return r; }
......
......@@ -8,6 +8,7 @@
#include <array>
#include <concepts>
#include <dune/common/tensortraits.hh>
#include <dune/common/concepts/archetypes/tensor.hh>
namespace Dune::Concept {
......@@ -42,15 +43,15 @@ concept Extents = requires(E extents, std::size_t i)
* - `Dune::TensorSpan`
*/
template <class T>
concept Tensor = Extents<T> && requires(T tensor)
concept Tensor = requires(T tensor)
{
requires Extents<typename T::extents_type>;
{ tensor.extents() } -> std::convertible_to<typename T::extents_type>;
requires Extents<typename TensorTraits<T>::extents_type>;
{ TensorTraits<T>::extents(tensor) } -> std::convertible_to<typename TensorTraits<T>::extents_type>;
};
//! A `TensorWithRank` is a `Tensor` with given tensor-rank `rank`.
template <class T, std::size_t rank>
concept TensorWithRank = Tensor<T> && T::rank() == rank;
concept TensorWithRank = Tensor<T> && TensorTraits<T>::rank() == rank;
//! A `Vector` is a `Tensor` of rank 1.
template <class T>
......@@ -75,14 +76,14 @@ static_assert(Concept::Matrix<Archetypes::Tensor<double,2>>);
*/
template <class T>
concept RandomAccessTensor = Tensor<T> &&
requires(T tensor, std::array<typename T::index_type, T::rank()> indices)
requires(T tensor, std::array<typename TensorTraits<T>::index_type, TensorTraits<T>::rank()> indices)
{
tensor[indices];
};
//! A `RandomAccessTensorWithRank` is a `RandomAccessTensor` with given tensor-rank `rank`.
template <class T, std::size_t rank>
concept RandomAccessTensorWithRank = RandomAccessTensor<T> && T::rank() == rank;
concept RandomAccessTensorWithRank = RandomAccessTensor<T> && TensorTraits<T>::rank() == rank;
//! A `RandomAccessVector` is a `RandomAccessTensor` of rank 1.
template <class T>
......
......@@ -201,6 +201,17 @@ namespace Dune
return asImp().mat_access(i);
}
//! random access with array of indices
value_type & operator[] (std::array<size_type,2> i)
{
return asImp().mat_access(i[0])[i[1]];
}
const value_type & operator[] (std::array<size_type,2> i) const
{
return asImp().mat_access(i[0])[i[1]];
}
//! size method (number of rows)
constexpr size_type size() const
{
......
......@@ -303,6 +303,21 @@ namespace Dune {
return asImp()[i];
}
//! random access with array of indices
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
value_type & operator[] (std::array<SizeType,1> i)
{
return asImp()[i[0]];
}
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
const value_type & operator[] (std::array<SizeType,1> i) const
{
return asImp()[i[0]];
}
//! return reference to first element
constexpr value_type& front()
{
......
......@@ -52,6 +52,30 @@ namespace Dune
typedef typename FieldTraits<K>::real_type real_type;
};
template< class K >
struct TensorTraits< DynamicMatrix<K> >
{
using index_type = typename DenseMatVecTraits<DynamicMatrix<K>>::size_type;
using extents_type = Std::extents<index_type, Std::dynamic_extent, Std::dynamic_extent>;
using rank_type = typename extents_type::rank_type;
/// \brief Number of elements in all dimensions of the array, \related extents
static constexpr extents_type extents (const DynamicMatrix<K>& tensor) noexcept { return extents_type{tensor.N(), tensor.M()}; }
/// \brief Number of dimensions of the array
static constexpr rank_type rank () noexcept { return 2; }
/// \brief Number of dimension with dynamic size
static constexpr rank_type rank_dynamic () noexcept { return 2; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr std::size_t static_extent (rank_type /*r*/) noexcept { return Std::dynamic_extent; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr index_type extent (const DynamicMatrix<K>& tensor, rank_type r) noexcept { return r == 0 ? tensor.N() : tensor.M(); }
};
/** \brief Construct a matrix with a dynamic size.
*
* \tparam K is the field type (use float, double, complex, etc)
......
......@@ -13,13 +13,16 @@
#include <initializer_list>
#include <limits>
#include <utility>
#include <vector>
#include "boundschecking.hh"
#include "exceptions.hh"
#include "genericiterator.hh"
#include <dune/common/boundschecking.hh>
#include <dune/common/densevector.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/genericiterator.hh>
#include <dune/common/tensortraits.hh>
#include <dune/common/std/extents.hh>
#include <dune/common/std/span.hh>
#include <vector>
#include "densevector.hh"
namespace Dune {
......@@ -48,6 +51,31 @@ namespace Dune {
typedef typename FieldTraits< K >::real_type real_type;
};
template< class K, class Allocator >
struct TensorTraits< DynamicVector< K, Allocator > >
{
using index_type = typename DenseMatVecTraits<DynamicVector<K,Allocator>>::size_type;
using extents_type = Std::extents<index_type, Std::dynamic_extent>;
using rank_type = typename extents_type::rank_type;
/// \brief Number of elements in all dimensions of the array, \related extents
static constexpr extents_type extents (const DynamicVector<K,Allocator>& tensor) noexcept { return extents_type{tensor.size()}; }
/// \brief Number of dimensions of the array
static constexpr rank_type rank () noexcept { return 1; }
/// \brief Number of dimension with dynamic size
static constexpr rank_type rank_dynamic () noexcept { return 1; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr std::size_t static_extent (rank_type /*r*/) noexcept { return Std::dynamic_extent; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr index_type extent (const DynamicVector<K,Allocator>& tensor, rank_type /*r*/) noexcept { return tensor.size(); }
};
/** \brief Construct a vector with a dynamic size.
*
* \tparam K is the field type (use float, double, complex, etc)
......@@ -158,6 +186,21 @@ namespace Dune {
return _data[i];
}
//! random access with array of indices
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
K & operator[] (std::array<SizeType,1> i)
{
return _data[i[0]];
}
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
const K & operator[] (std::array<SizeType,1> i) const
{
return _data[i[0]];
}
//! return pointer to underlying array
K* data() noexcept
{
......
......@@ -19,6 +19,8 @@
#include <dune/common/promotiontraits.hh>
#include <dune/common/typetraits.hh>
#include <dune/common/matrixconcepts.hh>
#include <dune/common/tensortraits.hh>
#include <dune/common/std/extents.hh>
namespace Dune
{
......@@ -104,6 +106,30 @@ namespace Dune
typedef typename FieldTraits<K>::real_type real_type;
};
template< class K, int ROWS, int COLS >
struct TensorTraits< FieldMatrix<K,ROWS,COLS> >
{
using index_type = typename DenseMatVecTraits<FieldMatrix<K,ROWS,COLS>>::size_type;
using extents_type = Std::extents<std::size_t, std::size_t(ROWS),std::size_t(COLS)>;
using rank_type = typename extents_type::rank_type;
/// \brief Number of elements in all dimensions of the array, \related extents
static constexpr extents_type extents (const FieldMatrix<K,ROWS,COLS>& /*tensor*/) noexcept { return extents_type{}; }
/// \brief Number of dimensions of the array
static constexpr rank_type rank () noexcept { return 2; }
/// \brief Number of dimension with dynamic size
static constexpr rank_type rank_dynamic () noexcept { return 0; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr std::size_t static_extent (rank_type r) noexcept { return r == 0 ? ROWS : COLS; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr index_type extent (const FieldMatrix<K,ROWS,COLS>& /*tensor*/, rank_type r) noexcept { return r == 0 ? ROWS : COLS; }
};
/**
@brief A dense n x m matrix.
......
......@@ -19,8 +19,10 @@
#include <dune/common/ftraits.hh>
#include <dune/common/math.hh>
#include <dune/common/promotiontraits.hh>
#include <dune/common/tensortraits.hh>
#include <dune/common/typetraits.hh>
#include <dune/common/typeutilities.hh>
#include <dune/common/std/extents.hh>
namespace Dune {
......@@ -50,6 +52,30 @@ namespace Dune {
typedef typename FieldTraits<K>::real_type real_type;
};
template< class K, int SIZE >
struct TensorTraits< FieldVector<K,SIZE> >
{
using index_type = typename DenseMatVecTraits<FieldVector<K,SIZE>>::size_type;
using extents_type = Std::extents<std::size_t, std::size_t(SIZE)>;
using rank_type = typename extents_type::rank_type;
/// \brief Number of elements in all dimensions of the array, \related extents
static constexpr extents_type extents (const FieldVector<K,SIZE>& /*tensor*/) noexcept { return extents_type{}; }
/// \brief Number of dimensions of the array
static constexpr rank_type rank () noexcept { return 1; }
/// \brief Number of dimension with dynamic size
static constexpr rank_type rank_dynamic () noexcept { return 0; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr std::size_t static_extent (rank_type /*r*/) noexcept { return SIZE; }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr index_type extent (const FieldVector<K,SIZE>& /*tensor*/, rank_type /*r*/) noexcept { return SIZE; }
};
/**
* @brief TMP to check the size of a DenseVectors statically, if possible.
*
......@@ -209,6 +235,21 @@ namespace Dune {
return _data[i];
}
//! random access with array of indices
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
reference operator[] (std::array<SizeType,1> i)
{
return _data[i[0]];
}
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
const_reference operator[] (std::array<SizeType,1> i) const
{
return _data[i[0]];
}
//! Return pointer to underlying array
constexpr K* data () noexcept
{
......@@ -402,6 +443,21 @@ namespace Dune {
return _data;
}
//! random access with array of indices
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
reference operator[] (std::array<SizeType,1> i)
{
return _data;
}
template <class SizeType>
requires std::convertible_to<SizeType, size_type>
const_reference operator[] (std::array<SizeType,1> i) const
{
return _data;
}
//! return pointer to underlying array
constexpr K* data () noexcept
{
......
......@@ -14,6 +14,7 @@
#include <dune/common/integersequence.hh>
#include <dune/common/std/extents.hh>
#include <dune/common/rangeutilities.hh>
#include <dune/common/tensortraits.hh>
#include <dune/common/typetraits.hh>
#include <dune/common/concepts/tensor.hh>
#include <dune/common/std/extents.hh>
......@@ -103,20 +104,23 @@ for (int l = 0; l < a.extent(aSeq[0]); ++l)
template <std::size_t K = 0,
class A, class ASeq, class ASeqInv,
class B, class BSeq, class BSeqInv,
class C, class BinaryOp1, class BinaryOp2>
class C, class BinaryOp1, class BinaryOp2,
class ATraits = TensorTraits<A>,
class BTraits = TensorTraits<B>,
class CTraits = TensorTraits<C>>
constexpr DUNE_FORCE_INLINE
void tensorDotImpl (const A& a, ASeq aSeq, ASeqInv aSeqInv,
const B& b, BSeq bSeq, BSeqInv bSeqInv,
C& c, BinaryOp1 op1, BinaryOp2 op2,
std::array<typename A::index_type,A::rank()> aIndices = {},
std::array<typename B::index_type,B::rank()> bIndices = {},
std::array<typename C::index_type,C::rank()> cIndices = {})
std::array<typename ATraits::index_type,ATraits::rank()> aIndices = {},
std::array<typename BTraits::index_type,BTraits::rank()> bIndices = {},
std::array<typename CTraits::index_type,CTraits::rank()> cIndices = {})
{
if constexpr(aSeq.size() > 0 && bSeq.size() > 0) {
// first, loop over the contraction indices of A and B
constexpr std::size_t I = head(aSeq);
constexpr std::size_t J = head(bSeq);
for (typename A::index_type k = 0; k < a.extent(I); ++k) {
for (typename ATraits::index_type k = 0; k < ATraits::extent(a,I); ++k) {
aIndices[I] = k;
bIndices[J] = k;
tensorDotImpl<K>(a,tail(aSeq),aSeqInv,b,tail(bSeq),bSeqInv,c,op1,op2,aIndices,bIndices,cIndices);
......@@ -125,7 +129,7 @@ void tensorDotImpl (const A& a, ASeq aSeq, ASeqInv aSeqInv,
else if constexpr(aSeqInv.size() > 0) {
// second, loop over the remaining indices of tensor A
constexpr std::size_t I = head(aSeqInv);
for (typename A::index_type i = 0; i < a.extent(I); ++i) {
for (typename ATraits::index_type i = 0; i < ATraits::extent(a,I); ++i) {
aIndices[I] = i;
cIndices[K] = i;
tensorDotImpl<K+1>(a,aSeq,tail(aSeqInv),b,bSeq,bSeqInv,c,op1,op2,aIndices,bIndices,cIndices);
......@@ -134,7 +138,7 @@ void tensorDotImpl (const A& a, ASeq aSeq, ASeqInv aSeqInv,
else if constexpr(bSeqInv.size() > 0) {
// third, loop over the remaining indices of tensor B
constexpr std::size_t J = head(bSeqInv);
for (typename B::index_type j = 0; j < b.extent(J); ++j) {
for (typename BTraits::index_type j = 0; j < BTraits::extent(b,J); ++j) {
bIndices[J] = j;
cIndices[K] = j;
tensorDotImpl<K+1>(a,aSeq,aSeqInv,b,bSeq,tail(bSeqInv),c,op1,op2,aIndices,bIndices,cIndices);
......@@ -170,16 +174,20 @@ constexpr auto tensordotOut (const A& a, std::index_sequence<II...> aSeq,
{
static_assert(aSeq.size() == bSeq.size());
using ATraits = TensorTraits<A>;
using BTraits = TensorTraits<B>;
using CTraits = TensorTraits<C>;
// create integer sequences that do not include the contraction indices
const auto aSeqInv = difference<A::rank()>(aSeq); // {0,1,...A::rank()-1} \ {II...}
const auto bSeqInv = difference<B::rank()>(bSeq); // {0,1,...B::rank()-1} \ {JJ...}
static_assert(aSeqInv.size() + bSeqInv.size() == C::rank());
const auto aSeqInv = difference<ATraits::rank()>(aSeq); // {0,1,...A::rank()-1} \ {II...}
const auto bSeqInv = difference<BTraits::rank()>(bSeq); // {0,1,...B::rank()-1} \ {JJ...}
static_assert(aSeqInv.size() + bSeqInv.size() == CTraits::rank());
// the extents of a and the extents of b must be compatible to c
using EA = typename A::extents_type;
using EB = typename B::extents_type;
using EA = typename ATraits::extents_type;
using EB = typename BTraits::extents_type;
static_assert(Impl::checkStaticExtents<EA,EB>(aSeq, bSeq));
assert((Impl::checkExtents(a.extents(), aSeq, b.extents(), bSeq)));
assert((Impl::checkExtents(ATraits::extents(a), aSeq, BTraits::extents(b), bSeq)));
// Objects of `A` and `B` must be different from object `C`
assert((void*)(&a) != (void*)(&c) && (void*)(&b) != (void*)(&c));
......@@ -208,17 +216,21 @@ constexpr void tensordotOut (const A& a, const B& b, C& c,
std::integral_constant<std::size_t,N> axes = {},
BinaryOp1 op1 = {}, BinaryOp2 op2 = {})
{
using SeqI = typename StaticIntegralRange<std::size_t,A::rank(),A::rank()-N>::integer_sequence;
using InvSeqI = std::make_index_sequence<A::rank()-N>;
using ATraits = TensorTraits<A>;
using BTraits = TensorTraits<B>;
using CTraits = TensorTraits<C>;
using SeqI = typename StaticIntegralRange<std::size_t,ATraits::rank(),ATraits::rank()-N>::integer_sequence;
using InvSeqI = std::make_index_sequence<ATraits::rank()-N>;
using SeqJ = std::make_index_sequence<N>;
using InvSeqJ = typename StaticIntegralRange<std::size_t,B::rank(),N>::integer_sequence;
static_assert(InvSeqI::size() + InvSeqJ::size() == C::rank());
using InvSeqJ = typename StaticIntegralRange<std::size_t,BTraits::rank(),N>::integer_sequence;
static_assert(InvSeqI::size() + InvSeqJ::size() == CTraits::rank());
// the extents of a and the extents of b must be compatible to c
using EA = typename A::extents_type;
using EB = typename B::extents_type;
using EA = typename ATraits::extents_type;
using EB = typename BTraits::extents_type;
static_assert(Impl::checkStaticExtents<EA,EB>(SeqI{}, SeqJ{}));
assert((Impl::checkExtents(a.extents(), SeqI{}, b.extents(), SeqJ{})));
assert((Impl::checkExtents(ATraits::extents(a), SeqI{}, BTraits::extents(b), SeqJ{})));
// Objects of `A` and `B` must be different from object `C`
assert((void*)(&a) != (void*)(&c) && (void*)(&b) != (void*)(&c));
......@@ -252,23 +264,26 @@ constexpr auto tensordot (const A& a, std::index_sequence<II...> aSeq,
const B& b, std::index_sequence<JJ...> bSeq,
BinaryOp1 op1 = {}, BinaryOp2 op2 = {})
{
using ATraits = TensorTraits<A>;
using BTraits = TensorTraits<B>;
// the extents(II) of a and the extents(JJ) of b must match
using EA = typename A::extents_type;
using EB = typename B::extents_type;
using EA = typename ATraits::extents_type;
using EB = typename BTraits::extents_type;
static_assert(Impl::checkStaticExtents<EA,EB>(aSeq, bSeq));
assert((Impl::checkExtents(a.extents(), aSeq, b.extents(), bSeq)));
assert((Impl::checkExtents(ATraits::extents(a), aSeq, BTraits::extents(b), bSeq)));
// create integer sequences that do not include the contraction indices
const auto aSeqInv = difference<A::rank()>(aSeq); // {0,1,...A::rank()-1} \ {II...}
const auto bSeqInv = difference<B::rank()>(bSeq); // {0,1,...B::rank()-1} \ {JJ...}
const auto aSeqInv = difference<ATraits::rank()>(aSeq); // {0,1,...A::rank()-1} \ {II...}
const auto bSeqInv = difference<BTraits::rank()>(bSeq); // {0,1,...B::rank()-1} \ {JJ...}
// create result extents by collecting the extents of a and b that are not contracted
auto cExtents = Impl::concatExtents(
Impl::sliceExtents(a.extents(), aSeqInv),
Impl::sliceExtents(b.extents(), bSeqInv));
Impl::sliceExtents(ATraits::extents(a), aSeqInv),
Impl::sliceExtents(BTraits::extents(b), bSeqInv));
using VA = decltype(std::declval<A>()[std::array<typename A::index_type,A::rank()>{}]);
using VB = decltype(std::declval<B>()[std::array<typename B::index_type,B::rank()>{}]);
using VA = decltype(std::declval<A>()[std::array<typename ATraits::index_type,ATraits::rank()>{}]);
using VB = decltype(std::declval<B>()[std::array<typename BTraits::index_type,BTraits::rank()>{}]);
using V = std::invoke_result_t<BinaryOp2,VA,VB>;
auto c = Tensor{cExtents, V(0)};
Impl::tensorDotImpl(a,aSeq,aSeqInv,b,bSeq,bSeqInv,c,std::ref(op1),std::ref(op2));
......@@ -296,7 +311,8 @@ constexpr auto tensordot (const A& a, const B& b,
std::integral_constant<std::size_t,N> axes = {},
BinaryOp1 op1 = {}, BinaryOp2 op2 = {})
{
using SeqI = typename StaticIntegralRange<std::size_t,A::rank(),A::rank()-N>::integer_sequence;
using ATraits = TensorTraits<A>;
using SeqI = typename StaticIntegralRange<std::size_t,ATraits::rank(),ATraits::rank()-N>::integer_sequence;
using SeqJ = std::make_index_sequence<N>;
return tensordot(a,SeqI{},b,SeqJ{},op1,op2);
}
......
......@@ -651,6 +651,49 @@ constexpr bool operator== (const S& number, const TensorMixin<D,B>& rhs) noexcep
return number == rhs();
}
/** \brief Output stream overload for tensor types */
template <class D, class B>
std::ostream& operator<< (std::ostream& out, const Dune::TensorMixin<D,B>& tensor)
{
using extents_type = typename Dune::TensorMixin<D,B>::extents_type;
using index_type = typename Dune::TensorMixin<D,B>::index_type;
if constexpr(extents_type::rank() == 0) {
out << tensor();
} else if constexpr(extents_type::rank() == 1) {
out << "[";
for (index_type i = 0; i < tensor.extent(0); ++i)
out << tensor(i) << (i < tensor.extent(0)-1 ? ", " : "]");
} else if constexpr(extents_type::rank() == 2) {
out << "[\n";
for (index_type i = 0; i < tensor.extent(0); ++i) {
out << " [";
for (index_type j = 0; j < tensor.extent(1); ++j)
out << tensor(i,j) << (j < tensor.extent(1)-1 ? ", " : "]");
out << (i < tensor.extent(0)-1 ? ",\n" : "\n");
}
out << ']';
} else if constexpr(extents_type::rank() == 3) {
out << "[\n";
for (index_type i = 0; i < tensor.extent(0); ++i) {
out << " [\n";
for (index_type j = 0; j < tensor.extent(1); ++j) {
out << " [";
for (index_type k = 0; k < tensor.extent(2); ++k)
out << tensor(i,j,k) << (k < tensor.extent(2)-1 ? ", " : "]");
out << (j < tensor.extent(1)-1 ? ",\n" : "\n");
}
out << " ]";
out << (i < tensor.extent(0)-1 ? ",\n" : "\n");
}
out << ']';
} else {
out << "Tensor<" << extents_type::rank() << ">";
}
return out;
}
template <class D, class B>
struct FieldTraits< TensorMixin<D,B> >
{
......
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
// SPDX-FileCopyrightInfo: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#ifndef DUNE_COMMON_TENSORTRAITS_HH
#define DUNE_COMMON_TENSORTRAITS_HH
#include <cstddef>
namespace Dune {
template <class T>
struct TensorTraits
{
using extents_type = typename T::extents_type;
using index_type = typename extents_type::index_type;
using rank_type = typename extents_type::rank_type;
/// \brief Number of elements in all dimensions of the array, \related extents
static constexpr const extents_type& extents (const T& tensor) noexcept { return tensor.extents(); }
/// \brief Number of dimensions of the array
static constexpr rank_type rank () noexcept { return extents_type::rank(); }
/// \brief Number of dimension with dynamic size
static constexpr rank_type rank_dynamic () noexcept { return extents_type::rank_dynamic(); }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr std::size_t static_extent (rank_type r) noexcept { return extents_type::static_extent(r); }
/// \brief Number of elements in the r'th dimension of the tensor
static constexpr index_type extent (const T& tensor, rank_type r) noexcept { return extents(tensor).extent(r); }
};
} // end namespace Dune
#endif // DUNE_COMMON_TENSORTRAITS_HH
......@@ -3,6 +3,10 @@
// SPDX-FileCopyrightInfo: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#include <dune/common/dynmatrix.hh>
#include <dune/common/dynvector.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/fvector.hh>
#include <dune/common/tensor.hh>
#include <dune/common/tensordot.hh>
#include <dune/common/tensorspan.hh>
......@@ -23,6 +27,11 @@ int main(int argc, char** argv)
auto dTensor32 = Dune::Tensor<double,Dune::dynamic,Dune::dynamic>{3,2};
auto dTensor234 = Dune::Tensor<double,Dune::dynamic,Dune::dynamic,Dune::dynamic>{2,3,4};
auto dVector2 = Dune::DynamicVector<double>{2};
auto dVector3 = Dune::DynamicVector<double>{3};
auto dMatrix23 = Dune::DynamicMatrix<double>{2,3};
auto dMatrix32 = Dune::DynamicMatrix<double>{3,2};
auto fTensor = Dune::Tensor<double>{};
auto fTensor2 = Dune::Tensor<double,2>{};
auto fTensor3 = Dune::Tensor<double,3>{};
......@@ -30,6 +39,11 @@ int main(int argc, char** argv)
auto fTensor32 = Dune::Tensor<double,3,2>{};
auto fTensor234 = Dune::Tensor<double,2,3,4>{};
auto fVector2 = Dune::FieldVector<double,2>{};
auto fVector3 = Dune::FieldVector<double,3>{};
auto fMatrix23 = Dune::FieldMatrix<double,2,3>{};
auto fMatrix32 = Dune::FieldMatrix<double,3,2>{};
// test dynamic tensors
{
auto d = tensordot<0>(dTensor,dTensor);
......@@ -54,6 +68,11 @@ int main(int argc, char** argv)
testSuite.check(d23.extent(1) == 3);
}
{
auto d = tensordot<1>(dVector2,dVector2);
testSuite.check(d.rank() == 0);
}
{
auto d = tensordot<2>(dTensor23,dTensor23);
testSuite.check(d.rank() == 0);
......@@ -101,6 +120,43 @@ int main(int argc, char** argv)
testSuite.check(d2332.extent(3) == 2);
}
{
auto d = tensordot<2>(dMatrix23,dMatrix23);
testSuite.check(d.rank() == 0);
auto d2 = tensordot<1>(dMatrix23,dVector3);
testSuite.check(d2.rank() == 1);
testSuite.check(d2.extent(0) == 2);
auto d3 = tensordot<1>(dVector2,dMatrix23);
testSuite.check(d3.rank() == 1);
testSuite.check(d3.extent(0) == 3);
auto d22 = tensordot<1>(dMatrix23,dMatrix32);
testSuite.check(d22.rank() == 2);
testSuite.check(d22.extent(0) == 2);
testSuite.check(d22.extent(1) == 2);
auto d223 = tensordot<0>(dVector2,dMatrix23);
testSuite.check(d223.rank() == 3);
testSuite.check(d223.extent(0) == 2);
testSuite.check(d223.extent(1) == 2);
testSuite.check(d223.extent(2) == 3);
auto d233 = tensordot<0>(dMatrix23,dVector3);
testSuite.check(d233.rank() == 3);
testSuite.check(d233.extent(0) == 2);
testSuite.check(d233.extent(1) == 3);
testSuite.check(d233.extent(2) == 3);
auto d2332 = tensordot<0>(dMatrix23,dMatrix32);
testSuite.check(d2332.rank() == 4);
testSuite.check(d2332.extent(0) == 2);
testSuite.check(d2332.extent(1) == 3);
testSuite.check(d2332.extent(2) == 3);
testSuite.check(d2332.extent(3) == 2);
}
// test mixed static/dynamic tensors
{
auto d = tensordot<0>(fTensor,dTensor);
......@@ -125,6 +181,16 @@ int main(int argc, char** argv)
testSuite.check(d23.extent(1) == 3);
}
{
auto d = tensordot<1>(fVector2,dVector2);
testSuite.check(d.rank() == 0);
auto d23 = tensordot<0>(fVector2,dVector3);
testSuite.check(d23.rank() == 2);
testSuite.check(d23.extent(0) == 2);
testSuite.check(d23.extent(1) == 3);
}
{
auto d = tensordot<2>(fTensor23,dTensor23);
testSuite.check(d.rank() == 0);
......@@ -172,6 +238,24 @@ int main(int argc, char** argv)
testSuite.check(d2332.extent(3) == 2);
}
{
auto d = tensordot<2>(fMatrix23,dTensor23);
testSuite.check(d.rank() == 0);
auto d2 = tensordot<1>(fMatrix23,dTensor3);
testSuite.check(d2.rank() == 1);
testSuite.check(d2.extent(0) == 2);
auto d3 = tensordot<1>(fVector2,dTensor23);
testSuite.check(d3.rank() == 1);
testSuite.check(d3.extent(0) == 3);
auto d22 = tensordot<1>(fMatrix23,dTensor32);
testSuite.check(d22.rank() == 2);
testSuite.check(d22.extent(0) == 2);
testSuite.check(d22.extent(1) == 2);
}
// test interaction with TensorSpan
{
auto dMat0 = tensordot<1>(fTensor23,dTensor32.toTensorSpan());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment