• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 namespace Eigen {
14 
15 /** \class TensorExpr
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor expression classes.
19   *
20   * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
21   * This is typically used to generate constants.
22   *
23   * The TensorCwiseUnaryOp class represents an expression where a unary operator
24   * (e.g. cwiseSqrt) is applied to an expression.
25   *
26   * The TensorCwiseBinaryOp class represents an expression where a binary
27   * operator (e.g. addition) is applied to a lhs and a rhs expression.
28   *
29   */
30 namespace internal {
31 template<typename NullaryOp, typename XprType>
32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33     : traits<XprType>
34 {
35   typedef traits<XprType> XprTraits;
36   typedef typename XprType::Scalar Scalar;
37   typedef typename XprType::Nested XprTypeNested;
38   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
39   static const int NumDimensions = XprTraits::NumDimensions;
40   static const int Layout = XprTraits::Layout;
41 
42   enum {
43     Flags = 0
44   };
45 };
46 
47 }  // end namespace internal
48 
49 
50 
51 template<typename NullaryOp, typename XprType>
52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
53 {
54   public:
55     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
56     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57     typedef typename XprType::CoeffReturnType CoeffReturnType;
58     typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
59     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
60     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
61 
62     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
63         : m_xpr(xpr), m_functor(func) {}
64 
65     EIGEN_DEVICE_FUNC
66     const typename internal::remove_all<typename XprType::Nested>::type&
67     nestedExpression() const { return m_xpr; }
68 
69     EIGEN_DEVICE_FUNC
70     const NullaryOp& functor() const { return m_functor; }
71 
72   protected:
73     typename XprType::Nested m_xpr;
74     const NullaryOp m_functor;
75 };
76 
77 
78 
79 namespace internal {
80 template<typename UnaryOp, typename XprType>
81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
82     : traits<XprType>
83 {
84   // TODO(phli): Add InputScalar, InputPacket.  Check references to
85   // current Scalar/Packet to see if the intent is Input or Output.
86   typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
87   typedef traits<XprType> XprTraits;
88   typedef typename XprType::Nested XprTypeNested;
89   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
90   static const int NumDimensions = XprTraits::NumDimensions;
91   static const int Layout = XprTraits::Layout;
92 };
93 
94 template<typename UnaryOp, typename XprType>
95 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
96 {
97   typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
98 };
99 
100 template<typename UnaryOp, typename XprType>
101 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
102 {
103   typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
104 };
105 
106 }  // end namespace internal
107 
108 
109 
110 template<typename UnaryOp, typename XprType>
111 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
112 {
113   public:
114     // TODO(phli): Add InputScalar, InputPacket.  Check references to
115     // current Scalar/Packet to see if the intent is Input or Output.
116     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
117     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
118     typedef Scalar CoeffReturnType;
119     typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
120     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
121     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
122 
123     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
124       : m_xpr(xpr), m_functor(func) {}
125 
126     EIGEN_DEVICE_FUNC
127     const UnaryOp& functor() const { return m_functor; }
128 
129     /** \returns the nested expression */
130     EIGEN_DEVICE_FUNC
131     const typename internal::remove_all<typename XprType::Nested>::type&
132     nestedExpression() const { return m_xpr; }
133 
134   protected:
135     typename XprType::Nested m_xpr;
136     const UnaryOp m_functor;
137 };
138 
139 
140 namespace internal {
141 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
142 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
143 {
144   // Type promotion to handle the case where the types of the lhs and the rhs
145   // are different.
146   // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
147   // current Scalar/Packet to see if the intent is Inputs or Output.
148   typedef typename result_of<
149       BinaryOp(typename LhsXprType::Scalar,
150                typename RhsXprType::Scalar)>::type Scalar;
151   typedef traits<LhsXprType> XprTraits;
152   typedef typename promote_storage_type<
153       typename traits<LhsXprType>::StorageKind,
154       typename traits<RhsXprType>::StorageKind>::ret StorageKind;
155   typedef typename promote_index_type<
156       typename traits<LhsXprType>::Index,
157       typename traits<RhsXprType>::Index>::type Index;
158   typedef typename LhsXprType::Nested LhsNested;
159   typedef typename RhsXprType::Nested RhsNested;
160   typedef typename remove_reference<LhsNested>::type _LhsNested;
161   typedef typename remove_reference<RhsNested>::type _RhsNested;
162   static const int NumDimensions = XprTraits::NumDimensions;
163   static const int Layout = XprTraits::Layout;
164 
165   enum {
166     Flags = 0
167   };
168 };
169 
170 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
171 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
172 {
173   typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
174 };
175 
176 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
177 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
178 {
179   typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
180 };
181 
182 }  // end namespace internal
183 
184 
185 
186 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
187 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
188 {
189   public:
190     // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
191     // current Scalar/Packet to see if the intent is Inputs or Output.
192     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
193     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
194     typedef Scalar CoeffReturnType;
195     typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
196     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
197     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
198 
199     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
200         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
201 
202     EIGEN_DEVICE_FUNC
203     const BinaryOp& functor() const { return m_functor; }
204 
205     /** \returns the nested expressions */
206     EIGEN_DEVICE_FUNC
207     const typename internal::remove_all<typename LhsXprType::Nested>::type&
208     lhsExpression() const { return m_lhs_xpr; }
209 
210     EIGEN_DEVICE_FUNC
211     const typename internal::remove_all<typename RhsXprType::Nested>::type&
212     rhsExpression() const { return m_rhs_xpr; }
213 
214   protected:
215     typename LhsXprType::Nested m_lhs_xpr;
216     typename RhsXprType::Nested m_rhs_xpr;
217     const BinaryOp m_functor;
218 };
219 
220 
221 namespace internal {
222 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
223 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
224 {
225   // Type promotion to handle the case where the types of the args are different.
226   typedef typename result_of<
227       TernaryOp(typename Arg1XprType::Scalar,
228                 typename Arg2XprType::Scalar,
229                 typename Arg3XprType::Scalar)>::type Scalar;
230   typedef traits<Arg1XprType> XprTraits;
231   typedef typename traits<Arg1XprType>::StorageKind StorageKind;
232   typedef typename traits<Arg1XprType>::Index Index;
233   typedef typename Arg1XprType::Nested Arg1Nested;
234   typedef typename Arg2XprType::Nested Arg2Nested;
235   typedef typename Arg3XprType::Nested Arg3Nested;
236   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
237   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
238   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
239   static const int NumDimensions = XprTraits::NumDimensions;
240   static const int Layout = XprTraits::Layout;
241 
242   enum {
243     Flags = 0
244   };
245 };
246 
247 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
248 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
249 {
250   typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
251 };
252 
253 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
254 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
255 {
256   typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
257 };
258 
259 }  // end namespace internal
260 
261 
262 
263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
264 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
265 {
266   public:
267     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
268     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
269     typedef Scalar CoeffReturnType;
270     typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
271     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
272     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
273 
274     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
275         : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
276 
277     EIGEN_DEVICE_FUNC
278     const TernaryOp& functor() const { return m_functor; }
279 
280     /** \returns the nested expressions */
281     EIGEN_DEVICE_FUNC
282     const typename internal::remove_all<typename Arg1XprType::Nested>::type&
283     arg1Expression() const { return m_arg1_xpr; }
284 
285     EIGEN_DEVICE_FUNC
286     const typename internal::remove_all<typename Arg2XprType::Nested>::type&
287     arg2Expression() const { return m_arg2_xpr; }
288 
289     EIGEN_DEVICE_FUNC
290     const typename internal::remove_all<typename Arg3XprType::Nested>::type&
291     arg3Expression() const { return m_arg3_xpr; }
292 
293   protected:
294     typename Arg1XprType::Nested m_arg1_xpr;
295     typename Arg2XprType::Nested m_arg2_xpr;
296     typename Arg3XprType::Nested m_arg3_xpr;
297     const TernaryOp m_functor;
298 };
299 
300 
301 namespace internal {
302 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
303 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
304     : traits<ThenXprType>
305 {
306   typedef typename traits<ThenXprType>::Scalar Scalar;
307   typedef traits<ThenXprType> XprTraits;
308   typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
309                                         typename traits<ElseXprType>::StorageKind>::ret StorageKind;
310   typedef typename promote_index_type<typename traits<ElseXprType>::Index,
311                                       typename traits<ThenXprType>::Index>::type Index;
312   typedef typename IfXprType::Nested IfNested;
313   typedef typename ThenXprType::Nested ThenNested;
314   typedef typename ElseXprType::Nested ElseNested;
315   static const int NumDimensions = XprTraits::NumDimensions;
316   static const int Layout = XprTraits::Layout;
317 };
318 
319 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
320 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
321 {
322   typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
323 };
324 
325 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
326 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
327 {
328   typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
329 };
330 
331 }  // end namespace internal
332 
333 
334 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
335 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
336 {
337   public:
338     typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
339     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
340     typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
341                                                     typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
342     typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
343     typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
344     typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
345 
346     EIGEN_DEVICE_FUNC
347     TensorSelectOp(const IfXprType& a_condition,
348                    const ThenXprType& a_then,
349                    const ElseXprType& a_else)
350       : m_condition(a_condition), m_then(a_then), m_else(a_else)
351     { }
352 
353     EIGEN_DEVICE_FUNC
354     const IfXprType& ifExpression() const { return m_condition; }
355 
356     EIGEN_DEVICE_FUNC
357     const ThenXprType& thenExpression() const { return m_then; }
358 
359     EIGEN_DEVICE_FUNC
360     const ElseXprType& elseExpression() const { return m_else; }
361 
362   protected:
363     typename IfXprType::Nested m_condition;
364     typename ThenXprType::Nested m_then;
365     typename ElseXprType::Nested m_else;
366 };
367 
368 
369 } // end namespace Eigen
370 
371 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
372