• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_
18 
19 #include <algorithm>
20 #include <array>
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <numeric>
26 #include <random>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/core/lib/core/bits.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 
42 namespace array_impl {
43 
44 // conjunction
45 //
46 // Performs a compile-time logical AND operation on the passed types (which
47 // must have  `::value` members convertible to `bool`. Short-circuits if it
48 // encounters any `false` members (and does not compare the `::value` members
49 // of any remaining arguments).
50 //
51 // This metafunction is designed to be a drop-in replacement for the C++17
52 // `std::conjunction` metafunction.
53 template <typename... Ts>
54 struct conjunction;
55 
56 template <typename T, typename... Ts>
57 struct conjunction<T, Ts...>
58     : std::conditional<T::value, conjunction<Ts...>, T>::type {};
59 
60 template <>
61 struct conjunction<> : std::true_type {};
62 
63 // A type trait that is valid when all elements in a parameter pack are of
64 // integral type. Not using an alias template to work around MSVC 14.00 bug.
65 template <typename... Ts>
66 struct pack_is_integral : conjunction<std::is_integral<Ts>...> {};
67 
68 // Compares three same-sized vectors elementwise. For each item in `values`,
69 // returns false if any of values[i] is outside the half-open range [starts[i],
70 // ends[i]).
71 template <typename C1, typename C2, typename C3>
72 bool all_inside_range(const C1& values, const C2& range_starts,
73                       const C3& range_ends) {
74   for (size_t i = 0, e = values.size(); i < e; ++i) {
75     if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
76       return false;
77     }
78   }
79   return true;
80 }
81 
82 }  // namespace array_impl
83 
84 // General N dimensional array class with arbitrary value type.
85 template <typename T>
86 class Array {
87  public:
88   // Type inference can have a hard time parsing very deep initializer list
89   // nests, especially if one or more dimensions is one as the compiler just
90   // sees a single-element integer initializer. These typedefs allow casting
91   // explicitly with less typing.
92   using InitializerList1D = std::initializer_list<T>;
93   using InitializerList2D = std::initializer_list<InitializerList1D>;
94   using InitializerList3D = std::initializer_list<InitializerList2D>;
95   using InitializerList4D = std::initializer_list<InitializerList3D>;
96 
97   using value_type = T;
98 
99   // Creates a new array with the specified dimensions.
100   explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
101 
102   // Creates a new array with the specified dimensions and specified value for
103   // every cell.
104   Array(absl::Span<const int64> sizes, T value)
105       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
106     Fill(value);
107   }
108 
109   // Creates a 2D array from the given nested initializer list. The outer
110   // initializer list is the first dimension, the inner is the second dimension.
111   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
112   Array(InitializerList2D values)
113       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
114     int64 idx = 0;
115     for (const auto& it1 : values) {
116       for (const auto& it2 : it1) {
117         values_[idx] = it2;
118         ++idx;
119       }
120     }
121     CHECK(idx == num_elements());
122   }
123 
124   // Creates a 1D array of a floating-point type (half, bfloat16, float,
125   // or double) from an initializer list of float values.
126   template <typename T2, typename = typename std::enable_if<
127                              (std::is_same<T, Eigen::half>::value ||
128                               std::is_same<T, bfloat16>::value ||
129                               std::is_same<T, float>::value ||
130                               std::is_same<T, double>::value) &&
131                              std::is_same<T2, float>::value>::type>
132   Array(std::initializer_list<T2> values)
133       : Array(ToInt64Vector({values.size()})) {
134     int64 idx = 0;
135     for (const auto& it1 : values) {
136       values_[idx] = static_cast<T>(it1);
137       ++idx;
138     }
139     CHECK(idx == num_elements());
140   }
141 
142   // Creates a 2D array of a floating-point type (half, bfloat16, float,
143   // or double) from an initializer list of float values.
144   template <typename T2, typename = typename std::enable_if<
145                              (std::is_same<T, Eigen::half>::value ||
146                               std::is_same<T, bfloat16>::value ||
147                               std::is_same<T, float>::value ||
148                               std::is_same<T, double>::value) &&
149                              std::is_same<T2, float>::value>::type>
150   Array(std::initializer_list<std::initializer_list<T2>> values)
151       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
152     int64 idx = 0;
153     for (const auto& it1 : values) {
154       for (const auto& it2 : it1) {
155         values_[idx] = static_cast<T>(it2);
156         ++idx;
157       }
158     }
159     CHECK(idx == num_elements());
160   }
161 
162   // Creates a 3D array from the given nested initializer list. The outer
163   // initializer list is the first dimension, and so on.
164   Array(InitializerList3D values)
165       : Array(ToInt64Vector({values.size(), values.begin()->size(),
166                              values.begin()->begin()->size()})) {
167     int64 idx = 0;
168     for (const auto& it1 : values) {
169       for (const auto& it2 : it1) {
170         for (const auto& it3 : it2) {
171           values_[idx] = it3;
172           ++idx;
173         }
174       }
175     }
176     CHECK(idx == num_elements());
177   }
178 
179   // Creates a 3D array of a floating-point type (half, bfloat16, float,
180   // or double) from an initializer list of float values.
181   template <typename T2, typename = typename std::enable_if<
182                              (std::is_same<T, Eigen::half>::value ||
183                               std::is_same<T, bfloat16>::value ||
184                               std::is_same<T, float>::value ||
185                               std::is_same<T, double>::value) &&
186                              std::is_same<T2, float>::value>::type>
187   Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
188             values)
189       : Array(ToInt64Vector({values.size(), values.begin()->size(),
190                              values.begin()->begin()->size()})) {
191     int64 idx = 0;
192     for (const auto& it1 : values) {
193       for (const auto& it2 : it1) {
194         for (const auto& it3 : it2) {
195           values_[idx] = static_cast<T>(it3);
196           ++idx;
197         }
198       }
199     }
200     CHECK(idx == num_elements());
201   }
202 
203   // Creates a 4D array from the given nested initializer list. The outer
204   // initializer list is the first dimension, and so on.
205   Array(InitializerList4D values)
206       : Array(ToInt64Vector({values.size(), values.begin()->size(),
207                              values.begin()->begin()->size(),
208                              values.begin()->begin()->begin()->size()})) {
209     int64 idx = 0;
210     for (const auto& it1 : values) {
211       for (const auto& it2 : it1) {
212         for (const auto& it3 : it2) {
213           for (const auto& it4 : it3) {
214             values_[idx] = it4;
215             ++idx;
216           }
217         }
218       }
219     }
220     CHECK(idx == num_elements());
221   }
222 
223   // Creates a 4D array of a floating-point type (half, bfloat16, float,
224   // or double) from an initializer list of float values.
225   template <typename T2, typename = typename std::enable_if<
226                              (std::is_same<T, Eigen::half>::value ||
227                               std::is_same<T, bfloat16>::value ||
228                               std::is_same<T, float>::value ||
229                               std::is_same<T, double>::value) &&
230                              std::is_same<T2, float>::value>::type>
231   Array(std::initializer_list<
232         std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
233             values)
234       : Array(ToInt64Vector({values.size(), values.begin()->size(),
235                              values.begin()->begin()->size(),
236                              values.begin()->begin()->begin()->size()})) {
237     int64 idx = 0;
238     for (const auto& it1 : values) {
239       for (const auto& it2 : it1) {
240         for (const auto& it3 : it2) {
241           for (const auto& it4 : it3) {
242             values_[idx] = static_cast<T>(it4);
243             ++idx;
244           }
245         }
246       }
247     }
248     CHECK(idx == num_elements());
249   }
250 
251   Array(const Array<T>& other)
252       : sizes_(other.sizes_), values_(new T[num_elements()]) {
253     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
254               &values_[0]);
255   }
256 
257   Array<T>& operator=(const Array<T>& other) {
258     sizes_ = other.sizes_;
259     values_.reset(new T[num_elements()]);
260     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
261               &values_[0]);
262     return *this;
263   }
264 
265   // Fills the array with the specified value.
266   void Fill(const T& value) {
267     std::fill(&values_[0], &values_[0] + num_elements(), value);
268   }
269 
270   // Fills the array with sequentially increasing values.
271   void FillIota(const T& value) {
272     std::iota(&values_[0], &values_[0] + num_elements(), value);
273   }
274 
275   // Fills the array with a repeating sequence:
276   //   [value, value + 1, ..., value + length - 1, value, ... ]
277   void FillRepeatedIota(const T& value, int64 length) {
278     for (int64 i = 0; i < num_elements(); i += length) {
279       std::iota(&values_[i], &values_[std::min(i + length, num_elements())],
280                 value);
281     }
282   }
283 
284   // Fills the array with the sequence i*multiplier for i=0,1,...
285   void FillWithMultiples(const T& multiplier) {
286     for (int64 i = 0; i < num_elements(); ++i) {
287       values_[i] = static_cast<T>(i) * multiplier;
288     }
289   }
290 
291   // Fills the array with random normal variables with the specified mean.
292   void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) {
293     FillRandomDouble(static_cast<double>(stddev), mean, seed);
294   }
295 
296   void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) {
297     std::mt19937 g(seed);
298     std::normal_distribution<double> distribution(mean, stddev);
299     for (int64 i = 0; i < num_elements(); ++i) {
300       if (std::is_same<T, bool>()) {
301         values_[i] = static_cast<T>(distribution(g) > 0.0);
302       } else {
303         values_[i] = static_cast<T>(distribution(g));
304       }
305     }
306   }
307 
308   // Sets all the values in the array to values specified in the container.
309   template <typename Container = std::initializer_list<T>>
310   void SetValues(const Container& container) {
311     CHECK_EQ(std::distance(std::begin(container), std::end(container)),
312              num_elements());
313     std::copy(std::begin(container), std::end(container), &values_[0]);
314   }
315 
316   // Invokes a callback with the (indices, value_ptr) for each cell in the
317   // array.
318   void Each(std::function<void(absl::Span<const int64>, T*)> f) {
319     std::vector<int64> index(sizes_.size());
320     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
321       f(index, &values_[i]);
322     }
323   }
324 
325   // Invokes a callback with the (indices, value) for each cell in the array.
326   void Each(std::function<void(absl::Span<const int64>, T)> f) const {
327     std::vector<int64> index(sizes_.size());
328     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
329       f(index, values_[i]);
330     }
331   }
332 
333   // Invokes a callback with the (indices, value_ptr) for each cell in the
334   // array. If a callback returns a non-OK status, returns that else returns
335   // Status::OK().
336   Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
337     std::vector<int64> index(sizes_.size());
338     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
339       Status s = f(index, &values_[i]);
340       if (!s.ok()) {
341         return s;
342       }
343     }
344     return Status::OK();
345   }
346 
347   // Invokes a callback with the (indices, value) for each cell in the array.
348   // If a callback returns a non-OK status, returns that else returns
349   // Status::OK().
350   Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
351     std::vector<int64> index(sizes_.size());
352     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
353       Status s = f(index, values_[i]);
354       if (!s.ok()) {
355         return s;
356       }
357     }
358     return Status::OK();
359   }
360 
361   // Returns the value at the cell specified by the indexes. The number of
362   // arguments have to match with the number of dimensions for the array.
363   //
364   // The type trait is required to avoid this overload participating too
365   // eagerly; a parameter pack can take zero or more elements, so we must
366   // restrict this to only parameter packs that are all of integral type.
367   template <typename... Dims>
368   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
369                           const T&>::type
370   operator()(Dims... dims) const {
371     // We are using a std::array to avoid having to allocate memory in this
372     // function for performance reasons.
373     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
374     return values_[calculate_index(indexes)];
375   }
376 
377   // Returns the value at the cell specified by the indexes. The number of
378   // arguments have to match with the number of dimensions for the array.
379   template <typename... Dims>
380   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
381                           T&>::type
382   operator()(Dims... dims) {
383     // We are using a std::array to avoid having to allocate memory in this
384     // function for performance reasons.
385     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
386     return values_[calculate_index(indexes)];
387   }
388 
389   // Returns the value at the cell specified by the indexes. The number of
390   // arguments have to match with the number of dimensions for the array.
391   const T& operator()(absl::Span<const int64> indexes) const {
392     return values_[calculate_index(indexes)];
393   }
394 
395   // Returns the value at the cell specified by the indexes. The number of
396   // arguments have to match with the number of dimensions for the array.
397   T& operator()(absl::Span<const int64> indexes) {
398     return values_[calculate_index(indexes)];
399   }
400 
401   // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
402   // to the underlying storage of the array (similarly to std::vector::data()).
403   T* data() const {
404     // TODO(tberghammer): Get rid of the const_cast. Currently it is needed
405     // because the Eigen backend needs a non-const pointers even for reading
406     // from the array.
407     return const_cast<Array*>(this)->values_.get();
408   }
409 
410   // Returns the size of the dimension at the given index.
411   int64 dim(int64 n) const {
412     const int64 sizes_size = sizes_.size();
413     CHECK(n < sizes_size);
414     return sizes_[n];
415   }
416 
417   // Returns a vector containing the dimensions of the array.
418   const std::vector<int64>& dimensions() const { return sizes_; }
419 
420   int64 num_dimensions() const { return sizes_.size(); }
421 
422   // Returns the total number of elements in the array.
423   int64 num_elements() const {
424     return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
425                            std::multiplies<int64>());
426   }
427 
428   const T* begin() const { return &values_[0]; }
429   T* begin() { return &values_[0]; }
430   const T* end() const { return &values_[num_elements()]; }
431   T* end() { return &values_[num_elements()]; }
432 
433   bool operator==(const Array<T>& other) const {
434     if (sizes_.size() != other.sizes_.size()) {
435       return false;
436     }
437     for (int64 i = 0, end = sizes_.size(); i < end; ++i) {
438       if (sizes_[i] != other.sizes_[i]) {
439         return false;
440       }
441     }
442     for (int64 i = 0; i < num_elements(); ++i) {
443       if (values_[i] != other.values_[i]) {
444         return false;
445       }
446     }
447     return true;
448   }
449 
450   bool operator!=(const Array<T>& other) const { return !(*this == other); }
451 
452   // Performs the equivalent of a slice operation on this array.
453   Array<T> Slice(absl::Span<const int64> starts,
454                  absl::Span<const int64> limits) const {
455     CHECK_EQ(starts.size(), num_dimensions());
456     CHECK_EQ(limits.size(), num_dimensions());
457 
458     std::vector<int64> sizes;
459     std::transform(starts.begin(), starts.end(), limits.begin(),
460                    std::back_inserter(sizes),
461                    [](int64 start, int64 limit) { return limit - start; });
462     Array<T> result(sizes);
463 
464     std::vector<int64> index(sizes_.size());
465     int64 slice_i = 0;
466     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
467       if (array_impl::all_inside_range(index, starts, limits)) {
468         // Even though the bounds of result are different to our bounds, we're
469         // iterating in the same order. So we can simply write successive linear
470         // indices instead of recalculating a multi-dimensional index.
471         result.values_[slice_i++] = values_[i];
472       }
473     }
474     return result;
475   }
476 
477   // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
478   void UpdateSlice(const Array<T>& from,
479                    absl::Span<const int64> start_indices) {
480     CHECK_EQ(from.num_dimensions(), num_dimensions());
481     std::vector<int64> limit_indices;
482     std::transform(start_indices.begin(), start_indices.end(),
483                    from.dimensions().begin(), std::back_inserter(limit_indices),
484                    std::plus<int64>{});
485     std::vector<int64> index(sizes_.size());
486     int64 from_i = 0;
487     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
488       if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
489         // Even though the bounds of from are different to our bounds, we're
490         // iterating in the same order. So we can simply write successive linear
491         // indices instead of recalculating a multi-dimensional index.
492         values_[i] = from.values_[from_i++];
493       }
494     }
495   }
496 
497   // Performs an in-place reshape, modifying the dimensions but not the
498   // underlying data.
499   void Reshape(absl::Span<const int64> new_dimensions) {
500     int64 old_num_elements = num_elements();
501     sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
502     CHECK_EQ(num_elements(), old_num_elements);
503   }
504 
505   // Returns a string representation of the array suitable for debugging.
506   string ToString() const {
507     std::vector<string> pieces;
508     std::vector<int64> index(sizes_.size());
509     do {
510       // Emit leading spaces and opening square brackets
511       if (index.back() == 0) {
512         for (int64 i = sizes_.size() - 1; i >= 0; --i) {
513           if (i == 0 || index[i - 1] != 0) {
514             for (int64 j = 0; j < sizes_.size(); ++j) {
515               pieces.push_back(j < i ? " " : "[");
516             }
517             break;
518           }
519         }
520       }
521 
522       pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
523 
524       // Emit comma if it isn't the last element
525       if (index.back() != sizes_.back() - 1) {
526         pieces.push_back(", ");
527       }
528 
529       // Emit closing square brackets
530       for (int64 i = sizes_.size() - 1; i >= 0; --i) {
531         if (index[i] != sizes_[i] - 1) {
532           break;
533         }
534         pieces.push_back("]");
535         if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) {
536           pieces.push_back(",\n");
537         }
538       }
539     } while (next_index(&index));
540     return absl::StrJoin(pieces, "");
541   }
542 
543  private:
544   // Converts an initializer_list of type U to a vector of type int64. Used by
545   // the initializer list based constructors to convert the size type into int64
546   // to be passed to the size based constructor.
547   template <typename U>
548   static std::vector<int64> ToInt64Vector(
549       const std::initializer_list<U>& data) {
550     return std::vector<int64>(data.begin(), data.end());
551   }
552 
553   // Returns the linear index from the list of per-dimension indexes. Function
554   // is templated so can be used with an std::array from operator() to avoid
555   // memory allocation.
556   template <typename U>
557   int64 calculate_index(const U& indexes) const {
558     CHECK_EQ(sizes_.size(), indexes.size());
559     int64 index = 0;
560     for (int64 i = 0; i < sizes_.size(); ++i) {
561       index *= sizes_[i];
562       index += indexes[i];
563     }
564     DCHECK_LT(index, this->num_elements());
565     return index;
566   }
567 
568   // Advances the specified set of indexes and returns true if we haven't
569   // wrapped around (i.e. result isn't {0, 0, ...}).
570   bool next_index(std::vector<int64>* index) const {
571     CHECK_EQ(index->size(), sizes_.size());
572     for (int64 i = sizes_.size() - 1; i >= 0; --i) {
573       (*index)[i]++;
574       if ((*index)[i] < sizes_[i]) {
575         return true;
576       }
577       (*index)[i] = 0;
578     }
579     return false;
580   }
581 
582   std::vector<int64> sizes_;
583   std::unique_ptr<T[]> values_;
584 };
585 
586 // Specialization of FillRandom() method for complex64 type. Uses real part of
587 // the stddev parameter as the standard deviation value.
588 template <>
589 void Array<complex64>::FillRandom(const complex64& stddev, const double mean,
590                                   const int seed);
591 
592 }  // namespace xla
593 
594 #endif  // TENSORFLOW_COMPILER_XLA_ARRAY_H_
595