• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
12 
13 namespace Eigen {
14 
15 /** \class TensorPatch
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor patch class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename PatchDim, typename XprType>
24 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef traits<XprType> XprTraits;
28   typedef typename XprTraits::StorageKind StorageKind;
29   typedef typename XprTraits::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = XprTraits::NumDimensions + 1;
33   static const int Layout = XprTraits::Layout;
34 };
35 
36 template<typename PatchDim, typename XprType>
37 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
38 {
39   typedef const TensorPatchOp<PatchDim, XprType>& type;
40 };
41 
42 template<typename PatchDim, typename XprType>
43 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
44 {
45   typedef TensorPatchOp<PatchDim, XprType> type;
46 };
47 
48 }  // end namespace internal
49 
50 
51 
52 template<typename PatchDim, typename XprType>
53 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors>
54 {
55   public:
56   typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58   typedef typename XprType::CoeffReturnType CoeffReturnType;
59   typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
60   typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
61   typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
62 
63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims)
64       : m_xpr(expr), m_patch_dims(patch_dims) {}
65 
66     EIGEN_DEVICE_FUNC
67     const PatchDim& patch_dims() const { return m_patch_dims; }
68 
69     EIGEN_DEVICE_FUNC
70     const typename internal::remove_all<typename XprType::Nested>::type&
71     expression() const { return m_xpr; }
72 
73   protected:
74     typename XprType::Nested m_xpr;
75     const PatchDim m_patch_dims;
76 };
77 
78 
79 // Eval as rvalue
80 template<typename PatchDim, typename ArgType, typename Device>
81 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
82 {
83   typedef TensorPatchOp<PatchDim, ArgType> XprType;
84   typedef typename XprType::Index Index;
85   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
86   typedef DSizes<Index, NumDims> Dimensions;
87   typedef typename XprType::Scalar Scalar;
88   typedef typename XprType::CoeffReturnType CoeffReturnType;
89   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
91 
92 
93   enum {
94     IsAligned = false,
95     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
96     Layout = TensorEvaluator<ArgType, Device>::Layout,
97     CoordAccess = false,
98     RawAccess = false
99  };
100 
101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
102       : m_impl(op.expression(), device)
103   {
104     Index num_patches = 1;
105     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
106     const PatchDim& patch_dims = op.patch_dims();
107     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
108       for (int i = 0; i < NumDims-1; ++i) {
109         m_dimensions[i] = patch_dims[i];
110         num_patches *= (input_dims[i] - patch_dims[i] + 1);
111       }
112       m_dimensions[NumDims-1] = num_patches;
113 
114       m_inputStrides[0] = 1;
115       m_patchStrides[0] = 1;
116       for (int i = 1; i < NumDims-1; ++i) {
117         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
118         m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
119       }
120       m_outputStrides[0] = 1;
121       for (int i = 1; i < NumDims; ++i) {
122         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
123       }
124     } else {
125       for (int i = 0; i < NumDims-1; ++i) {
126         m_dimensions[i+1] = patch_dims[i];
127         num_patches *= (input_dims[i] - patch_dims[i] + 1);
128       }
129       m_dimensions[0] = num_patches;
130 
131       m_inputStrides[NumDims-2] = 1;
132       m_patchStrides[NumDims-2] = 1;
133       for (int i = NumDims-3; i >= 0; --i) {
134         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
135         m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
136       }
137       m_outputStrides[NumDims-1] = 1;
138       for (int i = NumDims-2; i >= 0; --i) {
139         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
140       }
141     }
142   }
143 
144   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
145 
146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
147     m_impl.evalSubExprsIfNeeded(NULL);
148     return true;
149   }
150 
151   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
152     m_impl.cleanup();
153   }
154 
155   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
156   {
157     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
158     // Find the location of the first element of the patch.
159     Index patchIndex = index / m_outputStrides[output_stride_index];
160     // Find the offset of the element wrt the location of the first element.
161     Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
162     Index inputIndex = 0;
163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
164       for (int i = NumDims - 2; i > 0; --i) {
165         const Index patchIdx = patchIndex / m_patchStrides[i];
166         patchIndex -= patchIdx * m_patchStrides[i];
167         const Index offsetIdx = patchOffset / m_outputStrides[i];
168         patchOffset -= offsetIdx * m_outputStrides[i];
169         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
170       }
171     } else {
172       for (int i = 0; i < NumDims - 2; ++i) {
173         const Index patchIdx = patchIndex / m_patchStrides[i];
174         patchIndex -= patchIdx * m_patchStrides[i];
175         const Index offsetIdx = patchOffset / m_outputStrides[i+1];
176         patchOffset -= offsetIdx * m_outputStrides[i+1];
177         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
178       }
179     }
180     inputIndex += (patchIndex + patchOffset);
181     return m_impl.coeff(inputIndex);
182   }
183 
184   template<int LoadMode>
185   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
186   {
187     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
188     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
189 
190     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
191     Index indices[2] = {index, index + PacketSize - 1};
192     Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
193                              indices[1] / m_outputStrides[output_stride_index]};
194     Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
195                              indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
196 
197     Index inputIndices[2] = {0, 0};
198     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
199       for (int i = NumDims - 2; i > 0; --i) {
200         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
201                                    patchIndices[1] / m_patchStrides[i]};
202         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
203         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
204 
205         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
206                                     patchOffsets[1] / m_outputStrides[i]};
207         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
208         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
209 
210         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
211         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
212       }
213     } else {
214       for (int i = 0; i < NumDims - 2; ++i) {
215         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
216                                    patchIndices[1] / m_patchStrides[i]};
217         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
218         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
219 
220         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
221                                     patchOffsets[1] / m_outputStrides[i+1]};
222         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
223         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
224 
225         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
226         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
227       }
228     }
229     inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
230     inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
231 
232     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
233       PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
234       return rslt;
235     }
236     else {
237       EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
238       values[0] = m_impl.coeff(inputIndices[0]);
239       values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
240       for (int i = 1; i < PacketSize-1; ++i) {
241         values[i] = coeff(index+i);
242       }
243       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
244       return rslt;
245     }
246   }
247 
248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
249     const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() +
250                                            TensorOpCost::MulCost<Index>() +
251                                            2 * TensorOpCost::AddCost<Index>());
252     return m_impl.costPerCoeff(vectorized) +
253            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
254   }
255 
256   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
257 
258  protected:
259   Dimensions m_dimensions;
260   array<Index, NumDims> m_outputStrides;
261   array<Index, NumDims-1> m_inputStrides;
262   array<Index, NumDims-1> m_patchStrides;
263 
264   TensorEvaluator<ArgType, Device> m_impl;
265 };
266 
267 } // end namespace Eigen
268 
269 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
270