1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H 12 13 namespace Eigen { 14 15 /** \class TensorAssign 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief The tensor assignment class. 19 * 20 * This class is represents the assignment of the values resulting from the evaluation of 21 * the rhs expression to the memory locations denoted by the lhs expression. 22 */ 23 namespace internal { 24 template<typename LhsXprType, typename RhsXprType> 25 struct traits<TensorAssignOp<LhsXprType, RhsXprType> > 26 { 27 typedef typename LhsXprType::Scalar Scalar; 28 typedef typename traits<LhsXprType>::StorageKind StorageKind; 29 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 30 typename traits<RhsXprType>::Index>::type Index; 31 typedef typename LhsXprType::Nested LhsNested; 32 typedef typename RhsXprType::Nested RhsNested; 33 typedef typename remove_reference<LhsNested>::type _LhsNested; 34 typedef typename remove_reference<RhsNested>::type _RhsNested; 35 static const std::size_t NumDimensions = internal::traits<LhsXprType>::NumDimensions; 36 static const int Layout = internal::traits<LhsXprType>::Layout; 37 typedef typename traits<LhsXprType>::PointerType PointerType; 38 39 enum { 40 Flags = 0 41 }; 42 }; 43 44 template<typename LhsXprType, typename RhsXprType> 45 struct eval<TensorAssignOp<LhsXprType, RhsXprType>, Eigen::Dense> 46 { 47 typedef const TensorAssignOp<LhsXprType, RhsXprType>& type; 48 }; 49 50 template<typename LhsXprType, typename RhsXprType> 51 struct nested<TensorAssignOp<LhsXprType, RhsXprType>, 1, typename eval<TensorAssignOp<LhsXprType, RhsXprType> >::type> 52 { 53 typedef TensorAssignOp<LhsXprType, RhsXprType> type; 54 }; 55 56 } // end namespace internal 57 58 59 60 template<typename LhsXprType, typename RhsXprType> 61 class TensorAssignOp : public TensorBase<TensorAssignOp<LhsXprType, RhsXprType> > 62 { 63 public: 64 typedef typename Eigen::internal::traits<TensorAssignOp>::Scalar Scalar; 65 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 66 typedef typename LhsXprType::CoeffReturnType CoeffReturnType; 67 typedef typename Eigen::internal::nested<TensorAssignOp>::type Nested; 68 typedef typename Eigen::internal::traits<TensorAssignOp>::StorageKind StorageKind; 69 typedef typename Eigen::internal::traits<TensorAssignOp>::Index Index; 70 71 static const int NumDims = Eigen::internal::traits<TensorAssignOp>::NumDimensions; 72 73 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorAssignOp(LhsXprType& lhs, const RhsXprType& rhs) 74 : m_lhs_xpr(lhs), m_rhs_xpr(rhs) {} 75 76 /** \returns the nested expressions */ 77 EIGEN_DEVICE_FUNC 78 typename internal::remove_all<typename LhsXprType::Nested>::type& 79 lhsExpression() const { return *((typename internal::remove_all<typename LhsXprType::Nested>::type*)&m_lhs_xpr); } 80 81 EIGEN_DEVICE_FUNC 82 const typename internal::remove_all<typename RhsXprType::Nested>::type& 83 rhsExpression() const { return m_rhs_xpr; } 84 85 protected: 86 typename internal::remove_all<typename LhsXprType::Nested>::type& m_lhs_xpr; 87 const typename internal::remove_all<typename RhsXprType::Nested>::type& m_rhs_xpr; 88 }; 89 90 91 template<typename LeftArgType, typename RightArgType, typename Device> 92 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> 93 { 94 typedef TensorAssignOp<LeftArgType, RightArgType> XprType; 95 typedef typename XprType::Index Index; 96 typedef typename XprType::Scalar Scalar; 97 typedef typename XprType::CoeffReturnType CoeffReturnType; 98 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 99 typedef typename TensorEvaluator<RightArgType, Device>::Dimensions Dimensions; 100 typedef StorageMemory<CoeffReturnType, Device> Storage; 101 typedef typename Storage::Type EvaluatorPointerType; 102 103 static const int PacketSize = PacketType<CoeffReturnType, Device>::size; 104 static const int NumDims = XprType::NumDims; 105 106 enum { 107 IsAligned = int(TensorEvaluator<LeftArgType, Device>::IsAligned) & 108 int(TensorEvaluator<RightArgType, Device>::IsAligned), 109 PacketAccess = int(TensorEvaluator<LeftArgType, Device>::PacketAccess) & 110 int(TensorEvaluator<RightArgType, Device>::PacketAccess), 111 BlockAccess = int(TensorEvaluator<LeftArgType, Device>::BlockAccess) & 112 int(TensorEvaluator<RightArgType, Device>::BlockAccess), 113 PreferBlockAccess = int(TensorEvaluator<LeftArgType, Device>::PreferBlockAccess) | 114 int(TensorEvaluator<RightArgType, Device>::PreferBlockAccess), 115 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 116 RawAccess = TensorEvaluator<LeftArgType, Device>::RawAccess 117 }; 118 119 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 120 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; 121 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; 122 123 typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlock 124 RightTensorBlock; 125 //===--------------------------------------------------------------------===// 126 127 TensorEvaluator(const XprType& op, const Device& device) : 128 m_leftImpl(op.lhsExpression(), device), 129 m_rightImpl(op.rhsExpression(), device) 130 { 131 EIGEN_STATIC_ASSERT( 132 (static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == 133 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)), 134 YOU_MADE_A_PROGRAMMING_MISTAKE); 135 } 136 137 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const 138 { 139 // The dimensions of the lhs and the rhs tensors should be equal to prevent 140 // overflows and ensure the result is fully initialized. 141 // TODO: use left impl instead if right impl dimensions are known at compile time. 142 return m_rightImpl.dimensions(); 143 } 144 145 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { 146 eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); 147 m_leftImpl.evalSubExprsIfNeeded(NULL); 148 // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non 149 // null value), attempt to evaluate the rhs expression in place. Returns true iff in place 150 // evaluation isn't supported and the caller still needs to manually assign the values generated 151 // by the rhs to the lhs. 152 return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data()); 153 } 154 155 #ifdef EIGEN_USE_THREADS 156 template <typename EvalSubExprsCallback> 157 EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( 158 EvaluatorPointerType, EvalSubExprsCallback done) { 159 m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done](bool) { 160 m_rightImpl.evalSubExprsIfNeededAsync( 161 m_leftImpl.data(), [done](bool need_assign) { done(need_assign); }); 162 }); 163 } 164 #endif // EIGEN_USE_THREADS 165 166 EIGEN_STRONG_INLINE void cleanup() { 167 m_leftImpl.cleanup(); 168 m_rightImpl.cleanup(); 169 } 170 171 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalScalar(Index i) { 172 m_leftImpl.coeffRef(i) = m_rightImpl.coeff(i); 173 } 174 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalPacket(Index i) { 175 176 const int LhsStoreMode = TensorEvaluator<LeftArgType, Device>::IsAligned ? Aligned : Unaligned; 177 const int RhsLoadMode = TensorEvaluator<RightArgType, Device>::IsAligned ? Aligned : Unaligned; 178 m_leftImpl.template writePacket<LhsStoreMode>(i, m_rightImpl.template packet<RhsLoadMode>(i)); 179 } 180 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 181 { 182 return m_leftImpl.coeff(index); 183 } 184 template<int LoadMode> 185 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const 186 { 187 return m_leftImpl.template packet<LoadMode>(index); 188 } 189 190 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 191 costPerCoeff(bool vectorized) const { 192 // We assume that evalPacket or evalScalar is called to perform the 193 // assignment and account for the cost of the write here, but reduce left 194 // cost by one load because we are using m_leftImpl.coeffRef. 195 TensorOpCost left = m_leftImpl.costPerCoeff(vectorized); 196 return m_rightImpl.costPerCoeff(vectorized) + 197 TensorOpCost( 198 numext::maxi(0.0, left.bytes_loaded() - sizeof(CoeffReturnType)), 199 left.bytes_stored(), left.compute_cycles()) + 200 TensorOpCost(0, sizeof(CoeffReturnType), 0, vectorized, PacketSize); 201 } 202 203 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 204 internal::TensorBlockResourceRequirements getResourceRequirements() const { 205 return internal::TensorBlockResourceRequirements::merge( 206 m_leftImpl.getResourceRequirements(), 207 m_rightImpl.getResourceRequirements()); 208 } 209 210 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlock( 211 TensorBlockDesc& desc, TensorBlockScratch& scratch) { 212 if (TensorEvaluator<LeftArgType, Device>::RawAccess && 213 m_leftImpl.data() != NULL) { 214 // If destination has raw data access, we pass it as a potential 215 // destination for a block descriptor evaluation. 216 desc.template AddDestinationBuffer<Layout>( 217 /*dst_base=*/m_leftImpl.data() + desc.offset(), 218 /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions())); 219 } 220 221 RightTensorBlock block = m_rightImpl.block(desc, scratch, /*root_of_expr_ast=*/true); 222 // If block was evaluated into a destination, there is no need to do assignment. 223 if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) { 224 m_leftImpl.writeBlock(desc, block); 225 } 226 block.cleanup(); 227 } 228 229 #ifdef EIGEN_USE_SYCL 230 // binding placeholder accessors to a command group handler for SYCL 231 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 232 m_leftImpl.bind(cgh); 233 m_rightImpl.bind(cgh); 234 } 235 #endif 236 237 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_leftImpl.data(); } 238 239 private: 240 TensorEvaluator<LeftArgType, Device> m_leftImpl; 241 TensorEvaluator<RightArgType, Device> m_rightImpl; 242 }; 243 244 } 245 246 247 #endif // EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H 248