• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/array2d.h"
33 #include "tensorflow/compiler/xla/array3d.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56 
57 // Abstract base class for literals.
58 class LiteralBase {
59  public:
60   virtual ~LiteralBase() = 0;
61 
62   // Literals are equal if they have compatible shapes and the same data
63   // values. Layout is not compared.
64   bool operator==(const LiteralBase& other) const;
65   bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66 
67   // Returns the shape of the literal.
shape()68   const Shape& shape() const { return root_piece().subshape(); }
69 
70   // Serialize to proto.
71   LiteralProto ToProto() const;
72 
73   // Returns a Span of the array for this literal for the given NativeT
74   // (e.g., float). CHECKs if the subshape of the literal at the given
75   // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76   // to native type.
77   template <typename NativeT>
78   absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79 
80   // Returns a const pointer to (or size of) the underlying buffer holding the
81   // array at the given shape index. CHECKs if the subshape of the literal at
82   // the given ShapeIndex is not array.
83   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
85 
86   // Returns this literal's data as a string. This literal must be a rank-1 U8
87   // array.
88   string GetR1U8AsString() const;
89 
90   // Returns a string representation of the literal value. The Shape of the
91   // literal is a prefix of the literal value in the string.
92 
93   // Warning: this function can take minutes for multi-million
94   // element Literals.
95   string ToString() const;
96 
97   // Similar to ToString, but return the result in a compact
98   // one-line form.
99   string ToStringOneline() const;
100 
101   // Returns a string representation of the literal value which does *not*
102   // include the shape string.
103   string ToStringWithoutShape() const;
104 
105   // Similar to ToStringWithoutShape, but return the result in a compact
106   // one-line form.
107   string ToStringWithoutShapeOneline() const;
108 
109   // Returns a string representation of the literal value which includes the
110   // shape string with its layout.does *not* include the shape string.
111   string ToStringWithLayout() const;
112 
113   // Similar to ToStringWithLayout, but return the result in a compact
114   // one-line form.
115   string ToStringWithLayoutOneline() const;
116 
117   // Gets an element in the literal at the given index. The multi_index is
118   // CHECKed against the dimension sizes.
119   template <typename NativeT>
120   NativeT Get(absl::Span<const int64> multi_index,
121               const ShapeIndex& shape_index) const;
122   // Overloads of Get for array literals. CHECKs if the literal is not
123   // array-shaped and dense.
124   template <typename NativeT>
125   NativeT Get(absl::Span<const int64> multi_index) const;
126 
127   // Get the dynamic size on dim_index in the literal at the given shape_index.
128   int32 GetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index) const;
129   int32 GetDynamicSize(int64_t dim_index) const;
130 
131   // Returns the element value at index (0, ..., 0), however many zeroes are
132   // required for that index.
133   template <typename NativeT>
134   NativeT GetFirstElement() const;
135 
136   // As above but returns any integer type casted to an int64.
137   absl::optional<int64> GetFirstInteger() const;
138 
139   // As Get(), but determines the correct type and converts the value
140   // into text.
141   string GetAsString(absl::Span<const int64> multi_index,
142                      const ShapeIndex& shape_index = {}) const;
143 
144   // Return whether the value at the specified index is equal to the provided
145   // generic `value` (T must be an arithmetic type).
146   //
147   // Precondition: must be an array.
148   template <typename T>
149   typename std::enable_if<(std::is_arithmetic<T>::value ||
150                            std::is_same<T, Eigen::half>::value ||
151                            std::is_same<T, bfloat16>::value),
152                           bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)153   IsEqualAt(absl::Span<const int64> multi_index, T value) const {
154     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
155       return *as_s64 == value;
156     }
157     complex128 as_complex128 = *GetAsComplex128(multi_index);
158     return as_complex128.imag() == 0 && as_complex128.real() == value;
159   }
160 
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)161   bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
162     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
163       return *as_s64 == value.real() && value.imag() == 0;
164     }
165     auto as_complex128 = GetAsComplex128(multi_index);
166     return *as_complex128 == value;
167   }
168 
169   // As Get(), but determines the correct type and converts the value into
170   // int64.  This literal must be an array.
171   absl::optional<int64> GetIntegralAsS64(
172       absl::Span<const int64> multi_index) const;
173 
174   // As Get(), but determines the correct type, and converts the value into
175   // double. This literal must be an array.
176   absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
177 
178   // As Get(), but determines the correct type, and converts the value into
179   // complex128. All floating point types can be converted into complex128.
180   //
181   // This literal must be an array.
182   absl::optional<complex128> GetAsComplex128(
183       absl::Span<const int64> multi_index) const;
184 
185   // Invokes the "per cell" callback for each element in the provided
186   // literal with the element's indices and a string representation of
187   // the element's value.
188   //
189   // This function is useful if you want a polymorphic representation
190   // of the tensor's elements (turning it to a string for something
191   // like representation in a protobuf).
192   //
193   // This literal must have a dense layout.
194   void EachCellAsString(
195       const std::function<void(absl::Span<const int64> indices,
196                                const string& value)>& per_cell) const;
197   template <typename NativeT>
198   void EachCell(
199       std::function<void(absl::Span<const int64> indices, NativeT value)>
200           per_cell) const;
201 
202   // Returns whether every element in this literal is equal to value.
203   //
204   // value is an int8 because we expect this to be called with small
205   // compile-time constants (0, -1, etc.) and so that whatever value you pass
206   // can be represented exactly by floating-point types as small as 16 bits.
207   //
208   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
209   // are considered equal to true/false; other values are not considered equal
210   // to true. Also if this literal is not array-shaped false is returned.
211   bool IsAll(int8_t value) const;
212 
213   // Like IsAll(const Literal&, int8), except we check whether the literal is
214   // equal to a particular floating-point number.
215   //
216   // If the literal is not a floating-point value, this always returns false.
217   //
218   // This casts value to the type of literal, then compares using ==.  The usual
219   // admonishments about floating-point equality checks apply.  We expect you to
220   // use this to check for values that can be expressed precisely as a float,
221   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
222   bool IsAllFloat(float value) const;
223 
224   // Like IsAll(const Literal&, int8), except we check whether the literal is
225   // equal to a particular complex number.
226   //
227   // If the literal is not a complex value, this always returns false.
228   //
229   // This casts value to the type of literal, then compares using ==.  The usual
230   // admonishments about floating-point equality checks apply.  We expect you to
231   // use this to check for complex values that can be expressed precisely as
232   // float pairs e.g. (-0.5, 1.0).
233   //
234   // This literal must have a dense layout.
235   bool IsAllComplex(complex64 value) const;
236 
237   // Literal consists entirely of the first element of the literal.
238   bool IsAllFirst() const;
239 
240   // Literal consists entirely of an iota.
241   bool IsR1Iota() const;
242 
243   // Returns the stride if the literal is a strided iota.
244   absl::optional<int64> IsR1StridedIota() const;
245 
246   // Returns whether this literal is zero at the specified index. This literal
247   // must be an array with a dense layout.
248   bool IsZero(absl::Span<const int64> indices) const;
249 
250   // Returns the count of the elements in the array at the given shape index in
251   // this literal.
252   int64 element_count(const ShapeIndex& index = {}) const {
253     if (index.empty()) {
254       // Common case, avoid GetSubshape().
255       return ShapeUtil::ElementsIn(shape());
256     }
257     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
258   }
259 
260   // Compute a hash for this literal.
261   size_t Hash() const;
262 
263   // Converts this literal to the given shape. Returns an error is the
264   // conversion is not possible.
265   StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
266 
267   // Converts this literal to another primitive type using a bitcast
268   // conversion. The to and from primitive types must have the same bit
269   // width. Returns an error if the conversion is not possible. This literal
270   // must be array-shaped.
271   StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
272 
273   // Converts this literal to another primitive type. Returns an error if the
274   // conversion is not possible. This literal must be array-shaped.
275   StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
276 
277   // Clones the underlying buffers into a new Literal.
278   Literal Clone() const;
279 
280   // TODO(b/67651157): The methods below which perform computation on Literals
281   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
282   // evaluator code which operates on Literals.
283   //
284   // Creates a new value that has the equivalent value as this
285   // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
286   // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
287   // minor-to-major dimension layout and the value in the cell at any given
288   // logical index (i0, i1) will be the same.
289   //
290   // For tuple shaped literals, shape_index should be used to select the inner
291   // array that the new layout applies to.
292   //
293   // Note: this is useful when the client wants to ensure that a value placed in
294   // the XLA allocation tracker has a particular layout; for efficiency
295   // purposes or avoiding unimplemented operation/layout combinations.
296   Literal Relayout(const Layout& new_layout,
297                    const ShapeIndex& shape_index = {}) const;
298 
299   // An overload of Relayout which changes the layout of the entire shape rather
300   // than being limited to a single array within the shape.
301   Literal Relayout(const Shape& shape_with_layout) const;
302 
303   // Generate a new literal whose static sizes are equal to the previous
304   // literal's dynamic sizes.
305   Literal ToStatic() const;
306 
307   // Expand a static literal into a new one with a bounded dyanmic literal. The
308   // static dimensions of the original literal becomes dynamic dimensions of the
309   // new literal, where the argument `bounded_shape` becomes the bounded shape
310   // of the new literal.
311   //
312   // Precondition: bounded_shape.is_dynamic()
313   Literal ToBoundedDynamic(const Shape& bounded_shape) const;
314 
315   // Creates a new literal by reshaping this literal to have the given
316   // dimensions. The total number of elements must not change; The
317   // implementation currently only supports monotonic dim0-major layouts.
318   // This literal must be an array.
319   StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
320 
321   // Creates a new literal by broadcasting this literal with `dimensions` to
322   // yield a literal of shape `result_shape`.
323   StatusOr<Literal> Broadcast(const Shape& result_shape,
324                               absl::Span<const int64> dimensions) const;
325 
326   // Creates a new literal by reordering the dimensions of this literal.
327   // The given `permutation` must be a permutation of the dimension numbers
328   // in the original literal, and it specifies the order of the new dimensions
329   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
330   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
331   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
332   // This literal must be an array.
333   Literal Transpose(absl::Span<const int64> permutation) const;
334 
335   // Creates a sub-array from this literal by extracting the indices
336   // [start_index, limit_index) of each dimension. The result literal has the
337   // same rank and layout as for the given literal. The number of indices in
338   // start_indices and limit_indices must be the rank of the literal, and the
339   // indices follow the order of the dimensions.
340   // This literal must be an array.
341   Literal Slice(absl::Span<const int64> start_indices,
342                 absl::Span<const int64> limit_indices) const;
343 
344   // Creates a literal with a prepended dimension with bound "times"; e.g. a
345   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
346   // literal replicated four times.
347   // This literal must be an array.
348   template <typename NativeT>
349   Literal Replicate(int64_t times) const;
350 
351   // Creates a new Literal object with the shape specified as parameter.
352   // The content of the literal values is the default value of the primitive
353   // type of literal itself (0 for numeric types, and false for predicates).
354   //
355   // Note: It's an antipattern to use this method then immediately call
356   // MutableLiteralBase::Populate on the result (since that results in zero
357   // initialization, then reinitialization. Consider if a call to
358   // absl::make_unique<Literal>(shape), followed by the call to
359   // MutableLiteralBase::Populate can be used instead.
360   static Literal CreateFromShape(const Shape& shape);
361 
362  protected:
363   // A data structure representing a subshape at a particular ShapeIndex within
364   // the literal. For array-shaped ShapeIndexes, this data structure holds the
365   // pointer to the memory allocated for the array data.
366   class Piece {
367    public:
368     // Returns the buffer holding the array data for this piece as an array
369     // slice. This piece must be array-shaped.
370     template <typename NativeT>
371     absl::Span<const NativeT> data() const;
372     template <typename NativeT>
373     absl::Span<NativeT> data();
374 
375     // Returns the buffer holding the array data for this piece as a void*. This
376     // piece must be array-shaped.
377     void* untyped_data();
378     const void* untyped_data() const;
379 
380     // Gets or sets an element in the array at the given index. The multi_index
381     // is CHECKed against the dimension sizes of the array.  This piece must be
382     // array-shaped.
383     template <typename NativeT>
384     NativeT Get(absl::Span<const int64> index) const;
385     template <typename NativeT>
386     void Set(absl::Span<const int64> index, NativeT value);
387 
388     int32 GetDynamicSize(int64_t dim_index) const;
389     void SetDynamicSize(int64_t dim_index, int32_t size);
390     // Gets/sets the buffer holding the array data.
buffer()391     char* buffer() const { return buffer_; }
set_buffer(char * buffer)392     void set_buffer(char* buffer) { buffer_ = buffer; }
393 
394     // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()395     int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
set_dynamic_size_buffer(int32 * dynamic_size_buffer)396     void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
397       dynamic_size_buffer_ = dynamic_size_buffer;
398     }
399 
dynamic_size_buffer_bytes()400     int64 dynamic_size_buffer_bytes() const {
401       return subshape().dimensions_size() * sizeof(int32);
402     }
403 
404     // Gets or sets the subshape of this piece. This reference points to a
405     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()406     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)407     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
408 
409     // Returns the size in bytes of the buffer holding the array data.
size_bytes()410     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
411 
412     // Returns the number of elements in this piece's array.
element_count()413     int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
414 
415     // Returns the child piece at 'index' of this piece.
child(int64_t index)416     Piece& child(int64_t index) { return children_[index]; }
417 
418     // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)419     void emplace_back(Piece child_piece) {
420       children_.emplace_back(std::move(child_piece));
421     }
422 
423     // Returns the size of children pieces of this piece.
children_size()424     int64 children_size() { return children_.size(); }
425 
426     // Visitor functions that recursively traverses the piece and calls the
427     // given function at each child piece. The function has the type:
428     //    void (const ShapeIndex& index, const Piece& piece)
429     template <typename Fn>
ForEachSubpiece(const Fn & func)430     void ForEachSubpiece(const Fn& func) const {
431       ShapeIndex index;
432       return ForEachHelper(
433                  [&func](const ShapeIndex& index, const Piece& piece) {
434                    func(index, piece);
435                    return Status::OK();
436                  },
437                  *this, &index)
438           .IgnoreError();
439     }
440     // Same as above, but the function has the type:
441     //    Status (const ShapeIndex& index, const Piece& piece)
442     // The first non-OK return value is returned by the function.
443     template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)444     Status ForEachSubpieceWithStatus(const Fn& func) const {
445       ShapeIndex index;
446       return ForEachHelper(func, *this, &index);
447     }
448     // Same as above, but the function has the type:
449     //    Bool (const ShapeIndex& index, const Piece& piece)
450     // The first non-true return value is returned by the function.
451     template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)452     bool ForEachSubpieceWithBool(const Fn& func) const {
453       ShapeIndex index;
454       return ForEachHelperBool(func, *this, &index);
455     }
456     // Same as above, but the function has the type:
457     //    Void (const ShapeIndex& index, Piece& piece)
458     template <typename Fn>
ForEachMutableSubpiece(const Fn & func)459     void ForEachMutableSubpiece(const Fn& func) {
460       ShapeIndex index;
461       return ForEachMutableHelper(
462                  [&func](const ShapeIndex& index, Piece* piece) {
463                    func(index, piece);
464                    return Status::OK();
465                  },
466                  const_cast<xla::LiteralBase::Piece*>(this), &index)
467           .IgnoreError();
468     }
469     // Same as above, but the function has the type:
470     //    Status (const ShapeIndex& index, Piece& piece)
471     // The first non-OK return value is returned by the function.
472     template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)473     Status ForEachMutableSubpieceWithStatus(const Fn& func) {
474       ShapeIndex index;
475       return ForEachMutableHelper(
476           func, const_cast<xla::LiteralBase::Piece*>(this), &index);
477     }
478 
479     // Returns true if this piece and 'other' contain the same data. This piece
480     // and 'other' must be array-shaped and compatible. If a literal has dynamic
481     // shape, comparison is done only for the valid elements.
482     bool EqualElements(const Piece& other) const;
483 
484     // Returns true if this piece and other pieces have the same dynamic
485     // dimension sizes.
486     bool EqualDynamicSize(const Piece& other) const;
487 
488     // Writes the shape and data (if array-shaped) into the given proto.
489     void WriteToProto(LiteralProto* proto) const;
490 
491     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
492     // and src must be compatible. If only_dynamic_bound is true, only elements
493     // within dynamic bounds will be copied.
494     Status CopyFrom(const Piece& src, bool only_dynamic_bound);
495 
496     // Copies the data from the given proto into this piece. The shape of this
497     // piece must be equal (not just compatible) to the shape of the proto.
498     Status CopyFromProto(const LiteralProto& proto);
499 
500    private:
501     // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
502     // The first non-OK (or non-true) value is returned by the function.
503     // The callable 'func' has the same signature as described above in
504     // ForEachSubpiece*.
505     template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)506     Status ForEachHelper(const Fn& func, const Piece& piece,
507                          ShapeIndex* index) const {
508       TF_RETURN_IF_ERROR(func(*index, piece));
509       for (int64_t i = 0; i < piece.children_.size(); ++i) {
510         index->push_back(i);
511         TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
512         index->pop_back();
513       }
514       return Status::OK();
515     }
516     template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)517     bool ForEachHelperBool(const Fn& func, const Piece& piece,
518                            ShapeIndex* index) const {
519       if (!func(*index, piece)) {
520         return false;
521       }
522       for (int64_t i = 0; i < piece.children_.size(); ++i) {
523         index->push_back(i);
524         if (!ForEachHelperBool(func, piece.children_[i], index)) {
525           return false;
526         }
527         index->pop_back();
528       }
529       return true;
530     }
531     template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)532     Status ForEachMutableHelper(const Fn& func, Piece* piece,
533                                 ShapeIndex* index) {
534       TF_RETURN_IF_ERROR(func(*index, piece));
535       for (int64_t i = 0; i < piece->children_.size(); ++i) {
536         index->push_back(i);
537         TF_RETURN_IF_ERROR(
538             ForEachMutableHelper(func, &piece->children_[i], index));
539         index->pop_back();
540       }
541       return Status::OK();
542     }
543 
544     // Recursive helper for EqualElements.
545     template <typename NativeT>
546     bool EqualElementsInternal(const Piece& other,
547                                std::vector<int64>* multi_index) const;
548 
549     // Internal helper to copy elements from another given piece
550     template <typename NativeT>
551     void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
552 
553     // For array-shaped pieces, this is the buffer holding the literal data.
554     char* buffer_ = nullptr;
555 
556     int32* dynamic_size_buffer_ = nullptr;
557 
558     // The shape of piece. This points into the shape of the containing Literal
559     // (Literal::shape_).
560     const Shape* subshape_ = nullptr;
561 
562     // Children pieces for tuple shaped pieces.
563     std::vector<Piece> children_ = {};
564   };  // class Piece
565 
piece(const ShapeIndex & shape_index)566   const Piece& piece(const ShapeIndex& shape_index) const {
567     Piece* piece = &const_cast<Piece&>(root_piece());
568     for (const auto i : shape_index) {
569       DCHECK_GE(i, 0);
570       DCHECK_LT(i, piece->children_size());
571       piece = &piece->child(i);
572     }
573     return *piece;
574   }
575 
576   // Returns the piece at the root of the shape.
577   virtual const Piece& root_piece() const = 0;
578 
579   // LiteralSlice and Literal must access Pieces of other Literals.
580   friend class MutableLiteralBase;
581   friend class LiteralSlice;
582   friend class BorrowingLiteral;
583 
584  private:
585   template <typename NativeT>
586   Literal SliceInternal(const Shape& result_shape,
587                         absl::Span<const int64> start_indices) const;
588 };
589 
590 // Abstract base class representing a mutable literal in XLA.
591 class MutableLiteralBase : public LiteralBase {
592  public:
593   virtual ~MutableLiteralBase() = 0;
594 
595   // Returns a Span view of the array for this literal for the
596   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
597   // given ShapeIndex is not array. See primitive_util.h for the mapping from
598   // XLA type to native type.
599   template <typename NativeT>
600   absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
601   // Unhide const method from parent class.
602   using LiteralBase::data;
603 
604   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
605   // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()606   Shape* mutable_shape_do_not_use() { return shape_.get(); }
607 
608   // Set the dynamic size on dim_index in the literal at the given shape_index.
609   void SetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index,
610                       int32_t size);
611   void SetDynamicSize(int64_t dim_index, int32_t size);
612 
613   // Returns a pointer to the underlying buffer holding the array at the given
614   // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
615   // is not array.
616   void* untyped_data(const ShapeIndex& shape_index = {});
617   // Unhide const method from parent class.
618   using LiteralBase::untyped_data;
619 
620   template <typename NativeT>
621   void MutableEachCell(
622       std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
623           per_cell);
624 
625   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
626   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
627   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
628   // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
629   // is true, only elements within dynamic bounds will be copied.
630   Status CopyFrom(const LiteralSlice& src_literal,
631                   const ShapeIndex& dest_shape_index = {},
632                   const ShapeIndex& src_shape_index = {},
633                   bool only_dynamic_bound = false);
634 
635   // Copies the values from src_literal, starting at src_base shape indexes,
636   // to this literal, starting at dest_base, where the copy size in each
637   // dimension is specified by copy_size.
638   // The src_literal and this literal must have the same primitive type,
639   // src_base+copy_size must fit the source literal dimensions, as well as
640   // dest_base+copy_size must fit the destination literal dimensions.
641   // Note: if either src_literal or this literal contains dimensions with zero
642   // element, then copy_size must be 0 in these dimensions while the
643   // corresponding base indices being 0.
644   // This literal and 'src_literal' must be arrays.
645   Status CopySliceFrom(const LiteralSlice& src_literal,
646                        absl::Span<const int64> src_base,
647                        absl::Span<const int64> dest_base,
648                        absl::Span<const int64> copy_size);
649 
650   // Copies one element from src_literal[src_index] to (*this)[dest_index].
651   Status CopyElementFrom(const LiteralSlice& src_literal,
652                          absl::Span<const int64> src_index,
653                          absl::Span<const int64> dest_index);
654 
655   // Sets an element in the literal at the given index. The multi_index is
656   // CHECKed against the dimension sizes.
657   template <typename NativeT>
658   void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
659            NativeT value);
660   // Overloads of Set for array literals. CHECKs if the literal is not
661   // array-shaped and dense.
662   template <typename NativeT>
663   void Set(absl::Span<const int64> multi_index, NativeT value);
664 
665   // As Set(), but truncates `value` to the literal element type before storing.
666   // This literal must be an array.
667   Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64_t value);
668 
669   // As Set(), but truncates `value` to the literal element type before storing.
670   // This literal must be an array.
671   Status SetFromDouble(absl::Span<const int64> multi_index, double value);
672 
673   // Populate this literal with the given values. Examples:
674   //
675   //   // Populate with floats.
676   //   Array2D<float> float_values = ...
677   //   literal.PopulateR2FromArray2D(values);
678   //
679   //   // Populate with int32s.
680   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
681   //
682   // The shape and element type of this literal must match given values. For
683   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
684   // array of S32.
685   template <typename NativeT>
686   void PopulateR1(absl::Span<const NativeT> values);
687   void PopulateR1(const tensorflow::core::Bitmap& values);
688   template <typename NativeT>
689   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
690   template <typename NativeT>
691   void PopulateFromArray(const Array<NativeT>& values);
692   template <typename NativeT>
693   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
694   template <typename NativeT>
695   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
696   template <typename NativeT>
697   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
698 
699   // Populates literal values by calling the generator function for every cell
700   // in this literal object.
701   //
702   // generator must be a callable of the type
703   // NativeT(absl::Span<int64> indexes) or compatible.
704   //
705   // This literal must have a dense layout.
706   template <typename NativeT, typename FnType>
707   Status Populate(const FnType& generator);
708 
709   // A parallel version of Populate(). This can be used if the generator is
710   // thread-safe and the values for the shape's different elements are
711   // independent.
712   template <typename NativeT, typename FnType>
713   Status PopulateParallel(const FnType& generator);
714 
715   // Fills this literal with the given value.
716   template <typename NativeT>
717   void PopulateWithValue(NativeT value);
718 
719   // This operation is the inverse of DecomposeTuple. The given elements are
720   // moved into the tuple elements of a new tuple-shaped Literal which is
721   // returned. Upon return, each of the Literals in 'elements' is set to a nil
722   // shape (empty tuple).
723   static Literal MoveIntoTuple(absl::Span<Literal> elements);
724 
725   // Serialize from a proto.
726   static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
727                                            bool prohibit_empty_literal = true);
728 
729  protected:
730   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)731   Piece& piece(const ShapeIndex& shape_index) {
732     return const_cast<Piece&>(LiteralBase::piece(shape_index));
733   }
734 
root_piece()735   Piece& root_piece() const override { return *root_piece_; };
736 
737   // Internal template helper for the Literal::CopySliceFrom(), matching its
738   // arguments one by one.
739   template <typename NativeT>
740   Status CopySliceFromInternal(const LiteralBase& src_literal,
741                                absl::Span<const int64> src_base,
742                                absl::Span<const int64> dest_base,
743                                absl::Span<const int64> copy_size);
744 
745   // Utility structure which is used to create the optimal configuration for
746   // a ShapeUtil::ForEachIndex() scan across two literals.
747   struct StrideConfig {
748     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
749                  absl::Span<const int64> dimensions);
750 
751     // The dimensions of the stride operation. Essentially every dimension
752     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
753     // steps.
754     absl::Span<const int64> dimensions;
755     DimensionVector base;
756     DimensionVector step;
757     int64 minor_dimension = 0;
758     // The size of the strides for source and destination. One of the two
759     // (the one looping through its most minor dimension) will be 1, while
760     // the other will be the stride size at the dimension matching the other
761     // shape most minor dimension being scanned.
762     int64 dest_stride = 1;
763     int64 source_stride = 1;
764     // The size of the inner loop on the most minor dimension.
765     int64 minor_loop_size = 1;
766   };
767 
768   // Literal class always owns the shape. The parent class borrows this shape.
769   std::unique_ptr<Shape> shape_;
770 
771   Piece* root_piece_ = nullptr;
772 
773   // Implementation details shared between Populate() and PopulateParallel()
774   template <typename NativeT, typename FnType>
775   Status PopulateInternal(const FnType& generator, bool parallel);
776 
777   friend class LiteralBase;
778   friend class MutableBorrowingLiteral;
779 };
780 std::ostream& operator<<(std::ostream& out, const Literal& literal);
781 
782 // The underlying buffer and shape is always owned by this class.
783 class Literal : public MutableLiteralBase {
784  public:
Literal()785   Literal() : Literal(ShapeUtil::MakeNil()) {}
786 
787   // Create a literal of the given shape. The literal is allocated sufficient
788   // memory to hold the shape. Memory is uninitialized.
789   explicit Literal(const Shape& shape);
790   virtual ~Literal();
791 
792   // Literals are moveable, but not copyable. To copy a literal use
793   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
794   // of literals which can be expensive.
795   Literal(const Literal& other) = delete;
796   Literal& operator=(const Literal& other) = delete;
797   Literal(Literal&& other);
798   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
799   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
800   // to nullptr.
801   Literal(const Shape& shape, bool allocate_arrays);
802   Literal& operator=(Literal&& other);
803 
804   // Similar to CopyFrom, but with move semantics. The subshape of this literal
805   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
806   // (layouts and shapes must match), but need not be arrays. The memory
807   // allocated in this literal for the subshape at dest_shape_index is
808   // deallocated, and the respective buffers are replaced with those in
809   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
810   virtual Status MoveFrom(Literal&& src_literal,
811                           const ShapeIndex& dest_shape_index = {});
812 
813   // Returns a vector containing the tuple elements of this Literal as separate
814   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
815   // elements are moved into the new Literals; no data is copied. Upon return
816   // this Literal is set to a nil shape (empty tuple)
817   std::vector<Literal> DecomposeTuple();
818 
819   // Returns a subliteral specified by given shape_index. No data is copied, the
820   // current literal becomes invalid after this function call.
821   Literal SubLiteral(ShapeIndexView shape_index);
822 
823  private:
824   // Deallocate the buffers held by this literal.
825   void DeallocateBuffers();
826 
827   // Recursively sets the subshapes and buffers of all subpieces rooted at
828   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
829   // the shape.
830   void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
831 };
832 
833 // The underlying buffer is not owned by this class and is always owned by
834 // others. The shape is not owned by this class and not mutable.
835 class MutableBorrowingLiteral : public MutableLiteralBase {
836  public:
837   virtual ~MutableBorrowingLiteral();
838 
MutableBorrowingLiteral()839   MutableBorrowingLiteral() : MutableLiteralBase() {}
840 
841   MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
842   MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
843 
844   // Implicit conversion constructors.
845   MutableBorrowingLiteral(MutableLiteralBase* literal);
846   MutableBorrowingLiteral(MutableBorrowingLiteral literal,
847                           const ShapeIndex& view_root);
848   MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
849 
850   // Create a literal from a list of buffers and a shape.
851   // Returns a tuple literal if `shape` is a tuple type.
852   MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
853 
854  private:
855   // Recursively copies the subtree from the `src_piece` at the given child
856   // index to the `dest_piece`. For buffers only the pointers are copied, but
857   // not the content.
858   void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
859                         Piece* dest_piece);
860 };
861 
862 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
863 // literal buffers always owned by others.
864 class LiteralSlice : public LiteralBase {
865  public:
LiteralSlice()866   LiteralSlice() : LiteralBase() {}
867 
868   // Implicit conversion constructors.
869   LiteralSlice(const LiteralBase& literal);
870   LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
871 
872  private:
root_piece()873   const Piece& root_piece() const override { return *root_piece_; };
874 
875   const Piece* root_piece_;  // Not owned.
876 };
877 
878 // A read-only Literal where the underlying buffers are never owned by this
879 // class.
880 class BorrowingLiteral : public LiteralBase {
881  public:
BorrowingLiteral()882   BorrowingLiteral() : LiteralBase() {}
883 
884   // 'src_buf_ptr' is not owned by this class and must outlive the
885   // lifetime of this class. It points to an appropriately sized buffer with
886   // data interpretered as indicated by 'shape'.
887   // This constructor is only used for array shapes.
888   BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
889   // Similar as above, except to be used for constructing non-nested tuples.
890   BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
891                    const Shape& shape);
892   // TODO(b/79707221): adding constructors for nested tuples as well.
893 
894  private:
895   // Recursively builds the subtree for the given piece and sets the subshapes
896   // of the given piece with the given shape.
897   void BuildPieceSubtree(const Shape& shape, Piece* piece);
898 
899   // Accessor for the root piece of this literal.
root_piece()900   const Piece& root_piece() const override { return root_piece_; };
901   Piece root_piece_;
902 
903   // Shape of this literal. Stored as unique_ptr such that the (default) move
904   // construction of this class would be trivially correct: the pointer to Shape
905   // root_piece_ stores will still point to the correct address.
906   std::unique_ptr<Shape> shape_;
907 };
908 
909 template <typename NativeT>
data()910 absl::Span<const NativeT> LiteralBase::Piece::data() const {
911   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
912   DCHECK_EQ(subshape().element_type(),
913             primitive_util::NativeToPrimitiveType<NativeT>())
914       << "Attempting to access "
915       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
916       << " type, but literal element type is "
917       << PrimitiveType_Name(subshape().element_type());
918   return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
919                                    element_count());
920 }
921 
922 template <typename NativeT>
data()923 absl::Span<NativeT> LiteralBase::Piece::data() {
924   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
925   DCHECK_EQ(subshape().element_type(),
926             primitive_util::NativeToPrimitiveType<NativeT>())
927       << "Attempting to access "
928       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
929       << " type, but literal element type is "
930       << PrimitiveType_Name(subshape().element_type());
931   return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
932                              element_count());
933 }
934 
935 template <typename NativeT>
Get(absl::Span<const int64> multi_index)936 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
937   CHECK(LayoutUtil::IsDenseArray(subshape())) << subshape();
938   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
939       subshape(), multi_index)];
940 }
941 
942 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)943 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
944                              NativeT value) {
945   CHECK(LayoutUtil::IsDenseArray(subshape()));
946   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
947       subshape(), multi_index)] = value;
948 }
949 
950 template <typename NativeT>
data(const ShapeIndex & shape_index)951 absl::Span<const NativeT> LiteralBase::data(
952     const ShapeIndex& shape_index) const {
953   return piece(shape_index).data<NativeT>();
954 }
955 
956 template <typename NativeT>
data(const ShapeIndex & shape_index)957 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
958   return piece(shape_index).data<NativeT>();
959 }
960 
961 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)962 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
963                                 const ShapeIndex& shape_index) const {
964   return piece(shape_index).Get<NativeT>(multi_index);
965 }
966 
967 template <typename NativeT>
Get(absl::Span<const int64> multi_index)968 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
969   return root_piece().Get<NativeT>(multi_index);
970 }
971 
972 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)973 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
974                                     const ShapeIndex& shape_index,
975                                     NativeT value) {
976   return piece(shape_index).Set<NativeT>(multi_index, value);
977 }
978 
979 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)980 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
981                                     NativeT value) {
982   return root_piece().Set<NativeT>(multi_index, value);
983 }
984 
985 template <typename NativeT>
GetFirstElement()986 NativeT LiteralBase::GetFirstElement() const {
987   return data<NativeT>().at(0);
988 }
989 
990 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)991 void LiteralBase::EachCell(
992     std::function<void(absl::Span<const int64> indices, NativeT value)>
993         per_cell) const {
994   if (ShapeUtil::IsZeroElementArray(shape())) {
995     return;
996   }
997   std::vector<int64> indices(shape().rank(), 0);
998 
999   Shape shape_dynamic = shape();
1000   for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1001     shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1002   }
1003   do {
1004     per_cell(indices, Get<NativeT>(indices));
1005   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1006 }
1007 
1008 template <typename NativeT>
MutableEachCell(std::function<NativeT (absl::Span<const int64> indices,NativeT value)> per_cell)1009 void MutableLiteralBase::MutableEachCell(
1010     std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
1011         per_cell) {
1012   if (ShapeUtil::IsZeroElementArray(shape())) {
1013     return;
1014   }
1015   std::vector<int64> indices(shape().rank(), 0);
1016 
1017   Shape shape_dynamic = shape();
1018   for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1019     shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1020   }
1021   do {
1022     Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
1023   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1024 }
1025 
1026 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)1027 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
1028   CHECK(shape().IsArray());
1029   CHECK_EQ(shape().rank(), 1);
1030   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1031   CHECK_EQ(shape().element_type(),
1032            primitive_util::NativeToPrimitiveType<NativeT>());
1033   auto data_span = data<NativeT>();
1034   std::copy(values.begin(), values.end(), data_span.begin());
1035 }
1036 
1037 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1038 void MutableLiteralBase::PopulateR2(
1039     std::initializer_list<std::initializer_list<NativeT>> values) {
1040   CHECK(shape().IsArray());
1041   CHECK_EQ(shape().rank(), 2);
1042   CHECK_EQ(shape().element_type(),
1043            primitive_util::NativeToPrimitiveType<NativeT>());
1044 
1045   const int64_t dim0_size = values.size();
1046   const int64_t dim1_size = values.begin()->size();
1047   CHECK_EQ(dim0_size, shape().dimensions(0));
1048   CHECK_EQ(dim1_size, shape().dimensions(1));
1049 
1050   int64_t dim0 = 0;
1051   for (auto inner_list : values) {
1052     int64_t dim1 = 0;
1053     for (auto value : inner_list) {
1054       Set({dim0, dim1}, value);
1055       ++dim1;
1056     }
1057     CHECK_EQ(dim1_size, dim1);
1058     ++dim0;
1059   }
1060 }
1061 
1062 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1063 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1064   CHECK(shape().IsArray());
1065   CHECK_EQ(shape().element_type(),
1066            primitive_util::NativeToPrimitiveType<NativeT>());
1067   CHECK_EQ(shape().rank(), values.num_dimensions());
1068   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1069     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1070   }
1071   values.Each([this](absl::Span<const int64> indices, NativeT value) {
1072     this->Set(indices, value);
1073   });
1074 }
1075 
1076 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1077 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1078   PopulateFromArray(values);
1079 }
1080 
1081 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1082 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1083   PopulateFromArray(values);
1084 }
1085 
1086 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1087 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1088   PopulateFromArray(values);
1089 }
1090 
1091 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1092 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1093                                             bool parallel) {
1094   const Shape& this_shape = shape();
1095   const int64_t rank = this_shape.rank();
1096   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1097   TF_RET_CHECK(this_shape.element_type() ==
1098                primitive_util::NativeToPrimitiveType<NativeT>());
1099   absl::Span<NativeT> literal_data = data<NativeT>();
1100   if (rank > 0) {
1101     StrideConfig stride_config(this_shape, this_shape,
1102                                AsInt64Slice(this_shape.dimensions()));
1103     int64_t minor_dimension_size =
1104         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1105 
1106     auto init_function = [&](absl::Span<const int64> indexes) {
1107       DimensionVector minor_scan_indexes(rank, 0);
1108       const int64_t index =
1109           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1110       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1111       for (int64_t i = 0; i < minor_dimension_size; ++i) {
1112         minor_scan_indexes[stride_config.minor_dimension] = i;
1113         literal_data.at(index + i) = generator(minor_scan_indexes);
1114       }
1115     };
1116     if (parallel) {
1117       ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1118                                       stride_config.dimensions,
1119                                       stride_config.step, init_function);
1120     } else {
1121       ShapeUtil::ForEachIndex(
1122           this_shape, stride_config.base, stride_config.dimensions,
1123           stride_config.step,
1124           [&init_function](absl::Span<const int64> indexes) {
1125             init_function(indexes);
1126             return true;
1127           });
1128     }
1129   } else {
1130     // For scalars.
1131     literal_data.at(0) = generator({});
1132   }
1133   return Status::OK();
1134 }
1135 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1136 Status MutableLiteralBase::Populate(const FnType& generator) {
1137   return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1138 }
1139 
1140 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1141 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1142   return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1143 }
1144 
1145 template <typename NativeT>
PopulateWithValue(NativeT value)1146 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1147   CHECK(shape().IsArray());
1148   CHECK_EQ(shape().element_type(),
1149            primitive_util::NativeToPrimitiveType<NativeT>());
1150   for (NativeT& element : data<NativeT>()) {
1151     element = value;
1152   }
1153 }
1154 
1155 template <typename NativeT>
Replicate(int64_t times)1156 Literal LiteralBase::Replicate(int64_t times) const {
1157   DimensionVector bounds = {times};
1158   bounds.reserve(shape().dimensions_size() + 1);
1159   for (int64_t bound : shape().dimensions()) {
1160     bounds.push_back(bound);
1161   }
1162   Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1163   int64_t elements = ShapeUtil::ElementsIn(literal.shape());
1164   if (elements == 0) {
1165     return literal;
1166   }
1167 
1168   DimensionVector output_indices(bounds.size(), 0);
1169   absl::Span<const int64> input_indices = output_indices;
1170   input_indices.remove_prefix(1);
1171 
1172   bool done = false;
1173   while (!done) {
1174     const auto element = Get<NativeT>(input_indices);
1175     literal.Set<NativeT>(output_indices, element);
1176 
1177     done = true;
1178     for (int n = 0; n < output_indices.size(); ++n) {
1179       ++output_indices[n];
1180       if (output_indices[n] < bounds[n]) {
1181         done = false;
1182         break;
1183       }
1184       output_indices[n] = 0;
1185     }
1186   }
1187   return literal;
1188 }
1189 
1190 }  // namespace xla
1191 
1192 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1193