1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 12 13 namespace Eigen { 14 15 /** \class TensorExpr 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor expression classes. 19 * 20 * The TensorCwiseNullaryOp class applies a nullary operators to an expression. 21 * This is typically used to generate constants. 22 * 23 * The TensorCwiseUnaryOp class represents an expression where a unary operator 24 * (e.g. cwiseSqrt) is applied to an expression. 25 * 26 * The TensorCwiseBinaryOp class represents an expression where a binary 27 * operator (e.g. addition) is applied to a lhs and a rhs expression. 28 * 29 */ 30 namespace internal { 31 template<typename NullaryOp, typename XprType> 32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > 33 : traits<XprType> 34 { 35 typedef traits<XprType> XprTraits; 36 typedef typename XprType::Scalar Scalar; 37 typedef typename XprType::Nested XprTypeNested; 38 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 39 static const int NumDimensions = XprTraits::NumDimensions; 40 static const int Layout = XprTraits::Layout; 41 typedef typename XprTraits::PointerType PointerType; 42 enum { 43 Flags = 0 44 }; 45 }; 46 47 } // end namespace internal 48 49 50 51 template<typename NullaryOp, typename XprType> 52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> 53 { 54 public: 55 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; 56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 57 typedef typename XprType::CoeffReturnType CoeffReturnType; 58 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested; 59 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; 60 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; 61 62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) 63 : m_xpr(xpr), m_functor(func) {} 64 65 EIGEN_DEVICE_FUNC 66 const typename internal::remove_all<typename XprType::Nested>::type& 67 nestedExpression() const { return m_xpr; } 68 69 EIGEN_DEVICE_FUNC 70 const NullaryOp& functor() const { return m_functor; } 71 72 protected: 73 typename XprType::Nested m_xpr; 74 const NullaryOp m_functor; 75 }; 76 77 78 79 namespace internal { 80 template<typename UnaryOp, typename XprType> 81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > 82 : traits<XprType> 83 { 84 // TODO(phli): Add InputScalar, InputPacket. Check references to 85 // current Scalar/Packet to see if the intent is Input or Output. 86 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; 87 typedef traits<XprType> XprTraits; 88 typedef typename XprType::Nested XprTypeNested; 89 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 90 static const int NumDimensions = XprTraits::NumDimensions; 91 static const int Layout = XprTraits::Layout; 92 typedef typename TypeConversion<Scalar, 93 typename XprTraits::PointerType 94 >::type 95 PointerType; 96 }; 97 98 template<typename UnaryOp, typename XprType> 99 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense> 100 { 101 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type; 102 }; 103 104 template<typename UnaryOp, typename XprType> 105 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type> 106 { 107 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type; 108 }; 109 110 } // end namespace internal 111 112 113 114 template<typename UnaryOp, typename XprType> 115 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> 116 { 117 public: 118 // TODO(phli): Add InputScalar, InputPacket. Check references to 119 // current Scalar/Packet to see if the intent is Input or Output. 120 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; 121 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 122 typedef Scalar CoeffReturnType; 123 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; 124 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; 125 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp()) 128 : m_xpr(xpr), m_functor(func) {} 129 130 EIGEN_DEVICE_FUNC 131 const UnaryOp& functor() const { return m_functor; } 132 133 /** \returns the nested expression */ 134 EIGEN_DEVICE_FUNC 135 const typename internal::remove_all<typename XprType::Nested>::type& 136 nestedExpression() const { return m_xpr; } 137 138 protected: 139 typename XprType::Nested m_xpr; 140 const UnaryOp m_functor; 141 }; 142 143 144 namespace internal { 145 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 146 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > 147 { 148 // Type promotion to handle the case where the types of the lhs and the rhs 149 // are different. 150 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 151 // current Scalar/Packet to see if the intent is Inputs or Output. 152 typedef typename result_of< 153 BinaryOp(typename LhsXprType::Scalar, 154 typename RhsXprType::Scalar)>::type Scalar; 155 typedef traits<LhsXprType> XprTraits; 156 typedef typename promote_storage_type< 157 typename traits<LhsXprType>::StorageKind, 158 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 159 typedef typename promote_index_type< 160 typename traits<LhsXprType>::Index, 161 typename traits<RhsXprType>::Index>::type Index; 162 typedef typename LhsXprType::Nested LhsNested; 163 typedef typename RhsXprType::Nested RhsNested; 164 typedef typename remove_reference<LhsNested>::type _LhsNested; 165 typedef typename remove_reference<RhsNested>::type _RhsNested; 166 static const int NumDimensions = XprTraits::NumDimensions; 167 static const int Layout = XprTraits::Layout; 168 typedef typename TypeConversion<Scalar, 169 typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, 170 typename traits<LhsXprType>::PointerType, 171 typename traits<RhsXprType>::PointerType>::type 172 >::type 173 PointerType; 174 enum { 175 Flags = 0 176 }; 177 }; 178 179 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 180 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense> 181 { 182 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type; 183 }; 184 185 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 186 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type> 187 { 188 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type; 189 }; 190 191 } // end namespace internal 192 193 194 195 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 196 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> 197 { 198 public: 199 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 200 // current Scalar/Packet to see if the intent is Inputs or Output. 201 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; 202 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 203 typedef Scalar CoeffReturnType; 204 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; 205 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; 206 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; 207 208 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) 209 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} 210 211 EIGEN_DEVICE_FUNC 212 const BinaryOp& functor() const { return m_functor; } 213 214 /** \returns the nested expressions */ 215 EIGEN_DEVICE_FUNC 216 const typename internal::remove_all<typename LhsXprType::Nested>::type& 217 lhsExpression() const { return m_lhs_xpr; } 218 219 EIGEN_DEVICE_FUNC 220 const typename internal::remove_all<typename RhsXprType::Nested>::type& 221 rhsExpression() const { return m_rhs_xpr; } 222 223 protected: 224 typename LhsXprType::Nested m_lhs_xpr; 225 typename RhsXprType::Nested m_rhs_xpr; 226 const BinaryOp m_functor; 227 }; 228 229 230 namespace internal { 231 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 232 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > 233 { 234 // Type promotion to handle the case where the types of the args are different. 235 typedef typename result_of< 236 TernaryOp(typename Arg1XprType::Scalar, 237 typename Arg2XprType::Scalar, 238 typename Arg3XprType::Scalar)>::type Scalar; 239 typedef traits<Arg1XprType> XprTraits; 240 typedef typename traits<Arg1XprType>::StorageKind StorageKind; 241 typedef typename traits<Arg1XprType>::Index Index; 242 typedef typename Arg1XprType::Nested Arg1Nested; 243 typedef typename Arg2XprType::Nested Arg2Nested; 244 typedef typename Arg3XprType::Nested Arg3Nested; 245 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; 246 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested; 247 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; 248 static const int NumDimensions = XprTraits::NumDimensions; 249 static const int Layout = XprTraits::Layout; 250 typedef typename TypeConversion<Scalar, 251 typename conditional<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val, 252 typename traits<Arg2XprType>::PointerType, 253 typename traits<Arg3XprType>::PointerType>::type 254 >::type 255 PointerType; 256 enum { 257 Flags = 0 258 }; 259 }; 260 261 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 262 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> 263 { 264 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type; 265 }; 266 267 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 268 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> 269 { 270 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type; 271 }; 272 273 } // end namespace internal 274 275 276 277 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 278 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> 279 { 280 public: 281 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar; 282 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 283 typedef Scalar CoeffReturnType; 284 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested; 285 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind; 286 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index; 287 288 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) 289 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} 290 291 EIGEN_DEVICE_FUNC 292 const TernaryOp& functor() const { return m_functor; } 293 294 /** \returns the nested expressions */ 295 EIGEN_DEVICE_FUNC 296 const typename internal::remove_all<typename Arg1XprType::Nested>::type& 297 arg1Expression() const { return m_arg1_xpr; } 298 299 EIGEN_DEVICE_FUNC 300 const typename internal::remove_all<typename Arg2XprType::Nested>::type& 301 arg2Expression() const { return m_arg2_xpr; } 302 303 EIGEN_DEVICE_FUNC 304 const typename internal::remove_all<typename Arg3XprType::Nested>::type& 305 arg3Expression() const { return m_arg3_xpr; } 306 307 protected: 308 typename Arg1XprType::Nested m_arg1_xpr; 309 typename Arg2XprType::Nested m_arg2_xpr; 310 typename Arg3XprType::Nested m_arg3_xpr; 311 const TernaryOp m_functor; 312 }; 313 314 315 namespace internal { 316 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 317 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > 318 : traits<ThenXprType> 319 { 320 typedef typename traits<ThenXprType>::Scalar Scalar; 321 typedef traits<ThenXprType> XprTraits; 322 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, 323 typename traits<ElseXprType>::StorageKind>::ret StorageKind; 324 typedef typename promote_index_type<typename traits<ElseXprType>::Index, 325 typename traits<ThenXprType>::Index>::type Index; 326 typedef typename IfXprType::Nested IfNested; 327 typedef typename ThenXprType::Nested ThenNested; 328 typedef typename ElseXprType::Nested ElseNested; 329 static const int NumDimensions = XprTraits::NumDimensions; 330 static const int Layout = XprTraits::Layout; 331 typedef typename conditional<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val, 332 typename traits<ThenXprType>::PointerType, 333 typename traits<ElseXprType>::PointerType>::type PointerType; 334 }; 335 336 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 337 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> 338 { 339 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; 340 }; 341 342 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 343 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> 344 { 345 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; 346 }; 347 348 } // end namespace internal 349 350 351 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 352 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> 353 { 354 public: 355 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; 356 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 357 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, 358 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; 359 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; 360 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; 361 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; 362 363 EIGEN_DEVICE_FUNC 364 TensorSelectOp(const IfXprType& a_condition, 365 const ThenXprType& a_then, 366 const ElseXprType& a_else) 367 : m_condition(a_condition), m_then(a_then), m_else(a_else) 368 { } 369 370 EIGEN_DEVICE_FUNC 371 const IfXprType& ifExpression() const { return m_condition; } 372 373 EIGEN_DEVICE_FUNC 374 const ThenXprType& thenExpression() const { return m_then; } 375 376 EIGEN_DEVICE_FUNC 377 const ElseXprType& elseExpression() const { return m_else; } 378 379 protected: 380 typename IfXprType::Nested m_condition; 381 typename ThenXprType::Nested m_then; 382 typename ElseXprType::Nested m_else; 383 }; 384 385 386 } // end namespace Eigen 387 388 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 389