• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12 
13 namespace Eigen {
14 
15 /** \class TensorCustomUnaryOp
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor custom class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename CustomUnaryFunc, typename XprType>
24 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef typename XprType::StorageKind StorageKind;
28   typedef typename XprType::Index Index;
29   typedef typename XprType::Nested Nested;
30   typedef typename remove_reference<Nested>::type _Nested;
31   static const int NumDimensions = traits<XprType>::NumDimensions;
32   static const int Layout = traits<XprType>::Layout;
33 };
34 
35 template<typename CustomUnaryFunc, typename XprType>
36 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
37 {
38   typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
39 };
40 
41 template<typename CustomUnaryFunc, typename XprType>
42 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
43 {
44   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
45 };
46 
47 }  // end namespace internal
48 
49 
50 
51 template<typename CustomUnaryFunc, typename XprType>
52 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
53 {
54   public:
55   typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
56   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57   typedef typename XprType::CoeffReturnType CoeffReturnType;
58   typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
59   typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
60   typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
61 
62   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
63       : m_expr(expr), m_func(func) {}
64 
65   EIGEN_DEVICE_FUNC
66   const CustomUnaryFunc& func() const { return m_func; }
67 
68   EIGEN_DEVICE_FUNC
69   const typename internal::remove_all<typename XprType::Nested>::type&
70   expression() const { return m_expr; }
71 
72   protected:
73     typename XprType::Nested m_expr;
74     const CustomUnaryFunc m_func;
75 };
76 
77 
78 // Eval as rvalue
79 template<typename CustomUnaryFunc, typename XprType, typename Device>
80 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
81 {
82   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
83   typedef typename internal::traits<ArgType>::Index Index;
84   static const int NumDims = internal::traits<ArgType>::NumDimensions;
85   typedef DSizes<Index, NumDims> Dimensions;
86   typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
87   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
88   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
89   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
90 
91   enum {
92     IsAligned = false,
93     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
94     BlockAccess = false,
95     Layout = TensorEvaluator<XprType, Device>::Layout,
96     CoordAccess = false,  // to be implemented
97     RawAccess = false
98   };
99 
100   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
101       : m_op(op), m_device(device), m_result(NULL)
102   {
103     m_dimensions = op.func().dimensions(op.expression());
104   }
105 
106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
107 
108   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
109     if (data) {
110       evalTo(data);
111       return false;
112     } else {
113       m_result = static_cast<CoeffReturnType*>(
114           m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
115       evalTo(m_result);
116       return true;
117     }
118   }
119 
120   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
121     if (m_result != NULL) {
122       m_device.deallocate(m_result);
123       m_result = NULL;
124     }
125   }
126 
127   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
128     return m_result[index];
129   }
130 
131   template<int LoadMode>
132   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
133     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
134   }
135 
136   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
137     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
138     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
139   }
140 
141   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
142 
143  protected:
144   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
145     TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
146         data, m_dimensions);
147     m_op.func().eval(m_op.expression(), result, m_device);
148   }
149 
150   Dimensions m_dimensions;
151   const ArgType m_op;
152   const Device& m_device;
153   CoeffReturnType* m_result;
154 };
155 
156 
157 
158 /** \class TensorCustomBinaryOp
159   * \ingroup CXX11_Tensor_Module
160   *
161   * \brief Tensor custom class.
162   *
163   *
164   */
165 namespace internal {
166 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
167 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
168 {
169   typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
170                                                   typename RhsXprType::Scalar>::ret Scalar;
171   typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
172                                                   typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
173   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
174                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
175   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
176                                       typename traits<RhsXprType>::Index>::type Index;
177   typedef typename LhsXprType::Nested LhsNested;
178   typedef typename RhsXprType::Nested RhsNested;
179   typedef typename remove_reference<LhsNested>::type _LhsNested;
180   typedef typename remove_reference<RhsNested>::type _RhsNested;
181   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
182   static const int Layout = traits<LhsXprType>::Layout;
183 };
184 
185 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
186 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
187 {
188   typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
189 };
190 
191 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
192 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
193 {
194   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
195 };
196 
197 }  // end namespace internal
198 
199 
200 
201 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
202 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
203 {
204   public:
205   typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
206   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
207   typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
208   typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
209   typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
210   typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
211 
212   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
213 
214       : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
215 
216   EIGEN_DEVICE_FUNC
217   const CustomBinaryFunc& func() const { return m_func; }
218 
219   EIGEN_DEVICE_FUNC
220   const typename internal::remove_all<typename LhsXprType::Nested>::type&
221   lhsExpression() const { return m_lhs_xpr; }
222 
223   EIGEN_DEVICE_FUNC
224   const typename internal::remove_all<typename RhsXprType::Nested>::type&
225   rhsExpression() const { return m_rhs_xpr; }
226 
227   protected:
228     typename LhsXprType::Nested m_lhs_xpr;
229     typename RhsXprType::Nested m_rhs_xpr;
230     const CustomBinaryFunc m_func;
231 };
232 
233 
234 // Eval as rvalue
235 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
236 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
237 {
238   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
239   typedef typename internal::traits<XprType>::Index Index;
240   static const int NumDims = internal::traits<XprType>::NumDimensions;
241   typedef DSizes<Index, NumDims> Dimensions;
242   typedef typename XprType::Scalar Scalar;
243   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
244   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
245   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
246 
247   enum {
248     IsAligned = false,
249     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
250     BlockAccess = false,
251     Layout = TensorEvaluator<LhsXprType, Device>::Layout,
252     CoordAccess = false,  // to be implemented
253     RawAccess = false
254   };
255 
256   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
257       : m_op(op), m_device(device), m_result(NULL)
258   {
259     m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
260   }
261 
262   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
263 
264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
265     if (data) {
266       evalTo(data);
267       return false;
268     } else {
269       m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
270       evalTo(m_result);
271       return true;
272     }
273   }
274 
275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
276     if (m_result != NULL) {
277       m_device.deallocate(m_result);
278       m_result = NULL;
279     }
280   }
281 
282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
283     return m_result[index];
284   }
285 
286   template<int LoadMode>
287   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
288     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
289   }
290 
291   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
292     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
293     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
294   }
295 
296   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
297 
298  protected:
299   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
300     TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
301     m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
302   }
303 
304   Dimensions m_dimensions;
305   const XprType m_op;
306   const Device& m_device;
307   CoeffReturnType* m_result;
308 };
309 
310 
311 } // end namespace Eigen
312 
313 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
314