1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_ 18 19 #include <algorithm> 20 #include <array> 21 #include <functional> 22 #include <initializer_list> 23 #include <iterator> 24 #include <memory> 25 #include <numeric> 26 #include <random> 27 #include <type_traits> 28 #include <vector> 29 30 #include "absl/strings/str_cat.h" 31 #include "absl/strings/str_join.h" 32 #include "absl/types/span.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/core/lib/core/bits.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 42 namespace array_impl { 43 44 // conjunction 45 // 46 // Performs a compile-time logical AND operation on the passed types (which 47 // must have `::value` members convertible to `bool`. Short-circuits if it 48 // encounters any `false` members (and does not compare the `::value` members 49 // of any remaining arguments). 50 // 51 // This metafunction is designed to be a drop-in replacement for the C++17 52 // `std::conjunction` metafunction. 53 template <typename... Ts> 54 struct conjunction; 55 56 template <typename T, typename... Ts> 57 struct conjunction<T, Ts...> 58 : std::conditional<T::value, conjunction<Ts...>, T>::type {}; 59 60 template <> 61 struct conjunction<> : std::true_type {}; 62 63 // A type trait that is valid when all elements in a parameter pack are of 64 // integral type. Not using an alias template to work around MSVC 14.00 bug. 65 template <typename... Ts> 66 struct pack_is_integral : conjunction<std::is_integral<Ts>...> {}; 67 68 // Compares three same-sized vectors elementwise. For each item in `values`, 69 // returns false if any of values[i] is outside the half-open range [starts[i], 70 // ends[i]). 71 template <typename C1, typename C2, typename C3> 72 bool all_inside_range(const C1& values, const C2& range_starts, 73 const C3& range_ends) { 74 for (size_t i = 0, e = values.size(); i < e; ++i) { 75 if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { 76 return false; 77 } 78 } 79 return true; 80 } 81 82 } // namespace array_impl 83 84 // General N dimensional array class with arbitrary value type. 85 template <typename T> 86 class Array { 87 public: 88 // Type inference can have a hard time parsing very deep initializer list 89 // nests, especially if one or more dimensions is one as the compiler just 90 // sees a single-element integer initializer. These typedefs allow casting 91 // explicitly with less typing. 92 using InitializerList1D = std::initializer_list<T>; 93 using InitializerList2D = std::initializer_list<InitializerList1D>; 94 using InitializerList3D = std::initializer_list<InitializerList2D>; 95 using InitializerList4D = std::initializer_list<InitializerList3D>; 96 97 using value_type = T; 98 99 // Creates a new array with the specified dimensions and initialized elements. 100 explicit Array(absl::Span<const int64> sizes) 101 : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]()) {} 102 103 // Creates a new array with the specified dimensions and specified value for 104 // every cell. 105 Array(absl::Span<const int64> sizes, T value) 106 : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { 107 Fill(value); 108 } 109 110 // Creates a 2D array from the given nested initializer list. The outer 111 // initializer list is the first dimension, the inner is the second dimension. 112 // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. 113 Array(InitializerList2D values) 114 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 115 int64_t idx = 0; 116 for (const auto& it1 : values) { 117 for (const auto& it2 : it1) { 118 values_[idx] = it2; 119 ++idx; 120 } 121 } 122 CHECK(idx == num_elements()); 123 } 124 125 // Creates a 1D array of a floating-point type (half, bfloat16, float, 126 // or double) from an initializer list of float values. 127 template <typename T2, typename = typename std::enable_if< 128 (std::is_same<T, Eigen::half>::value || 129 std::is_same<T, bfloat16>::value || 130 std::is_same<T, float>::value || 131 std::is_same<T, double>::value) && 132 std::is_same<T2, float>::value>::type> 133 Array(std::initializer_list<T2> values) 134 : Array(ToInt64Vector({values.size()})) { 135 int64_t idx = 0; 136 for (const auto& it1 : values) { 137 values_[idx] = static_cast<T>(it1); 138 ++idx; 139 } 140 CHECK(idx == num_elements()); 141 } 142 143 // Creates a 2D array of a floating-point type (half, bfloat16, float, 144 // or double) from an initializer list of float values. 145 template <typename T2, typename = typename std::enable_if< 146 (std::is_same<T, Eigen::half>::value || 147 std::is_same<T, bfloat16>::value || 148 std::is_same<T, float>::value || 149 std::is_same<T, double>::value) && 150 std::is_same<T2, float>::value>::type> 151 Array(std::initializer_list<std::initializer_list<T2>> values) 152 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 153 int64_t idx = 0; 154 for (const auto& it1 : values) { 155 for (const auto& it2 : it1) { 156 values_[idx] = static_cast<T>(it2); 157 ++idx; 158 } 159 } 160 CHECK(idx == num_elements()); 161 } 162 163 // Creates a 3D array from the given nested initializer list. The outer 164 // initializer list is the first dimension, and so on. 165 Array(InitializerList3D values) 166 : Array(ToInt64Vector({values.size(), values.begin()->size(), 167 values.begin()->begin()->size()})) { 168 int64_t idx = 0; 169 for (const auto& it1 : values) { 170 for (const auto& it2 : it1) { 171 for (const auto& it3 : it2) { 172 values_[idx] = it3; 173 ++idx; 174 } 175 } 176 } 177 CHECK(idx == num_elements()); 178 } 179 180 // Creates a 3D array of a floating-point type (half, bfloat16, float, 181 // or double) from an initializer list of float values. 182 template <typename T2, typename = typename std::enable_if< 183 (std::is_same<T, Eigen::half>::value || 184 std::is_same<T, bfloat16>::value || 185 std::is_same<T, float>::value || 186 std::is_same<T, double>::value) && 187 std::is_same<T2, float>::value>::type> 188 Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>> 189 values) 190 : Array(ToInt64Vector({values.size(), values.begin()->size(), 191 values.begin()->begin()->size()})) { 192 int64_t idx = 0; 193 for (const auto& it1 : values) { 194 for (const auto& it2 : it1) { 195 for (const auto& it3 : it2) { 196 values_[idx] = static_cast<T>(it3); 197 ++idx; 198 } 199 } 200 } 201 CHECK(idx == num_elements()); 202 } 203 204 // Creates a 4D array from the given nested initializer list. The outer 205 // initializer list is the first dimension, and so on. 206 Array(InitializerList4D values) 207 : Array(ToInt64Vector({values.size(), values.begin()->size(), 208 values.begin()->begin()->size(), 209 values.begin()->begin()->begin()->size()})) { 210 int64_t idx = 0; 211 for (const auto& it1 : values) { 212 for (const auto& it2 : it1) { 213 for (const auto& it3 : it2) { 214 for (const auto& it4 : it3) { 215 values_[idx] = it4; 216 ++idx; 217 } 218 } 219 } 220 } 221 CHECK(idx == num_elements()); 222 } 223 224 // Creates a 4D array of a floating-point type (half, bfloat16, float, 225 // or double) from an initializer list of float values. 226 template <typename T2, typename = typename std::enable_if< 227 (std::is_same<T, Eigen::half>::value || 228 std::is_same<T, bfloat16>::value || 229 std::is_same<T, float>::value || 230 std::is_same<T, double>::value) && 231 std::is_same<T2, float>::value>::type> 232 Array(std::initializer_list< 233 std::initializer_list<std::initializer_list<std::initializer_list<T2>>>> 234 values) 235 : Array(ToInt64Vector({values.size(), values.begin()->size(), 236 values.begin()->begin()->size(), 237 values.begin()->begin()->begin()->size()})) { 238 int64_t idx = 0; 239 for (const auto& it1 : values) { 240 for (const auto& it2 : it1) { 241 for (const auto& it3 : it2) { 242 for (const auto& it4 : it3) { 243 values_[idx] = static_cast<T>(it4); 244 ++idx; 245 } 246 } 247 } 248 } 249 CHECK(idx == num_elements()); 250 } 251 252 Array(const Array<T>& other) 253 : sizes_(other.sizes_), values_(new T[num_elements()]) { 254 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 255 &values_[0]); 256 } 257 258 Array<T>& operator=(const Array<T>& other) { 259 sizes_ = other.sizes_; 260 values_.reset(new T[num_elements()]); 261 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 262 &values_[0]); 263 return *this; 264 } 265 266 // Fills the array with the specified value. 267 void Fill(const T& value) { 268 std::fill(&values_[0], &values_[0] + num_elements(), value); 269 } 270 271 // Fills the array with sequentially increasing values. 272 void FillIota(const T& value) { 273 std::iota(&values_[0], &values_[0] + num_elements(), value); 274 } 275 276 // Fills the array with a repeating sequence: 277 // [value, value + 1, ..., value + length - 1, value, ... ] 278 void FillRepeatedIota(const T& value, int64_t length) { 279 for (int64_t i = 0; i < num_elements(); i += length) { 280 std::iota(&values_[i], &values_[std::min(i + length, num_elements())], 281 value); 282 } 283 } 284 285 // Fills the array with the sequence i*multiplier for i=0,1,... 286 void FillWithMultiples(const T& multiplier) { 287 for (int64_t i = 0; i < num_elements(); ++i) { 288 values_[i] = static_cast<T>(i) * multiplier; 289 } 290 } 291 292 // Fills the array with random normal variables with the specified mean. 293 void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) { 294 FillRandomDouble(static_cast<double>(stddev), mean, seed); 295 } 296 297 void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) { 298 std::mt19937 g(seed); 299 std::normal_distribution<double> distribution(mean, stddev); 300 for (int64_t i = 0; i < num_elements(); ++i) { 301 if (std::is_same<T, bool>()) { 302 values_[i] = static_cast<T>(distribution(g) > 0.0); 303 } else { 304 values_[i] = static_cast<T>(distribution(g)); 305 } 306 } 307 } 308 309 // Sets all the values in the array to values specified in the container. 310 template <typename Container = std::initializer_list<T>> 311 void SetValues(const Container& container) { 312 CHECK_EQ(std::distance(std::begin(container), std::end(container)), 313 num_elements()); 314 std::copy(std::begin(container), std::end(container), &values_[0]); 315 } 316 317 // Invokes a callback with the (indices, value_ptr) for each cell in the 318 // array. 319 void Each(std::function<void(absl::Span<const int64>, T*)> f) { 320 std::vector<int64> index(sizes_.size()); 321 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 322 f(index, &values_[i]); 323 } 324 } 325 326 // Invokes a callback with the (indices, value) for each cell in the array. 327 void Each(std::function<void(absl::Span<const int64>, T)> f) const { 328 std::vector<int64> index(sizes_.size()); 329 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 330 f(index, values_[i]); 331 } 332 } 333 334 // Invokes a callback with the (indices, value_ptr) for each cell in the 335 // array. If a callback returns a non-OK status, returns that else returns 336 // Status::OK(). 337 Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) { 338 std::vector<int64> index(sizes_.size()); 339 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 340 Status s = f(index, &values_[i]); 341 if (!s.ok()) { 342 return s; 343 } 344 } 345 return Status::OK(); 346 } 347 348 // Invokes a callback with the (indices, value) for each cell in the array. 349 // If a callback returns a non-OK status, returns that else returns 350 // Status::OK(). 351 Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const { 352 std::vector<int64> index(sizes_.size()); 353 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 354 Status s = f(index, values_[i]); 355 if (!s.ok()) { 356 return s; 357 } 358 } 359 return Status::OK(); 360 } 361 362 // Returns the value at the cell specified by the indexes. The number of 363 // arguments have to match with the number of dimensions for the array. 364 // 365 // The type trait is required to avoid this overload participating too 366 // eagerly; a parameter pack can take zero or more elements, so we must 367 // restrict this to only parameter packs that are all of integral type. 368 template <typename... Dims> 369 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 370 const T&>::type 371 operator()(Dims... dims) const { 372 // We are using a std::array to avoid having to allocate memory in this 373 // function for performance reasons. 374 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 375 return values_[calculate_index(indexes)]; 376 } 377 378 // Returns the value at the cell specified by the indexes. The number of 379 // arguments have to match with the number of dimensions for the array. 380 template <typename... Dims> 381 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 382 T&>::type 383 operator()(Dims... dims) { 384 // We are using a std::array to avoid having to allocate memory in this 385 // function for performance reasons. 386 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 387 return values_[calculate_index(indexes)]; 388 } 389 390 // Returns the value at the cell specified by the indexes. The number of 391 // arguments have to match with the number of dimensions for the array. 392 const T& operator()(absl::Span<const int64> indexes) const { 393 return values_[calculate_index(indexes)]; 394 } 395 396 // Returns the value at the cell specified by the indexes. The number of 397 // arguments have to match with the number of dimensions for the array. 398 T& operator()(absl::Span<const int64> indexes) { 399 return values_[calculate_index(indexes)]; 400 } 401 402 // Low-level accessor for stuff like memcmp, handle with care. Returns pointer 403 // to the underlying storage of the array (similarly to std::vector::data()). 404 T* data() const { 405 // TODO(tberghammer): Get rid of the const_cast. Currently it is needed 406 // because the Eigen backend needs a non-const pointers even for reading 407 // from the array. 408 return const_cast<Array*>(this)->values_.get(); 409 } 410 411 // Returns the size of the dimension at the given index. 412 int64 dim(int64_t n) const { 413 const int64_t sizes_size = sizes_.size(); 414 CHECK(n < sizes_size); 415 return sizes_[n]; 416 } 417 418 // Returns a vector containing the dimensions of the array. 419 const std::vector<int64>& dimensions() const { return sizes_; } 420 421 int64 num_dimensions() const { return sizes_.size(); } 422 423 // Returns the total number of elements in the array. 424 int64 num_elements() const { 425 return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, 426 std::multiplies<int64>()); 427 } 428 429 const T* begin() const { return &values_[0]; } 430 T* begin() { return &values_[0]; } 431 const T* end() const { return &values_[num_elements()]; } 432 T* end() { return &values_[num_elements()]; } 433 434 bool operator==(const Array<T>& other) const { 435 if (sizes_.size() != other.sizes_.size()) { 436 return false; 437 } 438 for (int64_t i = 0, end = sizes_.size(); i < end; ++i) { 439 if (sizes_[i] != other.sizes_[i]) { 440 return false; 441 } 442 } 443 for (int64_t i = 0; i < num_elements(); ++i) { 444 if (values_[i] != other.values_[i]) { 445 return false; 446 } 447 } 448 return true; 449 } 450 451 bool operator!=(const Array<T>& other) const { return !(*this == other); } 452 453 // Performs the equivalent of a slice operation on this array. 454 Array<T> Slice(absl::Span<const int64> starts, 455 absl::Span<const int64> limits) const { 456 CHECK_EQ(starts.size(), num_dimensions()); 457 CHECK_EQ(limits.size(), num_dimensions()); 458 459 std::vector<int64> sizes; 460 std::transform(starts.begin(), starts.end(), limits.begin(), 461 std::back_inserter(sizes), 462 [](int64_t start, int64_t limit) { return limit - start; }); 463 Array<T> result(sizes); 464 465 std::vector<int64> index(sizes_.size()); 466 int64_t slice_i = 0; 467 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 468 if (array_impl::all_inside_range(index, starts, limits)) { 469 // Even though the bounds of result are different to our bounds, we're 470 // iterating in the same order. So we can simply write successive linear 471 // indices instead of recalculating a multi-dimensional index. 472 result.values_[slice_i++] = values_[i]; 473 } 474 } 475 return result; 476 } 477 478 // Performs the equivalent of a DynamicUpdateSlice in-place on this array. 479 void UpdateSlice(const Array<T>& from, 480 absl::Span<const int64> start_indices) { 481 CHECK_EQ(from.num_dimensions(), num_dimensions()); 482 std::vector<int64> limit_indices; 483 std::transform(start_indices.begin(), start_indices.end(), 484 from.dimensions().begin(), std::back_inserter(limit_indices), 485 std::plus<int64>{}); 486 std::vector<int64> index(sizes_.size()); 487 int64_t from_i = 0; 488 for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { 489 if (array_impl::all_inside_range(index, start_indices, limit_indices)) { 490 // Even though the bounds of from are different to our bounds, we're 491 // iterating in the same order. So we can simply write successive linear 492 // indices instead of recalculating a multi-dimensional index. 493 values_[i] = from.values_[from_i++]; 494 } 495 } 496 } 497 498 // Performs an in-place reshape, modifying the dimensions but not the 499 // underlying data. 500 void Reshape(absl::Span<const int64> new_dimensions) { 501 int64_t old_num_elements = num_elements(); 502 sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end()); 503 CHECK_EQ(num_elements(), old_num_elements); 504 } 505 506 // Returns a string representation of the array suitable for debugging. 507 string ToString() const { 508 std::vector<string> pieces; 509 std::vector<int64> index(sizes_.size()); 510 do { 511 // Emit leading spaces and opening square brackets 512 if (index.back() == 0) { 513 for (int64_t i = sizes_.size() - 1; i >= 0; --i) { 514 if (i == 0 || index[i - 1] != 0) { 515 for (int64_t j = 0; j < sizes_.size(); ++j) { 516 pieces.push_back(j < i ? " " : "["); 517 } 518 break; 519 } 520 } 521 } 522 523 pieces.push_back(absl::StrCat(values_[calculate_index(index)])); 524 525 // Emit comma if it isn't the last element 526 if (index.back() != sizes_.back() - 1) { 527 pieces.push_back(", "); 528 } 529 530 // Emit closing square brackets 531 for (int64_t i = sizes_.size() - 1; i >= 0; --i) { 532 if (index[i] != sizes_[i] - 1) { 533 break; 534 } 535 pieces.push_back("]"); 536 if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) { 537 pieces.push_back(",\n"); 538 } 539 } 540 } while (next_index(&index)); 541 return absl::StrJoin(pieces, ""); 542 } 543 544 private: 545 // Converts an initializer_list of type U to a vector of type int64. Used by 546 // the initializer list based constructors to convert the size type into int64 547 // to be passed to the size based constructor. 548 template <typename U> 549 static std::vector<int64> ToInt64Vector( 550 const std::initializer_list<U>& data) { 551 return std::vector<int64>(data.begin(), data.end()); 552 } 553 554 // Returns the linear index from the list of per-dimension indexes. Function 555 // is templated so can be used with an std::array from operator() to avoid 556 // memory allocation. 557 template <typename U> 558 int64 calculate_index(const U& indexes) const { 559 CHECK_EQ(sizes_.size(), indexes.size()); 560 int64_t index = 0; 561 for (int64_t i = 0; i < sizes_.size(); ++i) { 562 index *= sizes_[i]; 563 index += indexes[i]; 564 } 565 DCHECK_LT(index, this->num_elements()); 566 return index; 567 } 568 569 // Advances the specified set of indexes and returns true if we haven't 570 // wrapped around (i.e. result isn't {0, 0, ...}). 571 bool next_index(std::vector<int64>* index) const { 572 CHECK_EQ(index->size(), sizes_.size()); 573 for (int64_t i = sizes_.size() - 1; i >= 0; --i) { 574 (*index)[i]++; 575 if ((*index)[i] < sizes_[i]) { 576 return true; 577 } 578 (*index)[i] = 0; 579 } 580 return false; 581 } 582 583 std::vector<int64> sizes_; 584 std::unique_ptr<T[]> values_; 585 }; 586 587 // Specialization of FillRandom() method for complex64 type. Uses real part of 588 // the stddev parameter as the standard deviation value. 589 template <> 590 void Array<complex64>::FillRandom(const complex64& stddev, const double mean, 591 const int seed); 592 593 } // namespace xla 594 595 #endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ 596