1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/array2d.h"
33 #include "tensorflow/compiler/xla/array3d.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56
57 // Abstract base class for literals.
58 class LiteralBase {
59 public:
60 virtual ~LiteralBase() = 0;
61
62 // Literals are equal if they have compatible shapes and the same data
63 // values. Layout is not compared.
64 bool operator==(const LiteralBase& other) const;
65 bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66
67 // Returns the shape of the literal.
shape()68 const Shape& shape() const { return root_piece().subshape(); }
69
70 // Serialize to proto.
71 LiteralProto ToProto() const;
72
73 // Returns a Span of the array for this literal for the given NativeT
74 // (e.g., float). CHECKs if the subshape of the literal at the given
75 // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76 // to native type.
77 template <typename NativeT>
78 absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79
80 // Returns a const pointer to (or size of) the underlying buffer holding the
81 // array at the given shape index. CHECKs if the subshape of the literal at
82 // the given ShapeIndex is not array.
83 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
85
86 // Returns this literal's data as a string. This literal must be a rank-1 U8
87 // array.
88 string GetR1U8AsString() const;
89
90 // Returns a string representation of the literal value. The Shape of the
91 // literal is a prefix of the literal value in the string.
92
93 // Warning: this function can take minutes for multi-million
94 // element Literals.
95 string ToString() const;
96
97 // Similar to ToString, but return the result in a compact
98 // one-line form.
99 string ToStringOneline() const;
100
101 // Returns a string representation of the literal value which does *not*
102 // include the shape string.
103 string ToStringWithoutShape() const;
104
105 // Similar to ToStringWithoutShape, but return the result in a compact
106 // one-line form.
107 string ToStringWithoutShapeOneline() const;
108
109 // Returns a string representation of the literal value which includes the
110 // shape string with its layout.does *not* include the shape string.
111 string ToStringWithLayout() const;
112
113 // Gets an element in the literal at the given index. The multi_index is
114 // CHECKed against the dimension sizes.
115 template <typename NativeT>
116 NativeT Get(absl::Span<const int64> multi_index,
117 const ShapeIndex& shape_index) const;
118 // Overloads of Get for array literals. CHECKs if the literal is not
119 // array-shaped and dense.
120 template <typename NativeT>
121 NativeT Get(absl::Span<const int64> multi_index) const;
122
123 // Get the dynamic size on dim_index in the literal at the given shape_index.
124 int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const;
125 int32 GetDynamicSize(int64 dim_index) const;
126
127 // Returns the element value at index (0, ..., 0), however many zeroes are
128 // required for that index.
129 template <typename NativeT>
130 NativeT GetFirstElement() const;
131
132 // As above but returns any integer type casted to an int64.
133 absl::optional<int64> GetFirstInteger() const;
134
135 // As Get(), but determines the correct type and converts the value
136 // into text.
137 string GetAsString(absl::Span<const int64> multi_index,
138 const ShapeIndex& shape_index = {}) const;
139
140 // Return whether the value at the specified index is equal to the provided
141 // generic `value` (T must be an arithmetic type).
142 //
143 // Precondition: must be an array.
144 template <typename T>
145 typename std::enable_if<(std::is_arithmetic<T>::value ||
146 std::is_same<T, Eigen::half>::value ||
147 std::is_same<T, bfloat16>::value),
148 bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)149 IsEqualAt(absl::Span<const int64> multi_index, T value) const {
150 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
151 return *as_s64 == value;
152 }
153 complex128 as_complex128 = *GetAsComplex128(multi_index);
154 return as_complex128.imag() == 0 && as_complex128.real() == value;
155 }
156
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)157 bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
158 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
159 return *as_s64 == value.real() && value.imag() == 0;
160 }
161 auto as_complex128 = GetAsComplex128(multi_index);
162 return *as_complex128 == value;
163 }
164
165 // As Get(), but determines the correct type and converts the value into
166 // int64. This literal must be an array.
167 absl::optional<int64> GetIntegralAsS64(
168 absl::Span<const int64> multi_index) const;
169
170 // As Get(), but determines the correct type, and converts the value into
171 // double. This literal must be an array.
172 absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
173
174 // As Get(), but determines the correct type, and converts the value into
175 // complex128. All floating point types can be converted into complex128.
176 //
177 // This literal must be an array.
178 absl::optional<complex128> GetAsComplex128(
179 absl::Span<const int64> multi_index) const;
180
181 // Invokes the "per cell" callback for each element in the provided
182 // literal with the element's indices and a string representation of
183 // the element's value.
184 //
185 // This function is useful if you want a polymorphic representation
186 // of the tensor's elements (turning it to a string for something
187 // like representation in a protobuf).
188 //
189 // This literal must have a dense layout.
190 void EachCellAsString(
191 const std::function<void(absl::Span<const int64> indices,
192 const string& value)>& per_cell) const;
193 template <typename NativeT>
194 void EachCell(
195 std::function<void(absl::Span<const int64> indices, NativeT value)>
196 per_cell) const;
197
198 // Returns whether every element in this literal is equal to value.
199 //
200 // value is an int8 because we expect this to be called with small
201 // compile-time constants (0, -1, etc.) and so that whatever value you pass
202 // can be represented exactly by floating-point types as small as 16 bits.
203 //
204 // If value doesn't fit in this literal's type, returns false. Values of 1/0
205 // are considered equal to true/false; other values are not considered equal
206 // to true. Also if this literal is not array-shaped false is returned.
207 bool IsAll(int8 value) const;
208
209 // Like IsAll(const Literal&, int8), except we check whether the literal is
210 // equal to a particular floating-point number.
211 //
212 // If the literal is not a floating-point value, this always returns false.
213 //
214 // This casts value to the type of literal, then compares using ==. The usual
215 // admonishments about floating-point equality checks apply. We expect you to
216 // use this to check for values that can be expressed precisely as a float,
217 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
218 bool IsAllFloat(float value) const;
219
220 // Like IsAll(const Literal&, int8), except we check whether the literal is
221 // equal to a particular complex number.
222 //
223 // If the literal is not a complex value, this always returns false.
224 //
225 // This casts value to the type of literal, then compares using ==. The usual
226 // admonishments about floating-point equality checks apply. We expect you to
227 // use this to check for complex values that can be expressed precisely as
228 // float pairs e.g. (-0.5, 1.0).
229 //
230 // This literal must have a dense layout.
231 bool IsAllComplex(complex64 value) const;
232
233 // Literal consists entirely of the first element of the literal.
234 bool IsAllFirst() const;
235
236 // Literal consists entirely of an iota.
237 bool IsR1Iota() const;
238
239 // Returns whether this literal is zero at the specified index. This literal
240 // must be an array with a dense layout.
241 bool IsZero(absl::Span<const int64> indices) const;
242
243 // Returns the count of the elements in the array at the given shape index in
244 // this literal.
245 int64 element_count(const ShapeIndex& index = {}) const {
246 if (index.empty()) {
247 // Common case, avoid GetSubshape().
248 return ShapeUtil::ElementsIn(shape());
249 }
250 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
251 }
252
253 // Compute a hash for this literal.
254 size_t Hash() const;
255
256 // Converts this literal to the given shape. Returns an error is the
257 // conversion is not possible.
258 StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
259
260 // Converts this literal to another primitive type using a bitcast
261 // conversion. The to and from primitive types must have the same bit
262 // width. Returns an error if the conversion is not possible. This literal
263 // must be array-shaped.
264 StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
265
266 // Converts this literal to another primitive type. Returns an error if the
267 // conversion is not possible. This literal must be array-shaped.
268 StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
269
270 // Clones the underlying buffers into a new Literal.
271 Literal Clone() const;
272
273 // TODO(b/67651157): The methods below which perform computation on Literals
274 // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
275 // evaluator code which operates on Literals.
276 //
277 // Creates a new value that has the equivalent value as this
278 // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
279 // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
280 // minor-to-major dimension layout and the value in the cell at any given
281 // logical index (i0, i1) will be the same.
282 //
283 // For tuple shaped literals, shape_index should be used to select the inner
284 // array that the new layout applies to.
285 //
286 // Note: this is useful when the client wants to ensure that a value placed in
287 // the XLA allocation tracker has a particular layout; for efficiency
288 // purposes or avoiding unimplemented operation/layout combinations.
289 Literal Relayout(const Layout& new_layout,
290 const ShapeIndex& shape_index = {}) const;
291
292 // An overload of Relayout which changes the layout of the entire shape rather
293 // than being limited to a single array within the shape.
294 Literal Relayout(const Shape& shape_with_layout) const;
295
296 // Generate a new literal whose static sizes are equal to the previous
297 // literal's dynamic sizes.
298 Literal ToStatic() const;
299
300 // Expand a static literal into a new one with a bounded dyanmic literal. The
301 // static dimensions of the original literal becomes dynamic dimensions of the
302 // new literal, where the argument `bounded_shape` becomes the bounded shape
303 // of the new literal.
304 //
305 // Precondition: bounded_shape.is_dynamic()
306 Literal ToBoundedDynamic(const Shape& bounded_shape) const;
307
308 // Creates a new literal by reshaping this literal to have the given
309 // dimensions. The total number of elements must not change; The
310 // implementation currently only supports monotonic dim0-major layouts.
311 // This literal must be an array.
312 StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
313
314 // Creates a new literal by broadcasting this literal with `dimensions` to
315 // yield a literal of shape `result_shape`.
316 StatusOr<Literal> Broadcast(const Shape& result_shape,
317 absl::Span<const int64> dimensions) const;
318
319 // Creates a new literal by reordering the dimensions of this literal.
320 // The given `permutation` must be a permutation of the dimension numbers
321 // in the original literal, and it specifies the order of the new dimensions
322 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
323 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
324 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
325 // This literal must be an array.
326 Literal Transpose(absl::Span<const int64> permutation) const;
327
328 // Creates a sub-array from this literal by extracting the indices
329 // [start_index, limit_index) of each dimension. The result literal has the
330 // same rank and layout as for the given literal. The number of indices in
331 // start_indices and limit_indices must be the rank of the literal, and the
332 // indices follow the order of the dimensions.
333 // This literal must be an array.
334 Literal Slice(absl::Span<const int64> start_indices,
335 absl::Span<const int64> limit_indices) const;
336
337 // Creates a literal with a prepended dimension with bound "times"; e.g. a
338 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
339 // literal replicated four times.
340 // This literal must be an array.
341 template <typename NativeT>
342 Literal Replicate(int64 times) const;
343
344 // Creates a new Literal object with the shape specified as parameter.
345 // The content of the literal values is the default value of the primitive
346 // type of literal itself (0 for numeric types, and false for predicates).
347 //
348 // Note: It's an antipattern to use this method then immediately call
349 // MutableLiteralBase::Populate on the result (since that results in zero
350 // initialization, then reinitialization. Consider if a call to
351 // absl::make_unique<Literal>(shape), followed by the call to
352 // MutableLiteralBase::Populate can be used instead.
353 static Literal CreateFromShape(const Shape& shape);
354
355 protected:
356 // A data structure representing a subshape at a particular ShapeIndex within
357 // the literal. For array-shaped ShapeIndexes, this data structure holds the
358 // pointer to the memory allocated for the array data.
359 class Piece {
360 public:
361 // Returns the buffer holding the array data for this piece as an array
362 // slice. This piece must be array-shaped.
363 template <typename NativeT>
364 absl::Span<const NativeT> data() const;
365 template <typename NativeT>
366 absl::Span<NativeT> data();
367
368 // Returns the buffer holding the array data for this piece as a void*. This
369 // piece must be array-shaped.
370 void* untyped_data();
371 const void* untyped_data() const;
372
373 // Gets or sets an element in the array at the given index. The multi_index
374 // is CHECKed against the dimension sizes of the array. This piece must be
375 // array-shaped.
376 template <typename NativeT>
377 NativeT Get(absl::Span<const int64> index) const;
378 template <typename NativeT>
379 void Set(absl::Span<const int64> index, NativeT value);
380
381 int32 GetDynamicSize(int64 dim_index) const;
382 void SetDynamicSize(int64 dim_index, int32 size);
383 // Gets/sets the buffer holding the array data.
buffer()384 char* buffer() const { return buffer_; }
set_buffer(char * buffer)385 void set_buffer(char* buffer) { buffer_ = buffer; }
386
387 // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()388 int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
set_dynamic_size_buffer(int32 * dynamic_size_buffer)389 void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
390 dynamic_size_buffer_ = dynamic_size_buffer;
391 }
392
dynamic_size_buffer_bytes()393 int64 dynamic_size_buffer_bytes() const {
394 return subshape().dimensions_size() * sizeof(int32);
395 }
396
397 // Gets or sets the subshape of this piece. This reference points to a
398 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()399 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)400 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
401
402 // Returns the size in bytes of the buffer holding the array data.
size_bytes()403 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
404
405 // Returns the number of elements in this piece's array.
element_count()406 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
407
408 // Returns the child piece at 'index' of this piece.
child(int64 index)409 Piece& child(int64 index) { return children_[index]; }
410
411 // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)412 void emplace_back(Piece child_piece) {
413 children_.emplace_back(std::move(child_piece));
414 }
415
416 // Returns the size of children pieces of this piece.
children_size()417 int64 children_size() { return children_.size(); }
418
419 // Visitor functions that recursively traverses the piece and calls the
420 // given function at each child piece. The function has the type:
421 // void (const ShapeIndex& index, const Piece& piece)
422 template <typename Fn>
ForEachSubpiece(const Fn & func)423 void ForEachSubpiece(const Fn& func) const {
424 ShapeIndex index;
425 return ForEachHelper(
426 [&func](const ShapeIndex& index, const Piece& piece) {
427 func(index, piece);
428 return Status::OK();
429 },
430 *this, &index)
431 .IgnoreError();
432 }
433 // Same as above, but the function has the type:
434 // Status (const ShapeIndex& index, const Piece& piece)
435 // The first non-OK return value is returned by the function.
436 template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)437 Status ForEachSubpieceWithStatus(const Fn& func) const {
438 ShapeIndex index;
439 return ForEachHelper(func, *this, &index);
440 }
441 // Same as above, but the function has the type:
442 // Bool (const ShapeIndex& index, const Piece& piece)
443 // The first non-true return value is returned by the function.
444 template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)445 bool ForEachSubpieceWithBool(const Fn& func) const {
446 ShapeIndex index;
447 return ForEachHelperBool(func, *this, &index);
448 }
449 // Same as above, but the function has the type:
450 // Void (const ShapeIndex& index, Piece& piece)
451 template <typename Fn>
ForEachMutableSubpiece(const Fn & func)452 void ForEachMutableSubpiece(const Fn& func) {
453 ShapeIndex index;
454 return ForEachMutableHelper(
455 [&func](const ShapeIndex& index, Piece* piece) {
456 func(index, piece);
457 return Status::OK();
458 },
459 const_cast<xla::LiteralBase::Piece*>(this), &index)
460 .IgnoreError();
461 }
462 // Same as above, but the function has the type:
463 // Status (const ShapeIndex& index, Piece& piece)
464 // The first non-OK return value is returned by the function.
465 template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)466 Status ForEachMutableSubpieceWithStatus(const Fn& func) {
467 ShapeIndex index;
468 return ForEachMutableHelper(
469 func, const_cast<xla::LiteralBase::Piece*>(this), &index);
470 }
471
472 // Returns true if this piece and 'other' contain the same data. This piece
473 // and 'other' must be array-shaped and compatible. If a literal has dynamic
474 // shape, comparison is done only for the valid elements.
475 bool EqualElements(const Piece& other) const;
476
477 // Returns true if this piece and other pieces have the same dynamic
478 // dimension sizes.
479 bool EqualDynamicSize(const Piece& other) const;
480
481 // Writes the shape and data (if array-shaped) into the given proto.
482 void WriteToProto(LiteralProto* proto) const;
483
484 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
485 // and src must be compatible. If only_dynamic_bound is true, only elements
486 // within dynamic bounds will be copied.
487 Status CopyFrom(const Piece& src, bool only_dynamic_bound);
488
489 // Copies the data from the given proto into this piece. The shape of this
490 // piece must be equal (not just compatible) to the shape of the proto.
491 Status CopyFromProto(const LiteralProto& proto);
492
493 private:
494 // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
495 // The first non-OK (or non-true) value is returned by the function.
496 // The callable 'func' has the same signature as described above in
497 // ForEachSubpiece*.
498 template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)499 Status ForEachHelper(const Fn& func, const Piece& piece,
500 ShapeIndex* index) const {
501 TF_RETURN_IF_ERROR(func(*index, piece));
502 for (int64 i = 0; i < piece.children_.size(); ++i) {
503 index->push_back(i);
504 TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
505 index->pop_back();
506 }
507 return Status::OK();
508 }
509 template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)510 bool ForEachHelperBool(const Fn& func, const Piece& piece,
511 ShapeIndex* index) const {
512 if (!func(*index, piece)) {
513 return false;
514 }
515 for (int64 i = 0; i < piece.children_.size(); ++i) {
516 index->push_back(i);
517 if (!ForEachHelperBool(func, piece.children_[i], index)) {
518 return false;
519 }
520 index->pop_back();
521 }
522 return true;
523 }
524 template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)525 Status ForEachMutableHelper(const Fn& func, Piece* piece,
526 ShapeIndex* index) {
527 TF_RETURN_IF_ERROR(func(*index, piece));
528 for (int64 i = 0; i < piece->children_.size(); ++i) {
529 index->push_back(i);
530 TF_RETURN_IF_ERROR(
531 ForEachMutableHelper(func, &piece->children_[i], index));
532 index->pop_back();
533 }
534 return Status::OK();
535 }
536
537 // Recursive helper for EqualElements.
538 template <typename NativeT>
539 bool EqualElementsInternal(const Piece& other,
540 std::vector<int64>* multi_index) const;
541
542 // Internal helper to copy elements from another given piece
543 template <typename NativeT>
544 void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
545
546 // For array-shaped pieces, this is the buffer holding the literal data.
547 char* buffer_ = nullptr;
548
549 int32* dynamic_size_buffer_ = nullptr;
550
551 // The shape of piece. This points into the shape of the containing Literal
552 // (Literal::shape_).
553 const Shape* subshape_ = nullptr;
554
555 // Children pieces for tuple shaped pieces.
556 std::vector<Piece> children_ = {};
557 }; // class Piece
558
piece(const ShapeIndex & shape_index)559 const Piece& piece(const ShapeIndex& shape_index) const {
560 Piece* piece = &const_cast<Piece&>(root_piece());
561 for (const auto i : shape_index) {
562 DCHECK_GE(i, 0);
563 DCHECK_LT(i, piece->children_size());
564 piece = &piece->child(i);
565 }
566 return *piece;
567 }
568
569 // Returns the piece at the root of the shape.
570 virtual const Piece& root_piece() const = 0;
571
572 // LiteralSlice and Literal must access Pieces of other Literals.
573 friend class MutableLiteralBase;
574 friend class LiteralSlice;
575 friend class BorrowingLiteral;
576
577 private:
578 template <typename NativeT>
579 Literal SliceInternal(const Shape& result_shape,
580 absl::Span<const int64> start_indices) const;
581 };
582
583 // Abstract base class representing a mutable literal in XLA.
584 class MutableLiteralBase : public LiteralBase {
585 public:
586 virtual ~MutableLiteralBase() = 0;
587
588 // Returns a Span view of the array for this literal for the
589 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
590 // given ShapeIndex is not array. See primitive_util.h for the mapping from
591 // XLA type to native type.
592 template <typename NativeT>
593 absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
594 // Unhide const method from parent class.
595 using LiteralBase::data;
596
597 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
598 // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()599 Shape* mutable_shape_do_not_use() { return shape_.get(); }
600
601 // Set the dynamic size on dim_index in the literal at the given shape_index.
602 void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index,
603 int32 size);
604 void SetDynamicSize(int64 dim_index, int32 size);
605
606 // Returns a pointer to the underlying buffer holding the array at the given
607 // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
608 // is not array.
609 void* untyped_data(const ShapeIndex& shape_index = {});
610 // Unhide const method from parent class.
611 using LiteralBase::untyped_data;
612
613 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
614 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
615 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
616 // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
617 // is true, only elements within dynamic bounds will be copied.
618 Status CopyFrom(const LiteralSlice& src_literal,
619 const ShapeIndex& dest_shape_index = {},
620 const ShapeIndex& src_shape_index = {},
621 bool only_dynamic_bound = false);
622
623 // Copies the values from src_literal, starting at src_base shape indexes,
624 // to this literal, starting at dest_base, where the copy size in each
625 // dimension is specified by copy_size.
626 // The src_literal and this literal must have the same primitive type,
627 // src_base+copy_size must fit the source literal dimensions, as well as
628 // dest_base+copy_size must fit the destination literal dimensions.
629 // Note: if either src_literal or this literal contains dimensions with zero
630 // element, then copy_size must be 0 in these dimensions while the
631 // corresponding base indices being 0.
632 // This literal and 'src_literal' must be arrays.
633 Status CopySliceFrom(const LiteralSlice& src_literal,
634 absl::Span<const int64> src_base,
635 absl::Span<const int64> dest_base,
636 absl::Span<const int64> copy_size);
637
638 // Copies one element from src_literal[src_index] to (*this)[dest_index].
639 Status CopyElementFrom(const LiteralSlice& src_literal,
640 absl::Span<const int64> src_index,
641 absl::Span<const int64> dest_index);
642
643 // Sets an element in the literal at the given index. The multi_index is
644 // CHECKed against the dimension sizes.
645 template <typename NativeT>
646 void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
647 NativeT value);
648 // Overloads of Set for array literals. CHECKs if the literal is not
649 // array-shaped and dense.
650 template <typename NativeT>
651 void Set(absl::Span<const int64> multi_index, NativeT value);
652
653 // As Set(), but truncates `value` to the literal element type before storing.
654 // This literal must be an array.
655 Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
656
657 // As Set(), but truncates `value` to the literal element type before storing.
658 // This literal must be an array.
659 Status SetFromDouble(absl::Span<const int64> multi_index, double value);
660
661 // Populate this literal with the given values. Examples:
662 //
663 // // Populate with floats.
664 // Array2D<float> float_values = ...
665 // literal.PopulateR2FromArray2D(values);
666 //
667 // // Populate with int32s.
668 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
669 //
670 // The shape and element type of this literal must match given values. For
671 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
672 // array of S32.
673 template <typename NativeT>
674 void PopulateR1(absl::Span<const NativeT> values);
675 void PopulateR1(const tensorflow::core::Bitmap& values);
676 template <typename NativeT>
677 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
678 template <typename NativeT>
679 void PopulateFromArray(const Array<NativeT>& values);
680 template <typename NativeT>
681 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
682 template <typename NativeT>
683 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
684 template <typename NativeT>
685 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
686
687 // Populates literal values by calling the generator function for every cell
688 // in this literal object.
689 //
690 // generator must be a callable of the type
691 // NativeT(absl::Span<int64> indexes) or compatible.
692 //
693 // This literal must have a dense layout.
694 template <typename NativeT, typename FnType>
695 Status Populate(const FnType& generator);
696
697 // A parallel version of Populate(). This can be used if the generator is
698 // thread-safe and the values for the shape's different elements are
699 // independent.
700 template <typename NativeT, typename FnType>
701 Status PopulateParallel(const FnType& generator);
702
703 // Fills this literal with the given value.
704 template <typename NativeT>
705 void PopulateWithValue(NativeT value);
706
707 // This operation is the inverse of DecomposeTuple. The given elements are
708 // moved into the tuple elements of a new tuple-shaped Literal which is
709 // returned. Upon return, each of the Literals in 'elements' is set to a nil
710 // shape (empty tuple).
711 static Literal MoveIntoTuple(absl::Span<Literal> elements);
712
713 // Serialize from a proto.
714 static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
715 bool prohibit_empty_literal = true);
716
717 protected:
718 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)719 Piece& piece(const ShapeIndex& shape_index) {
720 return const_cast<Piece&>(LiteralBase::piece(shape_index));
721 }
722
root_piece()723 Piece& root_piece() const override { return *root_piece_; };
724
725 // Internal template helper for the Literal::CopySliceFrom(), matching its
726 // arguments one by one.
727 template <typename NativeT>
728 Status CopySliceFromInternal(const LiteralBase& src_literal,
729 absl::Span<const int64> src_base,
730 absl::Span<const int64> dest_base,
731 absl::Span<const int64> copy_size);
732
733 // Utility structure which is used to create the optimal configuration for
734 // a ShapeUtil::ForEachIndex() scan across two literals.
735 struct StrideConfig {
736 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
737 absl::Span<const int64> dimensions);
738
739 // The dimensions of the stride operation. Essentially every dimension
740 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
741 // steps.
742 absl::Span<const int64> dimensions;
743 DimensionVector base;
744 DimensionVector step;
745 int64 minor_dimension = 0;
746 // The size of the strides for source and destination. One of the two
747 // (the one looping through its most minor dimension) will be 1, while
748 // the other will be the stride size at the dimension matching the other
749 // shape most minor dimension being scanned.
750 int64 dest_stride = 1;
751 int64 source_stride = 1;
752 // The size of the inner loop on the most minor dimension.
753 int64 minor_loop_size = 1;
754 };
755
756 // Literal class always owns the shape. The parent class borrows this shape.
757 std::unique_ptr<Shape> shape_;
758
759 Piece* root_piece_ = nullptr;
760
761 // Implementation details shared between Populate() and PopulateParallel()
762 template <typename NativeT, typename FnType>
763 Status PopulateInternal(const FnType& generator, bool parallel);
764
765 friend class LiteralBase;
766 friend class MutableBorrowingLiteral;
767 };
768 std::ostream& operator<<(std::ostream& out, const Literal& literal);
769
770 // The underlying buffer and shape is always owned by this class.
771 class Literal : public MutableLiteralBase {
772 public:
Literal()773 Literal() : Literal(ShapeUtil::MakeNil()) {}
774
775 // Create a literal of the given shape. The literal is allocated sufficient
776 // memory to hold the shape. Memory is uninitialized.
777 explicit Literal(const Shape& shape);
778 virtual ~Literal();
779
780 // Literals are moveable, but not copyable. To copy a literal use
781 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
782 // of literals which can be expensive.
783 Literal(const Literal& other) = delete;
784 Literal& operator=(const Literal& other) = delete;
785 Literal(Literal&& other);
786 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
787 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
788 // to nullptr.
789 Literal(const Shape& shape, bool allocate_arrays);
790 Literal& operator=(Literal&& other);
791
792 // Similar to CopyFrom, but with move semantics. The subshape of this literal
793 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
794 // (layouts and shapes must match), but need not be arrays. The memory
795 // allocated in this literal for the subshape at dest_shape_index is
796 // deallocated, and the respective buffers are replaced with those in
797 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
798 virtual Status MoveFrom(Literal&& src_literal,
799 const ShapeIndex& dest_shape_index = {});
800
801 // Returns a vector containing the tuple elements of this Literal as separate
802 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
803 // elements are moved into the new Literals; no data is copied. Upon return
804 // this Literal is set to a nil shape (empty tuple)
805 std::vector<Literal> DecomposeTuple();
806
807 private:
808 // Deallocate the buffers held by this literal.
809 void DeallocateBuffers();
810
811 // Recursively sets the subshapes and buffers of all subpieces rooted at
812 // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
813 // the shape.
814 void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
815 };
816
817 // The underlying buffer is not owned by this class and is always owned by
818 // others. The shape is not owned by this class and not mutable.
819 class MutableBorrowingLiteral : public MutableLiteralBase {
820 public:
821 virtual ~MutableBorrowingLiteral();
822
MutableBorrowingLiteral()823 MutableBorrowingLiteral() : MutableLiteralBase() {}
824
825 MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
826 MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
827
828 // Implicit conversion constructors.
829 MutableBorrowingLiteral(MutableLiteralBase* literal);
830 MutableBorrowingLiteral(MutableBorrowingLiteral literal,
831 const ShapeIndex& view_root);
832 MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
833
834 // Create a literal from a list of buffers and a shape.
835 // Returns a tuple literal if `shape` is a tuple type.
836 MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
837
838 private:
839 // Recursively copies the subtree from the `src_piece` at the given child
840 // index to the `dest_piece`. For buffers only the pointers are copied, but
841 // not the content.
842 void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
843 Piece* dest_piece);
844 };
845
846 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
847 // literal buffers always owned by others.
848 class LiteralSlice : public LiteralBase {
849 public:
LiteralSlice()850 LiteralSlice() : LiteralBase() {}
851
852 // Implicit conversion constructors.
853 LiteralSlice(const LiteralBase& literal);
854 LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
855
856 private:
root_piece()857 const Piece& root_piece() const override { return *root_piece_; };
858
859 const Piece* root_piece_; // Not owned.
860 };
861
862 // A read-only Literal where the underlying buffers are never owned by this
863 // class.
864 class BorrowingLiteral : public LiteralBase {
865 public:
BorrowingLiteral()866 BorrowingLiteral() : LiteralBase() {}
867
868 // 'src_buf_ptr' is not owned by this class and must outlive the
869 // lifetime of this class. It points to an appropriately sized buffer with
870 // data interpretered as indicated by 'shape'.
871 // This constructor is only used for array shapes.
872 BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
873 // Similar as above, except to be used for constructing non-nested tuples.
874 BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
875 const Shape& shape);
876 // TODO(b/79707221): adding constructors for nested tuples as well.
877
878 private:
879 // Recursively builds the subtree for the given piece and sets the subshapes
880 // of the given piece with the given shape.
881 void BuildPieceSubtree(const Shape& shape, Piece* piece);
882
883 // Accessor for the root piece of this literal.
root_piece()884 const Piece& root_piece() const override { return root_piece_; };
885 Piece root_piece_;
886
887 // Shape of this literal. Stored as unique_ptr such that the (default) move
888 // construction of this class would be trivially correct: the pointer to Shape
889 // root_piece_ stores will still point to the correct address.
890 std::unique_ptr<Shape> shape_;
891 };
892
893 template <typename NativeT>
data()894 absl::Span<const NativeT> LiteralBase::Piece::data() const {
895 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
896 DCHECK_EQ(subshape().element_type(),
897 primitive_util::NativeToPrimitiveType<NativeT>())
898 << "Attempting to access "
899 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
900 << " type, but literal element type is "
901 << PrimitiveType_Name(subshape().element_type());
902 return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
903 element_count());
904 }
905
906 template <typename NativeT>
data()907 absl::Span<NativeT> LiteralBase::Piece::data() {
908 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
909 DCHECK_EQ(subshape().element_type(),
910 primitive_util::NativeToPrimitiveType<NativeT>())
911 << "Attempting to access "
912 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
913 << " type, but literal element type is "
914 << PrimitiveType_Name(subshape().element_type());
915 return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
916 element_count());
917 }
918
919 template <typename NativeT>
Get(absl::Span<const int64> multi_index)920 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
921 CHECK(LayoutUtil::IsDenseArray(subshape()));
922 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
923 subshape(), multi_index)];
924 }
925
926 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)927 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
928 NativeT value) {
929 CHECK(LayoutUtil::IsDenseArray(subshape()));
930 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
931 subshape(), multi_index)] = value;
932 }
933
934 template <typename NativeT>
data(const ShapeIndex & shape_index)935 absl::Span<const NativeT> LiteralBase::data(
936 const ShapeIndex& shape_index) const {
937 return piece(shape_index).data<NativeT>();
938 }
939
940 template <typename NativeT>
data(const ShapeIndex & shape_index)941 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
942 return piece(shape_index).data<NativeT>();
943 }
944
945 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)946 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
947 const ShapeIndex& shape_index) const {
948 return piece(shape_index).Get<NativeT>(multi_index);
949 }
950
951 template <typename NativeT>
Get(absl::Span<const int64> multi_index)952 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
953 return root_piece().Get<NativeT>(multi_index);
954 }
955
956 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)957 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
958 const ShapeIndex& shape_index,
959 NativeT value) {
960 return piece(shape_index).Set<NativeT>(multi_index, value);
961 }
962
963 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)964 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
965 NativeT value) {
966 return root_piece().Set<NativeT>(multi_index, value);
967 }
968
969 template <typename NativeT>
GetFirstElement()970 NativeT LiteralBase::GetFirstElement() const {
971 return data<NativeT>().at(0);
972 }
973
974 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)975 void LiteralBase::EachCell(
976 std::function<void(absl::Span<const int64> indices, NativeT value)>
977 per_cell) const {
978 if (ShapeUtil::IsZeroElementArray(shape())) {
979 return;
980 }
981 std::vector<int64> indices(shape().rank(), 0);
982
983 Shape shape_dynamic = shape();
984 for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
985 shape_dynamic.set_dimensions(i, GetDynamicSize(i));
986 }
987 do {
988 per_cell(indices, Get<NativeT>(indices));
989 } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
990 }
991
992 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)993 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
994 CHECK(shape().IsArray());
995 CHECK_EQ(shape().rank(), 1);
996 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
997 CHECK_EQ(shape().element_type(),
998 primitive_util::NativeToPrimitiveType<NativeT>());
999 auto data_span = data<NativeT>();
1000 std::copy(values.begin(), values.end(), data_span.begin());
1001 }
1002
1003 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1004 void MutableLiteralBase::PopulateR2(
1005 std::initializer_list<std::initializer_list<NativeT>> values) {
1006 CHECK(shape().IsArray());
1007 CHECK_EQ(shape().rank(), 2);
1008 CHECK_EQ(shape().element_type(),
1009 primitive_util::NativeToPrimitiveType<NativeT>());
1010
1011 const int64 dim0_size = values.size();
1012 const int64 dim1_size = values.begin()->size();
1013 CHECK_EQ(dim0_size, shape().dimensions(0));
1014 CHECK_EQ(dim1_size, shape().dimensions(1));
1015
1016 int64 dim0 = 0;
1017 for (auto inner_list : values) {
1018 int64 dim1 = 0;
1019 for (auto value : inner_list) {
1020 Set({dim0, dim1}, value);
1021 ++dim1;
1022 }
1023 CHECK_EQ(dim1_size, dim1);
1024 ++dim0;
1025 }
1026 }
1027
1028 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1029 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1030 CHECK(shape().IsArray());
1031 CHECK_EQ(shape().element_type(),
1032 primitive_util::NativeToPrimitiveType<NativeT>());
1033 CHECK_EQ(shape().rank(), values.num_dimensions());
1034 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1035 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1036 }
1037 values.Each([this](absl::Span<const int64> indices, NativeT value) {
1038 this->Set(indices, value);
1039 });
1040 }
1041
1042 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1043 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1044 PopulateFromArray(values);
1045 }
1046
1047 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1048 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1049 PopulateFromArray(values);
1050 }
1051
1052 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1053 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1054 PopulateFromArray(values);
1055 }
1056
1057 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1058 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1059 bool parallel) {
1060 const Shape& this_shape = shape();
1061 const int64 rank = this_shape.rank();
1062 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1063 TF_RET_CHECK(this_shape.element_type() ==
1064 primitive_util::NativeToPrimitiveType<NativeT>());
1065 absl::Span<NativeT> literal_data = data<NativeT>();
1066 if (rank > 0) {
1067 StrideConfig stride_config(this_shape, this_shape,
1068 AsInt64Slice(this_shape.dimensions()));
1069 int64 minor_dimension_size =
1070 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1071
1072 auto init_function = [&](absl::Span<const int64> indexes) {
1073 DimensionVector minor_scan_indexes(rank, 0);
1074 const int64 index =
1075 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1076 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1077 for (int64 i = 0; i < minor_dimension_size; ++i) {
1078 minor_scan_indexes[stride_config.minor_dimension] = i;
1079 literal_data.at(index + i) = generator(minor_scan_indexes);
1080 }
1081 };
1082 if (parallel) {
1083 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1084 stride_config.dimensions,
1085 stride_config.step, init_function);
1086 } else {
1087 ShapeUtil::ForEachIndex(
1088 this_shape, stride_config.base, stride_config.dimensions,
1089 stride_config.step,
1090 [&init_function](absl::Span<const int64> indexes) {
1091 init_function(indexes);
1092 return true;
1093 });
1094 }
1095 } else {
1096 // For scalars.
1097 literal_data.at(0) = generator({});
1098 }
1099 return Status::OK();
1100 }
1101 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1102 Status MutableLiteralBase::Populate(const FnType& generator) {
1103 return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1104 }
1105
1106 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1107 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1108 return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1109 }
1110
1111 template <typename NativeT>
PopulateWithValue(NativeT value)1112 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1113 CHECK(shape().IsArray());
1114 CHECK_EQ(shape().element_type(),
1115 primitive_util::NativeToPrimitiveType<NativeT>());
1116 for (NativeT& element : data<NativeT>()) {
1117 element = value;
1118 }
1119 }
1120
1121 template <typename NativeT>
Replicate(int64 times)1122 Literal LiteralBase::Replicate(int64 times) const {
1123 DimensionVector bounds = {times};
1124 bounds.reserve(shape().dimensions_size() + 1);
1125 for (int64 bound : shape().dimensions()) {
1126 bounds.push_back(bound);
1127 }
1128 Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1129 int64 elements = ShapeUtil::ElementsIn(literal.shape());
1130 if (elements == 0) {
1131 return literal;
1132 }
1133
1134 DimensionVector output_indices(bounds.size(), 0);
1135 absl::Span<const int64> input_indices = output_indices;
1136 input_indices.remove_prefix(1);
1137
1138 bool done = false;
1139 while (!done) {
1140 const auto element = Get<NativeT>(input_indices);
1141 literal.Set<NativeT>(output_indices, element);
1142
1143 done = true;
1144 for (int n = 0; n < output_indices.size(); ++n) {
1145 ++output_indices[n];
1146 if (output_indices[n] < bounds[n]) {
1147 done = false;
1148 break;
1149 }
1150 output_indices[n] = 0;
1151 }
1152 }
1153 return literal;
1154 }
1155
1156 } // namespace xla
1157
1158 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1159