• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
12 
13 namespace Eigen {
14 
15 /** \class TensorStriding
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor striding class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename Strides, typename XprType>
24 struct traits<TensorStridingOp<Strides, XprType> > : public traits<XprType>
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef traits<XprType> XprTraits;
28   typedef typename XprTraits::StorageKind StorageKind;
29   typedef typename XprTraits::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = XprTraits::NumDimensions;
33   static const int Layout = XprTraits::Layout;
34 };
35 
36 template<typename Strides, typename XprType>
37 struct eval<TensorStridingOp<Strides, XprType>, Eigen::Dense>
38 {
39   typedef const TensorStridingOp<Strides, XprType>& type;
40 };
41 
42 template<typename Strides, typename XprType>
43 struct nested<TensorStridingOp<Strides, XprType>, 1, typename eval<TensorStridingOp<Strides, XprType> >::type>
44 {
45   typedef TensorStridingOp<Strides, XprType> type;
46 };
47 
48 }  // end namespace internal
49 
50 
51 
52 template<typename Strides, typename XprType>
53 class TensorStridingOp : public TensorBase<TensorStridingOp<Strides, XprType> >
54 {
55   public:
56   typedef typename Eigen::internal::traits<TensorStridingOp>::Scalar Scalar;
57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58   typedef typename XprType::CoeffReturnType CoeffReturnType;
59   typedef typename Eigen::internal::nested<TensorStridingOp>::type Nested;
60   typedef typename Eigen::internal::traits<TensorStridingOp>::StorageKind StorageKind;
61   typedef typename Eigen::internal::traits<TensorStridingOp>::Index Index;
62 
63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorStridingOp(const XprType& expr, const Strides& dims)
64       : m_xpr(expr), m_dims(dims) {}
65 
66     EIGEN_DEVICE_FUNC
67     const Strides& strides() const { return m_dims; }
68 
69     EIGEN_DEVICE_FUNC
70     const typename internal::remove_all<typename XprType::Nested>::type&
71     expression() const { return m_xpr; }
72 
73     EIGEN_DEVICE_FUNC
74     EIGEN_STRONG_INLINE TensorStridingOp& operator = (const TensorStridingOp& other)
75     {
76       typedef TensorAssignOp<TensorStridingOp, const TensorStridingOp> Assign;
77       Assign assign(*this, other);
78       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
79       return *this;
80     }
81 
82     template<typename OtherDerived>
83     EIGEN_DEVICE_FUNC
84     EIGEN_STRONG_INLINE TensorStridingOp& operator = (const OtherDerived& other)
85     {
86       typedef TensorAssignOp<TensorStridingOp, const OtherDerived> Assign;
87       Assign assign(*this, other);
88       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
89       return *this;
90     }
91 
92   protected:
93     typename XprType::Nested m_xpr;
94     const Strides m_dims;
95 };
96 
97 
98 // Eval as rvalue
99 template<typename Strides, typename ArgType, typename Device>
100 struct TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device>
101 {
102   typedef TensorStridingOp<Strides, ArgType> XprType;
103   typedef typename XprType::Index Index;
104   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
105   typedef DSizes<Index, NumDims> Dimensions;
106   typedef typename XprType::Scalar Scalar;
107   typedef typename XprType::CoeffReturnType CoeffReturnType;
108   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
109   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
110 
111   enum {
112     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
113     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
114     Layout = TensorEvaluator<ArgType, Device>::Layout,
115     CoordAccess = false,  // to be implemented
116     RawAccess = false
117   };
118 
119   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
120       : m_impl(op.expression(), device)
121   {
122     m_dimensions = m_impl.dimensions();
123     for (int i = 0; i < NumDims; ++i) {
124       m_dimensions[i] = ceilf(static_cast<float>(m_dimensions[i]) / op.strides()[i]);
125     }
126 
127     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
128     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
129       m_outputStrides[0] = 1;
130       m_inputStrides[0] = 1;
131       for (int i = 1; i < NumDims; ++i) {
132         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
133         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
134         m_inputStrides[i-1] *= op.strides()[i-1];
135       }
136       m_inputStrides[NumDims-1] *= op.strides()[NumDims-1];
137     } else {  // RowMajor
138       m_outputStrides[NumDims-1] = 1;
139       m_inputStrides[NumDims-1] = 1;
140       for (int i = NumDims - 2; i >= 0; --i) {
141         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
142         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
143         m_inputStrides[i+1] *= op.strides()[i+1];
144       }
145       m_inputStrides[0] *= op.strides()[0];
146     }
147   }
148 
149   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
150 
151   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
152     m_impl.evalSubExprsIfNeeded(NULL);
153     return true;
154   }
155   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
156     m_impl.cleanup();
157   }
158 
159   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
160   {
161     return m_impl.coeff(srcCoeff(index));
162   }
163 
164   template<int LoadMode>
165   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
166   {
167     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
168     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
169 
170     Index inputIndices[] = {0, 0};
171     Index indices[] = {index, index + PacketSize - 1};
172     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
173       for (int i = NumDims - 1; i > 0; --i) {
174         const Index idx0 = indices[0] / m_outputStrides[i];
175         const Index idx1 = indices[1] / m_outputStrides[i];
176         inputIndices[0] += idx0 * m_inputStrides[i];
177         inputIndices[1] += idx1 * m_inputStrides[i];
178         indices[0] -= idx0 * m_outputStrides[i];
179         indices[1] -= idx1 * m_outputStrides[i];
180       }
181       inputIndices[0] += indices[0] * m_inputStrides[0];
182       inputIndices[1] += indices[1] * m_inputStrides[0];
183     } else {  // RowMajor
184       for (int i = 0; i < NumDims - 1; ++i) {
185         const Index idx0 = indices[0] / m_outputStrides[i];
186         const Index idx1 = indices[1] / m_outputStrides[i];
187         inputIndices[0] += idx0 * m_inputStrides[i];
188         inputIndices[1] += idx1 * m_inputStrides[i];
189         indices[0] -= idx0 * m_outputStrides[i];
190         indices[1] -= idx1 * m_outputStrides[i];
191       }
192       inputIndices[0] += indices[0] * m_inputStrides[NumDims-1];
193       inputIndices[1] += indices[1] * m_inputStrides[NumDims-1];
194     }
195     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
196       PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
197       return rslt;
198     }
199     else {
200       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
201       values[0] = m_impl.coeff(inputIndices[0]);
202       values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
203       for (int i = 1; i < PacketSize-1; ++i) {
204         values[i] = coeff(index+i);
205       }
206       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
207       return rslt;
208     }
209   }
210 
211   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
212     double compute_cost = (NumDims - 1) * (TensorOpCost::AddCost<Index>() +
213                                            TensorOpCost::MulCost<Index>() +
214                                            TensorOpCost::DivCost<Index>()) +
215         TensorOpCost::MulCost<Index>();
216     if (vectorized) {
217       compute_cost *= 2;  // packet() computes two indices
218     }
219     const int innerDim = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? 0 : (NumDims - 1);
220     return m_impl.costPerCoeff(vectorized && m_inputStrides[innerDim] == 1) +
221         // Computation is not vectorized per se, but it is done once per packet.
222         TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
223   }
224 
225   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
226 
227  protected:
228   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index srcCoeff(Index index) const
229   {
230     Index inputIndex = 0;
231     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
232       for (int i = NumDims - 1; i > 0; --i) {
233         const Index idx = index / m_outputStrides[i];
234         inputIndex += idx * m_inputStrides[i];
235         index -= idx * m_outputStrides[i];
236       }
237       inputIndex += index * m_inputStrides[0];
238     } else {  // RowMajor
239       for (int i = 0; i < NumDims - 1; ++i) {
240         const Index idx = index / m_outputStrides[i];
241         inputIndex += idx * m_inputStrides[i];
242         index -= idx * m_outputStrides[i];
243       }
244       inputIndex += index * m_inputStrides[NumDims-1];
245     }
246     return inputIndex;
247   }
248 
249   Dimensions m_dimensions;
250   array<Index, NumDims> m_outputStrides;
251   array<Index, NumDims> m_inputStrides;
252   TensorEvaluator<ArgType, Device> m_impl;
253 };
254 
255 
256 // Eval as lvalue
257 template<typename Strides, typename ArgType, typename Device>
258 struct TensorEvaluator<TensorStridingOp<Strides, ArgType>, Device>
259     : public TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device>
260 {
261   typedef TensorStridingOp<Strides, ArgType> XprType;
262   typedef TensorEvaluator<const XprType, Device> Base;
263   //  typedef typename XprType::Index Index;
264   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
265   //  typedef DSizes<Index, NumDims> Dimensions;
266 
267   enum {
268     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
269     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
270     Layout = TensorEvaluator<ArgType, Device>::Layout,
271     CoordAccess = false,  // to be implemented
272     RawAccess = false
273   };
274 
275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
276       : Base(op, device) { }
277 
278   typedef typename XprType::Index Index;
279   typedef typename XprType::Scalar Scalar;
280   typedef typename XprType::CoeffReturnType CoeffReturnType;
281   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
282   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
283 
284   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
285   {
286     return this->m_impl.coeffRef(this->srcCoeff(index));
287   }
288 
289   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
290   void writePacket(Index index, const PacketReturnType& x)
291   {
292     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
293     eigen_assert(index+PacketSize-1 < this->dimensions().TotalSize());
294 
295     Index inputIndices[] = {0, 0};
296     Index indices[] = {index, index + PacketSize - 1};
297     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
298       for (int i = NumDims - 1; i > 0; --i) {
299         const Index idx0 = indices[0] / this->m_outputStrides[i];
300         const Index idx1 = indices[1] / this->m_outputStrides[i];
301         inputIndices[0] += idx0 * this->m_inputStrides[i];
302         inputIndices[1] += idx1 * this->m_inputStrides[i];
303         indices[0] -= idx0 * this->m_outputStrides[i];
304         indices[1] -= idx1 * this->m_outputStrides[i];
305       }
306       inputIndices[0] += indices[0] * this->m_inputStrides[0];
307       inputIndices[1] += indices[1] * this->m_inputStrides[0];
308     } else {  // RowMajor
309       for (int i = 0; i < NumDims - 1; ++i) {
310         const Index idx0 = indices[0] / this->m_outputStrides[i];
311         const Index idx1 = indices[1] / this->m_outputStrides[i];
312         inputIndices[0] += idx0 * this->m_inputStrides[i];
313         inputIndices[1] += idx1 * this->m_inputStrides[i];
314         indices[0] -= idx0 * this->m_outputStrides[i];
315         indices[1] -= idx1 * this->m_outputStrides[i];
316       }
317       inputIndices[0] += indices[0] * this->m_inputStrides[NumDims-1];
318       inputIndices[1] += indices[1] * this->m_inputStrides[NumDims-1];
319     }
320     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
321       this->m_impl.template writePacket<Unaligned>(inputIndices[0], x);
322     }
323     else {
324       EIGEN_ALIGN_MAX Scalar values[PacketSize];
325       internal::pstore<Scalar, PacketReturnType>(values, x);
326       this->m_impl.coeffRef(inputIndices[0]) = values[0];
327       this->m_impl.coeffRef(inputIndices[1]) = values[PacketSize-1];
328       for (int i = 1; i < PacketSize-1; ++i) {
329         this->coeffRef(index+i) = values[i];
330       }
331     }
332   }
333 };
334 
335 
336 } // end namespace Eigen
337 
338 #endif // EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
339