• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18 
19 #include <algorithm>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <optional>
26 #include <ostream>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31 
32 #include "absl/strings/string_view.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/array2d.h"
35 #include "tensorflow/compiler/xla/array3d.h"
36 #include "tensorflow/compiler/xla/array4d.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/protobuf.h"
49 
50 namespace xla {
51 
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56 
57 // Abstract base class for literals.
58 class LiteralBase {
59  public:
60   virtual ~LiteralBase() = 0;
61 
62   // Literals are equal if they have compatible shapes and the same data
63   // values. Layout is not compared.
64   bool operator==(const LiteralBase& other) const;
65   bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66 
67   // Returns the shape of the literal.
shape()68   const Shape& shape() const { return root_piece().subshape(); }
69 
70   // Serialize to proto.
71   LiteralProto ToProto() const;
72 
73   // Returns a Span of the array for this literal for the given NativeT
74   // (e.g., float). CHECKs if the subshape of the literal at the given
75   // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76   // to native type.
77   template <typename NativeT>
78   absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79 
80   // Returns a const pointer to (or size of) the underlying buffer holding the
81   // array at the given shape index. CHECKs if the subshape of the literal at
82   // the given ShapeIndex is not array.
83   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84   int64_t size_bytes(const ShapeIndex& shape_index = {}) const;
85 
86   // Returns this literal's data as a string. This literal must be a rank-1 U8
87   // array.
88   std::string GetR1U8AsString() const;
89 
90   // Returns a string representation of the literal value. The Shape of the
91   // literal is a prefix of the literal value in the string.
92 
93   // Warning: this function can take minutes for multi-million
94   // element Literals.
95   std::string ToString() const;
96 
97   // Similar to ToString, but return the result in a compact
98   // one-line form.
99   std::string ToStringOneline() const;
100 
101   // Returns a string representation of the literal value which does *not*
102   // include the shape string.
103   std::string ToStringWithoutShape() const;
104 
105   // Similar to ToStringWithoutShape, but return the result in a compact
106   // one-line form.
107   std::string ToStringWithoutShapeOneline() const;
108 
109   // Returns a string representation of the literal value which includes the
110   // shape string with its layout.does *not* include the shape string.
111   std::string ToStringWithLayout() const;
112 
113   // Similar to ToStringWithLayout, but return the result in a compact
114   // one-line form.
115   std::string ToStringWithLayoutOneline() const;
116 
117   // Gets an element in the literal at the given index. The multi_index is
118   // CHECKed against the dimension sizes.
119   template <typename NativeT>
120   NativeT Get(absl::Span<const int64_t> multi_index,
121               const ShapeIndex& shape_index) const;
122   // Overloads of Get for array literals. CHECKs if the literal is not
123   // array-shaped and dense.
124   template <typename NativeT>
125   NativeT Get(absl::Span<const int64_t> multi_index) const;
126 
127   // Get the dynamic size on dim_index in the literal at the given shape_index.
128   int32_t GetDynamicSize(int64_t dim_index,
129                          const ShapeIndex& shape_index) const;
130   int32_t GetDynamicSize(int64_t dim_index) const;
131 
132   // Returns the element value at index (0, ..., 0), however many zeroes are
133   // required for that index.
134   template <typename NativeT>
135   NativeT GetFirstElement() const;
136 
137   // As above but returns any integer type casted to an int64_t.
138   std::optional<int64_t> GetFirstInteger() const;
139 
140   // As Get(), but determines the correct type and converts the value
141   // into text.
142   std::string GetAsString(absl::Span<const int64_t> multi_index,
143                           const ShapeIndex& shape_index = {}) const;
144 
145   // Return whether the value at the specified index is equal to the provided
146   // generic `value` (T must be an arithmetic type).
147   //
148   // Precondition: must be an array.
149   template <typename T>
150   typename std::enable_if<(std::is_arithmetic<T>::value ||
151                            std::is_same<T, Eigen::half>::value ||
152                            std::is_same<T, bfloat16>::value),
153                           bool>::type
IsEqualAt(absl::Span<const int64_t> multi_index,T value)154   IsEqualAt(absl::Span<const int64_t> multi_index, T value) const {
155     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
156       return *as_s64 == value;
157     }
158     complex128 as_complex128 = *GetAsComplex128(multi_index);
159     return as_complex128.imag() == 0 && as_complex128.real() == value;
160   }
161 
IsEqualAt(absl::Span<const int64_t> multi_index,complex128 value)162   bool IsEqualAt(absl::Span<const int64_t> multi_index,
163                  complex128 value) const {
164     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
165       return *as_s64 == value.real() && value.imag() == 0;
166     }
167     auto as_complex128 = GetAsComplex128(multi_index);
168     return *as_complex128 == value;
169   }
170 
171   // As Get(), but determines the correct type and converts the value into
172   // int64_t.  This literal must be an array.
173   std::optional<int64_t> GetIntegralAsS64(
174       absl::Span<const int64_t> multi_index) const;
175 
176   // As Get(), but determines the correct type, and converts the value into
177   // double. This literal must be an array.
178   std::optional<double> GetAsDouble(
179       absl::Span<const int64_t> multi_index) const;
180 
181   // As Get(), but determines the correct type, and converts the value into
182   // complex128. All floating point types can be converted into complex128.
183   //
184   // This literal must be an array.
185   std::optional<complex128> GetAsComplex128(
186       absl::Span<const int64_t> multi_index) const;
187 
188   // Invokes the "per cell" callback for each element in the provided
189   // literal with the element's indices and a string representation of
190   // the element's value.
191   //
192   // This function is useful if you want a polymorphic representation
193   // of the tensor's elements (turning it to a string for something
194   // like representation in a protobuf).
195   //
196   // This literal must have a dense layout.
197   void EachCellAsString(
198       const std::function<void(absl::Span<const int64_t> indices,
199                                const std::string& value)>& per_cell) const;
200   template <typename NativeT>
201   void EachCell(
202       std::function<void(absl::Span<const int64_t> indices, NativeT value)>
203           per_cell) const;
204 
205   // Checks whether all of this literal's values are equal to the given scalar
206   // literal.
207   //
208   // If `this` is not an array (e.g. it's a tuple), returns false.  This is
209   // simpler than trying to handle subshapes here, and it's almost always what
210   // you want.
211   //
212   // Preconditions:
213   //  - `scalar` is a scalar.
214   //  - `scalar` has the same element-type as `this`.
215   bool IsAll(const Literal& scalar) const;
216 
217   // Returns whether every element in this literal is equal to value.
218   //
219   // value is an int8_t because we expect this to be called with small
220   // compile-time constants (0, -1, etc.) and so that whatever value you pass
221   // can be represented exactly by floating-point types as small as 16 bits.
222   //
223   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
224   // are considered equal to true/false; other values are not considered equal
225   // to true.
226   //
227   // Returns false if this literal is not array-shaped.
228   bool IsAll(int8_t value) const;
229 
230   // Like IsAll(int8_t), except we check whether the literal is equal to a
231   // particular floating-point or complex number.
232   //
233   // Returns false if this literal is not a floating-point / complex value, or
234   // if it's not an array.
235   //
236   // This casts value to the type of literal, then compares using ==, with the
237   // caveat that NaNs are considered equal.  The usual admonishments about
238   // floating-point equality checks apply.  We expect you to use this to check
239   // for values that can be expressed precisely as a float, e.g. -0.5.
240   bool IsAllFloat(float value) const;
241   bool IsAllComplex(complex64 value) const;
242 
243   // Deetermines if this literal consists entirely of the first element of the
244   // literal.
245   //
246   // Returns false if this literal is not an array.
247   bool IsAllFirst() const;
248 
249   // Literal consists entirely of an iota.
250   bool IsR1Iota() const;
251 
252   // Returns the stride if the literal is a strided iota.
253   std::optional<int64_t> IsR1StridedIota() const;
254 
255   // Returns whether this literal is zero at the specified index. This literal
256   // must be an array with a dense layout.
257   bool IsZero(absl::Span<const int64_t> indices) const;
258 
259   // Returns the count of the elements in the array at the given shape index in
260   // this literal.
261   int64_t element_count(const ShapeIndex& index = {}) const {
262     if (index.empty()) {
263       // Common case, avoid GetSubshape().
264       return ShapeUtil::ElementsIn(shape());
265     }
266     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
267   }
268 
269   // Compute a hash for this literal.
270   template <typename H>
AbslHashValue(H state,const LiteralBase & value)271   friend H AbslHashValue(H state, const LiteralBase& value) {
272     return LiteralBase::Hash(std::move(state), value);
273   }
274 
275   template <typename H, bool kIsLayoutSensitive = true,
276             int64_t kByteLimit = std::numeric_limits<int64_t>::max()>
Hash(H state,const LiteralBase & literal)277   static H Hash(H state, const LiteralBase& literal) {
278     state =
279         Shape::Hash<H, kIsLayoutSensitive>(std::move(state), literal.shape());
280 
281     ShapeUtil::ForEachSubshape(
282         literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) {
283           if (!subshape.IsArray()) {
284             return;
285           }
286 
287           CHECK(LayoutUtil::IsDenseArray(subshape));
288           auto data = absl::MakeConstSpan(
289               static_cast<const char*>(literal.untyped_data(index)),
290               std::min(kByteLimit, literal.size_bytes(index)));
291           state = H::combine(std::move(state), data);
292         });
293 
294     return std::move(state);
295   }
296 
297   // Converts this literal to the given shape. Returns an error is the
298   // conversion is not possible.
299   StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
300 
301   // Converts this literal to another primitive type using a bitcast
302   // conversion. Returns an error if the conversion is not possible. This
303   // literal must be array-shaped.
304   StatusOr<Literal> BitcastConvert(const Shape& dest_shape) const;
305 
306   // Converts this literal to another primitive type. Returns an error if the
307   // conversion is not possible. This literal must be array-shaped.
308   StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
309 
310   // Clones the underlying buffers into a new Literal.
311   Literal Clone() const;
312   std::unique_ptr<Literal> CloneToUnique() const;
313 
314   // TODO(b/67651157): The methods below which perform computation on Literals
315   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
316   // evaluator code which operates on Literals.
317   //
318   // Creates a new value that has the equivalent value as this
319   // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
320   // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
321   // minor-to-major dimension layout and the value in the cell at any given
322   // logical index (i0, i1) will be the same.
323   //
324   // For tuple shaped literals, shape_index should be used to select the inner
325   // array that the new layout applies to.
326   //
327   // Note: this is useful when the client wants to ensure that a value placed in
328   // the XLA allocation tracker has a particular layout; for efficiency
329   // purposes or avoiding unimplemented operation/layout combinations.
330   Literal Relayout(const Layout& new_layout,
331                    const ShapeIndex& shape_index = {}) const;
332 
333   // An overload of Relayout which changes the layout of the entire shape rather
334   // than being limited to a single array within the shape.
335   Literal Relayout(const Shape& shape_with_layout) const;
336 
337   // Generate a new literal whose static sizes are equal to the previous
338   // literal's dynamic sizes.
339   Literal ToStatic() const;
340 
341   // Expand a static literal into a new one with a bounded dyanmic literal. The
342   // static dimensions of the original literal becomes dynamic dimensions of the
343   // new literal, where the argument `bounded_shape` becomes the bounded shape
344   // of the new literal.
345   //
346   // Precondition: bounded_shape.is_dynamic()
347   Literal ToBoundedDynamic(const Shape& bounded_shape) const;
348 
349   // Creates a new literal by reshaping this literal to have the given
350   // dimensions. The total number of elements must not change; The
351   // implementation currently only supports monotonic dim0-major layouts.
352   // This literal must be an array.
353   StatusOr<Literal> Reshape(absl::Span<const int64_t> dimensions) const;
354 
355   // Creates a new literal by broadcasting this literal with `dimensions` to
356   // yield a literal of shape `result_shape`.
357   StatusOr<Literal> Broadcast(const Shape& result_shape,
358                               absl::Span<const int64_t> dimensions) const;
359 
360   // Creates a new literal by reordering the dimensions of this literal.
361   // The given `permutation` must be a permutation of the dimension numbers
362   // in the original literal, and it specifies the order of the new dimensions
363   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
364   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
365   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
366   // This literal must be an array.
367   Literal Transpose(absl::Span<const int64_t> permutation) const;
368 
369   // Creates a sub-array from this literal by extracting the indices
370   // [start_index, limit_index) of each dimension. The result literal has the
371   // same rank and layout as for the given literal. The number of indices in
372   // start_indices and limit_indices must be the rank of the literal, and the
373   // indices follow the order of the dimensions.
374   // This literal must be an array.
375   Literal Slice(absl::Span<const int64_t> start_indices,
376                 absl::Span<const int64_t> limit_indices) const;
377 
378   // Creates a literal with a prepended dimension with bound "times"; e.g. a
379   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
380   // literal replicated four times.
381   // This literal must be an array.
382   template <typename NativeT>
383   Literal Replicate(int64_t times) const;
384 
385   // Returns true if the leaf arrays of the literal within the given shape index
386   // are all determined.
387   // See comments on ArrayValueState for detailed explanation.
388   bool IsDetermined(const ShapeIndex& shape_index = {}) const;
389 
390   // Returns true if the leaf arrays of the literal within the given shape index
391   // are all known.
392   // See comments on ArrayValueState for detailed explanation.
393   bool IsKnown(const ShapeIndex& shape_index = {}) const;
394 
395   // Creates a new Literal object with the shape specified as parameter.
396   // The content of the literal values is the default value of the primitive
397   // type of literal itself (0 for numeric types, and false for predicates).
398   //
399   // Note: It's an antipattern to use this method then immediately call
400   // MutableLiteralBase::Populate on the result (since that results in zero
401   // initialization, then reinitialization. Consider if a call to
402   // std::make_unique<Literal>(shape), followed by the call to
403   // MutableLiteralBase::Populate can be used instead.
404   static Literal CreateFromShape(const Shape& shape);
405 
406   // WARNING: These two functions are only supposed to be used by HloEvaluator.
407   // The rest of XLA assumes all literals are known.
408   // Similar to CreateFromShape() but marks all leaf arrays as unknown.
409   static Literal CreateFromShapeWithUnknownLeafArrays(const Shape& shape);
410   // Similar to CreateFromShape() but marks all leaf arrays as undetermined.
411   static Literal CreateFromShapeWithUndeterminedLeafArrays(const Shape& shape);
412 
413  protected:
414   // Array literals could be in one of the following three states:
415   //   1) Known: we have evaluated and known the value of the array literal.
416   //   2) Unknown: we have tried to evaluate the array literal, but its value
417   //               cannot be evaluated statically.
418   //   3) Undetermined: we haven't tried to evaluate the array literal.
419   //  Unknown and Undetermined states are only meant to be used within
420   //  HloEvaluator. The rest of XLA assumes array literals are all known.
421   //  Literals that are unknown or undetermined can be copied from, using
422   //  CopyFrom and Clone, or moved from using move constructor. Accessing values
423   //  of such literals causes undefined behavior.
424   enum class ArrayValueState { kKnown = 0, kUnknown = 1, kUndetermined = 2 };
425 
426   // A data structure representing a subshape at a particular ShapeIndex within
427   // the literal. For array-shaped ShapeIndexes, this data structure holds the
428   // pointer to the memory allocated for the array data.
429   class Piece {
430    public:
431     ArrayValueState get_array_value_state() const;
432     void set_array_value_state(ArrayValueState state);
433     // Returns the buffer holding the array data for this piece as an array
434     // slice. This piece must be array-shaped.
435     template <typename NativeT>
436     absl::Span<const NativeT> data() const;
437     template <typename NativeT>
438     absl::Span<NativeT> data();
439 
440     // Returns the buffer holding the array data for this piece as a void*. This
441     // piece must be array-shaped.
442     void* untyped_data();
443     const void* untyped_data() const;
444 
445     // Gets or sets an element in the array at the given index. The multi_index
446     // is CHECKed against the dimension sizes of the array.  This piece must be
447     // array-shaped.
448     template <typename NativeT>
449     NativeT Get(absl::Span<const int64_t> index) const;
450     template <typename NativeT>
451     void Set(absl::Span<const int64_t> index, NativeT value);
452 
453     int32_t GetDynamicSize(int64_t dim_index) const;
454     void SetDynamicSize(int64_t dim_index, int32_t size);
455     void AllocateBuffers();
456     void DeallocateBuffers();
457     // Gets/sets the buffer holding the array data.
buffer()458     const char* buffer() const { return std::visit(BufferVisitor{}, rep_); }
buffer()459     char* buffer() {
460       return const_cast<char*>(const_cast<const Piece*>(this)->buffer());
461     }
set_buffer(char * buffer)462     void set_buffer(char* buffer) {
463       CHECK(subshape_->IsArray());
464       auto* array_rep = std::holds_alternative<Uninitialized>(rep_)
465                             ? &rep_.emplace<ArrayRep>()
466                             : GetArrayRep();
467       DCHECK(array_rep);
468       array_rep->data = buffer;
469     }
MoveDataFrom(Piece & from)470     void MoveDataFrom(Piece& from) {
471       DCHECK(!std::holds_alternative<ArrayRep>(rep_));
472       DCHECK(!std::holds_alternative<TupleRep>(rep_));
473       if (auto* array_rep = from.GetArrayRep()) {
474         rep_.emplace<ArrayRep>().data = array_rep->data;
475       } else if (auto* inlined_rep = from.GetInlinedRep()) {
476         std::memcpy(rep_.emplace<InlinedRep>().data, inlined_rep->data,
477                     from.total_bytes());
478       }
479       from.rep_.emplace<Uninitialized>();
480     }
481 
482     // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()483     const int32_t* dynamic_size_buffer() const {
484       return reinterpret_cast<const int32_t*>(buffer() + size_bytes());
485     }
dynamic_size_buffer()486     int32_t* dynamic_size_buffer() {
487       return const_cast<int32_t*>(
488           const_cast<const Piece*>(this)->dynamic_size_buffer());
489     }
490 
dynamic_size_buffer_bytes()491     int64_t dynamic_size_buffer_bytes() const {
492       return subshape().dimensions_size() * sizeof(int32_t);
493     }
494 
495     // Gets or sets the subshape of this piece. This reference points to a
496     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()497     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)498     void set_subshape(const Shape* subshape) {
499       subshape_ = subshape;
500       if (std::holds_alternative<Uninitialized>(rep_)) {
501         if (subshape_->IsTuple()) {
502           rep_.emplace<TupleRep>();
503         }
504       }
505     }
506 
507     // Returns the size in bytes of the buffer holding the array data.
size_bytes()508     int64_t size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
509 
510     // Total size in bytes, including the dynamic size addition.
511     //
512     // The shape can become dynamic after this literal is allocated, so we
513     // over-allocate the margin for the dynamic shape description in case we
514     // need it.
total_bytes()515     int64_t total_bytes() const {
516       return size_bytes() + dynamic_size_buffer_bytes();
517     }
518 
519     // Returns the number of elements in this piece's array.
element_count()520     int64_t element_count() const { return ShapeUtil::ElementsIn(subshape()); }
521 
522     // Returns the child piece at 'index' of this piece.
child(int64_t index)523     Piece& child(int64_t index) {
524       return const_cast<Piece&>(const_cast<const Piece*>(this)->child(index));
525     }
child(int64_t index)526     const Piece& child(int64_t index) const {
527       auto* tuple_rep = GetTupelRep();
528       DCHECK(tuple_rep);
529       return tuple_rep->children[index];
530     }
531 
532     // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)533     void emplace_back(Piece child_piece) {
534       auto* tuple_rep = GetTupelRep();
535       DCHECK(tuple_rep);
536       tuple_rep->children.emplace_back(std::move(child_piece));
537     }
538 
539     // Returns the size of children pieces of this piece.
children_size()540     int64_t children_size() {
541       if (auto* tuple_rep = GetTupelRep()) {
542         return tuple_rep->children.size();
543       }
544       return 0;
545     }
546 
547     // Visitor functions that recursively traverses the piece and calls the
548     // given function at each child piece. The function has the type:
549     //    void (const ShapeIndex& index, const Piece& piece)
550     template <typename Fn>
ForEachSubpiece(const Fn & func)551     void ForEachSubpiece(const Fn& func) const {
552       ShapeIndex index;
553       return ForEachHelper(
554                  [&func](const ShapeIndex& index, const Piece& piece) {
555                    func(index, piece);
556                    return OkStatus();
557                  },
558                  *this, &index)
559           .IgnoreError();
560     }
561     // Same as above, but the function has the type:
562     //    Status (const ShapeIndex& index, const Piece& piece)
563     // The first non-OK return value is returned by the function.
564     template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)565     Status ForEachSubpieceWithStatus(const Fn& func) const {
566       ShapeIndex index;
567       return ForEachHelper(func, *this, &index);
568     }
569     // Same as above, but the function has the type:
570     //    Bool (const ShapeIndex& index, const Piece& piece)
571     // The first non-true return value is returned by the function.
572     template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)573     bool ForEachSubpieceWithBool(const Fn& func) const {
574       ShapeIndex index;
575       return ForEachHelperBool(func, *this, &index);
576     }
577     // Same as above, but the function has the type:
578     //    Void (const ShapeIndex& index, Piece& piece)
579     template <typename Fn>
ForEachMutableSubpiece(const Fn & func)580     void ForEachMutableSubpiece(const Fn& func) {
581       ShapeIndex index;
582       return ForEachMutableHelper(
583                  [&func](const ShapeIndex& index, Piece* piece) {
584                    func(index, piece);
585                    return OkStatus();
586                  },
587                  const_cast<xla::LiteralBase::Piece*>(this), &index)
588           .IgnoreError();
589     }
590     // Same as above, but the function has the type:
591     //    Status (const ShapeIndex& index, Piece& piece)
592     // The first non-OK return value is returned by the function.
593     template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)594     Status ForEachMutableSubpieceWithStatus(const Fn& func) {
595       ShapeIndex index;
596       return ForEachMutableHelper(
597           func, const_cast<xla::LiteralBase::Piece*>(this), &index);
598     }
599 
600     // Checks whether all elements of this Piece are equal to the given literal.
601     //
602     // Returns false if this Piece is not an array.
603     //
604     // Preconditions:
605     //  - `scalar` is a scalar.
606     //  - `scalar`'s type matches that of `this`.
607     bool IsAll(const Literal& scalar) const;
608 
609     // Returns true if this piece and 'other' contain the same data. This piece
610     // and 'other' must be array-shaped and compatible. If a literal has dynamic
611     // shape, comparison is done only for the valid elements.
612     bool EqualElements(const Piece& other) const;
613 
614     // Returns true if this piece and other pieces have the same dynamic
615     // dimension sizes.
616     bool EqualDynamicSize(const Piece& other) const;
617 
618     // Writes the shape and data (if array-shaped) into the given proto.
619     void WriteToProto(LiteralProto* proto) const;
620 
621     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
622     // and src must be compatible. If only_dynamic_bound is true, only elements
623     // within dynamic bounds will be copied.
624     Status CopyFrom(const Piece& src, bool only_dynamic_bound);
625 
626     // Copies the data from the given proto into this piece. The shape of this
627     // piece must be equal (not just compatible) to the shape of the proto.
628     Status CopyFromProto(const LiteralProto& proto);
629 
630     // See comments on ArrayValueState for detailed explanation.
631     bool IsDetermined() const;
632 
633     bool IsKnown() const;
634 
635    private:
636     // Uninitialized state representation.
637     struct Uninitialized {};
638     // Out of line array storage.
639     union ArrayRep {
640       char* data;
641     };
642     struct TupleRep {
643       // Children pieces for tuple shaped pieces.
644       std::vector<Piece> children = {};
645     };
646 
647     // Use just so many bytes that we don't increase the sizeof(Piece).
648     static inline constexpr size_t kMaxInlinedBytes =
649         std::max(sizeof(ArrayRep), sizeof(TupleRep));
650 
651     // Inlined array storage.
652     struct InlinedRep {
653       char data[kMaxInlinedBytes];
654     };
655 
656     // Helper visiter to access the buffer in the representation variant.
657     struct BufferVisitor {
operatorBufferVisitor658       char* operator()(Uninitialized&) { return nullptr; }
operatorBufferVisitor659       const char* operator()(const Uninitialized&) const { return nullptr; }
operatorBufferVisitor660       char* operator()(TupleRep&) { return nullptr; }
operatorBufferVisitor661       const char* operator()(const TupleRep&) const { return nullptr; }
operatorBufferVisitor662       char* operator()(InlinedRep& rep) { return rep.data; }
operatorBufferVisitor663       const char* operator()(const InlinedRep& rep) const { return rep.data; }
operatorBufferVisitor664       char* operator()(ArrayRep& rep) { return rep.data; }
operatorBufferVisitor665       const char* operator()(const ArrayRep& rep) const { return rep.data; }
666     };
667 
GetInlinedRep()668     const InlinedRep* GetInlinedRep() const {
669       return std::get_if<InlinedRep>(&rep_);
670     }
GetInlinedRep()671     InlinedRep* GetInlinedRep() { return std::get_if<InlinedRep>(&rep_); }
672 
GetArrayRep()673     const ArrayRep* GetArrayRep() const { return std::get_if<ArrayRep>(&rep_); }
GetArrayRep()674     ArrayRep* GetArrayRep() { return std::get_if<ArrayRep>(&rep_); }
675 
GetTupelRep()676     const TupleRep* GetTupelRep() const { return std::get_if<TupleRep>(&rep_); }
GetTupelRep()677     TupleRep* GetTupelRep() { return std::get_if<TupleRep>(&rep_); }
678     // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
679     // The first non-OK (or non-true) value is returned by the function.
680     // The callable 'func' has the same signature as described above in
681     // ForEachSubpiece*.
682     template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)683     Status ForEachHelper(const Fn& func, const Piece& piece,
684                          ShapeIndex* index) const {
685       TF_RETURN_IF_ERROR(func(*index, piece));
686       if (auto* tuple_rep = piece.GetTupelRep()) {
687         for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
688           index->push_back(i);
689           TF_RETURN_IF_ERROR(
690               ForEachHelper(func, tuple_rep->children[i], index));
691           index->pop_back();
692         }
693       }
694       return OkStatus();
695     }
696     template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)697     bool ForEachHelperBool(const Fn& func, const Piece& piece,
698                            ShapeIndex* index) const {
699       if (!func(*index, piece)) {
700         return false;
701       }
702       if (auto* tuple_rep = piece.GetTupelRep()) {
703         for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
704           index->push_back(i);
705           if (!ForEachHelperBool(func, tuple_rep->children[i], index)) {
706             return false;
707           }
708           index->pop_back();
709         }
710       }
711       return true;
712     }
713     template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)714     Status ForEachMutableHelper(const Fn& func, Piece* piece,
715                                 ShapeIndex* index) {
716       TF_RETURN_IF_ERROR(func(*index, piece));
717       if (auto* tuple_rep = piece->GetTupelRep()) {
718         for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
719           index->push_back(i);
720           TF_RETURN_IF_ERROR(
721               ForEachMutableHelper(func, &tuple_rep->children[i], index));
722           index->pop_back();
723         }
724       }
725       return OkStatus();
726     }
727 
728     // Recursive helper for EqualElements.
729     template <typename NativeT>
730     bool EqualElementsInternal(const Piece& other,
731                                std::vector<int64_t>* multi_index) const;
732 
733     // Internal helper to copy elements from another given piece
734     template <typename NativeT>
735     void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
736 
737     // Storage representation of this piece.
738     std::variant<Uninitialized, InlinedRep, ArrayRep, TupleRep> rep_;
739 
740     // The shape of piece. This points into the shape of the containing Literal
741     // (Literal::shape_).
742     const Shape* subshape_ = nullptr;
743 
744     ArrayValueState array_value_state_ = ArrayValueState::kKnown;
745   };  // class Piece
746 
piece(const ShapeIndex & shape_index)747   const Piece& piece(const ShapeIndex& shape_index) const {
748     Piece* piece = &const_cast<Piece&>(root_piece());
749     for (const auto i : shape_index) {
750       DCHECK_GE(i, 0);
751       DCHECK_LT(i, piece->children_size());
752       piece = &piece->child(i);
753     }
754     return *piece;
755   }
756 
757   // Returns the piece at the root of the shape.
758   virtual const Piece& root_piece() const = 0;
759 
760   // LiteralSlice and Literal must access Pieces of other Literals.
761   friend class MutableLiteralBase;
762   friend class LiteralSlice;
763   friend class BorrowingLiteral;
764 
765  private:
766   template <typename NativeT>
767   Literal SliceInternal(const Shape& result_shape,
768                         absl::Span<const int64_t> start_indices) const;
769 };
770 
771 // Abstract base class representing a mutable literal in XLA.
772 class MutableLiteralBase : public LiteralBase {
773  public:
774   virtual ~MutableLiteralBase() = 0;
775 
776   // Returns a Span view of the array for this literal for the
777   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
778   // given ShapeIndex is not array. See primitive_util.h for the mapping from
779   // XLA type to native type.
780   template <typename NativeT>
781   absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
782   // Unhide const method from parent class.
783   using LiteralBase::data;
784 
785   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
786   // mutate the shape as this can produce malformed Literals.
787   Shape* mutable_shape_do_not_use();
788 
789   // Set the dynamic size on dim_index in the literal at the given shape_index.
790   void SetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index,
791                       int32_t size);
792   void SetDynamicSize(int64_t dim_index, int32_t size);
793 
794   // Returns a pointer to the underlying buffer holding the array at the given
795   // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
796   // is not array.
797   void* untyped_data(const ShapeIndex& shape_index = {});
798   // Unhide const method from parent class.
799   using LiteralBase::untyped_data;
800 
801   template <typename NativeT>
802   void MutableEachCell(
803       std::function<NativeT(absl::Span<const int64_t> indices, NativeT value)>
804           per_cell);
805 
806   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
807   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
808   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
809   // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
810   // is true, only elements within dynamic bounds will be copied.
811   Status CopyFrom(const LiteralSlice& src_literal,
812                   const ShapeIndex& dest_shape_index = {},
813                   const ShapeIndex& src_shape_index = {},
814                   bool only_dynamic_bound = false);
815 
816   // Copies the values from src_literal, starting at src_base shape indexes,
817   // to this literal, starting at dest_base, where the copy size in each
818   // dimension is specified by copy_size.
819   // The src_literal and this literal must have the same primitive type,
820   // src_base+copy_size must fit the source literal dimensions, as well as
821   // dest_base+copy_size must fit the destination literal dimensions.
822   // Note: if either src_literal or this literal contains dimensions with zero
823   // element, then copy_size must be 0 in these dimensions while the
824   // corresponding base indices being 0.
825   // This literal and 'src_literal' must be arrays.
826   Status CopySliceFrom(const LiteralSlice& src_literal,
827                        absl::Span<const int64_t> src_base,
828                        absl::Span<const int64_t> dest_base,
829                        absl::Span<const int64_t> copy_size);
830 
831   // Copies one element from src_literal[src_index] to (*this)[dest_index].
832   Status CopyElementFrom(const LiteralSlice& src_literal,
833                          absl::Span<const int64_t> src_index,
834                          absl::Span<const int64_t> dest_index);
835 
836   // Sets an element in the literal at the given index. The multi_index is
837   // CHECKed against the dimension sizes.
838   template <typename NativeT>
839   void Set(absl::Span<const int64_t> multi_index, const ShapeIndex& shape_index,
840            NativeT value);
841   // Overloads of Set for array literals. CHECKs if the literal is not
842   // array-shaped and dense.
843   template <typename NativeT>
844   void Set(absl::Span<const int64_t> multi_index, NativeT value);
845 
846   // As Set(), but truncates `value` to the literal element type before storing.
847   // This literal must be an array.
848   Status SetIntegralAsS64(absl::Span<const int64_t> multi_index, int64_t value);
849 
850   // As Set(), but truncates `value` to the literal element type before storing.
851   // This literal must be an array.
852   Status SetFromDouble(absl::Span<const int64_t> multi_index, double value);
853 
854   // Populate this literal with the given values. Examples:
855   //
856   //   // Populate with floats.
857   //   Array2D<float> float_values = ...
858   //   literal.PopulateR2FromArray2D(values);
859   //
860   //   // Populate with int32s.
861   //   literal.PopulateR2<int32_t>({{1, 2}, {3, 4}});
862   //
863   // The shape and element type of this literal must match given values. For
864   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
865   // array of S32.
866   template <typename NativeT>
867   void PopulateR1(absl::Span<const NativeT> values);
868   void PopulateR1(const tensorflow::core::Bitmap& values);
869   template <typename NativeT>
870   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
871   template <typename NativeT>
872   void PopulateFromArray(const Array<NativeT>& values);
873   template <typename NativeT>
874   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
875   template <typename NativeT>
876   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
877   template <typename NativeT>
878   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
879 
880   // Populates literal values by calling the generator function for every cell
881   // in this literal object.
882   //
883   // generator must be a callable of the type
884   // NativeT(absl::Span<const int64_t> indexes) or compatible.
885   //
886   // This literal must have a dense layout.
887   template <typename NativeT, typename FnType>
888   Status Populate(const FnType& generator);
889 
890   // A parallel version of Populate(). This can be used if the generator is
891   // thread-safe and the values for the shape's different elements are
892   // independent.
893   template <typename NativeT, typename FnType>
894   Status PopulateParallel(const FnType& generator);
895 
896   // Fills this literal with the given value.
897   template <typename NativeT>
898   void PopulateWithValue(NativeT value);
899 
900   // This operation is the inverse of DecomposeTuple. The given elements are
901   // moved into the tuple elements of a new tuple-shaped Literal which is
902   // returned. Upon return, each of the Literals in 'elements' is set to a nil
903   // shape (empty tuple).
904   static Literal MoveIntoTuple(absl::Span<Literal> elements);
905 
906   // Serialize from a proto.
907   static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
908                                            bool prohibit_empty_literal = true);
909 
910  protected:
911   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)912   Piece& piece(const ShapeIndex& shape_index) {
913     return const_cast<Piece&>(LiteralBase::piece(shape_index));
914   }
915 
mutable_root_piece()916   Piece& mutable_root_piece() { return const_cast<Piece&>(root_piece()); }
917 
918   // Internal template helper for the Literal::CopySliceFrom(), matching its
919   // arguments one by one.
920   template <typename NativeT>
921   Status CopySliceFromInternal(const LiteralBase& src_literal,
922                                absl::Span<const int64_t> src_base,
923                                absl::Span<const int64_t> dest_base,
924                                absl::Span<const int64_t> copy_size);
925 
926   // Utility structure which is used to create the optimal configuration for
927   // a ShapeUtil::ForEachIndex() scan across two literals.
928   struct StrideConfig {
929     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
930                  absl::Span<const int64_t> dimensions);
931 
932     // The dimensions of the stride operation. Essentially every dimension
933     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
934     // steps.
935     absl::Span<const int64_t> dimensions;
936     DimensionVector base;
937     DimensionVector step;
938     int64_t minor_dimension = 0;
939     // The size of the strides for source and destination. One of the two
940     // (the one looping through its most minor dimension) will be 1, while
941     // the other will be the stride size at the dimension matching the other
942     // shape most minor dimension being scanned.
943     int64_t dest_stride = 1;
944     int64_t source_stride = 1;
945     // The size of the inner loop on the most minor dimension.
946     int64_t minor_loop_size = 1;
947   };
948 
949   // A unique_ptr like class which may or may not have ownership of its pointer.
950   // The literal may or may not own the storage of the shape. Creating/copying a
951   // shape can incur significant overhead which in many case we'd like to avoid,
952   // esp. for small literals.
953   class MaybeOwningShapePtr {
954    public:
955     MaybeOwningShapePtr() = default;
MaybeOwningShapePtr(std::unique_ptr<Shape> unique)956     explicit MaybeOwningShapePtr(std::unique_ptr<Shape> unique)
957         : ptr_and_owning_bit_(TakeUnique(std::move(unique))) {}
958 
MaybeOwningShapePtr(const Shape * borrowed)959     explicit MaybeOwningShapePtr(const Shape* borrowed)
960         : ptr_and_owning_bit_(Borrow(borrowed)) {}
961 
~MaybeOwningShapePtr()962     ~MaybeOwningShapePtr() { MaybeDeleteOwned(); }
963 
get()964     const Shape* get() const {
965       return reinterpret_cast<const Shape*>(ptr_and_owning_bit_ & kPointerMask);
966     }
967     Shape* get_mutable(bool ensure_owned = false) {
968       const Shape* const_ptr = get();
969       // TODO(b/67651157): Remove this copy on write logic and combine get() and
970       // get_mutable() once we remove mutable_shape_do_not_use().
971       if (const_ptr && !OwnsPtr()) {
972         ptr_and_owning_bit_ = TakeUnique(std::make_unique<Shape>(*const_ptr));
973         const_ptr = get();
974       }
975       DCHECK(OwnsPtr());
976       return const_cast<Shape*>(const_ptr);
977     }
978     const Shape* operator->() const { return get(); }
979     const Shape& operator*() const { return *get(); }
980 
981     MaybeOwningShapePtr& operator=(std::unique_ptr<Shape> unique) {
982       MaybeDeleteOwned();
983       ptr_and_owning_bit_ = TakeUnique(std::move(std::move(unique)));
984       return *this;
985     }
986 
987     MaybeOwningShapePtr& operator=(const Shape* borrowed) {
988       MaybeDeleteOwned();
989       ptr_and_owning_bit_ = Borrow(borrowed);
990       return *this;
991     }
992 
993     MaybeOwningShapePtr& operator=(MaybeOwningShapePtr&& other) {
994       using std::swap;
995       swap(ptr_and_owning_bit_, other.ptr_and_owning_bit_);
996       return *this;
997     }
998 
999     MaybeOwningShapePtr(const MaybeOwningShapePtr&) = delete;
MaybeOwningShapePtr(MaybeOwningShapePtr && other)1000     MaybeOwningShapePtr(MaybeOwningShapePtr&& other)
1001         : ptr_and_owning_bit_(other.ptr_and_owning_bit_) {
1002       other.ptr_and_owning_bit_ = 0;
1003     }
1004 
Clone()1005     MaybeOwningShapePtr Clone() const {
1006       const Shape* ptr = get();
1007       if (ptr && OwnsPtr()) {
1008         return MaybeOwningShapePtr(std::make_unique<Shape>(*ptr));
1009       }
1010       return MaybeOwningShapePtr(ptr);
1011     }
1012 
1013    private:
1014     enum : uint64_t {
1015       kOwningBitMask = 1UL,
1016       kPointerMask = ~kOwningBitMask,
1017     };
TakeUnique(std::unique_ptr<Shape> unique)1018     static intptr_t TakeUnique(std::unique_ptr<Shape> unique) {
1019       Shape* released = unique.release();
1020       DCHECK_EQ(reinterpret_cast<intptr_t>(released) & kOwningBitMask, 0);
1021       return reinterpret_cast<intptr_t>(released) | kOwningBitMask;
1022     }
1023 
Borrow(const Shape * borrowed)1024     static intptr_t Borrow(const Shape* borrowed) {
1025       DCHECK_EQ(reinterpret_cast<intptr_t>(borrowed) & kOwningBitMask, 0);
1026       return reinterpret_cast<intptr_t>(borrowed);
1027     }
1028 
OwnsPtr()1029     bool OwnsPtr() const { return kOwningBitMask & ptr_and_owning_bit_; }
1030 
MaybeDeleteOwned()1031     void MaybeDeleteOwned() {
1032       if (OwnsPtr()) {
1033         delete get();
1034       }
1035     }
1036 
1037     intptr_t ptr_and_owning_bit_ = 0;
1038   };
1039 
1040   // The parent class borrows this shape.
1041   MaybeOwningShapePtr shape_;
1042 
1043   // Implementation details shared between Populate() and PopulateParallel()
1044   template <typename NativeT, typename FnType>
1045   Status PopulateInternal(const FnType& generator, bool parallel);
1046 
1047   friend class LiteralBase;
1048   friend class MutableBorrowingLiteral;
1049 };
1050 std::ostream& operator<<(std::ostream& out, const Literal& literal);
1051 
1052 // The underlying buffer and shape is always owned by this class.
1053 class Literal : public MutableLiteralBase {
1054  public:
1055   Literal();
1056 
1057   // Create a literal of the given shape. The literal is allocated sufficient
1058   // memory to hold the shape. Memory is uninitialized.
1059   explicit Literal(const Shape& shape);
1060   virtual ~Literal();
1061 
1062   // Literals are moveable, but not copyable. To copy a literal use
1063   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
1064   // of literals which can be expensive.
1065   Literal(const Literal& other) = delete;
1066   Literal& operator=(const Literal& other) = delete;
1067   Literal(Literal&& other);
1068   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
1069   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
1070   // to nullptr.
1071   Literal(const Shape& shape, bool allocate_arrays,
1072           ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1073   Literal& operator=(Literal&& other);
1074 
1075   // Similar to CopyFrom, but with move semantics. The subshape of this literal
1076   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
1077   // (layouts and shapes must match), but need not be arrays. The memory
1078   // allocated in this literal for the subshape at dest_shape_index is
1079   // deallocated, and the respective buffers are replaced with those in
1080   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
1081   virtual Status MoveFrom(Literal&& src_literal,
1082                           const ShapeIndex& dest_shape_index = {});
1083 
1084   // Returns a vector containing the tuple elements of this Literal as separate
1085   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
1086   // elements are moved into the new Literals; no data is copied. Upon return
1087   // this Literal is set to a nil shape (empty tuple)
1088   //
1089   // TODO(jlebar): Because this function invalidates `this`, it should be
1090   // ref-qualified with &&.
1091   std::vector<Literal> DecomposeTuple();
1092 
1093   // Returns a subliteral specified by given shape_index. No data is copied, the
1094   // current literal becomes invalid after this function call.
1095   //
1096   // TODO(jlebar): Because this function invalidates `this`, it should be
1097   // ref-qualified with &&.
1098   Literal SubLiteral(ShapeIndexView shape_index);
1099 
1100  private:
1101   friend class LiteralBase;
1102   friend class MutableLiteralBase;
root_piece()1103   const Piece& root_piece() const override { return root_piece_; };
1104   // Deallocate the buffers held by this literal.
1105   void DeallocateBuffers();
1106 
1107   // Recursively sets the subshapes and buffers of all subpieces rooted at
1108   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
1109   // the shape.
1110   void SetPiece(
1111       const Shape& shape, Piece* piece, bool allocate_arrays,
1112       ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1113   Piece root_piece_;
1114 };
1115 
1116 // The underlying buffer is not owned by this class and is always owned by
1117 // others. The shape is not owned by this class and not mutable.
1118 class MutableBorrowingLiteral : public MutableLiteralBase {
1119  public:
1120   virtual ~MutableBorrowingLiteral();
1121 
MutableBorrowingLiteral()1122   MutableBorrowingLiteral() : MutableLiteralBase() {}
1123 
1124   MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
1125   MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
1126 
1127   // Implicit conversion constructors.
1128   MutableBorrowingLiteral(MutableLiteralBase* literal);
1129   MutableBorrowingLiteral(MutableBorrowingLiteral literal,
1130                           const ShapeIndex& view_root);
1131   MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
1132 
1133   // Create a literal from a list of buffers and a shape.
1134   // Returns a tuple literal if `shape` is a tuple type.
1135   MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
1136 
1137  private:
root_piece()1138   const Piece& root_piece() const override { return *root_piece_; };
1139   // Recursively copies the subtree from the `src_piece` at the given child
1140   // index to the `dest_piece`. For buffers only the pointers are copied, but
1141   // not the content.
1142   void CopyPieceSubtree(const Shape& shape, const Piece* src_piece,
1143                         Piece* dest_piece);
1144   Piece* root_piece_ = nullptr;
1145 };
1146 
1147 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
1148 // literal buffers always owned by others.
1149 class LiteralSlice : public LiteralBase {
1150  public:
LiteralSlice()1151   LiteralSlice() : LiteralBase() {}
1152 
1153   // Implicit conversion constructors.
1154   LiteralSlice(const LiteralBase& literal);
1155   LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
1156 
1157  private:
root_piece()1158   const Piece& root_piece() const override { return *root_piece_; };
1159 
1160   const Piece* root_piece_;  // Not owned.
1161 };
1162 
1163 // A read-only Literal where the underlying buffers are never owned by this
1164 // class.
1165 class BorrowingLiteral : public LiteralBase {
1166  public:
BorrowingLiteral()1167   BorrowingLiteral() : LiteralBase() {}
1168 
1169   // 'src_buf_ptr' is not owned by this class and must outlive the
1170   // lifetime of this class. It points to an appropriately sized buffer with
1171   // data interpretered as indicated by 'shape'.
1172   // This constructor is only used for array shapes.
1173   BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
1174   // Similar as above, except to be used for constructing non-nested tuples.
1175   BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
1176                    const Shape& shape);
1177   // TODO(b/79707221): adding constructors for nested tuples as well.
1178 
1179  private:
1180   // Recursively builds the subtree for the given piece and sets the subshapes
1181   // of the given piece with the given shape.
1182   void BuildPieceSubtree(const Shape& shape, Piece* piece);
1183 
1184   // Accessor for the root piece of this literal.
root_piece()1185   const Piece& root_piece() const override { return root_piece_; };
1186   Piece root_piece_;
1187 
1188   // Shape of this literal. Stored as unique_ptr such that the (default) move
1189   // construction of this class would be trivially correct: the pointer to Shape
1190   // root_piece_ stores will still point to the correct address.
1191   std::unique_ptr<Shape> shape_;
1192 };
1193 
1194 template <typename NativeT>
data()1195 absl::Span<const NativeT> LiteralBase::Piece::data() const {
1196   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1197   DCHECK_EQ(subshape().element_type(),
1198             primitive_util::NativeToPrimitiveType<NativeT>())
1199       << "Attempting to access "
1200       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
1201       << " type, but literal element type is "
1202       << PrimitiveType_Name(subshape().element_type());
1203   return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
1204                                    element_count());
1205 }
1206 
1207 template <typename NativeT>
data()1208 absl::Span<NativeT> LiteralBase::Piece::data() {
1209   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1210   DCHECK_EQ(subshape().element_type(),
1211             primitive_util::NativeToPrimitiveType<NativeT>())
1212       << "Attempting to access "
1213       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
1214       << " type, but literal element type is "
1215       << PrimitiveType_Name(subshape().element_type());
1216   return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
1217                              element_count());
1218 }
1219 
1220 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index)1221 NativeT LiteralBase::Piece::Get(absl::Span<const int64_t> multi_index) const {
1222   CHECK(LayoutUtil::IsDenseArray(subshape())) << subshape();
1223   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1224       subshape(), multi_index)];
1225 }
1226 
1227 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,NativeT value)1228 void LiteralBase::Piece::Set(absl::Span<const int64_t> multi_index,
1229                              NativeT value) {
1230   CHECK(LayoutUtil::IsDenseArray(subshape()));
1231   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1232       subshape(), multi_index)] = value;
1233 }
1234 
1235 template <typename NativeT>
data(const ShapeIndex & shape_index)1236 absl::Span<const NativeT> LiteralBase::data(
1237     const ShapeIndex& shape_index) const {
1238   return piece(shape_index).data<NativeT>();
1239 }
1240 
1241 template <typename NativeT>
data(const ShapeIndex & shape_index)1242 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
1243   return piece(shape_index).data<NativeT>();
1244 }
1245 
1246 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index)1247 inline NativeT LiteralBase::Get(absl::Span<const int64_t> multi_index,
1248                                 const ShapeIndex& shape_index) const {
1249   return piece(shape_index).Get<NativeT>(multi_index);
1250 }
1251 
1252 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index)1253 inline NativeT LiteralBase::Get(absl::Span<const int64_t> multi_index) const {
1254   return root_piece().Get<NativeT>(multi_index);
1255 }
1256 
1257 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index,NativeT value)1258 inline void MutableLiteralBase::Set(absl::Span<const int64_t> multi_index,
1259                                     const ShapeIndex& shape_index,
1260                                     NativeT value) {
1261   return piece(shape_index).Set<NativeT>(multi_index, value);
1262 }
1263 
1264 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,NativeT value)1265 inline void MutableLiteralBase::Set(absl::Span<const int64_t> multi_index,
1266                                     NativeT value) {
1267   return mutable_root_piece().Set<NativeT>(multi_index, value);
1268 }
1269 
1270 template <typename NativeT>
GetFirstElement()1271 NativeT LiteralBase::GetFirstElement() const {
1272   return data<NativeT>().at(0);
1273 }
1274 
1275 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64_t> indices,NativeT value)> per_cell)1276 void LiteralBase::EachCell(
1277     std::function<void(absl::Span<const int64_t> indices, NativeT value)>
1278         per_cell) const {
1279   if (ShapeUtil::IsZeroElementArray(shape())) {
1280     return;
1281   }
1282   std::vector<int64_t> indices(shape().rank(), 0);
1283 
1284   Shape shape_dynamic = shape();
1285   for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1286     shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1287   }
1288   do {
1289     per_cell(indices, Get<NativeT>(indices));
1290   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1291 }
1292 
1293 template <typename NativeT>
MutableEachCell(std::function<NativeT (absl::Span<const int64_t> indices,NativeT value)> per_cell)1294 void MutableLiteralBase::MutableEachCell(
1295     std::function<NativeT(absl::Span<const int64_t> indices, NativeT value)>
1296         per_cell) {
1297   if (ShapeUtil::IsZeroElementArray(shape())) {
1298     return;
1299   }
1300   std::vector<int64_t> indices(shape().rank(), 0);
1301   Shape shape_dynamic = shape();
1302   for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1303     shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1304   }
1305   do {
1306     Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
1307   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1308 }
1309 
1310 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)1311 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
1312   CHECK(shape().IsArray());
1313   CHECK_EQ(shape().rank(), 1);
1314   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1315   CHECK_EQ(shape().element_type(),
1316            primitive_util::NativeToPrimitiveType<NativeT>());
1317   auto data_span = data<NativeT>();
1318   std::copy(values.begin(), values.end(), data_span.begin());
1319 }
1320 
1321 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1322 void MutableLiteralBase::PopulateR2(
1323     std::initializer_list<std::initializer_list<NativeT>> values) {
1324   CHECK(shape().IsArray());
1325   CHECK_EQ(shape().rank(), 2);
1326   CHECK_EQ(shape().element_type(),
1327            primitive_util::NativeToPrimitiveType<NativeT>());
1328 
1329   const int64_t dim0_size = values.size();
1330   const int64_t dim1_size = values.begin()->size();
1331   CHECK_EQ(dim0_size, shape().dimensions(0));
1332   CHECK_EQ(dim1_size, shape().dimensions(1));
1333 
1334   int64_t dim0 = 0;
1335   for (auto inner_list : values) {
1336     int64_t dim1 = 0;
1337     for (auto value : inner_list) {
1338       Set({dim0, dim1}, value);
1339       ++dim1;
1340     }
1341     CHECK_EQ(dim1_size, dim1);
1342     ++dim0;
1343   }
1344 }
1345 
1346 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1347 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1348   CHECK(shape().IsArray());
1349   CHECK_EQ(shape().element_type(),
1350            primitive_util::NativeToPrimitiveType<NativeT>());
1351   CHECK_EQ(shape().rank(), values.num_dimensions());
1352   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1353     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1354   }
1355   values.Each([this](absl::Span<const int64_t> indices, NativeT value) {
1356     this->Set(indices, value);
1357   });
1358 }
1359 
1360 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1361 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1362   PopulateFromArray(values);
1363 }
1364 
1365 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1366 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1367   PopulateFromArray(values);
1368 }
1369 
1370 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1371 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1372   PopulateFromArray(values);
1373 }
1374 
1375 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1376 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1377                                             bool parallel) {
1378   const Shape& this_shape = shape();
1379   const int64_t rank = this_shape.rank();
1380   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1381   TF_RET_CHECK(this_shape.element_type() ==
1382                primitive_util::NativeToPrimitiveType<NativeT>())
1383       << "Failing to populate literal with element type "
1384       << primitive_util::LowercasePrimitiveTypeName(this_shape.element_type())
1385       << " using data of type "
1386       << primitive_util::LowercasePrimitiveTypeName(
1387              primitive_util::NativeToPrimitiveType<NativeT>());
1388   absl::Span<NativeT> literal_data = data<NativeT>();
1389   if (rank > 0) {
1390     StrideConfig stride_config(this_shape, this_shape, this_shape.dimensions());
1391     int64_t minor_dimension_size =
1392         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1393 
1394     auto init_function = [&](absl::Span<const int64_t> indexes, int thread_id) {
1395       DimensionVector minor_scan_indexes(rank, 0);
1396       const int64_t index =
1397           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1398       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1399       for (int64_t i = 0; i < minor_dimension_size; ++i) {
1400         minor_scan_indexes[stride_config.minor_dimension] = i;
1401         literal_data.at(index + i) = generator(minor_scan_indexes, thread_id);
1402       }
1403     };
1404     if (parallel) {
1405       ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1406                                       stride_config.dimensions,
1407                                       stride_config.step, init_function);
1408     } else {
1409       ShapeUtil::ForEachIndex(
1410           this_shape, stride_config.base, stride_config.dimensions,
1411           stride_config.step,
1412           [&init_function](absl::Span<const int64_t> indexes) {
1413             init_function(indexes, /*thread_id=*/-1);
1414             return true;
1415           });
1416     }
1417   } else {
1418     // For scalars.
1419     literal_data.at(0) = generator({}, /*thread_id=*/-1);
1420   }
1421   return OkStatus();
1422 }
1423 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1424 Status MutableLiteralBase::Populate(const FnType& generator) {
1425   return PopulateInternal<NativeT>(
1426       [&](absl::Span<const int64_t> indexes, int /*thread_id*/) {
1427         return generator(indexes);
1428       },
1429       /*parallel=*/false);
1430 }
1431 
1432 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1433 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1434   return PopulateInternal<NativeT>(
1435       [&](absl::Span<const int64_t> indexes, int thread_id) {
1436         return generator(indexes, thread_id);
1437       },
1438       /*parallel=*/true);
1439 }
1440 
1441 template <typename NativeT>
PopulateWithValue(NativeT value)1442 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1443   CHECK(shape().IsArray());
1444   CHECK_EQ(shape().element_type(),
1445            primitive_util::NativeToPrimitiveType<NativeT>());
1446   for (NativeT& element : data<NativeT>()) {
1447     element = value;
1448   }
1449 }
1450 
1451 template <typename NativeT>
Replicate(int64_t times)1452 Literal LiteralBase::Replicate(int64_t times) const {
1453   DimensionVector bounds = {times};
1454   bounds.reserve(shape().dimensions_size() + 1);
1455   for (int64_t bound : shape().dimensions()) {
1456     bounds.push_back(bound);
1457   }
1458   Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1459   int64_t elements = ShapeUtil::ElementsIn(literal.shape());
1460   if (elements == 0) {
1461     return literal;
1462   }
1463 
1464   DimensionVector output_indices(bounds.size(), 0);
1465   absl::Span<const int64_t> input_indices = output_indices;
1466   input_indices.remove_prefix(1);
1467 
1468   bool done = false;
1469   while (!done) {
1470     const auto element = Get<NativeT>(input_indices);
1471     literal.Set<NativeT>(output_indices, element);
1472 
1473     done = true;
1474     for (int n = 0; n < output_indices.size(); ++n) {
1475       ++output_indices[n];
1476       if (output_indices[n] < bounds[n]) {
1477         done = false;
1478         break;
1479       }
1480       output_indices[n] = 0;
1481     }
1482   }
1483   return literal;
1484 }
1485 
1486 }  // namespace xla
1487 
1488 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1489