• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
19 #include <algorithm>
20 #include <array>
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <numeric>
26 #include <random>
27 #include <type_traits>
28 #include <vector>
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/platform/logging.h"
38 namespace xla {
40 namespace array_impl {
42 // conjunction
43 //
44 // Performs a compile-time logical AND operation on the passed types (which
45 // must have  `::value` members convertible to `bool`. Short-circuits if it
46 // encounters any `false` members (and does not compare the `::value` members
47 // of any remaining arguments).
48 //
49 // This metafunction is designed to be a drop-in replacement for the C++17
50 // `std::conjunction` metafunction.
51 template <typename... Ts>
52 struct conjunction;
54 template <typename T, typename... Ts>
55 struct conjunction<T, Ts...>
56     : std::conditional<T::value, conjunction<Ts...>, T>::type {};
58 template <>
59 struct conjunction<> : std::true_type {};
61 // A type trait that is valid when all elements in a parameter pack are of
62 // integral type. Not using an alias template to work around MSVC 14.00 bug.
63 template <typename... Ts>
64 struct pack_is_integral : conjunction<std::is_integral<Ts>...> {};
66 // Compares three same-sized vectors elementwise. For each item in `values`,
67 // returns false if any of values[i] is outside the half-open range [starts[i],
68 // ends[i]).
69 template <typename C1, typename C2, typename C3>
70 bool all_inside_range(const C1& values, const C2& range_starts,
71                       const C3& range_ends) {
72   for (size_t i = 0, e = values.size(); i < e; ++i) {
73     if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
74       return false;
75     }
76   }
77   return true;
78 }
80 }  // namespace array_impl
82 // General N dimensional array class with arbitrary value type.
83 template <typename T>
84 class Array {
85  public:
86   // Type inference can have a hard time parsing very deep initializer list
87   // nests, especially if one or more dimensions is one as the compiler just
88   // sees a single-element integer initializer. These typedefs allow casting
89   // explicitly with less typing.
90   using InitializerList1D = std::initializer_list<T>;
91   using InitializerList2D = std::initializer_list<InitializerList1D>;
92   using InitializerList3D = std::initializer_list<InitializerList2D>;
93   using InitializerList4D = std::initializer_list<InitializerList3D>;
95   using value_type = T;
97   // Creates a new array with the specified dimensions and initialized elements.
98   explicit Array(absl::Span<const int64_t> sizes)
99       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]()) {}
101   // Creates a new array with the specified dimensions and specified value for
102   // every cell.
103   Array(absl::Span<const int64_t> sizes, T value)
104       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
105     Fill(value);
106   }
108   // Creates a 2D array from the given nested initializer list. The outer
109   // initializer list is the first dimension, the inner is the second dimension.
110   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
111   Array(InitializerList2D values)
112       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
113     int64_t idx = 0;
114     for (const auto& it1 : values) {
115       for (const auto& it2 : it1) {
116         values_[idx] = it2;
117         ++idx;
118       }
119     }
120     CHECK(idx == num_elements());
121   }
123   // Creates a 1D array of a floating-point type (half, bfloat16, float,
124   // or double) from an initializer list of float values.
125   template <typename T2, typename = typename std::enable_if<
126                              (std::is_same<T, Eigen::half>::value ||
127                               std::is_same<T, bfloat16>::value ||
128                               std::is_same<T, float>::value ||
129                               std::is_same<T, double>::value) &&
130                              std::is_same<T2, float>::value>::type>
131   Array(std::initializer_list<T2> values)
132       : Array(ToInt64Vector({values.size()})) {
133     int64_t idx = 0;
134     for (const auto& it1 : values) {
135       values_[idx] = static_cast<T>(it1);
136       ++idx;
137     }
138     CHECK(idx == num_elements());
139   }
141   // Creates a 2D array of a floating-point type (half, bfloat16, float,
142   // or double) from an initializer list of float values.
143   template <typename T2, typename = typename std::enable_if<
144                              (std::is_same<T, Eigen::half>::value ||
145                               std::is_same<T, bfloat16>::value ||
146                               std::is_same<T, float>::value ||
147                               std::is_same<T, double>::value) &&
148                              std::is_same<T2, float>::value>::type>
149   Array(std::initializer_list<std::initializer_list<T2>> values)
150       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
151     int64_t idx = 0;
152     for (const auto& it1 : values) {
153       for (const auto& it2 : it1) {
154         values_[idx] = static_cast<T>(it2);
155         ++idx;
156       }
157     }
158     CHECK(idx == num_elements());
159   }
161   // Creates a 3D array from the given nested initializer list. The outer
162   // initializer list is the first dimension, and so on.
163   Array(InitializerList3D values)
164       : Array(ToInt64Vector({values.size(), values.begin()->size(),
165                              values.begin()->begin()->size()})) {
166     int64_t idx = 0;
167     for (const auto& it1 : values) {
168       for (const auto& it2 : it1) {
169         for (const auto& it3 : it2) {
170           values_[idx] = it3;
171           ++idx;
172         }
173       }
174     }
175     CHECK(idx == num_elements());
176   }
178   // Creates a 3D array of a floating-point type (half, bfloat16, float,
179   // or double) from an initializer list of float values.
180   template <typename T2, typename = typename std::enable_if<
181                              (std::is_same<T, Eigen::half>::value ||
182                               std::is_same<T, bfloat16>::value ||
183                               std::is_same<T, float>::value ||
184                               std::is_same<T, double>::value) &&
185                              std::is_same<T2, float>::value>::type>
186   Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
187             values)
188       : Array(ToInt64Vector({values.size(), values.begin()->size(),
189                              values.begin()->begin()->size()})) {
190     int64_t idx = 0;
191     for (const auto& it1 : values) {
192       for (const auto& it2 : it1) {
193         for (const auto& it3 : it2) {
194           values_[idx] = static_cast<T>(it3);
195           ++idx;
196         }
197       }
198     }
199     CHECK(idx == num_elements());
200   }
202   // Creates a 4D array from the given nested initializer list. The outer
203   // initializer list is the first dimension, and so on.
204   Array(InitializerList4D values)
205       : Array(ToInt64Vector({values.size(), values.begin()->size(),
206                              values.begin()->begin()->size(),
207                              values.begin()->begin()->begin()->size()})) {
208     int64_t idx = 0;
209     for (const auto& it1 : values) {
210       for (const auto& it2 : it1) {
211         for (const auto& it3 : it2) {
212           for (const auto& it4 : it3) {
213             values_[idx] = it4;
214             ++idx;
215           }
216         }
217       }
218     }
219     CHECK(idx == num_elements());
220   }
222   // Creates a 4D array of a floating-point type (half, bfloat16, float,
223   // or double) from an initializer list of float values.
224   template <typename T2, typename = typename std::enable_if<
225                              (std::is_same<T, Eigen::half>::value ||
226                               std::is_same<T, bfloat16>::value ||
227                               std::is_same<T, float>::value ||
228                               std::is_same<T, double>::value) &&
229                              std::is_same<T2, float>::value>::type>
230   Array(std::initializer_list<
231         std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
232             values)
233       : Array(ToInt64Vector({values.size(), values.begin()->size(),
234                              values.begin()->begin()->size(),
235                              values.begin()->begin()->begin()->size()})) {
236     int64_t idx = 0;
237     for (const auto& it1 : values) {
238       for (const auto& it2 : it1) {
239         for (const auto& it3 : it2) {
240           for (const auto& it4 : it3) {
241             values_[idx] = static_cast<T>(it4);
242             ++idx;
243           }
244         }
245       }
246     }
247     CHECK(idx == num_elements());
248   }
250   Array(const Array<T>& other)
251       : sizes_(other.sizes_), values_(new T[num_elements()]) {
252     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
253               &values_[0]);
254   }
256   Array(Array<T>&& other)
257       : sizes_(std::move(other.sizes_)), values_(std::move(other.values_)) {}
259   Array<T>& operator=(const Array<T>& other) {
260     sizes_ = other.sizes_;
261     values_.reset(new T[num_elements()]);
262     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
263               &values_[0]);
264     return *this;
265   }
267   Array<T>& operator=(Array<T>&& other) {
268     sizes_ = std::move(other.sizes_);
269     values_ = std::move(other.values_);
270     return *this;
271   }
273   // Fills the array with the specified value.
274   void Fill(const T& value) {
275     std::fill(&values_[0], &values_[0] + num_elements(), value);
276   }
278   // Fills the array with sequentially increasing values.
279   void FillIota(const T& value) {
280     std::iota(&values_[0], &values_[0] + num_elements(), value);
281   }
283   // Fills the array with a repeating sequence:
284   //   [value, value + 1, ..., value + length - 1, value, ... ]
285   void FillRepeatedIota(const T& value, int64_t length) {
286     for (int64_t i = 0; i < num_elements(); i += length) {
287       std::iota(&values_[i], &values_[std::min(i + length, num_elements())],
288                 value);
289     }
290   }
292   // Fills the array with the sequence i*multiplier for i=0,1,...
293   void FillWithMultiples(const T& multiplier) {
294     for (int64_t i = 0; i < num_elements(); ++i) {
295       values_[i] = static_cast<T>(i) * multiplier;
296     }
297   }
299   // Fills the array with random normal variables with the specified mean.
300   void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) {
301     FillRandomDouble(static_cast<double>(stddev), mean, seed);
302   }
304   void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) {
305     std::mt19937 g(seed);
306     std::normal_distribution<double> distribution(mean, stddev);
307     for (int64_t i = 0; i < num_elements(); ++i) {
308       if (std::is_same<T, bool>()) {
309         values_[i] = static_cast<T>(distribution(g) > 0.0);
310       } else {
311         values_[i] = static_cast<T>(distribution(g));
312       }
313     }
314   }
316   // Fills the array with random uniform variables in the [min_value, max_value]
317   // range. Defined for integral types.
318   template <typename = typename std::enable_if<std::is_integral<T>::value>>
319   void FillRandomUniform(const T& min_value, const T& max_value,
320                          int seed = 12345) {
321     std::mt19937 g(seed);
322     std::uniform_int_distribution<T> distribution(min_value, max_value);
323     for (int64_t i = 0; i < num_elements(); ++i) {
324       values_[i] = static_cast<T>(distribution(g));
325     }
326   }
328   // Fills the array with random uniform variables that's either True or False.
329   // Defined for boolean type.
330   void FillRandomBool(int seed = 12345) {
331     std::mt19937 g(seed);
332     std::uniform_int_distribution<int32_t> distribution(0, 1);
333     for (int64_t i = 0; i < num_elements(); ++i) {
334       values_[i] = static_cast<bool>(distribution(g));
335     }
336   }
338   // Sets all the values in the array to values specified in the container.
339   template <typename Container = std::initializer_list<T>>
340   void SetValues(const Container& container) {
341     CHECK_EQ(std::distance(std::begin(container), std::end(container)),
342              num_elements());
343     std::copy(std::begin(container), std::end(container), &values_[0]);
344   }
346   // Invokes a callback with the (indices, value_ptr) for each cell in the
347   // array.
348   void Each(std::function<void(absl::Span<const int64_t>, T*)> f) {
349     std::vector<int64_t> index(sizes_.size());
350     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
351       f(index, &values_[i]);
352     }
353   }
355   // Invokes a callback with the (indices, value) for each cell in the array.
356   void Each(std::function<void(absl::Span<const int64_t>, T)> f) const {
357     std::vector<int64_t> index(sizes_.size());
358     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
359       f(index, values_[i]);
360     }
361   }
363   // Invokes a callback with the (indices, value_ptr) for each cell in the
364   // array. If a callback returns a non-OK status, returns that else returns
365   // Status::OK().
366   Status EachStatus(std::function<Status(absl::Span<const int64_t>, T*)> f) {
367     std::vector<int64_t> index(sizes_.size());
368     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
369       Status s = f(index, &values_[i]);
370       if (!s.ok()) {
371         return s;
372       }
373     }
374     return OkStatus();
375   }
377   // Invokes a callback with the (indices, value) for each cell in the array.
378   // If a callback returns a non-OK status, returns that else returns
379   // Status::OK().
380   Status EachStatus(
381       std::function<Status(absl::Span<const int64_t>, T)> f) const {
382     std::vector<int64_t> index(sizes_.size());
383     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
384       Status s = f(index, values_[i]);
385       if (!s.ok()) {
386         return s;
387       }
388     }
389     return OkStatus();
390   }
392   // Returns the value at the cell specified by the indexes. The number of
393   // arguments have to match with the number of dimensions for the array.
394   //
395   // The type trait is required to avoid this overload participating too
396   // eagerly; a parameter pack can take zero or more elements, so we must
397   // restrict this to only parameter packs that are all of integral type.
398   template <typename... Dims>
399   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
400                           const T&>::type
401   operator()(Dims... dims) const {
402     // We are using a std::array to avoid having to allocate memory in this
403     // function for performance reasons.
404     std::array<int64_t, sizeof...(dims)> indexes{
405         {static_cast<int64_t>(dims)...}};
406     return values_[calculate_index(indexes)];
407   }
409   // Returns the value at the cell specified by the indexes. The number of
410   // arguments have to match with the number of dimensions for the array.
411   template <typename... Dims>
412   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
413                           T&>::type
414   operator()(Dims... dims) {
415     // We are using a std::array to avoid having to allocate memory in this
416     // function for performance reasons.
417     std::array<int64_t, sizeof...(dims)> indexes{
418         {static_cast<int64_t>(dims)...}};
419     return values_[calculate_index(indexes)];
420   }
422   // Returns the value at the cell specified by the indexes. The number of
423   // arguments have to match with the number of dimensions for the array.
424   const T& operator()(absl::Span<const int64_t> indexes) const {
425     return values_[calculate_index(indexes)];
426   }
428   // Returns the value at the cell specified by the indexes. The number of
429   // arguments have to match with the number of dimensions for the array.
430   T& operator()(absl::Span<const int64_t> indexes) {
431     return values_[calculate_index(indexes)];
432   }
434   // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
435   // to the underlying storage of the array (similarly to std::vector::data()).
436   T* data() const {
437     // TODO(tberghammer): Get rid of the const_cast. Currently it is needed
438     // because the Eigen backend needs a non-const pointers even for reading
439     // from the array.
440     return const_cast<Array*>(this)->values_.get();
441   }
443   // Returns the size of the dimension at the given index.
444   int64_t dim(int64_t n) const {
445     const int64_t sizes_size = sizes_.size();
446     CHECK(n < sizes_size);
447     return sizes_[n];
448   }
450   // Returns a vector containing the dimensions of the array.
451   const std::vector<int64_t>& dimensions() const { return sizes_; }
453   int64_t num_dimensions() const { return sizes_.size(); }
455   // Returns the total number of elements in the array.
456   int64_t num_elements() const {
457     return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
458                            std::multiplies<int64_t>());
459   }
461   const T* begin() const { return &values_[0]; }
462   T* begin() { return &values_[0]; }
463   const T* end() const { return &values_[num_elements()]; }
464   T* end() { return &values_[num_elements()]; }
466   bool operator==(const Array<T>& other) const {
467     if (sizes_.size() != other.sizes_.size()) {
468       return false;
469     }
470     for (int64_t i = 0, end = sizes_.size(); i < end; ++i) {
471       if (sizes_[i] != other.sizes_[i]) {
472         return false;
473       }
474     }
475     for (int64_t i = 0; i < num_elements(); ++i) {
476       if (values_[i] != other.values_[i]) {
477         return false;
478       }
479     }
480     return true;
481   }
483   bool operator!=(const Array<T>& other) const { return !(*this == other); }
485   // Performs the equivalent of a slice operation on this array.
486   Array<T> Slice(absl::Span<const int64_t> starts,
487                  absl::Span<const int64_t> limits) const {
488     CHECK_EQ(starts.size(), num_dimensions());
489     CHECK_EQ(limits.size(), num_dimensions());
491     std::vector<int64_t> sizes;
492     std::transform(starts.begin(), starts.end(), limits.begin(),
493                    std::back_inserter(sizes),
494                    [](int64_t start, int64_t limit) { return limit - start; });
495     Array<T> result(sizes);
497     std::vector<int64_t> index(sizes_.size());
498     int64_t slice_i = 0;
499     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
500       if (array_impl::all_inside_range(index, starts, limits)) {
501         // Even though the bounds of result are different to our bounds, we're
502         // iterating in the same order. So we can simply write successive linear
503         // indices instead of recalculating a multi-dimensional index.
504         result.values_[slice_i++] = values_[i];
505       }
506     }
507     return result;
508   }
510   // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
511   void UpdateSlice(const Array<T>& from,
512                    absl::Span<const int64_t> start_indices) {
513     CHECK_EQ(from.num_dimensions(), num_dimensions());
514     std::vector<int64_t> limit_indices;
515     std::transform(start_indices.begin(), start_indices.end(),
516                    from.dimensions().begin(), std::back_inserter(limit_indices),
517                    std::plus<int64_t>{});
518     std::vector<int64_t> index(sizes_.size());
519     int64_t from_i = 0;
520     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
521       if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
522         // Even though the bounds of from are different to our bounds, we're
523         // iterating in the same order. So we can simply write successive linear
524         // indices instead of recalculating a multi-dimensional index.
525         values_[i] = from.values_[from_i++];
526       }
527     }
528   }
530   // Performs an in-place reshape, modifying the dimensions but not the
531   // underlying data.
532   void Reshape(absl::Span<const int64_t> new_dimensions) {
533     int64_t old_num_elements = num_elements();
534     sizes_ = std::vector<int64_t>(new_dimensions.begin(), new_dimensions.end());
535     CHECK_EQ(num_elements(), old_num_elements);
536   }
538   // Performs a permutation of dimensions.
539   void TransposeDimensions(absl::Span<const int64_t> permutation) {
540     std::vector<int64_t> permuted_dims(permutation.size());
541     for (int64_t i = 0; i < permutation.size(); ++i) {
542       permuted_dims[i] = this->dim(permutation[i]);
543     }
544     Array<T> permuted(permuted_dims);
545     std::vector<int64_t> src_indices(sizes_.size(), -1);
546     permuted.Each([&](absl::Span<const int64_t> indices, int64_t* value) {
547       CHECK_EQ(sizes_.size(), indices.size());
548       for (int64_t i = 0; i < sizes_.size(); ++i) {
549         src_indices[permutation[i]] = indices[i];
550       }
551       *value = (*this)(src_indices);
552     });
553     *this = std::move(permuted);
554   }
556   template <typename H>
557   friend H AbslHashValue(H h, const Array& array) {
558     return H::combine(std::move(h), absl::MakeSpan(array.begin(), array.end()),
559                       array.dimensions());
560   }
562   // Returns a string representation of the array suitable for debugging.
563   std::string ToString() const {
564     if (sizes_.empty()) {
565       return "";
566     }
567     std::vector<std::string> pieces;
568     std::vector<int64_t> index(sizes_.size());
569     do {
570       // Emit leading spaces and opening square brackets
571       if (index.back() == 0) {
572         for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
573           if (i == 0 || index[i - 1] != 0) {
574             for (int64_t j = 0; j < sizes_.size(); ++j) {
575               pieces.push_back(j < i ? " " : "[");
576             }
577             break;
578           }
579         }
580       }
581       int value_index = calculate_index(index);
582       if (value_index < num_elements()) {
583         pieces.push_back(absl::StrCat(values_[value_index]));
584       }
586       // Emit comma if it isn't the last element
587       if (index.back() < sizes_.back() - 1) {
588         pieces.push_back(", ");
589       }
591       // Emit closing square brackets
592       for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
593         if (index[i] < sizes_[i] - 1) {
594           break;
595         }
596         pieces.push_back("]");
597         if (i != 0 && index[i - 1] < sizes_[i - 1] - 1) {
598           pieces.push_back(",\n");
599         }
600       }
601     } while (next_index(&index));
602     return absl::StrJoin(pieces, "");
603   }
605  private:
606   // Converts an initializer_list of type U to a vector of type int64_t. Used by
607   // the initializer list based constructors to convert the size type into
608   // int64_t to be passed to the size based constructor.
609   template <typename U>
610   static std::vector<int64_t> ToInt64Vector(
611       const std::initializer_list<U>& data) {
612     return std::vector<int64_t>(data.begin(), data.end());
613   }
615   // Returns the linear index from the list of per-dimension indexes. Function
616   // is templated so can be used with an std::array from operator() to avoid
617   // memory allocation.
618   // The returned value may be larger than or equal to the number of elements if
619   // the indexes exceed the array's corresponding dimension size.
620   template <typename U>
621   int64_t calculate_index(const U& indexes) const {
622     CHECK_EQ(sizes_.size(), indexes.size());
623     int64_t index = 0;
624     for (int64_t i = 0; i < sizes_.size(); ++i) {
625       index *= sizes_[i];
626       index += indexes[i];
627     }
628     return index;
629   }
631   // Advances the specified set of indexes and returns true if we haven't
632   // wrapped around (i.e. result isn't {0, 0, ...}).
633   bool next_index(std::vector<int64_t>* index) const {
634     CHECK_EQ(index->size(), sizes_.size());
635     for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
636       (*index)[i]++;
637       if ((*index)[i] < sizes_[i]) {
638         return true;
639       }
640       (*index)[i] = 0;
641     }
642     return false;
643   }
645   std::vector<int64_t> sizes_;
646   std::unique_ptr<T[]> values_;
647 };
649 // Specialization of FillRandom() method for complex64 type. Uses real part of
650 // the stddev parameter as the standard deviation value.
651 template <>
652 void Array<complex64>::FillRandom(const complex64& stddev, const double mean,
653                                   const int seed);
655 }  // namespace xla