1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H 12 13 namespace Eigen { 14 15 /** \class TensorPatch 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor patch class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename PatchDim, typename XprType> 24 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType> 25 { 26 typedef typename XprType::Scalar Scalar; 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef typename XprType::Nested Nested; 31 typedef typename remove_reference<Nested>::type _Nested; 32 static const int NumDimensions = XprTraits::NumDimensions + 1; 33 static const int Layout = XprTraits::Layout; 34 }; 35 36 template<typename PatchDim, typename XprType> 37 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense> 38 { 39 typedef const TensorPatchOp<PatchDim, XprType>& type; 40 }; 41 42 template<typename PatchDim, typename XprType> 43 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type> 44 { 45 typedef TensorPatchOp<PatchDim, XprType> type; 46 }; 47 48 } // end namespace internal 49 50 51 52 template<typename PatchDim, typename XprType> 53 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors> 54 { 55 public: 56 typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar; 57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 58 typedef typename XprType::CoeffReturnType CoeffReturnType; 59 typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested; 60 typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind; 61 typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index; 62 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims) 64 : m_xpr(expr), m_patch_dims(patch_dims) {} 65 66 EIGEN_DEVICE_FUNC 67 const PatchDim& patch_dims() const { return m_patch_dims; } 68 69 EIGEN_DEVICE_FUNC 70 const typename internal::remove_all<typename XprType::Nested>::type& 71 expression() const { return m_xpr; } 72 73 protected: 74 typename XprType::Nested m_xpr; 75 const PatchDim m_patch_dims; 76 }; 77 78 79 // Eval as rvalue 80 template<typename PatchDim, typename ArgType, typename Device> 81 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device> 82 { 83 typedef TensorPatchOp<PatchDim, ArgType> XprType; 84 typedef typename XprType::Index Index; 85 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1; 86 typedef DSizes<Index, NumDims> Dimensions; 87 typedef typename XprType::Scalar Scalar; 88 typedef typename XprType::CoeffReturnType CoeffReturnType; 89 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 90 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 91 92 93 enum { 94 IsAligned = false, 95 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 96 Layout = TensorEvaluator<ArgType, Device>::Layout, 97 CoordAccess = false, 98 RawAccess = false 99 }; 100 101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 102 : m_impl(op.expression(), device) 103 { 104 Index num_patches = 1; 105 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); 106 const PatchDim& patch_dims = op.patch_dims(); 107 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 108 for (int i = 0; i < NumDims-1; ++i) { 109 m_dimensions[i] = patch_dims[i]; 110 num_patches *= (input_dims[i] - patch_dims[i] + 1); 111 } 112 m_dimensions[NumDims-1] = num_patches; 113 114 m_inputStrides[0] = 1; 115 m_patchStrides[0] = 1; 116 for (int i = 1; i < NumDims-1; ++i) { 117 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; 118 m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1); 119 } 120 m_outputStrides[0] = 1; 121 for (int i = 1; i < NumDims; ++i) { 122 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; 123 } 124 } else { 125 for (int i = 0; i < NumDims-1; ++i) { 126 m_dimensions[i+1] = patch_dims[i]; 127 num_patches *= (input_dims[i] - patch_dims[i] + 1); 128 } 129 m_dimensions[0] = num_patches; 130 131 m_inputStrides[NumDims-2] = 1; 132 m_patchStrides[NumDims-2] = 1; 133 for (int i = NumDims-3; i >= 0; --i) { 134 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1]; 135 m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1); 136 } 137 m_outputStrides[NumDims-1] = 1; 138 for (int i = NumDims-2; i >= 0; --i) { 139 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; 140 } 141 } 142 } 143 144 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 145 146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 147 m_impl.evalSubExprsIfNeeded(NULL); 148 return true; 149 } 150 151 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 152 m_impl.cleanup(); 153 } 154 155 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 156 { 157 Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0; 158 // Find the location of the first element of the patch. 159 Index patchIndex = index / m_outputStrides[output_stride_index]; 160 // Find the offset of the element wrt the location of the first element. 161 Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index]; 162 Index inputIndex = 0; 163 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 164 for (int i = NumDims - 2; i > 0; --i) { 165 const Index patchIdx = patchIndex / m_patchStrides[i]; 166 patchIndex -= patchIdx * m_patchStrides[i]; 167 const Index offsetIdx = patchOffset / m_outputStrides[i]; 168 patchOffset -= offsetIdx * m_outputStrides[i]; 169 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i]; 170 } 171 } else { 172 for (int i = 0; i < NumDims - 2; ++i) { 173 const Index patchIdx = patchIndex / m_patchStrides[i]; 174 patchIndex -= patchIdx * m_patchStrides[i]; 175 const Index offsetIdx = patchOffset / m_outputStrides[i+1]; 176 patchOffset -= offsetIdx * m_outputStrides[i+1]; 177 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i]; 178 } 179 } 180 inputIndex += (patchIndex + patchOffset); 181 return m_impl.coeff(inputIndex); 182 } 183 184 template<int LoadMode> 185 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 186 { 187 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 188 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 189 190 Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0; 191 Index indices[2] = {index, index + PacketSize - 1}; 192 Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index], 193 indices[1] / m_outputStrides[output_stride_index]}; 194 Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index], 195 indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]}; 196 197 Index inputIndices[2] = {0, 0}; 198 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 199 for (int i = NumDims - 2; i > 0; --i) { 200 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i], 201 patchIndices[1] / m_patchStrides[i]}; 202 patchIndices[0] -= patchIdx[0] * m_patchStrides[i]; 203 patchIndices[1] -= patchIdx[1] * m_patchStrides[i]; 204 205 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i], 206 patchOffsets[1] / m_outputStrides[i]}; 207 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i]; 208 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i]; 209 210 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i]; 211 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i]; 212 } 213 } else { 214 for (int i = 0; i < NumDims - 2; ++i) { 215 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i], 216 patchIndices[1] / m_patchStrides[i]}; 217 patchIndices[0] -= patchIdx[0] * m_patchStrides[i]; 218 patchIndices[1] -= patchIdx[1] * m_patchStrides[i]; 219 220 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1], 221 patchOffsets[1] / m_outputStrides[i+1]}; 222 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1]; 223 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1]; 224 225 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i]; 226 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i]; 227 } 228 } 229 inputIndices[0] += (patchIndices[0] + patchOffsets[0]); 230 inputIndices[1] += (patchIndices[1] + patchOffsets[1]); 231 232 if (inputIndices[1] - inputIndices[0] == PacketSize - 1) { 233 PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]); 234 return rslt; 235 } 236 else { 237 EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize]; 238 values[0] = m_impl.coeff(inputIndices[0]); 239 values[PacketSize-1] = m_impl.coeff(inputIndices[1]); 240 for (int i = 1; i < PacketSize-1; ++i) { 241 values[i] = coeff(index+i); 242 } 243 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 244 return rslt; 245 } 246 } 247 248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 249 const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() + 250 TensorOpCost::MulCost<Index>() + 251 2 * TensorOpCost::AddCost<Index>()); 252 return m_impl.costPerCoeff(vectorized) + 253 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize); 254 } 255 256 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 257 258 protected: 259 Dimensions m_dimensions; 260 array<Index, NumDims> m_outputStrides; 261 array<Index, NumDims-1> m_inputStrides; 262 array<Index, NumDims-1> m_patchStrides; 263 264 TensorEvaluator<ArgType, Device> m_impl; 265 }; 266 267 } // end namespace Eigen 268 269 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H 270