Skip to content
Snippets Groups Projects
Commit 721eae1d authored by Jö Fahlke's avatar Jö Fahlke
Browse files

[Simd] Introduce replacement functions for && and ||.

These operators already had shaky support in Vc.  Vectorclass does not support
them at all.
parent 63f6a7c9
Branches
Tags
1 merge request!193Extended SIMD interface
Pipeline #
......@@ -14,6 +14,7 @@
#include <cstddef>
#include <dune/common/rangeutilities.hh>
#include <dune/common/simd/base.hh>
#include <dune/common/simd/interface.hh>
......@@ -112,6 +113,36 @@ namespace Dune {
return m;
}
//! implements Simd::mask()
template<class V>
Mask<V> mask(ADLTag<0, std::is_same<V, Mask<V> >::value>,
const V &v)
{
return v;
}
//! implements Simd::mask()
template<class V>
Mask<V> mask(ADLTag<0, !std::is_same<V, Mask<V> >::value>,
const V &v)
{
return v != V(Scalar<V>(0));
}
//! implements Simd::maskOr()
template<class V1, class V2>
auto maskOr(ADLTag<0>, const V1 &v1, const V2 &v2)
{
return Simd::mask(v1) || Simd::mask(v2);
}
//! implements Simd::maskAnd()
template<class V1, class V2>
auto maskAnd(ADLTag<0>, const V1 &v1, const V2 &v2)
{
return Simd::mask(v1) && Simd::mask(v2);
}
//! @} Overloadable and default functions
//! @} Group SIMDAbstract
} // namespace Overloads
......
......@@ -405,6 +405,36 @@ namespace Dune {
return min(Overloads::ADLTag<6>{}, v);
}
//! Convert to mask, analogue of bool(s) for scalars
/**
* Implemented by `Overloads::mask()`.
*/
template<class V>
auto mask(const V &v)
{
return mask(Overloads::ADLTag<6>{}, v);
}
//! Logic or of masks
/**
* Implemented by `Overloads::maskOr()`.
*/
template<class V1, class V2>
auto maskOr(const V1 &v1, const V2 &v2)
{
return maskOr(Overloads::ADLTag<6>{}, v1, v2);
}
//! Logic and of masks
/**
* Implemented by `Overloads::maskAnd()`.
*/
template<class V1, class V2>
auto maskAnd(const V1 &v1, const V2 &v2)
{
return maskAnd(Overloads::ADLTag<6>{}, v1, v2);
}
//! @}
/** @name Syntactic Sugar
......
......@@ -18,6 +18,7 @@
#include <dune/common/classname.hh>
#include <dune/common/simd/simd.hh>
#include <dune/common/std/type_traits.hh>
#include <dune/common/typetraits.hh>
#include <dune/common/unused.hh>
......@@ -211,6 +212,9 @@ namespace Dune {
return T(5);
}
template<class Op, class... Args>
using ScalarResult =
decltype(std::declval<Op>().scalar(std::declval<Args>()...));
template<class Call>
using CanCall = Impl::CanCall<Call>;
......@@ -598,11 +602,32 @@ namespace Dune {
struct OpInfix##NAME \
{ \
template<class V1, class V2> \
auto operator()(V1&& v1, V2&& v2) const \
-> decltype(std::forward<V1>(v1) SYMBOL std::forward<V2>(v2)) \
decltype(auto) vector(V1&& v1, V2&& v2) const \
{ \
return std::forward<V1>(v1) SYMBOL std::forward<V2>(v2); \
} \
template<class S1, class S2> \
auto scalar(S1&& s1, S2&& s2) const \
-> decltype(std::forward<S1>(s1) SYMBOL std::forward<S2>(s2)) \
{ \
return std::forward<S1>(s1) SYMBOL std::forward<S2>(s2); \
} \
}
#define DUNE_SIMD_REPL_OP(NAME, REPLFN, SYMBOL) \
struct OpInfix##NAME \
{ \
template<class V1, class V2> \
decltype(auto) vector(V1&& v1, V2&& v2) const \
{ \
return Simd::REPLFN(std::forward<V1>(v1), std::forward<V2>(v2)); \
} \
template<class S1, class S2> \
auto scalar(S1&& s1, S2&& s2) const \
-> decltype(std::forward<S1>(s1) SYMBOL std::forward<S2>(s2)) \
{ \
return std::forward<S1>(s1) SYMBOL std::forward<S2>(s2); \
} \
}
DUNE_SIMD_INFIX_OP(Mul, * );
......@@ -627,8 +652,10 @@ namespace Dune {
DUNE_SIMD_INFIX_OP(BitXor, ^ );
DUNE_SIMD_INFIX_OP(BitOr, | );
DUNE_SIMD_INFIX_OP(LogicAnd, && );
DUNE_SIMD_INFIX_OP(LogicOr, || );
// Those are not supported in any meaningful way by vectorclass
// We need to test replacement functions maskAnd() and maskOr() instead.
DUNE_SIMD_REPL_OP(LogicAnd, maskAnd, && );
DUNE_SIMD_REPL_OP(LogicOr, maskOr, || );
DUNE_SIMD_INFIX_OP(Assign, = );
DUNE_SIMD_INFIX_OP(AssignMul, *= );
......@@ -643,6 +670,7 @@ namespace Dune {
DUNE_SIMD_INFIX_OP(AssignOr, |= );
#undef DUNE_SIMD_INFIX_OP
#undef DUNE_SIMD_REPL_OP
// just used as a tag
struct OpInfixComma {};
......@@ -697,8 +725,9 @@ namespace Dune {
template<class V1, class V2, class Op>
std::enable_if_t<
CanCall<Op(decltype(lane(0, std::declval<V1>())),
decltype(lane(0, std::declval<V2>())))>::value>
Std::is_detected_v<ScalarResult, Op,
decltype(lane(0, std::declval<V1>())),
decltype(lane(0, std::declval<V2>()))> >
checkBinaryOpVV(Op op)
{
#define DUNE_SIMD_OPNAME (className<Op(V1, V2)>())
......@@ -713,7 +742,8 @@ namespace Dune {
// copy the arguments in case V1 or V2 are references
auto arg1 = val1;
auto arg2 = val2;
auto &&result = op(static_cast<V1>(arg1), static_cast<V2>(arg2));
auto &&result =
op.vector(static_cast<V1>(arg1), static_cast<V2>(arg2));
using T = Scalar<std::decay_t<decltype(result)> >;
for(std::size_t l = 0; l < lanes(val1); ++l)
{
......@@ -721,8 +751,8 @@ namespace Dune {
// `static_cast` around the `op()` is necessary
DUNE_SIMD_CHECK_OP
(lane(l, result)
== static_cast<T>(op(lane(l, static_cast<V1>(val1)),
lane(l, static_cast<V2>(val2)))));
== static_cast<T>(op.scalar(lane(l, static_cast<V1>(val1)),
lane(l, static_cast<V2>(val2)))));
}
// op might modify val1 and val2, verify that any such
// modification also happens in the vector case
......@@ -736,8 +766,9 @@ namespace Dune {
template<class V1, class V2, class Op>
std::enable_if_t<
!CanCall<Op(decltype(lane(0, std::declval<V1>())),
decltype(lane(0, std::declval<V2>())))>::value>
!Std::is_detected_v<ScalarResult, Op,
decltype(lane(0, std::declval<V1>())),
decltype(lane(0, std::declval<V2>()))> >
checkBinaryOpVV(Op op)
{
// log_ << "No " << className<Op(decltype(lane(0, std::declval<V1>())),
......@@ -788,7 +819,8 @@ namespace Dune {
template<class T1, class V2, class Op>
std::enable_if_t<
CanCall<Op(T1, decltype(lane(0, std::declval<V2>())))>::value>
Std::is_detected_v<ScalarResult, Op, T1,
decltype(lane(0, std::declval<V2>()))> >
checkBinaryOpSV(Op op)
{
#define DUNE_SIMD_OPNAME (className<Op(T1, V2)>())
......@@ -812,9 +844,11 @@ namespace Dune {
auto varg1 = vval1;
auto varg2 = vval2;
auto &&sresult = op(static_cast<T1>(sarg1), static_cast<V2>(sarg2));
auto &&sresult =
op.vector(static_cast<T1>(sarg1), static_cast<V2>(sarg2));
using TS = Scalar<std::decay_t<decltype(sresult)> >;
auto &&vresult = op(static_cast<V1>(varg1), static_cast<V2>(varg2));
auto &&vresult =
op.vector(static_cast<V1>(varg1), static_cast<V2>(varg2));
using TV = Scalar<std::decay_t<decltype(vresult)> >;
for(std::size_t l = 0; l < lanes<std::decay_t<V1> >(); ++l)
{
......@@ -822,12 +856,12 @@ namespace Dune {
// `static_cast` around the `op()` is necessary
DUNE_SIMD_CHECK_OP
(lane(l, sresult)
== static_cast<TS>(op( static_cast<T1>(sval1),
lane(l, static_cast<V2>(sval2)))));
== static_cast<TS>(op.scalar( static_cast<T1>(sval1),
lane(l, static_cast<V2>(sval2)))));
DUNE_SIMD_CHECK_OP
(lane(l, vresult)
== static_cast<TV>(op(lane(l, static_cast<V1>(vval1)),
lane(l, static_cast<V2>(vval2)))));
== static_cast<TV>(op.scalar(lane(l, static_cast<V1>(vval1)),
lane(l, static_cast<V2>(vval2)))));
// cross check
DUNE_SIMD_CHECK_OP(lane(l, sresult) == lane(l, vresult));
}
......@@ -848,7 +882,8 @@ namespace Dune {
template<class T1, class V2, class Op>
std::enable_if_t<
!CanCall<Op(T1, decltype(lane(0, std::declval<V2>())))>::value>
!Std::is_detected_v<ScalarResult, Op, T1,
decltype(lane(0, std::declval<V2>()))> >
checkBinaryOpSV(Op op)
{
// log_ << "No "
......@@ -902,7 +937,8 @@ namespace Dune {
template<class V1, class T2, class Op>
std::enable_if_t<
CanCall<Op(decltype(lane(0, std::declval<V1>())), T2)>::value>
Std::is_detected_v<ScalarResult, Op,
decltype(lane(0, std::declval<V1>())), T2> >
checkBinaryOpVS(Op op)
{
#define DUNE_SIMD_OPNAME (className<Op(V1, T2)>())
......@@ -926,9 +962,11 @@ namespace Dune {
auto varg1 = vval1;
auto varg2 = vval2;
auto &&sresult = op(static_cast<V1>(sarg1), static_cast<T2>(sarg2));
auto &&sresult =
op.vector(static_cast<V1>(sarg1), static_cast<T2>(sarg2));
using TS = Scalar<std::decay_t<decltype(sresult)> >;
auto &&vresult = op(static_cast<V1>(varg1), static_cast<V2>(varg2));
auto &&vresult =
op.vector(static_cast<V1>(varg1), static_cast<V2>(varg2));
using TV = Scalar<std::decay_t<decltype(vresult)> >;
for(std::size_t l = 0; l < lanes<std::decay_t<V1> >(); ++l)
{
......@@ -936,12 +974,12 @@ namespace Dune {
// `static_cast` around the `op()` is necessary
DUNE_SIMD_CHECK_OP
(lane(l, sresult)
== static_cast<TS>(op(lane(l, static_cast<V1>(sval1)),
static_cast<T2>(sval2) )));
== static_cast<TS>(op.scalar(lane(l, static_cast<V1>(sval1)),
static_cast<T2>(sval2) )));
DUNE_SIMD_CHECK_OP
(lane(l, vresult)
== static_cast<TV>(op(lane(l, static_cast<V1>(vval1)),
lane(l, static_cast<V2>(vval2)))));
== static_cast<TV>(op.scalar(lane(l, static_cast<V1>(vval1)),
lane(l, static_cast<V2>(vval2)))));
// cross check
DUNE_SIMD_CHECK_OP(lane(l, sresult) == lane(l, vresult));
}
......@@ -962,7 +1000,8 @@ namespace Dune {
template<class V1, class T2, class Op>
std::enable_if_t<
!CanCall<Op(decltype(lane(0, std::declval<V1>())), T2)>::value>
!Std::is_detected_v<ScalarResult, Op,
decltype(lane(0, std::declval<V1>())), T2> >
checkBinaryOpVS(Op op)
{
// log_ << "No "
......@@ -1060,8 +1099,8 @@ namespace Dune {
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixBitXor );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixBitOr );
DUNE_SIMD_BINARY_OPCHECK( , , , InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK( , , , InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK( , VV, VS, InfixAssign );
DUNE_SIMD_BINARY_OPCHECK( , VV, VS, InfixAssignMul );
......@@ -1117,8 +1156,8 @@ namespace Dune {
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixBitXor );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixBitOr );
DUNE_SIMD_BINARY_OPCHECK( , , , InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK( , , , InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK( , VV, VS, InfixAssign );
DUNE_SIMD_BINARY_OPCHECK( , VV, VS, InfixAssignMul );
......@@ -1176,8 +1215,8 @@ namespace Dune {
DUNE_SIMD_BINARY_OPCHECK( , VV, , InfixBitXor );
DUNE_SIMD_BINARY_OPCHECK( , VV, , InfixBitOr );
DUNE_SIMD_BINARY_OPCHECK( , VV, , InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK( , VV, , InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicAnd );
DUNE_SIMD_BINARY_OPCHECK(SV, VV, VS, InfixLogicOr );
DUNE_SIMD_BINARY_OPCHECK( , VV, , InfixAssign );
DUNE_SIMD_BINARY_OPCHECK( , , , InfixAssignMul );
......@@ -1536,6 +1575,9 @@ namespace Dune {
"must not be references, and must not include "
"cv-qualifiers");
static_assert(std::is_same<M, Mask<M> >::value,
"Mask must be their own mask types.");
// check whether the test for this type already started
if(maskSeen_.emplace(typeid (M)).second == false)
{
......
......@@ -500,6 +500,42 @@ namespace Dune {
return !Vc::any_of(!mask);
}
//! implements Simd::maskAnd()
template<class S1, class V2>
auto maskAnd(ADLTag<5, std::is_same<Mask<S1>, bool>::value &&
VcImpl::IsVector<V2>::value>,
const S1 &s1, const V2 &v2)
{
return Simd::Mask<V2>(Simd::mask(s1)) && Simd::mask(v2);
}
//! implements Simd::maskAnd()
template<class V1, class S2>
auto maskAnd(ADLTag<5, VcImpl::IsVector<V1>::value &&
std::is_same<Mask<S2>, bool>::value>,
const V1 &v1, const S2 &s2)
{
return Simd::mask(v1) && Simd::Mask<V1>(Simd::mask(s2));
}
//! implements Simd::maskOr()
template<class S1, class V2>
auto maskOr(ADLTag<5, std::is_same<Mask<S1>, bool>::value &&
VcImpl::IsVector<V2>::value>,
const S1 &s1, const V2 &v2)
{
return Simd::Mask<V2>(Simd::mask(s1)) || Simd::mask(v2);
}
//! implements Simd::maskOr()
template<class V1, class S2>
auto maskOr(ADLTag<5, VcImpl::IsVector<V1>::value &&
std::is_same<Mask<S2>, bool>::value>,
const V1 &v1, const S2 &s2)
{
return Simd::mask(v1) || Simd::Mask<V1>(Simd::mask(s2));
}
//! @} group SIMDVc
} // namespace Overloads
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment