• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_
18 
19 #include <algorithm>
20 #include <array>
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <numeric>
26 #include <random>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/core/lib/core/bits.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 
42 namespace array_impl {
43 
44 // conjunction
45 //
46 // Performs a compile-time logical AND operation on the passed types (which
47 // must have  `::value` members convertible to `bool`. Short-circuits if it
48 // encounters any `false` members (and does not compare the `::value` members
49 // of any remaining arguments).
50 //
51 // This metafunction is designed to be a drop-in replacement for the C++17
52 // `std::conjunction` metafunction.
53 template <typename... Ts>
54 struct conjunction;
55 
56 template <typename T, typename... Ts>
57 struct conjunction<T, Ts...>
58     : std::conditional<T::value, conjunction<Ts...>, T>::type {};
59 
60 template <>
61 struct conjunction<> : std::true_type {};
62 
63 // A type trait that is valid when all elements in a parameter pack are of
64 // integral type.
65 template <typename... T>
66 using pack_is_integral = conjunction<std::is_integral<T>...>;
67 
68 // Compares three same-sized vectors elementwise. For each item in `values`,
69 // returns false if any of values[i] is outside the half-open range [starts[i],
70 // ends[i]).
71 template <typename C1, typename C2, typename C3>
72 bool all_inside_range(const C1& values, const C2& range_starts,
73                       const C3& range_ends) {
74   for (size_t i = 0, e = values.size(); i < e; ++i) {
75     if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
76       return false;
77     }
78   }
79   return true;
80 }
81 
82 }  // namespace array_impl
83 
84 // General N dimensional array class with arbitrary value type.
85 template <typename T>
86 class Array {
87  public:
88   // Type inference can have a hard time parsing very deep initializer list
89   // nests, especially if one or more dimensions is one as the compiler just
90   // sees a single-element integer initializer. These typedefs allow casting
91   // explicitly with less typing.
92   using InitializerList1D = std::initializer_list<T>;
93   using InitializerList2D = std::initializer_list<InitializerList1D>;
94   using InitializerList3D = std::initializer_list<InitializerList2D>;
95   using InitializerList4D = std::initializer_list<InitializerList3D>;
96 
97   using value_type = T;
98 
99   // Creates a new array with the specified dimensions.
100   explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
101 
102   // Creates a new array with the specified dimensions and specified value for
103   // every cell.
104   Array(absl::Span<const int64> sizes, T value)
105       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
106     Fill(value);
107   }
108 
109   // Creates a 2D array from the given nested initializer list. The outer
110   // initializer list is the first dimension, the inner is the second dimension.
111   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
112   Array(InitializerList2D values)
113       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
114     int64 idx = 0;
115     for (const auto& it1 : values) {
116       for (const auto& it2 : it1) {
117         values_[idx] = it2;
118         ++idx;
119       }
120     }
121     CHECK(idx == num_elements());
122   }
123 
124   // Creates a 1D array of a floating-point type (half, bfloat16, float,
125   // or double) from an initializer list of float values.
126   template <typename T2, typename = typename std::enable_if<
127                              (std::is_same<T, Eigen::half>::value ||
128                               std::is_same<T, bfloat16>::value ||
129                               std::is_same<T, float>::value ||
130                               std::is_same<T, double>::value) &&
131                              std::is_same<T2, float>::value>::type>
132   Array(std::initializer_list<T2> values)
133       : Array(ToInt64Vector({values.size()})) {
134     int64 idx = 0;
135     for (const auto& it1 : values) {
136       values_[idx] = static_cast<T>(it1);
137       ++idx;
138     }
139     CHECK(idx == num_elements());
140   }
141 
142   // Creates a 2D array of a floating-point type (half, bfloat16, float,
143   // or double) from an initializer list of float values.
144   template <typename T2, typename = typename std::enable_if<
145                              (std::is_same<T, Eigen::half>::value ||
146                               std::is_same<T, bfloat16>::value ||
147                               std::is_same<T, float>::value ||
148                               std::is_same<T, double>::value) &&
149                              std::is_same<T2, float>::value>::type>
150   Array(std::initializer_list<std::initializer_list<T2>> values)
151       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
152     int64 idx = 0;
153     for (const auto& it1 : values) {
154       for (const auto& it2 : it1) {
155         values_[idx] = static_cast<T>(it2);
156         ++idx;
157       }
158     }
159     CHECK(idx == num_elements());
160   }
161 
162   // Creates a 3D array from the given nested initializer list. The outer
163   // initializer list is the first dimension, and so on.
164   Array(InitializerList3D values)
165       : Array(ToInt64Vector({values.size(), values.begin()->size(),
166                              values.begin()->begin()->size()})) {
167     int64 idx = 0;
168     for (const auto& it1 : values) {
169       for (const auto& it2 : it1) {
170         for (const auto& it3 : it2) {
171           values_[idx] = it3;
172           ++idx;
173         }
174       }
175     }
176     CHECK(idx == num_elements());
177   }
178 
179   // Creates a 3D array of a floating-point type (half, bfloat16, float,
180   // or double) from an initializer list of float values.
181   template <typename T2, typename = typename std::enable_if<
182                              (std::is_same<T, Eigen::half>::value ||
183                               std::is_same<T, bfloat16>::value ||
184                               std::is_same<T, float>::value ||
185                               std::is_same<T, double>::value) &&
186                              std::is_same<T2, float>::value>::type>
187   Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
188             values)
189       : Array(ToInt64Vector({values.size(), values.begin()->size(),
190                              values.begin()->begin()->size()})) {
191     int64 idx = 0;
192     for (const auto& it1 : values) {
193       for (const auto& it2 : it1) {
194         for (const auto& it3 : it2) {
195           values_[idx] = static_cast<T>(it3);
196           ++idx;
197         }
198       }
199     }
200     CHECK(idx == num_elements());
201   }
202 
203   // Creates a 4D array from the given nested initializer list. The outer
204   // initializer list is the first dimension, and so on.
205   Array(InitializerList4D values)
206       : Array(ToInt64Vector({values.size(), values.begin()->size(),
207                              values.begin()->begin()->size(),
208                              values.begin()->begin()->begin()->size()})) {
209     int64 idx = 0;
210     for (const auto& it1 : values) {
211       for (const auto& it2 : it1) {
212         for (const auto& it3 : it2) {
213           for (const auto& it4 : it3) {
214             values_[idx] = it4;
215             ++idx;
216           }
217         }
218       }
219     }
220     CHECK(idx == num_elements());
221   }
222 
223   // Creates a 4D array of a floating-point type (half, bfloat16, float,
224   // or double) from an initializer list of float values.
225   template <typename T2, typename = typename std::enable_if<
226                              (std::is_same<T, Eigen::half>::value ||
227                               std::is_same<T, bfloat16>::value ||
228                               std::is_same<T, float>::value ||
229                               std::is_same<T, double>::value) &&
230                              std::is_same<T2, float>::value>::type>
231   Array(std::initializer_list<
232         std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
233             values)
234       : Array(ToInt64Vector({values.size(), values.begin()->size(),
235                              values.begin()->begin()->size(),
236                              values.begin()->begin()->begin()->size()})) {
237     int64 idx = 0;
238     for (const auto& it1 : values) {
239       for (const auto& it2 : it1) {
240         for (const auto& it3 : it2) {
241           for (const auto& it4 : it3) {
242             values_[idx] = static_cast<T>(it4);
243             ++idx;
244           }
245         }
246       }
247     }
248     CHECK(idx == num_elements());
249   }
250 
251   Array(const Array<T>& other)
252       : sizes_(other.sizes_), values_(new T[num_elements()]) {
253     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
254               &values_[0]);
255   }
256 
257   Array<T>& operator=(const Array<T>& other) {
258     sizes_ = other.sizes_;
259     values_.reset(new T[num_elements()]);
260     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
261               &values_[0]);
262     return *this;
263   }
264 
265   // Fills the array with the specified value.
266   void Fill(const T& value) {
267     std::fill(&values_[0], &values_[0] + num_elements(), value);
268   }
269 
270   // Fills the array with sequentially increasing values.
271   void FillIota(const T& value) {
272     std::iota(&values_[0], &values_[0] + num_elements(), value);
273   }
274 
275   // Fills the array with a repeating sequence:
276   //   [value, value + 1, ..., value + length - 1, value, ... ]
277   void FillRepeatedIota(const T& value, int64 length) {
278     for (int64 i = 0; i < num_elements(); i += length) {
279       std::iota(&values_[i], &values_[std::min(i + length, num_elements())],
280                 value);
281     }
282   }
283 
284   // Fills the array with the sequence i*multiplier for i=0,1,...
285   void FillWithMultiples(const T& multiplier) {
286     for (int64 i = 0; i < num_elements(); ++i) {
287       values_[i] = static_cast<T>(i) * multiplier;
288     }
289   }
290 
291   // Fills the array with random normal variables with the specified mean.
292   void FillRandom(const T& stddev, const double mean = 0.0,
293                   const int seed = 12345) {
294     std::mt19937 g(seed);
295     std::normal_distribution<double> distribution(mean,
296                                                   static_cast<double>(stddev));
297     for (int64 i = 0; i < num_elements(); ++i) {
298       values_[i] = static_cast<T>(distribution(g));
299     }
300   }
301 
302   // Sets all the values in the array to values specified in the container.
303   template <typename Container = std::initializer_list<T>>
304   void SetValues(const Container& container) {
305     CHECK_EQ(std::distance(std::begin(container), std::end(container)),
306              num_elements());
307     std::copy(std::begin(container), std::end(container), &values_[0]);
308   }
309 
310   // Invokes a callback with the (indices, value_ptr) for each cell in the
311   // array.
312   void Each(std::function<void(absl::Span<const int64>, T*)> f) {
313     std::vector<int64> index(sizes_.size());
314     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
315       f(index, &values_[i]);
316     }
317   }
318 
319   // Invokes a callback with the (indices, value) for each cell in the array.
320   void Each(std::function<void(absl::Span<const int64>, T)> f) const {
321     std::vector<int64> index(sizes_.size());
322     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
323       f(index, values_[i]);
324     }
325   }
326 
327   // Invokes a callback with the (indices, value_ptr) for each cell in the
328   // array. If a callback returns a non-OK status, returns that else returns
329   // Status::OK().
330   Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
331     std::vector<int64> index(sizes_.size());
332     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
333       Status s = f(index, &values_[i]);
334       if (!s.ok()) {
335         return s;
336       }
337     }
338     return Status::OK();
339   }
340 
341   // Invokes a callback with the (indices, value) for each cell in the array.
342   // If a callback returns a non-OK status, returns that else returns
343   // Status::OK().
344   Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
345     std::vector<int64> index(sizes_.size());
346     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
347       Status s = f(index, values_[i]);
348       if (!s.ok()) {
349         return s;
350       }
351     }
352     return Status::OK();
353   }
354 
355   // Returns the value at the cell specified by the indexes. The number of
356   // arguments have to match with the number of dimensions for the array.
357   //
358   // The type trait is required to avoid this overload participating too
359   // eagerly; a parameter pack can take zero or more elements, so we must
360   // restrict this to only parameter packs that are all of integral type.
361   template <typename... Dims>
362   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
363                           const T&>::type
364   operator()(Dims... dims) const {
365     // We are using a std::array to avoid having to allocate memory in this
366     // function for performance reasons.
367     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
368     return values_[calculate_index(indexes)];
369   }
370 
371   // Returns the value at the cell specified by the indexes. The number of
372   // arguments have to match with the number of dimensions for the array.
373   template <typename... Dims>
374   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
375                           T&>::type
376   operator()(Dims... dims) {
377     // We are using a std::array to avoid having to allocate memory in this
378     // function for performance reasons.
379     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
380     return values_[calculate_index(indexes)];
381   }
382 
383   // Returns the value at the cell specified by the indexes. The number of
384   // arguments have to match with the number of dimensions for the array.
385   const T& operator()(absl::Span<const int64> indexes) const {
386     return values_[calculate_index(indexes)];
387   }
388 
389   // Returns the value at the cell specified by the indexes. The number of
390   // arguments have to match with the number of dimensions for the array.
391   T& operator()(absl::Span<const int64> indexes) {
392     return values_[calculate_index(indexes)];
393   }
394 
395   // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
396   // to the underlying storage of the array (similarly to std::vector::data()).
397   T* data() const {
398     // TODO(tberghammer): Get rid of the const_cast. Currently it is needed
399     // because the Eigen backend needs a non-const pointers even for reading
400     // from the array.
401     return const_cast<Array*>(this)->values_.get();
402   }
403 
404   // Returns the size of the dimension at the given index.
405   int64 dim(int64 n) const {
406     CHECK(n < sizes_.size());
407     return sizes_[n];
408   }
409 
410   // Returns a vector containing the dimensions of the array.
411   const std::vector<int64>& dimensions() const { return sizes_; }
412 
413   int64 num_dimensions() const { return sizes_.size(); }
414 
415   // Returns the total number of elements in the array.
416   int64 num_elements() const {
417     return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
418                            std::multiplies<int64>());
419   }
420 
421   const T* begin() const { return &values_[0]; }
422   T* begin() { return &values_[0]; }
423   const T* end() const { return &values_[num_elements()]; }
424   T* end() { return &values_[num_elements()]; }
425 
426   bool operator==(const Array<T>& other) const {
427     if (sizes_.size() != other.sizes_.size()) {
428       return false;
429     }
430     for (int64 i = 0; i < sizes_.size(); ++i) {
431       if (sizes_[i] != other.sizes_[i]) {
432         return false;
433       }
434     }
435     for (int64 i = 0; i < num_elements(); ++i) {
436       if (values_[i] != other.values_[i]) {
437         return false;
438       }
439     }
440     return true;
441   }
442 
443   bool operator!=(const Array<T>& other) const { return !(*this == other); }
444 
445   // Performs the equivalent of a slice operation on this array.
446   Array<T> Slice(absl::Span<const int64> starts,
447                  absl::Span<const int64> limits) const {
448     CHECK_EQ(starts.size(), num_dimensions());
449     CHECK_EQ(limits.size(), num_dimensions());
450 
451     std::vector<int64> sizes;
452     std::transform(starts.begin(), starts.end(), limits.begin(),
453                    std::back_inserter(sizes),
454                    [](int64 start, int64 limit) { return limit - start; });
455     Array<T> result(sizes);
456 
457     std::vector<int64> index(sizes_.size());
458     int64 slice_i = 0;
459     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
460       if (array_impl::all_inside_range(index, starts, limits)) {
461         // Even though the bounds of result are different to our bounds, we're
462         // iterating in the same order. So we can simply write successive linear
463         // indices instead of recalculating a multi-dimensional index.
464         result.values_[slice_i++] = values_[i];
465       }
466     }
467     return result;
468   }
469 
470   // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
471   void UpdateSlice(const Array<T>& from,
472                    absl::Span<const int64> start_indices) {
473     CHECK_EQ(from.num_dimensions(), num_dimensions());
474     std::vector<int64> limit_indices;
475     std::transform(start_indices.begin(), start_indices.end(),
476                    from.dimensions().begin(), std::back_inserter(limit_indices),
477                    std::plus<int64>{});
478     std::vector<int64> index(sizes_.size());
479     int64 from_i = 0;
480     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
481       if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
482         // Even though the bounds of from are different to our bounds, we're
483         // iterating in the same order. So we can simply write successive linear
484         // indices instead of recalculating a multi-dimensional index.
485         values_[i] = from.values_[from_i++];
486       }
487     }
488   }
489 
490   // Performs an in-place reshape, modifying the dimensions but not the
491   // underlying data.
492   void Reshape(absl::Span<const int64> new_dimensions) {
493     int64 old_num_elements = num_elements();
494     sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
495     CHECK_EQ(num_elements(), old_num_elements);
496   }
497 
498   // Returns a string representation of the array suitable for debugging.
499   string ToString() const {
500     std::vector<string> pieces;
501     std::vector<int64> index(sizes_.size());
502     do {
503       // Emit leading spaces and opening square brackets
504       if (index.back() == 0) {
505         for (int64 i = sizes_.size() - 1; i >= 0; --i) {
506           if (i == 0 || index[i - 1] != 0) {
507             for (int64 j = 0; j < sizes_.size(); ++j) {
508               pieces.push_back(j < i ? " " : "[");
509             }
510             break;
511           }
512         }
513       }
514 
515       pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
516 
517       // Emit comma if it isn't the last element
518       if (index.back() != sizes_.back() - 1) {
519         pieces.push_back(", ");
520       }
521 
522       // Emit closing square brackets
523       for (int64 i = sizes_.size() - 1; i >= 0; --i) {
524         if (index[i] != sizes_[i] - 1) {
525           break;
526         }
527         pieces.push_back("]");
528         if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) {
529           pieces.push_back(",\n");
530         }
531       }
532     } while (next_index(&index));
533     return absl::StrJoin(pieces, "");
534   }
535 
536  private:
537   // Converts an initializer_list of type U to a vector of type int64. Used by
538   // the initializer list based constructors to convert the size type into int64
539   // to be passed to the size based constructor.
540   template <typename U>
541   static std::vector<int64> ToInt64Vector(
542       const std::initializer_list<U>& data) {
543     return std::vector<int64>(data.begin(), data.end());
544   }
545 
546   // Returns the linear index from the list of per-dimension indexes. Function
547   // is templated so can be used with an std::array from operator() to avoid
548   // memory allocation.
549   template <typename U>
550   int64 calculate_index(const U& indexes) const {
551     CHECK_EQ(sizes_.size(), indexes.size());
552     int64 index = 0;
553     for (int64 i = 0; i < sizes_.size(); ++i) {
554       index *= sizes_[i];
555       index += indexes[i];
556     }
557     return index;
558   }
559 
560   // Advances the specified set of indexes and returns true if we haven't
561   // wrapped around (i.e. result isn't {0, 0, ...}).
562   bool next_index(std::vector<int64>* index) const {
563     CHECK_EQ(index->size(), sizes_.size());
564     for (int64 i = sizes_.size() - 1; i >= 0; --i) {
565       (*index)[i]++;
566       if ((*index)[i] < sizes_[i]) {
567         return true;
568       }
569       (*index)[i] = 0;
570     }
571     return false;
572   }
573 
574   std::vector<int64> sizes_;
575   std::unique_ptr<T[]> values_;
576 };
577 
578 }  // namespace xla
579 
580 #endif  // TENSORFLOW_COMPILER_XLA_ARRAY_H_
581