1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18
19 #include <algorithm>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <optional>
26 #include <ostream>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31
32 #include "absl/strings/string_view.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/array2d.h"
35 #include "tensorflow/compiler/xla/array3d.h"
36 #include "tensorflow/compiler/xla/array4d.h"
37 #include "tensorflow/compiler/xla/index_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/protobuf.h"
49
50 namespace xla {
51
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56
57 // Abstract base class for literals.
58 class LiteralBase {
59 public:
60 virtual ~LiteralBase() = 0;
61
62 // Literals are equal if they have compatible shapes and the same data
63 // values. Layout is not compared.
64 bool operator==(const LiteralBase& other) const;
65 bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66
67 // Returns the shape of the literal.
shape()68 const Shape& shape() const { return root_piece().subshape(); }
69
70 // Serialize to proto.
71 LiteralProto ToProto() const;
72
73 // Returns a Span of the array for this literal for the given NativeT
74 // (e.g., float). CHECKs if the subshape of the literal at the given
75 // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76 // to native type.
77 template <typename NativeT>
78 absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79
80 // Returns a const pointer to (or size of) the underlying buffer holding the
81 // array at the given shape index. CHECKs if the subshape of the literal at
82 // the given ShapeIndex is not array.
83 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84 int64_t size_bytes(const ShapeIndex& shape_index = {}) const;
85
86 // Returns this literal's data as a string. This literal must be a rank-1 U8
87 // array.
88 std::string GetR1U8AsString() const;
89
90 // Returns a string representation of the literal value. The Shape of the
91 // literal is a prefix of the literal value in the string.
92
93 // Warning: this function can take minutes for multi-million
94 // element Literals.
95 std::string ToString() const;
96
97 // Similar to ToString, but return the result in a compact
98 // one-line form.
99 std::string ToStringOneline() const;
100
101 // Returns a string representation of the literal value which does *not*
102 // include the shape string.
103 std::string ToStringWithoutShape() const;
104
105 // Similar to ToStringWithoutShape, but return the result in a compact
106 // one-line form.
107 std::string ToStringWithoutShapeOneline() const;
108
109 // Returns a string representation of the literal value which includes the
110 // shape string with its layout.does *not* include the shape string.
111 std::string ToStringWithLayout() const;
112
113 // Similar to ToStringWithLayout, but return the result in a compact
114 // one-line form.
115 std::string ToStringWithLayoutOneline() const;
116
117 // Gets an element in the literal at the given index. The multi_index is
118 // CHECKed against the dimension sizes.
119 template <typename NativeT>
120 NativeT Get(absl::Span<const int64_t> multi_index,
121 const ShapeIndex& shape_index) const;
122 // Overloads of Get for array literals. CHECKs if the literal is not
123 // array-shaped and dense.
124 template <typename NativeT>
125 NativeT Get(absl::Span<const int64_t> multi_index) const;
126
127 // Get the dynamic size on dim_index in the literal at the given shape_index.
128 int32_t GetDynamicSize(int64_t dim_index,
129 const ShapeIndex& shape_index) const;
130 int32_t GetDynamicSize(int64_t dim_index) const;
131
132 // Returns the element value at index (0, ..., 0), however many zeroes are
133 // required for that index.
134 template <typename NativeT>
135 NativeT GetFirstElement() const;
136
137 // As above but returns any integer type casted to an int64_t.
138 std::optional<int64_t> GetFirstInteger() const;
139
140 // As Get(), but determines the correct type and converts the value
141 // into text.
142 std::string GetAsString(absl::Span<const int64_t> multi_index,
143 const ShapeIndex& shape_index = {}) const;
144
145 // Return whether the value at the specified index is equal to the provided
146 // generic `value` (T must be an arithmetic type).
147 //
148 // Precondition: must be an array.
149 template <typename T>
150 typename std::enable_if<(std::is_arithmetic<T>::value ||
151 std::is_same<T, Eigen::half>::value ||
152 std::is_same<T, bfloat16>::value),
153 bool>::type
IsEqualAt(absl::Span<const int64_t> multi_index,T value)154 IsEqualAt(absl::Span<const int64_t> multi_index, T value) const {
155 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
156 return *as_s64 == value;
157 }
158 complex128 as_complex128 = *GetAsComplex128(multi_index);
159 return as_complex128.imag() == 0 && as_complex128.real() == value;
160 }
161
IsEqualAt(absl::Span<const int64_t> multi_index,complex128 value)162 bool IsEqualAt(absl::Span<const int64_t> multi_index,
163 complex128 value) const {
164 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
165 return *as_s64 == value.real() && value.imag() == 0;
166 }
167 auto as_complex128 = GetAsComplex128(multi_index);
168 return *as_complex128 == value;
169 }
170
171 // As Get(), but determines the correct type and converts the value into
172 // int64_t. This literal must be an array.
173 std::optional<int64_t> GetIntegralAsS64(
174 absl::Span<const int64_t> multi_index) const;
175
176 // As Get(), but determines the correct type, and converts the value into
177 // double. This literal must be an array.
178 std::optional<double> GetAsDouble(
179 absl::Span<const int64_t> multi_index) const;
180
181 // As Get(), but determines the correct type, and converts the value into
182 // complex128. All floating point types can be converted into complex128.
183 //
184 // This literal must be an array.
185 std::optional<complex128> GetAsComplex128(
186 absl::Span<const int64_t> multi_index) const;
187
188 // Invokes the "per cell" callback for each element in the provided
189 // literal with the element's indices and a string representation of
190 // the element's value.
191 //
192 // This function is useful if you want a polymorphic representation
193 // of the tensor's elements (turning it to a string for something
194 // like representation in a protobuf).
195 //
196 // This literal must have a dense layout.
197 void EachCellAsString(
198 const std::function<void(absl::Span<const int64_t> indices,
199 const std::string& value)>& per_cell) const;
200 template <typename NativeT>
201 void EachCell(
202 std::function<void(absl::Span<const int64_t> indices, NativeT value)>
203 per_cell) const;
204
205 // Checks whether all of this literal's values are equal to the given scalar
206 // literal.
207 //
208 // If `this` is not an array (e.g. it's a tuple), returns false. This is
209 // simpler than trying to handle subshapes here, and it's almost always what
210 // you want.
211 //
212 // Preconditions:
213 // - `scalar` is a scalar.
214 // - `scalar` has the same element-type as `this`.
215 bool IsAll(const Literal& scalar) const;
216
217 // Returns whether every element in this literal is equal to value.
218 //
219 // value is an int8_t because we expect this to be called with small
220 // compile-time constants (0, -1, etc.) and so that whatever value you pass
221 // can be represented exactly by floating-point types as small as 16 bits.
222 //
223 // If value doesn't fit in this literal's type, returns false. Values of 1/0
224 // are considered equal to true/false; other values are not considered equal
225 // to true.
226 //
227 // Returns false if this literal is not array-shaped.
228 bool IsAll(int8_t value) const;
229
230 // Like IsAll(int8_t), except we check whether the literal is equal to a
231 // particular floating-point or complex number.
232 //
233 // Returns false if this literal is not a floating-point / complex value, or
234 // if it's not an array.
235 //
236 // This casts value to the type of literal, then compares using ==, with the
237 // caveat that NaNs are considered equal. The usual admonishments about
238 // floating-point equality checks apply. We expect you to use this to check
239 // for values that can be expressed precisely as a float, e.g. -0.5.
240 bool IsAllFloat(float value) const;
241 bool IsAllComplex(complex64 value) const;
242
243 // Deetermines if this literal consists entirely of the first element of the
244 // literal.
245 //
246 // Returns false if this literal is not an array.
247 bool IsAllFirst() const;
248
249 // Literal consists entirely of an iota.
250 bool IsR1Iota() const;
251
252 // Returns the stride if the literal is a strided iota.
253 std::optional<int64_t> IsR1StridedIota() const;
254
255 // Returns whether this literal is zero at the specified index. This literal
256 // must be an array with a dense layout.
257 bool IsZero(absl::Span<const int64_t> indices) const;
258
259 // Returns the count of the elements in the array at the given shape index in
260 // this literal.
261 int64_t element_count(const ShapeIndex& index = {}) const {
262 if (index.empty()) {
263 // Common case, avoid GetSubshape().
264 return ShapeUtil::ElementsIn(shape());
265 }
266 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
267 }
268
269 // Compute a hash for this literal.
270 template <typename H>
AbslHashValue(H state,const LiteralBase & value)271 friend H AbslHashValue(H state, const LiteralBase& value) {
272 return LiteralBase::Hash(std::move(state), value);
273 }
274
275 template <typename H, bool kIsLayoutSensitive = true,
276 int64_t kByteLimit = std::numeric_limits<int64_t>::max()>
Hash(H state,const LiteralBase & literal)277 static H Hash(H state, const LiteralBase& literal) {
278 state =
279 Shape::Hash<H, kIsLayoutSensitive>(std::move(state), literal.shape());
280
281 ShapeUtil::ForEachSubshape(
282 literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) {
283 if (!subshape.IsArray()) {
284 return;
285 }
286
287 CHECK(LayoutUtil::IsDenseArray(subshape));
288 auto data = absl::MakeConstSpan(
289 static_cast<const char*>(literal.untyped_data(index)),
290 std::min(kByteLimit, literal.size_bytes(index)));
291 state = H::combine(std::move(state), data);
292 });
293
294 return std::move(state);
295 }
296
297 // Converts this literal to the given shape. Returns an error is the
298 // conversion is not possible.
299 StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
300
301 // Converts this literal to another primitive type using a bitcast
302 // conversion. Returns an error if the conversion is not possible. This
303 // literal must be array-shaped.
304 StatusOr<Literal> BitcastConvert(const Shape& dest_shape) const;
305
306 // Converts this literal to another primitive type. Returns an error if the
307 // conversion is not possible. This literal must be array-shaped.
308 StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
309
310 // Clones the underlying buffers into a new Literal.
311 Literal Clone() const;
312 std::unique_ptr<Literal> CloneToUnique() const;
313
314 // TODO(b/67651157): The methods below which perform computation on Literals
315 // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
316 // evaluator code which operates on Literals.
317 //
318 // Creates a new value that has the equivalent value as this
319 // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
320 // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
321 // minor-to-major dimension layout and the value in the cell at any given
322 // logical index (i0, i1) will be the same.
323 //
324 // For tuple shaped literals, shape_index should be used to select the inner
325 // array that the new layout applies to.
326 //
327 // Note: this is useful when the client wants to ensure that a value placed in
328 // the XLA allocation tracker has a particular layout; for efficiency
329 // purposes or avoiding unimplemented operation/layout combinations.
330 Literal Relayout(const Layout& new_layout,
331 const ShapeIndex& shape_index = {}) const;
332
333 // An overload of Relayout which changes the layout of the entire shape rather
334 // than being limited to a single array within the shape.
335 Literal Relayout(const Shape& shape_with_layout) const;
336
337 // Generate a new literal whose static sizes are equal to the previous
338 // literal's dynamic sizes.
339 Literal ToStatic() const;
340
341 // Expand a static literal into a new one with a bounded dyanmic literal. The
342 // static dimensions of the original literal becomes dynamic dimensions of the
343 // new literal, where the argument `bounded_shape` becomes the bounded shape
344 // of the new literal.
345 //
346 // Precondition: bounded_shape.is_dynamic()
347 Literal ToBoundedDynamic(const Shape& bounded_shape) const;
348
349 // Creates a new literal by reshaping this literal to have the given
350 // dimensions. The total number of elements must not change; The
351 // implementation currently only supports monotonic dim0-major layouts.
352 // This literal must be an array.
353 StatusOr<Literal> Reshape(absl::Span<const int64_t> dimensions) const;
354
355 // Creates a new literal by broadcasting this literal with `dimensions` to
356 // yield a literal of shape `result_shape`.
357 StatusOr<Literal> Broadcast(const Shape& result_shape,
358 absl::Span<const int64_t> dimensions) const;
359
360 // Creates a new literal by reordering the dimensions of this literal.
361 // The given `permutation` must be a permutation of the dimension numbers
362 // in the original literal, and it specifies the order of the new dimensions
363 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
364 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
365 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
366 // This literal must be an array.
367 Literal Transpose(absl::Span<const int64_t> permutation) const;
368
369 // Creates a sub-array from this literal by extracting the indices
370 // [start_index, limit_index) of each dimension. The result literal has the
371 // same rank and layout as for the given literal. The number of indices in
372 // start_indices and limit_indices must be the rank of the literal, and the
373 // indices follow the order of the dimensions.
374 // This literal must be an array.
375 Literal Slice(absl::Span<const int64_t> start_indices,
376 absl::Span<const int64_t> limit_indices) const;
377
378 // Creates a literal with a prepended dimension with bound "times"; e.g. a
379 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
380 // literal replicated four times.
381 // This literal must be an array.
382 template <typename NativeT>
383 Literal Replicate(int64_t times) const;
384
385 // Returns true if the leaf arrays of the literal within the given shape index
386 // are all determined.
387 // See comments on ArrayValueState for detailed explanation.
388 bool IsDetermined(const ShapeIndex& shape_index = {}) const;
389
390 // Returns true if the leaf arrays of the literal within the given shape index
391 // are all known.
392 // See comments on ArrayValueState for detailed explanation.
393 bool IsKnown(const ShapeIndex& shape_index = {}) const;
394
395 // Creates a new Literal object with the shape specified as parameter.
396 // The content of the literal values is the default value of the primitive
397 // type of literal itself (0 for numeric types, and false for predicates).
398 //
399 // Note: It's an antipattern to use this method then immediately call
400 // MutableLiteralBase::Populate on the result (since that results in zero
401 // initialization, then reinitialization. Consider if a call to
402 // std::make_unique<Literal>(shape), followed by the call to
403 // MutableLiteralBase::Populate can be used instead.
404 static Literal CreateFromShape(const Shape& shape);
405
406 // WARNING: These two functions are only supposed to be used by HloEvaluator.
407 // The rest of XLA assumes all literals are known.
408 // Similar to CreateFromShape() but marks all leaf arrays as unknown.
409 static Literal CreateFromShapeWithUnknownLeafArrays(const Shape& shape);
410 // Similar to CreateFromShape() but marks all leaf arrays as undetermined.
411 static Literal CreateFromShapeWithUndeterminedLeafArrays(const Shape& shape);
412
413 protected:
414 // Array literals could be in one of the following three states:
415 // 1) Known: we have evaluated and known the value of the array literal.
416 // 2) Unknown: we have tried to evaluate the array literal, but its value
417 // cannot be evaluated statically.
418 // 3) Undetermined: we haven't tried to evaluate the array literal.
419 // Unknown and Undetermined states are only meant to be used within
420 // HloEvaluator. The rest of XLA assumes array literals are all known.
421 // Literals that are unknown or undetermined can be copied from, using
422 // CopyFrom and Clone, or moved from using move constructor. Accessing values
423 // of such literals causes undefined behavior.
424 enum class ArrayValueState { kKnown = 0, kUnknown = 1, kUndetermined = 2 };
425
426 // A data structure representing a subshape at a particular ShapeIndex within
427 // the literal. For array-shaped ShapeIndexes, this data structure holds the
428 // pointer to the memory allocated for the array data.
429 class Piece {
430 public:
431 ArrayValueState get_array_value_state() const;
432 void set_array_value_state(ArrayValueState state);
433 // Returns the buffer holding the array data for this piece as an array
434 // slice. This piece must be array-shaped.
435 template <typename NativeT>
436 absl::Span<const NativeT> data() const;
437 template <typename NativeT>
438 absl::Span<NativeT> data();
439
440 // Returns the buffer holding the array data for this piece as a void*. This
441 // piece must be array-shaped.
442 void* untyped_data();
443 const void* untyped_data() const;
444
445 // Gets or sets an element in the array at the given index. The multi_index
446 // is CHECKed against the dimension sizes of the array. This piece must be
447 // array-shaped.
448 template <typename NativeT>
449 NativeT Get(absl::Span<const int64_t> index) const;
450 template <typename NativeT>
451 void Set(absl::Span<const int64_t> index, NativeT value);
452
453 int32_t GetDynamicSize(int64_t dim_index) const;
454 void SetDynamicSize(int64_t dim_index, int32_t size);
455 void AllocateBuffers();
456 void DeallocateBuffers();
457 // Gets/sets the buffer holding the array data.
buffer()458 const char* buffer() const { return std::visit(BufferVisitor{}, rep_); }
buffer()459 char* buffer() {
460 return const_cast<char*>(const_cast<const Piece*>(this)->buffer());
461 }
set_buffer(char * buffer)462 void set_buffer(char* buffer) {
463 CHECK(subshape_->IsArray());
464 auto* array_rep = std::holds_alternative<Uninitialized>(rep_)
465 ? &rep_.emplace<ArrayRep>()
466 : GetArrayRep();
467 DCHECK(array_rep);
468 array_rep->data = buffer;
469 }
MoveDataFrom(Piece & from)470 void MoveDataFrom(Piece& from) {
471 DCHECK(!std::holds_alternative<ArrayRep>(rep_));
472 DCHECK(!std::holds_alternative<TupleRep>(rep_));
473 if (auto* array_rep = from.GetArrayRep()) {
474 rep_.emplace<ArrayRep>().data = array_rep->data;
475 } else if (auto* inlined_rep = from.GetInlinedRep()) {
476 std::memcpy(rep_.emplace<InlinedRep>().data, inlined_rep->data,
477 from.total_bytes());
478 }
479 from.rep_.emplace<Uninitialized>();
480 }
481
482 // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()483 const int32_t* dynamic_size_buffer() const {
484 return reinterpret_cast<const int32_t*>(buffer() + size_bytes());
485 }
dynamic_size_buffer()486 int32_t* dynamic_size_buffer() {
487 return const_cast<int32_t*>(
488 const_cast<const Piece*>(this)->dynamic_size_buffer());
489 }
490
dynamic_size_buffer_bytes()491 int64_t dynamic_size_buffer_bytes() const {
492 return subshape().dimensions_size() * sizeof(int32_t);
493 }
494
495 // Gets or sets the subshape of this piece. This reference points to a
496 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()497 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)498 void set_subshape(const Shape* subshape) {
499 subshape_ = subshape;
500 if (std::holds_alternative<Uninitialized>(rep_)) {
501 if (subshape_->IsTuple()) {
502 rep_.emplace<TupleRep>();
503 }
504 }
505 }
506
507 // Returns the size in bytes of the buffer holding the array data.
size_bytes()508 int64_t size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
509
510 // Total size in bytes, including the dynamic size addition.
511 //
512 // The shape can become dynamic after this literal is allocated, so we
513 // over-allocate the margin for the dynamic shape description in case we
514 // need it.
total_bytes()515 int64_t total_bytes() const {
516 return size_bytes() + dynamic_size_buffer_bytes();
517 }
518
519 // Returns the number of elements in this piece's array.
element_count()520 int64_t element_count() const { return ShapeUtil::ElementsIn(subshape()); }
521
522 // Returns the child piece at 'index' of this piece.
child(int64_t index)523 Piece& child(int64_t index) {
524 return const_cast<Piece&>(const_cast<const Piece*>(this)->child(index));
525 }
child(int64_t index)526 const Piece& child(int64_t index) const {
527 auto* tuple_rep = GetTupelRep();
528 DCHECK(tuple_rep);
529 return tuple_rep->children[index];
530 }
531
532 // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)533 void emplace_back(Piece child_piece) {
534 auto* tuple_rep = GetTupelRep();
535 DCHECK(tuple_rep);
536 tuple_rep->children.emplace_back(std::move(child_piece));
537 }
538
539 // Returns the size of children pieces of this piece.
children_size()540 int64_t children_size() {
541 if (auto* tuple_rep = GetTupelRep()) {
542 return tuple_rep->children.size();
543 }
544 return 0;
545 }
546
547 // Visitor functions that recursively traverses the piece and calls the
548 // given function at each child piece. The function has the type:
549 // void (const ShapeIndex& index, const Piece& piece)
550 template <typename Fn>
ForEachSubpiece(const Fn & func)551 void ForEachSubpiece(const Fn& func) const {
552 ShapeIndex index;
553 return ForEachHelper(
554 [&func](const ShapeIndex& index, const Piece& piece) {
555 func(index, piece);
556 return OkStatus();
557 },
558 *this, &index)
559 .IgnoreError();
560 }
561 // Same as above, but the function has the type:
562 // Status (const ShapeIndex& index, const Piece& piece)
563 // The first non-OK return value is returned by the function.
564 template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)565 Status ForEachSubpieceWithStatus(const Fn& func) const {
566 ShapeIndex index;
567 return ForEachHelper(func, *this, &index);
568 }
569 // Same as above, but the function has the type:
570 // Bool (const ShapeIndex& index, const Piece& piece)
571 // The first non-true return value is returned by the function.
572 template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)573 bool ForEachSubpieceWithBool(const Fn& func) const {
574 ShapeIndex index;
575 return ForEachHelperBool(func, *this, &index);
576 }
577 // Same as above, but the function has the type:
578 // Void (const ShapeIndex& index, Piece& piece)
579 template <typename Fn>
ForEachMutableSubpiece(const Fn & func)580 void ForEachMutableSubpiece(const Fn& func) {
581 ShapeIndex index;
582 return ForEachMutableHelper(
583 [&func](const ShapeIndex& index, Piece* piece) {
584 func(index, piece);
585 return OkStatus();
586 },
587 const_cast<xla::LiteralBase::Piece*>(this), &index)
588 .IgnoreError();
589 }
590 // Same as above, but the function has the type:
591 // Status (const ShapeIndex& index, Piece& piece)
592 // The first non-OK return value is returned by the function.
593 template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)594 Status ForEachMutableSubpieceWithStatus(const Fn& func) {
595 ShapeIndex index;
596 return ForEachMutableHelper(
597 func, const_cast<xla::LiteralBase::Piece*>(this), &index);
598 }
599
600 // Checks whether all elements of this Piece are equal to the given literal.
601 //
602 // Returns false if this Piece is not an array.
603 //
604 // Preconditions:
605 // - `scalar` is a scalar.
606 // - `scalar`'s type matches that of `this`.
607 bool IsAll(const Literal& scalar) const;
608
609 // Returns true if this piece and 'other' contain the same data. This piece
610 // and 'other' must be array-shaped and compatible. If a literal has dynamic
611 // shape, comparison is done only for the valid elements.
612 bool EqualElements(const Piece& other) const;
613
614 // Returns true if this piece and other pieces have the same dynamic
615 // dimension sizes.
616 bool EqualDynamicSize(const Piece& other) const;
617
618 // Writes the shape and data (if array-shaped) into the given proto.
619 void WriteToProto(LiteralProto* proto) const;
620
621 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
622 // and src must be compatible. If only_dynamic_bound is true, only elements
623 // within dynamic bounds will be copied.
624 Status CopyFrom(const Piece& src, bool only_dynamic_bound);
625
626 // Copies the data from the given proto into this piece. The shape of this
627 // piece must be equal (not just compatible) to the shape of the proto.
628 Status CopyFromProto(const LiteralProto& proto);
629
630 // See comments on ArrayValueState for detailed explanation.
631 bool IsDetermined() const;
632
633 bool IsKnown() const;
634
635 private:
636 // Uninitialized state representation.
637 struct Uninitialized {};
638 // Out of line array storage.
639 union ArrayRep {
640 char* data;
641 };
642 struct TupleRep {
643 // Children pieces for tuple shaped pieces.
644 std::vector<Piece> children = {};
645 };
646
647 // Use just so many bytes that we don't increase the sizeof(Piece).
648 static inline constexpr size_t kMaxInlinedBytes =
649 std::max(sizeof(ArrayRep), sizeof(TupleRep));
650
651 // Inlined array storage.
652 struct InlinedRep {
653 char data[kMaxInlinedBytes];
654 };
655
656 // Helper visiter to access the buffer in the representation variant.
657 struct BufferVisitor {
operatorBufferVisitor658 char* operator()(Uninitialized&) { return nullptr; }
operatorBufferVisitor659 const char* operator()(const Uninitialized&) const { return nullptr; }
operatorBufferVisitor660 char* operator()(TupleRep&) { return nullptr; }
operatorBufferVisitor661 const char* operator()(const TupleRep&) const { return nullptr; }
operatorBufferVisitor662 char* operator()(InlinedRep& rep) { return rep.data; }
operatorBufferVisitor663 const char* operator()(const InlinedRep& rep) const { return rep.data; }
operatorBufferVisitor664 char* operator()(ArrayRep& rep) { return rep.data; }
operatorBufferVisitor665 const char* operator()(const ArrayRep& rep) const { return rep.data; }
666 };
667
GetInlinedRep()668 const InlinedRep* GetInlinedRep() const {
669 return std::get_if<InlinedRep>(&rep_);
670 }
GetInlinedRep()671 InlinedRep* GetInlinedRep() { return std::get_if<InlinedRep>(&rep_); }
672
GetArrayRep()673 const ArrayRep* GetArrayRep() const { return std::get_if<ArrayRep>(&rep_); }
GetArrayRep()674 ArrayRep* GetArrayRep() { return std::get_if<ArrayRep>(&rep_); }
675
GetTupelRep()676 const TupleRep* GetTupelRep() const { return std::get_if<TupleRep>(&rep_); }
GetTupelRep()677 TupleRep* GetTupelRep() { return std::get_if<TupleRep>(&rep_); }
678 // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
679 // The first non-OK (or non-true) value is returned by the function.
680 // The callable 'func' has the same signature as described above in
681 // ForEachSubpiece*.
682 template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)683 Status ForEachHelper(const Fn& func, const Piece& piece,
684 ShapeIndex* index) const {
685 TF_RETURN_IF_ERROR(func(*index, piece));
686 if (auto* tuple_rep = piece.GetTupelRep()) {
687 for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
688 index->push_back(i);
689 TF_RETURN_IF_ERROR(
690 ForEachHelper(func, tuple_rep->children[i], index));
691 index->pop_back();
692 }
693 }
694 return OkStatus();
695 }
696 template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)697 bool ForEachHelperBool(const Fn& func, const Piece& piece,
698 ShapeIndex* index) const {
699 if (!func(*index, piece)) {
700 return false;
701 }
702 if (auto* tuple_rep = piece.GetTupelRep()) {
703 for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
704 index->push_back(i);
705 if (!ForEachHelperBool(func, tuple_rep->children[i], index)) {
706 return false;
707 }
708 index->pop_back();
709 }
710 }
711 return true;
712 }
713 template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)714 Status ForEachMutableHelper(const Fn& func, Piece* piece,
715 ShapeIndex* index) {
716 TF_RETURN_IF_ERROR(func(*index, piece));
717 if (auto* tuple_rep = piece->GetTupelRep()) {
718 for (int64_t i = 0; i < tuple_rep->children.size(); ++i) {
719 index->push_back(i);
720 TF_RETURN_IF_ERROR(
721 ForEachMutableHelper(func, &tuple_rep->children[i], index));
722 index->pop_back();
723 }
724 }
725 return OkStatus();
726 }
727
728 // Recursive helper for EqualElements.
729 template <typename NativeT>
730 bool EqualElementsInternal(const Piece& other,
731 std::vector<int64_t>* multi_index) const;
732
733 // Internal helper to copy elements from another given piece
734 template <typename NativeT>
735 void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
736
737 // Storage representation of this piece.
738 std::variant<Uninitialized, InlinedRep, ArrayRep, TupleRep> rep_;
739
740 // The shape of piece. This points into the shape of the containing Literal
741 // (Literal::shape_).
742 const Shape* subshape_ = nullptr;
743
744 ArrayValueState array_value_state_ = ArrayValueState::kKnown;
745 }; // class Piece
746
piece(const ShapeIndex & shape_index)747 const Piece& piece(const ShapeIndex& shape_index) const {
748 Piece* piece = &const_cast<Piece&>(root_piece());
749 for (const auto i : shape_index) {
750 DCHECK_GE(i, 0);
751 DCHECK_LT(i, piece->children_size());
752 piece = &piece->child(i);
753 }
754 return *piece;
755 }
756
757 // Returns the piece at the root of the shape.
758 virtual const Piece& root_piece() const = 0;
759
760 // LiteralSlice and Literal must access Pieces of other Literals.
761 friend class MutableLiteralBase;
762 friend class LiteralSlice;
763 friend class BorrowingLiteral;
764
765 private:
766 template <typename NativeT>
767 Literal SliceInternal(const Shape& result_shape,
768 absl::Span<const int64_t> start_indices) const;
769 };
770
771 // Abstract base class representing a mutable literal in XLA.
772 class MutableLiteralBase : public LiteralBase {
773 public:
774 virtual ~MutableLiteralBase() = 0;
775
776 // Returns a Span view of the array for this literal for the
777 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
778 // given ShapeIndex is not array. See primitive_util.h for the mapping from
779 // XLA type to native type.
780 template <typename NativeT>
781 absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
782 // Unhide const method from parent class.
783 using LiteralBase::data;
784
785 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
786 // mutate the shape as this can produce malformed Literals.
787 Shape* mutable_shape_do_not_use();
788
789 // Set the dynamic size on dim_index in the literal at the given shape_index.
790 void SetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index,
791 int32_t size);
792 void SetDynamicSize(int64_t dim_index, int32_t size);
793
794 // Returns a pointer to the underlying buffer holding the array at the given
795 // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
796 // is not array.
797 void* untyped_data(const ShapeIndex& shape_index = {});
798 // Unhide const method from parent class.
799 using LiteralBase::untyped_data;
800
801 template <typename NativeT>
802 void MutableEachCell(
803 std::function<NativeT(absl::Span<const int64_t> indices, NativeT value)>
804 per_cell);
805
806 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
807 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
808 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
809 // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
810 // is true, only elements within dynamic bounds will be copied.
811 Status CopyFrom(const LiteralSlice& src_literal,
812 const ShapeIndex& dest_shape_index = {},
813 const ShapeIndex& src_shape_index = {},
814 bool only_dynamic_bound = false);
815
816 // Copies the values from src_literal, starting at src_base shape indexes,
817 // to this literal, starting at dest_base, where the copy size in each
818 // dimension is specified by copy_size.
819 // The src_literal and this literal must have the same primitive type,
820 // src_base+copy_size must fit the source literal dimensions, as well as
821 // dest_base+copy_size must fit the destination literal dimensions.
822 // Note: if either src_literal or this literal contains dimensions with zero
823 // element, then copy_size must be 0 in these dimensions while the
824 // corresponding base indices being 0.
825 // This literal and 'src_literal' must be arrays.
826 Status CopySliceFrom(const LiteralSlice& src_literal,
827 absl::Span<const int64_t> src_base,
828 absl::Span<const int64_t> dest_base,
829 absl::Span<const int64_t> copy_size);
830
831 // Copies one element from src_literal[src_index] to (*this)[dest_index].
832 Status CopyElementFrom(const LiteralSlice& src_literal,
833 absl::Span<const int64_t> src_index,
834 absl::Span<const int64_t> dest_index);
835
836 // Sets an element in the literal at the given index. The multi_index is
837 // CHECKed against the dimension sizes.
838 template <typename NativeT>
839 void Set(absl::Span<const int64_t> multi_index, const ShapeIndex& shape_index,
840 NativeT value);
841 // Overloads of Set for array literals. CHECKs if the literal is not
842 // array-shaped and dense.
843 template <typename NativeT>
844 void Set(absl::Span<const int64_t> multi_index, NativeT value);
845
846 // As Set(), but truncates `value` to the literal element type before storing.
847 // This literal must be an array.
848 Status SetIntegralAsS64(absl::Span<const int64_t> multi_index, int64_t value);
849
850 // As Set(), but truncates `value` to the literal element type before storing.
851 // This literal must be an array.
852 Status SetFromDouble(absl::Span<const int64_t> multi_index, double value);
853
854 // Populate this literal with the given values. Examples:
855 //
856 // // Populate with floats.
857 // Array2D<float> float_values = ...
858 // literal.PopulateR2FromArray2D(values);
859 //
860 // // Populate with int32s.
861 // literal.PopulateR2<int32_t>({{1, 2}, {3, 4}});
862 //
863 // The shape and element type of this literal must match given values. For
864 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
865 // array of S32.
866 template <typename NativeT>
867 void PopulateR1(absl::Span<const NativeT> values);
868 void PopulateR1(const tensorflow::core::Bitmap& values);
869 template <typename NativeT>
870 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
871 template <typename NativeT>
872 void PopulateFromArray(const Array<NativeT>& values);
873 template <typename NativeT>
874 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
875 template <typename NativeT>
876 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
877 template <typename NativeT>
878 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
879
880 // Populates literal values by calling the generator function for every cell
881 // in this literal object.
882 //
883 // generator must be a callable of the type
884 // NativeT(absl::Span<const int64_t> indexes) or compatible.
885 //
886 // This literal must have a dense layout.
887 template <typename NativeT, typename FnType>
888 Status Populate(const FnType& generator);
889
890 // A parallel version of Populate(). This can be used if the generator is
891 // thread-safe and the values for the shape's different elements are
892 // independent.
893 template <typename NativeT, typename FnType>
894 Status PopulateParallel(const FnType& generator);
895
896 // Fills this literal with the given value.
897 template <typename NativeT>
898 void PopulateWithValue(NativeT value);
899
900 // This operation is the inverse of DecomposeTuple. The given elements are
901 // moved into the tuple elements of a new tuple-shaped Literal which is
902 // returned. Upon return, each of the Literals in 'elements' is set to a nil
903 // shape (empty tuple).
904 static Literal MoveIntoTuple(absl::Span<Literal> elements);
905
906 // Serialize from a proto.
907 static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
908 bool prohibit_empty_literal = true);
909
910 protected:
911 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)912 Piece& piece(const ShapeIndex& shape_index) {
913 return const_cast<Piece&>(LiteralBase::piece(shape_index));
914 }
915
mutable_root_piece()916 Piece& mutable_root_piece() { return const_cast<Piece&>(root_piece()); }
917
918 // Internal template helper for the Literal::CopySliceFrom(), matching its
919 // arguments one by one.
920 template <typename NativeT>
921 Status CopySliceFromInternal(const LiteralBase& src_literal,
922 absl::Span<const int64_t> src_base,
923 absl::Span<const int64_t> dest_base,
924 absl::Span<const int64_t> copy_size);
925
926 // Utility structure which is used to create the optimal configuration for
927 // a ShapeUtil::ForEachIndex() scan across two literals.
928 struct StrideConfig {
929 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
930 absl::Span<const int64_t> dimensions);
931
932 // The dimensions of the stride operation. Essentially every dimension
933 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
934 // steps.
935 absl::Span<const int64_t> dimensions;
936 DimensionVector base;
937 DimensionVector step;
938 int64_t minor_dimension = 0;
939 // The size of the strides for source and destination. One of the two
940 // (the one looping through its most minor dimension) will be 1, while
941 // the other will be the stride size at the dimension matching the other
942 // shape most minor dimension being scanned.
943 int64_t dest_stride = 1;
944 int64_t source_stride = 1;
945 // The size of the inner loop on the most minor dimension.
946 int64_t minor_loop_size = 1;
947 };
948
949 // A unique_ptr like class which may or may not have ownership of its pointer.
950 // The literal may or may not own the storage of the shape. Creating/copying a
951 // shape can incur significant overhead which in many case we'd like to avoid,
952 // esp. for small literals.
953 class MaybeOwningShapePtr {
954 public:
955 MaybeOwningShapePtr() = default;
MaybeOwningShapePtr(std::unique_ptr<Shape> unique)956 explicit MaybeOwningShapePtr(std::unique_ptr<Shape> unique)
957 : ptr_and_owning_bit_(TakeUnique(std::move(unique))) {}
958
MaybeOwningShapePtr(const Shape * borrowed)959 explicit MaybeOwningShapePtr(const Shape* borrowed)
960 : ptr_and_owning_bit_(Borrow(borrowed)) {}
961
~MaybeOwningShapePtr()962 ~MaybeOwningShapePtr() { MaybeDeleteOwned(); }
963
get()964 const Shape* get() const {
965 return reinterpret_cast<const Shape*>(ptr_and_owning_bit_ & kPointerMask);
966 }
967 Shape* get_mutable(bool ensure_owned = false) {
968 const Shape* const_ptr = get();
969 // TODO(b/67651157): Remove this copy on write logic and combine get() and
970 // get_mutable() once we remove mutable_shape_do_not_use().
971 if (const_ptr && !OwnsPtr()) {
972 ptr_and_owning_bit_ = TakeUnique(std::make_unique<Shape>(*const_ptr));
973 const_ptr = get();
974 }
975 DCHECK(OwnsPtr());
976 return const_cast<Shape*>(const_ptr);
977 }
978 const Shape* operator->() const { return get(); }
979 const Shape& operator*() const { return *get(); }
980
981 MaybeOwningShapePtr& operator=(std::unique_ptr<Shape> unique) {
982 MaybeDeleteOwned();
983 ptr_and_owning_bit_ = TakeUnique(std::move(std::move(unique)));
984 return *this;
985 }
986
987 MaybeOwningShapePtr& operator=(const Shape* borrowed) {
988 MaybeDeleteOwned();
989 ptr_and_owning_bit_ = Borrow(borrowed);
990 return *this;
991 }
992
993 MaybeOwningShapePtr& operator=(MaybeOwningShapePtr&& other) {
994 using std::swap;
995 swap(ptr_and_owning_bit_, other.ptr_and_owning_bit_);
996 return *this;
997 }
998
999 MaybeOwningShapePtr(const MaybeOwningShapePtr&) = delete;
MaybeOwningShapePtr(MaybeOwningShapePtr && other)1000 MaybeOwningShapePtr(MaybeOwningShapePtr&& other)
1001 : ptr_and_owning_bit_(other.ptr_and_owning_bit_) {
1002 other.ptr_and_owning_bit_ = 0;
1003 }
1004
Clone()1005 MaybeOwningShapePtr Clone() const {
1006 const Shape* ptr = get();
1007 if (ptr && OwnsPtr()) {
1008 return MaybeOwningShapePtr(std::make_unique<Shape>(*ptr));
1009 }
1010 return MaybeOwningShapePtr(ptr);
1011 }
1012
1013 private:
1014 enum : uint64_t {
1015 kOwningBitMask = 1UL,
1016 kPointerMask = ~kOwningBitMask,
1017 };
TakeUnique(std::unique_ptr<Shape> unique)1018 static intptr_t TakeUnique(std::unique_ptr<Shape> unique) {
1019 Shape* released = unique.release();
1020 DCHECK_EQ(reinterpret_cast<intptr_t>(released) & kOwningBitMask, 0);
1021 return reinterpret_cast<intptr_t>(released) | kOwningBitMask;
1022 }
1023
Borrow(const Shape * borrowed)1024 static intptr_t Borrow(const Shape* borrowed) {
1025 DCHECK_EQ(reinterpret_cast<intptr_t>(borrowed) & kOwningBitMask, 0);
1026 return reinterpret_cast<intptr_t>(borrowed);
1027 }
1028
OwnsPtr()1029 bool OwnsPtr() const { return kOwningBitMask & ptr_and_owning_bit_; }
1030
MaybeDeleteOwned()1031 void MaybeDeleteOwned() {
1032 if (OwnsPtr()) {
1033 delete get();
1034 }
1035 }
1036
1037 intptr_t ptr_and_owning_bit_ = 0;
1038 };
1039
1040 // The parent class borrows this shape.
1041 MaybeOwningShapePtr shape_;
1042
1043 // Implementation details shared between Populate() and PopulateParallel()
1044 template <typename NativeT, typename FnType>
1045 Status PopulateInternal(const FnType& generator, bool parallel);
1046
1047 friend class LiteralBase;
1048 friend class MutableBorrowingLiteral;
1049 };
1050 std::ostream& operator<<(std::ostream& out, const Literal& literal);
1051
1052 // The underlying buffer and shape is always owned by this class.
1053 class Literal : public MutableLiteralBase {
1054 public:
1055 Literal();
1056
1057 // Create a literal of the given shape. The literal is allocated sufficient
1058 // memory to hold the shape. Memory is uninitialized.
1059 explicit Literal(const Shape& shape);
1060 virtual ~Literal();
1061
1062 // Literals are moveable, but not copyable. To copy a literal use
1063 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
1064 // of literals which can be expensive.
1065 Literal(const Literal& other) = delete;
1066 Literal& operator=(const Literal& other) = delete;
1067 Literal(Literal&& other);
1068 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
1069 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
1070 // to nullptr.
1071 Literal(const Shape& shape, bool allocate_arrays,
1072 ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1073 Literal& operator=(Literal&& other);
1074
1075 // Similar to CopyFrom, but with move semantics. The subshape of this literal
1076 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
1077 // (layouts and shapes must match), but need not be arrays. The memory
1078 // allocated in this literal for the subshape at dest_shape_index is
1079 // deallocated, and the respective buffers are replaced with those in
1080 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
1081 virtual Status MoveFrom(Literal&& src_literal,
1082 const ShapeIndex& dest_shape_index = {});
1083
1084 // Returns a vector containing the tuple elements of this Literal as separate
1085 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
1086 // elements are moved into the new Literals; no data is copied. Upon return
1087 // this Literal is set to a nil shape (empty tuple)
1088 //
1089 // TODO(jlebar): Because this function invalidates `this`, it should be
1090 // ref-qualified with &&.
1091 std::vector<Literal> DecomposeTuple();
1092
1093 // Returns a subliteral specified by given shape_index. No data is copied, the
1094 // current literal becomes invalid after this function call.
1095 //
1096 // TODO(jlebar): Because this function invalidates `this`, it should be
1097 // ref-qualified with &&.
1098 Literal SubLiteral(ShapeIndexView shape_index);
1099
1100 private:
1101 friend class LiteralBase;
1102 friend class MutableLiteralBase;
root_piece()1103 const Piece& root_piece() const override { return root_piece_; };
1104 // Deallocate the buffers held by this literal.
1105 void DeallocateBuffers();
1106
1107 // Recursively sets the subshapes and buffers of all subpieces rooted at
1108 // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
1109 // the shape.
1110 void SetPiece(
1111 const Shape& shape, Piece* piece, bool allocate_arrays,
1112 ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1113 Piece root_piece_;
1114 };
1115
1116 // The underlying buffer is not owned by this class and is always owned by
1117 // others. The shape is not owned by this class and not mutable.
1118 class MutableBorrowingLiteral : public MutableLiteralBase {
1119 public:
1120 virtual ~MutableBorrowingLiteral();
1121
MutableBorrowingLiteral()1122 MutableBorrowingLiteral() : MutableLiteralBase() {}
1123
1124 MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
1125 MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
1126
1127 // Implicit conversion constructors.
1128 MutableBorrowingLiteral(MutableLiteralBase* literal);
1129 MutableBorrowingLiteral(MutableBorrowingLiteral literal,
1130 const ShapeIndex& view_root);
1131 MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
1132
1133 // Create a literal from a list of buffers and a shape.
1134 // Returns a tuple literal if `shape` is a tuple type.
1135 MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
1136
1137 private:
root_piece()1138 const Piece& root_piece() const override { return *root_piece_; };
1139 // Recursively copies the subtree from the `src_piece` at the given child
1140 // index to the `dest_piece`. For buffers only the pointers are copied, but
1141 // not the content.
1142 void CopyPieceSubtree(const Shape& shape, const Piece* src_piece,
1143 Piece* dest_piece);
1144 Piece* root_piece_ = nullptr;
1145 };
1146
1147 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
1148 // literal buffers always owned by others.
1149 class LiteralSlice : public LiteralBase {
1150 public:
LiteralSlice()1151 LiteralSlice() : LiteralBase() {}
1152
1153 // Implicit conversion constructors.
1154 LiteralSlice(const LiteralBase& literal);
1155 LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
1156
1157 private:
root_piece()1158 const Piece& root_piece() const override { return *root_piece_; };
1159
1160 const Piece* root_piece_; // Not owned.
1161 };
1162
1163 // A read-only Literal where the underlying buffers are never owned by this
1164 // class.
1165 class BorrowingLiteral : public LiteralBase {
1166 public:
BorrowingLiteral()1167 BorrowingLiteral() : LiteralBase() {}
1168
1169 // 'src_buf_ptr' is not owned by this class and must outlive the
1170 // lifetime of this class. It points to an appropriately sized buffer with
1171 // data interpretered as indicated by 'shape'.
1172 // This constructor is only used for array shapes.
1173 BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
1174 // Similar as above, except to be used for constructing non-nested tuples.
1175 BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
1176 const Shape& shape);
1177 // TODO(b/79707221): adding constructors for nested tuples as well.
1178
1179 private:
1180 // Recursively builds the subtree for the given piece and sets the subshapes
1181 // of the given piece with the given shape.
1182 void BuildPieceSubtree(const Shape& shape, Piece* piece);
1183
1184 // Accessor for the root piece of this literal.
root_piece()1185 const Piece& root_piece() const override { return root_piece_; };
1186 Piece root_piece_;
1187
1188 // Shape of this literal. Stored as unique_ptr such that the (default) move
1189 // construction of this class would be trivially correct: the pointer to Shape
1190 // root_piece_ stores will still point to the correct address.
1191 std::unique_ptr<Shape> shape_;
1192 };
1193
1194 template <typename NativeT>
data()1195 absl::Span<const NativeT> LiteralBase::Piece::data() const {
1196 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1197 DCHECK_EQ(subshape().element_type(),
1198 primitive_util::NativeToPrimitiveType<NativeT>())
1199 << "Attempting to access "
1200 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
1201 << " type, but literal element type is "
1202 << PrimitiveType_Name(subshape().element_type());
1203 return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
1204 element_count());
1205 }
1206
1207 template <typename NativeT>
data()1208 absl::Span<NativeT> LiteralBase::Piece::data() {
1209 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1210 DCHECK_EQ(subshape().element_type(),
1211 primitive_util::NativeToPrimitiveType<NativeT>())
1212 << "Attempting to access "
1213 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
1214 << " type, but literal element type is "
1215 << PrimitiveType_Name(subshape().element_type());
1216 return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
1217 element_count());
1218 }
1219
1220 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index)1221 NativeT LiteralBase::Piece::Get(absl::Span<const int64_t> multi_index) const {
1222 CHECK(LayoutUtil::IsDenseArray(subshape())) << subshape();
1223 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1224 subshape(), multi_index)];
1225 }
1226
1227 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,NativeT value)1228 void LiteralBase::Piece::Set(absl::Span<const int64_t> multi_index,
1229 NativeT value) {
1230 CHECK(LayoutUtil::IsDenseArray(subshape()));
1231 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1232 subshape(), multi_index)] = value;
1233 }
1234
1235 template <typename NativeT>
data(const ShapeIndex & shape_index)1236 absl::Span<const NativeT> LiteralBase::data(
1237 const ShapeIndex& shape_index) const {
1238 return piece(shape_index).data<NativeT>();
1239 }
1240
1241 template <typename NativeT>
data(const ShapeIndex & shape_index)1242 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
1243 return piece(shape_index).data<NativeT>();
1244 }
1245
1246 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index)1247 inline NativeT LiteralBase::Get(absl::Span<const int64_t> multi_index,
1248 const ShapeIndex& shape_index) const {
1249 return piece(shape_index).Get<NativeT>(multi_index);
1250 }
1251
1252 template <typename NativeT>
Get(absl::Span<const int64_t> multi_index)1253 inline NativeT LiteralBase::Get(absl::Span<const int64_t> multi_index) const {
1254 return root_piece().Get<NativeT>(multi_index);
1255 }
1256
1257 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,const ShapeIndex & shape_index,NativeT value)1258 inline void MutableLiteralBase::Set(absl::Span<const int64_t> multi_index,
1259 const ShapeIndex& shape_index,
1260 NativeT value) {
1261 return piece(shape_index).Set<NativeT>(multi_index, value);
1262 }
1263
1264 template <typename NativeT>
Set(absl::Span<const int64_t> multi_index,NativeT value)1265 inline void MutableLiteralBase::Set(absl::Span<const int64_t> multi_index,
1266 NativeT value) {
1267 return mutable_root_piece().Set<NativeT>(multi_index, value);
1268 }
1269
1270 template <typename NativeT>
GetFirstElement()1271 NativeT LiteralBase::GetFirstElement() const {
1272 return data<NativeT>().at(0);
1273 }
1274
1275 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64_t> indices,NativeT value)> per_cell)1276 void LiteralBase::EachCell(
1277 std::function<void(absl::Span<const int64_t> indices, NativeT value)>
1278 per_cell) const {
1279 if (ShapeUtil::IsZeroElementArray(shape())) {
1280 return;
1281 }
1282 std::vector<int64_t> indices(shape().rank(), 0);
1283
1284 Shape shape_dynamic = shape();
1285 for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1286 shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1287 }
1288 do {
1289 per_cell(indices, Get<NativeT>(indices));
1290 } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1291 }
1292
1293 template <typename NativeT>
MutableEachCell(std::function<NativeT (absl::Span<const int64_t> indices,NativeT value)> per_cell)1294 void MutableLiteralBase::MutableEachCell(
1295 std::function<NativeT(absl::Span<const int64_t> indices, NativeT value)>
1296 per_cell) {
1297 if (ShapeUtil::IsZeroElementArray(shape())) {
1298 return;
1299 }
1300 std::vector<int64_t> indices(shape().rank(), 0);
1301 Shape shape_dynamic = shape();
1302 for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1303 shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1304 }
1305 do {
1306 Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
1307 } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1308 }
1309
1310 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)1311 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
1312 CHECK(shape().IsArray());
1313 CHECK_EQ(shape().rank(), 1);
1314 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1315 CHECK_EQ(shape().element_type(),
1316 primitive_util::NativeToPrimitiveType<NativeT>());
1317 auto data_span = data<NativeT>();
1318 std::copy(values.begin(), values.end(), data_span.begin());
1319 }
1320
1321 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1322 void MutableLiteralBase::PopulateR2(
1323 std::initializer_list<std::initializer_list<NativeT>> values) {
1324 CHECK(shape().IsArray());
1325 CHECK_EQ(shape().rank(), 2);
1326 CHECK_EQ(shape().element_type(),
1327 primitive_util::NativeToPrimitiveType<NativeT>());
1328
1329 const int64_t dim0_size = values.size();
1330 const int64_t dim1_size = values.begin()->size();
1331 CHECK_EQ(dim0_size, shape().dimensions(0));
1332 CHECK_EQ(dim1_size, shape().dimensions(1));
1333
1334 int64_t dim0 = 0;
1335 for (auto inner_list : values) {
1336 int64_t dim1 = 0;
1337 for (auto value : inner_list) {
1338 Set({dim0, dim1}, value);
1339 ++dim1;
1340 }
1341 CHECK_EQ(dim1_size, dim1);
1342 ++dim0;
1343 }
1344 }
1345
1346 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1347 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1348 CHECK(shape().IsArray());
1349 CHECK_EQ(shape().element_type(),
1350 primitive_util::NativeToPrimitiveType<NativeT>());
1351 CHECK_EQ(shape().rank(), values.num_dimensions());
1352 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1353 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1354 }
1355 values.Each([this](absl::Span<const int64_t> indices, NativeT value) {
1356 this->Set(indices, value);
1357 });
1358 }
1359
1360 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1361 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1362 PopulateFromArray(values);
1363 }
1364
1365 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1366 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1367 PopulateFromArray(values);
1368 }
1369
1370 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1371 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1372 PopulateFromArray(values);
1373 }
1374
1375 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1376 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1377 bool parallel) {
1378 const Shape& this_shape = shape();
1379 const int64_t rank = this_shape.rank();
1380 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1381 TF_RET_CHECK(this_shape.element_type() ==
1382 primitive_util::NativeToPrimitiveType<NativeT>())
1383 << "Failing to populate literal with element type "
1384 << primitive_util::LowercasePrimitiveTypeName(this_shape.element_type())
1385 << " using data of type "
1386 << primitive_util::LowercasePrimitiveTypeName(
1387 primitive_util::NativeToPrimitiveType<NativeT>());
1388 absl::Span<NativeT> literal_data = data<NativeT>();
1389 if (rank > 0) {
1390 StrideConfig stride_config(this_shape, this_shape, this_shape.dimensions());
1391 int64_t minor_dimension_size =
1392 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1393
1394 auto init_function = [&](absl::Span<const int64_t> indexes, int thread_id) {
1395 DimensionVector minor_scan_indexes(rank, 0);
1396 const int64_t index =
1397 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1398 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1399 for (int64_t i = 0; i < minor_dimension_size; ++i) {
1400 minor_scan_indexes[stride_config.minor_dimension] = i;
1401 literal_data.at(index + i) = generator(minor_scan_indexes, thread_id);
1402 }
1403 };
1404 if (parallel) {
1405 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1406 stride_config.dimensions,
1407 stride_config.step, init_function);
1408 } else {
1409 ShapeUtil::ForEachIndex(
1410 this_shape, stride_config.base, stride_config.dimensions,
1411 stride_config.step,
1412 [&init_function](absl::Span<const int64_t> indexes) {
1413 init_function(indexes, /*thread_id=*/-1);
1414 return true;
1415 });
1416 }
1417 } else {
1418 // For scalars.
1419 literal_data.at(0) = generator({}, /*thread_id=*/-1);
1420 }
1421 return OkStatus();
1422 }
1423 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1424 Status MutableLiteralBase::Populate(const FnType& generator) {
1425 return PopulateInternal<NativeT>(
1426 [&](absl::Span<const int64_t> indexes, int /*thread_id*/) {
1427 return generator(indexes);
1428 },
1429 /*parallel=*/false);
1430 }
1431
1432 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1433 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1434 return PopulateInternal<NativeT>(
1435 [&](absl::Span<const int64_t> indexes, int thread_id) {
1436 return generator(indexes, thread_id);
1437 },
1438 /*parallel=*/true);
1439 }
1440
1441 template <typename NativeT>
PopulateWithValue(NativeT value)1442 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1443 CHECK(shape().IsArray());
1444 CHECK_EQ(shape().element_type(),
1445 primitive_util::NativeToPrimitiveType<NativeT>());
1446 for (NativeT& element : data<NativeT>()) {
1447 element = value;
1448 }
1449 }
1450
1451 template <typename NativeT>
Replicate(int64_t times)1452 Literal LiteralBase::Replicate(int64_t times) const {
1453 DimensionVector bounds = {times};
1454 bounds.reserve(shape().dimensions_size() + 1);
1455 for (int64_t bound : shape().dimensions()) {
1456 bounds.push_back(bound);
1457 }
1458 Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1459 int64_t elements = ShapeUtil::ElementsIn(literal.shape());
1460 if (elements == 0) {
1461 return literal;
1462 }
1463
1464 DimensionVector output_indices(bounds.size(), 0);
1465 absl::Span<const int64_t> input_indices = output_indices;
1466 input_indices.remove_prefix(1);
1467
1468 bool done = false;
1469 while (!done) {
1470 const auto element = Get<NativeT>(input_indices);
1471 literal.Set<NativeT>(output_indices, element);
1472
1473 done = true;
1474 for (int n = 0; n < output_indices.size(); ++n) {
1475 ++output_indices[n];
1476 if (output_indices[n] < bounds[n]) {
1477 done = false;
1478 break;
1479 }
1480 output_indices[n] = 0;
1481 }
1482 }
1483 return literal;
1484 }
1485
1486 } // namespace xla
1487
1488 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1489