• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
12 
13 namespace Eigen {
14 
15 /** \class TensorBroadcasting
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor broadcasting class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename Broadcast, typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef traits<XprType> XprTraits;
28   typedef typename XprTraits::StorageKind StorageKind;
29   typedef typename XprTraits::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = XprTraits::NumDimensions;
33   static const int Layout = XprTraits::Layout;
34   typedef typename XprTraits::PointerType PointerType;
35 };
36 
37 template<typename Broadcast, typename XprType>
38 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
39 {
40   typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
41 };
42 
43 template<typename Broadcast, typename XprType>
44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
45 {
46   typedef TensorBroadcastingOp<Broadcast, XprType> type;
47 };
48 
49 template <typename Dims>
50 struct is_input_scalar {
51   static const bool value = false;
52 };
53 template <>
54 struct is_input_scalar<Sizes<> > {
55   static const bool value = true;
56 };
57 #ifndef EIGEN_EMULATE_CXX11_META_H
58 template <typename std::ptrdiff_t... Indices>
59 struct is_input_scalar<Sizes<Indices...> > {
60   static const bool value = (Sizes<Indices...>::total_size == 1);
61 };
62 #endif
63 
64 }  // end namespace internal
65 
66 
67 
68 template<typename Broadcast, typename XprType>
69 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
70 {
71   public:
72   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
73   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
74   typedef typename XprType::CoeffReturnType CoeffReturnType;
75   typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
76   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
77   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
78 
79   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
80       : m_xpr(expr), m_broadcast(broadcast) {}
81 
82     EIGEN_DEVICE_FUNC
83     const Broadcast& broadcast() const { return m_broadcast; }
84 
85     EIGEN_DEVICE_FUNC
86     const typename internal::remove_all<typename XprType::Nested>::type&
87     expression() const { return m_xpr; }
88 
89   protected:
90     typename XprType::Nested m_xpr;
91     const Broadcast m_broadcast;
92 };
93 
94 
95 // Eval as rvalue
96 template<typename Broadcast, typename ArgType, typename Device>
97 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
98 {
99   typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
100   typedef typename XprType::Index Index;
101   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
102   typedef DSizes<Index, NumDims> Dimensions;
103   typedef typename XprType::Scalar Scalar;
104   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
105   typedef typename XprType::CoeffReturnType CoeffReturnType;
106   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
107   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
108   protected: //  all the non-static fields must have the same access control, otherwise the TensorEvaluator wont be standard layout;
109   bool isCopy, nByOne, oneByN;
110   public:
111   typedef StorageMemory<CoeffReturnType, Device> Storage;
112   typedef typename Storage::Type EvaluatorPointerType;
113 
114   enum {
115     IsAligned         = TensorEvaluator<ArgType, Device>::IsAligned,
116     PacketAccess      = TensorEvaluator<ArgType, Device>::PacketAccess,
117     BlockAccess       = TensorEvaluator<ArgType, Device>::BlockAccess,
118     PreferBlockAccess = true,
119     Layout            = TensorEvaluator<ArgType, Device>::Layout,
120     RawAccess         = false
121   };
122 
123   typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
124 
125   // We do block based broadcasting using a trick with 2x tensor rank and 0
126   // strides. See block method implementation for details.
127   typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
128 
129   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
130  typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
131   typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
132 
133   typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
134       ArgTensorBlock;
135 
136   typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
137                                                      Layout, Index>
138       TensorBlock;
139   //===--------------------------------------------------------------------===//
140 
141   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
142       : isCopy(false), nByOne(false), oneByN(false),
143         m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device)
144   {
145 
146     // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
147     // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
148     // tensor with N >= 1 of 1 element first and then broadcast.
149     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
150     const InputDimensions& input_dims = m_impl.dimensions();
151     isCopy = true;
152     for (int i = 0; i < NumDims; ++i) {
153       eigen_assert(input_dims[i] > 0);
154       m_dimensions[i] = input_dims[i] * m_broadcast[i];
155       if (m_broadcast[i] != 1) {
156         isCopy = false;
157       }
158     }
159 
160     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
161       m_inputStrides[0] = 1;
162       m_outputStrides[0] = 1;
163       for (int i = 1; i < NumDims; ++i) {
164         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
165         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
166       }
167     } else {
168       m_inputStrides[NumDims-1] = 1;
169       m_outputStrides[NumDims-1] = 1;
170       for (int i = NumDims-2; i >= 0; --i) {
171         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
172         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
173       }
174     }
175 
176     if (input_dims[0] == 1) {
177       oneByN = true;
178       for (int i = 1; i < NumDims; ++i) {
179         if (m_broadcast[i] != 1) {
180           oneByN = false;
181           break;
182         }
183       }
184     } else if (input_dims[NumDims-1] == 1) {
185       nByOne = true;
186       for (int i = 0; i < NumDims-1; ++i) {
187         if (m_broadcast[i] != 1) {
188           nByOne = false;
189           break;
190         }
191       }
192     }
193 
194     // Handle special format like NCHW, its input shape is '[1, N..., 1]' and
195     // broadcast shape is '[N, 1..., N]'
196     if (!oneByN && !nByOne) {
197       if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
198         nByOne = true;
199         oneByN = true;
200         for (int i = 1; i < NumDims-1; ++i) {
201           if (m_broadcast[i] != 1) {
202             nByOne = false;
203             oneByN = false;
204             break;
205           }
206         }
207       }
208     }
209   }
210 
211   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
212 
213   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
214     m_impl.evalSubExprsIfNeeded(NULL);
215     return true;
216   }
217 
218 #ifdef EIGEN_USE_THREADS
219   template <typename EvalSubExprsCallback>
220   EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
221       EvaluatorPointerType, EvalSubExprsCallback done) {
222     m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); });
223   }
224 #endif  // EIGEN_USE_THREADS
225 
226   EIGEN_STRONG_INLINE void cleanup() {
227     m_impl.cleanup();
228   }
229 
230   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
231   {
232     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
233       return m_impl.coeff(0);
234     }
235 
236     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
237       if (isCopy) {
238         return m_impl.coeff(index);
239       } else {
240         return coeffColMajor(index);
241       }
242     } else {
243       if (isCopy) {
244         return m_impl.coeff(index);
245       } else {
246         return coeffRowMajor(index);
247       }
248     }
249   }
250 
251   // TODO: attempt to speed this up. The integer divisions and modulo are slow
252   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index) const {
253     Index inputIndex = 0;
254     EIGEN_UNROLL_LOOP
255     for (int i = NumDims - 1; i > 0; --i) {
256       const Index idx = index / m_outputStrides[i];
257       if (internal::index_statically_eq<Broadcast>(i, 1)) {
258         eigen_assert(idx < m_impl.dimensions()[i]);
259         inputIndex += idx * m_inputStrides[i];
260       } else {
261         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
262           eigen_assert(idx % m_impl.dimensions()[i] == 0);
263         } else {
264           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
265         }
266       }
267       index -= idx * m_outputStrides[i];
268     }
269     if (internal::index_statically_eq<Broadcast>(0, 1)) {
270       eigen_assert(index < m_impl.dimensions()[0]);
271       inputIndex += index;
272     } else {
273       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
274         eigen_assert(index % m_impl.dimensions()[0] == 0);
275       } else {
276         inputIndex += (index % m_impl.dimensions()[0]);
277       }
278     }
279     return inputIndex;
280   }
281 
282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
283   {
284     return m_impl.coeff(indexColMajor(index));
285   }
286 
287   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index) const {
288     Index inputIndex = 0;
289     EIGEN_UNROLL_LOOP
290     for (int i = 0; i < NumDims - 1; ++i) {
291       const Index idx = index / m_outputStrides[i];
292       if (internal::index_statically_eq<Broadcast>(i, 1)) {
293         eigen_assert(idx < m_impl.dimensions()[i]);
294         inputIndex += idx * m_inputStrides[i];
295       } else {
296         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
297           eigen_assert(idx % m_impl.dimensions()[i] == 0);
298         } else {
299           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
300         }
301       }
302       index -= idx * m_outputStrides[i];
303     }
304     if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
305       eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
306       inputIndex += index;
307     } else {
308       if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
309         eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
310       } else {
311         inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
312       }
313     }
314     return inputIndex;
315   }
316 
317   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
318   {
319     return m_impl.coeff(indexRowMajor(index));
320   }
321 
322   template<int LoadMode>
323   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
324   {
325     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
326       return internal::pset1<PacketReturnType>(m_impl.coeff(0));
327     }
328 
329     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
330       if (isCopy) {
331         #ifdef EIGEN_GPU_COMPILE_PHASE
332         // See PR 437: on NVIDIA P100 and K20m we observed a x3-4 speed up by enforcing
333         // unaligned loads here. The reason is unclear though.
334         return m_impl.template packet<Unaligned>(index);
335         #else
336         return m_impl.template packet<LoadMode>(index);
337         #endif
338       } else if (oneByN && !nByOne) {
339         return packetNByOne<LoadMode>(index);
340       } else if (!oneByN && nByOne) {
341         return packetOneByN<LoadMode>(index);
342       } else if (oneByN && nByOne) {
343         return packetOneByNByOne<LoadMode>(index);
344       } else {
345         return packetColMajor<LoadMode>(index);
346       }
347     } else {
348       if (isCopy) {
349         #ifdef EIGEN_GPU_COMPILE_PHASE
350         // See above.
351         return m_impl.template packet<Unaligned>(index);
352         #else
353         return m_impl.template packet<LoadMode>(index);
354         #endif
355       } else if (oneByN && !nByOne) {
356         return packetOneByN<LoadMode>(index);
357       } else if (!oneByN && nByOne) {
358         return packetNByOne<LoadMode>(index);
359       } else if (oneByN && nByOne) {
360         return packetOneByNByOne<LoadMode>(index);
361       } else {
362         return packetRowMajor<LoadMode>(index);
363       }
364     }
365   }
366 
367   template<int LoadMode>
368   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
369   (Index index) const
370   {
371     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
372     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
373 
374     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
375     Index startDim, endDim;
376     Index inputIndex, outputOffset, batchedIndex;
377 
378     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
379       startDim = NumDims - 1;
380       endDim = 1;
381     } else {
382       startDim = 0;
383       endDim = NumDims - 2;
384     }
385 
386     batchedIndex = index % m_outputStrides[startDim];
387     inputIndex   = batchedIndex / m_outputStrides[endDim];
388     outputOffset = batchedIndex % m_outputStrides[endDim];
389 
390     if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
391       values[0] = m_impl.coeff(inputIndex);
392       return internal::pload1<PacketReturnType>(values);
393     } else {
394       EIGEN_UNROLL_LOOP
395       for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
396         if (outputOffset + cur < m_outputStrides[endDim]) {
397           values[i] = m_impl.coeff(inputIndex);
398         } else {
399           ++inputIndex;
400           inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
401           values[i] = m_impl.coeff(inputIndex);
402           outputOffset = 0;
403           cur = 0;
404         }
405       }
406       return internal::pload<PacketReturnType>(values);
407     }
408   }
409 
410   template<int LoadMode>
411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
412   {
413     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
414     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
415 
416     Index dim, inputIndex;
417 
418     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
419       dim = NumDims - 1;
420     } else {
421       dim = 0;
422     }
423 
424     inputIndex = index % m_inputStrides[dim];
425     if (inputIndex + PacketSize <= m_inputStrides[dim]) {
426       return m_impl.template packet<Unaligned>(inputIndex);
427     } else {
428       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
429       EIGEN_UNROLL_LOOP
430       for (int i = 0; i < PacketSize; ++i) {
431         if (inputIndex > m_inputStrides[dim]-1) {
432           inputIndex = 0;
433         }
434         values[i] = m_impl.coeff(inputIndex++);
435       }
436       return internal::pload<PacketReturnType>(values);
437     }
438   }
439 
440   template<int LoadMode>
441   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index) const
442   {
443     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
444     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
445 
446     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
447     Index dim, inputIndex, outputOffset;
448 
449     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
450       dim = 1;
451     } else {
452       dim = NumDims - 2;
453     }
454 
455     inputIndex   = index / m_outputStrides[dim];
456     outputOffset = index % m_outputStrides[dim];
457     if (outputOffset + PacketSize <= m_outputStrides[dim]) {
458       values[0] = m_impl.coeff(inputIndex);
459       return internal::pload1<PacketReturnType>(values);
460     } else {
461       EIGEN_UNROLL_LOOP
462       for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
463         if (outputOffset + cur < m_outputStrides[dim]) {
464           values[i] = m_impl.coeff(inputIndex);
465         } else {
466           values[i] = m_impl.coeff(++inputIndex);
467           outputOffset = 0;
468           cur = 0;
469         }
470       }
471       return internal::pload<PacketReturnType>(values);
472     }
473   }
474 
475   // Ignore the LoadMode and always use unaligned loads since we can't guarantee
476   // the alignment at compile time.
477   template<int LoadMode>
478   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
479   {
480     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
481     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
482 
483     const Index originalIndex = index;
484 
485     Index inputIndex = 0;
486     EIGEN_UNROLL_LOOP
487     for (int i = NumDims - 1; i > 0; --i) {
488       const Index idx = index / m_outputStrides[i];
489       if (internal::index_statically_eq<Broadcast>(i, 1)) {
490         eigen_assert(idx < m_impl.dimensions()[i]);
491         inputIndex += idx * m_inputStrides[i];
492       } else {
493         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
494           eigen_assert(idx % m_impl.dimensions()[i] == 0);
495         } else {
496           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
497         }
498       }
499       index -= idx * m_outputStrides[i];
500     }
501     Index innermostLoc;
502     if (internal::index_statically_eq<Broadcast>(0, 1)) {
503       eigen_assert(index < m_impl.dimensions()[0]);
504       innermostLoc = index;
505     } else {
506       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
507         eigen_assert(index % m_impl.dimensions()[0] == 0);
508         innermostLoc = 0;
509       } else {
510         innermostLoc = index % m_impl.dimensions()[0];
511       }
512     }
513     inputIndex += innermostLoc;
514 
515     // Todo: this could be extended to the second dimension if we're not
516     // broadcasting alongside the first dimension, and so on.
517     if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
518       return m_impl.template packet<Unaligned>(inputIndex);
519     } else {
520       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
521       values[0] = m_impl.coeff(inputIndex);
522       EIGEN_UNROLL_LOOP
523       for (int i = 1; i < PacketSize; ++i) {
524         if (innermostLoc + i < m_impl.dimensions()[0]) {
525           values[i] = m_impl.coeff(inputIndex+i);
526         } else {
527           values[i] = coeffColMajor(originalIndex+i);
528         }
529       }
530       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
531       return rslt;
532     }
533   }
534 
535   template<int LoadMode>
536   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
537   {
538     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
539     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
540 
541     const Index originalIndex = index;
542 
543     Index inputIndex = 0;
544     EIGEN_UNROLL_LOOP
545     for (int i = 0; i < NumDims - 1; ++i) {
546       const Index idx = index / m_outputStrides[i];
547       if (internal::index_statically_eq<Broadcast>(i, 1)) {
548         eigen_assert(idx < m_impl.dimensions()[i]);
549         inputIndex += idx * m_inputStrides[i];
550       } else {
551         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
552           eigen_assert(idx % m_impl.dimensions()[i] == 0);
553         } else {
554           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
555         }
556       }
557       index -= idx * m_outputStrides[i];
558     }
559     Index innermostLoc;
560     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
561       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
562       innermostLoc = index;
563     } else {
564       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
565         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
566         innermostLoc = 0;
567       } else {
568         innermostLoc = index % m_impl.dimensions()[NumDims-1];
569       }
570     }
571     inputIndex += innermostLoc;
572 
573     // Todo: this could be extended to the second dimension if we're not
574     // broadcasting alongside the first dimension, and so on.
575     if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
576       return m_impl.template packet<Unaligned>(inputIndex);
577     } else {
578       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
579       values[0] = m_impl.coeff(inputIndex);
580       EIGEN_UNROLL_LOOP
581       for (int i = 1; i < PacketSize; ++i) {
582         if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) {
583           values[i] = m_impl.coeff(inputIndex+i);
584         } else {
585           values[i] = coeffRowMajor(originalIndex+i);
586         }
587       }
588       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
589       return rslt;
590     }
591   }
592 
593   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
594   costPerCoeff(bool vectorized) const {
595     double compute_cost = TensorOpCost::AddCost<Index>();
596     if (!isCopy && NumDims > 0) {
597       EIGEN_UNROLL_LOOP
598       for (int i = NumDims - 1; i > 0; --i) {
599         compute_cost += TensorOpCost::DivCost<Index>();
600         if (internal::index_statically_eq<Broadcast>(i, 1)) {
601           compute_cost +=
602               TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
603         } else {
604           if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
605             compute_cost += TensorOpCost::MulCost<Index>() +
606                             TensorOpCost::ModCost<Index>() +
607                             TensorOpCost::AddCost<Index>();
608           }
609         }
610         compute_cost +=
611             TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
612       }
613     }
614     return m_impl.costPerCoeff(vectorized) +
615            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
616   }
617 
618   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
619   internal::TensorBlockResourceRequirements getResourceRequirements() const {
620     // TODO(wuke): Targeting L1 size is 30% faster than targeting L{-1} on large
621     // tensors. But this might need further tuning.
622     const size_t target_size = m_device.firstLevelCacheSize();
623     return internal::TensorBlockResourceRequirements::merge(
624         m_impl.getResourceRequirements(),
625         internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
626   }
627 
628   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
629   block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
630           bool /*root_of_expr_ast*/ = false) const {
631     BlockBroadcastingParams params = blockBroadcastingParams(desc);
632 
633     if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
634       return emptyBlock();
635     }
636 
637     // Prepare storage for the materialized broadcasting result.
638     const typename TensorBlock::Storage block_storage =
639         TensorBlock::prepareStorage(desc, scratch);
640     ScalarNoConst* materialized_output = block_storage.data();
641 
642     // We potentially will need to materialize input blocks.
643     size_t materialized_input_size = 0;
644     ScalarNoConst* materialized_input = NULL;
645 
646     // Initialize block broadcating iterator state for outer dimensions (outer
647     // with regard to bcast dimension). Dimension in this array are always in
648     // inner_most -> outer_most order (col major layout).
649     array<BlockBroadcastingIteratorState, NumDims> it;
650     int idx = 0;
651 
652     for (int i = params.inner_dim_count + 1; i < NumDims; ++i) {
653       const Index dim = IsColMajor ? i : NumDims - 1 - i;
654       it[idx].size = params.output_dims[dim];
655       it[idx].count = 0;
656       it[idx].output_stride = m_outputStrides[dim];
657       it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
658       idx++;
659     }
660 
661     // Write output into the beginning of `materialized_output`.
662     Index output_offset = 0;
663 
664     // We will fill output block by broadcasting along the bcast dim, and
665     // iterating over outer dimension.
666     const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
667 
668     for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
669       ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
670       Index bcast_offset = desc.offset() + output_offset;
671 
672       // Broadcast along the bcast dimension.
673       num_output_coeffs += BroadcastBlockAlongBcastDim(
674           params, bcast_offset, scratch, bcast_output, &materialized_input,
675           &materialized_input_size);
676 
677       // Switch to the next outer dimension.
678       for (int j = 0; j < idx; ++j) {
679         if (++it[j].count < it[j].size) {
680           output_offset += it[j].output_stride;
681           break;
682         }
683         it[j].count = 0;
684         output_offset -= it[j].output_span;
685       }
686     }
687 
688     return block_storage.AsTensorMaterializedBlock();
689   }
690 
691   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
692 
693   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
694 
695   Broadcast functor() const { return m_broadcast; }
696 #ifdef EIGEN_USE_SYCL
697   // binding placeholder accessors to a command group handler for SYCL
698   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
699       cl::sycl::handler& cgh) const {
700     m_impl.bind(cgh);
701   }
702 #endif
703  private:
704   static const bool IsColMajor =
705       static_cast<int>(Layout) == static_cast<int>(ColMajor);
706 
707   // We will build a general case block broadcasting on top of broadcasting
708   // primitive that will do broadcasting only for the inner dimension(s) along
709   // the first dimension smaller than the input size (it's called `bcast_dim`).
710   //
711   // Example:
712   //           dim:  0  1  2   (ColMajor)
713   //    input size: [9, 3, 6]
714   //    block size: [9, 2, 6]
715   //
716   // We will compute broadcasted block by iterating over the outer dimensions
717   // before `bcast_dim` (only dimension `2` in this example) and computing
718   // broadcasts along the `bcast_dim` (dimension `1` in this example).
719 
720   // BlockBroadcastingParams holds precomputed parameters for broadcasting a
721   // single block along the broadcasting dimension. Sizes and strides along the
722   // `bcast_dim` might be invalid, they will be adjusted later in
723   // `BroadcastBlockAlongBcastDim`.
724   struct BlockBroadcastingParams {
725     Dimensions input_dims;      // input expression dimensions
726     Dimensions output_dims;     // output block sizes
727     Dimensions output_strides;  // output block strides
728 
729     int inner_dim_count;   // count inner dimensions matching in size
730     int bcast_dim;         // broadcasting dimension index
731     Index bcast_dim_size;  // broadcasting dimension size
732     Index inner_dim_size;  // inner dimensions size
733 
734     // Block sizes and strides for the input block where all dimensions before
735     // `bcast_dim` are equal to `1`.
736     Dimensions input_block_sizes;
737     Dimensions input_block_strides;
738 
739     // Block sizes and strides for blocks with extra dimensions and strides `0`.
740     BroadcastDimensions bcast_block_sizes;
741     BroadcastDimensions bcast_block_strides;
742     BroadcastDimensions bcast_input_strides;
743   };
744 
745   struct BlockBroadcastingIteratorState {
746     Index size;
747     Index count;
748     Index output_stride;
749     Index output_span;
750   };
751 
752   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
753   blockBroadcastingParams(TensorBlockDesc& desc) const {
754     BlockBroadcastingParams params;
755 
756     params.input_dims = Dimensions(m_impl.dimensions());
757 
758     // Output block sizes and strides.
759     params.output_dims = desc.dimensions();
760     params.output_strides = internal::strides<Layout>(params.output_dims);
761 
762     // Find the broadcasting dimension (first dimension with output size smaller
763     // that the input size).
764     params.bcast_dim = 0;
765     params.bcast_dim_size = 1;
766     params.inner_dim_size = 1;
767 
768     // Count the number of inner dimensions that have the same size in the block
769     // and in the broadcast expression.
770     params.inner_dim_count = 0;
771 
772     for (int i = 0; i < NumDims; ++i) {
773       const int dim = IsColMajor ? i : NumDims - i - 1;
774 
775       if (params.output_dims[dim] == m_dimensions[dim]) {
776         params.inner_dim_size *= params.output_dims[dim];
777         ++params.inner_dim_count;
778         continue;
779       }
780 
781       // First non-matching dimension is the broadcasting dimension.
782       eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
783       params.bcast_dim = dim;
784       params.bcast_dim_size = params.output_dims[dim];
785       break;
786     }
787 
788     // Calculate the input block size for looking into the input.
789     for (int i = 0; i < params.inner_dim_count; ++i) {
790       const int dim = IsColMajor ? i : NumDims - i - 1;
791       params.input_block_sizes[dim] = params.input_dims[dim];
792     }
793     for (int i = params.inner_dim_count; i < NumDims; ++i) {
794       const int dim = IsColMajor ? i : NumDims - i - 1;
795       params.input_block_sizes[dim] = 1;
796     }
797     params.input_block_strides =
798         internal::strides<Layout>(params.input_block_sizes);
799 
800     // Broadcast with the 0-stride trick: Create 1 extra dim for each
801     // broadcast, set the input stride to 0.
802     //
803     // When ColMajor:
804     //
805     // - bcast_block_sizes:
806     //   [d_0, b_0, d_1, b_1, ...]
807     //
808     // - bcast_block_strides:
809     //   [output_block_strides[0], output_block_strides[0] * d_0,
810     //    output_block_strides[1], output_block_strides[1] * d_1,
811     //   ...]
812     //
813     // - bcast_input_strides:
814     //   [input_block_strides[0], 0,
815     //    input_block_strides[1], 0,
816     //   ...].
817     //
818     for (int i = 0; i < params.inner_dim_count; ++i) {
819       const int dim = IsColMajor ? i : NumDims - i - 1;
820 
821       const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
822       const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
823 
824       params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
825       params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
826       params.bcast_block_strides[copy_dim] = params.output_strides[dim];
827       params.bcast_block_strides[broadcast_dim] =
828           params.output_strides[dim] * params.input_dims[dim];
829       params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
830       params.bcast_input_strides[broadcast_dim] = 0;
831     }
832 
833     for (int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
834       const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
835       params.bcast_block_sizes[dim] = 1;
836       params.bcast_block_strides[dim] = 0;
837       params.bcast_input_strides[dim] = 0;
838     }
839 
840     return params;
841   }
842 
843   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock() const {
844     DSizes<Index, NumDims> dimensions;
845     for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
846     return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
847   }
848 
849   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
850       BlockBroadcastingParams params, Index bcast_offset,
851       TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
852       ScalarNoConst** materialized_input,
853       size_t* materialized_input_size) const {
854     if (params.bcast_dim_size == 1) {
855       // We just need one block read using the ready-set values above.
856       return BroadcastBlock(
857           params.input_block_sizes, params.input_block_strides,
858           params.bcast_block_sizes, params.bcast_block_strides,
859           params.bcast_input_strides, bcast_offset, 0, scratch,
860           materialized_output, materialized_input, materialized_input_size);
861 
862     } else if (params.input_dims[params.bcast_dim] == 1) {
863       // Broadcast bcast dimension (< NumDims) by bcast_dim_size.
864       const int broadcast_bcast_dim =
865           IsColMajor ? 2 * params.inner_dim_count + 1
866                      : 2 * NumDims - 2 * params.inner_dim_count - 2;
867 
868       params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
869       params.bcast_input_strides[broadcast_bcast_dim] = 0;
870       params.bcast_block_strides[broadcast_bcast_dim] =
871           params.output_strides[params.bcast_dim];
872 
873       return BroadcastBlock(
874           params.input_block_sizes, params.input_block_strides,
875           params.bcast_block_sizes, params.bcast_block_strides,
876           params.bcast_input_strides, bcast_offset, 0, scratch,
877           materialized_output, materialized_input, materialized_input_size);
878 
879     } else {
880       // Keep track of the total number of the coefficients written to the
881       // output block.
882       Index num_output_coeffs = 0;
883 
884       // The general case. Let's denote the output block as
885       //
886       //   x[..., a:a+bcast_dim_size, :, ..., :]
887       //
888       // where a:a+bcast_dim_size is a slice on the bcast_dim dimension
889       // (< NumDims). We need to split the a:a+bcast_dim_size into possibly 3
890       // sub-blocks:
891       //
892       // (1) a:b, where b is the smallest multiple of
893       //     input_dims[bcast_dim_start] in [a, a+bcast_dim_size].
894       //
895       // (2) b:c, where c is the largest multiple of input_dims[bcast_dim_start]
896       //     in [a, a+bcast_dim_size].
897       //
898       // (3) c:a+bcast_dim_size .
899       //
900       // Or, when b and c do not exist, we just need to process the whole block
901       // together.
902 
903       // Find a.
904       const Index bcast_dim_left_index =
905           bcast_offset / m_outputStrides[params.bcast_dim];
906 
907       // Find b and c.
908       const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
909 
910       // First multiple after a. This is b when <= bcast_dim_left_index +
911       // bcast_dim_size.
912       const Index first_multiple =
913           divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
914           input_bcast_dim_size;
915 
916       if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
917         // b exists, so does c. Find it.
918         const Index last_multiple =
919             (bcast_dim_left_index + params.bcast_dim_size) /
920             input_bcast_dim_size * input_bcast_dim_size;
921         const int copy_bcast_dim =
922             IsColMajor ? 2 * params.inner_dim_count
923                        : 2 * NumDims - 2 * params.inner_dim_count - 1;
924         const int broadcast_bcast_dim =
925             IsColMajor ? 2 * params.inner_dim_count + 1
926                        : 2 * NumDims - 2 * params.inner_dim_count - 2;
927 
928         if (first_multiple > bcast_dim_left_index) {
929           const Index head_size = first_multiple - bcast_dim_left_index;
930           params.input_block_sizes[params.bcast_dim] = head_size;
931           params.bcast_block_sizes[copy_bcast_dim] = head_size;
932           params.bcast_input_strides[copy_bcast_dim] =
933               params.input_block_strides[params.bcast_dim];
934           params.bcast_block_strides[copy_bcast_dim] =
935               params.output_strides[params.bcast_dim];
936           params.bcast_block_sizes[broadcast_bcast_dim] = 1;
937           params.bcast_input_strides[broadcast_bcast_dim] = 0;
938           params.bcast_block_strides[broadcast_bcast_dim] =
939               params.output_strides[params.bcast_dim] *
940               params.input_dims[params.bcast_dim];
941 
942           num_output_coeffs += BroadcastBlock(
943               params.input_block_sizes, params.input_block_strides,
944               params.bcast_block_sizes, params.bcast_block_strides,
945               params.bcast_input_strides, bcast_offset, 0, scratch,
946               materialized_output, materialized_input, materialized_input_size);
947         }
948         if (first_multiple < last_multiple) {
949           params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
950           params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
951           params.bcast_input_strides[copy_bcast_dim] =
952               params.input_block_strides[params.bcast_dim];
953           params.bcast_block_strides[copy_bcast_dim] =
954               params.output_strides[params.bcast_dim];
955           params.bcast_block_sizes[broadcast_bcast_dim] =
956               (last_multiple - first_multiple) / input_bcast_dim_size;
957           params.bcast_input_strides[broadcast_bcast_dim] = 0;
958           params.bcast_block_strides[broadcast_bcast_dim] =
959               params.output_strides[params.bcast_dim] *
960               params.input_dims[params.bcast_dim];
961           const Index offset = (first_multiple - bcast_dim_left_index) *
962                                m_outputStrides[params.bcast_dim];
963 
964           num_output_coeffs += BroadcastBlock(
965               params.input_block_sizes, params.input_block_strides,
966               params.bcast_block_sizes, params.bcast_block_strides,
967               params.bcast_input_strides, bcast_offset, offset, scratch,
968               materialized_output, materialized_input, materialized_input_size);
969         }
970         if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
971           const Index tail_size =
972               bcast_dim_left_index + params.bcast_dim_size - last_multiple;
973           params.input_block_sizes[params.bcast_dim] = tail_size;
974           params.bcast_block_sizes[copy_bcast_dim] = tail_size;
975           params.bcast_input_strides[copy_bcast_dim] =
976               params.input_block_strides[params.bcast_dim];
977           params.bcast_block_strides[copy_bcast_dim] =
978               params.output_strides[params.bcast_dim];
979           params.bcast_block_sizes[broadcast_bcast_dim] = 1;
980           params.bcast_input_strides[broadcast_bcast_dim] = 0;
981           params.bcast_block_strides[broadcast_bcast_dim] =
982               params.output_strides[params.bcast_dim] *
983               params.input_dims[params.bcast_dim];
984           const Index offset = (last_multiple - bcast_dim_left_index) *
985                                m_outputStrides[params.bcast_dim];
986 
987           num_output_coeffs += BroadcastBlock(
988               params.input_block_sizes, params.input_block_strides,
989               params.bcast_block_sizes, params.bcast_block_strides,
990               params.bcast_input_strides, bcast_offset, offset, scratch,
991               materialized_output, materialized_input, materialized_input_size);
992         }
993       } else {
994         // b and c do not exist.
995         const int copy_bcast_dim =
996             IsColMajor ? 2 * params.inner_dim_count
997                        : 2 * NumDims - 2 * params.inner_dim_count - 1;
998         params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
999         params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
1000         params.bcast_input_strides[copy_bcast_dim] =
1001             params.input_block_strides[params.bcast_dim];
1002         params.bcast_block_strides[copy_bcast_dim] =
1003             params.output_strides[params.bcast_dim];
1004 
1005         num_output_coeffs += BroadcastBlock(
1006             params.input_block_sizes, params.input_block_strides,
1007             params.bcast_block_sizes, params.bcast_block_strides,
1008             params.bcast_input_strides, bcast_offset, 0, scratch,
1009             materialized_output, materialized_input, materialized_input_size);
1010       }
1011 
1012       return num_output_coeffs;
1013     }
1014   }
1015 
1016   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock(
1017       const Dimensions& input_block_sizes,
1018       const Dimensions& input_block_strides,
1019       const BroadcastDimensions& bcast_block_sizes,
1020       const BroadcastDimensions& bcast_block_strides,
1021       const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
1022       Index offset, TensorBlockScratch& scratch,
1023       ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
1024       size_t* materialized_input_size) const {
1025     // ---------------------------------------------------------------------- //
1026     // Tensor block descriptor for reading block from the input.
1027     const Index input_offset = bcast_offset + offset;
1028     TensorBlockDesc input_desc(
1029         IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1030         input_block_sizes);
1031 
1032     ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
1033 
1034     // ---------------------------------------------------------------------- //
1035     // Materialize input block into a temporary memory buffer only if it's not
1036     // already available in the arg block.
1037     const ScalarNoConst* input_buffer = NULL;
1038 
1039     if (input_block.data() != NULL) {
1040       // Input block already has raw data, there is no need to materialize it.
1041       input_buffer = input_block.data();
1042 
1043     } else {
1044       // Otherwise we have to do block assignment into a temporary buffer.
1045 
1046       // Maybe reuse previously allocated buffer, or allocate a new one with a
1047       // scratch allocator.
1048       const size_t input_total_size = input_block_sizes.TotalSize();
1049       if (*materialized_input == NULL ||
1050           *materialized_input_size < input_total_size) {
1051         *materialized_input_size = input_total_size;
1052         void* mem = scratch.allocate(*materialized_input_size * sizeof(Scalar));
1053         *materialized_input = static_cast<ScalarNoConst*>(mem);
1054       }
1055 
1056       typedef internal::TensorBlockAssignment<
1057           ScalarNoConst, NumDims, typename ArgTensorBlock::XprType, Index>
1058           TensorBlockAssignment;
1059 
1060       TensorBlockAssignment::Run(
1061           TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1062                                         *materialized_input),
1063           input_block.expr());
1064 
1065       input_buffer = *materialized_input;
1066     }
1067 
1068     // ---------------------------------------------------------------------- //
1069     // Copy data from materialized input block to the materialized output, using
1070     // given broadcast strides (strides with zeroes).
1071     typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout>
1072         TensorBlockIO;
1073 
1074     typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1075     typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1076                                       materialized_output + offset);
1077 
1078     return TensorBlockIO::Copy(dst, src);
1079   }
1080 
1081 protected:
1082   const Device EIGEN_DEVICE_REF m_device;
1083   const typename internal::remove_reference<Broadcast>::type m_broadcast;
1084   Dimensions m_dimensions;
1085   array<Index, NumDims> m_outputStrides;
1086   array<Index, NumDims> m_inputStrides;
1087   TensorEvaluator<ArgType, Device> m_impl;
1088 };
1089 
1090 
1091 } // end namespace Eigen
1092 
1093 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
1094