1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
12
13 namespace Eigen {
14
15 /** \class TensorEvaluator
16 * \ingroup CXX11_Tensor_Module
17 *
18 * \brief The tensor evaluator classes.
19 *
20 * These classes are responsible for the evaluation of the tensor expression.
21 *
22 * TODO: add support for more types of expressions, in particular expressions
23 * leading to lvalues (slicing, reshaping, etc...)
24 */
25
26 // Generic evaluator
27 template<typename Derived, typename Device>
28 struct TensorEvaluator
29 {
30 typedef typename Derived::Index Index;
31 typedef typename Derived::Scalar Scalar;
32 typedef typename Derived::Scalar CoeffReturnType;
33 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
34 typedef typename Derived::Dimensions Dimensions;
35
36 // NumDimensions is -1 for variable dim tensors
37 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
38 internal::traits<Derived>::NumDimensions : 0;
39
40 enum {
41 IsAligned = Derived::IsAligned,
42 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
43 Layout = Derived::Layout,
44 CoordAccess = NumCoords > 0,
45 RawAccess = true
46 };
47
TensorEvaluatorTensorEvaluator48 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
49 : m_data(const_cast<typename internal::traits<Derived>::template MakePointer<Scalar>::Type>(m.data())), m_dims(m.dimensions()), m_device(device), m_impl(m)
50 { }
51
52 // Used for accessor extraction in SYCL Managed TensorMap:
derivedTensorEvaluator53 const Derived& derived() const { return m_impl; }
dimensionsTensorEvaluator54 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
55
evalSubExprsIfNeededTensorEvaluator56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* dest) {
57 if (dest) {
58 m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize());
59 return false;
60 }
61 return true;
62 }
63
cleanupTensorEvaluator64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
65
coeffTensorEvaluator66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
67 eigen_assert(m_data);
68 return m_data[index];
69 }
70
coeffRefTensorEvaluator71 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
72 eigen_assert(m_data);
73 return m_data[index];
74 }
75
76 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
packetTensorEvaluator77 PacketReturnType packet(Index index) const
78 {
79 return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
80 }
81
82 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
writePacketTensorEvaluator83 void writePacket(Index index, const PacketReturnType& x)
84 {
85 return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x);
86 }
87
coeffTensorEvaluator88 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
89 eigen_assert(m_data);
90 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
91 return m_data[m_dims.IndexOfColMajor(coords)];
92 } else {
93 return m_data[m_dims.IndexOfRowMajor(coords)];
94 }
95 }
96
coeffRefTensorEvaluator97 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
98 eigen_assert(m_data);
99 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
100 return m_data[m_dims.IndexOfColMajor(coords)];
101 } else {
102 return m_data[m_dims.IndexOfRowMajor(coords)];
103 }
104 }
105
costPerCoeffTensorEvaluator106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
107 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
108 internal::unpacket_traits<PacketReturnType>::size);
109 }
110
dataTensorEvaluator111 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<Scalar>::Type data() const { return m_data; }
112
113 /// required by sycl in order to construct sycl buffer from raw pointer
deviceTensorEvaluator114 const Device& device() const{return m_device;}
115
116 protected:
117 typename internal::traits<Derived>::template MakePointer<Scalar>::Type m_data;
118 Dimensions m_dims;
119 const Device& m_device;
120 const Derived& m_impl;
121 };
122
123 namespace {
124 template <typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const T * address)125 T loadConstant(const T* address) {
126 return *address;
127 }
128 // Use the texture cache on CUDA devices whenever possible
129 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
130 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const float * address)131 float loadConstant(const float* address) {
132 return __ldg(address);
133 }
134 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const double * address)135 double loadConstant(const double* address) {
136 return __ldg(address);
137 }
138 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const Eigen::half * address)139 Eigen::half loadConstant(const Eigen::half* address) {
140 return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x)));
141 }
142 #endif
143 }
144
145
146 // Default evaluator for rvalues
147 template<typename Derived, typename Device>
148 struct TensorEvaluator<const Derived, Device>
149 {
150 typedef typename Derived::Index Index;
151 typedef typename Derived::Scalar Scalar;
152 typedef typename Derived::Scalar CoeffReturnType;
153 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
154 typedef typename Derived::Dimensions Dimensions;
155
156 // NumDimensions is -1 for variable dim tensors
157 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
158 internal::traits<Derived>::NumDimensions : 0;
159
160 enum {
161 IsAligned = Derived::IsAligned,
162 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
163 Layout = Derived::Layout,
164 CoordAccess = NumCoords > 0,
165 RawAccess = true
166 };
167
168 // Used for accessor extraction in SYCL Managed TensorMap:
169 const Derived& derived() const { return m_impl; }
170
171 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
172 : m_data(m.data()), m_dims(m.dimensions()), m_device(device), m_impl(m)
173 { }
174
175 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
176
177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
178 if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization && data) {
179 m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
180 return false;
181 }
182 return true;
183 }
184
185 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
186
187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
188 eigen_assert(m_data);
189 return loadConstant(m_data+index);
190 }
191
192 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
193 PacketReturnType packet(Index index) const
194 {
195 return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
196 }
197
198 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
199 eigen_assert(m_data);
200 const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords)
201 : m_dims.IndexOfRowMajor(coords);
202 return loadConstant(m_data+index);
203 }
204
205 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
206 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
207 internal::unpacket_traits<PacketReturnType>::size);
208 }
209
210 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<const Scalar>::Type data() const { return m_data; }
211
212 /// added for sycl in order to construct the buffer from the sycl device
213 const Device& device() const{return m_device;}
214
215 protected:
216 typename internal::traits<Derived>::template MakePointer<const Scalar>::Type m_data;
217 Dimensions m_dims;
218 const Device& m_device;
219 const Derived& m_impl;
220 };
221
222
223
224
225 // -------------------- CwiseNullaryOp --------------------
226
227 template<typename NullaryOp, typename ArgType, typename Device>
228 struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
229 {
230 typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
231
232 enum {
233 IsAligned = true,
234 PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
235 Layout = TensorEvaluator<ArgType, Device>::Layout,
236 CoordAccess = false, // to be implemented
237 RawAccess = false
238 };
239
240 EIGEN_DEVICE_FUNC
241 TensorEvaluator(const XprType& op, const Device& device)
242 : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper()
243 { }
244
245 typedef typename XprType::Index Index;
246 typedef typename XprType::Scalar Scalar;
247 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
248 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
249 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
250 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
251
252 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
253
254 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
255 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
256
257 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
258 {
259 return m_wrapper(m_functor, index);
260 }
261
262 template<int LoadMode>
263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
264 {
265 return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index);
266 }
267
268 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
269 costPerCoeff(bool vectorized) const {
270 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
271 internal::unpacket_traits<PacketReturnType>::size);
272 }
273
274 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
275
276 /// required by sycl in order to extract the accessor
277 const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; }
278 /// required by sycl in order to extract the accessor
279 NullaryOp functor() const { return m_functor; }
280
281
282 private:
283 const NullaryOp m_functor;
284 TensorEvaluator<ArgType, Device> m_argImpl;
285 const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
286 };
287
288
289
290 // -------------------- CwiseUnaryOp --------------------
291
292 template<typename UnaryOp, typename ArgType, typename Device>
293 struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
294 {
295 typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
296
297 enum {
298 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
299 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
300 Layout = TensorEvaluator<ArgType, Device>::Layout,
301 CoordAccess = false, // to be implemented
302 RawAccess = false
303 };
304
305 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
306 : m_functor(op.functor()),
307 m_argImpl(op.nestedExpression(), device)
308 { }
309
310 typedef typename XprType::Index Index;
311 typedef typename XprType::Scalar Scalar;
312 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
313 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
314 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
315 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
316
317 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
318
319 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
320 m_argImpl.evalSubExprsIfNeeded(NULL);
321 return true;
322 }
323 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
324 m_argImpl.cleanup();
325 }
326
327 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
328 {
329 return m_functor(m_argImpl.coeff(index));
330 }
331
332 template<int LoadMode>
333 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
334 {
335 return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index));
336 }
337
338 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
339 const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
340 return m_argImpl.costPerCoeff(vectorized) +
341 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
342 }
343
344 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
345
346 /// required by sycl in order to extract the accessor
347 const TensorEvaluator<ArgType, Device> & impl() const { return m_argImpl; }
348 /// added for sycl in order to construct the buffer from sycl device
349 UnaryOp functor() const { return m_functor; }
350
351
352 private:
353 const UnaryOp m_functor;
354 TensorEvaluator<ArgType, Device> m_argImpl;
355 };
356
357
358 // -------------------- CwiseBinaryOp --------------------
359
360 template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
361 struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
362 {
363 typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
364
365 enum {
366 IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
367 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
368 internal::functor_traits<BinaryOp>::PacketAccess,
369 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
370 CoordAccess = false, // to be implemented
371 RawAccess = false
372 };
373
374 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
375 : m_functor(op.functor()),
376 m_leftImpl(op.lhsExpression(), device),
377 m_rightImpl(op.rhsExpression(), device)
378 {
379 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
380 eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
381 }
382
383 typedef typename XprType::Index Index;
384 typedef typename XprType::Scalar Scalar;
385 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
386 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
387 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
388 typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
389
390 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
391 {
392 // TODO: use right impl instead if right impl dimensions are known at compile time.
393 return m_leftImpl.dimensions();
394 }
395
396 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
397 m_leftImpl.evalSubExprsIfNeeded(NULL);
398 m_rightImpl.evalSubExprsIfNeeded(NULL);
399 return true;
400 }
401 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
402 m_leftImpl.cleanup();
403 m_rightImpl.cleanup();
404 }
405
406 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
407 {
408 return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index));
409 }
410 template<int LoadMode>
411 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
412 {
413 return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index));
414 }
415
416 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
417 costPerCoeff(bool vectorized) const {
418 const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
419 return m_leftImpl.costPerCoeff(vectorized) +
420 m_rightImpl.costPerCoeff(vectorized) +
421 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
422 }
423
424 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
425 /// required by sycl in order to extract the accessor
426 const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; }
427 /// required by sycl in order to extract the accessor
428 const TensorEvaluator<RightArgType, Device>& right_impl() const { return m_rightImpl; }
429 /// required by sycl in order to extract the accessor
430 BinaryOp functor() const { return m_functor; }
431
432 private:
433 const BinaryOp m_functor;
434 TensorEvaluator<LeftArgType, Device> m_leftImpl;
435 TensorEvaluator<RightArgType, Device> m_rightImpl;
436 };
437
438 // -------------------- CwiseTernaryOp --------------------
439
440 template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
441 struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
442 {
443 typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
444
445 enum {
446 IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
447 PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
448 internal::functor_traits<TernaryOp>::PacketAccess,
449 Layout = TensorEvaluator<Arg1Type, Device>::Layout,
450 CoordAccess = false, // to be implemented
451 RawAccess = false
452 };
453
454 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
455 : m_functor(op.functor()),
456 m_arg1Impl(op.arg1Expression(), device),
457 m_arg2Impl(op.arg2Expression(), device),
458 m_arg3Impl(op.arg3Expression(), device)
459 {
460 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
461
462 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
463 typename internal::traits<Arg2Type>::StorageKind>::value),
464 STORAGE_KIND_MUST_MATCH)
465 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
466 typename internal::traits<Arg3Type>::StorageKind>::value),
467 STORAGE_KIND_MUST_MATCH)
468 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
469 typename internal::traits<Arg2Type>::Index>::value),
470 STORAGE_INDEX_MUST_MATCH)
471 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
472 typename internal::traits<Arg3Type>::Index>::value),
473 STORAGE_INDEX_MUST_MATCH)
474
475 eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
476 }
477
478 typedef typename XprType::Index Index;
479 typedef typename XprType::Scalar Scalar;
480 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
481 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
482 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
483 typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
484
485 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
486 {
487 // TODO: use arg2 or arg3 dimensions if they are known at compile time.
488 return m_arg1Impl.dimensions();
489 }
490
491 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
492 m_arg1Impl.evalSubExprsIfNeeded(NULL);
493 m_arg2Impl.evalSubExprsIfNeeded(NULL);
494 m_arg3Impl.evalSubExprsIfNeeded(NULL);
495 return true;
496 }
497 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
498 m_arg1Impl.cleanup();
499 m_arg2Impl.cleanup();
500 m_arg3Impl.cleanup();
501 }
502
503 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
504 {
505 return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
506 }
507 template<int LoadMode>
508 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
509 {
510 return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
511 m_arg2Impl.template packet<LoadMode>(index),
512 m_arg3Impl.template packet<LoadMode>(index));
513 }
514
515 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
516 costPerCoeff(bool vectorized) const {
517 const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
518 return m_arg1Impl.costPerCoeff(vectorized) +
519 m_arg2Impl.costPerCoeff(vectorized) +
520 m_arg3Impl.costPerCoeff(vectorized) +
521 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
522 }
523
524 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
525
526 /// required by sycl in order to extract the accessor
527 const TensorEvaluator<Arg1Type, Device> & arg1Impl() const { return m_arg1Impl; }
528 /// required by sycl in order to extract the accessor
529 const TensorEvaluator<Arg2Type, Device>& arg2Impl() const { return m_arg2Impl; }
530 /// required by sycl in order to extract the accessor
531 const TensorEvaluator<Arg3Type, Device>& arg3Impl() const { return m_arg3Impl; }
532
533 private:
534 const TernaryOp m_functor;
535 TensorEvaluator<Arg1Type, Device> m_arg1Impl;
536 TensorEvaluator<Arg2Type, Device> m_arg2Impl;
537 TensorEvaluator<Arg3Type, Device> m_arg3Impl;
538 };
539
540
541 // -------------------- SelectOp --------------------
542
543 template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
544 struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
545 {
546 typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
547 typedef typename XprType::Scalar Scalar;
548
549 enum {
550 IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
551 PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &
552 internal::packet_traits<Scalar>::HasBlend,
553 Layout = TensorEvaluator<IfArgType, Device>::Layout,
554 CoordAccess = false, // to be implemented
555 RawAccess = false
556 };
557
558 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
559 : m_condImpl(op.ifExpression(), device),
560 m_thenImpl(op.thenExpression(), device),
561 m_elseImpl(op.elseExpression(), device)
562 {
563 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
564 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
565 eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
566 eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
567 }
568
569 typedef typename XprType::Index Index;
570 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
571 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
572 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
573 typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
574
575 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
576 {
577 // TODO: use then or else impl instead if they happen to be known at compile time.
578 return m_condImpl.dimensions();
579 }
580
581 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
582 m_condImpl.evalSubExprsIfNeeded(NULL);
583 m_thenImpl.evalSubExprsIfNeeded(NULL);
584 m_elseImpl.evalSubExprsIfNeeded(NULL);
585 return true;
586 }
587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
588 m_condImpl.cleanup();
589 m_thenImpl.cleanup();
590 m_elseImpl.cleanup();
591 }
592
593 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
594 {
595 return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
596 }
597 template<int LoadMode>
598 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
599 {
600 internal::Selector<PacketSize> select;
601 for (Index i = 0; i < PacketSize; ++i) {
602 select.select[i] = m_condImpl.coeff(index+i);
603 }
604 return internal::pblend(select,
605 m_thenImpl.template packet<LoadMode>(index),
606 m_elseImpl.template packet<LoadMode>(index));
607 }
608
609 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
610 costPerCoeff(bool vectorized) const {
611 return m_condImpl.costPerCoeff(vectorized) +
612 m_thenImpl.costPerCoeff(vectorized)
613 .cwiseMax(m_elseImpl.costPerCoeff(vectorized));
614 }
615
616 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType* data() const { return NULL; }
617 /// required by sycl in order to extract the accessor
618 const TensorEvaluator<IfArgType, Device> & cond_impl() const { return m_condImpl; }
619 /// required by sycl in order to extract the accessor
620 const TensorEvaluator<ThenArgType, Device>& then_impl() const { return m_thenImpl; }
621 /// required by sycl in order to extract the accessor
622 const TensorEvaluator<ElseArgType, Device>& else_impl() const { return m_elseImpl; }
623
624 private:
625 TensorEvaluator<IfArgType, Device> m_condImpl;
626 TensorEvaluator<ThenArgType, Device> m_thenImpl;
627 TensorEvaluator<ElseArgType, Device> m_elseImpl;
628 };
629
630
631 } // end namespace Eigen
632
633 #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
634