• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/array2d.h"
32 #include "tensorflow/compiler/xla/array3d.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/sparse_index_array.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56 
57 // Abstract base class for literals.
58 class LiteralBase {
59  public:
60   virtual ~LiteralBase() = 0;
61 
62   // Literals are equal if they have compatible shapes and the same data
63   // values. Layout is not compared.
64   bool operator==(const LiteralBase& other) const;
65   bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66 
67   // Returns the shape of the literal.
shape()68   const Shape& shape() const { return root_piece().subshape(); }
69 
70   // Serialize to proto.
71   LiteralProto ToProto() const;
72 
73   // Returns a Span of the array for this literal for the given NativeT
74   // (e.g., float). CHECKs if the subshape of the literal at the given
75   // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76   // to native type.
77   template <typename NativeT>
78   absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79 
80   // Returns a const pointer to the sparse index array. Returns nullptr if the
81   // literal is not a sparse array.
82   const SparseIndexArray* sparse_indices(
83       const ShapeIndex& shape_index = {}) const;
84 
85   // Returns a const pointer to (or size of) the underlying buffer holding the
86   // array at the given shape index. CHECKs if the subshape of the literal at
87   // the given ShapeIndex is not array.
88   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
89   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
90 
91   // Returns this literal's data as a string. This literal must be a rank-1 U8
92   // array.
93   string GetR1U8AsString() const;
94 
95   // Returns a string representation of the literal value. The Shape of the
96   // literal is a prefix of the literal value in the string.
97 
98   // Warning: this function can take minutes for multi-million
99   // element Literals.
100   string ToString() const;
101 
102   // Returns a string representation of the literal value which does *not*
103   // include the shape string.
104   string ToStringWithoutShape() const;
105 
106   // Returns a string representation of the literal value which includes the
107   // shape string with its layout.does *not* include the shape string.
108   string ToStringWithLayout() const;
109 
110   // Gets an element in the literal at the given index. The multi_index is
111   // CHECKed against the dimension sizes.
112   template <typename NativeT>
113   NativeT Get(absl::Span<const int64> multi_index,
114               const ShapeIndex& shape_index) const;
115   // Overloads of Get for array literals. CHECKs if the literal is not
116   // array-shaped and dense.
117   template <typename NativeT>
118   NativeT Get(absl::Span<const int64> multi_index) const;
119 
120   // Returns the element value at index (0, ..., 0), however many zeroes are
121   // required for that index.
122   template <typename NativeT>
123   NativeT GetFirstElement() const;
124 
125   // As Get(), but determines the correct type and converts the value
126   // into text.
127   string GetAsString(absl::Span<const int64> multi_index,
128                      const ShapeIndex& shape_index = {}) const;
129   // As GetSparseElement(), but determines the correct type and converts the
130   // value into text.
131   string GetSparseElementAsString(int64 sparse_element_number,
132                                   const ShapeIndex& shape_index = {}) const;
133   // As Get(), but determines the correct type and converts the value into
134   // int64.  This literal must be an array.
135   StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
136 
137   // Returns the multi-index of the element in a sparse literal at the given
138   // sparse element number.  The sparse element number is the position with in
139   // the sparse array's list of (index, value) pairs, and is checked against the
140   // total number of (index, value) pairs in the sparse array.
141   absl::Span<const int64> GetSparseIndex(
142       int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
143 
144   // Returns the value of the element in a sparse literal at the given sparse
145   // element number.  The sparse element number is the position with in the
146   // sparse array's list of (index, value) pairs, and is checked against the
147   // total number of (index, value) pairs in the sparse array.
148   template <typename NativeT>
149   NativeT GetSparseElement(int64 sparse_element_number,
150                            const ShapeIndex& shape_index = {}) const;
151 
152   // Invokes the "per cell" callback for each element in the provided
153   // literal with the element's indices and a string representation of
154   // the element's value.
155   //
156   // This function is useful if you want a polymorphic representation
157   // of the tensor's elements (turning it to a string for something
158   // like representation in a protobuf).
159   //
160   // This literal must have a dense layout.
161   void EachCellAsString(
162       const std::function<void(absl::Span<const int64> indices,
163                                const string& value)>& per_cell) const;
164   template <typename NativeT>
165   void EachCell(
166       std::function<void(absl::Span<const int64> indices, NativeT value)>
167           per_cell) const;
168 
169   // Returns whether every element in this literal is equal to value.
170   //
171   // value is an int8 because we expect this to be called with small
172   // compile-time constants (0, -1, etc.) and so that whatever value you pass
173   // can be represented exactly by floating-point types as small as 16 bits.
174   //
175   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
176   // are considered equal to true/false; other values are not considered equal
177   // to true. Also if this literal is not array-shaped false is returned.
178   bool IsAll(int8 value) const;
179 
180   // Like IsAll(const Literal&, int8), except we check whether the literal is
181   // equal to a particular floating-point number.
182   //
183   // If the literal is not a floating-point value, this always returns false.
184   //
185   // This casts value to the type of literal, then compares using ==.  The usual
186   // admonishments about floating-point equality checks apply.  We expect you to
187   // use this to check for values that can be expressed precisely as a float,
188   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
189   bool IsAllFloat(float value) const;
190 
191   // Like IsAll(const Literal&, int8), except we check whether the literal is
192   // equal to a particular complex number.
193   //
194   // If the literal is not a complex value, this always returns false.
195   //
196   // This casts value to the type of literal, then compares using ==.  The usual
197   // admonishments about floating-point equality checks apply.  We expect you to
198   // use this to check for complex values that can be expressed precisely as
199   // float pairs e.g. (-0.5, 1.0).
200   //
201   // This literal must have a dense layout.
202   bool IsAllComplex(complex64 value) const;
203 
204   // Literal consists entirely of the first element of the literal.
205   bool IsAllFirst() const;
206 
207   // Literal consists entirely of an iota.
208   bool IsR1Iota() const;
209 
210   // Returns whether this literal is zero at the specified index. This literal
211   // must be an array with a dense layout.
212   bool IsZero(absl::Span<const int64> indices) const;
213 
214   // Returns the count of the elements in the array at the given shape index in
215   // this literal.
216   int64 element_count(const ShapeIndex& index = {}) const {
217     if (index.empty()) {
218       // Common case, avoid GetSubshape().
219       return ShapeUtil::ElementsIn(shape());
220     }
221     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
222   }
223 
224   // Returns the count of the elements in the sparse array at the given shape
225   // index in this literal, which will be no larger than
226   // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
227   int64 sparse_element_count() const;
228 
229   // Compute a hash for this literal.  This literal must not be a sparse tensor
230   // or a tuple containing a sparse tensor.
231   size_t Hash() const;
232 
233   // Converts this literal to the given shape. Returns an error is the
234   // conversion is not possible.
235   StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
236 
237   // Converts this literal to another primitive type using a bitcast
238   // conversion. The to and from primitive types must have the same bit
239   // width. Returns an error if the conversion is not possible. This literal
240   // must be array-shaped.
241   StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
242 
243   // Converts this literal to another primitive type. Returns an error if the
244   // conversion is not possible. This literal must be array-shaped.
245   StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
246 
247   // Clones the underlying buffers into a new Literal.
248   Literal Clone() const;
249 
250   // TODO(b/67651157): The methods below which perform computation on Literals
251   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
252   // evaluator code which operates on Literals.
253   //
254   // Creates a new value that has the equivalent value as this
255   // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
256   // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
257   // minor-to-major dimension layout and the value in the cell at any given
258   // logical index (i0, i1) will be the same.
259   //
260   // For tuple shaped literals, shape_index should be used to select the inner
261   // array that the new layout applies to.
262   //
263   // Note: this is useful when the client wants to ensure that a value placed in
264   // the XLA allocation tracker has a particular layout; for efficiency
265   // purposes or avoiding unimplemented operation/layout combinations.
266   Literal Relayout(const Layout& new_layout,
267                    const ShapeIndex& shape_index = {}) const;
268 
269   // An overload of Relayout which changes the layout of the entire shape rather
270   // than being limited to a single array within the shape.
271   Literal Relayout(const Shape& shape_with_layout) const;
272 
273   // Creates a new literal by reshaping this literal to have the given
274   // dimensions. The total number of elements must not change; The
275   // implementation currently only supports monotonic dim0-major layouts.
276   // This literal must be an array.
277   StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
278 
279   // Creates a new literal by broadcasting this literal with `dimensions` to
280   // yield a literal of shape `result_shape`.
281   StatusOr<Literal> Broadcast(const Shape& result_shape,
282                               absl::Span<const int64> dimensions) const;
283 
284   // Creates a new literal by reordering the dimensions of this literal.
285   // The given `permutation` must be a permutation of the dimension numbers
286   // in the original literal, and it specifies the order of the new dimensions
287   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
288   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
289   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
290   // This literal must be an array.
291   Literal Transpose(absl::Span<const int64> permutation) const;
292 
293   // Creates a sub-array from this literal by extracting the indices
294   // [start_index, limit_index) of each dimension. The result literal has the
295   // same rank and layout as for the given literal. The number of indices in
296   // start_indices and limit_indices must be the rank of the literal, and the
297   // indices follow the order of the dimensions.
298   // This literal must be an array.
299   Literal Slice(absl::Span<const int64> start_indices,
300                 absl::Span<const int64> limit_indices) const;
301 
302   // Creates a literal with a prepended dimension with bound "times"; e.g. a
303   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
304   // literal replicated four times.
305   // This literal must be an array.
306   template <typename NativeT>
307   Literal Replicate(int64 times) const;
308 
309   // Creates a new Literal object with the shape specified as parameter.
310   // The content of the literal values is the default value of the primitive
311   // type of literal itself (0 for numeric types, and false for predicates).
312   //
313   // Note: It's an antipattern to use this method then immediately call
314   // MutableLiteralBase::Populate on the result (since that results in zero
315   // initialization, then reinitialization. Consider if a call to
316   // absl::make_unique<Literal>(shape), followed by the call to
317   // MutableLiteralBase::Populate can be used instead.
318   static Literal CreateFromShape(const Shape& shape);
319 
320  protected:
321   // A data structure representing a subshape at a particular ShapeIndex within
322   // the literal. For array-shaped ShapeIndexes, this data structure holds the
323   // pointer to the memory allocated for the array data.
324   class Piece {
325    public:
326     // Returns the buffer holding the array data for this piece as an array
327     // slice. This piece must be array-shaped.
328     template <typename NativeT>
329     absl::Span<const NativeT> data() const;
330     template <typename NativeT>
331     absl::Span<NativeT> data();
332 
333     // Returns the buffer holding the array data for this piece as a void*. This
334     // piece must be array-shaped.
335     void* untyped_data();
336     const void* untyped_data() const;
337 
338     // Gets or sets an element in the array at the given index. The multi_index
339     // is CHECKed against the dimension sizes of the array.  This piece must be
340     // array-shaped.
341     template <typename NativeT>
342     NativeT Get(absl::Span<const int64> index) const;
343     template <typename NativeT>
344     void Set(absl::Span<const int64> index, NativeT value);
345 
346     // Gets/sets the buffer holding the array data.
buffer()347     char* buffer() const { return buffer_; }
set_buffer(char * buffer)348     void set_buffer(char* buffer) { buffer_ = buffer; }
349 
350     // The array of multi-indices that provide the locations of non-zero
351     // elements in a sparse array.  Only used if
352     // LayoutUtil::IsSparseArray(shape()) is true.
sparse_indices()353     SparseIndexArray* sparse_indices() const { return sparse_indices_; }
set_sparse_indices(SparseIndexArray * sparse_indices)354     void set_sparse_indices(SparseIndexArray* sparse_indices) {
355       sparse_indices_ = sparse_indices;
356     }
357 
358     // Gets or sets the subshape of this piece. This reference points to a
359     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()360     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)361     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
362 
363     // Returns the size in bytes of the buffer holding the array data.
size_bytes()364     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
365 
366     // Returns the number of elements in this piece's array.
element_count()367     int64 element_count() const {
368       // If this is a sparse array, use the number of elements represented by
369       // the indices in the associated SparseIndexArray.
370       return LayoutUtil::IsSparseArray(subshape())
371                  ? sparse_indices()->index_count()
372                  : ShapeUtil::ElementsIn(subshape());
373     }
374 
375     // Returns the child piece at 'index' of this piece.
child(int64 index)376     Piece& child(int64 index) { return children_[index]; }
377 
378     // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)379     void emplace_back(Piece child_piece) {
380       children_.emplace_back(std::move(child_piece));
381     }
382 
383     // Returns the size of children pieces of this piece.
children_size()384     int64 children_size() { return children_.size(); }
385 
386     // Visitor functions that recursively traverses the piece and calls the
387     // given function at each child piece. The function has the type:
388     //    void (const ShapeIndex& index, const Piece& piece)
389     template <typename Fn>
ForEachSubpiece(const Fn & func)390     void ForEachSubpiece(const Fn& func) const {
391       ShapeIndex index;
392       return ForEachHelper(
393                  [&func](const ShapeIndex& index, const Piece& piece) {
394                    func(index, piece);
395                    return Status::OK();
396                  },
397                  *this, &index)
398           .IgnoreError();
399     }
400     // Same as above, but the function has the type:
401     //    Status (const ShapeIndex& index, const Piece& piece)
402     // The first non-OK return value is returned by the function.
403     template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)404     Status ForEachSubpieceWithStatus(const Fn& func) const {
405       ShapeIndex index;
406       return ForEachHelper(func, *this, &index);
407     }
408     // Same as above, but the function has the type:
409     //    Bool (const ShapeIndex& index, const Piece& piece)
410     // The first non-true return value is returned by the function.
411     template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)412     bool ForEachSubpieceWithBool(const Fn& func) const {
413       ShapeIndex index;
414       return ForEachHelperBool(func, *this, &index);
415     }
416     // Same as above, but the function has the type:
417     //    Void (const ShapeIndex& index, Piece& piece)
418     template <typename Fn>
ForEachMutableSubpiece(const Fn & func)419     void ForEachMutableSubpiece(const Fn& func) {
420       ShapeIndex index;
421       return ForEachMutableHelper(
422                  [&func](const ShapeIndex& index, Piece* piece) {
423                    func(index, piece);
424                    return Status::OK();
425                  },
426                  const_cast<xla::LiteralBase::Piece*>(this), &index)
427           .IgnoreError();
428     }
429     // Same as above, but the function has the type:
430     //    Status (const ShapeIndex& index, Piece& piece)
431     // The first non-OK return value is returned by the function.
432     template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)433     Status ForEachMutableSubpieceWithStatus(const Fn& func) {
434       ShapeIndex index;
435       return ForEachMutableHelper(
436           func, const_cast<xla::LiteralBase::Piece*>(this), &index);
437     }
438 
439     // Returns true if this piece and 'other' contain the same data. This piece
440     // and 'other' must be array-shaped and compatible.
441     bool EqualElements(const Piece& other) const;
442 
443     // Writes the shape and data (if array-shaped) into the given proto.
444     void WriteToProto(LiteralProto* proto) const;
445 
446     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
447     // and src must be compatible.
448     Status CopyFrom(const Piece& src);
449 
450     // Copies the data from the given proto into this piece. The shape of this
451     // piece must be equal (not just compatible) to the shape of the proto.
452     Status CopyFromProto(const LiteralProto& proto);
453 
454     // Sorts the elements in a sparse array.
455     void SortSparseElements();
456 
457    private:
458     // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
459     // The first non-OK (or non-true) value is returned by the function.
460     // The callable 'func' has the same signature as described above in
461     // ForEachSubpiece*.
462     template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)463     Status ForEachHelper(const Fn& func, const Piece& piece,
464                          ShapeIndex* index) const {
465       TF_RETURN_IF_ERROR(func(*index, piece));
466       for (int64 i = 0; i < piece.children_.size(); ++i) {
467         index->push_back(i);
468         TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
469         index->pop_back();
470       }
471       return Status::OK();
472     }
473     template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)474     bool ForEachHelperBool(const Fn& func, const Piece& piece,
475                            ShapeIndex* index) const {
476       if (!func(*index, piece)) {
477         return false;
478       }
479       for (int64 i = 0; i < piece.children_.size(); ++i) {
480         index->push_back(i);
481         if (!ForEachHelperBool(func, piece.children_[i], index)) {
482           return false;
483         }
484         index->pop_back();
485       }
486       return true;
487     }
488     template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)489     Status ForEachMutableHelper(const Fn& func, Piece* piece,
490                                 ShapeIndex* index) {
491       TF_RETURN_IF_ERROR(func(*index, piece));
492       for (int64 i = 0; i < piece->children_.size(); ++i) {
493         index->push_back(i);
494         TF_RETURN_IF_ERROR(
495             ForEachMutableHelper(func, &piece->children_[i], index));
496         index->pop_back();
497       }
498       return Status::OK();
499     }
500 
501     // Recursive helper for EqualElements.
502     template <typename NativeT>
503     bool EqualElementsInternal(const Piece& other,
504                                std::vector<int64>* multi_index) const;
505 
506     // Helper for SortSparseElements that has the element type as a template
507     // parameter.
508     template <typename NativeT>
509     void SortSparseElementsInternal();
510 
511     // For array-shaped pieces, this is the buffer holding the literal data.
512     char* buffer_ = nullptr;
513 
514     // For sparse arrays, this is the array of indices.
515     SparseIndexArray* sparse_indices_ = nullptr;
516 
517     // The shape of piece. This points into the shape of the containing Literal
518     // (Literal::shape_).
519     const Shape* subshape_ = nullptr;
520 
521     // Children pieces for tuple shaped pieces.
522     std::vector<Piece> children_ = {};
523   };  // class Piece
524 
piece(const ShapeIndex & shape_index)525   const Piece& piece(const ShapeIndex& shape_index) const {
526     Piece* piece = &const_cast<Piece&>(root_piece());
527     for (const auto i : shape_index) {
528       DCHECK_GE(i, 0);
529       DCHECK_LT(i, piece->children_size());
530       piece = &piece->child(i);
531     }
532     return *piece;
533   }
534 
535   // Returns the piece at the root of the shape.
536   virtual const Piece& root_piece() const = 0;
537 
538   // LiteralSlice and Literal must access Pieces of other Literals.
539   friend class MutableLiteralBase;
540   friend class LiteralSlice;
541   friend class BorrowingLiteral;
542 
543  private:
544   template <typename NativeT>
545   Literal SliceInternal(const Shape& result_shape,
546                         absl::Span<const int64> start_indices) const;
547 };
548 
549 // Abstract base class representing a mutable literal in XLA.
550 class MutableLiteralBase : public LiteralBase {
551  public:
552   virtual ~MutableLiteralBase() = 0;
553 
554   // Returns a Span view of the array for this literal for the
555   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
556   // given ShapeIndex is not array. See primitive_util.h for the mapping from
557   // XLA type to native type.
558   template <typename NativeT>
559   absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
560   // Unhide const method from parent class.
561   using LiteralBase::data;
562 
563   // Returns a pointer to the sparse index array. Returns nullptr if the literal
564   // is not a sparse array.
565   SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
566 
567   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
568   // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()569   Shape* mutable_shape_do_not_use() { return shape_.get(); }
570 
571   // Returns a pointer to the underlying buffer holding the array at the given
572   // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
573   // is not array.
574   void* untyped_data(const ShapeIndex& shape_index = {});
575   // Unhide const method from parent class.
576   using LiteralBase::untyped_data;
577 
578   // Populates a literal with a sparse layout with the given indices and values.
579   // Each index in the indices array is CHECKed against the dimensions in the
580   // literal's shape.  If sort is true, then the indices and values will be
581   // sorted.  If sort is false, then the indices and values are assumed to
582   // already be in sorted order.  See CreateSparse for an example of how data
583   // are populated.
584   template <typename NativeT>
585   void PopulateSparse(SparseIndexArray indices,
586                       absl::Span<const NativeT> values, bool sort = true);
587 
588   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
589   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
590   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
591   // rooted at 'src_shape_index', but need not be arrays.
592   Status CopyFrom(const LiteralSlice& src_literal,
593                   const ShapeIndex& dest_shape_index = {},
594                   const ShapeIndex& src_shape_index = {});
595 
596   // Copies the values from src_literal, starting at src_base shape indexes,
597   // to this literal, starting at dest_base, where the copy size in each
598   // dimension is specified by copy_size.
599   // The src_literal and this literal must have the same primitive type,
600   // src_base+copy_size must fit the source literal dimensions, as well as
601   // dest_base+copy_size must fit the destination literal dimensions.
602   // Note: if either src_literal or this literal contains dimensions with zero
603   // element, then copy_size must be 0 in these dimensions while the
604   // corresponding base indices being 0.
605   // This literal and 'src_literal' must be arrays.
606   Status CopySliceFrom(const LiteralSlice& src_literal,
607                        absl::Span<const int64> src_base,
608                        absl::Span<const int64> dest_base,
609                        absl::Span<const int64> copy_size);
610 
611   // Copies one element from src_literal[src_index] to (*this)[dest_index].
612   Status CopyElementFrom(const LiteralSlice& src_literal,
613                          absl::Span<const int64> src_index,
614                          absl::Span<const int64> dest_index);
615 
616   // Sets an element in the literal at the given index. The multi_index is
617   // CHECKed against the dimension sizes.
618   template <typename NativeT>
619   void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
620            NativeT value);
621   // Overloads of Set for array literals. CHECKs if the literal is not
622   // array-shaped and dense.
623   template <typename NativeT>
624   void Set(absl::Span<const int64> multi_index, NativeT value);
625 
626   // Appends the given element to the literal.  If the elements are not appended
627   // in sorted order, then SortSparseElements should be called before calling
628   // other methods.  This literal must have a sparse layout.
629   template <typename NativeT>
630   void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
631                            const ShapeIndex& shape_index = {});
632 
633   // Sorts the elements in a sparse array.
634   void SortSparseElements(const ShapeIndex& shape_index = {});
635 
636   // As Set(), but truncates `value` to the literal element type before storing.
637   // This literal must be an array.
638   Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
639 
640   // Populate this literal with the given values. Examples:
641   //
642   //   // Populate with floats.
643   //   Array2D<float> float_values = ...
644   //   literal.PopulateR2FromArray2D(values);
645   //
646   //   // Populate with int32s.
647   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
648   //
649   // The shape and element type of this literal must match given values. For
650   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
651   // array of S32.
652   template <typename NativeT>
653   void PopulateR1(absl::Span<const NativeT> values);
654   void PopulateR1(const tensorflow::core::Bitmap& values);
655   template <typename NativeT>
656   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
657   template <typename NativeT>
658   void PopulateFromArray(const Array<NativeT>& values);
659   template <typename NativeT>
660   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
661   template <typename NativeT>
662   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
663   template <typename NativeT>
664   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
665 
666   // Populates literal values by calling the generator function for every cell
667   // in this literal object.
668   //
669   // generator must be a callable of the type
670   // NativeT(absl::Span<int64> indexes) or compatible.
671   //
672   // This literal must have a dense layout.
673   template <typename NativeT, typename FnType>
674   Status Populate(const FnType& generator);
675 
676   // A parallel version of Populate(). This can be used if the generator is
677   // thread-safe and the values for the shape's different elements are
678   // independent.
679   template <typename NativeT, typename FnType>
680   Status PopulateParallel(const FnType& generator);
681 
682   // Fills this literal with the given value.
683   template <typename NativeT>
684   void PopulateWithValue(NativeT value);
685 
686   // This operation is the inverse of DecomposeTuple. The given elements are
687   // moved into the tuple elements of a new tuple-shaped Literal which is
688   // returned. Upon return, each of the Literals in 'elements' is set to a nil
689   // shape (empty tuple).
690   static Literal MoveIntoTuple(absl::Span<Literal> elements);
691 
692   // Serialize from a proto.
693   static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
694 
695  protected:
696   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)697   Piece& piece(const ShapeIndex& shape_index) {
698     return const_cast<Piece&>(LiteralBase::piece(shape_index));
699   }
700 
root_piece()701   Piece& root_piece() const override { return *root_piece_; };
702 
703   // Internal template helper for the Literal::CopySliceFrom(), matching its
704   // arguments one by one.
705   template <typename NativeT>
706   Status CopySliceFromInternal(const LiteralBase& src_literal,
707                                absl::Span<const int64> src_base,
708                                absl::Span<const int64> dest_base,
709                                absl::Span<const int64> copy_size);
710 
711   // Utility structure which is used to create the optimal configuration for
712   // a ShapeUtil::ForEachIndex() scan across two literals.
713   struct StrideConfig {
714     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
715                  absl::Span<const int64> dimensions);
716 
717     // The dimensions of the stride operation. Essentially every dimension
718     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
719     // steps.
720     absl::Span<const int64> dimensions;
721     DimensionVector base;
722     DimensionVector step;
723     int64 minor_dimension = 0;
724     // The size of the strides for source and destination. One of the two
725     // (the one looping through its most minor dimension) will be 1, while
726     // the other will be the stride size at the dimension matching the other
727     // shape most minor dimension being scanned.
728     int64 dest_stride = 1;
729     int64 source_stride = 1;
730     // The size of the inner loop on the most minor dimension.
731     int64 minor_loop_size = 1;
732   };
733 
734   // Literal class always owns the shape. The parent class borrows this shape.
735   std::unique_ptr<Shape> shape_;
736 
737   Piece* root_piece_ = nullptr;
738 
739   // Implementation details shared between Populate() and PopulateParallel()
740   template <typename NativeT, typename FnType>
741   Status PopulateInternal(const FnType& generator, bool parallel);
742 
743   friend class LiteralBase;
744   friend class MutableBorrowingLiteral;
745 };
746 std::ostream& operator<<(std::ostream& out, const Literal& literal);
747 
748 // The underlying buffer and shape is always owned by this class.
749 class Literal : public MutableLiteralBase {
750  public:
Literal()751   Literal() : Literal(ShapeUtil::MakeNil()) {}
752 
753   // Create a literal of the given shape. The literal is allocated sufficient
754   // memory to hold the shape. Memory is uninitialized.
755   explicit Literal(const Shape& shape);
756   virtual ~Literal();
757 
758   // Literals are moveable, but not copyable. To copy a literal use
759   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
760   // of literals which can be expensive.
761   Literal(const Literal& other) = delete;
762   Literal& operator=(const Literal& other) = delete;
763   Literal(Literal&& other);
764   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
765   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
766   // to nullptr.
767   Literal(const Shape& shape, bool allocate_arrays);
768   Literal& operator=(Literal&& other);
769 
770   // Similar to CopyFrom, but with move semantincs. The subshape of this literal
771   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
772   // (layouts and shapes must match), but need not be arrays. The memory
773   // allocated in this literal for the subshape at dest_shape_index is
774   // deallocated, and the respective buffers are replaced with those in
775   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
776   virtual Status MoveFrom(Literal&& src_literal,
777                           const ShapeIndex& dest_shape_index = {});
778 
779   // Returns a vector containing the tuple elements of this Literal as separate
780   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
781   // elements are moved into the new Literals; no data is copied. Upon return
782   // this Literal is set to a nil shape (empty tuple)
783   std::vector<Literal> DecomposeTuple();
784 
785  private:
786   // Deallocate the buffers held by this literal.
787   void DeallocateBuffers();
788 
789   // Recursively sets the subshapes and buffers of all subpieces rooted at
790   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
791   // the shape.
792   void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
793 };
794 
795 // The underlying buffer is not owned by this class and is always owned by
796 // others. The shape is not owned by this class and not mutable.
797 class MutableBorrowingLiteral : public MutableLiteralBase {
798  public:
799   virtual ~MutableBorrowingLiteral();
800 
MutableBorrowingLiteral()801   MutableBorrowingLiteral() : MutableLiteralBase() {}
802 
803   MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
804   MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
805 
806   // Implicit conversion constructors.
807   MutableBorrowingLiteral(const MutableLiteralBase& literal);
808   MutableBorrowingLiteral(MutableLiteralBase* literal);
809   MutableBorrowingLiteral(MutableBorrowingLiteral literal,
810                           const ShapeIndex& view_root);
811   MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
812 
813  private:
814   // Recursively copies the subtree from the `src_piece` at the given child
815   // index to the `dest_piece`. For buffers only the pointers are copied, but
816   // not the content.
817   void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
818                         Piece* dest_piece);
819 };
820 
821 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
822 // literal buffers always owned by others.
823 class LiteralSlice : public LiteralBase {
824  public:
LiteralSlice()825   LiteralSlice() : LiteralBase() {}
826 
827   // Implicit conversion constructors.
828   LiteralSlice(const LiteralBase& literal);
829   LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
830 
831  private:
root_piece()832   const Piece& root_piece() const override { return *root_piece_; };
833 
834   const Piece* root_piece_;  // Not owned.
835 };
836 
837 // A read-only Literal where the underlying buffers are never owned by this
838 // class.
839 class BorrowingLiteral : public LiteralBase {
840  public:
BorrowingLiteral()841   BorrowingLiteral() : LiteralBase() {}
842 
843   // 'src_buf_ptr' is not owned by this class and must outlive the
844   // lifetime of this class. It points to an appropirately sized buffer with
845   // data interpretered as indicated by 'shape'.
846   // This constructor is only used for array shapes.
847   BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
848   // Similar as above, except to be used for constructing non-nested tuples.
849   BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
850                    const Shape& shape);
851   // TODO(b/79707221): adding constructors for nested tuples as well.
852 
853  private:
854   // Recursively builds the subtree for the given piece and sets the subshapes
855   // of the given piece with the given shape.
856   void BuildPieceSubtree(const Shape& shape, Piece* piece);
857 
858   // Accessor for the root piece of this literal.
root_piece()859   const Piece& root_piece() const override { return root_piece_; };
860   Piece root_piece_;
861 
862   // Shape of this literal. Stored as unique_ptr such that the (default) move
863   // construction of this class would be trivially correct: the pointer to Shape
864   // root_piece_ stores will still point to the correct address.
865   std::unique_ptr<Shape> shape_;
866 };
867 
868 template <typename NativeT>
data()869 absl::Span<const NativeT> LiteralBase::Piece::data() const {
870   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
871   DCHECK_EQ(subshape().element_type(),
872             primitive_util::NativeToPrimitiveType<NativeT>())
873       << "Attempting to access "
874       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
875       << " type, but literal element type is "
876       << PrimitiveType_Name(subshape().element_type());
877   return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
878                                    element_count());
879 }
880 
881 template <typename NativeT>
data()882 absl::Span<NativeT> LiteralBase::Piece::data() {
883   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
884   DCHECK_EQ(subshape().element_type(),
885             primitive_util::NativeToPrimitiveType<NativeT>())
886       << "Attempting to access "
887       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
888       << " type, but literal element type is "
889       << PrimitiveType_Name(subshape().element_type());
890   return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
891                              element_count());
892 }
893 
894 template <typename NativeT>
Get(absl::Span<const int64> multi_index)895 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
896   CHECK(LayoutUtil::IsDenseArray(subshape()));
897   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
898       subshape(), multi_index)];
899 }
900 
901 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)902 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
903                              NativeT value) {
904   CHECK(LayoutUtil::IsDenseArray(subshape()));
905   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
906       subshape(), multi_index)] = value;
907 }
908 
909 template <typename NativeT>
data(const ShapeIndex & shape_index)910 absl::Span<const NativeT> LiteralBase::data(
911     const ShapeIndex& shape_index) const {
912   return piece(shape_index).data<NativeT>();
913 }
914 
915 template <typename NativeT>
data(const ShapeIndex & shape_index)916 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
917   return piece(shape_index).data<NativeT>();
918 }
919 
920 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)921 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
922                                 const ShapeIndex& shape_index) const {
923   return piece(shape_index).Get<NativeT>(multi_index);
924 }
925 
926 template <typename NativeT>
Get(absl::Span<const int64> multi_index)927 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
928   return root_piece().Get<NativeT>(multi_index);
929 }
930 
931 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)932 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
933                                     const ShapeIndex& shape_index,
934                                     NativeT value) {
935   return piece(shape_index).Set<NativeT>(multi_index, value);
936 }
937 
938 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)939 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
940                                     NativeT value) {
941   return root_piece().Set<NativeT>(multi_index, value);
942 }
943 
944 template <typename NativeT>
GetFirstElement()945 NativeT LiteralBase::GetFirstElement() const {
946   return data<NativeT>().at(0);
947 }
948 
949 template <typename NativeT>
GetSparseElement(int64 sparse_element_number,const ShapeIndex & shape_index)950 NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
951                                       const ShapeIndex& shape_index) const {
952   CHECK(
953       LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
954   return data<NativeT>(shape_index)[sparse_element_number];
955 }
956 
957 template <typename NativeT>
AppendSparseElement(absl::Span<const int64> multi_index,NativeT value,const ShapeIndex & shape_index)958 void MutableLiteralBase::AppendSparseElement(
959     absl::Span<const int64> multi_index, NativeT value,
960     const ShapeIndex& shape_index) {
961   Piece& p = piece(shape_index);
962   const Shape& subshape = p.subshape();
963   CHECK(LayoutUtil::IsSparseArray(subshape));
964   int64 rank = subshape.rank();
965   CHECK_EQ(multi_index.size(), rank);
966   for (int64 i = 0; i < rank; ++i) {
967     CHECK_GE(multi_index[i], 0);
968     CHECK_LT(multi_index[i], subshape.dimensions(i));
969   }
970   int64 last_element = p.sparse_indices()->index_count();
971   CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
972   p.sparse_indices()->Append(multi_index);
973   CHECK_LT(last_element, p.data<NativeT>().size());
974   p.data<NativeT>()[last_element] = value;
975 }
976 
977 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)978 void LiteralBase::EachCell(
979     std::function<void(absl::Span<const int64> indices, NativeT value)>
980         per_cell) const {
981   if (ShapeUtil::IsZeroElementArray(shape())) {
982     return;
983   }
984   std::vector<int64> indices(shape().rank(), 0);
985   do {
986     per_cell(indices, Get<NativeT>(indices));
987   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
988 }
989 
990 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)991 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
992   CHECK(shape().IsArray());
993   CHECK_EQ(shape().rank(), 1);
994   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
995   CHECK_EQ(shape().element_type(),
996            primitive_util::NativeToPrimitiveType<NativeT>());
997   auto data_span = data<NativeT>();
998   std::copy(values.begin(), values.end(), data_span.begin());
999 }
1000 
1001 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1002 void MutableLiteralBase::PopulateR2(
1003     std::initializer_list<std::initializer_list<NativeT>> values) {
1004   CHECK(shape().IsArray());
1005   CHECK_EQ(shape().rank(), 2);
1006   CHECK_EQ(shape().element_type(),
1007            primitive_util::NativeToPrimitiveType<NativeT>());
1008 
1009   const int64 dim0_size = values.size();
1010   const int64 dim1_size = values.begin()->size();
1011   CHECK_EQ(dim0_size, shape().dimensions(0));
1012   CHECK_EQ(dim1_size, shape().dimensions(1));
1013 
1014   int64 dim0 = 0;
1015   for (auto inner_list : values) {
1016     int64 dim1 = 0;
1017     for (auto value : inner_list) {
1018       Set({dim0, dim1}, value);
1019       ++dim1;
1020     }
1021     CHECK_EQ(dim1_size, dim1);
1022     ++dim0;
1023   }
1024 }
1025 
1026 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1027 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1028   CHECK(shape().IsArray());
1029   CHECK_EQ(shape().element_type(),
1030            primitive_util::NativeToPrimitiveType<NativeT>());
1031   CHECK_EQ(shape().rank(), values.num_dimensions());
1032   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1033     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1034   }
1035   values.Each([this](absl::Span<const int64> indices, NativeT value) {
1036     this->Set(indices, value);
1037   });
1038 }
1039 
1040 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1041 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1042   PopulateFromArray(values);
1043 }
1044 
1045 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1046 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1047   PopulateFromArray(values);
1048 }
1049 
1050 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1051 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1052   PopulateFromArray(values);
1053 }
1054 
1055 template <typename NativeT>
PopulateSparse(SparseIndexArray indices,absl::Span<const NativeT> values,bool sort)1056 void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
1057                                         absl::Span<const NativeT> values,
1058                                         bool sort) {
1059   CHECK(LayoutUtil::IsSparseArray(shape()));
1060   int rank = shape().rank();
1061   CHECK_EQ(indices.rank(), rank);
1062   int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
1063   CHECK_LE(indices.max_indices(), max_elements);
1064   int64 num_elements = values.size();
1065   CHECK_LE(num_elements, max_elements);
1066   CHECK_EQ(num_elements, indices.index_count());
1067   auto root_data = root_piece().data<NativeT>();
1068   // Piece::data() returns a Span of size equal to the number of indices
1069   // in the SparseIndexArray. So there is no need to adjust the size of the data
1070   // here. It is enough to just copy the incoming values into the data buffer.
1071   std::copy(values.begin(), values.end(), root_data.begin());
1072   *this->root_piece().sparse_indices() = std::move(indices);
1073   if (sort) {
1074     auto root_data = this->root_piece().data<NativeT>();
1075     this->root_piece().sparse_indices()->SortWithValues(root_data);
1076   }
1077   DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
1078 }
1079 
1080 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1081 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1082                                             bool parallel) {
1083   const Shape& this_shape = shape();
1084   const int64 rank = this_shape.rank();
1085   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1086   TF_RET_CHECK(this_shape.element_type() ==
1087                primitive_util::NativeToPrimitiveType<NativeT>());
1088   absl::Span<NativeT> literal_data = data<NativeT>();
1089   if (rank > 0) {
1090     StrideConfig stride_config(this_shape, this_shape,
1091                                AsInt64Slice(this_shape.dimensions()));
1092     int64 minor_dimension_size =
1093         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1094 
1095     auto init_function = [&](absl::Span<const int64> indexes) {
1096       DimensionVector minor_scan_indexes(rank, 0);
1097       const int64 index =
1098           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1099       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1100       for (int64 i = 0; i < minor_dimension_size; ++i) {
1101         minor_scan_indexes[stride_config.minor_dimension] = i;
1102         literal_data.at(index + i) = generator(minor_scan_indexes);
1103       }
1104     };
1105     if (parallel) {
1106       ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1107                                       stride_config.dimensions,
1108                                       stride_config.step, init_function);
1109     } else {
1110       ShapeUtil::ForEachIndex(
1111           this_shape, stride_config.base, stride_config.dimensions,
1112           stride_config.step,
1113           [&init_function](absl::Span<const int64> indexes) {
1114             init_function(indexes);
1115             return true;
1116           });
1117     }
1118   } else {
1119     // For scalars.
1120     literal_data.at(0) = generator({});
1121   }
1122   return Status::OK();
1123 }
1124 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1125 Status MutableLiteralBase::Populate(const FnType& generator) {
1126   return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1127 }
1128 
1129 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1130 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1131   return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1132 }
1133 
1134 template <typename NativeT>
PopulateWithValue(NativeT value)1135 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1136   CHECK(shape().IsArray());
1137   CHECK_EQ(shape().element_type(),
1138            primitive_util::NativeToPrimitiveType<NativeT>());
1139   for (NativeT& element : data<NativeT>()) {
1140     element = value;
1141   }
1142 }
1143 
1144 template <typename NativeT>
Replicate(int64 times)1145 Literal LiteralBase::Replicate(int64 times) const {
1146   DimensionVector bounds = {times};
1147   bounds.reserve(shape().dimensions_size() + 1);
1148   for (int64 bound : shape().dimensions()) {
1149     bounds.push_back(bound);
1150   }
1151   Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1152   int64 elements = ShapeUtil::ElementsIn(literal.shape());
1153   if (elements == 0) {
1154     return literal;
1155   }
1156 
1157   DimensionVector output_indices(bounds.size(), 0);
1158   absl::Span<const int64> input_indices = output_indices;
1159   input_indices.remove_prefix(1);
1160 
1161   bool done = false;
1162   while (!done) {
1163     const auto element = Get<NativeT>(input_indices);
1164     literal.Set<NativeT>(output_indices, element);
1165 
1166     done = true;
1167     for (int n = 0; n < output_indices.size(); ++n) {
1168       ++output_indices[n];
1169       if (output_indices[n] < bounds[n]) {
1170         done = false;
1171         break;
1172       }
1173       output_indices[n] = 0;
1174     }
1175   }
1176   return literal;
1177 }
1178 
1179 }  // namespace xla
1180 
1181 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1182