1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 12 13 namespace Eigen { 14 15 /** \class TensorCustomUnaryOp 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor custom class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename CustomUnaryFunc, typename XprType> 24 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 25 { 26 typedef typename XprType::Scalar Scalar; 27 typedef typename XprType::StorageKind StorageKind; 28 typedef typename XprType::Index Index; 29 typedef typename XprType::Nested Nested; 30 typedef typename remove_reference<Nested>::type _Nested; 31 static const int NumDimensions = traits<XprType>::NumDimensions; 32 static const int Layout = traits<XprType>::Layout; 33 }; 34 35 template<typename CustomUnaryFunc, typename XprType> 36 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> 37 { 38 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type; 39 }; 40 41 template<typename CustomUnaryFunc, typename XprType> 42 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 43 { 44 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type; 45 }; 46 47 } // end namespace internal 48 49 50 51 template<typename CustomUnaryFunc, typename XprType> 52 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> 53 { 54 public: 55 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar; 56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 57 typedef typename XprType::CoeffReturnType CoeffReturnType; 58 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested; 59 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind; 60 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index; 61 62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) 63 : m_expr(expr), m_func(func) {} 64 65 EIGEN_DEVICE_FUNC 66 const CustomUnaryFunc& func() const { return m_func; } 67 68 EIGEN_DEVICE_FUNC 69 const typename internal::remove_all<typename XprType::Nested>::type& 70 expression() const { return m_expr; } 71 72 protected: 73 typename XprType::Nested m_expr; 74 const CustomUnaryFunc m_func; 75 }; 76 77 78 // Eval as rvalue 79 template<typename CustomUnaryFunc, typename XprType, typename Device> 80 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> 81 { 82 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType; 83 typedef typename internal::traits<ArgType>::Index Index; 84 static const int NumDims = internal::traits<ArgType>::NumDimensions; 85 typedef DSizes<Index, NumDims> Dimensions; 86 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar; 87 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 88 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 89 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 90 91 enum { 92 IsAligned = false, 93 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 94 BlockAccess = false, 95 Layout = TensorEvaluator<XprType, Device>::Layout, 96 CoordAccess = false, // to be implemented 97 RawAccess = false 98 }; 99 100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) 101 : m_op(op), m_device(device), m_result(NULL) 102 { 103 m_dimensions = op.func().dimensions(op.expression()); 104 } 105 106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 107 108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 109 if (data) { 110 evalTo(data); 111 return false; 112 } else { 113 m_result = static_cast<CoeffReturnType*>( 114 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 115 evalTo(m_result); 116 return true; 117 } 118 } 119 120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 121 if (m_result != NULL) { 122 m_device.deallocate(m_result); 123 m_result = NULL; 124 } 125 } 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 128 return m_result[index]; 129 } 130 131 template<int LoadMode> 132 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 133 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 134 } 135 136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 137 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 138 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 139 } 140 141 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 142 143 protected: 144 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 145 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result( 146 data, m_dimensions); 147 m_op.func().eval(m_op.expression(), result, m_device); 148 } 149 150 Dimensions m_dimensions; 151 const ArgType m_op; 152 const Device& m_device; 153 CoeffReturnType* m_result; 154 }; 155 156 157 158 /** \class TensorCustomBinaryOp 159 * \ingroup CXX11_Tensor_Module 160 * 161 * \brief Tensor custom class. 162 * 163 * 164 */ 165 namespace internal { 166 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 167 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 168 { 169 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, 170 typename RhsXprType::Scalar>::ret Scalar; 171 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 172 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 173 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 174 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 175 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 176 typename traits<RhsXprType>::Index>::type Index; 177 typedef typename LhsXprType::Nested LhsNested; 178 typedef typename RhsXprType::Nested RhsNested; 179 typedef typename remove_reference<LhsNested>::type _LhsNested; 180 typedef typename remove_reference<RhsNested>::type _RhsNested; 181 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 182 static const int Layout = traits<LhsXprType>::Layout; 183 }; 184 185 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 186 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> 187 { 188 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type; 189 }; 190 191 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 192 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 193 { 194 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type; 195 }; 196 197 } // end namespace internal 198 199 200 201 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 202 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> 203 { 204 public: 205 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar; 206 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 207 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType; 208 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested; 209 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind; 210 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index; 211 212 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) 213 214 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} 215 216 EIGEN_DEVICE_FUNC 217 const CustomBinaryFunc& func() const { return m_func; } 218 219 EIGEN_DEVICE_FUNC 220 const typename internal::remove_all<typename LhsXprType::Nested>::type& 221 lhsExpression() const { return m_lhs_xpr; } 222 223 EIGEN_DEVICE_FUNC 224 const typename internal::remove_all<typename RhsXprType::Nested>::type& 225 rhsExpression() const { return m_rhs_xpr; } 226 227 protected: 228 typename LhsXprType::Nested m_lhs_xpr; 229 typename RhsXprType::Nested m_rhs_xpr; 230 const CustomBinaryFunc m_func; 231 }; 232 233 234 // Eval as rvalue 235 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device> 236 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> 237 { 238 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType; 239 typedef typename internal::traits<XprType>::Index Index; 240 static const int NumDims = internal::traits<XprType>::NumDimensions; 241 typedef DSizes<Index, NumDims> Dimensions; 242 typedef typename XprType::Scalar Scalar; 243 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 244 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 245 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 246 247 enum { 248 IsAligned = false, 249 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 250 BlockAccess = false, 251 Layout = TensorEvaluator<LhsXprType, Device>::Layout, 252 CoordAccess = false, // to be implemented 253 RawAccess = false 254 }; 255 256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 257 : m_op(op), m_device(device), m_result(NULL) 258 { 259 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); 260 } 261 262 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 263 264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 265 if (data) { 266 evalTo(data); 267 return false; 268 } else { 269 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 270 evalTo(m_result); 271 return true; 272 } 273 } 274 275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 276 if (m_result != NULL) { 277 m_device.deallocate(m_result); 278 m_result = NULL; 279 } 280 } 281 282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 283 return m_result[index]; 284 } 285 286 template<int LoadMode> 287 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 288 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 289 } 290 291 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 292 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 293 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 294 } 295 296 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 297 298 protected: 299 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 300 TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions); 301 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); 302 } 303 304 Dimensions m_dimensions; 305 const XprType m_op; 306 const Device& m_device; 307 CoeffReturnType* m_result; 308 }; 309 310 311 } // end namespace Eigen 312 313 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 314