1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/array2d.h"
32 #include "tensorflow/compiler/xla/array3d.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/sparse_index_array.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56
57 // Abstract base class for literals.
58 class LiteralBase {
59 public:
60 virtual ~LiteralBase() = 0;
61
62 // Literals are equal if they have compatible shapes and the same data
63 // values. Layout is not compared.
64 bool operator==(const LiteralBase& other) const;
65 bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66
67 // Returns the shape of the literal.
shape()68 const Shape& shape() const { return root_piece().subshape(); }
69
70 // Serialize to proto.
71 LiteralProto ToProto() const;
72
73 // Returns a Span of the array for this literal for the given NativeT
74 // (e.g., float). CHECKs if the subshape of the literal at the given
75 // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76 // to native type.
77 template <typename NativeT>
78 absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79
80 // Returns a const pointer to the sparse index array. Returns nullptr if the
81 // literal is not a sparse array.
82 const SparseIndexArray* sparse_indices(
83 const ShapeIndex& shape_index = {}) const;
84
85 // Returns a const pointer to (or size of) the underlying buffer holding the
86 // array at the given shape index. CHECKs if the subshape of the literal at
87 // the given ShapeIndex is not array.
88 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
89 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
90
91 // Returns this literal's data as a string. This literal must be a rank-1 U8
92 // array.
93 string GetR1U8AsString() const;
94
95 // Returns a string representation of the literal value. The Shape of the
96 // literal is a prefix of the literal value in the string.
97
98 // Warning: this function can take minutes for multi-million
99 // element Literals.
100 string ToString() const;
101
102 // Returns a string representation of the literal value which does *not*
103 // include the shape string.
104 string ToStringWithoutShape() const;
105
106 // Returns a string representation of the literal value which includes the
107 // shape string with its layout.does *not* include the shape string.
108 string ToStringWithLayout() const;
109
110 // Gets an element in the literal at the given index. The multi_index is
111 // CHECKed against the dimension sizes.
112 template <typename NativeT>
113 NativeT Get(absl::Span<const int64> multi_index,
114 const ShapeIndex& shape_index) const;
115 // Overloads of Get for array literals. CHECKs if the literal is not
116 // array-shaped and dense.
117 template <typename NativeT>
118 NativeT Get(absl::Span<const int64> multi_index) const;
119
120 // Returns the element value at index (0, ..., 0), however many zeroes are
121 // required for that index.
122 template <typename NativeT>
123 NativeT GetFirstElement() const;
124
125 // As Get(), but determines the correct type and converts the value
126 // into text.
127 string GetAsString(absl::Span<const int64> multi_index,
128 const ShapeIndex& shape_index = {}) const;
129 // As GetSparseElement(), but determines the correct type and converts the
130 // value into text.
131 string GetSparseElementAsString(int64 sparse_element_number,
132 const ShapeIndex& shape_index = {}) const;
133 // As Get(), but determines the correct type and converts the value into
134 // int64. This literal must be an array.
135 StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
136
137 // Returns the multi-index of the element in a sparse literal at the given
138 // sparse element number. The sparse element number is the position with in
139 // the sparse array's list of (index, value) pairs, and is checked against the
140 // total number of (index, value) pairs in the sparse array.
141 absl::Span<const int64> GetSparseIndex(
142 int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
143
144 // Returns the value of the element in a sparse literal at the given sparse
145 // element number. The sparse element number is the position with in the
146 // sparse array's list of (index, value) pairs, and is checked against the
147 // total number of (index, value) pairs in the sparse array.
148 template <typename NativeT>
149 NativeT GetSparseElement(int64 sparse_element_number,
150 const ShapeIndex& shape_index = {}) const;
151
152 // Invokes the "per cell" callback for each element in the provided
153 // literal with the element's indices and a string representation of
154 // the element's value.
155 //
156 // This function is useful if you want a polymorphic representation
157 // of the tensor's elements (turning it to a string for something
158 // like representation in a protobuf).
159 //
160 // This literal must have a dense layout.
161 void EachCellAsString(
162 const std::function<void(absl::Span<const int64> indices,
163 const string& value)>& per_cell) const;
164 template <typename NativeT>
165 void EachCell(
166 std::function<void(absl::Span<const int64> indices, NativeT value)>
167 per_cell) const;
168
169 // Returns whether every element in this literal is equal to value.
170 //
171 // value is an int8 because we expect this to be called with small
172 // compile-time constants (0, -1, etc.) and so that whatever value you pass
173 // can be represented exactly by floating-point types as small as 16 bits.
174 //
175 // If value doesn't fit in this literal's type, returns false. Values of 1/0
176 // are considered equal to true/false; other values are not considered equal
177 // to true. Also if this literal is not array-shaped false is returned.
178 bool IsAll(int8 value) const;
179
180 // Like IsAll(const Literal&, int8), except we check whether the literal is
181 // equal to a particular floating-point number.
182 //
183 // If the literal is not a floating-point value, this always returns false.
184 //
185 // This casts value to the type of literal, then compares using ==. The usual
186 // admonishments about floating-point equality checks apply. We expect you to
187 // use this to check for values that can be expressed precisely as a float,
188 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
189 bool IsAllFloat(float value) const;
190
191 // Like IsAll(const Literal&, int8), except we check whether the literal is
192 // equal to a particular complex number.
193 //
194 // If the literal is not a complex value, this always returns false.
195 //
196 // This casts value to the type of literal, then compares using ==. The usual
197 // admonishments about floating-point equality checks apply. We expect you to
198 // use this to check for complex values that can be expressed precisely as
199 // float pairs e.g. (-0.5, 1.0).
200 //
201 // This literal must have a dense layout.
202 bool IsAllComplex(complex64 value) const;
203
204 // Literal consists entirely of the first element of the literal.
205 bool IsAllFirst() const;
206
207 // Literal consists entirely of an iota.
208 bool IsR1Iota() const;
209
210 // Returns whether this literal is zero at the specified index. This literal
211 // must be an array with a dense layout.
212 bool IsZero(absl::Span<const int64> indices) const;
213
214 // Returns the count of the elements in the array at the given shape index in
215 // this literal.
216 int64 element_count(const ShapeIndex& index = {}) const {
217 if (index.empty()) {
218 // Common case, avoid GetSubshape().
219 return ShapeUtil::ElementsIn(shape());
220 }
221 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
222 }
223
224 // Returns the count of the elements in the sparse array at the given shape
225 // index in this literal, which will be no larger than
226 // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
227 int64 sparse_element_count() const;
228
229 // Compute a hash for this literal. This literal must not be a sparse tensor
230 // or a tuple containing a sparse tensor.
231 size_t Hash() const;
232
233 // Converts this literal to the given shape. Returns an error is the
234 // conversion is not possible.
235 StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
236
237 // Converts this literal to another primitive type using a bitcast
238 // conversion. The to and from primitive types must have the same bit
239 // width. Returns an error if the conversion is not possible. This literal
240 // must be array-shaped.
241 StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
242
243 // Converts this literal to another primitive type. Returns an error if the
244 // conversion is not possible. This literal must be array-shaped.
245 StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
246
247 // Clones the underlying buffers into a new Literal.
248 Literal Clone() const;
249
250 // TODO(b/67651157): The methods below which perform computation on Literals
251 // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
252 // evaluator code which operates on Literals.
253 //
254 // Creates a new value that has the equivalent value as this
255 // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
256 // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
257 // minor-to-major dimension layout and the value in the cell at any given
258 // logical index (i0, i1) will be the same.
259 //
260 // For tuple shaped literals, shape_index should be used to select the inner
261 // array that the new layout applies to.
262 //
263 // Note: this is useful when the client wants to ensure that a value placed in
264 // the XLA allocation tracker has a particular layout; for efficiency
265 // purposes or avoiding unimplemented operation/layout combinations.
266 Literal Relayout(const Layout& new_layout,
267 const ShapeIndex& shape_index = {}) const;
268
269 // An overload of Relayout which changes the layout of the entire shape rather
270 // than being limited to a single array within the shape.
271 Literal Relayout(const Shape& shape_with_layout) const;
272
273 // Creates a new literal by reshaping this literal to have the given
274 // dimensions. The total number of elements must not change; The
275 // implementation currently only supports monotonic dim0-major layouts.
276 // This literal must be an array.
277 StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
278
279 // Creates a new literal by broadcasting this literal with `dimensions` to
280 // yield a literal of shape `result_shape`.
281 StatusOr<Literal> Broadcast(const Shape& result_shape,
282 absl::Span<const int64> dimensions) const;
283
284 // Creates a new literal by reordering the dimensions of this literal.
285 // The given `permutation` must be a permutation of the dimension numbers
286 // in the original literal, and it specifies the order of the new dimensions
287 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
288 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
289 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
290 // This literal must be an array.
291 Literal Transpose(absl::Span<const int64> permutation) const;
292
293 // Creates a sub-array from this literal by extracting the indices
294 // [start_index, limit_index) of each dimension. The result literal has the
295 // same rank and layout as for the given literal. The number of indices in
296 // start_indices and limit_indices must be the rank of the literal, and the
297 // indices follow the order of the dimensions.
298 // This literal must be an array.
299 Literal Slice(absl::Span<const int64> start_indices,
300 absl::Span<const int64> limit_indices) const;
301
302 // Creates a literal with a prepended dimension with bound "times"; e.g. a
303 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
304 // literal replicated four times.
305 // This literal must be an array.
306 template <typename NativeT>
307 Literal Replicate(int64 times) const;
308
309 // Creates a new Literal object with the shape specified as parameter.
310 // The content of the literal values is the default value of the primitive
311 // type of literal itself (0 for numeric types, and false for predicates).
312 //
313 // Note: It's an antipattern to use this method then immediately call
314 // MutableLiteralBase::Populate on the result (since that results in zero
315 // initialization, then reinitialization. Consider if a call to
316 // absl::make_unique<Literal>(shape), followed by the call to
317 // MutableLiteralBase::Populate can be used instead.
318 static Literal CreateFromShape(const Shape& shape);
319
320 protected:
321 // A data structure representing a subshape at a particular ShapeIndex within
322 // the literal. For array-shaped ShapeIndexes, this data structure holds the
323 // pointer to the memory allocated for the array data.
324 class Piece {
325 public:
326 // Returns the buffer holding the array data for this piece as an array
327 // slice. This piece must be array-shaped.
328 template <typename NativeT>
329 absl::Span<const NativeT> data() const;
330 template <typename NativeT>
331 absl::Span<NativeT> data();
332
333 // Returns the buffer holding the array data for this piece as a void*. This
334 // piece must be array-shaped.
335 void* untyped_data();
336 const void* untyped_data() const;
337
338 // Gets or sets an element in the array at the given index. The multi_index
339 // is CHECKed against the dimension sizes of the array. This piece must be
340 // array-shaped.
341 template <typename NativeT>
342 NativeT Get(absl::Span<const int64> index) const;
343 template <typename NativeT>
344 void Set(absl::Span<const int64> index, NativeT value);
345
346 // Gets/sets the buffer holding the array data.
buffer()347 char* buffer() const { return buffer_; }
set_buffer(char * buffer)348 void set_buffer(char* buffer) { buffer_ = buffer; }
349
350 // The array of multi-indices that provide the locations of non-zero
351 // elements in a sparse array. Only used if
352 // LayoutUtil::IsSparseArray(shape()) is true.
sparse_indices()353 SparseIndexArray* sparse_indices() const { return sparse_indices_; }
set_sparse_indices(SparseIndexArray * sparse_indices)354 void set_sparse_indices(SparseIndexArray* sparse_indices) {
355 sparse_indices_ = sparse_indices;
356 }
357
358 // Gets or sets the subshape of this piece. This reference points to a
359 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()360 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)361 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
362
363 // Returns the size in bytes of the buffer holding the array data.
size_bytes()364 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
365
366 // Returns the number of elements in this piece's array.
element_count()367 int64 element_count() const {
368 // If this is a sparse array, use the number of elements represented by
369 // the indices in the associated SparseIndexArray.
370 return LayoutUtil::IsSparseArray(subshape())
371 ? sparse_indices()->index_count()
372 : ShapeUtil::ElementsIn(subshape());
373 }
374
375 // Returns the child piece at 'index' of this piece.
child(int64 index)376 Piece& child(int64 index) { return children_[index]; }
377
378 // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)379 void emplace_back(Piece child_piece) {
380 children_.emplace_back(std::move(child_piece));
381 }
382
383 // Returns the size of children pieces of this piece.
children_size()384 int64 children_size() { return children_.size(); }
385
386 // Visitor functions that recursively traverses the piece and calls the
387 // given function at each child piece. The function has the type:
388 // void (const ShapeIndex& index, const Piece& piece)
389 template <typename Fn>
ForEachSubpiece(const Fn & func)390 void ForEachSubpiece(const Fn& func) const {
391 ShapeIndex index;
392 return ForEachHelper(
393 [&func](const ShapeIndex& index, const Piece& piece) {
394 func(index, piece);
395 return Status::OK();
396 },
397 *this, &index)
398 .IgnoreError();
399 }
400 // Same as above, but the function has the type:
401 // Status (const ShapeIndex& index, const Piece& piece)
402 // The first non-OK return value is returned by the function.
403 template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)404 Status ForEachSubpieceWithStatus(const Fn& func) const {
405 ShapeIndex index;
406 return ForEachHelper(func, *this, &index);
407 }
408 // Same as above, but the function has the type:
409 // Bool (const ShapeIndex& index, const Piece& piece)
410 // The first non-true return value is returned by the function.
411 template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)412 bool ForEachSubpieceWithBool(const Fn& func) const {
413 ShapeIndex index;
414 return ForEachHelperBool(func, *this, &index);
415 }
416 // Same as above, but the function has the type:
417 // Void (const ShapeIndex& index, Piece& piece)
418 template <typename Fn>
ForEachMutableSubpiece(const Fn & func)419 void ForEachMutableSubpiece(const Fn& func) {
420 ShapeIndex index;
421 return ForEachMutableHelper(
422 [&func](const ShapeIndex& index, Piece* piece) {
423 func(index, piece);
424 return Status::OK();
425 },
426 const_cast<xla::LiteralBase::Piece*>(this), &index)
427 .IgnoreError();
428 }
429 // Same as above, but the function has the type:
430 // Status (const ShapeIndex& index, Piece& piece)
431 // The first non-OK return value is returned by the function.
432 template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)433 Status ForEachMutableSubpieceWithStatus(const Fn& func) {
434 ShapeIndex index;
435 return ForEachMutableHelper(
436 func, const_cast<xla::LiteralBase::Piece*>(this), &index);
437 }
438
439 // Returns true if this piece and 'other' contain the same data. This piece
440 // and 'other' must be array-shaped and compatible.
441 bool EqualElements(const Piece& other) const;
442
443 // Writes the shape and data (if array-shaped) into the given proto.
444 void WriteToProto(LiteralProto* proto) const;
445
446 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
447 // and src must be compatible.
448 Status CopyFrom(const Piece& src);
449
450 // Copies the data from the given proto into this piece. The shape of this
451 // piece must be equal (not just compatible) to the shape of the proto.
452 Status CopyFromProto(const LiteralProto& proto);
453
454 // Sorts the elements in a sparse array.
455 void SortSparseElements();
456
457 private:
458 // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
459 // The first non-OK (or non-true) value is returned by the function.
460 // The callable 'func' has the same signature as described above in
461 // ForEachSubpiece*.
462 template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)463 Status ForEachHelper(const Fn& func, const Piece& piece,
464 ShapeIndex* index) const {
465 TF_RETURN_IF_ERROR(func(*index, piece));
466 for (int64 i = 0; i < piece.children_.size(); ++i) {
467 index->push_back(i);
468 TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
469 index->pop_back();
470 }
471 return Status::OK();
472 }
473 template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)474 bool ForEachHelperBool(const Fn& func, const Piece& piece,
475 ShapeIndex* index) const {
476 if (!func(*index, piece)) {
477 return false;
478 }
479 for (int64 i = 0; i < piece.children_.size(); ++i) {
480 index->push_back(i);
481 if (!ForEachHelperBool(func, piece.children_[i], index)) {
482 return false;
483 }
484 index->pop_back();
485 }
486 return true;
487 }
488 template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)489 Status ForEachMutableHelper(const Fn& func, Piece* piece,
490 ShapeIndex* index) {
491 TF_RETURN_IF_ERROR(func(*index, piece));
492 for (int64 i = 0; i < piece->children_.size(); ++i) {
493 index->push_back(i);
494 TF_RETURN_IF_ERROR(
495 ForEachMutableHelper(func, &piece->children_[i], index));
496 index->pop_back();
497 }
498 return Status::OK();
499 }
500
501 // Recursive helper for EqualElements.
502 template <typename NativeT>
503 bool EqualElementsInternal(const Piece& other,
504 std::vector<int64>* multi_index) const;
505
506 // Helper for SortSparseElements that has the element type as a template
507 // parameter.
508 template <typename NativeT>
509 void SortSparseElementsInternal();
510
511 // For array-shaped pieces, this is the buffer holding the literal data.
512 char* buffer_ = nullptr;
513
514 // For sparse arrays, this is the array of indices.
515 SparseIndexArray* sparse_indices_ = nullptr;
516
517 // The shape of piece. This points into the shape of the containing Literal
518 // (Literal::shape_).
519 const Shape* subshape_ = nullptr;
520
521 // Children pieces for tuple shaped pieces.
522 std::vector<Piece> children_ = {};
523 }; // class Piece
524
piece(const ShapeIndex & shape_index)525 const Piece& piece(const ShapeIndex& shape_index) const {
526 Piece* piece = &const_cast<Piece&>(root_piece());
527 for (const auto i : shape_index) {
528 DCHECK_GE(i, 0);
529 DCHECK_LT(i, piece->children_size());
530 piece = &piece->child(i);
531 }
532 return *piece;
533 }
534
535 // Returns the piece at the root of the shape.
536 virtual const Piece& root_piece() const = 0;
537
538 // LiteralSlice and Literal must access Pieces of other Literals.
539 friend class MutableLiteralBase;
540 friend class LiteralSlice;
541 friend class BorrowingLiteral;
542
543 private:
544 template <typename NativeT>
545 Literal SliceInternal(const Shape& result_shape,
546 absl::Span<const int64> start_indices) const;
547 };
548
549 // Abstract base class representing a mutable literal in XLA.
550 class MutableLiteralBase : public LiteralBase {
551 public:
552 virtual ~MutableLiteralBase() = 0;
553
554 // Returns a Span view of the array for this literal for the
555 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
556 // given ShapeIndex is not array. See primitive_util.h for the mapping from
557 // XLA type to native type.
558 template <typename NativeT>
559 absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
560 // Unhide const method from parent class.
561 using LiteralBase::data;
562
563 // Returns a pointer to the sparse index array. Returns nullptr if the literal
564 // is not a sparse array.
565 SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
566
567 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
568 // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()569 Shape* mutable_shape_do_not_use() { return shape_.get(); }
570
571 // Returns a pointer to the underlying buffer holding the array at the given
572 // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
573 // is not array.
574 void* untyped_data(const ShapeIndex& shape_index = {});
575 // Unhide const method from parent class.
576 using LiteralBase::untyped_data;
577
578 // Populates a literal with a sparse layout with the given indices and values.
579 // Each index in the indices array is CHECKed against the dimensions in the
580 // literal's shape. If sort is true, then the indices and values will be
581 // sorted. If sort is false, then the indices and values are assumed to
582 // already be in sorted order. See CreateSparse for an example of how data
583 // are populated.
584 template <typename NativeT>
585 void PopulateSparse(SparseIndexArray indices,
586 absl::Span<const NativeT> values, bool sort = true);
587
588 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
589 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
590 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
591 // rooted at 'src_shape_index', but need not be arrays.
592 Status CopyFrom(const LiteralSlice& src_literal,
593 const ShapeIndex& dest_shape_index = {},
594 const ShapeIndex& src_shape_index = {});
595
596 // Copies the values from src_literal, starting at src_base shape indexes,
597 // to this literal, starting at dest_base, where the copy size in each
598 // dimension is specified by copy_size.
599 // The src_literal and this literal must have the same primitive type,
600 // src_base+copy_size must fit the source literal dimensions, as well as
601 // dest_base+copy_size must fit the destination literal dimensions.
602 // Note: if either src_literal or this literal contains dimensions with zero
603 // element, then copy_size must be 0 in these dimensions while the
604 // corresponding base indices being 0.
605 // This literal and 'src_literal' must be arrays.
606 Status CopySliceFrom(const LiteralSlice& src_literal,
607 absl::Span<const int64> src_base,
608 absl::Span<const int64> dest_base,
609 absl::Span<const int64> copy_size);
610
611 // Copies one element from src_literal[src_index] to (*this)[dest_index].
612 Status CopyElementFrom(const LiteralSlice& src_literal,
613 absl::Span<const int64> src_index,
614 absl::Span<const int64> dest_index);
615
616 // Sets an element in the literal at the given index. The multi_index is
617 // CHECKed against the dimension sizes.
618 template <typename NativeT>
619 void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
620 NativeT value);
621 // Overloads of Set for array literals. CHECKs if the literal is not
622 // array-shaped and dense.
623 template <typename NativeT>
624 void Set(absl::Span<const int64> multi_index, NativeT value);
625
626 // Appends the given element to the literal. If the elements are not appended
627 // in sorted order, then SortSparseElements should be called before calling
628 // other methods. This literal must have a sparse layout.
629 template <typename NativeT>
630 void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
631 const ShapeIndex& shape_index = {});
632
633 // Sorts the elements in a sparse array.
634 void SortSparseElements(const ShapeIndex& shape_index = {});
635
636 // As Set(), but truncates `value` to the literal element type before storing.
637 // This literal must be an array.
638 Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
639
640 // Populate this literal with the given values. Examples:
641 //
642 // // Populate with floats.
643 // Array2D<float> float_values = ...
644 // literal.PopulateR2FromArray2D(values);
645 //
646 // // Populate with int32s.
647 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
648 //
649 // The shape and element type of this literal must match given values. For
650 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
651 // array of S32.
652 template <typename NativeT>
653 void PopulateR1(absl::Span<const NativeT> values);
654 void PopulateR1(const tensorflow::core::Bitmap& values);
655 template <typename NativeT>
656 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
657 template <typename NativeT>
658 void PopulateFromArray(const Array<NativeT>& values);
659 template <typename NativeT>
660 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
661 template <typename NativeT>
662 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
663 template <typename NativeT>
664 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
665
666 // Populates literal values by calling the generator function for every cell
667 // in this literal object.
668 //
669 // generator must be a callable of the type
670 // NativeT(absl::Span<int64> indexes) or compatible.
671 //
672 // This literal must have a dense layout.
673 template <typename NativeT, typename FnType>
674 Status Populate(const FnType& generator);
675
676 // A parallel version of Populate(). This can be used if the generator is
677 // thread-safe and the values for the shape's different elements are
678 // independent.
679 template <typename NativeT, typename FnType>
680 Status PopulateParallel(const FnType& generator);
681
682 // Fills this literal with the given value.
683 template <typename NativeT>
684 void PopulateWithValue(NativeT value);
685
686 // This operation is the inverse of DecomposeTuple. The given elements are
687 // moved into the tuple elements of a new tuple-shaped Literal which is
688 // returned. Upon return, each of the Literals in 'elements' is set to a nil
689 // shape (empty tuple).
690 static Literal MoveIntoTuple(absl::Span<Literal> elements);
691
692 // Serialize from a proto.
693 static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
694
695 protected:
696 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)697 Piece& piece(const ShapeIndex& shape_index) {
698 return const_cast<Piece&>(LiteralBase::piece(shape_index));
699 }
700
root_piece()701 Piece& root_piece() const override { return *root_piece_; };
702
703 // Internal template helper for the Literal::CopySliceFrom(), matching its
704 // arguments one by one.
705 template <typename NativeT>
706 Status CopySliceFromInternal(const LiteralBase& src_literal,
707 absl::Span<const int64> src_base,
708 absl::Span<const int64> dest_base,
709 absl::Span<const int64> copy_size);
710
711 // Utility structure which is used to create the optimal configuration for
712 // a ShapeUtil::ForEachIndex() scan across two literals.
713 struct StrideConfig {
714 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
715 absl::Span<const int64> dimensions);
716
717 // The dimensions of the stride operation. Essentially every dimension
718 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
719 // steps.
720 absl::Span<const int64> dimensions;
721 DimensionVector base;
722 DimensionVector step;
723 int64 minor_dimension = 0;
724 // The size of the strides for source and destination. One of the two
725 // (the one looping through its most minor dimension) will be 1, while
726 // the other will be the stride size at the dimension matching the other
727 // shape most minor dimension being scanned.
728 int64 dest_stride = 1;
729 int64 source_stride = 1;
730 // The size of the inner loop on the most minor dimension.
731 int64 minor_loop_size = 1;
732 };
733
734 // Literal class always owns the shape. The parent class borrows this shape.
735 std::unique_ptr<Shape> shape_;
736
737 Piece* root_piece_ = nullptr;
738
739 // Implementation details shared between Populate() and PopulateParallel()
740 template <typename NativeT, typename FnType>
741 Status PopulateInternal(const FnType& generator, bool parallel);
742
743 friend class LiteralBase;
744 friend class MutableBorrowingLiteral;
745 };
746 std::ostream& operator<<(std::ostream& out, const Literal& literal);
747
748 // The underlying buffer and shape is always owned by this class.
749 class Literal : public MutableLiteralBase {
750 public:
Literal()751 Literal() : Literal(ShapeUtil::MakeNil()) {}
752
753 // Create a literal of the given shape. The literal is allocated sufficient
754 // memory to hold the shape. Memory is uninitialized.
755 explicit Literal(const Shape& shape);
756 virtual ~Literal();
757
758 // Literals are moveable, but not copyable. To copy a literal use
759 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
760 // of literals which can be expensive.
761 Literal(const Literal& other) = delete;
762 Literal& operator=(const Literal& other) = delete;
763 Literal(Literal&& other);
764 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
765 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
766 // to nullptr.
767 Literal(const Shape& shape, bool allocate_arrays);
768 Literal& operator=(Literal&& other);
769
770 // Similar to CopyFrom, but with move semantincs. The subshape of this literal
771 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
772 // (layouts and shapes must match), but need not be arrays. The memory
773 // allocated in this literal for the subshape at dest_shape_index is
774 // deallocated, and the respective buffers are replaced with those in
775 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
776 virtual Status MoveFrom(Literal&& src_literal,
777 const ShapeIndex& dest_shape_index = {});
778
779 // Returns a vector containing the tuple elements of this Literal as separate
780 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
781 // elements are moved into the new Literals; no data is copied. Upon return
782 // this Literal is set to a nil shape (empty tuple)
783 std::vector<Literal> DecomposeTuple();
784
785 private:
786 // Deallocate the buffers held by this literal.
787 void DeallocateBuffers();
788
789 // Recursively sets the subshapes and buffers of all subpieces rooted at
790 // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
791 // the shape.
792 void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
793 };
794
795 // The underlying buffer is not owned by this class and is always owned by
796 // others. The shape is not owned by this class and not mutable.
797 class MutableBorrowingLiteral : public MutableLiteralBase {
798 public:
799 virtual ~MutableBorrowingLiteral();
800
MutableBorrowingLiteral()801 MutableBorrowingLiteral() : MutableLiteralBase() {}
802
803 MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
804 MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
805
806 // Implicit conversion constructors.
807 MutableBorrowingLiteral(const MutableLiteralBase& literal);
808 MutableBorrowingLiteral(MutableLiteralBase* literal);
809 MutableBorrowingLiteral(MutableBorrowingLiteral literal,
810 const ShapeIndex& view_root);
811 MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
812
813 private:
814 // Recursively copies the subtree from the `src_piece` at the given child
815 // index to the `dest_piece`. For buffers only the pointers are copied, but
816 // not the content.
817 void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
818 Piece* dest_piece);
819 };
820
821 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
822 // literal buffers always owned by others.
823 class LiteralSlice : public LiteralBase {
824 public:
LiteralSlice()825 LiteralSlice() : LiteralBase() {}
826
827 // Implicit conversion constructors.
828 LiteralSlice(const LiteralBase& literal);
829 LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
830
831 private:
root_piece()832 const Piece& root_piece() const override { return *root_piece_; };
833
834 const Piece* root_piece_; // Not owned.
835 };
836
837 // A read-only Literal where the underlying buffers are never owned by this
838 // class.
839 class BorrowingLiteral : public LiteralBase {
840 public:
BorrowingLiteral()841 BorrowingLiteral() : LiteralBase() {}
842
843 // 'src_buf_ptr' is not owned by this class and must outlive the
844 // lifetime of this class. It points to an appropirately sized buffer with
845 // data interpretered as indicated by 'shape'.
846 // This constructor is only used for array shapes.
847 BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
848 // Similar as above, except to be used for constructing non-nested tuples.
849 BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
850 const Shape& shape);
851 // TODO(b/79707221): adding constructors for nested tuples as well.
852
853 private:
854 // Recursively builds the subtree for the given piece and sets the subshapes
855 // of the given piece with the given shape.
856 void BuildPieceSubtree(const Shape& shape, Piece* piece);
857
858 // Accessor for the root piece of this literal.
root_piece()859 const Piece& root_piece() const override { return root_piece_; };
860 Piece root_piece_;
861
862 // Shape of this literal. Stored as unique_ptr such that the (default) move
863 // construction of this class would be trivially correct: the pointer to Shape
864 // root_piece_ stores will still point to the correct address.
865 std::unique_ptr<Shape> shape_;
866 };
867
868 template <typename NativeT>
data()869 absl::Span<const NativeT> LiteralBase::Piece::data() const {
870 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
871 DCHECK_EQ(subshape().element_type(),
872 primitive_util::NativeToPrimitiveType<NativeT>())
873 << "Attempting to access "
874 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
875 << " type, but literal element type is "
876 << PrimitiveType_Name(subshape().element_type());
877 return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
878 element_count());
879 }
880
881 template <typename NativeT>
data()882 absl::Span<NativeT> LiteralBase::Piece::data() {
883 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
884 DCHECK_EQ(subshape().element_type(),
885 primitive_util::NativeToPrimitiveType<NativeT>())
886 << "Attempting to access "
887 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
888 << " type, but literal element type is "
889 << PrimitiveType_Name(subshape().element_type());
890 return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
891 element_count());
892 }
893
894 template <typename NativeT>
Get(absl::Span<const int64> multi_index)895 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
896 CHECK(LayoutUtil::IsDenseArray(subshape()));
897 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
898 subshape(), multi_index)];
899 }
900
901 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)902 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
903 NativeT value) {
904 CHECK(LayoutUtil::IsDenseArray(subshape()));
905 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
906 subshape(), multi_index)] = value;
907 }
908
909 template <typename NativeT>
data(const ShapeIndex & shape_index)910 absl::Span<const NativeT> LiteralBase::data(
911 const ShapeIndex& shape_index) const {
912 return piece(shape_index).data<NativeT>();
913 }
914
915 template <typename NativeT>
data(const ShapeIndex & shape_index)916 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
917 return piece(shape_index).data<NativeT>();
918 }
919
920 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)921 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
922 const ShapeIndex& shape_index) const {
923 return piece(shape_index).Get<NativeT>(multi_index);
924 }
925
926 template <typename NativeT>
Get(absl::Span<const int64> multi_index)927 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
928 return root_piece().Get<NativeT>(multi_index);
929 }
930
931 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)932 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
933 const ShapeIndex& shape_index,
934 NativeT value) {
935 return piece(shape_index).Set<NativeT>(multi_index, value);
936 }
937
938 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)939 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
940 NativeT value) {
941 return root_piece().Set<NativeT>(multi_index, value);
942 }
943
944 template <typename NativeT>
GetFirstElement()945 NativeT LiteralBase::GetFirstElement() const {
946 return data<NativeT>().at(0);
947 }
948
949 template <typename NativeT>
GetSparseElement(int64 sparse_element_number,const ShapeIndex & shape_index)950 NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
951 const ShapeIndex& shape_index) const {
952 CHECK(
953 LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
954 return data<NativeT>(shape_index)[sparse_element_number];
955 }
956
957 template <typename NativeT>
AppendSparseElement(absl::Span<const int64> multi_index,NativeT value,const ShapeIndex & shape_index)958 void MutableLiteralBase::AppendSparseElement(
959 absl::Span<const int64> multi_index, NativeT value,
960 const ShapeIndex& shape_index) {
961 Piece& p = piece(shape_index);
962 const Shape& subshape = p.subshape();
963 CHECK(LayoutUtil::IsSparseArray(subshape));
964 int64 rank = subshape.rank();
965 CHECK_EQ(multi_index.size(), rank);
966 for (int64 i = 0; i < rank; ++i) {
967 CHECK_GE(multi_index[i], 0);
968 CHECK_LT(multi_index[i], subshape.dimensions(i));
969 }
970 int64 last_element = p.sparse_indices()->index_count();
971 CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
972 p.sparse_indices()->Append(multi_index);
973 CHECK_LT(last_element, p.data<NativeT>().size());
974 p.data<NativeT>()[last_element] = value;
975 }
976
977 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)978 void LiteralBase::EachCell(
979 std::function<void(absl::Span<const int64> indices, NativeT value)>
980 per_cell) const {
981 if (ShapeUtil::IsZeroElementArray(shape())) {
982 return;
983 }
984 std::vector<int64> indices(shape().rank(), 0);
985 do {
986 per_cell(indices, Get<NativeT>(indices));
987 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
988 }
989
990 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)991 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
992 CHECK(shape().IsArray());
993 CHECK_EQ(shape().rank(), 1);
994 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
995 CHECK_EQ(shape().element_type(),
996 primitive_util::NativeToPrimitiveType<NativeT>());
997 auto data_span = data<NativeT>();
998 std::copy(values.begin(), values.end(), data_span.begin());
999 }
1000
1001 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1002 void MutableLiteralBase::PopulateR2(
1003 std::initializer_list<std::initializer_list<NativeT>> values) {
1004 CHECK(shape().IsArray());
1005 CHECK_EQ(shape().rank(), 2);
1006 CHECK_EQ(shape().element_type(),
1007 primitive_util::NativeToPrimitiveType<NativeT>());
1008
1009 const int64 dim0_size = values.size();
1010 const int64 dim1_size = values.begin()->size();
1011 CHECK_EQ(dim0_size, shape().dimensions(0));
1012 CHECK_EQ(dim1_size, shape().dimensions(1));
1013
1014 int64 dim0 = 0;
1015 for (auto inner_list : values) {
1016 int64 dim1 = 0;
1017 for (auto value : inner_list) {
1018 Set({dim0, dim1}, value);
1019 ++dim1;
1020 }
1021 CHECK_EQ(dim1_size, dim1);
1022 ++dim0;
1023 }
1024 }
1025
1026 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1027 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1028 CHECK(shape().IsArray());
1029 CHECK_EQ(shape().element_type(),
1030 primitive_util::NativeToPrimitiveType<NativeT>());
1031 CHECK_EQ(shape().rank(), values.num_dimensions());
1032 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1033 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1034 }
1035 values.Each([this](absl::Span<const int64> indices, NativeT value) {
1036 this->Set(indices, value);
1037 });
1038 }
1039
1040 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1041 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1042 PopulateFromArray(values);
1043 }
1044
1045 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1046 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1047 PopulateFromArray(values);
1048 }
1049
1050 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1051 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1052 PopulateFromArray(values);
1053 }
1054
1055 template <typename NativeT>
PopulateSparse(SparseIndexArray indices,absl::Span<const NativeT> values,bool sort)1056 void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
1057 absl::Span<const NativeT> values,
1058 bool sort) {
1059 CHECK(LayoutUtil::IsSparseArray(shape()));
1060 int rank = shape().rank();
1061 CHECK_EQ(indices.rank(), rank);
1062 int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
1063 CHECK_LE(indices.max_indices(), max_elements);
1064 int64 num_elements = values.size();
1065 CHECK_LE(num_elements, max_elements);
1066 CHECK_EQ(num_elements, indices.index_count());
1067 auto root_data = root_piece().data<NativeT>();
1068 // Piece::data() returns a Span of size equal to the number of indices
1069 // in the SparseIndexArray. So there is no need to adjust the size of the data
1070 // here. It is enough to just copy the incoming values into the data buffer.
1071 std::copy(values.begin(), values.end(), root_data.begin());
1072 *this->root_piece().sparse_indices() = std::move(indices);
1073 if (sort) {
1074 auto root_data = this->root_piece().data<NativeT>();
1075 this->root_piece().sparse_indices()->SortWithValues(root_data);
1076 }
1077 DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
1078 }
1079
1080 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1081 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1082 bool parallel) {
1083 const Shape& this_shape = shape();
1084 const int64 rank = this_shape.rank();
1085 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1086 TF_RET_CHECK(this_shape.element_type() ==
1087 primitive_util::NativeToPrimitiveType<NativeT>());
1088 absl::Span<NativeT> literal_data = data<NativeT>();
1089 if (rank > 0) {
1090 StrideConfig stride_config(this_shape, this_shape,
1091 AsInt64Slice(this_shape.dimensions()));
1092 int64 minor_dimension_size =
1093 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1094
1095 auto init_function = [&](absl::Span<const int64> indexes) {
1096 DimensionVector minor_scan_indexes(rank, 0);
1097 const int64 index =
1098 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1099 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1100 for (int64 i = 0; i < minor_dimension_size; ++i) {
1101 minor_scan_indexes[stride_config.minor_dimension] = i;
1102 literal_data.at(index + i) = generator(minor_scan_indexes);
1103 }
1104 };
1105 if (parallel) {
1106 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1107 stride_config.dimensions,
1108 stride_config.step, init_function);
1109 } else {
1110 ShapeUtil::ForEachIndex(
1111 this_shape, stride_config.base, stride_config.dimensions,
1112 stride_config.step,
1113 [&init_function](absl::Span<const int64> indexes) {
1114 init_function(indexes);
1115 return true;
1116 });
1117 }
1118 } else {
1119 // For scalars.
1120 literal_data.at(0) = generator({});
1121 }
1122 return Status::OK();
1123 }
1124 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1125 Status MutableLiteralBase::Populate(const FnType& generator) {
1126 return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1127 }
1128
1129 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1130 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1131 return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1132 }
1133
1134 template <typename NativeT>
PopulateWithValue(NativeT value)1135 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1136 CHECK(shape().IsArray());
1137 CHECK_EQ(shape().element_type(),
1138 primitive_util::NativeToPrimitiveType<NativeT>());
1139 for (NativeT& element : data<NativeT>()) {
1140 element = value;
1141 }
1142 }
1143
1144 template <typename NativeT>
Replicate(int64 times)1145 Literal LiteralBase::Replicate(int64 times) const {
1146 DimensionVector bounds = {times};
1147 bounds.reserve(shape().dimensions_size() + 1);
1148 for (int64 bound : shape().dimensions()) {
1149 bounds.push_back(bound);
1150 }
1151 Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1152 int64 elements = ShapeUtil::ElementsIn(literal.shape());
1153 if (elements == 0) {
1154 return literal;
1155 }
1156
1157 DimensionVector output_indices(bounds.size(), 0);
1158 absl::Span<const int64> input_indices = output_indices;
1159 input_indices.remove_prefix(1);
1160
1161 bool done = false;
1162 while (!done) {
1163 const auto element = Get<NativeT>(input_indices);
1164 literal.Set<NativeT>(output_indices, element);
1165
1166 done = true;
1167 for (int n = 0; n < output_indices.size(); ++n) {
1168 ++output_indices[n];
1169 if (output_indices[n] < bounds[n]) {
1170 done = false;
1171 break;
1172 }
1173 output_indices[n] = 0;
1174 }
1175 }
1176 return literal;
1177 }
1178
1179 } // namespace xla
1180
1181 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1182