1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Mehdi Goli Codeplay Software Ltd. 5 // Ralph Potter Codeplay Software Ltd. 6 // Luke Iwanski Codeplay Software Ltd. 7 // Contact: <eigen@codeplay.com> 8 // 9 // This Source Code Form is subject to the terms of the Mozilla 10 // Public License v. 2.0. If a copy of the MPL was not distributed 11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 12 13 /***************************************************************** 14 * TensorSyclExprConstructor.h 15 * 16 * \brief: 17 * This file re-create an expression on the SYCL device in order 18 * to use the original tensor evaluator. 19 * 20 *****************************************************************/ 21 22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP 23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP 24 25 namespace Eigen { 26 namespace TensorSycl { 27 namespace internal { 28 /// this class is used by EvalToOp in order to create an lhs expression which is 29 /// a pointer from an accessor on device-only buffer 30 template <typename PtrType, size_t N, typename... Params> 31 struct EvalToLHSConstructor { 32 PtrType expr; EvalToLHSConstructorEvalToLHSConstructor33 EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t): expr((&(*(utility::tuple::get<N>(t).get_pointer())))) {} 34 }; 35 36 /// \struct ExprConstructor is used to reconstruct the expression on the device and 37 /// recreate the expression with MakeGlobalPointer containing the device address 38 /// space for the TensorMap pointers used in eval function. 39 /// It receives the original expression type, the functor of the node, the tuple 40 /// of accessors, and the device expression type to re-instantiate the 41 /// expression tree for the device 42 template <typename OrigExpr, typename IndexExpr, typename... Params> 43 struct ExprConstructor; 44 45 /// specialisation of the \ref ExprConstructor struct when the node type is 46 /// TensorMap 47 #define TENSORMAP(CVQual)\ 48 template <typename Scalar_, int Options_, int Options2_, int Options3_, int NumIndices_, typename IndexType_,\ 49 template <class> class MakePointer_, size_t N, typename... Params>\ 50 struct ExprConstructor< CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer>,\ 51 CVQual PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options3_, MakePointer_>, N>, Params...>{\ 52 typedef CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer> Type;\ 53 Type expr;\ 54 template <typename FuncDetector>\ 55 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 56 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 57 }; 58 59 TENSORMAP(const) 60 TENSORMAP() 61 #undef TENSORMAP 62 63 #define UNARYCATEGORY(CVQual)\ 64 template <template<class, class> class UnaryCategory, typename OP, typename OrigRHSExpr, typename RHSExpr, typename... Params>\ 65 struct ExprConstructor<CVQual UnaryCategory<OP, OrigRHSExpr>, CVQual UnaryCategory<OP, RHSExpr>, Params...> {\ 66 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_type;\ 67 my_type rhsExpr;\ 68 typedef CVQual UnaryCategory<OP, typename my_type::Type> Type;\ 69 Type expr;\ 70 template <typename FuncDetector>\ 71 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 72 : rhsExpr(funcD.rhsExpr, t), expr(rhsExpr.expr, funcD.func) {}\ 73 }; 74 75 UNARYCATEGORY(const) 76 UNARYCATEGORY() 77 #undef UNARYCATEGORY 78 79 /// specialisation of the \ref ExprConstructor struct when the node type is 80 /// TensorBinaryOp 81 #define BINARYCATEGORY(CVQual)\ 82 template <template<class, class, class> class BinaryCategory, typename OP, typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr,\ 83 typename RHSExpr, typename... Params>\ 84 struct ExprConstructor<CVQual BinaryCategory<OP, OrigLHSExpr, OrigRHSExpr>, CVQual BinaryCategory<OP, LHSExpr, RHSExpr>, Params...> {\ 85 typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ 86 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ 87 typedef CVQual BinaryCategory<OP, typename my_left_type::Type, typename my_right_type::Type> Type;\ 88 my_left_type lhsExpr;\ 89 my_right_type rhsExpr;\ 90 Type expr;\ 91 template <typename FuncDetector>\ 92 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 93 : lhsExpr(funcD.lhsExpr, t),rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr, funcD.func) {}\ 94 }; 95 96 BINARYCATEGORY(const) 97 BINARYCATEGORY() 98 #undef BINARYCATEGORY 99 100 /// specialisation of the \ref ExprConstructor struct when the node type is 101 /// TensorCwiseTernaryOp 102 #define TERNARYCATEGORY(CVQual)\ 103 template <template <class, class, class, class> class TernaryCategory, typename OP, typename OrigArg1Expr, typename OrigArg2Expr,typename OrigArg3Expr,\ 104 typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename... Params>\ 105 struct ExprConstructor<CVQual TernaryCategory<OP, OrigArg1Expr, OrigArg2Expr, OrigArg3Expr>, CVQual TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Params...> {\ 106 typedef ExprConstructor<OrigArg1Expr, Arg1Expr, Params...> my_arg1_type;\ 107 typedef ExprConstructor<OrigArg2Expr, Arg2Expr, Params...> my_arg2_type;\ 108 typedef ExprConstructor<OrigArg3Expr, Arg3Expr, Params...> my_arg3_type;\ 109 typedef CVQual TernaryCategory<OP, typename my_arg1_type::Type, typename my_arg2_type::Type, typename my_arg3_type::Type> Type;\ 110 my_arg1_type arg1Expr;\ 111 my_arg2_type arg2Expr;\ 112 my_arg3_type arg3Expr;\ 113 Type expr;\ 114 template <typename FuncDetector>\ 115 ExprConstructor(FuncDetector &funcD,const utility::tuple::Tuple<Params...> &t)\ 116 : arg1Expr(funcD.arg1Expr, t), arg2Expr(funcD.arg2Expr, t), arg3Expr(funcD.arg3Expr, t), expr(arg1Expr.expr, arg2Expr.expr, arg3Expr.expr, funcD.func) {}\ 117 }; 118 119 TERNARYCATEGORY(const) 120 TERNARYCATEGORY() 121 #undef TERNARYCATEGORY 122 123 /// specialisation of the \ref ExprConstructor struct when the node type is 124 /// TensorCwiseSelectOp 125 #define SELECTOP(CVQual)\ 126 template <typename OrigIfExpr, typename OrigThenExpr, typename OrigElseExpr, typename IfExpr, typename ThenExpr, typename ElseExpr, typename... Params>\ 127 struct ExprConstructor< CVQual TensorSelectOp<OrigIfExpr, OrigThenExpr, OrigElseExpr>, CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Params...> {\ 128 typedef ExprConstructor<OrigIfExpr, IfExpr, Params...> my_if_type;\ 129 typedef ExprConstructor<OrigThenExpr, ThenExpr, Params...> my_then_type;\ 130 typedef ExprConstructor<OrigElseExpr, ElseExpr, Params...> my_else_type;\ 131 typedef CVQual TensorSelectOp<typename my_if_type::Type, typename my_then_type::Type, typename my_else_type::Type> Type;\ 132 my_if_type ifExpr;\ 133 my_then_type thenExpr;\ 134 my_else_type elseExpr;\ 135 Type expr;\ 136 template <typename FuncDetector>\ 137 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 138 : ifExpr(funcD.ifExpr, t), thenExpr(funcD.thenExpr, t), elseExpr(funcD.elseExpr, t), expr(ifExpr.expr, thenExpr.expr, elseExpr.expr) {}\ 139 }; 140 141 SELECTOP(const) 142 SELECTOP() 143 #undef SELECTOP 144 145 /// specialisation of the \ref ExprConstructor struct when the node type is 146 /// const TensorAssignOp 147 #define ASSIGN(CVQual)\ 148 template <typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr, typename RHSExpr, typename... Params>\ 149 struct ExprConstructor<CVQual TensorAssignOp<OrigLHSExpr, OrigRHSExpr>, CVQual TensorAssignOp<LHSExpr, RHSExpr>, Params...> {\ 150 typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ 151 typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ 152 typedef CVQual TensorAssignOp<typename my_left_type::Type, typename my_right_type::Type> Type;\ 153 my_left_type lhsExpr;\ 154 my_right_type rhsExpr;\ 155 Type expr;\ 156 template <typename FuncDetector>\ 157 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 158 : lhsExpr(funcD.lhsExpr, t), rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr) {}\ 159 }; 160 161 ASSIGN(const) 162 ASSIGN() 163 #undef ASSIGN 164 /// specialisation of the \ref ExprConstructor struct when the node type is 165 /// TensorEvalToOp 166 #define EVALTO(CVQual)\ 167 template <typename OrigExpr, typename Expr, typename... Params>\ 168 struct ExprConstructor<CVQual TensorEvalToOp<OrigExpr, MakeGlobalPointer>, CVQual TensorEvalToOp<Expr>, Params...> {\ 169 typedef ExprConstructor<OrigExpr, Expr, Params...> my_expr_type;\ 170 typedef typename TensorEvalToOp<OrigExpr, MakeGlobalPointer>::PointerType my_buffer_type;\ 171 typedef CVQual TensorEvalToOp<typename my_expr_type::Type, MakeGlobalPointer> Type;\ 172 my_expr_type nestedExpression;\ 173 EvalToLHSConstructor<my_buffer_type, 0, Params...> buffer;\ 174 Type expr;\ 175 template <typename FuncDetector>\ 176 ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ 177 : nestedExpression(funcD.rhsExpr, t), buffer(t), expr(buffer.expr, nestedExpression.expr) {}\ 178 }; 179 180 EVALTO(const) 181 EVALTO() 182 #undef EVALTO 183 184 /// specialisation of the \ref ExprConstructor struct when the node type is 185 /// TensorForcedEvalOp 186 #define FORCEDEVAL(CVQual)\ 187 template <typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ 188 struct ExprConstructor<CVQual TensorForcedEvalOp<OrigExpr, MakeGlobalPointer>,\ 189 CVQual PlaceHolder<CVQual TensorForcedEvalOp<DevExpr>, N>, Params...> {\ 190 typedef CVQual TensorMap<Tensor<typename TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::Scalar,\ 191 TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::NumDimensions, 0, typename TensorForcedEvalOp<DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ 192 Type expr;\ 193 template <typename FuncDetector>\ 194 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 195 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 196 }; 197 198 FORCEDEVAL(const) 199 FORCEDEVAL() 200 #undef FORCEDEVAL 201 202 template <bool Conds, size_t X , size_t Y > struct ValueCondition { 203 static const size_t Res =X; 204 }; 205 template<size_t X, size_t Y> struct ValueCondition<false, X , Y> { 206 static const size_t Res =Y; 207 }; 208 209 /// specialisation of the \ref ExprConstructor struct when the node type is TensorReductionOp 210 #define SYCLREDUCTIONEXPR(CVQual)\ 211 template <typename OP, typename Dim, typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ 212 struct ExprConstructor<CVQual TensorReductionOp<OP, Dim, OrigExpr, MakeGlobalPointer>,\ 213 CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dim, DevExpr>, N>, Params...> {\ 214 static const size_t NumIndices= ValueCondition< TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions==0, 1, TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions >::Res;\ 215 typedef CVQual TensorMap<Tensor<typename TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::Scalar,\ 216 NumIndices, 0, typename TensorReductionOp<OP, Dim, DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ 217 Type expr;\ 218 template <typename FuncDetector>\ 219 ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ 220 : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ 221 }; 222 223 SYCLREDUCTIONEXPR(const) 224 SYCLREDUCTIONEXPR() 225 #undef SYCLREDUCTIONEXPR 226 227 /// template deduction for \ref ExprConstructor struct 228 template <typename OrigExpr, typename IndexExpr, typename FuncD, typename... Params> 229 auto createDeviceExpression(FuncD &funcD, const utility::tuple::Tuple<Params...> &t) 230 -> decltype(ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t)) { 231 return ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t); 232 } 233 234 } /// namespace TensorSycl 235 } /// namespace internal 236 } /// namespace Eigen 237 238 239 #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP 240