1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 12 13 namespace Eigen { 14 15 /** \class TensorConcatenationOp 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor concatenation class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename Axis, typename LhsXprType, typename RhsXprType> 24 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > 25 { 26 // Type promotion to handle the case where the types of the lhs and the rhs are different. 27 typedef typename promote_storage_type<typename LhsXprType::Scalar, 28 typename RhsXprType::Scalar>::ret Scalar; 29 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 30 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 31 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 32 typename traits<RhsXprType>::Index>::type Index; 33 typedef typename LhsXprType::Nested LhsNested; 34 typedef typename RhsXprType::Nested RhsNested; 35 typedef typename remove_reference<LhsNested>::type _LhsNested; 36 typedef typename remove_reference<RhsNested>::type _RhsNested; 37 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 38 static const int Layout = traits<LhsXprType>::Layout; 39 enum { Flags = 0 }; 40 }; 41 42 template<typename Axis, typename LhsXprType, typename RhsXprType> 43 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> 44 { 45 typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type; 46 }; 47 48 template<typename Axis, typename LhsXprType, typename RhsXprType> 49 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1, typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> 50 { 51 typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type; 52 }; 53 54 } // end namespace internal 55 56 57 template<typename Axis, typename LhsXprType, typename RhsXprType> 58 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> 59 { 60 public: 61 typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar; 62 typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind; 63 typedef typename internal::traits<TensorConcatenationOp>::Index Index; 64 typedef typename internal::nested<TensorConcatenationOp>::type Nested; 65 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 66 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 67 typedef typename NumTraits<Scalar>::Real RealScalar; 68 69 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis) 70 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {} 71 72 EIGEN_DEVICE_FUNC 73 const typename internal::remove_all<typename LhsXprType::Nested>::type& 74 lhsExpression() const { return m_lhs_xpr; } 75 76 EIGEN_DEVICE_FUNC 77 const typename internal::remove_all<typename RhsXprType::Nested>::type& 78 rhsExpression() const { return m_rhs_xpr; } 79 80 EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; } 81 82 EIGEN_DEVICE_FUNC 83 EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other) 84 { 85 typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign; 86 Assign assign(*this, other); 87 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 88 return *this; 89 } 90 91 template<typename OtherDerived> 92 EIGEN_DEVICE_FUNC 93 EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other) 94 { 95 typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign; 96 Assign assign(*this, other); 97 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 98 return *this; 99 } 100 101 protected: 102 typename LhsXprType::Nested m_lhs_xpr; 103 typename RhsXprType::Nested m_rhs_xpr; 104 const Axis m_axis; 105 }; 106 107 108 // Eval as rvalue 109 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 110 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 111 { 112 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 113 typedef typename XprType::Index Index; 114 static const int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value; 115 static const int RightNumDims = internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value; 116 typedef DSizes<Index, NumDims> Dimensions; 117 typedef typename XprType::Scalar Scalar; 118 typedef typename XprType::CoeffReturnType CoeffReturnType; 119 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 120 enum { 121 IsAligned = false, 122 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, 123 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 124 RawAccess = false 125 }; 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 128 : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) 129 { 130 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); 131 EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE); 132 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); 133 134 eigen_assert(0 <= m_axis && m_axis < NumDims); 135 const Dimensions& lhs_dims = m_leftImpl.dimensions(); 136 const Dimensions& rhs_dims = m_rightImpl.dimensions(); 137 { 138 int i = 0; 139 for (; i < m_axis; ++i) { 140 eigen_assert(lhs_dims[i] > 0); 141 eigen_assert(lhs_dims[i] == rhs_dims[i]); 142 m_dimensions[i] = lhs_dims[i]; 143 } 144 eigen_assert(lhs_dims[i] > 0); // Now i == m_axis. 145 eigen_assert(rhs_dims[i] > 0); 146 m_dimensions[i] = lhs_dims[i] + rhs_dims[i]; 147 for (++i; i < NumDims; ++i) { 148 eigen_assert(lhs_dims[i] > 0); 149 eigen_assert(lhs_dims[i] == rhs_dims[i]); 150 m_dimensions[i] = lhs_dims[i]; 151 } 152 } 153 154 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 155 m_leftStrides[0] = 1; 156 m_rightStrides[0] = 1; 157 m_outputStrides[0] = 1; 158 159 for (int j = 1; j < NumDims; ++j) { 160 m_leftStrides[j] = m_leftStrides[j-1] * lhs_dims[j-1]; 161 m_rightStrides[j] = m_rightStrides[j-1] * rhs_dims[j-1]; 162 m_outputStrides[j] = m_outputStrides[j-1] * m_dimensions[j-1]; 163 } 164 } else { 165 m_leftStrides[NumDims - 1] = 1; 166 m_rightStrides[NumDims - 1] = 1; 167 m_outputStrides[NumDims - 1] = 1; 168 169 for (int j = NumDims - 2; j >= 0; --j) { 170 m_leftStrides[j] = m_leftStrides[j+1] * lhs_dims[j+1]; 171 m_rightStrides[j] = m_rightStrides[j+1] * rhs_dims[j+1]; 172 m_outputStrides[j] = m_outputStrides[j+1] * m_dimensions[j+1]; 173 } 174 } 175 } 176 177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 178 179 // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear? 180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) 181 { 182 m_leftImpl.evalSubExprsIfNeeded(NULL); 183 m_rightImpl.evalSubExprsIfNeeded(NULL); 184 return true; 185 } 186 187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() 188 { 189 m_leftImpl.cleanup(); 190 m_rightImpl.cleanup(); 191 } 192 193 // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow. 194 // See CL/76180724 comments for more ideas. 195 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 196 { 197 // Collect dimension-wise indices (subs). 198 array<Index, NumDims> subs; 199 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 200 for (int i = NumDims - 1; i > 0; --i) { 201 subs[i] = index / m_outputStrides[i]; 202 index -= subs[i] * m_outputStrides[i]; 203 } 204 subs[0] = index; 205 } else { 206 for (int i = 0; i < NumDims - 1; ++i) { 207 subs[i] = index / m_outputStrides[i]; 208 index -= subs[i] * m_outputStrides[i]; 209 } 210 subs[NumDims - 1] = index; 211 } 212 213 const Dimensions& left_dims = m_leftImpl.dimensions(); 214 if (subs[m_axis] < left_dims[m_axis]) { 215 Index left_index; 216 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 217 left_index = subs[0]; 218 for (int i = 1; i < NumDims; ++i) { 219 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 220 } 221 } else { 222 left_index = subs[NumDims - 1]; 223 for (int i = NumDims - 2; i >= 0; --i) { 224 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 225 } 226 } 227 return m_leftImpl.coeff(left_index); 228 } else { 229 subs[m_axis] -= left_dims[m_axis]; 230 const Dimensions& right_dims = m_rightImpl.dimensions(); 231 Index right_index; 232 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 233 right_index = subs[0]; 234 for (int i = 1; i < NumDims; ++i) { 235 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 236 } 237 } else { 238 right_index = subs[NumDims - 1]; 239 for (int i = NumDims - 2; i >= 0; --i) { 240 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 241 } 242 } 243 return m_rightImpl.coeff(right_index); 244 } 245 } 246 247 // TODO(phli): Add a real vectorization. 248 template<int LoadMode> 249 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 250 { 251 const int packetSize = internal::unpacket_traits<PacketReturnType>::size; 252 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 253 eigen_assert(index + packetSize - 1 < dimensions().TotalSize()); 254 255 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 256 for (int i = 0; i < packetSize; ++i) { 257 values[i] = coeff(index+i); 258 } 259 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 260 return rslt; 261 } 262 263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 264 costPerCoeff(bool vectorized) const { 265 const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 266 2 * TensorOpCost::MulCost<Index>() + 267 TensorOpCost::DivCost<Index>() + 268 TensorOpCost::ModCost<Index>()); 269 const double lhs_size = m_leftImpl.dimensions().TotalSize(); 270 const double rhs_size = m_rightImpl.dimensions().TotalSize(); 271 return (lhs_size / (lhs_size + rhs_size)) * 272 m_leftImpl.costPerCoeff(vectorized) + 273 (rhs_size / (lhs_size + rhs_size)) * 274 m_rightImpl.costPerCoeff(vectorized) + 275 TensorOpCost(0, 0, compute_cost); 276 } 277 278 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 279 280 protected: 281 Dimensions m_dimensions; 282 array<Index, NumDims> m_outputStrides; 283 array<Index, NumDims> m_leftStrides; 284 array<Index, NumDims> m_rightStrides; 285 TensorEvaluator<LeftArgType, Device> m_leftImpl; 286 TensorEvaluator<RightArgType, Device> m_rightImpl; 287 const Axis m_axis; 288 }; 289 290 // Eval as lvalue 291 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 292 struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 293 : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 294 { 295 typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base; 296 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 297 typedef typename Base::Dimensions Dimensions; 298 enum { 299 IsAligned = false, 300 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, 301 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 302 RawAccess = false 303 }; 304 305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) 306 : Base(op, device) 307 { 308 EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE); 309 } 310 311 typedef typename XprType::Index Index; 312 typedef typename XprType::Scalar Scalar; 313 typedef typename XprType::CoeffReturnType CoeffReturnType; 314 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 315 316 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) 317 { 318 // Collect dimension-wise indices (subs). 319 array<Index, Base::NumDims> subs; 320 for (int i = Base::NumDims - 1; i > 0; --i) { 321 subs[i] = index / this->m_outputStrides[i]; 322 index -= subs[i] * this->m_outputStrides[i]; 323 } 324 subs[0] = index; 325 326 const Dimensions& left_dims = this->m_leftImpl.dimensions(); 327 if (subs[this->m_axis] < left_dims[this->m_axis]) { 328 Index left_index = subs[0]; 329 for (int i = 1; i < Base::NumDims; ++i) { 330 left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i]; 331 } 332 return this->m_leftImpl.coeffRef(left_index); 333 } else { 334 subs[this->m_axis] -= left_dims[this->m_axis]; 335 const Dimensions& right_dims = this->m_rightImpl.dimensions(); 336 Index right_index = subs[0]; 337 for (int i = 1; i < Base::NumDims; ++i) { 338 right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i]; 339 } 340 return this->m_rightImpl.coeffRef(right_index); 341 } 342 } 343 344 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 345 void writePacket(Index index, const PacketReturnType& x) 346 { 347 const int packetSize = internal::unpacket_traits<PacketReturnType>::size; 348 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 349 eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize()); 350 351 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 352 internal::pstore<CoeffReturnType, PacketReturnType>(values, x); 353 for (int i = 0; i < packetSize; ++i) { 354 coeffRef(index+i) = values[i]; 355 } 356 } 357 }; 358 359 } // end namespace Eigen 360 361 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 362