• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
12 
13 namespace Eigen {
14 
15 /** \class TensorBroadcasting
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor broadcasting class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename Broadcast, typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef traits<XprType> XprTraits;
28   typedef typename XprTraits::StorageKind StorageKind;
29   typedef typename XprTraits::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = XprTraits::NumDimensions;
33   static const int Layout = XprTraits::Layout;
34 };
35 
36 template<typename Broadcast, typename XprType>
37 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
38 {
39   typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
40 };
41 
42 template<typename Broadcast, typename XprType>
43 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
44 {
45   typedef TensorBroadcastingOp<Broadcast, XprType> type;
46 };
47 
48 template <typename Dims>
49 struct is_input_scalar {
50   static const bool value = false;
51 };
52 template <>
53 struct is_input_scalar<Sizes<> > {
54   static const bool value = true;
55 };
56 #ifndef EIGEN_EMULATE_CXX11_META_H
57 template <typename std::size_t... Indices>
58 struct is_input_scalar<Sizes<Indices...> > {
59   static const bool value = (Sizes<Indices...>::total_size == 1);
60 };
61 #endif
62 
63 }  // end namespace internal
64 
65 
66 
67 template<typename Broadcast, typename XprType>
68 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
69 {
70   public:
71   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
72   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
73   typedef typename XprType::CoeffReturnType CoeffReturnType;
74   typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
75   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
76   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
77 
78   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
79       : m_xpr(expr), m_broadcast(broadcast) {}
80 
81     EIGEN_DEVICE_FUNC
82     const Broadcast& broadcast() const { return m_broadcast; }
83 
84     EIGEN_DEVICE_FUNC
85     const typename internal::remove_all<typename XprType::Nested>::type&
86     expression() const { return m_xpr; }
87 
88   protected:
89     typename XprType::Nested m_xpr;
90     const Broadcast m_broadcast;
91 };
92 
93 
94 // Eval as rvalue
95 template<typename Broadcast, typename ArgType, typename Device>
96 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
97 {
98   typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
99   typedef typename XprType::Index Index;
100   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
101   typedef DSizes<Index, NumDims> Dimensions;
102   typedef typename XprType::Scalar Scalar;
103   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
104   typedef typename XprType::CoeffReturnType CoeffReturnType;
105   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
106   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
107 
108   enum {
109     IsAligned = true,
110     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
111     Layout = TensorEvaluator<ArgType, Device>::Layout,
112     RawAccess = false
113   };
114 
115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
116     : m_broadcast(op.broadcast()),m_impl(op.expression(), device)
117   {
118     // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
119     // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
120     // tensor with N >= 1 of 1 element first and then broadcast.
121     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
122     const InputDimensions& input_dims = m_impl.dimensions();
123     const Broadcast& broadcast = op.broadcast();
124     for (int i = 0; i < NumDims; ++i) {
125       eigen_assert(input_dims[i] > 0);
126       m_dimensions[i] = input_dims[i] * broadcast[i];
127     }
128 
129     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
130       m_inputStrides[0] = 1;
131       m_outputStrides[0] = 1;
132       for (int i = 1; i < NumDims; ++i) {
133         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
134         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
135       }
136     } else {
137       m_inputStrides[NumDims-1] = 1;
138       m_outputStrides[NumDims-1] = 1;
139       for (int i = NumDims-2; i >= 0; --i) {
140         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
141         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
142       }
143     }
144   }
145 
146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
147 
148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
149     m_impl.evalSubExprsIfNeeded(NULL);
150     return true;
151   }
152 
153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
154     m_impl.cleanup();
155   }
156 
157   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
158   {
159     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
160       return m_impl.coeff(0);
161     }
162 
163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
164       return coeffColMajor(index);
165     } else {
166       return coeffRowMajor(index);
167     }
168   }
169 
170   // TODO: attempt to speed this up. The integer divisions and modulo are slow
171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
172   {
173     Index inputIndex = 0;
174     for (int i = NumDims - 1; i > 0; --i) {
175       const Index idx = index / m_outputStrides[i];
176       if (internal::index_statically_eq<Broadcast>(i, 1)) {
177         eigen_assert(idx < m_impl.dimensions()[i]);
178         inputIndex += idx * m_inputStrides[i];
179       } else {
180         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
181           eigen_assert(idx % m_impl.dimensions()[i] == 0);
182         } else {
183           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
184         }
185       }
186       index -= idx * m_outputStrides[i];
187     }
188     if (internal::index_statically_eq<Broadcast>(0, 1)) {
189       eigen_assert(index < m_impl.dimensions()[0]);
190       inputIndex += index;
191     } else {
192       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
193         eigen_assert(index % m_impl.dimensions()[0] == 0);
194       } else {
195         inputIndex += (index % m_impl.dimensions()[0]);
196       }
197     }
198     return m_impl.coeff(inputIndex);
199   }
200 
201   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
202   {
203     Index inputIndex = 0;
204     for (int i = 0; i < NumDims - 1; ++i) {
205       const Index idx = index / m_outputStrides[i];
206       if (internal::index_statically_eq<Broadcast>(i, 1)) {
207         eigen_assert(idx < m_impl.dimensions()[i]);
208         inputIndex += idx * m_inputStrides[i];
209       } else {
210         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
211           eigen_assert(idx % m_impl.dimensions()[i] == 0);
212         } else {
213           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
214         }
215       }
216       index -= idx * m_outputStrides[i];
217     }
218     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
219       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
220       inputIndex += index;
221     } else {
222       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
223         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
224       } else {
225         inputIndex += (index % m_impl.dimensions()[NumDims-1]);
226       }
227     }
228     return m_impl.coeff(inputIndex);
229   }
230 
231   template<int LoadMode>
232   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
233   {
234     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
235       return internal::pset1<PacketReturnType>(m_impl.coeff(0));
236     }
237 
238     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
239       return packetColMajor<LoadMode>(index);
240     } else {
241       return packetRowMajor<LoadMode>(index);
242     }
243   }
244 
245   // Ignore the LoadMode and always use unaligned loads since we can't guarantee
246   // the alignment at compile time.
247   template<int LoadMode>
248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
249   {
250     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
251     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
252 
253     const Index originalIndex = index;
254 
255     Index inputIndex = 0;
256     for (int i = NumDims - 1; i > 0; --i) {
257       const Index idx = index / m_outputStrides[i];
258       if (internal::index_statically_eq<Broadcast>(i, 1)) {
259         eigen_assert(idx < m_impl.dimensions()[i]);
260         inputIndex += idx * m_inputStrides[i];
261       } else {
262         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
263           eigen_assert(idx % m_impl.dimensions()[i] == 0);
264         } else {
265           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
266         }
267       }
268       index -= idx * m_outputStrides[i];
269     }
270     Index innermostLoc;
271     if (internal::index_statically_eq<Broadcast>(0, 1)) {
272       eigen_assert(index < m_impl.dimensions()[0]);
273       innermostLoc = index;
274     } else {
275       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
276         eigen_assert(index % m_impl.dimensions()[0] == 0);
277         innermostLoc = 0;
278       } else {
279         innermostLoc = index % m_impl.dimensions()[0];
280       }
281     }
282     inputIndex += innermostLoc;
283 
284     // Todo: this could be extended to the second dimension if we're not
285     // broadcasting alongside the first dimension, and so on.
286     if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
287       return m_impl.template packet<Unaligned>(inputIndex);
288     } else {
289       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
290       values[0] = m_impl.coeff(inputIndex);
291       for (int i = 1; i < PacketSize; ++i) {
292         values[i] = coeffColMajor(originalIndex+i);
293       }
294       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
295       return rslt;
296     }
297   }
298 
299   template<int LoadMode>
300   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
301   {
302     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
303     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
304 
305     const Index originalIndex = index;
306 
307     Index inputIndex = 0;
308     for (int i = 0; i < NumDims - 1; ++i) {
309       const Index idx = index / m_outputStrides[i];
310       if (internal::index_statically_eq<Broadcast>(i, 1)) {
311         eigen_assert(idx < m_impl.dimensions()[i]);
312         inputIndex += idx * m_inputStrides[i];
313       } else {
314         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
315           eigen_assert(idx % m_impl.dimensions()[i] == 0);
316         } else {
317           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
318         }
319       }
320       index -= idx * m_outputStrides[i];
321     }
322     Index innermostLoc;
323     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
324       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
325       innermostLoc = index;
326     } else {
327       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
328         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
329         innermostLoc = 0;
330       } else {
331         innermostLoc = index % m_impl.dimensions()[NumDims-1];
332       }
333     }
334     inputIndex += innermostLoc;
335 
336     // Todo: this could be extended to the second dimension if we're not
337     // broadcasting alongside the first dimension, and so on.
338     if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
339       return m_impl.template packet<Unaligned>(inputIndex);
340     } else {
341       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
342       values[0] = m_impl.coeff(inputIndex);
343       for (int i = 1; i < PacketSize; ++i) {
344         values[i] = coeffRowMajor(originalIndex+i);
345       }
346       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
347       return rslt;
348     }
349   }
350 
351   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
352   costPerCoeff(bool vectorized) const {
353     double compute_cost = TensorOpCost::AddCost<Index>();
354     if (NumDims > 0) {
355       for (int i = NumDims - 1; i > 0; --i) {
356         compute_cost += TensorOpCost::DivCost<Index>();
357         if (internal::index_statically_eq<Broadcast>(i, 1)) {
358           compute_cost +=
359               TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
360         } else {
361           if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
362             compute_cost += TensorOpCost::MulCost<Index>() +
363                             TensorOpCost::ModCost<Index>() +
364                             TensorOpCost::AddCost<Index>();
365           }
366         }
367         compute_cost +=
368             TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
369       }
370     }
371     return m_impl.costPerCoeff(vectorized) +
372            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
373   }
374 
375   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
376 
377   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
378 
379   Broadcast functor() const { return m_broadcast; }
380 
381  protected:
382   const Broadcast m_broadcast;
383   Dimensions m_dimensions;
384   array<Index, NumDims> m_outputStrides;
385   array<Index, NumDims> m_inputStrides;
386   TensorEvaluator<ArgType, Device> m_impl;
387 };
388 
389 
390 } // end namespace Eigen
391 
392 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
393