1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com> 5 // Benoit Steiner <benoit.steiner.goog@gmail.com> 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 13 14 namespace Eigen { 15 namespace internal { 16 17 /** \class TensorIndexTuple 18 * \ingroup CXX11_Tensor_Module 19 * 20 * \brief Tensor + Index Tuple class. 21 * 22 * 23 */ 24 template<typename XprType> 25 struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType> 26 { 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef Tuple<Index, typename XprTraits::Scalar> Scalar; 31 typedef typename XprType::Nested Nested; 32 typedef typename remove_reference<Nested>::type _Nested; 33 static const int NumDimensions = XprTraits::NumDimensions; 34 static const int Layout = XprTraits::Layout; 35 }; 36 37 template<typename XprType> 38 struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense> 39 { 40 typedef const TensorIndexTupleOp<XprType>EIGEN_DEVICE_REF type; 41 }; 42 43 template<typename XprType> 44 struct nested<TensorIndexTupleOp<XprType>, 1, 45 typename eval<TensorIndexTupleOp<XprType> >::type> 46 { 47 typedef TensorIndexTupleOp<XprType> type; 48 }; 49 50 } // end namespace internal 51 52 template<typename XprType> 53 class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors> 54 { 55 public: 56 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar; 57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 58 typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested; 59 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind; 60 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index; 61 typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType; 62 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr) 64 : m_xpr(expr) {} 65 66 EIGEN_DEVICE_FUNC 67 const typename internal::remove_all<typename XprType::Nested>::type& 68 expression() const { return m_xpr; } 69 70 protected: 71 typename XprType::Nested m_xpr; 72 }; 73 74 // Eval as rvalue 75 template<typename ArgType, typename Device> 76 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> 77 { 78 typedef TensorIndexTupleOp<ArgType> XprType; 79 typedef typename XprType::Index Index; 80 typedef typename XprType::Scalar Scalar; 81 typedef typename XprType::CoeffReturnType CoeffReturnType; 82 83 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; 84 static const int NumDims = internal::array_size<Dimensions>::value; 85 typedef StorageMemory<CoeffReturnType, Device> Storage; 86 typedef typename Storage::Type EvaluatorPointerType; 87 88 enum { 89 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, 90 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, 91 BlockAccess = false, 92 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess, 93 Layout = TensorEvaluator<ArgType, Device>::Layout, 94 CoordAccess = false, // to be implemented 95 RawAccess = false 96 }; 97 98 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 99 typedef internal::TensorBlockNotImplemented TensorBlock; 100 //===--------------------------------------------------------------------===// 101 102 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 103 : m_impl(op.expression(), device) { } 104 105 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 106 return m_impl.dimensions(); 107 } 108 109 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) { 110 m_impl.evalSubExprsIfNeeded(NULL); 111 return true; 112 } 113 EIGEN_STRONG_INLINE void cleanup() { 114 m_impl.cleanup(); 115 } 116 117 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 118 { 119 return CoeffReturnType(index, m_impl.coeff(index)); 120 } 121 122 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 123 costPerCoeff(bool vectorized) const { 124 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1); 125 } 126 127 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } 128 129 #ifdef EIGEN_USE_SYCL 130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 131 m_impl.bind(cgh); 132 } 133 #endif 134 135 protected: 136 TensorEvaluator<ArgType, Device> m_impl; 137 }; 138 139 namespace internal { 140 141 /** \class TensorTupleIndex 142 * \ingroup CXX11_Tensor_Module 143 * 144 * \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>. 145 * 146 */ 147 template<typename ReduceOp, typename Dims, typename XprType> 148 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType> 149 { 150 typedef traits<XprType> XprTraits; 151 typedef typename XprTraits::StorageKind StorageKind; 152 typedef typename XprTraits::Index Index; 153 typedef Index Scalar; 154 typedef typename XprType::Nested Nested; 155 typedef typename remove_reference<Nested>::type _Nested; 156 static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value; 157 static const int Layout = XprTraits::Layout; 158 }; 159 160 template<typename ReduceOp, typename Dims, typename XprType> 161 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense> 162 { 163 typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type; 164 }; 165 166 template<typename ReduceOp, typename Dims, typename XprType> 167 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1, 168 typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type> 169 { 170 typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type; 171 }; 172 173 } // end namespace internal 174 175 template<typename ReduceOp, typename Dims, typename XprType> 176 class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> 177 { 178 public: 179 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar; 180 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 181 typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested; 182 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind; 183 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index; 184 typedef Index CoeffReturnType; 185 186 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr, 187 const ReduceOp& reduce_op, 188 const Index return_dim, 189 const Dims& reduce_dims) 190 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {} 191 192 EIGEN_DEVICE_FUNC 193 const typename internal::remove_all<typename XprType::Nested>::type& 194 expression() const { return m_xpr; } 195 196 EIGEN_DEVICE_FUNC 197 const ReduceOp& reduce_op() const { return m_reduce_op; } 198 199 EIGEN_DEVICE_FUNC 200 const Dims& reduce_dims() const { return m_reduce_dims; } 201 202 EIGEN_DEVICE_FUNC 203 Index return_dim() const { return m_return_dim; } 204 205 protected: 206 typename XprType::Nested m_xpr; 207 const ReduceOp m_reduce_op; 208 const Index m_return_dim; 209 const Dims m_reduce_dims; 210 }; 211 212 // Eval as rvalue 213 template<typename ReduceOp, typename Dims, typename ArgType, typename Device> 214 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device> 215 { 216 typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType; 217 typedef typename XprType::Index Index; 218 typedef typename XprType::Scalar Scalar; 219 typedef typename XprType::CoeffReturnType CoeffReturnType; 220 typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType; 221 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions; 222 typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions; 223 static const int NumDims = internal::array_size<InputDimensions>::value; 224 typedef array<Index, NumDims> StrideDims; 225 typedef StorageMemory<CoeffReturnType, Device> Storage; 226 typedef typename Storage::Type EvaluatorPointerType; 227 typedef StorageMemory<TupleType, Device> TupleStorageMem; 228 229 enum { 230 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, 231 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, 232 BlockAccess = false, 233 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess, 234 Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout, 235 CoordAccess = false, // to be implemented 236 RawAccess = false 237 }; 238 239 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 240 typedef internal::TensorBlockNotImplemented TensorBlock; 241 //===--------------------------------------------------------------------===// 242 243 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 244 : m_orig_impl(op.expression(), device), 245 m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), 246 m_return_dim(op.return_dim()) 247 { 248 gen_strides(m_orig_impl.dimensions(), m_strides); 249 if (Layout == static_cast<int>(ColMajor)) { 250 const Index total_size = internal::array_prod(m_orig_impl.dimensions()); 251 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size; 252 } else { 253 const Index total_size = internal::array_prod(m_orig_impl.dimensions()); 254 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size; 255 } 256 // If m_return_dim is not a valid index, returns 1 or this can crash on Windows. 257 m_stride_div = ((m_return_dim >= 0) && 258 (m_return_dim < static_cast<Index>(m_strides.size()))) 259 ? m_strides[m_return_dim] : 1; 260 } 261 262 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 263 return m_impl.dimensions(); 264 } 265 266 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) { 267 m_impl.evalSubExprsIfNeeded(NULL); 268 return true; 269 } 270 EIGEN_STRONG_INLINE void cleanup() { 271 m_impl.cleanup(); 272 } 273 274 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 275 const TupleType v = m_impl.coeff(index); 276 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div; 277 } 278 279 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } 280 #ifdef EIGEN_USE_SYCL 281 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 282 m_impl.bind(cgh); 283 m_orig_impl.bind(cgh); 284 } 285 #endif 286 287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 288 costPerCoeff(bool vectorized) const { 289 const double compute_cost = 1.0 + 290 (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>())); 291 return m_orig_impl.costPerCoeff(vectorized) + 292 m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost); 293 } 294 295 private: 296 EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) { 297 if (m_return_dim < 0) { 298 return; // Won't be using the strides. 299 } 300 eigen_assert(m_return_dim < NumDims && 301 "Asking to convert index to a dimension outside of the rank"); 302 303 // Calculate m_stride_div and m_stride_mod, which are used to 304 // calculate the value of an index w.r.t. the m_return_dim. 305 if (Layout == static_cast<int>(ColMajor)) { 306 strides[0] = 1; 307 for (int i = 1; i < NumDims; ++i) { 308 strides[i] = strides[i-1] * dims[i-1]; 309 } 310 } else { 311 strides[NumDims-1] = 1; 312 for (int i = NumDims - 2; i >= 0; --i) { 313 strides[i] = strides[i+1] * dims[i+1]; 314 } 315 } 316 } 317 318 protected: 319 TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl; 320 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl; 321 const Index m_return_dim; 322 StrideDims m_strides; 323 Index m_stride_mod; 324 Index m_stride_div; 325 }; 326 327 } // end namespace Eigen 328 329 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 330