1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Ke Yang <yangke@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H 12 13 namespace Eigen { 14 15 /** \class TensorInflation 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor inflation class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename Strides, typename XprType> 24 struct traits<TensorInflationOp<Strides, XprType> > : public traits<XprType> 25 { 26 typedef typename XprType::Scalar Scalar; 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef typename XprType::Nested Nested; 31 typedef typename remove_reference<Nested>::type _Nested; 32 static const int NumDimensions = XprTraits::NumDimensions; 33 static const int Layout = XprTraits::Layout; 34 }; 35 36 template<typename Strides, typename XprType> 37 struct eval<TensorInflationOp<Strides, XprType>, Eigen::Dense> 38 { 39 typedef const TensorInflationOp<Strides, XprType>& type; 40 }; 41 42 template<typename Strides, typename XprType> 43 struct nested<TensorInflationOp<Strides, XprType>, 1, typename eval<TensorInflationOp<Strides, XprType> >::type> 44 { 45 typedef TensorInflationOp<Strides, XprType> type; 46 }; 47 48 } // end namespace internal 49 50 template<typename Strides, typename XprType> 51 class TensorInflationOp : public TensorBase<TensorInflationOp<Strides, XprType>, ReadOnlyAccessors> 52 { 53 public: 54 typedef typename Eigen::internal::traits<TensorInflationOp>::Scalar Scalar; 55 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 56 typedef typename XprType::CoeffReturnType CoeffReturnType; 57 typedef typename Eigen::internal::nested<TensorInflationOp>::type Nested; 58 typedef typename Eigen::internal::traits<TensorInflationOp>::StorageKind StorageKind; 59 typedef typename Eigen::internal::traits<TensorInflationOp>::Index Index; 60 61 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorInflationOp(const XprType& expr, const Strides& strides) 62 : m_xpr(expr), m_strides(strides) {} 63 64 EIGEN_DEVICE_FUNC 65 const Strides& strides() const { return m_strides; } 66 67 EIGEN_DEVICE_FUNC 68 const typename internal::remove_all<typename XprType::Nested>::type& 69 expression() const { return m_xpr; } 70 71 protected: 72 typename XprType::Nested m_xpr; 73 const Strides m_strides; 74 }; 75 76 // Eval as rvalue 77 template<typename Strides, typename ArgType, typename Device> 78 struct TensorEvaluator<const TensorInflationOp<Strides, ArgType>, Device> 79 { 80 typedef TensorInflationOp<Strides, ArgType> XprType; 81 typedef typename XprType::Index Index; 82 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 83 typedef DSizes<Index, NumDims> Dimensions; 84 typedef typename XprType::Scalar Scalar; 85 typedef typename XprType::CoeffReturnType CoeffReturnType; 86 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 87 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 88 89 enum { 90 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, 91 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 92 BlockAccess = false, 93 Layout = TensorEvaluator<ArgType, Device>::Layout, 94 CoordAccess = false, // to be implemented 95 RawAccess = false 96 }; 97 98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 99 : m_impl(op.expression(), device), m_strides(op.strides()) 100 { 101 m_dimensions = m_impl.dimensions(); 102 // Expand each dimension to the inflated dimension. 103 for (int i = 0; i < NumDims; ++i) { 104 m_dimensions[i] = (m_dimensions[i] - 1) * op.strides()[i] + 1; 105 } 106 107 // Remember the strides for fast division. 108 for (int i = 0; i < NumDims; ++i) { 109 m_fastStrides[i] = internal::TensorIntDivisor<Index>(m_strides[i]); 110 } 111 112 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); 113 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 114 m_outputStrides[0] = 1; 115 m_inputStrides[0] = 1; 116 for (int i = 1; i < NumDims; ++i) { 117 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; 118 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; 119 } 120 } else { // RowMajor 121 m_outputStrides[NumDims-1] = 1; 122 m_inputStrides[NumDims-1] = 1; 123 for (int i = NumDims - 2; i >= 0; --i) { 124 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; 125 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1]; 126 } 127 } 128 } 129 130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 131 132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 133 m_impl.evalSubExprsIfNeeded(NULL); 134 return true; 135 } 136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 137 m_impl.cleanup(); 138 } 139 140 // Computes the input index given the output index. Returns true if the output 141 // index doesn't fall into a hole. 142 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool getInputIndex(Index index, Index* inputIndex) const 143 { 144 eigen_assert(index < dimensions().TotalSize()); 145 *inputIndex = 0; 146 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 147 for (int i = NumDims - 1; i > 0; --i) { 148 const Index idx = index / m_outputStrides[i]; 149 if (idx != idx / m_fastStrides[i] * m_strides[i]) { 150 return false; 151 } 152 *inputIndex += idx / m_strides[i] * m_inputStrides[i]; 153 index -= idx * m_outputStrides[i]; 154 } 155 if (index != index / m_fastStrides[0] * m_strides[0]) { 156 return false; 157 } 158 *inputIndex += index / m_strides[0]; 159 return true; 160 } else { 161 for (int i = 0; i < NumDims - 1; ++i) { 162 const Index idx = index / m_outputStrides[i]; 163 if (idx != idx / m_fastStrides[i] * m_strides[i]) { 164 return false; 165 } 166 *inputIndex += idx / m_strides[i] * m_inputStrides[i]; 167 index -= idx * m_outputStrides[i]; 168 } 169 if (index != index / m_fastStrides[NumDims-1] * m_strides[NumDims-1]) { 170 return false; 171 } 172 *inputIndex += index / m_strides[NumDims - 1]; 173 } 174 return true; 175 } 176 177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 178 { 179 Index inputIndex = 0; 180 if (getInputIndex(index, &inputIndex)) { 181 return m_impl.coeff(inputIndex); 182 } else { 183 return Scalar(0); 184 } 185 } 186 187 // TODO(yangke): optimize this function so that we can detect and produce 188 // all-zero packets 189 template<int LoadMode> 190 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 191 { 192 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 193 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 194 195 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 196 for (int i = 0; i < PacketSize; ++i) { 197 values[i] = coeff(index+i); 198 } 199 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 200 return rslt; 201 } 202 203 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 204 const double compute_cost = NumDims * (3 * TensorOpCost::DivCost<Index>() + 205 3 * TensorOpCost::MulCost<Index>() + 206 2 * TensorOpCost::AddCost<Index>()); 207 const double input_size = m_impl.dimensions().TotalSize(); 208 const double output_size = m_dimensions.TotalSize(); 209 if (output_size == 0) 210 return TensorOpCost(); 211 return m_impl.costPerCoeff(vectorized) + 212 TensorOpCost(sizeof(CoeffReturnType) * input_size / output_size, 0, 213 compute_cost, vectorized, PacketSize); 214 } 215 216 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 217 218 protected: 219 Dimensions m_dimensions; 220 array<Index, NumDims> m_outputStrides; 221 array<Index, NumDims> m_inputStrides; 222 TensorEvaluator<ArgType, Device> m_impl; 223 const Strides m_strides; 224 array<internal::TensorIntDivisor<Index>, NumDims> m_fastStrides; 225 }; 226 227 } // end namespace Eigen 228 229 #endif // EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H 230