1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com> 5 // Benoit Steiner <benoit.steiner.goog@gmail.com> 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 13 14 namespace Eigen { 15 namespace internal { 16 17 /** \class TensorIndexTuple 18 * \ingroup CXX11_Tensor_Module 19 * 20 * \brief Tensor + Index Tuple class. 21 * 22 * 23 */ 24 template<typename XprType> 25 struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType> 26 { 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef Tuple<Index, typename XprTraits::Scalar> Scalar; 31 typedef typename XprType::Nested Nested; 32 typedef typename remove_reference<Nested>::type _Nested; 33 static const int NumDimensions = XprTraits::NumDimensions; 34 static const int Layout = XprTraits::Layout; 35 }; 36 37 template<typename XprType> 38 struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense> 39 { 40 typedef const TensorIndexTupleOp<XprType>& type; 41 }; 42 43 template<typename XprType> 44 struct nested<TensorIndexTupleOp<XprType>, 1, 45 typename eval<TensorIndexTupleOp<XprType> >::type> 46 { 47 typedef TensorIndexTupleOp<XprType> type; 48 }; 49 50 } // end namespace internal 51 52 template<typename XprType> 53 class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors> 54 { 55 public: 56 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar; 57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 58 typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested; 59 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind; 60 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index; 61 typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType; 62 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr) 64 : m_xpr(expr) {} 65 66 EIGEN_DEVICE_FUNC 67 const typename internal::remove_all<typename XprType::Nested>::type& 68 expression() const { return m_xpr; } 69 70 protected: 71 typename XprType::Nested m_xpr; 72 }; 73 74 // Eval as rvalue 75 template<typename ArgType, typename Device> 76 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> 77 { 78 typedef TensorIndexTupleOp<ArgType> XprType; 79 typedef typename XprType::Index Index; 80 typedef typename XprType::Scalar Scalar; 81 typedef typename XprType::CoeffReturnType CoeffReturnType; 82 83 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; 84 static const int NumDims = internal::array_size<Dimensions>::value; 85 86 enum { 87 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, 88 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, 89 BlockAccess = false, 90 Layout = TensorEvaluator<ArgType, Device>::Layout, 91 CoordAccess = false, // to be implemented 92 RawAccess = false 93 }; 94 95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 96 : m_impl(op.expression(), device) { } 97 98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 99 return m_impl.dimensions(); 100 } 101 102 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 103 m_impl.evalSubExprsIfNeeded(NULL); 104 return true; 105 } 106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 107 m_impl.cleanup(); 108 } 109 110 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 111 { 112 return CoeffReturnType(index, m_impl.coeff(index)); 113 } 114 115 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 116 costPerCoeff(bool vectorized) const { 117 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1); 118 } 119 120 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 121 122 protected: 123 TensorEvaluator<ArgType, Device> m_impl; 124 }; 125 126 namespace internal { 127 128 /** \class TensorTupleIndex 129 * \ingroup CXX11_Tensor_Module 130 * 131 * \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>. 132 * 133 */ 134 template<typename ReduceOp, typename Dims, typename XprType> 135 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType> 136 { 137 typedef traits<XprType> XprTraits; 138 typedef typename XprTraits::StorageKind StorageKind; 139 typedef typename XprTraits::Index Index; 140 typedef Index Scalar; 141 typedef typename XprType::Nested Nested; 142 typedef typename remove_reference<Nested>::type _Nested; 143 static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value; 144 static const int Layout = XprTraits::Layout; 145 }; 146 147 template<typename ReduceOp, typename Dims, typename XprType> 148 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense> 149 { 150 typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type; 151 }; 152 153 template<typename ReduceOp, typename Dims, typename XprType> 154 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1, 155 typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type> 156 { 157 typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type; 158 }; 159 160 } // end namespace internal 161 162 template<typename ReduceOp, typename Dims, typename XprType> 163 class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> 164 { 165 public: 166 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar; 167 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 168 typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested; 169 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind; 170 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index; 171 typedef Index CoeffReturnType; 172 173 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr, 174 const ReduceOp& reduce_op, 175 const int return_dim, 176 const Dims& reduce_dims) 177 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {} 178 179 EIGEN_DEVICE_FUNC 180 const typename internal::remove_all<typename XprType::Nested>::type& 181 expression() const { return m_xpr; } 182 183 EIGEN_DEVICE_FUNC 184 const ReduceOp& reduce_op() const { return m_reduce_op; } 185 186 EIGEN_DEVICE_FUNC 187 const Dims& reduce_dims() const { return m_reduce_dims; } 188 189 EIGEN_DEVICE_FUNC 190 int return_dim() const { return m_return_dim; } 191 192 protected: 193 typename XprType::Nested m_xpr; 194 const ReduceOp m_reduce_op; 195 const int m_return_dim; 196 const Dims m_reduce_dims; 197 }; 198 199 // Eval as rvalue 200 template<typename ReduceOp, typename Dims, typename ArgType, typename Device> 201 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device> 202 { 203 typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType; 204 typedef typename XprType::Index Index; 205 typedef typename XprType::Scalar Scalar; 206 typedef typename XprType::CoeffReturnType CoeffReturnType; 207 typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType; 208 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions; 209 typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions; 210 static const int NumDims = internal::array_size<InputDimensions>::value; 211 typedef array<Index, NumDims> StrideDims; 212 213 enum { 214 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, 215 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, 216 BlockAccess = false, 217 Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout, 218 CoordAccess = false, // to be implemented 219 RawAccess = false 220 }; 221 222 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 223 : m_orig_impl(op.expression(), device), 224 m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), 225 m_return_dim(op.return_dim()) { 226 227 gen_strides(m_orig_impl.dimensions(), m_strides); 228 if (Layout == static_cast<int>(ColMajor)) { 229 const Index total_size = internal::array_prod(m_orig_impl.dimensions()); 230 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size; 231 } else { 232 const Index total_size = internal::array_prod(m_orig_impl.dimensions()); 233 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size; 234 } 235 m_stride_div = m_strides[m_return_dim]; 236 } 237 238 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 239 return m_impl.dimensions(); 240 } 241 242 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 243 m_impl.evalSubExprsIfNeeded(NULL); 244 return true; 245 } 246 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 247 m_impl.cleanup(); 248 } 249 250 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 251 const TupleType v = m_impl.coeff(index); 252 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div; 253 } 254 255 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 256 257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 258 costPerCoeff(bool vectorized) const { 259 const double compute_cost = 1.0 + 260 (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>())); 261 return m_orig_impl.costPerCoeff(vectorized) + 262 m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost); 263 } 264 265 private: 266 EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) { 267 if (m_return_dim < 0) { 268 return; // Won't be using the strides. 269 } 270 eigen_assert(m_return_dim < NumDims && 271 "Asking to convert index to a dimension outside of the rank"); 272 273 // Calculate m_stride_div and m_stride_mod, which are used to 274 // calculate the value of an index w.r.t. the m_return_dim. 275 if (Layout == static_cast<int>(ColMajor)) { 276 strides[0] = 1; 277 for (int i = 1; i < NumDims; ++i) { 278 strides[i] = strides[i-1] * dims[i-1]; 279 } 280 } else { 281 strides[NumDims-1] = 1; 282 for (int i = NumDims - 2; i >= 0; --i) { 283 strides[i] = strides[i+1] * dims[i+1]; 284 } 285 } 286 } 287 288 protected: 289 TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl; 290 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl; 291 const int m_return_dim; 292 StrideDims m_strides; 293 Index m_stride_mod; 294 Index m_stride_div; 295 }; 296 297 } // end namespace Eigen 298 299 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 300