• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
12 
13 namespace Eigen {
14 
15 /** \class TensorAssign
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief The tensor assignment class.
19   *
20   * This class is represents the assignment of the values resulting from the evaluation of
21   * the rhs expression to the memory locations denoted by the lhs expression.
22   */
23 namespace internal {
24 template<typename LhsXprType, typename RhsXprType>
25 struct traits<TensorAssignOp<LhsXprType, RhsXprType> >
26 {
27   typedef typename LhsXprType::Scalar Scalar;
28   typedef typename traits<LhsXprType>::StorageKind StorageKind;
29   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
30                                       typename traits<RhsXprType>::Index>::type Index;
31   typedef typename LhsXprType::Nested LhsNested;
32   typedef typename RhsXprType::Nested RhsNested;
33   typedef typename remove_reference<LhsNested>::type _LhsNested;
34   typedef typename remove_reference<RhsNested>::type _RhsNested;
35   static const std::size_t NumDimensions = internal::traits<LhsXprType>::NumDimensions;
36   static const int Layout = internal::traits<LhsXprType>::Layout;
37   typedef typename traits<LhsXprType>::PointerType PointerType;
38 
39   enum {
40     Flags = 0
41   };
42 };
43 
44 template<typename LhsXprType, typename RhsXprType>
45 struct eval<TensorAssignOp<LhsXprType, RhsXprType>, Eigen::Dense>
46 {
47   typedef const TensorAssignOp<LhsXprType, RhsXprType>& type;
48 };
49 
50 template<typename LhsXprType, typename RhsXprType>
51 struct nested<TensorAssignOp<LhsXprType, RhsXprType>, 1, typename eval<TensorAssignOp<LhsXprType, RhsXprType> >::type>
52 {
53   typedef TensorAssignOp<LhsXprType, RhsXprType> type;
54 };
55 
56 }  // end namespace internal
57 
58 
59 
60 template<typename LhsXprType, typename RhsXprType>
61 class TensorAssignOp : public TensorBase<TensorAssignOp<LhsXprType, RhsXprType> >
62 {
63   public:
64   typedef typename Eigen::internal::traits<TensorAssignOp>::Scalar Scalar;
65   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
66   typedef typename LhsXprType::CoeffReturnType CoeffReturnType;
67   typedef typename Eigen::internal::nested<TensorAssignOp>::type Nested;
68   typedef typename Eigen::internal::traits<TensorAssignOp>::StorageKind StorageKind;
69   typedef typename Eigen::internal::traits<TensorAssignOp>::Index Index;
70 
71   static const int NumDims = Eigen::internal::traits<TensorAssignOp>::NumDimensions;
72 
73   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorAssignOp(LhsXprType& lhs, const RhsXprType& rhs)
74       : m_lhs_xpr(lhs), m_rhs_xpr(rhs) {}
75 
76     /** \returns the nested expressions */
77     EIGEN_DEVICE_FUNC
78     typename internal::remove_all<typename LhsXprType::Nested>::type&
79     lhsExpression() const { return *((typename internal::remove_all<typename LhsXprType::Nested>::type*)&m_lhs_xpr); }
80 
81     EIGEN_DEVICE_FUNC
82     const typename internal::remove_all<typename RhsXprType::Nested>::type&
83     rhsExpression() const { return m_rhs_xpr; }
84 
85   protected:
86     typename internal::remove_all<typename LhsXprType::Nested>::type& m_lhs_xpr;
87     const typename internal::remove_all<typename RhsXprType::Nested>::type& m_rhs_xpr;
88 };
89 
90 
91 template<typename LeftArgType, typename RightArgType, typename Device>
92 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
93 {
94   typedef TensorAssignOp<LeftArgType, RightArgType> XprType;
95   typedef typename XprType::Index Index;
96   typedef typename XprType::Scalar Scalar;
97   typedef typename XprType::CoeffReturnType CoeffReturnType;
98   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
99   typedef typename TensorEvaluator<RightArgType, Device>::Dimensions Dimensions;
100   typedef StorageMemory<CoeffReturnType, Device> Storage;
101   typedef typename Storage::Type EvaluatorPointerType;
102 
103   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
104   static const int NumDims = XprType::NumDims;
105 
106   enum {
107     IsAligned         = int(TensorEvaluator<LeftArgType, Device>::IsAligned) &
108                         int(TensorEvaluator<RightArgType, Device>::IsAligned),
109     PacketAccess      = int(TensorEvaluator<LeftArgType, Device>::PacketAccess) &
110                         int(TensorEvaluator<RightArgType, Device>::PacketAccess),
111     BlockAccess       = int(TensorEvaluator<LeftArgType, Device>::BlockAccess) &
112                         int(TensorEvaluator<RightArgType, Device>::BlockAccess),
113     PreferBlockAccess = int(TensorEvaluator<LeftArgType, Device>::PreferBlockAccess) |
114                         int(TensorEvaluator<RightArgType, Device>::PreferBlockAccess),
115     Layout            = TensorEvaluator<LeftArgType, Device>::Layout,
116     RawAccess         = TensorEvaluator<LeftArgType, Device>::RawAccess
117   };
118 
119   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
120   typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
121   typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
122 
123   typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlock
124       RightTensorBlock;
125   //===--------------------------------------------------------------------===//
126 
127   TensorEvaluator(const XprType& op, const Device& device) :
128       m_leftImpl(op.lhsExpression(), device),
129       m_rightImpl(op.rhsExpression(), device)
130   {
131     EIGEN_STATIC_ASSERT(
132         (static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
133          static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
134         YOU_MADE_A_PROGRAMMING_MISTAKE);
135   }
136 
137   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
138   {
139     // The dimensions of the lhs and the rhs tensors should be equal to prevent
140     // overflows and ensure the result is fully initialized.
141     // TODO: use left impl instead if right impl dimensions are known at compile time.
142     return m_rightImpl.dimensions();
143   }
144 
145   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
146     eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
147     m_leftImpl.evalSubExprsIfNeeded(NULL);
148     // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non
149     // null value), attempt to evaluate the rhs expression in place. Returns true iff in place
150     // evaluation isn't supported and the caller still needs to manually assign the values generated
151     // by the rhs to the lhs.
152     return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data());
153   }
154 
155 #ifdef EIGEN_USE_THREADS
156   template <typename EvalSubExprsCallback>
157   EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
158       EvaluatorPointerType, EvalSubExprsCallback done) {
159     m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done](bool) {
160       m_rightImpl.evalSubExprsIfNeededAsync(
161           m_leftImpl.data(), [done](bool need_assign) { done(need_assign); });
162     });
163   }
164 #endif  // EIGEN_USE_THREADS
165 
166   EIGEN_STRONG_INLINE void cleanup() {
167     m_leftImpl.cleanup();
168     m_rightImpl.cleanup();
169   }
170 
171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalScalar(Index i) {
172     m_leftImpl.coeffRef(i) = m_rightImpl.coeff(i);
173   }
174   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalPacket(Index i) {
175 
176     const int LhsStoreMode = TensorEvaluator<LeftArgType, Device>::IsAligned ? Aligned : Unaligned;
177     const int RhsLoadMode = TensorEvaluator<RightArgType, Device>::IsAligned ? Aligned : Unaligned;
178     m_leftImpl.template writePacket<LhsStoreMode>(i, m_rightImpl.template packet<RhsLoadMode>(i));
179   }
180   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
181   {
182     return m_leftImpl.coeff(index);
183   }
184   template<int LoadMode>
185   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
186   {
187     return m_leftImpl.template packet<LoadMode>(index);
188   }
189 
190   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
191   costPerCoeff(bool vectorized) const {
192     // We assume that evalPacket or evalScalar is called to perform the
193     // assignment and account for the cost of the write here, but reduce left
194     // cost by one load because we are using m_leftImpl.coeffRef.
195     TensorOpCost left = m_leftImpl.costPerCoeff(vectorized);
196     return m_rightImpl.costPerCoeff(vectorized) +
197            TensorOpCost(
198                numext::maxi(0.0, left.bytes_loaded() - sizeof(CoeffReturnType)),
199                left.bytes_stored(), left.compute_cycles()) +
200            TensorOpCost(0, sizeof(CoeffReturnType), 0, vectorized, PacketSize);
201   }
202 
203   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
204   internal::TensorBlockResourceRequirements getResourceRequirements() const {
205     return internal::TensorBlockResourceRequirements::merge(
206         m_leftImpl.getResourceRequirements(),
207         m_rightImpl.getResourceRequirements());
208   }
209 
210   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlock(
211       TensorBlockDesc& desc, TensorBlockScratch& scratch) {
212     if (TensorEvaluator<LeftArgType, Device>::RawAccess &&
213         m_leftImpl.data() != NULL) {
214       // If destination has raw data access, we pass it as a potential
215       // destination for a block descriptor evaluation.
216       desc.template AddDestinationBuffer<Layout>(
217           /*dst_base=*/m_leftImpl.data() + desc.offset(),
218           /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions()));
219     }
220 
221     RightTensorBlock block = m_rightImpl.block(desc, scratch, /*root_of_expr_ast=*/true);
222     // If block was evaluated into a destination, there is no need to do assignment.
223     if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) {
224       m_leftImpl.writeBlock(desc, block);
225     }
226     block.cleanup();
227   }
228 
229 #ifdef EIGEN_USE_SYCL
230   // binding placeholder accessors to a command group handler for SYCL
231   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
232     m_leftImpl.bind(cgh);
233     m_rightImpl.bind(cgh);
234   }
235 #endif
236 
237   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_leftImpl.data(); }
238 
239  private:
240   TensorEvaluator<LeftArgType, Device> m_leftImpl;
241   TensorEvaluator<RightArgType, Device> m_rightImpl;
242 };
243 
244 }
245 
246 
247 #endif // EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
248