• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_
18 
19 #include <algorithm>
20 #include <array>
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <numeric>
26 #include <random>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/core/lib/core/bits.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 
42 namespace array_impl {
43 
44 // conjunction
45 //
46 // Performs a compile-time logical AND operation on the passed types (which
47 // must have  `::value` members convertible to `bool`. Short-circuits if it
48 // encounters any `false` members (and does not compare the `::value` members
49 // of any remaining arguments).
50 //
51 // This metafunction is designed to be a drop-in replacement for the C++17
52 // `std::conjunction` metafunction.
53 template <typename... Ts>
54 struct conjunction;
55 
56 template <typename T, typename... Ts>
57 struct conjunction<T, Ts...>
58     : std::conditional<T::value, conjunction<Ts...>, T>::type {};
59 
60 template <>
61 struct conjunction<> : std::true_type {};
62 
63 // A type trait that is valid when all elements in a parameter pack are of
64 // integral type. Not using an alias template to work around MSVC 14.00 bug.
65 template <typename... Ts>
66 struct pack_is_integral : conjunction<std::is_integral<Ts>...> {};
67 
68 // Compares three same-sized vectors elementwise. For each item in `values`,
69 // returns false if any of values[i] is outside the half-open range [starts[i],
70 // ends[i]).
71 template <typename C1, typename C2, typename C3>
72 bool all_inside_range(const C1& values, const C2& range_starts,
73                       const C3& range_ends) {
74   for (size_t i = 0, e = values.size(); i < e; ++i) {
75     if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
76       return false;
77     }
78   }
79   return true;
80 }
81 
82 }  // namespace array_impl
83 
84 // General N dimensional array class with arbitrary value type.
85 template <typename T>
86 class Array {
87  public:
88   // Type inference can have a hard time parsing very deep initializer list
89   // nests, especially if one or more dimensions is one as the compiler just
90   // sees a single-element integer initializer. These typedefs allow casting
91   // explicitly with less typing.
92   using InitializerList1D = std::initializer_list<T>;
93   using InitializerList2D = std::initializer_list<InitializerList1D>;
94   using InitializerList3D = std::initializer_list<InitializerList2D>;
95   using InitializerList4D = std::initializer_list<InitializerList3D>;
96 
97   using value_type = T;
98 
99   // Creates a new array with the specified dimensions and initialized elements.
100   explicit Array(absl::Span<const int64> sizes)
101       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]()) {}
102 
103   // Creates a new array with the specified dimensions and specified value for
104   // every cell.
105   Array(absl::Span<const int64> sizes, T value)
106       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
107     Fill(value);
108   }
109 
110   // Creates a 2D array from the given nested initializer list. The outer
111   // initializer list is the first dimension, the inner is the second dimension.
112   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
113   Array(InitializerList2D values)
114       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
115     int64_t idx = 0;
116     for (const auto& it1 : values) {
117       for (const auto& it2 : it1) {
118         values_[idx] = it2;
119         ++idx;
120       }
121     }
122     CHECK(idx == num_elements());
123   }
124 
125   // Creates a 1D array of a floating-point type (half, bfloat16, float,
126   // or double) from an initializer list of float values.
127   template <typename T2, typename = typename std::enable_if<
128                              (std::is_same<T, Eigen::half>::value ||
129                               std::is_same<T, bfloat16>::value ||
130                               std::is_same<T, float>::value ||
131                               std::is_same<T, double>::value) &&
132                              std::is_same<T2, float>::value>::type>
133   Array(std::initializer_list<T2> values)
134       : Array(ToInt64Vector({values.size()})) {
135     int64_t idx = 0;
136     for (const auto& it1 : values) {
137       values_[idx] = static_cast<T>(it1);
138       ++idx;
139     }
140     CHECK(idx == num_elements());
141   }
142 
143   // Creates a 2D array of a floating-point type (half, bfloat16, float,
144   // or double) from an initializer list of float values.
145   template <typename T2, typename = typename std::enable_if<
146                              (std::is_same<T, Eigen::half>::value ||
147                               std::is_same<T, bfloat16>::value ||
148                               std::is_same<T, float>::value ||
149                               std::is_same<T, double>::value) &&
150                              std::is_same<T2, float>::value>::type>
151   Array(std::initializer_list<std::initializer_list<T2>> values)
152       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
153     int64_t idx = 0;
154     for (const auto& it1 : values) {
155       for (const auto& it2 : it1) {
156         values_[idx] = static_cast<T>(it2);
157         ++idx;
158       }
159     }
160     CHECK(idx == num_elements());
161   }
162 
163   // Creates a 3D array from the given nested initializer list. The outer
164   // initializer list is the first dimension, and so on.
165   Array(InitializerList3D values)
166       : Array(ToInt64Vector({values.size(), values.begin()->size(),
167                              values.begin()->begin()->size()})) {
168     int64_t idx = 0;
169     for (const auto& it1 : values) {
170       for (const auto& it2 : it1) {
171         for (const auto& it3 : it2) {
172           values_[idx] = it3;
173           ++idx;
174         }
175       }
176     }
177     CHECK(idx == num_elements());
178   }
179 
180   // Creates a 3D array of a floating-point type (half, bfloat16, float,
181   // or double) from an initializer list of float values.
182   template <typename T2, typename = typename std::enable_if<
183                              (std::is_same<T, Eigen::half>::value ||
184                               std::is_same<T, bfloat16>::value ||
185                               std::is_same<T, float>::value ||
186                               std::is_same<T, double>::value) &&
187                              std::is_same<T2, float>::value>::type>
188   Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
189             values)
190       : Array(ToInt64Vector({values.size(), values.begin()->size(),
191                              values.begin()->begin()->size()})) {
192     int64_t idx = 0;
193     for (const auto& it1 : values) {
194       for (const auto& it2 : it1) {
195         for (const auto& it3 : it2) {
196           values_[idx] = static_cast<T>(it3);
197           ++idx;
198         }
199       }
200     }
201     CHECK(idx == num_elements());
202   }
203 
204   // Creates a 4D array from the given nested initializer list. The outer
205   // initializer list is the first dimension, and so on.
206   Array(InitializerList4D values)
207       : Array(ToInt64Vector({values.size(), values.begin()->size(),
208                              values.begin()->begin()->size(),
209                              values.begin()->begin()->begin()->size()})) {
210     int64_t idx = 0;
211     for (const auto& it1 : values) {
212       for (const auto& it2 : it1) {
213         for (const auto& it3 : it2) {
214           for (const auto& it4 : it3) {
215             values_[idx] = it4;
216             ++idx;
217           }
218         }
219       }
220     }
221     CHECK(idx == num_elements());
222   }
223 
224   // Creates a 4D array of a floating-point type (half, bfloat16, float,
225   // or double) from an initializer list of float values.
226   template <typename T2, typename = typename std::enable_if<
227                              (std::is_same<T, Eigen::half>::value ||
228                               std::is_same<T, bfloat16>::value ||
229                               std::is_same<T, float>::value ||
230                               std::is_same<T, double>::value) &&
231                              std::is_same<T2, float>::value>::type>
232   Array(std::initializer_list<
233         std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
234             values)
235       : Array(ToInt64Vector({values.size(), values.begin()->size(),
236                              values.begin()->begin()->size(),
237                              values.begin()->begin()->begin()->size()})) {
238     int64_t idx = 0;
239     for (const auto& it1 : values) {
240       for (const auto& it2 : it1) {
241         for (const auto& it3 : it2) {
242           for (const auto& it4 : it3) {
243             values_[idx] = static_cast<T>(it4);
244             ++idx;
245           }
246         }
247       }
248     }
249     CHECK(idx == num_elements());
250   }
251 
252   Array(const Array<T>& other)
253       : sizes_(other.sizes_), values_(new T[num_elements()]) {
254     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
255               &values_[0]);
256   }
257 
258   Array<T>& operator=(const Array<T>& other) {
259     sizes_ = other.sizes_;
260     values_.reset(new T[num_elements()]);
261     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
262               &values_[0]);
263     return *this;
264   }
265 
266   // Fills the array with the specified value.
267   void Fill(const T& value) {
268     std::fill(&values_[0], &values_[0] + num_elements(), value);
269   }
270 
271   // Fills the array with sequentially increasing values.
272   void FillIota(const T& value) {
273     std::iota(&values_[0], &values_[0] + num_elements(), value);
274   }
275 
276   // Fills the array with a repeating sequence:
277   //   [value, value + 1, ..., value + length - 1, value, ... ]
278   void FillRepeatedIota(const T& value, int64_t length) {
279     for (int64_t i = 0; i < num_elements(); i += length) {
280       std::iota(&values_[i], &values_[std::min(i + length, num_elements())],
281                 value);
282     }
283   }
284 
285   // Fills the array with the sequence i*multiplier for i=0,1,...
286   void FillWithMultiples(const T& multiplier) {
287     for (int64_t i = 0; i < num_elements(); ++i) {
288       values_[i] = static_cast<T>(i) * multiplier;
289     }
290   }
291 
292   // Fills the array with random normal variables with the specified mean.
293   void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) {
294     FillRandomDouble(static_cast<double>(stddev), mean, seed);
295   }
296 
297   void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) {
298     std::mt19937 g(seed);
299     std::normal_distribution<double> distribution(mean, stddev);
300     for (int64_t i = 0; i < num_elements(); ++i) {
301       if (std::is_same<T, bool>()) {
302         values_[i] = static_cast<T>(distribution(g) > 0.0);
303       } else {
304         values_[i] = static_cast<T>(distribution(g));
305       }
306     }
307   }
308 
309   // Sets all the values in the array to values specified in the container.
310   template <typename Container = std::initializer_list<T>>
311   void SetValues(const Container& container) {
312     CHECK_EQ(std::distance(std::begin(container), std::end(container)),
313              num_elements());
314     std::copy(std::begin(container), std::end(container), &values_[0]);
315   }
316 
317   // Invokes a callback with the (indices, value_ptr) for each cell in the
318   // array.
319   void Each(std::function<void(absl::Span<const int64>, T*)> f) {
320     std::vector<int64> index(sizes_.size());
321     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
322       f(index, &values_[i]);
323     }
324   }
325 
326   // Invokes a callback with the (indices, value) for each cell in the array.
327   void Each(std::function<void(absl::Span<const int64>, T)> f) const {
328     std::vector<int64> index(sizes_.size());
329     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
330       f(index, values_[i]);
331     }
332   }
333 
334   // Invokes a callback with the (indices, value_ptr) for each cell in the
335   // array. If a callback returns a non-OK status, returns that else returns
336   // Status::OK().
337   Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
338     std::vector<int64> index(sizes_.size());
339     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
340       Status s = f(index, &values_[i]);
341       if (!s.ok()) {
342         return s;
343       }
344     }
345     return Status::OK();
346   }
347 
348   // Invokes a callback with the (indices, value) for each cell in the array.
349   // If a callback returns a non-OK status, returns that else returns
350   // Status::OK().
351   Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
352     std::vector<int64> index(sizes_.size());
353     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
354       Status s = f(index, values_[i]);
355       if (!s.ok()) {
356         return s;
357       }
358     }
359     return Status::OK();
360   }
361 
362   // Returns the value at the cell specified by the indexes. The number of
363   // arguments have to match with the number of dimensions for the array.
364   //
365   // The type trait is required to avoid this overload participating too
366   // eagerly; a parameter pack can take zero or more elements, so we must
367   // restrict this to only parameter packs that are all of integral type.
368   template <typename... Dims>
369   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
370                           const T&>::type
371   operator()(Dims... dims) const {
372     // We are using a std::array to avoid having to allocate memory in this
373     // function for performance reasons.
374     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
375     return values_[calculate_index(indexes)];
376   }
377 
378   // Returns the value at the cell specified by the indexes. The number of
379   // arguments have to match with the number of dimensions for the array.
380   template <typename... Dims>
381   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
382                           T&>::type
383   operator()(Dims... dims) {
384     // We are using a std::array to avoid having to allocate memory in this
385     // function for performance reasons.
386     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
387     return values_[calculate_index(indexes)];
388   }
389 
390   // Returns the value at the cell specified by the indexes. The number of
391   // arguments have to match with the number of dimensions for the array.
392   const T& operator()(absl::Span<const int64> indexes) const {
393     return values_[calculate_index(indexes)];
394   }
395 
396   // Returns the value at the cell specified by the indexes. The number of
397   // arguments have to match with the number of dimensions for the array.
398   T& operator()(absl::Span<const int64> indexes) {
399     return values_[calculate_index(indexes)];
400   }
401 
402   // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
403   // to the underlying storage of the array (similarly to std::vector::data()).
404   T* data() const {
405     // TODO(tberghammer): Get rid of the const_cast. Currently it is needed
406     // because the Eigen backend needs a non-const pointers even for reading
407     // from the array.
408     return const_cast<Array*>(this)->values_.get();
409   }
410 
411   // Returns the size of the dimension at the given index.
412   int64 dim(int64_t n) const {
413     const int64_t sizes_size = sizes_.size();
414     CHECK(n < sizes_size);
415     return sizes_[n];
416   }
417 
418   // Returns a vector containing the dimensions of the array.
419   const std::vector<int64>& dimensions() const { return sizes_; }
420 
421   int64 num_dimensions() const { return sizes_.size(); }
422 
423   // Returns the total number of elements in the array.
424   int64 num_elements() const {
425     return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
426                            std::multiplies<int64>());
427   }
428 
429   const T* begin() const { return &values_[0]; }
430   T* begin() { return &values_[0]; }
431   const T* end() const { return &values_[num_elements()]; }
432   T* end() { return &values_[num_elements()]; }
433 
434   bool operator==(const Array<T>& other) const {
435     if (sizes_.size() != other.sizes_.size()) {
436       return false;
437     }
438     for (int64_t i = 0, end = sizes_.size(); i < end; ++i) {
439       if (sizes_[i] != other.sizes_[i]) {
440         return false;
441       }
442     }
443     for (int64_t i = 0; i < num_elements(); ++i) {
444       if (values_[i] != other.values_[i]) {
445         return false;
446       }
447     }
448     return true;
449   }
450 
451   bool operator!=(const Array<T>& other) const { return !(*this == other); }
452 
453   // Performs the equivalent of a slice operation on this array.
454   Array<T> Slice(absl::Span<const int64> starts,
455                  absl::Span<const int64> limits) const {
456     CHECK_EQ(starts.size(), num_dimensions());
457     CHECK_EQ(limits.size(), num_dimensions());
458 
459     std::vector<int64> sizes;
460     std::transform(starts.begin(), starts.end(), limits.begin(),
461                    std::back_inserter(sizes),
462                    [](int64_t start, int64_t limit) { return limit - start; });
463     Array<T> result(sizes);
464 
465     std::vector<int64> index(sizes_.size());
466     int64_t slice_i = 0;
467     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
468       if (array_impl::all_inside_range(index, starts, limits)) {
469         // Even though the bounds of result are different to our bounds, we're
470         // iterating in the same order. So we can simply write successive linear
471         // indices instead of recalculating a multi-dimensional index.
472         result.values_[slice_i++] = values_[i];
473       }
474     }
475     return result;
476   }
477 
478   // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
479   void UpdateSlice(const Array<T>& from,
480                    absl::Span<const int64> start_indices) {
481     CHECK_EQ(from.num_dimensions(), num_dimensions());
482     std::vector<int64> limit_indices;
483     std::transform(start_indices.begin(), start_indices.end(),
484                    from.dimensions().begin(), std::back_inserter(limit_indices),
485                    std::plus<int64>{});
486     std::vector<int64> index(sizes_.size());
487     int64_t from_i = 0;
488     for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) {
489       if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
490         // Even though the bounds of from are different to our bounds, we're
491         // iterating in the same order. So we can simply write successive linear
492         // indices instead of recalculating a multi-dimensional index.
493         values_[i] = from.values_[from_i++];
494       }
495     }
496   }
497 
498   // Performs an in-place reshape, modifying the dimensions but not the
499   // underlying data.
500   void Reshape(absl::Span<const int64> new_dimensions) {
501     int64_t old_num_elements = num_elements();
502     sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
503     CHECK_EQ(num_elements(), old_num_elements);
504   }
505 
506   // Returns a string representation of the array suitable for debugging.
507   string ToString() const {
508     std::vector<string> pieces;
509     std::vector<int64> index(sizes_.size());
510     do {
511       // Emit leading spaces and opening square brackets
512       if (index.back() == 0) {
513         for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
514           if (i == 0 || index[i - 1] != 0) {
515             for (int64_t j = 0; j < sizes_.size(); ++j) {
516               pieces.push_back(j < i ? " " : "[");
517             }
518             break;
519           }
520         }
521       }
522 
523       pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
524 
525       // Emit comma if it isn't the last element
526       if (index.back() != sizes_.back() - 1) {
527         pieces.push_back(", ");
528       }
529 
530       // Emit closing square brackets
531       for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
532         if (index[i] != sizes_[i] - 1) {
533           break;
534         }
535         pieces.push_back("]");
536         if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) {
537           pieces.push_back(",\n");
538         }
539       }
540     } while (next_index(&index));
541     return absl::StrJoin(pieces, "");
542   }
543 
544  private:
545   // Converts an initializer_list of type U to a vector of type int64. Used by
546   // the initializer list based constructors to convert the size type into int64
547   // to be passed to the size based constructor.
548   template <typename U>
549   static std::vector<int64> ToInt64Vector(
550       const std::initializer_list<U>& data) {
551     return std::vector<int64>(data.begin(), data.end());
552   }
553 
554   // Returns the linear index from the list of per-dimension indexes. Function
555   // is templated so can be used with an std::array from operator() to avoid
556   // memory allocation.
557   template <typename U>
558   int64 calculate_index(const U& indexes) const {
559     CHECK_EQ(sizes_.size(), indexes.size());
560     int64_t index = 0;
561     for (int64_t i = 0; i < sizes_.size(); ++i) {
562       index *= sizes_[i];
563       index += indexes[i];
564     }
565     DCHECK_LT(index, this->num_elements());
566     return index;
567   }
568 
569   // Advances the specified set of indexes and returns true if we haven't
570   // wrapped around (i.e. result isn't {0, 0, ...}).
571   bool next_index(std::vector<int64>* index) const {
572     CHECK_EQ(index->size(), sizes_.size());
573     for (int64_t i = sizes_.size() - 1; i >= 0; --i) {
574       (*index)[i]++;
575       if ((*index)[i] < sizes_[i]) {
576         return true;
577       }
578       (*index)[i] = 0;
579     }
580     return false;
581   }
582 
583   std::vector<int64> sizes_;
584   std::unique_ptr<T[]> values_;
585 };
586 
587 // Specialization of FillRandom() method for complex64 type. Uses real part of
588 // the stddev parameter as the standard deviation value.
589 template <>
590 void Array<complex64>::FillRandom(const complex64& stddev, const double mean,
591                                   const int seed);
592 
593 }  // namespace xla
594 
595 #endif  // TENSORFLOW_COMPILER_XLA_ARRAY_H_
596