• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Mehdi Goli    Codeplay Software Ltd.
5 // Ralph Potter  Codeplay Software Ltd.
6 // Luke Iwanski  Codeplay Software Ltd.
7 // Contact: <eigen@codeplay.com>
8 //
9 // This Source Code Form is subject to the terms of the Mozilla
10 // Public License v. 2.0. If a copy of the MPL was not distributed
11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12 
13 /*****************************************************************
14  * TensorSyclPlaceHolderExpr.h
15  *
16  * \brief:
17  *  This is the specialisation of the placeholder expression based on the
18  * operation type
19  *
20 *****************************************************************/
21 
22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP
23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP
24 
25 namespace Eigen {
26 namespace TensorSycl {
27 namespace internal {
28 
29 /// \struct PlaceHolder
30 /// \brief PlaceHolder is used to replace the \ref TensorMap in the expression
31 /// tree.
32 /// PlaceHolder contains the order of the leaf node in the expression tree.
33 template <typename Scalar, size_t N>
34 struct PlaceHolder {
35   static constexpr size_t I = N;
36   typedef Scalar Type;
37 };
38 
39 /// \sttruct PlaceHolderExpression
40 /// \brief it is used to create the PlaceHolder expression. The PlaceHolder
41 /// expression is a copy of expression type in which the TensorMap of the has
42 /// been replaced with PlaceHolder.
43 template <typename Expr, size_t N>
44 struct PlaceHolderExpression;
45 
46 template<size_t N, typename... Args>
47 struct CalculateIndex;
48 
49 template<size_t N, typename Arg>
50 struct CalculateIndex<N, Arg>{
51   typedef typename PlaceHolderExpression<Arg, N>::Type ArgType;
52   typedef utility::tuple::Tuple<ArgType> ArgsTuple;
53 };
54 
55 template<size_t N, typename Arg1, typename Arg2>
56 struct CalculateIndex<N, Arg1, Arg2>{
57   static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
58   typedef typename PlaceHolderExpression<Arg1, N - Arg2LeafCount>::Type Arg1Type;
59   typedef typename PlaceHolderExpression<Arg2, N>::Type Arg2Type;
60   typedef utility::tuple::Tuple<Arg1Type, Arg2Type> ArgsTuple;
61 };
62 
63 template<size_t N, typename Arg1, typename Arg2, typename Arg3>
64 struct CalculateIndex<N, Arg1, Arg2, Arg3> {
65   static const size_t Arg3LeafCount = LeafCount<Arg3>::Count;
66   static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
67   typedef typename PlaceHolderExpression<Arg1, N - Arg3LeafCount - Arg2LeafCount>::Type Arg1Type;
68   typedef typename PlaceHolderExpression<Arg2, N - Arg3LeafCount>::Type Arg2Type;
69   typedef typename PlaceHolderExpression<Arg3, N>::Type Arg3Type;
70   typedef utility::tuple::Tuple<Arg1Type, Arg2Type, Arg3Type> ArgsTuple;
71 };
72 
73 template<template<class...> class Category , class OP, class TPL>
74 struct CategoryHelper;
75 
76 template<template<class...> class Category , class OP, class ...T >
77 struct CategoryHelper<Category, OP, utility::tuple::Tuple<T...> > {
78   typedef Category<OP, T... > Type;
79 };
80 
81 template<template<class...> class Category , class ...T >
82 struct CategoryHelper<Category, NoOP, utility::tuple::Tuple<T...> > {
83   typedef Category<T... > Type;
84 };
85 
86 /// specialisation of the \ref PlaceHolderExpression when the node is
87 /// TensorCwiseNullaryOp, TensorCwiseUnaryOp, TensorBroadcastingOp, TensorCwiseBinaryOp,  TensorCwiseTernaryOp
88 #define OPEXPRCATEGORY(CVQual)\
89 template <template <class, class... > class Category, typename OP, typename... SubExpr, size_t N>\
90 struct PlaceHolderExpression<CVQual Category<OP, SubExpr...>, N>{\
91   typedef CVQual typename CategoryHelper<Category, OP, typename CalculateIndex<N, SubExpr...>::ArgsTuple>::Type Type;\
92 };
93 
94 OPEXPRCATEGORY(const)
95 OPEXPRCATEGORY()
96 #undef OPEXPRCATEGORY
97 
98 /// specialisation of the \ref PlaceHolderExpression when the node is
99 /// TensorCwiseSelectOp
100 #define SELECTEXPR(CVQual)\
101 template <typename IfExpr, typename ThenExpr, typename ElseExpr, size_t N>\
102 struct PlaceHolderExpression<CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, N> {\
103   typedef CVQual typename CategoryHelper<TensorSelectOp, NoOP, typename CalculateIndex<N, IfExpr, ThenExpr, ElseExpr>::ArgsTuple>::Type Type;\
104 };
105 
106 SELECTEXPR(const)
107 SELECTEXPR()
108 #undef SELECTEXPR
109 
110 /// specialisation of the \ref PlaceHolderExpression when the node is
111 /// TensorAssignOp
112 #define ASSIGNEXPR(CVQual)\
113 template <typename LHSExpr, typename RHSExpr, size_t N>\
114 struct PlaceHolderExpression<CVQual TensorAssignOp<LHSExpr, RHSExpr>, N> {\
115   typedef CVQual typename CategoryHelper<TensorAssignOp, NoOP, typename CalculateIndex<N, LHSExpr, RHSExpr>::ArgsTuple>::Type Type;\
116 };
117 
118 ASSIGNEXPR(const)
119 ASSIGNEXPR()
120 #undef ASSIGNEXPR
121 
122 /// specialisation of the \ref PlaceHolderExpression when the node is
123 /// TensorMap
124 #define TENSORMAPEXPR(CVQual)\
125 template <typename Scalar_, int Options_, int Options2_, int NumIndices_, typename IndexType_, template <class> class MakePointer_, size_t N>\
126 struct PlaceHolderExpression< CVQual TensorMap< Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> {\
127   typedef CVQual PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> Type;\
128 };
129 
130 TENSORMAPEXPR(const)
131 TENSORMAPEXPR()
132 #undef TENSORMAPEXPR
133 
134 /// specialisation of the \ref PlaceHolderExpression when the node is
135 /// TensorForcedEvalOp
136 #define FORCEDEVAL(CVQual)\
137 template <typename Expr, size_t N>\
138 struct PlaceHolderExpression<CVQual TensorForcedEvalOp<Expr>, N> {\
139   typedef CVQual PlaceHolder<CVQual TensorForcedEvalOp<Expr>, N> Type;\
140 };
141 
142 FORCEDEVAL(const)
143 FORCEDEVAL()
144 #undef FORCEDEVAL
145 
146 /// specialisation of the \ref PlaceHolderExpression when the node is
147 /// TensorEvalToOp
148 #define EVALTO(CVQual)\
149 template <typename Expr, size_t N>\
150 struct PlaceHolderExpression<CVQual TensorEvalToOp<Expr>, N> {\
151   typedef CVQual TensorEvalToOp<typename CalculateIndex <N, Expr>::ArgType> Type;\
152 };
153 
154 EVALTO(const)
155 EVALTO()
156 #undef EVALTO
157 
158 
159 /// specialisation of the \ref PlaceHolderExpression when the node is
160 /// TensorReductionOp
161 #define SYCLREDUCTION(CVQual)\
162 template <typename OP, typename Dims, typename Expr, size_t N>\
163 struct PlaceHolderExpression<CVQual TensorReductionOp<OP, Dims, Expr>, N>{\
164   typedef CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dims,Expr>, N> Type;\
165 };
166 SYCLREDUCTION(const)
167 SYCLREDUCTION()
168 #undef SYCLREDUCTION
169 
170 /// template deduction for \ref PlaceHolderExpression struct
171 template <typename Expr>
172 struct createPlaceHolderExpression {
173   static const size_t TotalLeaves = LeafCount<Expr>::Count;
174   typedef typename PlaceHolderExpression<Expr, TotalLeaves - 1>::Type Type;
175 };
176 
177 }  // internal
178 }  // TensorSycl
179 }  // namespace Eigen
180 
181 #endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP
182