1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/array2d.h"
33 #include "tensorflow/compiler/xla/array3d.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56
57 // Abstract base class for literals.
58 class LiteralBase {
59 public:
60 virtual ~LiteralBase() = 0;
61
62 // Literals are equal if they have compatible shapes and the same data
63 // values. Layout is not compared.
64 bool operator==(const LiteralBase& other) const;
65 bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66
67 // Returns the shape of the literal.
shape()68 const Shape& shape() const { return root_piece().subshape(); }
69
70 // Serialize to proto.
71 LiteralProto ToProto() const;
72
73 // Returns a Span of the array for this literal for the given NativeT
74 // (e.g., float). CHECKs if the subshape of the literal at the given
75 // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76 // to native type.
77 template <typename NativeT>
78 absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79
80 // Returns a const pointer to (or size of) the underlying buffer holding the
81 // array at the given shape index. CHECKs if the subshape of the literal at
82 // the given ShapeIndex is not array.
83 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
85
86 // Returns this literal's data as a string. This literal must be a rank-1 U8
87 // array.
88 string GetR1U8AsString() const;
89
90 // Returns a string representation of the literal value. The Shape of the
91 // literal is a prefix of the literal value in the string.
92
93 // Warning: this function can take minutes for multi-million
94 // element Literals.
95 string ToString() const;
96
97 // Similar to ToString, but return the result in a compact
98 // one-line form.
99 string ToStringOneline() const;
100
101 // Returns a string representation of the literal value which does *not*
102 // include the shape string.
103 string ToStringWithoutShape() const;
104
105 // Similar to ToStringWithoutShape, but return the result in a compact
106 // one-line form.
107 string ToStringWithoutShapeOneline() const;
108
109 // Returns a string representation of the literal value which includes the
110 // shape string with its layout.does *not* include the shape string.
111 string ToStringWithLayout() const;
112
113 // Similar to ToStringWithLayout, but return the result in a compact
114 // one-line form.
115 string ToStringWithLayoutOneline() const;
116
117 // Gets an element in the literal at the given index. The multi_index is
118 // CHECKed against the dimension sizes.
119 template <typename NativeT>
120 NativeT Get(absl::Span<const int64> multi_index,
121 const ShapeIndex& shape_index) const;
122 // Overloads of Get for array literals. CHECKs if the literal is not
123 // array-shaped and dense.
124 template <typename NativeT>
125 NativeT Get(absl::Span<const int64> multi_index) const;
126
127 // Get the dynamic size on dim_index in the literal at the given shape_index.
128 int32 GetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index) const;
129 int32 GetDynamicSize(int64_t dim_index) const;
130
131 // Returns the element value at index (0, ..., 0), however many zeroes are
132 // required for that index.
133 template <typename NativeT>
134 NativeT GetFirstElement() const;
135
136 // As above but returns any integer type casted to an int64.
137 absl::optional<int64> GetFirstInteger() const;
138
139 // As Get(), but determines the correct type and converts the value
140 // into text.
141 string GetAsString(absl::Span<const int64> multi_index,
142 const ShapeIndex& shape_index = {}) const;
143
144 // Return whether the value at the specified index is equal to the provided
145 // generic `value` (T must be an arithmetic type).
146 //
147 // Precondition: must be an array.
148 template <typename T>
149 typename std::enable_if<(std::is_arithmetic<T>::value ||
150 std::is_same<T, Eigen::half>::value ||
151 std::is_same<T, bfloat16>::value),
152 bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)153 IsEqualAt(absl::Span<const int64> multi_index, T value) const {
154 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
155 return *as_s64 == value;
156 }
157 complex128 as_complex128 = *GetAsComplex128(multi_index);
158 return as_complex128.imag() == 0 && as_complex128.real() == value;
159 }
160
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)161 bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
162 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
163 return *as_s64 == value.real() && value.imag() == 0;
164 }
165 auto as_complex128 = GetAsComplex128(multi_index);
166 return *as_complex128 == value;
167 }
168
169 // As Get(), but determines the correct type and converts the value into
170 // int64. This literal must be an array.
171 absl::optional<int64> GetIntegralAsS64(
172 absl::Span<const int64> multi_index) const;
173
174 // As Get(), but determines the correct type, and converts the value into
175 // double. This literal must be an array.
176 absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
177
178 // As Get(), but determines the correct type, and converts the value into
179 // complex128. All floating point types can be converted into complex128.
180 //
181 // This literal must be an array.
182 absl::optional<complex128> GetAsComplex128(
183 absl::Span<const int64> multi_index) const;
184
185 // Invokes the "per cell" callback for each element in the provided
186 // literal with the element's indices and a string representation of
187 // the element's value.
188 //
189 // This function is useful if you want a polymorphic representation
190 // of the tensor's elements (turning it to a string for something
191 // like representation in a protobuf).
192 //
193 // This literal must have a dense layout.
194 void EachCellAsString(
195 const std::function<void(absl::Span<const int64> indices,
196 const string& value)>& per_cell) const;
197 template <typename NativeT>
198 void EachCell(
199 std::function<void(absl::Span<const int64> indices, NativeT value)>
200 per_cell) const;
201
202 // Returns whether every element in this literal is equal to value.
203 //
204 // value is an int8 because we expect this to be called with small
205 // compile-time constants (0, -1, etc.) and so that whatever value you pass
206 // can be represented exactly by floating-point types as small as 16 bits.
207 //
208 // If value doesn't fit in this literal's type, returns false. Values of 1/0
209 // are considered equal to true/false; other values are not considered equal
210 // to true. Also if this literal is not array-shaped false is returned.
211 bool IsAll(int8_t value) const;
212
213 // Like IsAll(const Literal&, int8), except we check whether the literal is
214 // equal to a particular floating-point number.
215 //
216 // If the literal is not a floating-point value, this always returns false.
217 //
218 // This casts value to the type of literal, then compares using ==. The usual
219 // admonishments about floating-point equality checks apply. We expect you to
220 // use this to check for values that can be expressed precisely as a float,
221 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
222 bool IsAllFloat(float value) const;
223
224 // Like IsAll(const Literal&, int8), except we check whether the literal is
225 // equal to a particular complex number.
226 //
227 // If the literal is not a complex value, this always returns false.
228 //
229 // This casts value to the type of literal, then compares using ==. The usual
230 // admonishments about floating-point equality checks apply. We expect you to
231 // use this to check for complex values that can be expressed precisely as
232 // float pairs e.g. (-0.5, 1.0).
233 //
234 // This literal must have a dense layout.
235 bool IsAllComplex(complex64 value) const;
236
237 // Literal consists entirely of the first element of the literal.
238 bool IsAllFirst() const;
239
240 // Literal consists entirely of an iota.
241 bool IsR1Iota() const;
242
243 // Returns the stride if the literal is a strided iota.
244 absl::optional<int64> IsR1StridedIota() const;
245
246 // Returns whether this literal is zero at the specified index. This literal
247 // must be an array with a dense layout.
248 bool IsZero(absl::Span<const int64> indices) const;
249
250 // Returns the count of the elements in the array at the given shape index in
251 // this literal.
252 int64 element_count(const ShapeIndex& index = {}) const {
253 if (index.empty()) {
254 // Common case, avoid GetSubshape().
255 return ShapeUtil::ElementsIn(shape());
256 }
257 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
258 }
259
260 // Compute a hash for this literal.
261 size_t Hash() const;
262
263 // Converts this literal to the given shape. Returns an error is the
264 // conversion is not possible.
265 StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
266
267 // Converts this literal to another primitive type using a bitcast
268 // conversion. The to and from primitive types must have the same bit
269 // width. Returns an error if the conversion is not possible. This literal
270 // must be array-shaped.
271 StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
272
273 // Converts this literal to another primitive type. Returns an error if the
274 // conversion is not possible. This literal must be array-shaped.
275 StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
276
277 // Clones the underlying buffers into a new Literal.
278 Literal Clone() const;
279
280 // TODO(b/67651157): The methods below which perform computation on Literals
281 // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
282 // evaluator code which operates on Literals.
283 //
284 // Creates a new value that has the equivalent value as this
285 // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
286 // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
287 // minor-to-major dimension layout and the value in the cell at any given
288 // logical index (i0, i1) will be the same.
289 //
290 // For tuple shaped literals, shape_index should be used to select the inner
291 // array that the new layout applies to.
292 //
293 // Note: this is useful when the client wants to ensure that a value placed in
294 // the XLA allocation tracker has a particular layout; for efficiency
295 // purposes or avoiding unimplemented operation/layout combinations.
296 Literal Relayout(const Layout& new_layout,
297 const ShapeIndex& shape_index = {}) const;
298
299 // An overload of Relayout which changes the layout of the entire shape rather
300 // than being limited to a single array within the shape.
301 Literal Relayout(const Shape& shape_with_layout) const;
302
303 // Generate a new literal whose static sizes are equal to the previous
304 // literal's dynamic sizes.
305 Literal ToStatic() const;
306
307 // Expand a static literal into a new one with a bounded dyanmic literal. The
308 // static dimensions of the original literal becomes dynamic dimensions of the
309 // new literal, where the argument `bounded_shape` becomes the bounded shape
310 // of the new literal.
311 //
312 // Precondition: bounded_shape.is_dynamic()
313 Literal ToBoundedDynamic(const Shape& bounded_shape) const;
314
315 // Creates a new literal by reshaping this literal to have the given
316 // dimensions. The total number of elements must not change; The
317 // implementation currently only supports monotonic dim0-major layouts.
318 // This literal must be an array.
319 StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
320
321 // Creates a new literal by broadcasting this literal with `dimensions` to
322 // yield a literal of shape `result_shape`.
323 StatusOr<Literal> Broadcast(const Shape& result_shape,
324 absl::Span<const int64> dimensions) const;
325
326 // Creates a new literal by reordering the dimensions of this literal.
327 // The given `permutation` must be a permutation of the dimension numbers
328 // in the original literal, and it specifies the order of the new dimensions
329 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
330 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
331 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
332 // This literal must be an array.
333 Literal Transpose(absl::Span<const int64> permutation) const;
334
335 // Creates a sub-array from this literal by extracting the indices
336 // [start_index, limit_index) of each dimension. The result literal has the
337 // same rank and layout as for the given literal. The number of indices in
338 // start_indices and limit_indices must be the rank of the literal, and the
339 // indices follow the order of the dimensions.
340 // This literal must be an array.
341 Literal Slice(absl::Span<const int64> start_indices,
342 absl::Span<const int64> limit_indices) const;
343
344 // Creates a literal with a prepended dimension with bound "times"; e.g. a
345 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
346 // literal replicated four times.
347 // This literal must be an array.
348 template <typename NativeT>
349 Literal Replicate(int64_t times) const;
350
351 // Creates a new Literal object with the shape specified as parameter.
352 // The content of the literal values is the default value of the primitive
353 // type of literal itself (0 for numeric types, and false for predicates).
354 //
355 // Note: It's an antipattern to use this method then immediately call
356 // MutableLiteralBase::Populate on the result (since that results in zero
357 // initialization, then reinitialization. Consider if a call to
358 // absl::make_unique<Literal>(shape), followed by the call to
359 // MutableLiteralBase::Populate can be used instead.
360 static Literal CreateFromShape(const Shape& shape);
361
362 protected:
363 // A data structure representing a subshape at a particular ShapeIndex within
364 // the literal. For array-shaped ShapeIndexes, this data structure holds the
365 // pointer to the memory allocated for the array data.
366 class Piece {
367 public:
368 // Returns the buffer holding the array data for this piece as an array
369 // slice. This piece must be array-shaped.
370 template <typename NativeT>
371 absl::Span<const NativeT> data() const;
372 template <typename NativeT>
373 absl::Span<NativeT> data();
374
375 // Returns the buffer holding the array data for this piece as a void*. This
376 // piece must be array-shaped.
377 void* untyped_data();
378 const void* untyped_data() const;
379
380 // Gets or sets an element in the array at the given index. The multi_index
381 // is CHECKed against the dimension sizes of the array. This piece must be
382 // array-shaped.
383 template <typename NativeT>
384 NativeT Get(absl::Span<const int64> index) const;
385 template <typename NativeT>
386 void Set(absl::Span<const int64> index, NativeT value);
387
388 int32 GetDynamicSize(int64_t dim_index) const;
389 void SetDynamicSize(int64_t dim_index, int32_t size);
390 // Gets/sets the buffer holding the array data.
buffer()391 char* buffer() const { return buffer_; }
set_buffer(char * buffer)392 void set_buffer(char* buffer) { buffer_ = buffer; }
393
394 // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()395 int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
set_dynamic_size_buffer(int32 * dynamic_size_buffer)396 void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
397 dynamic_size_buffer_ = dynamic_size_buffer;
398 }
399
dynamic_size_buffer_bytes()400 int64 dynamic_size_buffer_bytes() const {
401 return subshape().dimensions_size() * sizeof(int32);
402 }
403
404 // Gets or sets the subshape of this piece. This reference points to a
405 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()406 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)407 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
408
409 // Returns the size in bytes of the buffer holding the array data.
size_bytes()410 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
411
412 // Returns the number of elements in this piece's array.
element_count()413 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
414
415 // Returns the child piece at 'index' of this piece.
child(int64_t index)416 Piece& child(int64_t index) { return children_[index]; }
417
418 // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)419 void emplace_back(Piece child_piece) {
420 children_.emplace_back(std::move(child_piece));
421 }
422
423 // Returns the size of children pieces of this piece.
children_size()424 int64 children_size() { return children_.size(); }
425
426 // Visitor functions that recursively traverses the piece and calls the
427 // given function at each child piece. The function has the type:
428 // void (const ShapeIndex& index, const Piece& piece)
429 template <typename Fn>
ForEachSubpiece(const Fn & func)430 void ForEachSubpiece(const Fn& func) const {
431 ShapeIndex index;
432 return ForEachHelper(
433 [&func](const ShapeIndex& index, const Piece& piece) {
434 func(index, piece);
435 return Status::OK();
436 },
437 *this, &index)
438 .IgnoreError();
439 }
440 // Same as above, but the function has the type:
441 // Status (const ShapeIndex& index, const Piece& piece)
442 // The first non-OK return value is returned by the function.
443 template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)444 Status ForEachSubpieceWithStatus(const Fn& func) const {
445 ShapeIndex index;
446 return ForEachHelper(func, *this, &index);
447 }
448 // Same as above, but the function has the type:
449 // Bool (const ShapeIndex& index, const Piece& piece)
450 // The first non-true return value is returned by the function.
451 template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)452 bool ForEachSubpieceWithBool(const Fn& func) const {
453 ShapeIndex index;
454 return ForEachHelperBool(func, *this, &index);
455 }
456 // Same as above, but the function has the type:
457 // Void (const ShapeIndex& index, Piece& piece)
458 template <typename Fn>
ForEachMutableSubpiece(const Fn & func)459 void ForEachMutableSubpiece(const Fn& func) {
460 ShapeIndex index;
461 return ForEachMutableHelper(
462 [&func](const ShapeIndex& index, Piece* piece) {
463 func(index, piece);
464 return Status::OK();
465 },
466 const_cast<xla::LiteralBase::Piece*>(this), &index)
467 .IgnoreError();
468 }
469 // Same as above, but the function has the type:
470 // Status (const ShapeIndex& index, Piece& piece)
471 // The first non-OK return value is returned by the function.
472 template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)473 Status ForEachMutableSubpieceWithStatus(const Fn& func) {
474 ShapeIndex index;
475 return ForEachMutableHelper(
476 func, const_cast<xla::LiteralBase::Piece*>(this), &index);
477 }
478
479 // Returns true if this piece and 'other' contain the same data. This piece
480 // and 'other' must be array-shaped and compatible. If a literal has dynamic
481 // shape, comparison is done only for the valid elements.
482 bool EqualElements(const Piece& other) const;
483
484 // Returns true if this piece and other pieces have the same dynamic
485 // dimension sizes.
486 bool EqualDynamicSize(const Piece& other) const;
487
488 // Writes the shape and data (if array-shaped) into the given proto.
489 void WriteToProto(LiteralProto* proto) const;
490
491 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
492 // and src must be compatible. If only_dynamic_bound is true, only elements
493 // within dynamic bounds will be copied.
494 Status CopyFrom(const Piece& src, bool only_dynamic_bound);
495
496 // Copies the data from the given proto into this piece. The shape of this
497 // piece must be equal (not just compatible) to the shape of the proto.
498 Status CopyFromProto(const LiteralProto& proto);
499
500 private:
501 // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
502 // The first non-OK (or non-true) value is returned by the function.
503 // The callable 'func' has the same signature as described above in
504 // ForEachSubpiece*.
505 template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)506 Status ForEachHelper(const Fn& func, const Piece& piece,
507 ShapeIndex* index) const {
508 TF_RETURN_IF_ERROR(func(*index, piece));
509 for (int64_t i = 0; i < piece.children_.size(); ++i) {
510 index->push_back(i);
511 TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
512 index->pop_back();
513 }
514 return Status::OK();
515 }
516 template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)517 bool ForEachHelperBool(const Fn& func, const Piece& piece,
518 ShapeIndex* index) const {
519 if (!func(*index, piece)) {
520 return false;
521 }
522 for (int64_t i = 0; i < piece.children_.size(); ++i) {
523 index->push_back(i);
524 if (!ForEachHelperBool(func, piece.children_[i], index)) {
525 return false;
526 }
527 index->pop_back();
528 }
529 return true;
530 }
531 template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)532 Status ForEachMutableHelper(const Fn& func, Piece* piece,
533 ShapeIndex* index) {
534 TF_RETURN_IF_ERROR(func(*index, piece));
535 for (int64_t i = 0; i < piece->children_.size(); ++i) {
536 index->push_back(i);
537 TF_RETURN_IF_ERROR(
538 ForEachMutableHelper(func, &piece->children_[i], index));
539 index->pop_back();
540 }
541 return Status::OK();
542 }
543
544 // Recursive helper for EqualElements.
545 template <typename NativeT>
546 bool EqualElementsInternal(const Piece& other,
547 std::vector<int64>* multi_index) const;
548
549 // Internal helper to copy elements from another given piece
550 template <typename NativeT>
551 void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
552
553 // For array-shaped pieces, this is the buffer holding the literal data.
554 char* buffer_ = nullptr;
555
556 int32* dynamic_size_buffer_ = nullptr;
557
558 // The shape of piece. This points into the shape of the containing Literal
559 // (Literal::shape_).
560 const Shape* subshape_ = nullptr;
561
562 // Children pieces for tuple shaped pieces.
563 std::vector<Piece> children_ = {};
564 }; // class Piece
565
piece(const ShapeIndex & shape_index)566 const Piece& piece(const ShapeIndex& shape_index) const {
567 Piece* piece = &const_cast<Piece&>(root_piece());
568 for (const auto i : shape_index) {
569 DCHECK_GE(i, 0);
570 DCHECK_LT(i, piece->children_size());
571 piece = &piece->child(i);
572 }
573 return *piece;
574 }
575
576 // Returns the piece at the root of the shape.
577 virtual const Piece& root_piece() const = 0;
578
579 // LiteralSlice and Literal must access Pieces of other Literals.
580 friend class MutableLiteralBase;
581 friend class LiteralSlice;
582 friend class BorrowingLiteral;
583
584 private:
585 template <typename NativeT>
586 Literal SliceInternal(const Shape& result_shape,
587 absl::Span<const int64> start_indices) const;
588 };
589
590 // Abstract base class representing a mutable literal in XLA.
591 class MutableLiteralBase : public LiteralBase {
592 public:
593 virtual ~MutableLiteralBase() = 0;
594
595 // Returns a Span view of the array for this literal for the
596 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
597 // given ShapeIndex is not array. See primitive_util.h for the mapping from
598 // XLA type to native type.
599 template <typename NativeT>
600 absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
601 // Unhide const method from parent class.
602 using LiteralBase::data;
603
604 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
605 // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()606 Shape* mutable_shape_do_not_use() { return shape_.get(); }
607
608 // Set the dynamic size on dim_index in the literal at the given shape_index.
609 void SetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index,
610 int32_t size);
611 void SetDynamicSize(int64_t dim_index, int32_t size);
612
613 // Returns a pointer to the underlying buffer holding the array at the given
614 // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
615 // is not array.
616 void* untyped_data(const ShapeIndex& shape_index = {});
617 // Unhide const method from parent class.
618 using LiteralBase::untyped_data;
619
620 template <typename NativeT>
621 void MutableEachCell(
622 std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
623 per_cell);
624
625 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
626 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
627 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
628 // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
629 // is true, only elements within dynamic bounds will be copied.
630 Status CopyFrom(const LiteralSlice& src_literal,
631 const ShapeIndex& dest_shape_index = {},
632 const ShapeIndex& src_shape_index = {},
633 bool only_dynamic_bound = false);
634
635 // Copies the values from src_literal, starting at src_base shape indexes,
636 // to this literal, starting at dest_base, where the copy size in each
637 // dimension is specified by copy_size.
638 // The src_literal and this literal must have the same primitive type,
639 // src_base+copy_size must fit the source literal dimensions, as well as
640 // dest_base+copy_size must fit the destination literal dimensions.
641 // Note: if either src_literal or this literal contains dimensions with zero
642 // element, then copy_size must be 0 in these dimensions while the
643 // corresponding base indices being 0.
644 // This literal and 'src_literal' must be arrays.
645 Status CopySliceFrom(const LiteralSlice& src_literal,
646 absl::Span<const int64> src_base,
647 absl::Span<const int64> dest_base,
648 absl::Span<const int64> copy_size);
649
650 // Copies one element from src_literal[src_index] to (*this)[dest_index].
651 Status CopyElementFrom(const LiteralSlice& src_literal,
652 absl::Span<const int64> src_index,
653 absl::Span<const int64> dest_index);
654
655 // Sets an element in the literal at the given index. The multi_index is
656 // CHECKed against the dimension sizes.
657 template <typename NativeT>
658 void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
659 NativeT value);
660 // Overloads of Set for array literals. CHECKs if the literal is not
661 // array-shaped and dense.
662 template <typename NativeT>
663 void Set(absl::Span<const int64> multi_index, NativeT value);
664
665 // As Set(), but truncates `value` to the literal element type before storing.
666 // This literal must be an array.
667 Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64_t value);
668
669 // As Set(), but truncates `value` to the literal element type before storing.
670 // This literal must be an array.
671 Status SetFromDouble(absl::Span<const int64> multi_index, double value);
672
673 // Populate this literal with the given values. Examples:
674 //
675 // // Populate with floats.
676 // Array2D<float> float_values = ...
677 // literal.PopulateR2FromArray2D(values);
678 //
679 // // Populate with int32s.
680 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
681 //
682 // The shape and element type of this literal must match given values. For
683 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
684 // array of S32.
685 template <typename NativeT>
686 void PopulateR1(absl::Span<const NativeT> values);
687 void PopulateR1(const tensorflow::core::Bitmap& values);
688 template <typename NativeT>
689 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
690 template <typename NativeT>
691 void PopulateFromArray(const Array<NativeT>& values);
692 template <typename NativeT>
693 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
694 template <typename NativeT>
695 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
696 template <typename NativeT>
697 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
698
699 // Populates literal values by calling the generator function for every cell
700 // in this literal object.
701 //
702 // generator must be a callable of the type
703 // NativeT(absl::Span<int64> indexes) or compatible.
704 //
705 // This literal must have a dense layout.
706 template <typename NativeT, typename FnType>
707 Status Populate(const FnType& generator);
708
709 // A parallel version of Populate(). This can be used if the generator is
710 // thread-safe and the values for the shape's different elements are
711 // independent.
712 template <typename NativeT, typename FnType>
713 Status PopulateParallel(const FnType& generator);
714
715 // Fills this literal with the given value.
716 template <typename NativeT>
717 void PopulateWithValue(NativeT value);
718
719 // This operation is the inverse of DecomposeTuple. The given elements are
720 // moved into the tuple elements of a new tuple-shaped Literal which is
721 // returned. Upon return, each of the Literals in 'elements' is set to a nil
722 // shape (empty tuple).
723 static Literal MoveIntoTuple(absl::Span<Literal> elements);
724
725 // Serialize from a proto.
726 static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
727 bool prohibit_empty_literal = true);
728
729 protected:
730 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)731 Piece& piece(const ShapeIndex& shape_index) {
732 return const_cast<Piece&>(LiteralBase::piece(shape_index));
733 }
734
root_piece()735 Piece& root_piece() const override { return *root_piece_; };
736
737 // Internal template helper for the Literal::CopySliceFrom(), matching its
738 // arguments one by one.
739 template <typename NativeT>
740 Status CopySliceFromInternal(const LiteralBase& src_literal,
741 absl::Span<const int64> src_base,
742 absl::Span<const int64> dest_base,
743 absl::Span<const int64> copy_size);
744
745 // Utility structure which is used to create the optimal configuration for
746 // a ShapeUtil::ForEachIndex() scan across two literals.
747 struct StrideConfig {
748 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
749 absl::Span<const int64> dimensions);
750
751 // The dimensions of the stride operation. Essentially every dimension
752 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
753 // steps.
754 absl::Span<const int64> dimensions;
755 DimensionVector base;
756 DimensionVector step;
757 int64 minor_dimension = 0;
758 // The size of the strides for source and destination. One of the two
759 // (the one looping through its most minor dimension) will be 1, while
760 // the other will be the stride size at the dimension matching the other
761 // shape most minor dimension being scanned.
762 int64 dest_stride = 1;
763 int64 source_stride = 1;
764 // The size of the inner loop on the most minor dimension.
765 int64 minor_loop_size = 1;
766 };
767
768 // Literal class always owns the shape. The parent class borrows this shape.
769 std::unique_ptr<Shape> shape_;
770
771 Piece* root_piece_ = nullptr;
772
773 // Implementation details shared between Populate() and PopulateParallel()
774 template <typename NativeT, typename FnType>
775 Status PopulateInternal(const FnType& generator, bool parallel);
776
777 friend class LiteralBase;
778 friend class MutableBorrowingLiteral;
779 };
780 std::ostream& operator<<(std::ostream& out, const Literal& literal);
781
782 // The underlying buffer and shape is always owned by this class.
783 class Literal : public MutableLiteralBase {
784 public:
Literal()785 Literal() : Literal(ShapeUtil::MakeNil()) {}
786
787 // Create a literal of the given shape. The literal is allocated sufficient
788 // memory to hold the shape. Memory is uninitialized.
789 explicit Literal(const Shape& shape);
790 virtual ~Literal();
791
792 // Literals are moveable, but not copyable. To copy a literal use
793 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
794 // of literals which can be expensive.
795 Literal(const Literal& other) = delete;
796 Literal& operator=(const Literal& other) = delete;
797 Literal(Literal&& other);
798 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
799 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
800 // to nullptr.
801 Literal(const Shape& shape, bool allocate_arrays);
802 Literal& operator=(Literal&& other);
803
804 // Similar to CopyFrom, but with move semantics. The subshape of this literal
805 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
806 // (layouts and shapes must match), but need not be arrays. The memory
807 // allocated in this literal for the subshape at dest_shape_index is
808 // deallocated, and the respective buffers are replaced with those in
809 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
810 virtual Status MoveFrom(Literal&& src_literal,
811 const ShapeIndex& dest_shape_index = {});
812
813 // Returns a vector containing the tuple elements of this Literal as separate
814 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
815 // elements are moved into the new Literals; no data is copied. Upon return
816 // this Literal is set to a nil shape (empty tuple)
817 std::vector<Literal> DecomposeTuple();
818
819 // Returns a subliteral specified by given shape_index. No data is copied, the
820 // current literal becomes invalid after this function call.
821 Literal SubLiteral(ShapeIndexView shape_index);
822
823 private:
824 // Deallocate the buffers held by this literal.
825 void DeallocateBuffers();
826
827 // Recursively sets the subshapes and buffers of all subpieces rooted at
828 // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
829 // the shape.
830 void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
831 };
832
833 // The underlying buffer is not owned by this class and is always owned by
834 // others. The shape is not owned by this class and not mutable.
835 class MutableBorrowingLiteral : public MutableLiteralBase {
836 public:
837 virtual ~MutableBorrowingLiteral();
838
MutableBorrowingLiteral()839 MutableBorrowingLiteral() : MutableLiteralBase() {}
840
841 MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
842 MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
843
844 // Implicit conversion constructors.
845 MutableBorrowingLiteral(MutableLiteralBase* literal);
846 MutableBorrowingLiteral(MutableBorrowingLiteral literal,
847 const ShapeIndex& view_root);
848 MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
849
850 // Create a literal from a list of buffers and a shape.
851 // Returns a tuple literal if `shape` is a tuple type.
852 MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
853
854 private:
855 // Recursively copies the subtree from the `src_piece` at the given child
856 // index to the `dest_piece`. For buffers only the pointers are copied, but
857 // not the content.
858 void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
859 Piece* dest_piece);
860 };
861
862 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
863 // literal buffers always owned by others.
864 class LiteralSlice : public LiteralBase {
865 public:
LiteralSlice()866 LiteralSlice() : LiteralBase() {}
867
868 // Implicit conversion constructors.
869 LiteralSlice(const LiteralBase& literal);
870 LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
871
872 private:
root_piece()873 const Piece& root_piece() const override { return *root_piece_; };
874
875 const Piece* root_piece_; // Not owned.
876 };
877
878 // A read-only Literal where the underlying buffers are never owned by this
879 // class.
880 class BorrowingLiteral : public LiteralBase {
881 public:
BorrowingLiteral()882 BorrowingLiteral() : LiteralBase() {}
883
884 // 'src_buf_ptr' is not owned by this class and must outlive the
885 // lifetime of this class. It points to an appropriately sized buffer with
886 // data interpretered as indicated by 'shape'.
887 // This constructor is only used for array shapes.
888 BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
889 // Similar as above, except to be used for constructing non-nested tuples.
890 BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
891 const Shape& shape);
892 // TODO(b/79707221): adding constructors for nested tuples as well.
893
894 private:
895 // Recursively builds the subtree for the given piece and sets the subshapes
896 // of the given piece with the given shape.
897 void BuildPieceSubtree(const Shape& shape, Piece* piece);
898
899 // Accessor for the root piece of this literal.
root_piece()900 const Piece& root_piece() const override { return root_piece_; };
901 Piece root_piece_;
902
903 // Shape of this literal. Stored as unique_ptr such that the (default) move
904 // construction of this class would be trivially correct: the pointer to Shape
905 // root_piece_ stores will still point to the correct address.
906 std::unique_ptr<Shape> shape_;
907 };
908
909 template <typename NativeT>
data()910 absl::Span<const NativeT> LiteralBase::Piece::data() const {
911 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
912 DCHECK_EQ(subshape().element_type(),
913 primitive_util::NativeToPrimitiveType<NativeT>())
914 << "Attempting to access "
915 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
916 << " type, but literal element type is "
917 << PrimitiveType_Name(subshape().element_type());
918 return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
919 element_count());
920 }
921
922 template <typename NativeT>
data()923 absl::Span<NativeT> LiteralBase::Piece::data() {
924 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
925 DCHECK_EQ(subshape().element_type(),
926 primitive_util::NativeToPrimitiveType<NativeT>())
927 << "Attempting to access "
928 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
929 << " type, but literal element type is "
930 << PrimitiveType_Name(subshape().element_type());
931 return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
932 element_count());
933 }
934
935 template <typename NativeT>
Get(absl::Span<const int64> multi_index)936 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
937 CHECK(LayoutUtil::IsDenseArray(subshape())) << subshape();
938 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
939 subshape(), multi_index)];
940 }
941
942 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)943 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
944 NativeT value) {
945 CHECK(LayoutUtil::IsDenseArray(subshape()));
946 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
947 subshape(), multi_index)] = value;
948 }
949
950 template <typename NativeT>
data(const ShapeIndex & shape_index)951 absl::Span<const NativeT> LiteralBase::data(
952 const ShapeIndex& shape_index) const {
953 return piece(shape_index).data<NativeT>();
954 }
955
956 template <typename NativeT>
data(const ShapeIndex & shape_index)957 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
958 return piece(shape_index).data<NativeT>();
959 }
960
961 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)962 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
963 const ShapeIndex& shape_index) const {
964 return piece(shape_index).Get<NativeT>(multi_index);
965 }
966
967 template <typename NativeT>
Get(absl::Span<const int64> multi_index)968 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
969 return root_piece().Get<NativeT>(multi_index);
970 }
971
972 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)973 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
974 const ShapeIndex& shape_index,
975 NativeT value) {
976 return piece(shape_index).Set<NativeT>(multi_index, value);
977 }
978
979 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)980 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
981 NativeT value) {
982 return root_piece().Set<NativeT>(multi_index, value);
983 }
984
985 template <typename NativeT>
GetFirstElement()986 NativeT LiteralBase::GetFirstElement() const {
987 return data<NativeT>().at(0);
988 }
989
990 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)991 void LiteralBase::EachCell(
992 std::function<void(absl::Span<const int64> indices, NativeT value)>
993 per_cell) const {
994 if (ShapeUtil::IsZeroElementArray(shape())) {
995 return;
996 }
997 std::vector<int64> indices(shape().rank(), 0);
998
999 Shape shape_dynamic = shape();
1000 for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1001 shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1002 }
1003 do {
1004 per_cell(indices, Get<NativeT>(indices));
1005 } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1006 }
1007
1008 template <typename NativeT>
MutableEachCell(std::function<NativeT (absl::Span<const int64> indices,NativeT value)> per_cell)1009 void MutableLiteralBase::MutableEachCell(
1010 std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
1011 per_cell) {
1012 if (ShapeUtil::IsZeroElementArray(shape())) {
1013 return;
1014 }
1015 std::vector<int64> indices(shape().rank(), 0);
1016
1017 Shape shape_dynamic = shape();
1018 for (int64_t i = 0; i < shape_dynamic.rank(); ++i) {
1019 shape_dynamic.set_dimensions(i, GetDynamicSize(i));
1020 }
1021 do {
1022 Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
1023 } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
1024 }
1025
1026 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)1027 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
1028 CHECK(shape().IsArray());
1029 CHECK_EQ(shape().rank(), 1);
1030 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1031 CHECK_EQ(shape().element_type(),
1032 primitive_util::NativeToPrimitiveType<NativeT>());
1033 auto data_span = data<NativeT>();
1034 std::copy(values.begin(), values.end(), data_span.begin());
1035 }
1036
1037 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1038 void MutableLiteralBase::PopulateR2(
1039 std::initializer_list<std::initializer_list<NativeT>> values) {
1040 CHECK(shape().IsArray());
1041 CHECK_EQ(shape().rank(), 2);
1042 CHECK_EQ(shape().element_type(),
1043 primitive_util::NativeToPrimitiveType<NativeT>());
1044
1045 const int64_t dim0_size = values.size();
1046 const int64_t dim1_size = values.begin()->size();
1047 CHECK_EQ(dim0_size, shape().dimensions(0));
1048 CHECK_EQ(dim1_size, shape().dimensions(1));
1049
1050 int64_t dim0 = 0;
1051 for (auto inner_list : values) {
1052 int64_t dim1 = 0;
1053 for (auto value : inner_list) {
1054 Set({dim0, dim1}, value);
1055 ++dim1;
1056 }
1057 CHECK_EQ(dim1_size, dim1);
1058 ++dim0;
1059 }
1060 }
1061
1062 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1063 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1064 CHECK(shape().IsArray());
1065 CHECK_EQ(shape().element_type(),
1066 primitive_util::NativeToPrimitiveType<NativeT>());
1067 CHECK_EQ(shape().rank(), values.num_dimensions());
1068 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1069 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1070 }
1071 values.Each([this](absl::Span<const int64> indices, NativeT value) {
1072 this->Set(indices, value);
1073 });
1074 }
1075
1076 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1077 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1078 PopulateFromArray(values);
1079 }
1080
1081 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1082 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1083 PopulateFromArray(values);
1084 }
1085
1086 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1087 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1088 PopulateFromArray(values);
1089 }
1090
1091 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1092 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1093 bool parallel) {
1094 const Shape& this_shape = shape();
1095 const int64_t rank = this_shape.rank();
1096 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1097 TF_RET_CHECK(this_shape.element_type() ==
1098 primitive_util::NativeToPrimitiveType<NativeT>());
1099 absl::Span<NativeT> literal_data = data<NativeT>();
1100 if (rank > 0) {
1101 StrideConfig stride_config(this_shape, this_shape,
1102 AsInt64Slice(this_shape.dimensions()));
1103 int64_t minor_dimension_size =
1104 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1105
1106 auto init_function = [&](absl::Span<const int64> indexes) {
1107 DimensionVector minor_scan_indexes(rank, 0);
1108 const int64_t index =
1109 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1110 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1111 for (int64_t i = 0; i < minor_dimension_size; ++i) {
1112 minor_scan_indexes[stride_config.minor_dimension] = i;
1113 literal_data.at(index + i) = generator(minor_scan_indexes);
1114 }
1115 };
1116 if (parallel) {
1117 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1118 stride_config.dimensions,
1119 stride_config.step, init_function);
1120 } else {
1121 ShapeUtil::ForEachIndex(
1122 this_shape, stride_config.base, stride_config.dimensions,
1123 stride_config.step,
1124 [&init_function](absl::Span<const int64> indexes) {
1125 init_function(indexes);
1126 return true;
1127 });
1128 }
1129 } else {
1130 // For scalars.
1131 literal_data.at(0) = generator({});
1132 }
1133 return Status::OK();
1134 }
1135 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1136 Status MutableLiteralBase::Populate(const FnType& generator) {
1137 return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1138 }
1139
1140 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1141 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1142 return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1143 }
1144
1145 template <typename NativeT>
PopulateWithValue(NativeT value)1146 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1147 CHECK(shape().IsArray());
1148 CHECK_EQ(shape().element_type(),
1149 primitive_util::NativeToPrimitiveType<NativeT>());
1150 for (NativeT& element : data<NativeT>()) {
1151 element = value;
1152 }
1153 }
1154
1155 template <typename NativeT>
Replicate(int64_t times)1156 Literal LiteralBase::Replicate(int64_t times) const {
1157 DimensionVector bounds = {times};
1158 bounds.reserve(shape().dimensions_size() + 1);
1159 for (int64_t bound : shape().dimensions()) {
1160 bounds.push_back(bound);
1161 }
1162 Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1163 int64_t elements = ShapeUtil::ElementsIn(literal.shape());
1164 if (elements == 0) {
1165 return literal;
1166 }
1167
1168 DimensionVector output_indices(bounds.size(), 0);
1169 absl::Span<const int64> input_indices = output_indices;
1170 input_indices.remove_prefix(1);
1171
1172 bool done = false;
1173 while (!done) {
1174 const auto element = Get<NativeT>(input_indices);
1175 literal.Set<NativeT>(output_indices, element);
1176
1177 done = true;
1178 for (int n = 0; n < output_indices.size(); ++n) {
1179 ++output_indices[n];
1180 if (output_indices[n] < bounds[n]) {
1181 done = false;
1182 break;
1183 }
1184 output_indices[n] = 0;
1185 }
1186 }
1187 return literal;
1188 }
1189
1190 } // namespace xla
1191
1192 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1193