• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 namespace Eigen {
14 
15 /** \class TensorExpr
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor expression classes.
19   *
20   * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
21   * This is typically used to generate constants.
22   *
23   * The TensorCwiseUnaryOp class represents an expression where a unary operator
24   * (e.g. cwiseSqrt) is applied to an expression.
25   *
26   * The TensorCwiseBinaryOp class represents an expression where a binary
27   * operator (e.g. addition) is applied to a lhs and a rhs expression.
28   *
29   */
30 namespace internal {
31 template<typename NullaryOp, typename XprType>
32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33     : traits<XprType>
34 {
35   typedef traits<XprType> XprTraits;
36   typedef typename XprType::Scalar Scalar;
37   typedef typename XprType::Nested XprTypeNested;
38   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
39   static const int NumDimensions = XprTraits::NumDimensions;
40   static const int Layout = XprTraits::Layout;
41   typedef typename XprTraits::PointerType PointerType;
42   enum {
43     Flags = 0
44   };
45 };
46 
47 }  // end namespace internal
48 
49 
50 
51 template<typename NullaryOp, typename XprType>
52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
53 {
54   public:
55     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
56     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57     typedef typename XprType::CoeffReturnType CoeffReturnType;
58     typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
59     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
60     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
61 
62     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
63         : m_xpr(xpr), m_functor(func) {}
64 
65     EIGEN_DEVICE_FUNC
66     const typename internal::remove_all<typename XprType::Nested>::type&
67     nestedExpression() const { return m_xpr; }
68 
69     EIGEN_DEVICE_FUNC
70     const NullaryOp& functor() const { return m_functor; }
71 
72   protected:
73     typename XprType::Nested m_xpr;
74     const NullaryOp m_functor;
75 };
76 
77 
78 
79 namespace internal {
80 template<typename UnaryOp, typename XprType>
81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
82     : traits<XprType>
83 {
84   // TODO(phli): Add InputScalar, InputPacket.  Check references to
85   // current Scalar/Packet to see if the intent is Input or Output.
86   typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
87   typedef traits<XprType> XprTraits;
88   typedef typename XprType::Nested XprTypeNested;
89   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
90   static const int NumDimensions = XprTraits::NumDimensions;
91   static const int Layout = XprTraits::Layout;
92   typedef typename TypeConversion<Scalar,
93                                   typename XprTraits::PointerType
94                                   >::type
95                                   PointerType;
96 };
97 
98 template<typename UnaryOp, typename XprType>
99 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
100 {
101   typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
102 };
103 
104 template<typename UnaryOp, typename XprType>
105 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
106 {
107   typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
108 };
109 
110 }  // end namespace internal
111 
112 
113 
114 template<typename UnaryOp, typename XprType>
115 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
116 {
117   public:
118     // TODO(phli): Add InputScalar, InputPacket.  Check references to
119     // current Scalar/Packet to see if the intent is Input or Output.
120     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
121     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
122     typedef Scalar CoeffReturnType;
123     typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
124     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
125     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
126 
127     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
128       : m_xpr(xpr), m_functor(func) {}
129 
130     EIGEN_DEVICE_FUNC
131     const UnaryOp& functor() const { return m_functor; }
132 
133     /** \returns the nested expression */
134     EIGEN_DEVICE_FUNC
135     const typename internal::remove_all<typename XprType::Nested>::type&
136     nestedExpression() const { return m_xpr; }
137 
138   protected:
139     typename XprType::Nested m_xpr;
140     const UnaryOp m_functor;
141 };
142 
143 
144 namespace internal {
145 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
146 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
147 {
148   // Type promotion to handle the case where the types of the lhs and the rhs
149   // are different.
150   // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
151   // current Scalar/Packet to see if the intent is Inputs or Output.
152   typedef typename result_of<
153       BinaryOp(typename LhsXprType::Scalar,
154                typename RhsXprType::Scalar)>::type Scalar;
155   typedef traits<LhsXprType> XprTraits;
156   typedef typename promote_storage_type<
157       typename traits<LhsXprType>::StorageKind,
158       typename traits<RhsXprType>::StorageKind>::ret StorageKind;
159   typedef typename promote_index_type<
160       typename traits<LhsXprType>::Index,
161       typename traits<RhsXprType>::Index>::type Index;
162   typedef typename LhsXprType::Nested LhsNested;
163   typedef typename RhsXprType::Nested RhsNested;
164   typedef typename remove_reference<LhsNested>::type _LhsNested;
165   typedef typename remove_reference<RhsNested>::type _RhsNested;
166   static const int NumDimensions = XprTraits::NumDimensions;
167   static const int Layout = XprTraits::Layout;
168   typedef typename TypeConversion<Scalar,
169                                   typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
170                                                       typename traits<LhsXprType>::PointerType,
171                                                       typename traits<RhsXprType>::PointerType>::type
172                                   >::type
173                                   PointerType;
174   enum {
175     Flags = 0
176   };
177 };
178 
179 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
180 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
181 {
182   typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
183 };
184 
185 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
186 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
187 {
188   typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
189 };
190 
191 }  // end namespace internal
192 
193 
194 
195 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
196 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
197 {
198   public:
199     // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
200     // current Scalar/Packet to see if the intent is Inputs or Output.
201     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
202     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
203     typedef Scalar CoeffReturnType;
204     typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
205     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
206     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
207 
208     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
209         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
210 
211     EIGEN_DEVICE_FUNC
212     const BinaryOp& functor() const { return m_functor; }
213 
214     /** \returns the nested expressions */
215     EIGEN_DEVICE_FUNC
216     const typename internal::remove_all<typename LhsXprType::Nested>::type&
217     lhsExpression() const { return m_lhs_xpr; }
218 
219     EIGEN_DEVICE_FUNC
220     const typename internal::remove_all<typename RhsXprType::Nested>::type&
221     rhsExpression() const { return m_rhs_xpr; }
222 
223   protected:
224     typename LhsXprType::Nested m_lhs_xpr;
225     typename RhsXprType::Nested m_rhs_xpr;
226     const BinaryOp m_functor;
227 };
228 
229 
230 namespace internal {
231 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
232 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
233 {
234   // Type promotion to handle the case where the types of the args are different.
235   typedef typename result_of<
236       TernaryOp(typename Arg1XprType::Scalar,
237                 typename Arg2XprType::Scalar,
238                 typename Arg3XprType::Scalar)>::type Scalar;
239   typedef traits<Arg1XprType> XprTraits;
240   typedef typename traits<Arg1XprType>::StorageKind StorageKind;
241   typedef typename traits<Arg1XprType>::Index Index;
242   typedef typename Arg1XprType::Nested Arg1Nested;
243   typedef typename Arg2XprType::Nested Arg2Nested;
244   typedef typename Arg3XprType::Nested Arg3Nested;
245   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
246   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
247   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
248   static const int NumDimensions = XprTraits::NumDimensions;
249   static const int Layout = XprTraits::Layout;
250   typedef typename TypeConversion<Scalar,
251                                   typename conditional<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
252                                                       typename traits<Arg2XprType>::PointerType,
253                                                       typename traits<Arg3XprType>::PointerType>::type
254                                   >::type
255                                   PointerType;
256   enum {
257     Flags = 0
258   };
259 };
260 
261 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
262 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
263 {
264   typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
265 };
266 
267 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
268 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
269 {
270   typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
271 };
272 
273 }  // end namespace internal
274 
275 
276 
277 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
278 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
279 {
280   public:
281     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
282     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
283     typedef Scalar CoeffReturnType;
284     typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
285     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
286     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
287 
288     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
289         : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
290 
291     EIGEN_DEVICE_FUNC
292     const TernaryOp& functor() const { return m_functor; }
293 
294     /** \returns the nested expressions */
295     EIGEN_DEVICE_FUNC
296     const typename internal::remove_all<typename Arg1XprType::Nested>::type&
297     arg1Expression() const { return m_arg1_xpr; }
298 
299     EIGEN_DEVICE_FUNC
300     const typename internal::remove_all<typename Arg2XprType::Nested>::type&
301     arg2Expression() const { return m_arg2_xpr; }
302 
303     EIGEN_DEVICE_FUNC
304     const typename internal::remove_all<typename Arg3XprType::Nested>::type&
305     arg3Expression() const { return m_arg3_xpr; }
306 
307   protected:
308     typename Arg1XprType::Nested m_arg1_xpr;
309     typename Arg2XprType::Nested m_arg2_xpr;
310     typename Arg3XprType::Nested m_arg3_xpr;
311     const TernaryOp m_functor;
312 };
313 
314 
315 namespace internal {
316 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
317 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
318     : traits<ThenXprType>
319 {
320   typedef typename traits<ThenXprType>::Scalar Scalar;
321   typedef traits<ThenXprType> XprTraits;
322   typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
323                                         typename traits<ElseXprType>::StorageKind>::ret StorageKind;
324   typedef typename promote_index_type<typename traits<ElseXprType>::Index,
325                                       typename traits<ThenXprType>::Index>::type Index;
326   typedef typename IfXprType::Nested IfNested;
327   typedef typename ThenXprType::Nested ThenNested;
328   typedef typename ElseXprType::Nested ElseNested;
329   static const int NumDimensions = XprTraits::NumDimensions;
330   static const int Layout = XprTraits::Layout;
331   typedef typename conditional<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
332                                typename traits<ThenXprType>::PointerType,
333                                typename traits<ElseXprType>::PointerType>::type PointerType;
334 };
335 
336 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
337 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
338 {
339   typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
340 };
341 
342 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
343 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
344 {
345   typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
346 };
347 
348 }  // end namespace internal
349 
350 
351 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
352 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
353 {
354   public:
355     typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
356     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
357     typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
358                                                     typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
359     typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
360     typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
361     typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
362 
363     EIGEN_DEVICE_FUNC
364     TensorSelectOp(const IfXprType& a_condition,
365                    const ThenXprType& a_then,
366                    const ElseXprType& a_else)
367       : m_condition(a_condition), m_then(a_then), m_else(a_else)
368     { }
369 
370     EIGEN_DEVICE_FUNC
371     const IfXprType& ifExpression() const { return m_condition; }
372 
373     EIGEN_DEVICE_FUNC
374     const ThenXprType& thenExpression() const { return m_then; }
375 
376     EIGEN_DEVICE_FUNC
377     const ElseXprType& elseExpression() const { return m_else; }
378 
379   protected:
380     typename IfXprType::Nested m_condition;
381     typename ThenXprType::Nested m_then;
382     typename ElseXprType::Nested m_else;
383 };
384 
385 
386 } // end namespace Eigen
387 
388 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
389