1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 12 13 namespace Eigen { 14 15 /** \class TensorBroadcasting 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor broadcasting class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename Broadcast, typename XprType> 24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType> 25 { 26 typedef typename XprType::Scalar Scalar; 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef typename XprType::Nested Nested; 31 typedef typename remove_reference<Nested>::type _Nested; 32 static const int NumDimensions = XprTraits::NumDimensions; 33 static const int Layout = XprTraits::Layout; 34 typedef typename XprTraits::PointerType PointerType; 35 }; 36 37 template<typename Broadcast, typename XprType> 38 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense> 39 { 40 typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type; 41 }; 42 43 template<typename Broadcast, typename XprType> 44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type> 45 { 46 typedef TensorBroadcastingOp<Broadcast, XprType> type; 47 }; 48 49 template <typename Dims> 50 struct is_input_scalar { 51 static const bool value = false; 52 }; 53 template <> 54 struct is_input_scalar<Sizes<> > { 55 static const bool value = true; 56 }; 57 #ifndef EIGEN_EMULATE_CXX11_META_H 58 template <typename std::ptrdiff_t... Indices> 59 struct is_input_scalar<Sizes<Indices...> > { 60 static const bool value = (Sizes<Indices...>::total_size == 1); 61 }; 62 #endif 63 64 } // end namespace internal 65 66 67 68 template<typename Broadcast, typename XprType> 69 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors> 70 { 71 public: 72 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar; 73 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 74 typedef typename XprType::CoeffReturnType CoeffReturnType; 75 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested; 76 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind; 77 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index; 78 79 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast) 80 : m_xpr(expr), m_broadcast(broadcast) {} 81 82 EIGEN_DEVICE_FUNC 83 const Broadcast& broadcast() const { return m_broadcast; } 84 85 EIGEN_DEVICE_FUNC 86 const typename internal::remove_all<typename XprType::Nested>::type& 87 expression() const { return m_xpr; } 88 89 protected: 90 typename XprType::Nested m_xpr; 91 const Broadcast m_broadcast; 92 }; 93 94 95 // Eval as rvalue 96 template<typename Broadcast, typename ArgType, typename Device> 97 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> 98 { 99 typedef TensorBroadcastingOp<Broadcast, ArgType> XprType; 100 typedef typename XprType::Index Index; 101 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 102 typedef DSizes<Index, NumDims> Dimensions; 103 typedef typename XprType::Scalar Scalar; 104 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions; 105 typedef typename XprType::CoeffReturnType CoeffReturnType; 106 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 107 static const int PacketSize = PacketType<CoeffReturnType, Device>::size; 108 protected: // all the non-static fields must have the same access control, otherwise the TensorEvaluator wont be standard layout; 109 bool isCopy, nByOne, oneByN; 110 public: 111 typedef StorageMemory<CoeffReturnType, Device> Storage; 112 typedef typename Storage::Type EvaluatorPointerType; 113 114 enum { 115 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, 116 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 117 BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess, 118 PreferBlockAccess = true, 119 Layout = TensorEvaluator<ArgType, Device>::Layout, 120 RawAccess = false 121 }; 122 123 typedef typename internal::remove_const<Scalar>::type ScalarNoConst; 124 125 // We do block based broadcasting using a trick with 2x tensor rank and 0 126 // strides. See block method implementation for details. 127 typedef DSizes<Index, 2 * NumDims> BroadcastDimensions; 128 129 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 130 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; 131 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; 132 133 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock 134 ArgTensorBlock; 135 136 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims, 137 Layout, Index> 138 TensorBlock; 139 //===--------------------------------------------------------------------===// 140 141 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 142 : isCopy(false), nByOne(false), oneByN(false), 143 m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device) 144 { 145 146 // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar 147 // and store the result in a scalar. Instead one should reshape the scalar into a a N-D 148 // tensor with N >= 1 of 1 element first and then broadcast. 149 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); 150 const InputDimensions& input_dims = m_impl.dimensions(); 151 isCopy = true; 152 for (int i = 0; i < NumDims; ++i) { 153 eigen_assert(input_dims[i] > 0); 154 m_dimensions[i] = input_dims[i] * m_broadcast[i]; 155 if (m_broadcast[i] != 1) { 156 isCopy = false; 157 } 158 } 159 160 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 161 m_inputStrides[0] = 1; 162 m_outputStrides[0] = 1; 163 for (int i = 1; i < NumDims; ++i) { 164 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; 165 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; 166 } 167 } else { 168 m_inputStrides[NumDims-1] = 1; 169 m_outputStrides[NumDims-1] = 1; 170 for (int i = NumDims-2; i >= 0; --i) { 171 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1]; 172 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; 173 } 174 } 175 176 if (input_dims[0] == 1) { 177 oneByN = true; 178 for (int i = 1; i < NumDims; ++i) { 179 if (m_broadcast[i] != 1) { 180 oneByN = false; 181 break; 182 } 183 } 184 } else if (input_dims[NumDims-1] == 1) { 185 nByOne = true; 186 for (int i = 0; i < NumDims-1; ++i) { 187 if (m_broadcast[i] != 1) { 188 nByOne = false; 189 break; 190 } 191 } 192 } 193 194 // Handle special format like NCHW, its input shape is '[1, N..., 1]' and 195 // broadcast shape is '[N, 1..., N]' 196 if (!oneByN && !nByOne) { 197 if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) { 198 nByOne = true; 199 oneByN = true; 200 for (int i = 1; i < NumDims-1; ++i) { 201 if (m_broadcast[i] != 1) { 202 nByOne = false; 203 oneByN = false; 204 break; 205 } 206 } 207 } 208 } 209 } 210 211 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 212 213 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { 214 m_impl.evalSubExprsIfNeeded(NULL); 215 return true; 216 } 217 218 #ifdef EIGEN_USE_THREADS 219 template <typename EvalSubExprsCallback> 220 EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( 221 EvaluatorPointerType, EvalSubExprsCallback done) { 222 m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); 223 } 224 #endif // EIGEN_USE_THREADS 225 226 EIGEN_STRONG_INLINE void cleanup() { 227 m_impl.cleanup(); 228 } 229 230 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const 231 { 232 if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { 233 return m_impl.coeff(0); 234 } 235 236 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 237 if (isCopy) { 238 return m_impl.coeff(index); 239 } else { 240 return coeffColMajor(index); 241 } 242 } else { 243 if (isCopy) { 244 return m_impl.coeff(index); 245 } else { 246 return coeffRowMajor(index); 247 } 248 } 249 } 250 251 // TODO: attempt to speed this up. The integer divisions and modulo are slow 252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index) const { 253 Index inputIndex = 0; 254 EIGEN_UNROLL_LOOP 255 for (int i = NumDims - 1; i > 0; --i) { 256 const Index idx = index / m_outputStrides[i]; 257 if (internal::index_statically_eq<Broadcast>(i, 1)) { 258 eigen_assert(idx < m_impl.dimensions()[i]); 259 inputIndex += idx * m_inputStrides[i]; 260 } else { 261 if (internal::index_statically_eq<InputDimensions>(i, 1)) { 262 eigen_assert(idx % m_impl.dimensions()[i] == 0); 263 } else { 264 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; 265 } 266 } 267 index -= idx * m_outputStrides[i]; 268 } 269 if (internal::index_statically_eq<Broadcast>(0, 1)) { 270 eigen_assert(index < m_impl.dimensions()[0]); 271 inputIndex += index; 272 } else { 273 if (internal::index_statically_eq<InputDimensions>(0, 1)) { 274 eigen_assert(index % m_impl.dimensions()[0] == 0); 275 } else { 276 inputIndex += (index % m_impl.dimensions()[0]); 277 } 278 } 279 return inputIndex; 280 } 281 282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const 283 { 284 return m_impl.coeff(indexColMajor(index)); 285 } 286 287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index) const { 288 Index inputIndex = 0; 289 EIGEN_UNROLL_LOOP 290 for (int i = 0; i < NumDims - 1; ++i) { 291 const Index idx = index / m_outputStrides[i]; 292 if (internal::index_statically_eq<Broadcast>(i, 1)) { 293 eigen_assert(idx < m_impl.dimensions()[i]); 294 inputIndex += idx * m_inputStrides[i]; 295 } else { 296 if (internal::index_statically_eq<InputDimensions>(i, 1)) { 297 eigen_assert(idx % m_impl.dimensions()[i] == 0); 298 } else { 299 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; 300 } 301 } 302 index -= idx * m_outputStrides[i]; 303 } 304 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) { 305 eigen_assert(index < m_impl.dimensions()[NumDims - 1]); 306 inputIndex += index; 307 } else { 308 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) { 309 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0); 310 } else { 311 inputIndex += (index % m_impl.dimensions()[NumDims - 1]); 312 } 313 } 314 return inputIndex; 315 } 316 317 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const 318 { 319 return m_impl.coeff(indexRowMajor(index)); 320 } 321 322 template<int LoadMode> 323 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const 324 { 325 if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { 326 return internal::pset1<PacketReturnType>(m_impl.coeff(0)); 327 } 328 329 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 330 if (isCopy) { 331 #ifdef EIGEN_GPU_COMPILE_PHASE 332 // See PR 437: on NVIDIA P100 and K20m we observed a x3-4 speed up by enforcing 333 // unaligned loads here. The reason is unclear though. 334 return m_impl.template packet<Unaligned>(index); 335 #else 336 return m_impl.template packet<LoadMode>(index); 337 #endif 338 } else if (oneByN && !nByOne) { 339 return packetNByOne<LoadMode>(index); 340 } else if (!oneByN && nByOne) { 341 return packetOneByN<LoadMode>(index); 342 } else if (oneByN && nByOne) { 343 return packetOneByNByOne<LoadMode>(index); 344 } else { 345 return packetColMajor<LoadMode>(index); 346 } 347 } else { 348 if (isCopy) { 349 #ifdef EIGEN_GPU_COMPILE_PHASE 350 // See above. 351 return m_impl.template packet<Unaligned>(index); 352 #else 353 return m_impl.template packet<LoadMode>(index); 354 #endif 355 } else if (oneByN && !nByOne) { 356 return packetOneByN<LoadMode>(index); 357 } else if (!oneByN && nByOne) { 358 return packetNByOne<LoadMode>(index); 359 } else if (oneByN && nByOne) { 360 return packetOneByNByOne<LoadMode>(index); 361 } else { 362 return packetRowMajor<LoadMode>(index); 363 } 364 } 365 } 366 367 template<int LoadMode> 368 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne 369 (Index index) const 370 { 371 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 372 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 373 374 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 375 Index startDim, endDim; 376 Index inputIndex, outputOffset, batchedIndex; 377 378 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 379 startDim = NumDims - 1; 380 endDim = 1; 381 } else { 382 startDim = 0; 383 endDim = NumDims - 2; 384 } 385 386 batchedIndex = index % m_outputStrides[startDim]; 387 inputIndex = batchedIndex / m_outputStrides[endDim]; 388 outputOffset = batchedIndex % m_outputStrides[endDim]; 389 390 if (outputOffset + PacketSize <= m_outputStrides[endDim]) { 391 values[0] = m_impl.coeff(inputIndex); 392 return internal::pload1<PacketReturnType>(values); 393 } else { 394 EIGEN_UNROLL_LOOP 395 for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) { 396 if (outputOffset + cur < m_outputStrides[endDim]) { 397 values[i] = m_impl.coeff(inputIndex); 398 } else { 399 ++inputIndex; 400 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex); 401 values[i] = m_impl.coeff(inputIndex); 402 outputOffset = 0; 403 cur = 0; 404 } 405 } 406 return internal::pload<PacketReturnType>(values); 407 } 408 } 409 410 template<int LoadMode> 411 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const 412 { 413 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 414 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 415 416 Index dim, inputIndex; 417 418 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 419 dim = NumDims - 1; 420 } else { 421 dim = 0; 422 } 423 424 inputIndex = index % m_inputStrides[dim]; 425 if (inputIndex + PacketSize <= m_inputStrides[dim]) { 426 return m_impl.template packet<Unaligned>(inputIndex); 427 } else { 428 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 429 EIGEN_UNROLL_LOOP 430 for (int i = 0; i < PacketSize; ++i) { 431 if (inputIndex > m_inputStrides[dim]-1) { 432 inputIndex = 0; 433 } 434 values[i] = m_impl.coeff(inputIndex++); 435 } 436 return internal::pload<PacketReturnType>(values); 437 } 438 } 439 440 template<int LoadMode> 441 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index) const 442 { 443 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 444 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 445 446 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 447 Index dim, inputIndex, outputOffset; 448 449 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 450 dim = 1; 451 } else { 452 dim = NumDims - 2; 453 } 454 455 inputIndex = index / m_outputStrides[dim]; 456 outputOffset = index % m_outputStrides[dim]; 457 if (outputOffset + PacketSize <= m_outputStrides[dim]) { 458 values[0] = m_impl.coeff(inputIndex); 459 return internal::pload1<PacketReturnType>(values); 460 } else { 461 EIGEN_UNROLL_LOOP 462 for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) { 463 if (outputOffset + cur < m_outputStrides[dim]) { 464 values[i] = m_impl.coeff(inputIndex); 465 } else { 466 values[i] = m_impl.coeff(++inputIndex); 467 outputOffset = 0; 468 cur = 0; 469 } 470 } 471 return internal::pload<PacketReturnType>(values); 472 } 473 } 474 475 // Ignore the LoadMode and always use unaligned loads since we can't guarantee 476 // the alignment at compile time. 477 template<int LoadMode> 478 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const 479 { 480 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 481 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 482 483 const Index originalIndex = index; 484 485 Index inputIndex = 0; 486 EIGEN_UNROLL_LOOP 487 for (int i = NumDims - 1; i > 0; --i) { 488 const Index idx = index / m_outputStrides[i]; 489 if (internal::index_statically_eq<Broadcast>(i, 1)) { 490 eigen_assert(idx < m_impl.dimensions()[i]); 491 inputIndex += idx * m_inputStrides[i]; 492 } else { 493 if (internal::index_statically_eq<InputDimensions>(i, 1)) { 494 eigen_assert(idx % m_impl.dimensions()[i] == 0); 495 } else { 496 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; 497 } 498 } 499 index -= idx * m_outputStrides[i]; 500 } 501 Index innermostLoc; 502 if (internal::index_statically_eq<Broadcast>(0, 1)) { 503 eigen_assert(index < m_impl.dimensions()[0]); 504 innermostLoc = index; 505 } else { 506 if (internal::index_statically_eq<InputDimensions>(0, 1)) { 507 eigen_assert(index % m_impl.dimensions()[0] == 0); 508 innermostLoc = 0; 509 } else { 510 innermostLoc = index % m_impl.dimensions()[0]; 511 } 512 } 513 inputIndex += innermostLoc; 514 515 // Todo: this could be extended to the second dimension if we're not 516 // broadcasting alongside the first dimension, and so on. 517 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) { 518 return m_impl.template packet<Unaligned>(inputIndex); 519 } else { 520 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 521 values[0] = m_impl.coeff(inputIndex); 522 EIGEN_UNROLL_LOOP 523 for (int i = 1; i < PacketSize; ++i) { 524 if (innermostLoc + i < m_impl.dimensions()[0]) { 525 values[i] = m_impl.coeff(inputIndex+i); 526 } else { 527 values[i] = coeffColMajor(originalIndex+i); 528 } 529 } 530 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 531 return rslt; 532 } 533 } 534 535 template<int LoadMode> 536 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const 537 { 538 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 539 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 540 541 const Index originalIndex = index; 542 543 Index inputIndex = 0; 544 EIGEN_UNROLL_LOOP 545 for (int i = 0; i < NumDims - 1; ++i) { 546 const Index idx = index / m_outputStrides[i]; 547 if (internal::index_statically_eq<Broadcast>(i, 1)) { 548 eigen_assert(idx < m_impl.dimensions()[i]); 549 inputIndex += idx * m_inputStrides[i]; 550 } else { 551 if (internal::index_statically_eq<InputDimensions>(i, 1)) { 552 eigen_assert(idx % m_impl.dimensions()[i] == 0); 553 } else { 554 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; 555 } 556 } 557 index -= idx * m_outputStrides[i]; 558 } 559 Index innermostLoc; 560 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) { 561 eigen_assert(index < m_impl.dimensions()[NumDims-1]); 562 innermostLoc = index; 563 } else { 564 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) { 565 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0); 566 innermostLoc = 0; 567 } else { 568 innermostLoc = index % m_impl.dimensions()[NumDims-1]; 569 } 570 } 571 inputIndex += innermostLoc; 572 573 // Todo: this could be extended to the second dimension if we're not 574 // broadcasting alongside the first dimension, and so on. 575 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) { 576 return m_impl.template packet<Unaligned>(inputIndex); 577 } else { 578 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 579 values[0] = m_impl.coeff(inputIndex); 580 EIGEN_UNROLL_LOOP 581 for (int i = 1; i < PacketSize; ++i) { 582 if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) { 583 values[i] = m_impl.coeff(inputIndex+i); 584 } else { 585 values[i] = coeffRowMajor(originalIndex+i); 586 } 587 } 588 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 589 return rslt; 590 } 591 } 592 593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 594 costPerCoeff(bool vectorized) const { 595 double compute_cost = TensorOpCost::AddCost<Index>(); 596 if (!isCopy && NumDims > 0) { 597 EIGEN_UNROLL_LOOP 598 for (int i = NumDims - 1; i > 0; --i) { 599 compute_cost += TensorOpCost::DivCost<Index>(); 600 if (internal::index_statically_eq<Broadcast>(i, 1)) { 601 compute_cost += 602 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>(); 603 } else { 604 if (!internal::index_statically_eq<InputDimensions>(i, 1)) { 605 compute_cost += TensorOpCost::MulCost<Index>() + 606 TensorOpCost::ModCost<Index>() + 607 TensorOpCost::AddCost<Index>(); 608 } 609 } 610 compute_cost += 611 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>(); 612 } 613 } 614 return m_impl.costPerCoeff(vectorized) + 615 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize); 616 } 617 618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 619 internal::TensorBlockResourceRequirements getResourceRequirements() const { 620 // TODO(wuke): Targeting L1 size is 30% faster than targeting L{-1} on large 621 // tensors. But this might need further tuning. 622 const size_t target_size = m_device.firstLevelCacheSize(); 623 return internal::TensorBlockResourceRequirements::merge( 624 m_impl.getResourceRequirements(), 625 internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size)); 626 } 627 628 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock 629 block(TensorBlockDesc& desc, TensorBlockScratch& scratch, 630 bool /*root_of_expr_ast*/ = false) const { 631 BlockBroadcastingParams params = blockBroadcastingParams(desc); 632 633 if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) { 634 return emptyBlock(); 635 } 636 637 // Prepare storage for the materialized broadcasting result. 638 const typename TensorBlock::Storage block_storage = 639 TensorBlock::prepareStorage(desc, scratch); 640 ScalarNoConst* materialized_output = block_storage.data(); 641 642 // We potentially will need to materialize input blocks. 643 size_t materialized_input_size = 0; 644 ScalarNoConst* materialized_input = NULL; 645 646 // Initialize block broadcating iterator state for outer dimensions (outer 647 // with regard to bcast dimension). Dimension in this array are always in 648 // inner_most -> outer_most order (col major layout). 649 array<BlockBroadcastingIteratorState, NumDims> it; 650 int idx = 0; 651 652 for (int i = params.inner_dim_count + 1; i < NumDims; ++i) { 653 const Index dim = IsColMajor ? i : NumDims - 1 - i; 654 it[idx].size = params.output_dims[dim]; 655 it[idx].count = 0; 656 it[idx].output_stride = m_outputStrides[dim]; 657 it[idx].output_span = it[idx].output_stride * (it[idx].size - 1); 658 idx++; 659 } 660 661 // Write output into the beginning of `materialized_output`. 662 Index output_offset = 0; 663 664 // We will fill output block by broadcasting along the bcast dim, and 665 // iterating over outer dimension. 666 const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize(); 667 668 for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) { 669 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs; 670 Index bcast_offset = desc.offset() + output_offset; 671 672 // Broadcast along the bcast dimension. 673 num_output_coeffs += BroadcastBlockAlongBcastDim( 674 params, bcast_offset, scratch, bcast_output, &materialized_input, 675 &materialized_input_size); 676 677 // Switch to the next outer dimension. 678 for (int j = 0; j < idx; ++j) { 679 if (++it[j].count < it[j].size) { 680 output_offset += it[j].output_stride; 681 break; 682 } 683 it[j].count = 0; 684 output_offset -= it[j].output_span; 685 } 686 } 687 688 return block_storage.AsTensorMaterializedBlock(); 689 } 690 691 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } 692 693 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; } 694 695 Broadcast functor() const { return m_broadcast; } 696 #ifdef EIGEN_USE_SYCL 697 // binding placeholder accessors to a command group handler for SYCL 698 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind( 699 cl::sycl::handler& cgh) const { 700 m_impl.bind(cgh); 701 } 702 #endif 703 private: 704 static const bool IsColMajor = 705 static_cast<int>(Layout) == static_cast<int>(ColMajor); 706 707 // We will build a general case block broadcasting on top of broadcasting 708 // primitive that will do broadcasting only for the inner dimension(s) along 709 // the first dimension smaller than the input size (it's called `bcast_dim`). 710 // 711 // Example: 712 // dim: 0 1 2 (ColMajor) 713 // input size: [9, 3, 6] 714 // block size: [9, 2, 6] 715 // 716 // We will compute broadcasted block by iterating over the outer dimensions 717 // before `bcast_dim` (only dimension `2` in this example) and computing 718 // broadcasts along the `bcast_dim` (dimension `1` in this example). 719 720 // BlockBroadcastingParams holds precomputed parameters for broadcasting a 721 // single block along the broadcasting dimension. Sizes and strides along the 722 // `bcast_dim` might be invalid, they will be adjusted later in 723 // `BroadcastBlockAlongBcastDim`. 724 struct BlockBroadcastingParams { 725 Dimensions input_dims; // input expression dimensions 726 Dimensions output_dims; // output block sizes 727 Dimensions output_strides; // output block strides 728 729 int inner_dim_count; // count inner dimensions matching in size 730 int bcast_dim; // broadcasting dimension index 731 Index bcast_dim_size; // broadcasting dimension size 732 Index inner_dim_size; // inner dimensions size 733 734 // Block sizes and strides for the input block where all dimensions before 735 // `bcast_dim` are equal to `1`. 736 Dimensions input_block_sizes; 737 Dimensions input_block_strides; 738 739 // Block sizes and strides for blocks with extra dimensions and strides `0`. 740 BroadcastDimensions bcast_block_sizes; 741 BroadcastDimensions bcast_block_strides; 742 BroadcastDimensions bcast_input_strides; 743 }; 744 745 struct BlockBroadcastingIteratorState { 746 Index size; 747 Index count; 748 Index output_stride; 749 Index output_span; 750 }; 751 752 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams 753 blockBroadcastingParams(TensorBlockDesc& desc) const { 754 BlockBroadcastingParams params; 755 756 params.input_dims = Dimensions(m_impl.dimensions()); 757 758 // Output block sizes and strides. 759 params.output_dims = desc.dimensions(); 760 params.output_strides = internal::strides<Layout>(params.output_dims); 761 762 // Find the broadcasting dimension (first dimension with output size smaller 763 // that the input size). 764 params.bcast_dim = 0; 765 params.bcast_dim_size = 1; 766 params.inner_dim_size = 1; 767 768 // Count the number of inner dimensions that have the same size in the block 769 // and in the broadcast expression. 770 params.inner_dim_count = 0; 771 772 for (int i = 0; i < NumDims; ++i) { 773 const int dim = IsColMajor ? i : NumDims - i - 1; 774 775 if (params.output_dims[dim] == m_dimensions[dim]) { 776 params.inner_dim_size *= params.output_dims[dim]; 777 ++params.inner_dim_count; 778 continue; 779 } 780 781 // First non-matching dimension is the broadcasting dimension. 782 eigen_assert(params.output_dims[dim] < m_dimensions[dim]); 783 params.bcast_dim = dim; 784 params.bcast_dim_size = params.output_dims[dim]; 785 break; 786 } 787 788 // Calculate the input block size for looking into the input. 789 for (int i = 0; i < params.inner_dim_count; ++i) { 790 const int dim = IsColMajor ? i : NumDims - i - 1; 791 params.input_block_sizes[dim] = params.input_dims[dim]; 792 } 793 for (int i = params.inner_dim_count; i < NumDims; ++i) { 794 const int dim = IsColMajor ? i : NumDims - i - 1; 795 params.input_block_sizes[dim] = 1; 796 } 797 params.input_block_strides = 798 internal::strides<Layout>(params.input_block_sizes); 799 800 // Broadcast with the 0-stride trick: Create 1 extra dim for each 801 // broadcast, set the input stride to 0. 802 // 803 // When ColMajor: 804 // 805 // - bcast_block_sizes: 806 // [d_0, b_0, d_1, b_1, ...] 807 // 808 // - bcast_block_strides: 809 // [output_block_strides[0], output_block_strides[0] * d_0, 810 // output_block_strides[1], output_block_strides[1] * d_1, 811 // ...] 812 // 813 // - bcast_input_strides: 814 // [input_block_strides[0], 0, 815 // input_block_strides[1], 0, 816 // ...]. 817 // 818 for (int i = 0; i < params.inner_dim_count; ++i) { 819 const int dim = IsColMajor ? i : NumDims - i - 1; 820 821 const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1; 822 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1; 823 824 params.bcast_block_sizes[copy_dim] = params.input_dims[dim]; 825 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim]; 826 params.bcast_block_strides[copy_dim] = params.output_strides[dim]; 827 params.bcast_block_strides[broadcast_dim] = 828 params.output_strides[dim] * params.input_dims[dim]; 829 params.bcast_input_strides[copy_dim] = params.input_block_strides[dim]; 830 params.bcast_input_strides[broadcast_dim] = 0; 831 } 832 833 for (int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) { 834 const int dim = IsColMajor ? i : 2 * NumDims - i - 1; 835 params.bcast_block_sizes[dim] = 1; 836 params.bcast_block_strides[dim] = 0; 837 params.bcast_input_strides[dim] = 0; 838 } 839 840 return params; 841 } 842 843 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock() const { 844 DSizes<Index, NumDims> dimensions; 845 for (int i = 0; i < NumDims; ++i) dimensions[i] = 0; 846 return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions); 847 } 848 849 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim( 850 BlockBroadcastingParams params, Index bcast_offset, 851 TensorBlockScratch& scratch, ScalarNoConst* materialized_output, 852 ScalarNoConst** materialized_input, 853 size_t* materialized_input_size) const { 854 if (params.bcast_dim_size == 1) { 855 // We just need one block read using the ready-set values above. 856 return BroadcastBlock( 857 params.input_block_sizes, params.input_block_strides, 858 params.bcast_block_sizes, params.bcast_block_strides, 859 params.bcast_input_strides, bcast_offset, 0, scratch, 860 materialized_output, materialized_input, materialized_input_size); 861 862 } else if (params.input_dims[params.bcast_dim] == 1) { 863 // Broadcast bcast dimension (< NumDims) by bcast_dim_size. 864 const int broadcast_bcast_dim = 865 IsColMajor ? 2 * params.inner_dim_count + 1 866 : 2 * NumDims - 2 * params.inner_dim_count - 2; 867 868 params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size; 869 params.bcast_input_strides[broadcast_bcast_dim] = 0; 870 params.bcast_block_strides[broadcast_bcast_dim] = 871 params.output_strides[params.bcast_dim]; 872 873 return BroadcastBlock( 874 params.input_block_sizes, params.input_block_strides, 875 params.bcast_block_sizes, params.bcast_block_strides, 876 params.bcast_input_strides, bcast_offset, 0, scratch, 877 materialized_output, materialized_input, materialized_input_size); 878 879 } else { 880 // Keep track of the total number of the coefficients written to the 881 // output block. 882 Index num_output_coeffs = 0; 883 884 // The general case. Let's denote the output block as 885 // 886 // x[..., a:a+bcast_dim_size, :, ..., :] 887 // 888 // where a:a+bcast_dim_size is a slice on the bcast_dim dimension 889 // (< NumDims). We need to split the a:a+bcast_dim_size into possibly 3 890 // sub-blocks: 891 // 892 // (1) a:b, where b is the smallest multiple of 893 // input_dims[bcast_dim_start] in [a, a+bcast_dim_size]. 894 // 895 // (2) b:c, where c is the largest multiple of input_dims[bcast_dim_start] 896 // in [a, a+bcast_dim_size]. 897 // 898 // (3) c:a+bcast_dim_size . 899 // 900 // Or, when b and c do not exist, we just need to process the whole block 901 // together. 902 903 // Find a. 904 const Index bcast_dim_left_index = 905 bcast_offset / m_outputStrides[params.bcast_dim]; 906 907 // Find b and c. 908 const Index input_bcast_dim_size = params.input_dims[params.bcast_dim]; 909 910 // First multiple after a. This is b when <= bcast_dim_left_index + 911 // bcast_dim_size. 912 const Index first_multiple = 913 divup<Index>(bcast_dim_left_index, input_bcast_dim_size) * 914 input_bcast_dim_size; 915 916 if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) { 917 // b exists, so does c. Find it. 918 const Index last_multiple = 919 (bcast_dim_left_index + params.bcast_dim_size) / 920 input_bcast_dim_size * input_bcast_dim_size; 921 const int copy_bcast_dim = 922 IsColMajor ? 2 * params.inner_dim_count 923 : 2 * NumDims - 2 * params.inner_dim_count - 1; 924 const int broadcast_bcast_dim = 925 IsColMajor ? 2 * params.inner_dim_count + 1 926 : 2 * NumDims - 2 * params.inner_dim_count - 2; 927 928 if (first_multiple > bcast_dim_left_index) { 929 const Index head_size = first_multiple - bcast_dim_left_index; 930 params.input_block_sizes[params.bcast_dim] = head_size; 931 params.bcast_block_sizes[copy_bcast_dim] = head_size; 932 params.bcast_input_strides[copy_bcast_dim] = 933 params.input_block_strides[params.bcast_dim]; 934 params.bcast_block_strides[copy_bcast_dim] = 935 params.output_strides[params.bcast_dim]; 936 params.bcast_block_sizes[broadcast_bcast_dim] = 1; 937 params.bcast_input_strides[broadcast_bcast_dim] = 0; 938 params.bcast_block_strides[broadcast_bcast_dim] = 939 params.output_strides[params.bcast_dim] * 940 params.input_dims[params.bcast_dim]; 941 942 num_output_coeffs += BroadcastBlock( 943 params.input_block_sizes, params.input_block_strides, 944 params.bcast_block_sizes, params.bcast_block_strides, 945 params.bcast_input_strides, bcast_offset, 0, scratch, 946 materialized_output, materialized_input, materialized_input_size); 947 } 948 if (first_multiple < last_multiple) { 949 params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size; 950 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size; 951 params.bcast_input_strides[copy_bcast_dim] = 952 params.input_block_strides[params.bcast_dim]; 953 params.bcast_block_strides[copy_bcast_dim] = 954 params.output_strides[params.bcast_dim]; 955 params.bcast_block_sizes[broadcast_bcast_dim] = 956 (last_multiple - first_multiple) / input_bcast_dim_size; 957 params.bcast_input_strides[broadcast_bcast_dim] = 0; 958 params.bcast_block_strides[broadcast_bcast_dim] = 959 params.output_strides[params.bcast_dim] * 960 params.input_dims[params.bcast_dim]; 961 const Index offset = (first_multiple - bcast_dim_left_index) * 962 m_outputStrides[params.bcast_dim]; 963 964 num_output_coeffs += BroadcastBlock( 965 params.input_block_sizes, params.input_block_strides, 966 params.bcast_block_sizes, params.bcast_block_strides, 967 params.bcast_input_strides, bcast_offset, offset, scratch, 968 materialized_output, materialized_input, materialized_input_size); 969 } 970 if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) { 971 const Index tail_size = 972 bcast_dim_left_index + params.bcast_dim_size - last_multiple; 973 params.input_block_sizes[params.bcast_dim] = tail_size; 974 params.bcast_block_sizes[copy_bcast_dim] = tail_size; 975 params.bcast_input_strides[copy_bcast_dim] = 976 params.input_block_strides[params.bcast_dim]; 977 params.bcast_block_strides[copy_bcast_dim] = 978 params.output_strides[params.bcast_dim]; 979 params.bcast_block_sizes[broadcast_bcast_dim] = 1; 980 params.bcast_input_strides[broadcast_bcast_dim] = 0; 981 params.bcast_block_strides[broadcast_bcast_dim] = 982 params.output_strides[params.bcast_dim] * 983 params.input_dims[params.bcast_dim]; 984 const Index offset = (last_multiple - bcast_dim_left_index) * 985 m_outputStrides[params.bcast_dim]; 986 987 num_output_coeffs += BroadcastBlock( 988 params.input_block_sizes, params.input_block_strides, 989 params.bcast_block_sizes, params.bcast_block_strides, 990 params.bcast_input_strides, bcast_offset, offset, scratch, 991 materialized_output, materialized_input, materialized_input_size); 992 } 993 } else { 994 // b and c do not exist. 995 const int copy_bcast_dim = 996 IsColMajor ? 2 * params.inner_dim_count 997 : 2 * NumDims - 2 * params.inner_dim_count - 1; 998 params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size; 999 params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size; 1000 params.bcast_input_strides[copy_bcast_dim] = 1001 params.input_block_strides[params.bcast_dim]; 1002 params.bcast_block_strides[copy_bcast_dim] = 1003 params.output_strides[params.bcast_dim]; 1004 1005 num_output_coeffs += BroadcastBlock( 1006 params.input_block_sizes, params.input_block_strides, 1007 params.bcast_block_sizes, params.bcast_block_strides, 1008 params.bcast_input_strides, bcast_offset, 0, scratch, 1009 materialized_output, materialized_input, materialized_input_size); 1010 } 1011 1012 return num_output_coeffs; 1013 } 1014 } 1015 1016 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock( 1017 const Dimensions& input_block_sizes, 1018 const Dimensions& input_block_strides, 1019 const BroadcastDimensions& bcast_block_sizes, 1020 const BroadcastDimensions& bcast_block_strides, 1021 const BroadcastDimensions& bcast_input_strides, Index bcast_offset, 1022 Index offset, TensorBlockScratch& scratch, 1023 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input, 1024 size_t* materialized_input_size) const { 1025 // ---------------------------------------------------------------------- // 1026 // Tensor block descriptor for reading block from the input. 1027 const Index input_offset = bcast_offset + offset; 1028 TensorBlockDesc input_desc( 1029 IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset), 1030 input_block_sizes); 1031 1032 ArgTensorBlock input_block = m_impl.block(input_desc, scratch); 1033 1034 // ---------------------------------------------------------------------- // 1035 // Materialize input block into a temporary memory buffer only if it's not 1036 // already available in the arg block. 1037 const ScalarNoConst* input_buffer = NULL; 1038 1039 if (input_block.data() != NULL) { 1040 // Input block already has raw data, there is no need to materialize it. 1041 input_buffer = input_block.data(); 1042 1043 } else { 1044 // Otherwise we have to do block assignment into a temporary buffer. 1045 1046 // Maybe reuse previously allocated buffer, or allocate a new one with a 1047 // scratch allocator. 1048 const size_t input_total_size = input_block_sizes.TotalSize(); 1049 if (*materialized_input == NULL || 1050 *materialized_input_size < input_total_size) { 1051 *materialized_input_size = input_total_size; 1052 void* mem = scratch.allocate(*materialized_input_size * sizeof(Scalar)); 1053 *materialized_input = static_cast<ScalarNoConst*>(mem); 1054 } 1055 1056 typedef internal::TensorBlockAssignment< 1057 ScalarNoConst, NumDims, typename ArgTensorBlock::XprType, Index> 1058 TensorBlockAssignment; 1059 1060 TensorBlockAssignment::Run( 1061 TensorBlockAssignment::target(input_block_sizes, input_block_strides, 1062 *materialized_input), 1063 input_block.expr()); 1064 1065 input_buffer = *materialized_input; 1066 } 1067 1068 // ---------------------------------------------------------------------- // 1069 // Copy data from materialized input block to the materialized output, using 1070 // given broadcast strides (strides with zeroes). 1071 typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout> 1072 TensorBlockIO; 1073 1074 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer); 1075 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides, 1076 materialized_output + offset); 1077 1078 return TensorBlockIO::Copy(dst, src); 1079 } 1080 1081 protected: 1082 const Device EIGEN_DEVICE_REF m_device; 1083 const typename internal::remove_reference<Broadcast>::type m_broadcast; 1084 Dimensions m_dimensions; 1085 array<Index, NumDims> m_outputStrides; 1086 array<Index, NumDims> m_inputStrides; 1087 TensorEvaluator<ArgType, Device> m_impl; 1088 }; 1089 1090 1091 } // end namespace Eigen 1092 1093 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 1094