...
 
Commits (15)
auto f = [](Coord x, double u) -> int { return c(x)*u*u; };
auto u = U(x);
f.bind(e);
auto r = f(x,u);
derivative(f) == derivative(f, Component<0>());
auto dfdx == derivative(f, Component<0>());
auto dfdx == derivative(f, _1);
auto dfdu == derivative(f, Component<1>());
auto dfdu == derivative(f, _2);
auto fx = std::bind(f, _1, u);
auto fx_eval_u = std::bind(f, _1, std::bind(U,_1));
auto fu = std::bind(f, x, _1);
auto fux = std::bind(f, _2, _1);
f(x,u) == fu(x) == fx(u) == fux(u,x);
f(x,U(x)) == f(x,u) == fx_eval_u(x);
......@@ -44,7 +44,7 @@ public:
return y;
}
friend Polynomial derivative(const Polynomial& p)
friend Polynomial derivative(const Polynomial& p, DefaultDerivativeDirection = derivativeDirection::_default)
{
auto derivative = Polynomial();
derivative.coefficients_.resize(p.coefficients_.size()-1);
......
......@@ -21,7 +21,8 @@ public:
template<class K, int sinFactor, int cosFactor>
TrigonometricFunction<K, -cosFactor, sinFactor> derivative(const TrigonometricFunction<K, sinFactor, cosFactor>& f)
TrigonometricFunction<K, -cosFactor, sinFactor> derivative(const TrigonometricFunction<K, sinFactor, cosFactor>& f,
DefaultDerivativeDirection = derivativeDirection::_default)
{
return TrigonometricFunction<K, -cosFactor, sinFactor>();
}
......
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_FUNCTIONS_COMMON_DERIVATIVEDIRECTION_HH
#define DUNE_FUNCTIONS_COMMON_DERIVATIVEDIRECTION_HH
#include <dune/functions/common/type_traits.hh>
#include "concept.hh"
namespace Dune {
namespace Functions {
template<int D>
struct DerivativeDirection :
public std::integral_constant<int, D>
{};
namespace derivativeDirection
{
extern DerivativeDirection<1> _default;
extern DerivativeDirection<1> _d1;
extern DerivativeDirection<2> _d2;
extern DerivativeDirection<3> _d3;
extern DerivativeDirection<4> _d4;
extern DerivativeDirection<5> _d5;
extern DerivativeDirection<6> _d6;
extern DerivativeDirection<7> _d7;
extern DerivativeDirection<8> _d8;
extern DerivativeDirection<9> _d9;
extern DerivativeDirection<10> _d10;
extern DerivativeDirection<10> _dN;
}
using DefaultDerivativeDirection = DerivativeDirection<1>;
/**
* A concept describing types that have a derivative(f,dir) method found by ADL
*/
template<int D>
struct HasFreeDerivative
{
template<class F>
auto require(F&& f) -> decltype(
derivative(f,DerivativeDirection<D>())
);
};
}
}
#endif // DUNE_FUNCTIONS_COMMON_DERIVATIVEDIRECTION_HH
......@@ -7,6 +7,7 @@
#include <dune/functions/common/type_traits.hh>
#include <dune/functions/common/defaultderivativetraits.hh>
#include <dune/functions/common/derivativedirection.hh>
#include <dune/functions/common/differentiablefunction_imp.hh>
#include <dune/functions/common/polymorphicsmallobject.hh>
#include <dune/functions/common/concept.hh>
......@@ -15,8 +16,6 @@
namespace Dune {
namespace Functions {
/**
* Default implementation is empty
* The actual implementation is only given if Signature is an type
......@@ -32,30 +31,34 @@ class DifferentiableFunction
* \brief Class storing differentiable functions using type erasure
*
*/
template<class Range, class Domain, template<class> class DerivativeTraits, size_t bufferSize>
class DifferentiableFunction< Range(Domain), DerivativeTraits, bufferSize>
template<typename Range, typename... Domain, template<class> class DerivativeTraits, size_t bufferSize>
class DifferentiableFunction< Range(Domain...), DerivativeTraits, bufferSize>
{
public:
/**
* \brief Signature of wrapped functions
*/
using Signature = Range(Domain);
/**
* \brief Raw signature of wrapped functions without possible const and reference qualifiers
*/
using RawSignature = typename SignatureTraits<Signature>::RawSignature;
using Signature = Range(Domain...);
/**
* \brief Signature of derivative of wrapped functions
*/
using DerivativeSignature = typename DerivativeTraits<RawSignature>::Range(Domain);
template<int P>
struct PartialDomain
{
static_assert( P <= sizeof...(Domain), "Derivative direction is greater than the number of parameters in the signature");
using type = typename std::tuple_element< P-1, std::tuple<Domain...> >::type;
};
template<typename P>
struct DerivativeInterface
{
using PartialDomain = P;
using PartialSignature = Range(PartialDomain);
using RawSignature = typename SignatureTraits<PartialSignature>::RawSignature;
using DerivativeRange = typename DerivativeTraits<RawSignature>::Range;
using DerivativeSignature = DerivativeRange(Domain...);
using type = DifferentiableFunction<DerivativeSignature, DerivativeTraits, bufferSize>;
};
/**
* \brief Wrapper type of returned derivatives
*/
using DerivativeInterface = DifferentiableFunction<DerivativeSignature, DerivativeTraits, bufferSize>;
using DerivativeInterfaces = std::tuple<typename DerivativeInterface<Domain>::type...>;
/**
* \brief Construct from function
......@@ -70,7 +73,7 @@ public:
*/
template<class F, disableCopyMove<DifferentiableFunction, F> = 0 >
DifferentiableFunction(F&& f) :
f_(Imp::DifferentiableFunctionWrapper<Signature, DerivativeInterface, typename std::decay<F>::type>(std::forward<F>(f)))
f_(Imp::DifferentiableFunctionWrapper<Signature, DerivativeInterfaces, typename std::decay<F>::type>(std::forward<F>(f)))
{}
DifferentiableFunction() = default;
......@@ -78,9 +81,9 @@ public:
/**
* \brief Evaluation of wrapped function
*/
Range operator() (const Domain& x) const
Range operator() (const Domain&... x) const
{
return f_.get().operator()(x);
return f_.get().operator()(x...);
}
/**
......@@ -88,13 +91,16 @@ public:
*
* This is a free function that will be found by ADL.
*/
friend DerivativeInterface derivative(const DifferentiableFunction& t)
template<int D = 1>
friend
typename DerivativeInterface< typename PartialDomain<D>::type >::type
derivative(const DifferentiableFunction& t, DerivativeDirection<D> dir = DerivativeDirection<D>())
{
return t.f_.get().derivative();
return t.f_.get().derivative(dir);
}
private:
PolymorphicSmallObject<Imp::DifferentiableFunctionWrapperBase<Signature, DerivativeInterface>, bufferSize > f_;
PolymorphicSmallObject<Imp::DifferentiableFunctionWrapperBase<Signature, DerivativeInterfaces>, bufferSize > f_;
};
......
......@@ -3,10 +3,14 @@
#ifndef DUNE_FUNCTIONS_COMMON_DIFFERENTIABLE_FUNCTION_IMP_HH
#define DUNE_FUNCTIONS_COMMON_DIFFERENTIABLE_FUNCTION_IMP_HH
#include <tuple>
#include <dune/common/exceptions.hh>
#include <dune/functions/common/signature.hh>
#include <dune/functions/common/type_traits.hh>
#include <dune/functions/common/interfaces.hh>
#include <dune/functions/common/derivativedirection.hh>
#include "concept.hh"
......@@ -14,64 +18,114 @@ namespace Dune {
namespace Functions {
namespace Imp {
/**
* A concept describing types that have a derivative() method found by ADL
*/
struct HasFreeDerivative
{
template<class F>
auto require(F&& f) -> decltype(
derivative(f)
);
};
template<class Dummy, class F,
template<class Dummy, class F, int D,
typename std::enable_if<
Dune::Functions::Concept::models< HasFreeDerivative, F>() , int>::type = 0>
auto derivativeIfImplemented(const F& f) -> decltype(derivative(f))
Dune::Functions::Concept::models< HasFreeDerivative<D>, F>() , int>::type = 0>
auto derivativeIfImplemented(const F& f, DerivativeDirection<D> d) -> decltype(derivative(f,d))
{
return derivative(f);
return derivative(f,d);
}
template<class Dummy, class F,
template<class Dummy, class F, int D,
typename std::enable_if<
not(Dune::Functions::Concept::models< HasFreeDerivative, F>()) , int>::type = 0>
Dummy derivativeIfImplemented(const F& f)
not(Dune::Functions::Concept::models< HasFreeDerivative<D>, F>()) , int>::type = 0>
Dummy derivativeIfImplemented(const F& f, DerivativeDirection<D> d)
{
DUNE_THROW(Dune::NotImplemented, "Derivative not implemented");
}
template<typename DerivativeInterfaces, int P>
class PartialDerivativeWrapperBase;
template<typename... DerivativeInterfaces>
class PartialDerivativeWrapperBase<std::tuple<DerivativeInterfaces...>, 0> {
public:
void derivative() const {};
};
template<typename... DerivativeInterfaces, int P>
class PartialDerivativeWrapperBase<std::tuple<DerivativeInterfaces...>, P> :
public PartialDerivativeWrapperBase<std::tuple<DerivativeInterfaces...>, P-1>
{
/**
* \brief Wrapper type of returned derivatives
*/
using DerivativeInterface = typename std::tuple_element< P-1, std::tuple<DerivativeInterfaces...> >::type;
public:
using PartialDerivativeWrapperBase<std::tuple<DerivativeInterfaces...>, P-1>::derivative;
template<class Signature, class DerivativeInterface>
class DifferentiableFunctionWrapperBase
{};
/**
* \brief compute partial derivative wrt P'th parameter
*/
virtual DerivativeInterface derivative(DerivativeDirection<P> d) const = 0;
template<class Range, class Domain, class DerivativeInterface>
class DifferentiableFunctionWrapperBase<Range(Domain), DerivativeInterface> :
public PolymorphicType<DifferentiableFunctionWrapperBase<Range(Domain), DerivativeInterface> >
};
template<class Signature, class DerivativePack>
class DifferentiableFunctionWrapperBase;
template<typename Range, typename... Domain, typename... DerivativeInterfaces>
class DifferentiableFunctionWrapperBase<Range(Domain...), std::tuple<DerivativeInterfaces...> > :
public PolymorphicType<DifferentiableFunctionWrapperBase<Range(Domain...), std::tuple<DerivativeInterfaces...> > >,
public PartialDerivativeWrapperBase<std::tuple<DerivativeInterfaces...>, sizeof...(DerivativeInterfaces)>
{
static_assert( sizeof...(Domain) == sizeof...(DerivativeInterfaces), "Type Mismatch");
public:
virtual Range operator() (const Domain&... x) const = 0;
};
virtual Range operator() (const Domain& x) const = 0;
virtual DerivativeInterface derivative() const = 0;
template<typename Signature, typename DerivativeInterfaces, int P, class WrapperImp>
class PartialDerivativeWrapper;
template<typename Signature, typename... DerivativeInterfaces, class WrapperImp>
class PartialDerivativeWrapper<Signature, std::tuple<DerivativeInterfaces...>, 0, WrapperImp> :
public DifferentiableFunctionWrapperBase<Signature, std::tuple<DerivativeInterfaces...> >
{
public:
void derivative() const {};
};
template<typename Signanture, typename... DerivativeInterfaces, int P, class WrapperImp>
class PartialDerivativeWrapper<Signanture, std::tuple<DerivativeInterfaces...>, P, WrapperImp> :
public PartialDerivativeWrapper<Signanture, std::tuple<DerivativeInterfaces...>, P-1, WrapperImp>
{
/**
* \brief Wrapper type of returned derivatives
*/
using DerivativeInterface = typename std::tuple_element< P-1, std::tuple<DerivativeInterfaces...> >::type;
public:
using PartialDerivativeWrapper<Signanture, std::tuple<DerivativeInterfaces...>, P-1, WrapperImp>::derivative;
/**
* \brief compute partial derivative wrt P'th parameter
*/
virtual DerivativeInterface derivative(DerivativeDirection<P> d) const
{
auto f_ = static_cast<const WrapperImp*>(this)->f_;
using FImp = decltype(f_);
return derivativeIfImplemented<DerivativeInterface, FImp>(f_,d);
};
};
template<class Signature, class DerivativeInterface, class FImp>
class DifferentiableFunctionWrapper
{};
template<class Signature, class DerivativeInterfaces, class FImp>
class DifferentiableFunctionWrapper;
template<class Range, class Domain, class DerivativeInterface, class FImp>
class DifferentiableFunctionWrapper< Range(Domain), DerivativeInterface, FImp> :
public DifferentiableFunctionWrapperBase<Range(Domain), DerivativeInterface>
template<typename Range, typename... Domain, typename... DerivativeInterfaces, class FImp>
class DifferentiableFunctionWrapper< Range(Domain...), std::tuple<DerivativeInterfaces...>, FImp> :
public PartialDerivativeWrapper< Range(Domain...), std::tuple<DerivativeInterfaces...>, sizeof...(DerivativeInterfaces),
DifferentiableFunctionWrapper< Range(Domain...), std::tuple<DerivativeInterfaces...>, FImp> >
{
static_assert( sizeof...(Domain) == sizeof...(DerivativeInterfaces), "Type Mismatch");
static_assert( sizeof...(Domain) > 0, "Type Mismatch");
public:
template<class F, disableCopyMove<DifferentiableFunctionWrapper, F> = 0>
......@@ -79,14 +133,9 @@ public:
f_(std::forward<F>(f))
{}
virtual Range operator() (const Domain& x) const
{
return f_(x);
}
virtual DerivativeInterface derivative() const
virtual Range operator() (const Domain&... x) const
{
return derivativeIfImplemented<DerivativeInterface, FImp>(f_);
return f_(x...);
}
virtual DifferentiableFunctionWrapper* clone() const
......@@ -104,7 +153,7 @@ public:
return new (buffer) DifferentiableFunctionWrapper(std::move(*this));
}
private:
// private:
FImp f_;
};
......
......@@ -77,7 +77,8 @@ public:
return f_(x);
}
friend Derivative derivative(const DifferentiableFunctionFromCallables& t)
friend Derivative derivative(const DifferentiableFunctionFromCallables& t,
DefaultDerivativeDirection = derivativeDirection::_default)
{
return t.df_;
}
......@@ -88,26 +89,6 @@ private:
};
template<class Signature, template<class> class DerivativeTraits=DefaultDerivativeTraits>
struct SignatureTag;
/**
* \brief Tag-class to encapsulate signature information
*
* \tparam Range range type
* \tparam Domain domain type
* \tparam DerivativeTraits traits template used to determine derivative traits
*/
template<class Range, class Domain, template<class> class DerivativeTraitsT>
struct SignatureTag<Range(Domain), DerivativeTraitsT>
{
using Signature = Range(Domain);
template<class T>
using DerivativeTraits = DerivativeTraitsT<T>;
};
/**
* \brief Create a DifferentiableFunction from callables
*
......
......@@ -4,27 +4,46 @@
#define DUNE_FUNCTIONS_COMMON_SIGNATURE_HH
#include <type_traits>
#include <dune/functions/common/defaultderivativetraits.hh>
namespace Dune {
namespace Functions {
template<class Signature>
struct SignatureTraits;
template<class R, class D>
struct SignatureTraits<R(D)>
template<typename R, typename... D>
struct SignatureTraits<R(D...)>
{
using Range = R;
using Domain = D;
using Range = R;
using Domains = std::tuple< D... >;
using RawRange = typename std::decay<Range>::type;
using RawDomains = std::tuple< typename std::decay<D>::type... >;
using RawRange = typename std::decay<Range>::type;
using RawDomain = typename std::decay<Domain>::type;
using RawSignature = RawRange(typename std::decay<D>::type...);
using RawSignature = RawRange(RawDomain);
enum { DomainSize = sizeof...(D) };
};
template<class Signature, template<class> class DerivativeTraits=DefaultDerivativeTraits>
struct SignatureTag;
/**
* \brief Tag-class to encapsulate signature information
*
* \tparam Range range type
* \tparam Domain domain type
* \tparam DerivativeTraits traits template used to determine derivative traits
*/
template<typename Range, typename... Domain, template<class> class DerivativeTraitsT>
struct SignatureTag<Range(Domain...), DerivativeTraitsT>
{
using Signature = Range(Domain...);
template<class T>
using DerivativeTraits = DerivativeTraitsT<T>;
};
} // namespace Functions
......
......@@ -15,6 +15,47 @@
#include <dune/functions/analyticfunctions/polynomial.hh>
#include <dune/functions/analyticfunctions/trigonometricfunction.hh>
template<int V>
class ConstantIntegerFunction
{
public:
double operator() (const double& x, const double& t) const
{
return V;
}
template<int D = 1>
friend ConstantIntegerFunction<0> derivative(const ConstantIntegerFunction& p,
Dune::Functions::DerivativeDirection<D> = Dune::Functions::DerivativeDirection<D>())
{
return ConstantIntegerFunction<0>();
}
};
class MultiParamTestFunction
{
public:
double operator() (const double& x, const double& t) const
{
return x-t;
}
friend ConstantIntegerFunction<1> derivative(const MultiParamTestFunction& p,
Dune::Functions::DerivativeDirection<1> = Dune::Functions::DerivativeDirection<1>())
{
return ConstantIntegerFunction<1>();
}
friend ConstantIntegerFunction<-1> derivative(const MultiParamTestFunction& p,
Dune::Functions::DerivativeDirection<2>)
{
return ConstantIntegerFunction<-1>();
}
};
//#include <dune/functions/common/callable.hh>
//#include "derivativecheck.hh"
......@@ -24,9 +65,13 @@
struct DifferentiableFunctionImplementableTest
{
template<class F>
static bool checkWithFunction(F&& f)
template<typename Range, typename... Domain, class F, int DiffDir, typename... Params>
static bool checkWithFunction(Dune::Functions::SignatureTag<Range(Domain...)>, F&& f,
Dune::Functions::DerivativeDirection<DiffDir> diffDir,
Params... params)
{
static_assert(sizeof...(Domain) == sizeof...(Params), "Number of parameters does not math the signature");
bool passed = true;
{
......@@ -35,7 +80,7 @@ struct DifferentiableFunctionImplementableTest
// passed = passed and DerivativeCheck<QuadraticPolynomial>::checkAllImplementedTrulyDerived(testFunction, 10);
// Test whether I can evaluate the function somewhere
std::cout << "Function value at x=5: " << f(5) << std::endl;
std::cout << "Function value at x=5: " << f(params...) << std::endl;
......@@ -43,50 +88,64 @@ struct DifferentiableFunctionImplementableTest
// Test whether I can evaluate the first derivative
auto df = derivative(f);
std::cout << "Derivative at x=5: " << df(5) << std::endl;
std::cout << "Derivative at x=5: " << df(params...) << std::endl;
// Test whether I can evaluate the second derivative through FunctionHandle
auto ddf = derivative(df);
std::cout << "Second derivative at x=5: " << ddf(5) << std::endl;
std::cout << "Second derivative at x=5: " << ddf(params...) << std::endl;
// Test whether I can evaluate the third derivative through FunctionHandle
auto dddf = derivative(ddf);
std::cout << "Third derivative at x=5: " << dddf(5) << std::endl;
std::cout << "Third derivative at x=5: " << dddf(params...) << std::endl;
std::cout << std::endl << "Check calling derivatives with explicit derivative direction" << std::endl;
// Test whether I can evaluate the derivative via direction parameter
auto dfdn = derivative(f,diffDir);
std::cout << "Derivative at x=5: " << dfdn(params...) << std::endl;
// Test whether I can evaluate the second derivative through FunctionHandle
auto ddfdn = derivative(df,diffDir);
std::cout << "Second derivative at x=5: " << ddfdn(params...) << std::endl;
// Test whether I can evaluate the third derivative through FunctionHandle
auto dddfdn = derivative(ddf,diffDir);
std::cout << "Third derivative at x=5: " << dddfdn(params...) << std::endl;
std::cout << std::endl << "Check calling derivatives through DifferentiableFunction object" << std::endl;
Dune::Functions::DifferentiableFunction<double(const double&)> fi = f;
Dune::Functions::DifferentiableFunction<Range(Domain...)> fi = f;
// Try to reassign wrapper
fi = f;
// Try assigning a default constructed wrapper
Dune::Functions::DifferentiableFunction<double(const double&)> fii;
Dune::Functions::DifferentiableFunction<Range(Domain...)> fii;
fii = fi;
// Try to copy wrapper
auto fiii = fii;
std::cout << "Function value at x=5: " << fiii(5) << std::endl;
std::cout << "Function value at x=5: " << fiii(params...) << std::endl;
// Test whether I can evaluate the first derivative
auto dfiii = derivative(fiii);
std::cout << "Derivative at x=5: " << dfiii(5) << std::endl;
std::cout << "Derivative at x=5: " << dfiii(params...) << std::endl;
// Test whether I can evaluate the second derivative through FunctionHandle
auto ddfiii = derivative(dfiii);
std::cout << "Second derivative at x=5: " << ddfiii(5) << std::endl;
std::cout << "Second derivative at x=5: " << ddfiii(params...) << std::endl;
// Test whether I can evaluate the third derivative through FunctionHandle
auto dddfiii = derivative(ddfiii);
std::cout << "Third derivative at x=5: " << dddfiii(5) << std::endl;
std::cout << "Third derivative at x=5: " << dddfiii(params...) << std::endl;
// Wrap as non-differentiable function
Dune::Functions::DifferentiableFunction<double(const double&)> g = [=] (const double& x) {return f(x);};
std::cout << "Function value at x=5: " << g(5) << std::endl;
Dune::Functions::DifferentiableFunction<Range(Domain...)> g = [=] (Domain... x) {return f(x...);};
std::cout << "Function value at x=5: " << g(params...) << std::endl;
try {
auto dg = derivative(g);
// auto dg2 = derivative(g,diffDir);
// std::cout << dg(params...) << " ... " << dg2(params...) << std::endl;
passed = false;
}
catch (Dune::NotImplemented e)
......@@ -119,16 +178,31 @@ struct DifferentiableFunctionImplementableTest
{
bool passed = true;
passed = passed and checkWithFunction(Dune::Functions::Polynomial<double>({1, 2, 3}));
passed = passed and checkWithFunction(Dune::Functions::TrigonometricFunction<double, 1, 0>());
double value = 5.5;
Dune::Functions::SignatureTag<double(const double&)> signature;
using Dune::Functions::derivativeDirection::_d1;
using Dune::Functions::derivativeDirection::_d2;
passed = passed and checkWithFunction(signature, Dune::Functions::Polynomial<double>({1, 2, 3}), _d1, value);
passed = passed and checkWithFunction(signature, Dune::Functions::TrigonometricFunction<double, 1, 0>(), _d1, value);
auto f = [](double x){ return std::sin(x);};
auto df = [](double x){ return std::cos(x);};
auto ddf = [](double x){ return -std::sin(x);};
auto dddf = [](double x){ return -std::cos(x);};
auto F = makeDifferentiableFunctionFromCallables(Dune::Functions::SignatureTag<double(double)>(), f, df, ddf, dddf);
passed = passed and checkWithFunction(F);
auto F = makeDifferentiableFunctionFromCallables(signature, f, df, ddf, dddf);
passed = passed and checkWithFunction(signature, F, _d1, value);
Dune::Functions::SignatureTag<double(const double&, const double&)> signature2;
// auto f2 = [](double x, double t){ return x-t; };
// auto df2dx = [](double x, double t){ return 1; };
// auto df2dt = [](double x, double t){ return -1; };
// auto F2 = makeDifferentiableFunctionFromCallables(signature2, f2, df2dx);
MultiParamTestFunction F2;
passed = passed and checkWithFunction(signature2, F2, _d1, 1.0, 2.0);
passed = passed and checkWithFunction(signature2, F2, _d2, 1.0, 2.0);
return passed;
}
......