1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 12 13 namespace Eigen { 14 15 /** \class TensorExpr 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor expression classes. 19 * 20 * The TensorCwiseNullaryOp class applies a nullary operators to an expression. 21 * This is typically used to generate constants. 22 * 23 * The TensorCwiseUnaryOp class represents an expression where a unary operator 24 * (e.g. cwiseSqrt) is applied to an expression. 25 * 26 * The TensorCwiseBinaryOp class represents an expression where a binary 27 * operator (e.g. addition) is applied to a lhs and a rhs expression. 28 * 29 */ 30 namespace internal { 31 template<typename NullaryOp, typename XprType> 32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > 33 : traits<XprType> 34 { 35 typedef traits<XprType> XprTraits; 36 typedef typename XprType::Scalar Scalar; 37 typedef typename XprType::Nested XprTypeNested; 38 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 39 static const int NumDimensions = XprTraits::NumDimensions; 40 static const int Layout = XprTraits::Layout; 41 42 enum { 43 Flags = 0 44 }; 45 }; 46 47 } // end namespace internal 48 49 50 51 template<typename NullaryOp, typename XprType> 52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> 53 { 54 public: 55 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; 56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 57 typedef typename XprType::CoeffReturnType CoeffReturnType; 58 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested; 59 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; 60 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; 61 62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) 63 : m_xpr(xpr), m_functor(func) {} 64 65 EIGEN_DEVICE_FUNC 66 const typename internal::remove_all<typename XprType::Nested>::type& 67 nestedExpression() const { return m_xpr; } 68 69 EIGEN_DEVICE_FUNC 70 const NullaryOp& functor() const { return m_functor; } 71 72 protected: 73 typename XprType::Nested m_xpr; 74 const NullaryOp m_functor; 75 }; 76 77 78 79 namespace internal { 80 template<typename UnaryOp, typename XprType> 81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > 82 : traits<XprType> 83 { 84 // TODO(phli): Add InputScalar, InputPacket. Check references to 85 // current Scalar/Packet to see if the intent is Input or Output. 86 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; 87 typedef traits<XprType> XprTraits; 88 typedef typename XprType::Nested XprTypeNested; 89 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 90 static const int NumDimensions = XprTraits::NumDimensions; 91 static const int Layout = XprTraits::Layout; 92 }; 93 94 template<typename UnaryOp, typename XprType> 95 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense> 96 { 97 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type; 98 }; 99 100 template<typename UnaryOp, typename XprType> 101 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type> 102 { 103 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type; 104 }; 105 106 } // end namespace internal 107 108 109 110 template<typename UnaryOp, typename XprType> 111 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> 112 { 113 public: 114 // TODO(phli): Add InputScalar, InputPacket. Check references to 115 // current Scalar/Packet to see if the intent is Input or Output. 116 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; 117 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 118 typedef Scalar CoeffReturnType; 119 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; 120 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; 121 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; 122 123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp()) 124 : m_xpr(xpr), m_functor(func) {} 125 126 EIGEN_DEVICE_FUNC 127 const UnaryOp& functor() const { return m_functor; } 128 129 /** \returns the nested expression */ 130 EIGEN_DEVICE_FUNC 131 const typename internal::remove_all<typename XprType::Nested>::type& 132 nestedExpression() const { return m_xpr; } 133 134 protected: 135 typename XprType::Nested m_xpr; 136 const UnaryOp m_functor; 137 }; 138 139 140 namespace internal { 141 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 142 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > 143 { 144 // Type promotion to handle the case where the types of the lhs and the rhs 145 // are different. 146 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 147 // current Scalar/Packet to see if the intent is Inputs or Output. 148 typedef typename result_of< 149 BinaryOp(typename LhsXprType::Scalar, 150 typename RhsXprType::Scalar)>::type Scalar; 151 typedef traits<LhsXprType> XprTraits; 152 typedef typename promote_storage_type< 153 typename traits<LhsXprType>::StorageKind, 154 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 155 typedef typename promote_index_type< 156 typename traits<LhsXprType>::Index, 157 typename traits<RhsXprType>::Index>::type Index; 158 typedef typename LhsXprType::Nested LhsNested; 159 typedef typename RhsXprType::Nested RhsNested; 160 typedef typename remove_reference<LhsNested>::type _LhsNested; 161 typedef typename remove_reference<RhsNested>::type _RhsNested; 162 static const int NumDimensions = XprTraits::NumDimensions; 163 static const int Layout = XprTraits::Layout; 164 165 enum { 166 Flags = 0 167 }; 168 }; 169 170 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 171 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense> 172 { 173 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type; 174 }; 175 176 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 177 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type> 178 { 179 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type; 180 }; 181 182 } // end namespace internal 183 184 185 186 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 187 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> 188 { 189 public: 190 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 191 // current Scalar/Packet to see if the intent is Inputs or Output. 192 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; 193 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 194 typedef Scalar CoeffReturnType; 195 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; 196 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; 197 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; 198 199 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) 200 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} 201 202 EIGEN_DEVICE_FUNC 203 const BinaryOp& functor() const { return m_functor; } 204 205 /** \returns the nested expressions */ 206 EIGEN_DEVICE_FUNC 207 const typename internal::remove_all<typename LhsXprType::Nested>::type& 208 lhsExpression() const { return m_lhs_xpr; } 209 210 EIGEN_DEVICE_FUNC 211 const typename internal::remove_all<typename RhsXprType::Nested>::type& 212 rhsExpression() const { return m_rhs_xpr; } 213 214 protected: 215 typename LhsXprType::Nested m_lhs_xpr; 216 typename RhsXprType::Nested m_rhs_xpr; 217 const BinaryOp m_functor; 218 }; 219 220 221 namespace internal { 222 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 223 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > 224 { 225 // Type promotion to handle the case where the types of the args are different. 226 typedef typename result_of< 227 TernaryOp(typename Arg1XprType::Scalar, 228 typename Arg2XprType::Scalar, 229 typename Arg3XprType::Scalar)>::type Scalar; 230 typedef traits<Arg1XprType> XprTraits; 231 typedef typename traits<Arg1XprType>::StorageKind StorageKind; 232 typedef typename traits<Arg1XprType>::Index Index; 233 typedef typename Arg1XprType::Nested Arg1Nested; 234 typedef typename Arg2XprType::Nested Arg2Nested; 235 typedef typename Arg3XprType::Nested Arg3Nested; 236 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; 237 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested; 238 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; 239 static const int NumDimensions = XprTraits::NumDimensions; 240 static const int Layout = XprTraits::Layout; 241 242 enum { 243 Flags = 0 244 }; 245 }; 246 247 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 248 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> 249 { 250 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type; 251 }; 252 253 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 254 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> 255 { 256 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type; 257 }; 258 259 } // end namespace internal 260 261 262 263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 264 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> 265 { 266 public: 267 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar; 268 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 269 typedef Scalar CoeffReturnType; 270 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested; 271 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind; 272 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index; 273 274 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) 275 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} 276 277 EIGEN_DEVICE_FUNC 278 const TernaryOp& functor() const { return m_functor; } 279 280 /** \returns the nested expressions */ 281 EIGEN_DEVICE_FUNC 282 const typename internal::remove_all<typename Arg1XprType::Nested>::type& 283 arg1Expression() const { return m_arg1_xpr; } 284 285 EIGEN_DEVICE_FUNC 286 const typename internal::remove_all<typename Arg2XprType::Nested>::type& 287 arg2Expression() const { return m_arg2_xpr; } 288 289 EIGEN_DEVICE_FUNC 290 const typename internal::remove_all<typename Arg3XprType::Nested>::type& 291 arg3Expression() const { return m_arg3_xpr; } 292 293 protected: 294 typename Arg1XprType::Nested m_arg1_xpr; 295 typename Arg2XprType::Nested m_arg2_xpr; 296 typename Arg3XprType::Nested m_arg3_xpr; 297 const TernaryOp m_functor; 298 }; 299 300 301 namespace internal { 302 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 303 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > 304 : traits<ThenXprType> 305 { 306 typedef typename traits<ThenXprType>::Scalar Scalar; 307 typedef traits<ThenXprType> XprTraits; 308 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, 309 typename traits<ElseXprType>::StorageKind>::ret StorageKind; 310 typedef typename promote_index_type<typename traits<ElseXprType>::Index, 311 typename traits<ThenXprType>::Index>::type Index; 312 typedef typename IfXprType::Nested IfNested; 313 typedef typename ThenXprType::Nested ThenNested; 314 typedef typename ElseXprType::Nested ElseNested; 315 static const int NumDimensions = XprTraits::NumDimensions; 316 static const int Layout = XprTraits::Layout; 317 }; 318 319 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 320 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> 321 { 322 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; 323 }; 324 325 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 326 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> 327 { 328 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; 329 }; 330 331 } // end namespace internal 332 333 334 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 335 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> 336 { 337 public: 338 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; 339 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 340 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, 341 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; 342 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; 343 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; 344 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; 345 346 EIGEN_DEVICE_FUNC 347 TensorSelectOp(const IfXprType& a_condition, 348 const ThenXprType& a_then, 349 const ElseXprType& a_else) 350 : m_condition(a_condition), m_then(a_then), m_else(a_else) 351 { } 352 353 EIGEN_DEVICE_FUNC 354 const IfXprType& ifExpression() const { return m_condition; } 355 356 EIGEN_DEVICE_FUNC 357 const ThenXprType& thenExpression() const { return m_then; } 358 359 EIGEN_DEVICE_FUNC 360 const ElseXprType& elseExpression() const { return m_else; } 361 362 protected: 363 typename IfXprType::Nested m_condition; 364 typename ThenXprType::Nested m_then; 365 typename ElseXprType::Nested m_else; 366 }; 367 368 369 } // end namespace Eigen 370 371 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 372