1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 enum { 18 Rhs = 0, 19 Lhs = 1 20 }; 21 22 /* 23 * Implementation of the Eigen blas_data_mapper class for tensors. 24 */ 25 26 template <typename Tensor, bool HasRawAccess> struct CoeffLoader { 27 enum { 28 DirectOffsets = false 29 }; 30 CoeffLoaderCoeffLoader31 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { } 32 offsetBufferCoeffLoader33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) { 34 eigen_assert(false && "unsupported"); 35 } 36 coeffCoeffLoader37 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } 38 39 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetCoeffLoader40 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 41 { 42 return m_tensor.template packet<LoadMode>(index); 43 } 44 45 46 private: 47 const Tensor m_tensor; 48 }; 49 50 template <typename Tensor> struct CoeffLoader<Tensor, true> { 51 enum { 52 DirectOffsets = true 53 }; 54 55 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {} 56 57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 58 m_data += offset; 59 } 60 61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } 62 63 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 64 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 65 { 66 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index); 67 } 68 private: 69 typedef typename Tensor::Scalar Scalar; 70 const Scalar* m_data; 71 }; 72 73 template<typename Scalar, typename Index, int side, 74 typename Tensor, 75 typename nocontract_t, typename contract_t, 76 int packet_size, bool inner_dim_contiguous, int Alignment> 77 class SimpleTensorContractionMapper { 78 public: 79 EIGEN_DEVICE_FUNC 80 SimpleTensorContractionMapper(const Tensor& tensor, 81 const nocontract_t& nocontract_strides, 82 const nocontract_t& ij_strides, 83 const contract_t& contract_strides, 84 const contract_t& k_strides) : 85 m_tensor(tensor), 86 m_nocontract_strides(nocontract_strides), 87 m_ij_strides(ij_strides), 88 m_contract_strides(contract_strides), 89 m_k_strides(k_strides) { } 90 91 enum { 92 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets 93 }; 94 95 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 96 m_tensor.offsetBuffer(offset); 97 } 98 99 EIGEN_DEVICE_FUNC 100 EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } 101 102 EIGEN_DEVICE_FUNC 103 EIGEN_STRONG_INLINE Scalar operator()(Index row) const { 104 // column major assumption 105 return operator()(row, 0); 106 } 107 108 EIGEN_DEVICE_FUNC 109 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { 110 return m_tensor.coeff(computeIndex(row, col)); 111 } 112 113 EIGEN_DEVICE_FUNC 114 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { 115 const bool left = (side == Lhs); 116 Index nocontract_val = left ? row : col; 117 Index linidx = 0; 118 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 119 const Index idx = nocontract_val / m_ij_strides[i]; 120 linidx += idx * m_nocontract_strides[i]; 121 nocontract_val -= idx * m_ij_strides[i]; 122 } 123 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 124 if (side == Lhs && inner_dim_contiguous) { 125 eigen_assert(m_nocontract_strides[0] == 1); 126 linidx += nocontract_val; 127 } else { 128 linidx += nocontract_val * m_nocontract_strides[0]; 129 } 130 } 131 132 Index contract_val = left ? col : row; 133 if(array_size<contract_t>::value > 0) { 134 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 135 const Index idx = contract_val / m_k_strides[i]; 136 linidx += idx * m_contract_strides[i]; 137 contract_val -= idx * m_k_strides[i]; 138 } 139 140 if (side == Rhs && inner_dim_contiguous) { 141 eigen_assert(m_contract_strides[0] == 1); 142 linidx += contract_val; 143 } else { 144 linidx += contract_val * m_contract_strides[0]; 145 } 146 } 147 148 return linidx; 149 } 150 151 EIGEN_DEVICE_FUNC 152 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const { 153 const bool left = (side == Lhs); 154 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; 155 Index linidx[2] = {0, 0}; 156 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 157 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 158 const Index idx0 = nocontract_val[0] / m_ij_strides[i]; 159 const Index idx1 = nocontract_val[1] / m_ij_strides[i]; 160 linidx[0] += idx0 * m_nocontract_strides[i]; 161 linidx[1] += idx1 * m_nocontract_strides[i]; 162 nocontract_val[0] -= idx0 * m_ij_strides[i]; 163 nocontract_val[1] -= idx1 * m_ij_strides[i]; 164 } 165 if (side == Lhs && inner_dim_contiguous) { 166 eigen_assert(m_nocontract_strides[0] == 1); 167 linidx[0] += nocontract_val[0]; 168 linidx[1] += nocontract_val[1]; 169 } else { 170 linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; 171 linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; 172 } 173 } 174 175 Index contract_val[2] = {left ? col : row, left ? col : row + distance}; 176 if (array_size<contract_t>::value> 0) { 177 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 178 const Index idx0 = contract_val[0] / m_k_strides[i]; 179 const Index idx1 = contract_val[1] / m_k_strides[i]; 180 linidx[0] += idx0 * m_contract_strides[i]; 181 linidx[1] += idx1 * m_contract_strides[i]; 182 contract_val[0] -= idx0 * m_k_strides[i]; 183 contract_val[1] -= idx1 * m_k_strides[i]; 184 } 185 186 if (side == Rhs && inner_dim_contiguous) { 187 eigen_assert(m_contract_strides[0] == 1); 188 linidx[0] += contract_val[0]; 189 linidx[1] += contract_val[1]; 190 } else { 191 linidx[0] += contract_val[0] * m_contract_strides[0]; 192 linidx[1] += contract_val[1] * m_contract_strides[0]; 193 } 194 } 195 return IndexPair<Index>(linidx[0], linidx[1]); 196 } 197 198 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { 199 // Only claim alignment when we can compute the actual stride (ie when we're 200 // dealing with the lhs with inner_dim_contiguous. This is because the 201 // matrix-vector product relies on the stride when dealing with aligned inputs. 202 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; 203 } 204 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { 205 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; 206 } 207 208 protected: 209 CoeffLoader<Tensor, Tensor::RawAccess> m_tensor; 210 const nocontract_t m_nocontract_strides; 211 const nocontract_t m_ij_strides; 212 const contract_t m_contract_strides; 213 const contract_t m_k_strides; 214 }; 215 216 217 template<typename Scalar, typename Index, int side, 218 typename Tensor, 219 typename nocontract_t, typename contract_t, 220 int packet_size, bool inner_dim_contiguous, 221 bool inner_dim_reordered, int Alignment> 222 class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> 223 { 224 public: 225 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper; 226 227 EIGEN_DEVICE_FUNC 228 BaseTensorContractionMapper(const Tensor& tensor, 229 const nocontract_t& nocontract_strides, 230 const nocontract_t& ij_strides, 231 const contract_t& contract_strides, 232 const contract_t& k_strides) : 233 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 234 235 typedef typename Tensor::PacketReturnType Packet; 236 typedef typename unpacket_traits<Packet>::half HalfPacket; 237 238 template <int AlignmentType> 239 EIGEN_DEVICE_FUNC 240 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { 241 // whole method makes column major assumption 242 243 // don't need to add offsets for now (because operator handles that) 244 // current code assumes packet size must be a multiple of 2 245 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); 246 247 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { 248 const Index index = this->computeIndex(i, j); 249 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); 250 return this->m_tensor.template packet<AlignmentType>(index); 251 } 252 253 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); 254 const Index first = indexPair.first; 255 const Index last = indexPair.second; 256 257 // We can always do optimized packet reads from left hand side right now, because 258 // the vertical matrix dimension on the left hand side is never contracting. 259 // On the right hand side we need to check if the contracting dimensions may have 260 // been shuffled first. 261 if (Tensor::PacketAccess && 262 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && 263 (last - first) == (packet_size - 1)) { 264 265 return this->m_tensor.template packet<AlignmentType>(first); 266 } 267 268 EIGEN_ALIGN_MAX Scalar data[packet_size]; 269 270 data[0] = this->m_tensor.coeff(first); 271 for (Index k = 1; k < packet_size - 1; k += 2) { 272 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); 273 data[k] = this->m_tensor.coeff(internal_pair.first); 274 data[k + 1] = this->m_tensor.coeff(internal_pair.second); 275 } 276 data[packet_size - 1] = this->m_tensor.coeff(last); 277 278 return pload<Packet>(data); 279 } 280 281 template <int AlignmentType> 282 EIGEN_DEVICE_FUNC 283 EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { 284 // whole method makes column major assumption 285 286 // don't need to add offsets for now (because operator handles that) 287 const Index half_packet_size = unpacket_traits<HalfPacket>::size; 288 if (half_packet_size == packet_size) { 289 return loadPacket<AlignmentType>(i, j); 290 } 291 EIGEN_ALIGN_MAX Scalar data[half_packet_size]; 292 for (Index k = 0; k < half_packet_size; k++) { 293 data[k] = operator()(i + k, j); 294 } 295 return pload<HalfPacket>(data); 296 } 297 }; 298 299 300 template<typename Scalar, typename Index, int side, 301 typename Tensor, 302 typename nocontract_t, typename contract_t, 303 bool inner_dim_contiguous, 304 bool inner_dim_reordered, int Alignment> 305 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> 306 { 307 public: 308 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper; 309 310 EIGEN_DEVICE_FUNC 311 BaseTensorContractionMapper(const Tensor& tensor, 312 const nocontract_t& nocontract_strides, 313 const nocontract_t& ij_strides, 314 const contract_t& contract_strides, 315 const contract_t& k_strides) : 316 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 317 318 typedef typename Tensor::PacketReturnType Packet; 319 template <int> EIGEN_DEVICE_FUNC 320 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { 321 EIGEN_ALIGN_MAX Scalar data[1]; 322 data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); 323 return pload<typename Tensor::PacketReturnType>(data); 324 } 325 template <int> EIGEN_DEVICE_FUNC 326 EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { 327 return loadPacket(i, j); 328 } 329 }; 330 331 332 template<typename Scalar, typename Index, int side, 333 typename Tensor, 334 typename nocontract_t, typename contract_t, 335 int packet_size, 336 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 337 class TensorContractionSubMapper { 338 public: 339 typedef typename Tensor::PacketReturnType Packet; 340 typedef typename unpacket_traits<Packet>::half HalfPacket; 341 342 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper; 343 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self; 344 typedef Self LinearMapper; 345 346 enum { 347 // We can use direct offsets iff the parent mapper supports then and we can compute the strides. 348 // TODO: we should also enable direct offsets for the Rhs case. 349 UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0) 350 }; 351 352 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 353 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { 354 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute 355 // this offset every time we attempt to access a coefficient. 356 if (UseDirectOffsets) { 357 Index stride = m_base_mapper.stride(); 358 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride); 359 } 360 } 361 362 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 363 if (UseDirectOffsets) { 364 return m_base_mapper(i, 0); 365 } 366 return m_base_mapper(i + m_vert_offset, m_horiz_offset); 367 } 368 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { 369 if (UseDirectOffsets) { 370 return m_base_mapper(i, j); 371 } 372 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); 373 } 374 375 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 376 if (UseDirectOffsets) { 377 return m_base_mapper.template loadPacket<Alignment>(i, 0); 378 } 379 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset); 380 } 381 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { 382 if (UseDirectOffsets) { 383 return m_base_mapper.template loadPacket<Alignment>(i, j); 384 } 385 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset); 386 } 387 388 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { 389 if (UseDirectOffsets) { 390 return m_base_mapper.template loadHalfPacket<Alignment>(i, 0); 391 } 392 return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset); 393 } 394 395 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { 396 if (UseDirectOffsets) { 397 m_base_mapper.storePacket(i, 0, p); 398 } 399 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); 400 } 401 402 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 403 if (UseDirectOffsets) { 404 return LinearMapper(m_base_mapper, i, j); 405 } 406 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); 407 } 408 409 template <typename PacketT, int AlignmentType> 410 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { 411 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); 412 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; 413 if (UseDirectOffsets) { 414 return m_base_mapper.template loadPacket<ActualAlignment>(i, 0); 415 } 416 return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset); 417 } 418 419 template <typename Packet> 420 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { 421 return false; 422 } 423 424 private: 425 ParentMapper m_base_mapper; 426 const Index m_vert_offset; 427 const Index m_horiz_offset; 428 }; 429 430 431 template<typename Scalar_, typename Index, int side, 432 typename Tensor, 433 typename nocontract_t, typename contract_t, 434 int packet_size, 435 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 436 class TensorContractionInputMapper 437 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { 438 439 public: 440 typedef Scalar_ Scalar; 441 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base; 442 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; 443 typedef SubMapper VectorMapper; 444 445 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, 446 const nocontract_t& nocontract_strides, 447 const nocontract_t& ij_strides, 448 const contract_t& contract_strides, 449 const contract_t& k_strides) 450 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 451 452 EIGEN_DEVICE_FUNC 453 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 454 return SubMapper(*this, i, j); 455 } 456 457 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { 458 return VectorMapper(*this, i, j); 459 } 460 }; 461 462 463 464 } // end namespace internal 465 } // end namespace Eigen 466 467 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 468