• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/array2d.h"
32 #include "tensorflow/compiler/xla/array3d.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/bitmap.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/macros.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 namespace xla {
50 
51 // Forward declare Literal and LiteralSlice class to be used by the creation
52 // methods in the base class.
53 class Literal;
54 class LiteralSlice;
55 
56 // Abstract base class for literals.
57 class LiteralBase {
58  public:
59   virtual ~LiteralBase() = 0;
60 
61   // Literals are equal if they have compatible shapes and the same data
62   // values. Layout is not compared.
63   bool operator==(const LiteralBase& other) const;
64   bool operator!=(const LiteralBase& other) const { return !(*this == other); }
65 
66   // Returns the shape of the literal.
shape()67   const Shape& shape() const { return root_piece().subshape(); }
68 
69   // Serialize to proto.
70   LiteralProto ToProto() const;
71 
72   // Returns a Span of the array for this literal for the given NativeT
73   // (e.g., float). CHECKs if the subshape of the literal at the given
74   // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
75   // to native type.
76   template <typename NativeT>
77   absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
78 
79   // Returns a const pointer to (or size of) the underlying buffer holding the
80   // array at the given shape index. CHECKs if the subshape of the literal at
81   // the given ShapeIndex is not array.
82   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
83   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
84 
85   // Returns this literal's data as a string. This literal must be a rank-1 U8
86   // array.
87   string GetR1U8AsString() const;
88 
89   // Returns a string representation of the literal value. The Shape of the
90   // literal is a prefix of the literal value in the string.
91 
92   // Warning: this function can take minutes for multi-million
93   // element Literals.
94   string ToString() const;
95 
96   // Returns a string representation of the literal value which does *not*
97   // include the shape string.
98   string ToStringWithoutShape() const;
99 
100   // Returns a string representation of the literal value which includes the
101   // shape string with its layout.does *not* include the shape string.
102   string ToStringWithLayout() const;
103 
104   // Gets an element in the literal at the given index. The multi_index is
105   // CHECKed against the dimension sizes.
106   template <typename NativeT>
107   NativeT Get(absl::Span<const int64> multi_index,
108               const ShapeIndex& shape_index) const;
109   // Overloads of Get for array literals. CHECKs if the literal is not
110   // array-shaped and dense.
111   template <typename NativeT>
112   NativeT Get(absl::Span<const int64> multi_index) const;
113 
114   // Returns the element value at index (0, ..., 0), however many zeroes are
115   // required for that index.
116   template <typename NativeT>
117   NativeT GetFirstElement() const;
118 
119   // As Get(), but determines the correct type and converts the value
120   // into text.
121   string GetAsString(absl::Span<const int64> multi_index,
122                      const ShapeIndex& shape_index = {}) const;
123 
124   // Return whether the value at the specified index is equal to the provided
125   // generic `value` (T must be an arithmetic type).
126   //
127   // Precondition: must be an array.
128   template <typename T>
129   typename std::enable_if<(std::is_arithmetic<T>::value ||
130                            std::is_same<T, Eigen::half>::value ||
131                            std::is_same<T, bfloat16>::value),
132                           bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)133   IsEqualAt(absl::Span<const int64> multi_index, T value) const {
134     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
135       return *as_s64 == value;
136     }
137     complex128 as_complex128 = *GetAsComplex128(multi_index);
138     return as_complex128.imag() == 0 && as_complex128.real() == value;
139   }
140 
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)141   bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
142     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
143       return *as_s64 == value.real() && value.imag() == 0;
144     }
145     auto as_complex128 = GetAsComplex128(multi_index);
146     return *as_complex128 == value;
147   }
148 
149   // As Get(), but determines the correct type and converts the value into
150   // int64.  This literal must be an array.
151   absl::optional<int64> GetIntegralAsS64(
152       absl::Span<const int64> multi_index) const;
153 
154   // As Get(), but determines the correct type, and converts the value into
155   // double. This literal must be an array.
156   absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
157 
158   // As Get(), but determines the correct type, and converts the value into
159   // complex128. All floating point types can be converted into complex128.
160   //
161   // This literal must be an array.
162   absl::optional<complex128> GetAsComplex128(
163       absl::Span<const int64> multi_index) const;
164 
165   // Invokes the "per cell" callback for each element in the provided
166   // literal with the element's indices and a string representation of
167   // the element's value.
168   //
169   // This function is useful if you want a polymorphic representation
170   // of the tensor's elements (turning it to a string for something
171   // like representation in a protobuf).
172   //
173   // This literal must have a dense layout.
174   void EachCellAsString(
175       const std::function<void(absl::Span<const int64> indices,
176                                const string& value)>& per_cell) const;
177   template <typename NativeT>
178   void EachCell(
179       std::function<void(absl::Span<const int64> indices, NativeT value)>
180           per_cell) const;
181 
182   // Returns whether every element in this literal is equal to value.
183   //
184   // value is an int8 because we expect this to be called with small
185   // compile-time constants (0, -1, etc.) and so that whatever value you pass
186   // can be represented exactly by floating-point types as small as 16 bits.
187   //
188   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
189   // are considered equal to true/false; other values are not considered equal
190   // to true. Also if this literal is not array-shaped false is returned.
191   bool IsAll(int8 value) const;
192 
193   // Like IsAll(const Literal&, int8), except we check whether the literal is
194   // equal to a particular floating-point number.
195   //
196   // If the literal is not a floating-point value, this always returns false.
197   //
198   // This casts value to the type of literal, then compares using ==.  The usual
199   // admonishments about floating-point equality checks apply.  We expect you to
200   // use this to check for values that can be expressed precisely as a float,
201   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
202   bool IsAllFloat(float value) const;
203 
204   // Like IsAll(const Literal&, int8), except we check whether the literal is
205   // equal to a particular complex number.
206   //
207   // If the literal is not a complex value, this always returns false.
208   //
209   // This casts value to the type of literal, then compares using ==.  The usual
210   // admonishments about floating-point equality checks apply.  We expect you to
211   // use this to check for complex values that can be expressed precisely as
212   // float pairs e.g. (-0.5, 1.0).
213   //
214   // This literal must have a dense layout.
215   bool IsAllComplex(complex64 value) const;
216 
217   // Literal consists entirely of the first element of the literal.
218   bool IsAllFirst() const;
219 
220   // Literal consists entirely of an iota.
221   bool IsR1Iota() const;
222 
223   // Returns whether this literal is zero at the specified index. This literal
224   // must be an array with a dense layout.
225   bool IsZero(absl::Span<const int64> indices) const;
226 
227   // Returns the count of the elements in the array at the given shape index in
228   // this literal.
229   int64 element_count(const ShapeIndex& index = {}) const {
230     if (index.empty()) {
231       // Common case, avoid GetSubshape().
232       return ShapeUtil::ElementsIn(shape());
233     }
234     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
235   }
236 
237   // Compute a hash for this literal.
238   size_t Hash() const;
239 
240   // Converts this literal to the given shape. Returns an error is the
241   // conversion is not possible.
242   StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
243 
244   // Converts this literal to another primitive type using a bitcast
245   // conversion. The to and from primitive types must have the same bit
246   // width. Returns an error if the conversion is not possible. This literal
247   // must be array-shaped.
248   StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
249 
250   // Converts this literal to another primitive type. Returns an error if the
251   // conversion is not possible. This literal must be array-shaped.
252   StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
253 
254   // Clones the underlying buffers into a new Literal.
255   Literal Clone() const;
256 
257   // TODO(b/67651157): The methods below which perform computation on Literals
258   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
259   // evaluator code which operates on Literals.
260   //
261   // Creates a new value that has the equivalent value as this
262   // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
263   // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
264   // minor-to-major dimension layout and the value in the cell at any given
265   // logical index (i0, i1) will be the same.
266   //
267   // For tuple shaped literals, shape_index should be used to select the inner
268   // array that the new layout applies to.
269   //
270   // Note: this is useful when the client wants to ensure that a value placed in
271   // the XLA allocation tracker has a particular layout; for efficiency
272   // purposes or avoiding unimplemented operation/layout combinations.
273   Literal Relayout(const Layout& new_layout,
274                    const ShapeIndex& shape_index = {}) const;
275 
276   // An overload of Relayout which changes the layout of the entire shape rather
277   // than being limited to a single array within the shape.
278   Literal Relayout(const Shape& shape_with_layout) const;
279 
280   // Creates a new literal by reshaping this literal to have the given
281   // dimensions. The total number of elements must not change; The
282   // implementation currently only supports monotonic dim0-major layouts.
283   // This literal must be an array.
284   StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
285 
286   // Creates a new literal by broadcasting this literal with `dimensions` to
287   // yield a literal of shape `result_shape`.
288   StatusOr<Literal> Broadcast(const Shape& result_shape,
289                               absl::Span<const int64> dimensions) const;
290 
291   // Creates a new literal by reordering the dimensions of this literal.
292   // The given `permutation` must be a permutation of the dimension numbers
293   // in the original literal, and it specifies the order of the new dimensions
294   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
295   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
296   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
297   // This literal must be an array.
298   Literal Transpose(absl::Span<const int64> permutation) const;
299 
300   // Creates a sub-array from this literal by extracting the indices
301   // [start_index, limit_index) of each dimension. The result literal has the
302   // same rank and layout as for the given literal. The number of indices in
303   // start_indices and limit_indices must be the rank of the literal, and the
304   // indices follow the order of the dimensions.
305   // This literal must be an array.
306   Literal Slice(absl::Span<const int64> start_indices,
307                 absl::Span<const int64> limit_indices) const;
308 
309   // Creates a literal with a prepended dimension with bound "times"; e.g. a
310   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
311   // literal replicated four times.
312   // This literal must be an array.
313   template <typename NativeT>
314   Literal Replicate(int64 times) const;
315 
316   // Creates a new Literal object with the shape specified as parameter.
317   // The content of the literal values is the default value of the primitive
318   // type of literal itself (0 for numeric types, and false for predicates).
319   //
320   // Note: It's an antipattern to use this method then immediately call
321   // MutableLiteralBase::Populate on the result (since that results in zero
322   // initialization, then reinitialization. Consider if a call to
323   // absl::make_unique<Literal>(shape), followed by the call to
324   // MutableLiteralBase::Populate can be used instead.
325   static Literal CreateFromShape(const Shape& shape);
326 
327  protected:
328   // A data structure representing a subshape at a particular ShapeIndex within
329   // the literal. For array-shaped ShapeIndexes, this data structure holds the
330   // pointer to the memory allocated for the array data.
331   class Piece {
332    public:
333     // Returns the buffer holding the array data for this piece as an array
334     // slice. This piece must be array-shaped.
335     template <typename NativeT>
336     absl::Span<const NativeT> data() const;
337     template <typename NativeT>
338     absl::Span<NativeT> data();
339 
340     // Returns the buffer holding the array data for this piece as a void*. This
341     // piece must be array-shaped.
342     void* untyped_data();
343     const void* untyped_data() const;
344 
345     // Gets or sets an element in the array at the given index. The multi_index
346     // is CHECKed against the dimension sizes of the array.  This piece must be
347     // array-shaped.
348     template <typename NativeT>
349     NativeT Get(absl::Span<const int64> index) const;
350     template <typename NativeT>
351     void Set(absl::Span<const int64> index, NativeT value);
352 
353     // Gets/sets the buffer holding the array data.
buffer()354     char* buffer() const { return buffer_; }
set_buffer(char * buffer)355     void set_buffer(char* buffer) { buffer_ = buffer; }
356 
357     // Gets or sets the subshape of this piece. This reference points to a
358     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()359     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)360     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
361 
362     // Returns the size in bytes of the buffer holding the array data.
size_bytes()363     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
364 
365     // Returns the number of elements in this piece's array.
element_count()366     int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
367 
368     // Returns the child piece at 'index' of this piece.
child(int64 index)369     Piece& child(int64 index) { return children_[index]; }
370 
371     // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)372     void emplace_back(Piece child_piece) {
373       children_.emplace_back(std::move(child_piece));
374     }
375 
376     // Returns the size of children pieces of this piece.
children_size()377     int64 children_size() { return children_.size(); }
378 
379     // Visitor functions that recursively traverses the piece and calls the
380     // given function at each child piece. The function has the type:
381     //    void (const ShapeIndex& index, const Piece& piece)
382     template <typename Fn>
ForEachSubpiece(const Fn & func)383     void ForEachSubpiece(const Fn& func) const {
384       ShapeIndex index;
385       return ForEachHelper(
386                  [&func](const ShapeIndex& index, const Piece& piece) {
387                    func(index, piece);
388                    return Status::OK();
389                  },
390                  *this, &index)
391           .IgnoreError();
392     }
393     // Same as above, but the function has the type:
394     //    Status (const ShapeIndex& index, const Piece& piece)
395     // The first non-OK return value is returned by the function.
396     template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)397     Status ForEachSubpieceWithStatus(const Fn& func) const {
398       ShapeIndex index;
399       return ForEachHelper(func, *this, &index);
400     }
401     // Same as above, but the function has the type:
402     //    Bool (const ShapeIndex& index, const Piece& piece)
403     // The first non-true return value is returned by the function.
404     template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)405     bool ForEachSubpieceWithBool(const Fn& func) const {
406       ShapeIndex index;
407       return ForEachHelperBool(func, *this, &index);
408     }
409     // Same as above, but the function has the type:
410     //    Void (const ShapeIndex& index, Piece& piece)
411     template <typename Fn>
ForEachMutableSubpiece(const Fn & func)412     void ForEachMutableSubpiece(const Fn& func) {
413       ShapeIndex index;
414       return ForEachMutableHelper(
415                  [&func](const ShapeIndex& index, Piece* piece) {
416                    func(index, piece);
417                    return Status::OK();
418                  },
419                  const_cast<xla::LiteralBase::Piece*>(this), &index)
420           .IgnoreError();
421     }
422     // Same as above, but the function has the type:
423     //    Status (const ShapeIndex& index, Piece& piece)
424     // The first non-OK return value is returned by the function.
425     template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)426     Status ForEachMutableSubpieceWithStatus(const Fn& func) {
427       ShapeIndex index;
428       return ForEachMutableHelper(
429           func, const_cast<xla::LiteralBase::Piece*>(this), &index);
430     }
431 
432     // Returns true if this piece and 'other' contain the same data. This piece
433     // and 'other' must be array-shaped and compatible.
434     bool EqualElements(const Piece& other) const;
435 
436     // Writes the shape and data (if array-shaped) into the given proto.
437     void WriteToProto(LiteralProto* proto) const;
438 
439     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
440     // and src must be compatible.
441     Status CopyFrom(const Piece& src);
442 
443     // Copies the data from the given proto into this piece. The shape of this
444     // piece must be equal (not just compatible) to the shape of the proto.
445     Status CopyFromProto(const LiteralProto& proto);
446 
447    private:
448     // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
449     // The first non-OK (or non-true) value is returned by the function.
450     // The callable 'func' has the same signature as described above in
451     // ForEachSubpiece*.
452     template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)453     Status ForEachHelper(const Fn& func, const Piece& piece,
454                          ShapeIndex* index) const {
455       TF_RETURN_IF_ERROR(func(*index, piece));
456       for (int64 i = 0; i < piece.children_.size(); ++i) {
457         index->push_back(i);
458         TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
459         index->pop_back();
460       }
461       return Status::OK();
462     }
463     template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)464     bool ForEachHelperBool(const Fn& func, const Piece& piece,
465                            ShapeIndex* index) const {
466       if (!func(*index, piece)) {
467         return false;
468       }
469       for (int64 i = 0; i < piece.children_.size(); ++i) {
470         index->push_back(i);
471         if (!ForEachHelperBool(func, piece.children_[i], index)) {
472           return false;
473         }
474         index->pop_back();
475       }
476       return true;
477     }
478     template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)479     Status ForEachMutableHelper(const Fn& func, Piece* piece,
480                                 ShapeIndex* index) {
481       TF_RETURN_IF_ERROR(func(*index, piece));
482       for (int64 i = 0; i < piece->children_.size(); ++i) {
483         index->push_back(i);
484         TF_RETURN_IF_ERROR(
485             ForEachMutableHelper(func, &piece->children_[i], index));
486         index->pop_back();
487       }
488       return Status::OK();
489     }
490 
491     // Recursive helper for EqualElements.
492     template <typename NativeT>
493     bool EqualElementsInternal(const Piece& other,
494                                std::vector<int64>* multi_index) const;
495 
496     // For array-shaped pieces, this is the buffer holding the literal data.
497     char* buffer_ = nullptr;
498 
499     // The shape of piece. This points into the shape of the containing Literal
500     // (Literal::shape_).
501     const Shape* subshape_ = nullptr;
502 
503     // Children pieces for tuple shaped pieces.
504     std::vector<Piece> children_ = {};
505   };  // class Piece
506 
piece(const ShapeIndex & shape_index)507   const Piece& piece(const ShapeIndex& shape_index) const {
508     Piece* piece = &const_cast<Piece&>(root_piece());
509     for (const auto i : shape_index) {
510       DCHECK_GE(i, 0);
511       DCHECK_LT(i, piece->children_size());
512       piece = &piece->child(i);
513     }
514     return *piece;
515   }
516 
517   // Returns the piece at the root of the shape.
518   virtual const Piece& root_piece() const = 0;
519 
520   // LiteralSlice and Literal must access Pieces of other Literals.
521   friend class MutableLiteralBase;
522   friend class LiteralSlice;
523   friend class BorrowingLiteral;
524 
525  private:
526   template <typename NativeT>
527   Literal SliceInternal(const Shape& result_shape,
528                         absl::Span<const int64> start_indices) const;
529 };
530 
531 // Abstract base class representing a mutable literal in XLA.
532 class MutableLiteralBase : public LiteralBase {
533  public:
534   virtual ~MutableLiteralBase() = 0;
535 
536   // Returns a Span view of the array for this literal for the
537   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
538   // given ShapeIndex is not array. See primitive_util.h for the mapping from
539   // XLA type to native type.
540   template <typename NativeT>
541   absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
542   // Unhide const method from parent class.
543   using LiteralBase::data;
544 
545   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
546   // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()547   Shape* mutable_shape_do_not_use() { return shape_.get(); }
548 
549   // Returns a pointer to the underlying buffer holding the array at the given
550   // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
551   // is not array.
552   void* untyped_data(const ShapeIndex& shape_index = {});
553   // Unhide const method from parent class.
554   using LiteralBase::untyped_data;
555 
556   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
557   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
558   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
559   // rooted at 'src_shape_index', but need not be arrays.
560   Status CopyFrom(const LiteralSlice& src_literal,
561                   const ShapeIndex& dest_shape_index = {},
562                   const ShapeIndex& src_shape_index = {});
563 
564   // Copies the values from src_literal, starting at src_base shape indexes,
565   // to this literal, starting at dest_base, where the copy size in each
566   // dimension is specified by copy_size.
567   // The src_literal and this literal must have the same primitive type,
568   // src_base+copy_size must fit the source literal dimensions, as well as
569   // dest_base+copy_size must fit the destination literal dimensions.
570   // Note: if either src_literal or this literal contains dimensions with zero
571   // element, then copy_size must be 0 in these dimensions while the
572   // corresponding base indices being 0.
573   // This literal and 'src_literal' must be arrays.
574   Status CopySliceFrom(const LiteralSlice& src_literal,
575                        absl::Span<const int64> src_base,
576                        absl::Span<const int64> dest_base,
577                        absl::Span<const int64> copy_size);
578 
579   // Copies one element from src_literal[src_index] to (*this)[dest_index].
580   Status CopyElementFrom(const LiteralSlice& src_literal,
581                          absl::Span<const int64> src_index,
582                          absl::Span<const int64> dest_index);
583 
584   // Sets an element in the literal at the given index. The multi_index is
585   // CHECKed against the dimension sizes.
586   template <typename NativeT>
587   void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
588            NativeT value);
589   // Overloads of Set for array literals. CHECKs if the literal is not
590   // array-shaped and dense.
591   template <typename NativeT>
592   void Set(absl::Span<const int64> multi_index, NativeT value);
593 
594   // As Set(), but truncates `value` to the literal element type before storing.
595   // This literal must be an array.
596   Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
597 
598   // As Set(), but truncates `value` to the literal element type before storing.
599   // This literal must be an array.
600   Status SetFromDouble(absl::Span<const int64> multi_index, double value);
601 
602   // Populate this literal with the given values. Examples:
603   //
604   //   // Populate with floats.
605   //   Array2D<float> float_values = ...
606   //   literal.PopulateR2FromArray2D(values);
607   //
608   //   // Populate with int32s.
609   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
610   //
611   // The shape and element type of this literal must match given values. For
612   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
613   // array of S32.
614   template <typename NativeT>
615   void PopulateR1(absl::Span<const NativeT> values);
616   void PopulateR1(const tensorflow::core::Bitmap& values);
617   template <typename NativeT>
618   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
619   template <typename NativeT>
620   void PopulateFromArray(const Array<NativeT>& values);
621   template <typename NativeT>
622   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
623   template <typename NativeT>
624   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
625   template <typename NativeT>
626   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
627 
628   // Populates literal values by calling the generator function for every cell
629   // in this literal object.
630   //
631   // generator must be a callable of the type
632   // NativeT(absl::Span<int64> indexes) or compatible.
633   //
634   // This literal must have a dense layout.
635   template <typename NativeT, typename FnType>
636   Status Populate(const FnType& generator);
637 
638   // A parallel version of Populate(). This can be used if the generator is
639   // thread-safe and the values for the shape's different elements are
640   // independent.
641   template <typename NativeT, typename FnType>
642   Status PopulateParallel(const FnType& generator);
643 
644   // Fills this literal with the given value.
645   template <typename NativeT>
646   void PopulateWithValue(NativeT value);
647 
648   // This operation is the inverse of DecomposeTuple. The given elements are
649   // moved into the tuple elements of a new tuple-shaped Literal which is
650   // returned. Upon return, each of the Literals in 'elements' is set to a nil
651   // shape (empty tuple).
652   static Literal MoveIntoTuple(absl::Span<Literal> elements);
653 
654   // Serialize from a proto.
655   static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
656                                            bool prohibit_empty_literal = true);
657 
658  protected:
659   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)660   Piece& piece(const ShapeIndex& shape_index) {
661     return const_cast<Piece&>(LiteralBase::piece(shape_index));
662   }
663 
root_piece()664   Piece& root_piece() const override { return *root_piece_; };
665 
666   // Internal template helper for the Literal::CopySliceFrom(), matching its
667   // arguments one by one.
668   template <typename NativeT>
669   Status CopySliceFromInternal(const LiteralBase& src_literal,
670                                absl::Span<const int64> src_base,
671                                absl::Span<const int64> dest_base,
672                                absl::Span<const int64> copy_size);
673 
674   // Utility structure which is used to create the optimal configuration for
675   // a ShapeUtil::ForEachIndex() scan across two literals.
676   struct StrideConfig {
677     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
678                  absl::Span<const int64> dimensions);
679 
680     // The dimensions of the stride operation. Essentially every dimension
681     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
682     // steps.
683     absl::Span<const int64> dimensions;
684     DimensionVector base;
685     DimensionVector step;
686     int64 minor_dimension = 0;
687     // The size of the strides for source and destination. One of the two
688     // (the one looping through its most minor dimension) will be 1, while
689     // the other will be the stride size at the dimension matching the other
690     // shape most minor dimension being scanned.
691     int64 dest_stride = 1;
692     int64 source_stride = 1;
693     // The size of the inner loop on the most minor dimension.
694     int64 minor_loop_size = 1;
695   };
696 
697   // Literal class always owns the shape. The parent class borrows this shape.
698   std::unique_ptr<Shape> shape_;
699 
700   Piece* root_piece_ = nullptr;
701 
702   // Implementation details shared between Populate() and PopulateParallel()
703   template <typename NativeT, typename FnType>
704   Status PopulateInternal(const FnType& generator, bool parallel);
705 
706   friend class LiteralBase;
707   friend class MutableBorrowingLiteral;
708 };
709 std::ostream& operator<<(std::ostream& out, const Literal& literal);
710 
711 // The underlying buffer and shape is always owned by this class.
712 class Literal : public MutableLiteralBase {
713  public:
Literal()714   Literal() : Literal(ShapeUtil::MakeNil()) {}
715 
716   // Create a literal of the given shape. The literal is allocated sufficient
717   // memory to hold the shape. Memory is uninitialized.
718   explicit Literal(const Shape& shape);
719   virtual ~Literal();
720 
721   // Literals are moveable, but not copyable. To copy a literal use
722   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
723   // of literals which can be expensive.
724   Literal(const Literal& other) = delete;
725   Literal& operator=(const Literal& other) = delete;
726   Literal(Literal&& other);
727   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
728   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
729   // to nullptr.
730   Literal(const Shape& shape, bool allocate_arrays);
731   Literal& operator=(Literal&& other);
732 
733   // Similar to CopyFrom, but with move semantics. The subshape of this literal
734   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
735   // (layouts and shapes must match), but need not be arrays. The memory
736   // allocated in this literal for the subshape at dest_shape_index is
737   // deallocated, and the respective buffers are replaced with those in
738   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
739   virtual Status MoveFrom(Literal&& src_literal,
740                           const ShapeIndex& dest_shape_index = {});
741 
742   // Returns a vector containing the tuple elements of this Literal as separate
743   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
744   // elements are moved into the new Literals; no data is copied. Upon return
745   // this Literal is set to a nil shape (empty tuple)
746   std::vector<Literal> DecomposeTuple();
747 
748  private:
749   // Deallocate the buffers held by this literal.
750   void DeallocateBuffers();
751 
752   // Recursively sets the subshapes and buffers of all subpieces rooted at
753   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
754   // the shape.
755   void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
756 };
757 
758 // The underlying buffer is not owned by this class and is always owned by
759 // others. The shape is not owned by this class and not mutable.
760 class MutableBorrowingLiteral : public MutableLiteralBase {
761  public:
762   virtual ~MutableBorrowingLiteral();
763 
MutableBorrowingLiteral()764   MutableBorrowingLiteral() : MutableLiteralBase() {}
765 
766   MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
767   MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
768 
769   // Implicit conversion constructors.
770   MutableBorrowingLiteral(MutableLiteralBase* literal);
771   MutableBorrowingLiteral(MutableBorrowingLiteral literal,
772                           const ShapeIndex& view_root);
773   MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
774 
775  private:
776   // Recursively copies the subtree from the `src_piece` at the given child
777   // index to the `dest_piece`. For buffers only the pointers are copied, but
778   // not the content.
779   void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
780                         Piece* dest_piece);
781 };
782 
783 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
784 // literal buffers always owned by others.
785 class LiteralSlice : public LiteralBase {
786  public:
LiteralSlice()787   LiteralSlice() : LiteralBase() {}
788 
789   // Implicit conversion constructors.
790   LiteralSlice(const LiteralBase& literal);
791   LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
792 
793  private:
root_piece()794   const Piece& root_piece() const override { return *root_piece_; };
795 
796   const Piece* root_piece_;  // Not owned.
797 };
798 
799 // A read-only Literal where the underlying buffers are never owned by this
800 // class.
801 class BorrowingLiteral : public LiteralBase {
802  public:
BorrowingLiteral()803   BorrowingLiteral() : LiteralBase() {}
804 
805   // 'src_buf_ptr' is not owned by this class and must outlive the
806   // lifetime of this class. It points to an appropriately sized buffer with
807   // data interpretered as indicated by 'shape'.
808   // This constructor is only used for array shapes.
809   BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
810   // Similar as above, except to be used for constructing non-nested tuples.
811   BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
812                    const Shape& shape);
813   // TODO(b/79707221): adding constructors for nested tuples as well.
814 
815  private:
816   // Recursively builds the subtree for the given piece and sets the subshapes
817   // of the given piece with the given shape.
818   void BuildPieceSubtree(const Shape& shape, Piece* piece);
819 
820   // Accessor for the root piece of this literal.
root_piece()821   const Piece& root_piece() const override { return root_piece_; };
822   Piece root_piece_;
823 
824   // Shape of this literal. Stored as unique_ptr such that the (default) move
825   // construction of this class would be trivially correct: the pointer to Shape
826   // root_piece_ stores will still point to the correct address.
827   std::unique_ptr<Shape> shape_;
828 };
829 
830 template <typename NativeT>
data()831 absl::Span<const NativeT> LiteralBase::Piece::data() const {
832   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
833   DCHECK_EQ(subshape().element_type(),
834             primitive_util::NativeToPrimitiveType<NativeT>())
835       << "Attempting to access "
836       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
837       << " type, but literal element type is "
838       << PrimitiveType_Name(subshape().element_type());
839   return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
840                                    element_count());
841 }
842 
843 template <typename NativeT>
data()844 absl::Span<NativeT> LiteralBase::Piece::data() {
845   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
846   DCHECK_EQ(subshape().element_type(),
847             primitive_util::NativeToPrimitiveType<NativeT>())
848       << "Attempting to access "
849       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
850       << " type, but literal element type is "
851       << PrimitiveType_Name(subshape().element_type());
852   return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
853                              element_count());
854 }
855 
856 template <typename NativeT>
Get(absl::Span<const int64> multi_index)857 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
858   CHECK(LayoutUtil::IsDenseArray(subshape()));
859   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
860       subshape(), multi_index)];
861 }
862 
863 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)864 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
865                              NativeT value) {
866   CHECK(LayoutUtil::IsDenseArray(subshape()));
867   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
868       subshape(), multi_index)] = value;
869 }
870 
871 template <typename NativeT>
data(const ShapeIndex & shape_index)872 absl::Span<const NativeT> LiteralBase::data(
873     const ShapeIndex& shape_index) const {
874   return piece(shape_index).data<NativeT>();
875 }
876 
877 template <typename NativeT>
data(const ShapeIndex & shape_index)878 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
879   return piece(shape_index).data<NativeT>();
880 }
881 
882 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)883 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
884                                 const ShapeIndex& shape_index) const {
885   return piece(shape_index).Get<NativeT>(multi_index);
886 }
887 
888 template <typename NativeT>
Get(absl::Span<const int64> multi_index)889 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
890   return root_piece().Get<NativeT>(multi_index);
891 }
892 
893 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)894 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
895                                     const ShapeIndex& shape_index,
896                                     NativeT value) {
897   return piece(shape_index).Set<NativeT>(multi_index, value);
898 }
899 
900 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)901 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
902                                     NativeT value) {
903   return root_piece().Set<NativeT>(multi_index, value);
904 }
905 
906 template <typename NativeT>
GetFirstElement()907 NativeT LiteralBase::GetFirstElement() const {
908   return data<NativeT>().at(0);
909 }
910 
911 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)912 void LiteralBase::EachCell(
913     std::function<void(absl::Span<const int64> indices, NativeT value)>
914         per_cell) const {
915   if (ShapeUtil::IsZeroElementArray(shape())) {
916     return;
917   }
918   std::vector<int64> indices(shape().rank(), 0);
919   do {
920     per_cell(indices, Get<NativeT>(indices));
921   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
922 }
923 
924 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)925 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
926   CHECK(shape().IsArray());
927   CHECK_EQ(shape().rank(), 1);
928   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
929   CHECK_EQ(shape().element_type(),
930            primitive_util::NativeToPrimitiveType<NativeT>());
931   auto data_span = data<NativeT>();
932   std::copy(values.begin(), values.end(), data_span.begin());
933 }
934 
935 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)936 void MutableLiteralBase::PopulateR2(
937     std::initializer_list<std::initializer_list<NativeT>> values) {
938   CHECK(shape().IsArray());
939   CHECK_EQ(shape().rank(), 2);
940   CHECK_EQ(shape().element_type(),
941            primitive_util::NativeToPrimitiveType<NativeT>());
942 
943   const int64 dim0_size = values.size();
944   const int64 dim1_size = values.begin()->size();
945   CHECK_EQ(dim0_size, shape().dimensions(0));
946   CHECK_EQ(dim1_size, shape().dimensions(1));
947 
948   int64 dim0 = 0;
949   for (auto inner_list : values) {
950     int64 dim1 = 0;
951     for (auto value : inner_list) {
952       Set({dim0, dim1}, value);
953       ++dim1;
954     }
955     CHECK_EQ(dim1_size, dim1);
956     ++dim0;
957   }
958 }
959 
960 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)961 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
962   CHECK(shape().IsArray());
963   CHECK_EQ(shape().element_type(),
964            primitive_util::NativeToPrimitiveType<NativeT>());
965   CHECK_EQ(shape().rank(), values.num_dimensions());
966   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
967     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
968   }
969   values.Each([this](absl::Span<const int64> indices, NativeT value) {
970     this->Set(indices, value);
971   });
972 }
973 
974 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)975 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
976   PopulateFromArray(values);
977 }
978 
979 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)980 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
981   PopulateFromArray(values);
982 }
983 
984 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)985 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
986   PopulateFromArray(values);
987 }
988 
989 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)990 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
991                                             bool parallel) {
992   const Shape& this_shape = shape();
993   const int64 rank = this_shape.rank();
994   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
995   TF_RET_CHECK(this_shape.element_type() ==
996                primitive_util::NativeToPrimitiveType<NativeT>());
997   absl::Span<NativeT> literal_data = data<NativeT>();
998   if (rank > 0) {
999     StrideConfig stride_config(this_shape, this_shape,
1000                                AsInt64Slice(this_shape.dimensions()));
1001     int64 minor_dimension_size =
1002         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1003 
1004     auto init_function = [&](absl::Span<const int64> indexes) {
1005       DimensionVector minor_scan_indexes(rank, 0);
1006       const int64 index =
1007           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1008       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1009       for (int64 i = 0; i < minor_dimension_size; ++i) {
1010         minor_scan_indexes[stride_config.minor_dimension] = i;
1011         literal_data.at(index + i) = generator(minor_scan_indexes);
1012       }
1013     };
1014     if (parallel) {
1015       ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1016                                       stride_config.dimensions,
1017                                       stride_config.step, init_function);
1018     } else {
1019       ShapeUtil::ForEachIndex(
1020           this_shape, stride_config.base, stride_config.dimensions,
1021           stride_config.step,
1022           [&init_function](absl::Span<const int64> indexes) {
1023             init_function(indexes);
1024             return true;
1025           });
1026     }
1027   } else {
1028     // For scalars.
1029     literal_data.at(0) = generator({});
1030   }
1031   return Status::OK();
1032 }
1033 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1034 Status MutableLiteralBase::Populate(const FnType& generator) {
1035   return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1036 }
1037 
1038 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1039 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1040   return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1041 }
1042 
1043 template <typename NativeT>
PopulateWithValue(NativeT value)1044 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1045   CHECK(shape().IsArray());
1046   CHECK_EQ(shape().element_type(),
1047            primitive_util::NativeToPrimitiveType<NativeT>());
1048   for (NativeT& element : data<NativeT>()) {
1049     element = value;
1050   }
1051 }
1052 
1053 template <typename NativeT>
Replicate(int64 times)1054 Literal LiteralBase::Replicate(int64 times) const {
1055   DimensionVector bounds = {times};
1056   bounds.reserve(shape().dimensions_size() + 1);
1057   for (int64 bound : shape().dimensions()) {
1058     bounds.push_back(bound);
1059   }
1060   Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1061   int64 elements = ShapeUtil::ElementsIn(literal.shape());
1062   if (elements == 0) {
1063     return literal;
1064   }
1065 
1066   DimensionVector output_indices(bounds.size(), 0);
1067   absl::Span<const int64> input_indices = output_indices;
1068   input_indices.remove_prefix(1);
1069 
1070   bool done = false;
1071   while (!done) {
1072     const auto element = Get<NativeT>(input_indices);
1073     literal.Set<NativeT>(output_indices, element);
1074 
1075     done = true;
1076     for (int n = 0; n < output_indices.size(); ++n) {
1077       ++output_indices[n];
1078       if (output_indices[n] < bounds[n]) {
1079         done = false;
1080         break;
1081       }
1082       output_indices[n] = 0;
1083     }
1084   }
1085   return literal;
1086 }
1087 
1088 }  // namespace xla
1089 
1090 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1091