1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_ 18 19 #include <algorithm> 20 #include <array> 21 #include <functional> 22 #include <initializer_list> 23 #include <iterator> 24 #include <memory> 25 #include <numeric> 26 #include <random> 27 #include <type_traits> 28 #include <vector> 29 30 #include "absl/strings/str_cat.h" 31 #include "absl/strings/str_join.h" 32 #include "absl/types/span.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/core/lib/core/bits.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace xla { 41 42 namespace array_impl { 43 44 // conjunction 45 // 46 // Performs a compile-time logical AND operation on the passed types (which 47 // must have `::value` members convertible to `bool`. Short-circuits if it 48 // encounters any `false` members (and does not compare the `::value` members 49 // of any remaining arguments). 50 // 51 // This metafunction is designed to be a drop-in replacement for the C++17 52 // `std::conjunction` metafunction. 53 template <typename... Ts> 54 struct conjunction; 55 56 template <typename T, typename... Ts> 57 struct conjunction<T, Ts...> 58 : std::conditional<T::value, conjunction<Ts...>, T>::type {}; 59 60 template <> 61 struct conjunction<> : std::true_type {}; 62 63 // A type trait that is valid when all elements in a parameter pack are of 64 // integral type. Not using an alias template to work around MSVC 14.00 bug. 65 template <typename... Ts> 66 struct pack_is_integral : conjunction<std::is_integral<Ts>...> {}; 67 68 // Compares three same-sized vectors elementwise. For each item in `values`, 69 // returns false if any of values[i] is outside the half-open range [starts[i], 70 // ends[i]). 71 template <typename C1, typename C2, typename C3> 72 bool all_inside_range(const C1& values, const C2& range_starts, 73 const C3& range_ends) { 74 for (size_t i = 0, e = values.size(); i < e; ++i) { 75 if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { 76 return false; 77 } 78 } 79 return true; 80 } 81 82 } // namespace array_impl 83 84 // General N dimensional array class with arbitrary value type. 85 template <typename T> 86 class Array { 87 public: 88 // Type inference can have a hard time parsing very deep initializer list 89 // nests, especially if one or more dimensions is one as the compiler just 90 // sees a single-element integer initializer. These typedefs allow casting 91 // explicitly with less typing. 92 using InitializerList1D = std::initializer_list<T>; 93 using InitializerList2D = std::initializer_list<InitializerList1D>; 94 using InitializerList3D = std::initializer_list<InitializerList2D>; 95 using InitializerList4D = std::initializer_list<InitializerList3D>; 96 97 using value_type = T; 98 99 // Creates a new array with the specified dimensions. 100 explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {} 101 102 // Creates a new array with the specified dimensions and specified value for 103 // every cell. 104 Array(absl::Span<const int64> sizes, T value) 105 : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { 106 Fill(value); 107 } 108 109 // Creates a 2D array from the given nested initializer list. The outer 110 // initializer list is the first dimension, the inner is the second dimension. 111 // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. 112 Array(InitializerList2D values) 113 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 114 int64 idx = 0; 115 for (const auto& it1 : values) { 116 for (const auto& it2 : it1) { 117 values_[idx] = it2; 118 ++idx; 119 } 120 } 121 CHECK(idx == num_elements()); 122 } 123 124 // Creates a 1D array of a floating-point type (half, bfloat16, float, 125 // or double) from an initializer list of float values. 126 template <typename T2, typename = typename std::enable_if< 127 (std::is_same<T, Eigen::half>::value || 128 std::is_same<T, bfloat16>::value || 129 std::is_same<T, float>::value || 130 std::is_same<T, double>::value) && 131 std::is_same<T2, float>::value>::type> 132 Array(std::initializer_list<T2> values) 133 : Array(ToInt64Vector({values.size()})) { 134 int64 idx = 0; 135 for (const auto& it1 : values) { 136 values_[idx] = static_cast<T>(it1); 137 ++idx; 138 } 139 CHECK(idx == num_elements()); 140 } 141 142 // Creates a 2D array of a floating-point type (half, bfloat16, float, 143 // or double) from an initializer list of float values. 144 template <typename T2, typename = typename std::enable_if< 145 (std::is_same<T, Eigen::half>::value || 146 std::is_same<T, bfloat16>::value || 147 std::is_same<T, float>::value || 148 std::is_same<T, double>::value) && 149 std::is_same<T2, float>::value>::type> 150 Array(std::initializer_list<std::initializer_list<T2>> values) 151 : Array(ToInt64Vector({values.size(), values.begin()->size()})) { 152 int64 idx = 0; 153 for (const auto& it1 : values) { 154 for (const auto& it2 : it1) { 155 values_[idx] = static_cast<T>(it2); 156 ++idx; 157 } 158 } 159 CHECK(idx == num_elements()); 160 } 161 162 // Creates a 3D array from the given nested initializer list. The outer 163 // initializer list is the first dimension, and so on. 164 Array(InitializerList3D values) 165 : Array(ToInt64Vector({values.size(), values.begin()->size(), 166 values.begin()->begin()->size()})) { 167 int64 idx = 0; 168 for (const auto& it1 : values) { 169 for (const auto& it2 : it1) { 170 for (const auto& it3 : it2) { 171 values_[idx] = it3; 172 ++idx; 173 } 174 } 175 } 176 CHECK(idx == num_elements()); 177 } 178 179 // Creates a 3D array of a floating-point type (half, bfloat16, float, 180 // or double) from an initializer list of float values. 181 template <typename T2, typename = typename std::enable_if< 182 (std::is_same<T, Eigen::half>::value || 183 std::is_same<T, bfloat16>::value || 184 std::is_same<T, float>::value || 185 std::is_same<T, double>::value) && 186 std::is_same<T2, float>::value>::type> 187 Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>> 188 values) 189 : Array(ToInt64Vector({values.size(), values.begin()->size(), 190 values.begin()->begin()->size()})) { 191 int64 idx = 0; 192 for (const auto& it1 : values) { 193 for (const auto& it2 : it1) { 194 for (const auto& it3 : it2) { 195 values_[idx] = static_cast<T>(it3); 196 ++idx; 197 } 198 } 199 } 200 CHECK(idx == num_elements()); 201 } 202 203 // Creates a 4D array from the given nested initializer list. The outer 204 // initializer list is the first dimension, and so on. 205 Array(InitializerList4D values) 206 : Array(ToInt64Vector({values.size(), values.begin()->size(), 207 values.begin()->begin()->size(), 208 values.begin()->begin()->begin()->size()})) { 209 int64 idx = 0; 210 for (const auto& it1 : values) { 211 for (const auto& it2 : it1) { 212 for (const auto& it3 : it2) { 213 for (const auto& it4 : it3) { 214 values_[idx] = it4; 215 ++idx; 216 } 217 } 218 } 219 } 220 CHECK(idx == num_elements()); 221 } 222 223 // Creates a 4D array of a floating-point type (half, bfloat16, float, 224 // or double) from an initializer list of float values. 225 template <typename T2, typename = typename std::enable_if< 226 (std::is_same<T, Eigen::half>::value || 227 std::is_same<T, bfloat16>::value || 228 std::is_same<T, float>::value || 229 std::is_same<T, double>::value) && 230 std::is_same<T2, float>::value>::type> 231 Array(std::initializer_list< 232 std::initializer_list<std::initializer_list<std::initializer_list<T2>>>> 233 values) 234 : Array(ToInt64Vector({values.size(), values.begin()->size(), 235 values.begin()->begin()->size(), 236 values.begin()->begin()->begin()->size()})) { 237 int64 idx = 0; 238 for (const auto& it1 : values) { 239 for (const auto& it2 : it1) { 240 for (const auto& it3 : it2) { 241 for (const auto& it4 : it3) { 242 values_[idx] = static_cast<T>(it4); 243 ++idx; 244 } 245 } 246 } 247 } 248 CHECK(idx == num_elements()); 249 } 250 251 Array(const Array<T>& other) 252 : sizes_(other.sizes_), values_(new T[num_elements()]) { 253 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 254 &values_[0]); 255 } 256 257 Array<T>& operator=(const Array<T>& other) { 258 sizes_ = other.sizes_; 259 values_.reset(new T[num_elements()]); 260 std::copy(&other.values_[0], &other.values_[0] + num_elements(), 261 &values_[0]); 262 return *this; 263 } 264 265 // Fills the array with the specified value. 266 void Fill(const T& value) { 267 std::fill(&values_[0], &values_[0] + num_elements(), value); 268 } 269 270 // Fills the array with sequentially increasing values. 271 void FillIota(const T& value) { 272 std::iota(&values_[0], &values_[0] + num_elements(), value); 273 } 274 275 // Fills the array with a repeating sequence: 276 // [value, value + 1, ..., value + length - 1, value, ... ] 277 void FillRepeatedIota(const T& value, int64 length) { 278 for (int64 i = 0; i < num_elements(); i += length) { 279 std::iota(&values_[i], &values_[std::min(i + length, num_elements())], 280 value); 281 } 282 } 283 284 // Fills the array with the sequence i*multiplier for i=0,1,... 285 void FillWithMultiples(const T& multiplier) { 286 for (int64 i = 0; i < num_elements(); ++i) { 287 values_[i] = static_cast<T>(i) * multiplier; 288 } 289 } 290 291 // Fills the array with random normal variables with the specified mean. 292 void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) { 293 FillRandomDouble(static_cast<double>(stddev), mean, seed); 294 } 295 296 void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) { 297 std::mt19937 g(seed); 298 std::normal_distribution<double> distribution(mean, stddev); 299 for (int64 i = 0; i < num_elements(); ++i) { 300 if (std::is_same<T, bool>()) { 301 values_[i] = static_cast<T>(distribution(g) > 0.0); 302 } else { 303 values_[i] = static_cast<T>(distribution(g)); 304 } 305 } 306 } 307 308 // Sets all the values in the array to values specified in the container. 309 template <typename Container = std::initializer_list<T>> 310 void SetValues(const Container& container) { 311 CHECK_EQ(std::distance(std::begin(container), std::end(container)), 312 num_elements()); 313 std::copy(std::begin(container), std::end(container), &values_[0]); 314 } 315 316 // Invokes a callback with the (indices, value_ptr) for each cell in the 317 // array. 318 void Each(std::function<void(absl::Span<const int64>, T*)> f) { 319 std::vector<int64> index(sizes_.size()); 320 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 321 f(index, &values_[i]); 322 } 323 } 324 325 // Invokes a callback with the (indices, value) for each cell in the array. 326 void Each(std::function<void(absl::Span<const int64>, T)> f) const { 327 std::vector<int64> index(sizes_.size()); 328 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 329 f(index, values_[i]); 330 } 331 } 332 333 // Invokes a callback with the (indices, value_ptr) for each cell in the 334 // array. If a callback returns a non-OK status, returns that else returns 335 // Status::OK(). 336 Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) { 337 std::vector<int64> index(sizes_.size()); 338 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 339 Status s = f(index, &values_[i]); 340 if (!s.ok()) { 341 return s; 342 } 343 } 344 return Status::OK(); 345 } 346 347 // Invokes a callback with the (indices, value) for each cell in the array. 348 // If a callback returns a non-OK status, returns that else returns 349 // Status::OK(). 350 Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const { 351 std::vector<int64> index(sizes_.size()); 352 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 353 Status s = f(index, values_[i]); 354 if (!s.ok()) { 355 return s; 356 } 357 } 358 return Status::OK(); 359 } 360 361 // Returns the value at the cell specified by the indexes. The number of 362 // arguments have to match with the number of dimensions for the array. 363 // 364 // The type trait is required to avoid this overload participating too 365 // eagerly; a parameter pack can take zero or more elements, so we must 366 // restrict this to only parameter packs that are all of integral type. 367 template <typename... Dims> 368 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 369 const T&>::type 370 operator()(Dims... dims) const { 371 // We are using a std::array to avoid having to allocate memory in this 372 // function for performance reasons. 373 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 374 return values_[calculate_index(indexes)]; 375 } 376 377 // Returns the value at the cell specified by the indexes. The number of 378 // arguments have to match with the number of dimensions for the array. 379 template <typename... Dims> 380 typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, 381 T&>::type 382 operator()(Dims... dims) { 383 // We are using a std::array to avoid having to allocate memory in this 384 // function for performance reasons. 385 std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; 386 return values_[calculate_index(indexes)]; 387 } 388 389 // Returns the value at the cell specified by the indexes. The number of 390 // arguments have to match with the number of dimensions for the array. 391 const T& operator()(absl::Span<const int64> indexes) const { 392 return values_[calculate_index(indexes)]; 393 } 394 395 // Returns the value at the cell specified by the indexes. The number of 396 // arguments have to match with the number of dimensions for the array. 397 T& operator()(absl::Span<const int64> indexes) { 398 return values_[calculate_index(indexes)]; 399 } 400 401 // Low-level accessor for stuff like memcmp, handle with care. Returns pointer 402 // to the underlying storage of the array (similarly to std::vector::data()). 403 T* data() const { 404 // TODO(tberghammer): Get rid of the const_cast. Currently it is needed 405 // because the Eigen backend needs a non-const pointers even for reading 406 // from the array. 407 return const_cast<Array*>(this)->values_.get(); 408 } 409 410 // Returns the size of the dimension at the given index. 411 int64 dim(int64 n) const { 412 const int64 sizes_size = sizes_.size(); 413 CHECK(n < sizes_size); 414 return sizes_[n]; 415 } 416 417 // Returns a vector containing the dimensions of the array. 418 const std::vector<int64>& dimensions() const { return sizes_; } 419 420 int64 num_dimensions() const { return sizes_.size(); } 421 422 // Returns the total number of elements in the array. 423 int64 num_elements() const { 424 return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, 425 std::multiplies<int64>()); 426 } 427 428 const T* begin() const { return &values_[0]; } 429 T* begin() { return &values_[0]; } 430 const T* end() const { return &values_[num_elements()]; } 431 T* end() { return &values_[num_elements()]; } 432 433 bool operator==(const Array<T>& other) const { 434 if (sizes_.size() != other.sizes_.size()) { 435 return false; 436 } 437 for (int64 i = 0, end = sizes_.size(); i < end; ++i) { 438 if (sizes_[i] != other.sizes_[i]) { 439 return false; 440 } 441 } 442 for (int64 i = 0; i < num_elements(); ++i) { 443 if (values_[i] != other.values_[i]) { 444 return false; 445 } 446 } 447 return true; 448 } 449 450 bool operator!=(const Array<T>& other) const { return !(*this == other); } 451 452 // Performs the equivalent of a slice operation on this array. 453 Array<T> Slice(absl::Span<const int64> starts, 454 absl::Span<const int64> limits) const { 455 CHECK_EQ(starts.size(), num_dimensions()); 456 CHECK_EQ(limits.size(), num_dimensions()); 457 458 std::vector<int64> sizes; 459 std::transform(starts.begin(), starts.end(), limits.begin(), 460 std::back_inserter(sizes), 461 [](int64 start, int64 limit) { return limit - start; }); 462 Array<T> result(sizes); 463 464 std::vector<int64> index(sizes_.size()); 465 int64 slice_i = 0; 466 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 467 if (array_impl::all_inside_range(index, starts, limits)) { 468 // Even though the bounds of result are different to our bounds, we're 469 // iterating in the same order. So we can simply write successive linear 470 // indices instead of recalculating a multi-dimensional index. 471 result.values_[slice_i++] = values_[i]; 472 } 473 } 474 return result; 475 } 476 477 // Performs the equivalent of a DynamicUpdateSlice in-place on this array. 478 void UpdateSlice(const Array<T>& from, 479 absl::Span<const int64> start_indices) { 480 CHECK_EQ(from.num_dimensions(), num_dimensions()); 481 std::vector<int64> limit_indices; 482 std::transform(start_indices.begin(), start_indices.end(), 483 from.dimensions().begin(), std::back_inserter(limit_indices), 484 std::plus<int64>{}); 485 std::vector<int64> index(sizes_.size()); 486 int64 from_i = 0; 487 for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { 488 if (array_impl::all_inside_range(index, start_indices, limit_indices)) { 489 // Even though the bounds of from are different to our bounds, we're 490 // iterating in the same order. So we can simply write successive linear 491 // indices instead of recalculating a multi-dimensional index. 492 values_[i] = from.values_[from_i++]; 493 } 494 } 495 } 496 497 // Performs an in-place reshape, modifying the dimensions but not the 498 // underlying data. 499 void Reshape(absl::Span<const int64> new_dimensions) { 500 int64 old_num_elements = num_elements(); 501 sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end()); 502 CHECK_EQ(num_elements(), old_num_elements); 503 } 504 505 // Returns a string representation of the array suitable for debugging. 506 string ToString() const { 507 std::vector<string> pieces; 508 std::vector<int64> index(sizes_.size()); 509 do { 510 // Emit leading spaces and opening square brackets 511 if (index.back() == 0) { 512 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 513 if (i == 0 || index[i - 1] != 0) { 514 for (int64 j = 0; j < sizes_.size(); ++j) { 515 pieces.push_back(j < i ? " " : "["); 516 } 517 break; 518 } 519 } 520 } 521 522 pieces.push_back(absl::StrCat(values_[calculate_index(index)])); 523 524 // Emit comma if it isn't the last element 525 if (index.back() != sizes_.back() - 1) { 526 pieces.push_back(", "); 527 } 528 529 // Emit closing square brackets 530 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 531 if (index[i] != sizes_[i] - 1) { 532 break; 533 } 534 pieces.push_back("]"); 535 if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) { 536 pieces.push_back(",\n"); 537 } 538 } 539 } while (next_index(&index)); 540 return absl::StrJoin(pieces, ""); 541 } 542 543 private: 544 // Converts an initializer_list of type U to a vector of type int64. Used by 545 // the initializer list based constructors to convert the size type into int64 546 // to be passed to the size based constructor. 547 template <typename U> 548 static std::vector<int64> ToInt64Vector( 549 const std::initializer_list<U>& data) { 550 return std::vector<int64>(data.begin(), data.end()); 551 } 552 553 // Returns the linear index from the list of per-dimension indexes. Function 554 // is templated so can be used with an std::array from operator() to avoid 555 // memory allocation. 556 template <typename U> 557 int64 calculate_index(const U& indexes) const { 558 CHECK_EQ(sizes_.size(), indexes.size()); 559 int64 index = 0; 560 for (int64 i = 0; i < sizes_.size(); ++i) { 561 index *= sizes_[i]; 562 index += indexes[i]; 563 } 564 DCHECK_LT(index, this->num_elements()); 565 return index; 566 } 567 568 // Advances the specified set of indexes and returns true if we haven't 569 // wrapped around (i.e. result isn't {0, 0, ...}). 570 bool next_index(std::vector<int64>* index) const { 571 CHECK_EQ(index->size(), sizes_.size()); 572 for (int64 i = sizes_.size() - 1; i >= 0; --i) { 573 (*index)[i]++; 574 if ((*index)[i] < sizes_[i]) { 575 return true; 576 } 577 (*index)[i] = 0; 578 } 579 return false; 580 } 581 582 std::vector<int64> sizes_; 583 std::unique_ptr<T[]> values_; 584 }; 585 586 // Specialization of FillRandom() method for complex64 type. Uses real part of 587 // the stddev parameter as the standard deviation value. 588 template <> 589 void Array<complex64>::FillRandom(const complex64& stddev, const double mean, 590 const int seed); 591 592 } // namespace xla 593 594 #endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ 595