1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_ 18 19 #include <algorithm> 20 #include <array> 21 #include <functional> 22 #include <initializer_list> 23 #include <iterator> 24 #include <memory> 25 #include <numeric> 26 #include <random> 27 #include <type_traits> 28 #include <vector> 29 30 #include "absl/strings/str_cat.h" 31 #include "absl/strings/str_join.h" 32 #include "absl/types/span.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/core/lib/core/bits.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 42 namespace array_impl { 43 44 // conjunction 45 // 46 // Performs a compile-time logical AND operation on the passed types (which 47 // must have `::value` members convertible to `bool`. Short-circuits if it 48 // encounters any `false` members (and does not compare the `::value` members 49 // of any remaining arguments). 50 // 51 // This metafunction is designed to be a drop-in replacement for the C++17 52 // `std::conjunction` metafunction. 53 template <typename... Ts> 54 struct conjunction; 55 56 template <typename T, typename... Ts> 57 struct conjunction<T, Ts...> 58 : std::conditional<T::value, conjunction<Ts...>, T>::type {}; 59 60 template <> 61 struct conjunction<> : std::true_type {}; 62 63 // A type trait that is valid when all elements in a parameter pack are of 64 // integral type. 65 template <typename... T> 66 using pack_is_integral = conjunction<std::is_integral<T>...>; 67 68 // Compares three same-sized vectors elementwise. For each item in `values`, 69 // returns false if any of values[i] is outside the half-open range [starts[i], 70 // ends[i]). 71 template <typename C1, typename C2, typename C3> 72 bool all_inside_range(const C1& values, const C2& range_starts, 73 const C3& range_ends) { 74 for (size_t i = 0, e = values.size(); i < e; ++i) { 75 if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { 76 return false; 77 } 78 } 79 return true; 80 } 81 82 } // namespace array_impl 83 84 // General N dimensional array class with arbitrary value type. 85 template <typename T> 86 class Array { 87 public: 88 // Type inference can have a hard time parsing very deep initializer list 89 // nests, especially if one or more dimensions is one as the compiler just 90 // sees a single-element integer initializer. These typedefs allow casting 91 // explicitly with less typing. 92 using InitializerList1D = std::initializer_list<T>; 93 using InitializerList2D = std::initializer_list<InitializerList1D>; 94 using InitializerList3D = std::initializer_list<InitializerList2D>; 95 using InitializerList4D = std::initializer_list<InitializerList3D>; 96 97 using value_type = T; 98 99 // Creates a new array with the specified dimensions. 100 explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {} 101 102 // Creates a new array with the specified dimensions and specified value for 103 // every cell. 104 Array(absl::Span<const int64> sizes, T value) 105 : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { 106 Fill(value); 107 } 108 109 // Creates a 2D array from the given nested initializer list. The outer 110 // initializer list is the first dimension, the inner is the second dimension. 111 // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. 112 Array(InitializerList2D values) 113 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 114 int64 idx = 0; 115 for (const auto& it1 : values) { 116 for (const auto& it2 : it1) { 117 values_[idx] = it2; 118 ++idx; 119 } 120 } 121 CHECK(idx == num_elements()); 122 } 123 124 // Creates a 1D array of a floating-point type (half, bfloat16, float, 125 // or double) from an initializer list of float values. 126 template <typename T2, typename = typename std::enable_if< 127 (std::is_same<T, Eigen::half>::value || 128 std::is_same<T, bfloat16>::value || 129 std::is_same<T, float>::value || 130 std::is_same<T, double>::value) && 131 std::is_same<T2, float>::value>::type> 132 Array(std::initializer_list<T2> values) 133 : Array(ToInt64Vector({values.size()})) { 134 int64 idx = 0; 135 for (const auto& it1 : values) { 136 values_[idx] = static_cast<T>(it1); 137 ++idx; 138 } 139 CHECK(idx == num_elements()); 140 } 141 142 // Creates a 2D array of a floating-point type (half, bfloat16, float, 143 // or double) from an initializer list of float values. 144 template <typename T2, typename = typename std::enable_if< 145 (std::is_same<T, Eigen::half>::value || 146 std::is_same<T, bfloat16>::value || 147 std::is_same<T, float>::value || 148 std::is_same<T, double>::value) && 149 std::is_same<T2, float>::value>::type> 150 Array(std::initializer_list<std::initializer_list<T2>> values) 151 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 152 int64 idx = 0; 153 for (const auto& it1 : values) { 154 for (const auto& it2 : it1) { 155 values_[idx] = static_cast<T>(it2); 156 ++idx; 157 } 158 } 159 CHECK(idx == num_elements()); 160 } 161 162 // Creates a 3D array from the given nested initializer list. The outer 163 // initializer list is the first dimension, and so on. 164 Array(InitializerList3D values) 165 : Array(ToInt64Vector({values.size(), values.begin()->size(), 166 values.begin()->begin()->size()})) { 167 int64 idx = 0; 168 for (const auto& it1 : values) { 169 for (const auto& it2 : it1) { 170 for (const auto& it3 : it2) { 171 values_[idx] = it3; 172 ++idx; 173 } 174 } 175 } 176 CHECK(idx == num_elements()); 177 } 178 179 // Creates a 3D array of a floating-point type (half, bfloat16, float, 180 // or double) from an initializer list of float values. 181 template <typename T2, typename = typename std::enable_if< 182 (std::is_same<T, Eigen::half>::value || 183 std::is_same<T, bfloat16>::value || 184 std::is_same<T, float>::value || 185 std::is_same<T, double>::value) && 186 std::is_same<T2, float>::value>::type> 187 Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>> 188 values) 189 : Array(ToInt64Vector({values.size(), values.begin()->size(), 190 values.begin()->begin()->size()})) { 191 int64 idx = 0; 192 for (const auto& it1 : values) { 193 for (const auto& it2 : it1) { 194 for (const auto& it3 : it2) { 195 values_[idx] = static_cast<T>(it3); 196 ++idx; 197 } 198 } 199 } 200 CHECK(idx == num_elements()); 201 } 202 203 // Creates a 4D array from the given nested initializer list. The outer 204 // initializer list is the first dimension, and so on. 205 Array(InitializerList4D values) 206 : Array(ToInt64Vector({values.size(), values.begin()->size(), 207 values.begin()->begin()->size(), 208 values.begin()->begin()->begin()->size()})) { 209 int64 idx = 0; 210 for (const auto& it1 : values) { 211 for (const auto& it2 : it1) { 212 for (const auto& it3 : it2) { 213 for (const auto& it4 : it3) { 214 values_[idx] = it4; 215 ++idx; 216 } 217 } 218 } 219 } 220 CHECK(idx == num_elements()); 221 } 222 223 // Creates a 4D array of a floating-point type (half, bfloat16, float, 224 // or double) from an initializer list of float values. 225 template <typename T2, typename = typename std::enable_if< 226 (std::is_same<T, Eigen::half>::value || 227 std::is_same<T, bfloat16>::value || 228 std::is_same<T, float>::value || 229 std::is_same<T, double>::value) && 230 std::is_same<T2, float>::value>::type> 231 Array(std::initializer_list< 232 std::initializer_list<std::initializer_list<std::initializer_list<T2>>>> 233 values) 234 : Array(ToInt64Vector({values.size(), values.begin()->size(), 235 values.begin()->begin()->size(), 236 values.begin()->begin()->begin()->size()})) { 237 int64 idx = 0; 238 for (const auto& it1 : values) { 239 for (const auto& it2 : it1) { 240 for (const auto& it3 : it2) { 241 for (const auto& it4 : it3) { 242 values_[idx] = static_cast<T>(it4); 243 ++idx; 244 } 245 } 246 } 247 } 248 CHECK(idx == num_elements()); 249 } 250 251 Array(const Array<T>& other) 252 : sizes_(other.sizes_), values_(new T[num_elements()]) { 253 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 254 &values_[0]); 255 } 256 257 Array<T>& operator=(const Array<T>& other) { 258 sizes_ = other.sizes_; 259 values_.reset(new T[num_elements()]); 260 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 261 &values_[0]); 262 return *this; 263 } 264 265 // Fills the array with the specified value. 266 void Fill(const T& value) { 267 std::fill(&values_[0], &values_[0] + num_elements(), value); 268 } 269 270 // Fills the array with sequentially increasing values. 271 void FillIota(const T& value) { 272 std::iota(&values_[0], &values_[0] + num_elements(), value); 273 } 274 275 // Fills the array with a repeating sequence: 276 // [value, value + 1, ..., value + length - 1, value, ... ] 277 void FillRepeatedIota(const T& value, int64 length) { 278 for (int64 i = 0; i < num_elements(); i += length) { 279 std::iota(&values_[i], &values_[std::min(i + length, num_elements())], 280 value); 281 } 282 } 283 284 // Fills the array with the sequence i*multiplier for i=0,1,... 285 void FillWithMultiples(const T& multiplier) { 286 for (int64 i = 0; i < num_elements(); ++i) { 287 values_[i] = static_cast<T>(i) * multiplier; 288 } 289 } 290 291 // Fills the array with random normal variables with the specified mean. 292 void FillRandom(const T& stddev, const double mean = 0.0, 293 const int seed = 12345) { 294 std::mt19937 g(seed); 295 std::normal_distribution<double> distribution(mean, 296 static_cast<double>(stddev)); 297 for (int64 i = 0; i < num_elements(); ++i) { 298 values_[i] = static_cast<T>(distribution(g)); 299 } 300 } 301 302 // Sets all the values in the array to values specified in the container. 303 template <typename Container = std::initializer_list<T>> 304 void SetValues(const Container& container) { 305 CHECK_EQ(std::distance(std::begin(container), std::end(container)), 306 num_elements()); 307 std::copy(std::begin(container), std::end(container), &values_[0]); 308 } 309 310 // Invokes a callback with the (indices, value_ptr) for each cell in the 311 // array. 312 void Each(std::function<void(absl::Span<const int64>, T*)> f) { 313 std::vector<int64> index(sizes_.size()); 314 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 315 f(index, &values_[i]); 316 } 317 } 318 319 // Invokes a callback with the (indices, value) for each cell in the array. 320 void Each(std::function<void(absl::Span<const int64>, T)> f) const { 321 std::vector<int64> index(sizes_.size()); 322 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 323 f(index, values_[i]); 324 } 325 } 326 327 // Invokes a callback with the (indices, value_ptr) for each cell in the 328 // array. If a callback returns a non-OK status, returns that else returns 329 // Status::OK(). 330 Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) { 331 std::vector<int64> index(sizes_.size()); 332 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 333 Status s = f(index, &values_[i]); 334 if (!s.ok()) { 335 return s; 336 } 337 } 338 return Status::OK(); 339 } 340 341 // Invokes a callback with the (indices, value) for each cell in the array. 342 // If a callback returns a non-OK status, returns that else returns 343 // Status::OK(). 344 Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const { 345 std::vector<int64> index(sizes_.size()); 346 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 347 Status s = f(index, values_[i]); 348 if (!s.ok()) { 349 return s; 350 } 351 } 352 return Status::OK(); 353 } 354 355 // Returns the value at the cell specified by the indexes. The number of 356 // arguments have to match with the number of dimensions for the array. 357 // 358 // The type trait is required to avoid this overload participating too 359 // eagerly; a parameter pack can take zero or more elements, so we must 360 // restrict this to only parameter packs that are all of integral type. 361 template <typename... Dims> 362 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 363 const T&>::type 364 operator()(Dims... dims) const { 365 // We are using a std::array to avoid having to allocate memory in this 366 // function for performance reasons. 367 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 368 return values_[calculate_index(indexes)]; 369 } 370 371 // Returns the value at the cell specified by the indexes. The number of 372 // arguments have to match with the number of dimensions for the array. 373 template <typename... Dims> 374 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 375 T&>::type 376 operator()(Dims... dims) { 377 // We are using a std::array to avoid having to allocate memory in this 378 // function for performance reasons. 379 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 380 return values_[calculate_index(indexes)]; 381 } 382 383 // Returns the value at the cell specified by the indexes. The number of 384 // arguments have to match with the number of dimensions for the array. 385 const T& operator()(absl::Span<const int64> indexes) const { 386 return values_[calculate_index(indexes)]; 387 } 388 389 // Returns the value at the cell specified by the indexes. The number of 390 // arguments have to match with the number of dimensions for the array. 391 T& operator()(absl::Span<const int64> indexes) { 392 return values_[calculate_index(indexes)]; 393 } 394 395 // Low-level accessor for stuff like memcmp, handle with care. Returns pointer 396 // to the underlying storage of the array (similarly to std::vector::data()). 397 T* data() const { 398 // TODO(tberghammer): Get rid of the const_cast. Currently it is needed 399 // because the Eigen backend needs a non-const pointers even for reading 400 // from the array. 401 return const_cast<Array*>(this)->values_.get(); 402 } 403 404 // Returns the size of the dimension at the given index. 405 int64 dim(int64 n) const { 406 CHECK(n < sizes_.size()); 407 return sizes_[n]; 408 } 409 410 // Returns a vector containing the dimensions of the array. 411 const std::vector<int64>& dimensions() const { return sizes_; } 412 413 int64 num_dimensions() const { return sizes_.size(); } 414 415 // Returns the total number of elements in the array. 416 int64 num_elements() const { 417 return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, 418 std::multiplies<int64>()); 419 } 420 421 const T* begin() const { return &values_[0]; } 422 T* begin() { return &values_[0]; } 423 const T* end() const { return &values_[num_elements()]; } 424 T* end() { return &values_[num_elements()]; } 425 426 bool operator==(const Array<T>& other) const { 427 if (sizes_.size() != other.sizes_.size()) { 428 return false; 429 } 430 for (int64 i = 0; i < sizes_.size(); ++i) { 431 if (sizes_[i] != other.sizes_[i]) { 432 return false; 433 } 434 } 435 for (int64 i = 0; i < num_elements(); ++i) { 436 if (values_[i] != other.values_[i]) { 437 return false; 438 } 439 } 440 return true; 441 } 442 443 bool operator!=(const Array<T>& other) const { return !(*this == other); } 444 445 // Performs the equivalent of a slice operation on this array. 446 Array<T> Slice(absl::Span<const int64> starts, 447 absl::Span<const int64> limits) const { 448 CHECK_EQ(starts.size(), num_dimensions()); 449 CHECK_EQ(limits.size(), num_dimensions()); 450 451 std::vector<int64> sizes; 452 std::transform(starts.begin(), starts.end(), limits.begin(), 453 std::back_inserter(sizes), 454 [](int64 start, int64 limit) { return limit - start; }); 455 Array<T> result(sizes); 456 457 std::vector<int64> index(sizes_.size()); 458 int64 slice_i = 0; 459 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 460 if (array_impl::all_inside_range(index, starts, limits)) { 461 // Even though the bounds of result are different to our bounds, we're 462 // iterating in the same order. So we can simply write successive linear 463 // indices instead of recalculating a multi-dimensional index. 464 result.values_[slice_i++] = values_[i]; 465 } 466 } 467 return result; 468 } 469 470 // Performs the equivalent of a DynamicUpdateSlice in-place on this array. 471 void UpdateSlice(const Array<T>& from, 472 absl::Span<const int64> start_indices) { 473 CHECK_EQ(from.num_dimensions(), num_dimensions()); 474 std::vector<int64> limit_indices; 475 std::transform(start_indices.begin(), start_indices.end(), 476 from.dimensions().begin(), std::back_inserter(limit_indices), 477 std::plus<int64>{}); 478 std::vector<int64> index(sizes_.size()); 479 int64 from_i = 0; 480 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 481 if (array_impl::all_inside_range(index, start_indices, limit_indices)) { 482 // Even though the bounds of from are different to our bounds, we're 483 // iterating in the same order. So we can simply write successive linear 484 // indices instead of recalculating a multi-dimensional index. 485 values_[i] = from.values_[from_i++]; 486 } 487 } 488 } 489 490 // Performs an in-place reshape, modifying the dimensions but not the 491 // underlying data. 492 void Reshape(absl::Span<const int64> new_dimensions) { 493 int64 old_num_elements = num_elements(); 494 sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end()); 495 CHECK_EQ(num_elements(), old_num_elements); 496 } 497 498 // Returns a string representation of the array suitable for debugging. 499 string ToString() const { 500 std::vector<string> pieces; 501 std::vector<int64> index(sizes_.size()); 502 do { 503 // Emit leading spaces and opening square brackets 504 if (index.back() == 0) { 505 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 506 if (i == 0 || index[i - 1] != 0) { 507 for (int64 j = 0; j < sizes_.size(); ++j) { 508 pieces.push_back(j < i ? " " : "["); 509 } 510 break; 511 } 512 } 513 } 514 515 pieces.push_back(absl::StrCat(values_[calculate_index(index)])); 516 517 // Emit comma if it isn't the last element 518 if (index.back() != sizes_.back() - 1) { 519 pieces.push_back(", "); 520 } 521 522 // Emit closing square brackets 523 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 524 if (index[i] != sizes_[i] - 1) { 525 break; 526 } 527 pieces.push_back("]"); 528 if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) { 529 pieces.push_back(",\n"); 530 } 531 } 532 } while (next_index(&index)); 533 return absl::StrJoin(pieces, ""); 534 } 535 536 private: 537 // Converts an initializer_list of type U to a vector of type int64. Used by 538 // the initializer list based constructors to convert the size type into int64 539 // to be passed to the size based constructor. 540 template <typename U> 541 static std::vector<int64> ToInt64Vector( 542 const std::initializer_list<U>& data) { 543 return std::vector<int64>(data.begin(), data.end()); 544 } 545 546 // Returns the linear index from the list of per-dimension indexes. Function 547 // is templated so can be used with an std::array from operator() to avoid 548 // memory allocation. 549 template <typename U> 550 int64 calculate_index(const U& indexes) const { 551 CHECK_EQ(sizes_.size(), indexes.size()); 552 int64 index = 0; 553 for (int64 i = 0; i < sizes_.size(); ++i) { 554 index *= sizes_[i]; 555 index += indexes[i]; 556 } 557 return index; 558 } 559 560 // Advances the specified set of indexes and returns true if we haven't 561 // wrapped around (i.e. result isn't {0, 0, ...}). 562 bool next_index(std::vector<int64>* index) const { 563 CHECK_EQ(index->size(), sizes_.size()); 564 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 565 (*index)[i]++; 566 if ((*index)[i] < sizes_[i]) { 567 return true; 568 } 569 (*index)[i] = 0; 570 } 571 return false; 572 } 573 574 std::vector<int64> sizes_; 575 std::unique_ptr<T[]> values_; 576 }; 577 578 } // namespace xla 579 580 #endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ 581