1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/array2d.h"
32 #include "tensorflow/compiler/xla/array3d.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/bitmap.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/macros.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/types.h"
48
49 namespace xla {
50
51 // Forward declare Literal and LiteralSlice class to be used by the creation
52 // methods in the base class.
53 class Literal;
54 class LiteralSlice;
55
56 // Abstract base class for literals.
57 class LiteralBase {
58 public:
59 virtual ~LiteralBase() = 0;
60
61 // Literals are equal if they have compatible shapes and the same data
62 // values. Layout is not compared.
63 bool operator==(const LiteralBase& other) const;
64 bool operator!=(const LiteralBase& other) const { return !(*this == other); }
65
66 // Returns the shape of the literal.
shape()67 const Shape& shape() const { return root_piece().subshape(); }
68
69 // Serialize to proto.
70 LiteralProto ToProto() const;
71
72 // Returns a Span of the array for this literal for the given NativeT
73 // (e.g., float). CHECKs if the subshape of the literal at the given
74 // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
75 // to native type.
76 template <typename NativeT>
77 absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
78
79 // Returns a const pointer to (or size of) the underlying buffer holding the
80 // array at the given shape index. CHECKs if the subshape of the literal at
81 // the given ShapeIndex is not array.
82 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
83 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
84
85 // Returns this literal's data as a string. This literal must be a rank-1 U8
86 // array.
87 string GetR1U8AsString() const;
88
89 // Returns a string representation of the literal value. The Shape of the
90 // literal is a prefix of the literal value in the string.
91
92 // Warning: this function can take minutes for multi-million
93 // element Literals.
94 string ToString() const;
95
96 // Returns a string representation of the literal value which does *not*
97 // include the shape string.
98 string ToStringWithoutShape() const;
99
100 // Returns a string representation of the literal value which includes the
101 // shape string with its layout.does *not* include the shape string.
102 string ToStringWithLayout() const;
103
104 // Gets an element in the literal at the given index. The multi_index is
105 // CHECKed against the dimension sizes.
106 template <typename NativeT>
107 NativeT Get(absl::Span<const int64> multi_index,
108 const ShapeIndex& shape_index) const;
109 // Overloads of Get for array literals. CHECKs if the literal is not
110 // array-shaped and dense.
111 template <typename NativeT>
112 NativeT Get(absl::Span<const int64> multi_index) const;
113
114 // Returns the element value at index (0, ..., 0), however many zeroes are
115 // required for that index.
116 template <typename NativeT>
117 NativeT GetFirstElement() const;
118
119 // As Get(), but determines the correct type and converts the value
120 // into text.
121 string GetAsString(absl::Span<const int64> multi_index,
122 const ShapeIndex& shape_index = {}) const;
123
124 // Return whether the value at the specified index is equal to the provided
125 // generic `value` (T must be an arithmetic type).
126 //
127 // Precondition: must be an array.
128 template <typename T>
129 typename std::enable_if<(std::is_arithmetic<T>::value ||
130 std::is_same<T, Eigen::half>::value ||
131 std::is_same<T, bfloat16>::value),
132 bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)133 IsEqualAt(absl::Span<const int64> multi_index, T value) const {
134 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
135 return *as_s64 == value;
136 }
137 complex128 as_complex128 = *GetAsComplex128(multi_index);
138 return as_complex128.imag() == 0 && as_complex128.real() == value;
139 }
140
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)141 bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
142 if (auto as_s64 = GetIntegralAsS64(multi_index)) {
143 return *as_s64 == value.real() && value.imag() == 0;
144 }
145 auto as_complex128 = GetAsComplex128(multi_index);
146 return *as_complex128 == value;
147 }
148
149 // As Get(), but determines the correct type and converts the value into
150 // int64. This literal must be an array.
151 absl::optional<int64> GetIntegralAsS64(
152 absl::Span<const int64> multi_index) const;
153
154 // As Get(), but determines the correct type, and converts the value into
155 // double. This literal must be an array.
156 absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
157
158 // As Get(), but determines the correct type, and converts the value into
159 // complex128. All floating point types can be converted into complex128.
160 //
161 // This literal must be an array.
162 absl::optional<complex128> GetAsComplex128(
163 absl::Span<const int64> multi_index) const;
164
165 // Invokes the "per cell" callback for each element in the provided
166 // literal with the element's indices and a string representation of
167 // the element's value.
168 //
169 // This function is useful if you want a polymorphic representation
170 // of the tensor's elements (turning it to a string for something
171 // like representation in a protobuf).
172 //
173 // This literal must have a dense layout.
174 void EachCellAsString(
175 const std::function<void(absl::Span<const int64> indices,
176 const string& value)>& per_cell) const;
177 template <typename NativeT>
178 void EachCell(
179 std::function<void(absl::Span<const int64> indices, NativeT value)>
180 per_cell) const;
181
182 // Returns whether every element in this literal is equal to value.
183 //
184 // value is an int8 because we expect this to be called with small
185 // compile-time constants (0, -1, etc.) and so that whatever value you pass
186 // can be represented exactly by floating-point types as small as 16 bits.
187 //
188 // If value doesn't fit in this literal's type, returns false. Values of 1/0
189 // are considered equal to true/false; other values are not considered equal
190 // to true. Also if this literal is not array-shaped false is returned.
191 bool IsAll(int8 value) const;
192
193 // Like IsAll(const Literal&, int8), except we check whether the literal is
194 // equal to a particular floating-point number.
195 //
196 // If the literal is not a floating-point value, this always returns false.
197 //
198 // This casts value to the type of literal, then compares using ==. The usual
199 // admonishments about floating-point equality checks apply. We expect you to
200 // use this to check for values that can be expressed precisely as a float,
201 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
202 bool IsAllFloat(float value) const;
203
204 // Like IsAll(const Literal&, int8), except we check whether the literal is
205 // equal to a particular complex number.
206 //
207 // If the literal is not a complex value, this always returns false.
208 //
209 // This casts value to the type of literal, then compares using ==. The usual
210 // admonishments about floating-point equality checks apply. We expect you to
211 // use this to check for complex values that can be expressed precisely as
212 // float pairs e.g. (-0.5, 1.0).
213 //
214 // This literal must have a dense layout.
215 bool IsAllComplex(complex64 value) const;
216
217 // Literal consists entirely of the first element of the literal.
218 bool IsAllFirst() const;
219
220 // Literal consists entirely of an iota.
221 bool IsR1Iota() const;
222
223 // Returns whether this literal is zero at the specified index. This literal
224 // must be an array with a dense layout.
225 bool IsZero(absl::Span<const int64> indices) const;
226
227 // Returns the count of the elements in the array at the given shape index in
228 // this literal.
229 int64 element_count(const ShapeIndex& index = {}) const {
230 if (index.empty()) {
231 // Common case, avoid GetSubshape().
232 return ShapeUtil::ElementsIn(shape());
233 }
234 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
235 }
236
237 // Compute a hash for this literal.
238 size_t Hash() const;
239
240 // Converts this literal to the given shape. Returns an error is the
241 // conversion is not possible.
242 StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
243
244 // Converts this literal to another primitive type using a bitcast
245 // conversion. The to and from primitive types must have the same bit
246 // width. Returns an error if the conversion is not possible. This literal
247 // must be array-shaped.
248 StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
249
250 // Converts this literal to another primitive type. Returns an error if the
251 // conversion is not possible. This literal must be array-shaped.
252 StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
253
254 // Clones the underlying buffers into a new Literal.
255 Literal Clone() const;
256
257 // TODO(b/67651157): The methods below which perform computation on Literals
258 // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
259 // evaluator code which operates on Literals.
260 //
261 // Creates a new value that has the equivalent value as this
262 // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
263 // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
264 // minor-to-major dimension layout and the value in the cell at any given
265 // logical index (i0, i1) will be the same.
266 //
267 // For tuple shaped literals, shape_index should be used to select the inner
268 // array that the new layout applies to.
269 //
270 // Note: this is useful when the client wants to ensure that a value placed in
271 // the XLA allocation tracker has a particular layout; for efficiency
272 // purposes or avoiding unimplemented operation/layout combinations.
273 Literal Relayout(const Layout& new_layout,
274 const ShapeIndex& shape_index = {}) const;
275
276 // An overload of Relayout which changes the layout of the entire shape rather
277 // than being limited to a single array within the shape.
278 Literal Relayout(const Shape& shape_with_layout) const;
279
280 // Creates a new literal by reshaping this literal to have the given
281 // dimensions. The total number of elements must not change; The
282 // implementation currently only supports monotonic dim0-major layouts.
283 // This literal must be an array.
284 StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
285
286 // Creates a new literal by broadcasting this literal with `dimensions` to
287 // yield a literal of shape `result_shape`.
288 StatusOr<Literal> Broadcast(const Shape& result_shape,
289 absl::Span<const int64> dimensions) const;
290
291 // Creates a new literal by reordering the dimensions of this literal.
292 // The given `permutation` must be a permutation of the dimension numbers
293 // in the original literal, and it specifies the order of the new dimensions
294 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
295 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
296 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
297 // This literal must be an array.
298 Literal Transpose(absl::Span<const int64> permutation) const;
299
300 // Creates a sub-array from this literal by extracting the indices
301 // [start_index, limit_index) of each dimension. The result literal has the
302 // same rank and layout as for the given literal. The number of indices in
303 // start_indices and limit_indices must be the rank of the literal, and the
304 // indices follow the order of the dimensions.
305 // This literal must be an array.
306 Literal Slice(absl::Span<const int64> start_indices,
307 absl::Span<const int64> limit_indices) const;
308
309 // Creates a literal with a prepended dimension with bound "times"; e.g. a
310 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
311 // literal replicated four times.
312 // This literal must be an array.
313 template <typename NativeT>
314 Literal Replicate(int64 times) const;
315
316 // Creates a new Literal object with the shape specified as parameter.
317 // The content of the literal values is the default value of the primitive
318 // type of literal itself (0 for numeric types, and false for predicates).
319 //
320 // Note: It's an antipattern to use this method then immediately call
321 // MutableLiteralBase::Populate on the result (since that results in zero
322 // initialization, then reinitialization. Consider if a call to
323 // absl::make_unique<Literal>(shape), followed by the call to
324 // MutableLiteralBase::Populate can be used instead.
325 static Literal CreateFromShape(const Shape& shape);
326
327 protected:
328 // A data structure representing a subshape at a particular ShapeIndex within
329 // the literal. For array-shaped ShapeIndexes, this data structure holds the
330 // pointer to the memory allocated for the array data.
331 class Piece {
332 public:
333 // Returns the buffer holding the array data for this piece as an array
334 // slice. This piece must be array-shaped.
335 template <typename NativeT>
336 absl::Span<const NativeT> data() const;
337 template <typename NativeT>
338 absl::Span<NativeT> data();
339
340 // Returns the buffer holding the array data for this piece as a void*. This
341 // piece must be array-shaped.
342 void* untyped_data();
343 const void* untyped_data() const;
344
345 // Gets or sets an element in the array at the given index. The multi_index
346 // is CHECKed against the dimension sizes of the array. This piece must be
347 // array-shaped.
348 template <typename NativeT>
349 NativeT Get(absl::Span<const int64> index) const;
350 template <typename NativeT>
351 void Set(absl::Span<const int64> index, NativeT value);
352
353 // Gets/sets the buffer holding the array data.
buffer()354 char* buffer() const { return buffer_; }
set_buffer(char * buffer)355 void set_buffer(char* buffer) { buffer_ = buffer; }
356
357 // Gets or sets the subshape of this piece. This reference points to a
358 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()359 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)360 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
361
362 // Returns the size in bytes of the buffer holding the array data.
size_bytes()363 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
364
365 // Returns the number of elements in this piece's array.
element_count()366 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
367
368 // Returns the child piece at 'index' of this piece.
child(int64 index)369 Piece& child(int64 index) { return children_[index]; }
370
371 // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)372 void emplace_back(Piece child_piece) {
373 children_.emplace_back(std::move(child_piece));
374 }
375
376 // Returns the size of children pieces of this piece.
children_size()377 int64 children_size() { return children_.size(); }
378
379 // Visitor functions that recursively traverses the piece and calls the
380 // given function at each child piece. The function has the type:
381 // void (const ShapeIndex& index, const Piece& piece)
382 template <typename Fn>
ForEachSubpiece(const Fn & func)383 void ForEachSubpiece(const Fn& func) const {
384 ShapeIndex index;
385 return ForEachHelper(
386 [&func](const ShapeIndex& index, const Piece& piece) {
387 func(index, piece);
388 return Status::OK();
389 },
390 *this, &index)
391 .IgnoreError();
392 }
393 // Same as above, but the function has the type:
394 // Status (const ShapeIndex& index, const Piece& piece)
395 // The first non-OK return value is returned by the function.
396 template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)397 Status ForEachSubpieceWithStatus(const Fn& func) const {
398 ShapeIndex index;
399 return ForEachHelper(func, *this, &index);
400 }
401 // Same as above, but the function has the type:
402 // Bool (const ShapeIndex& index, const Piece& piece)
403 // The first non-true return value is returned by the function.
404 template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)405 bool ForEachSubpieceWithBool(const Fn& func) const {
406 ShapeIndex index;
407 return ForEachHelperBool(func, *this, &index);
408 }
409 // Same as above, but the function has the type:
410 // Void (const ShapeIndex& index, Piece& piece)
411 template <typename Fn>
ForEachMutableSubpiece(const Fn & func)412 void ForEachMutableSubpiece(const Fn& func) {
413 ShapeIndex index;
414 return ForEachMutableHelper(
415 [&func](const ShapeIndex& index, Piece* piece) {
416 func(index, piece);
417 return Status::OK();
418 },
419 const_cast<xla::LiteralBase::Piece*>(this), &index)
420 .IgnoreError();
421 }
422 // Same as above, but the function has the type:
423 // Status (const ShapeIndex& index, Piece& piece)
424 // The first non-OK return value is returned by the function.
425 template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)426 Status ForEachMutableSubpieceWithStatus(const Fn& func) {
427 ShapeIndex index;
428 return ForEachMutableHelper(
429 func, const_cast<xla::LiteralBase::Piece*>(this), &index);
430 }
431
432 // Returns true if this piece and 'other' contain the same data. This piece
433 // and 'other' must be array-shaped and compatible.
434 bool EqualElements(const Piece& other) const;
435
436 // Writes the shape and data (if array-shaped) into the given proto.
437 void WriteToProto(LiteralProto* proto) const;
438
439 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
440 // and src must be compatible.
441 Status CopyFrom(const Piece& src);
442
443 // Copies the data from the given proto into this piece. The shape of this
444 // piece must be equal (not just compatible) to the shape of the proto.
445 Status CopyFromProto(const LiteralProto& proto);
446
447 private:
448 // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
449 // The first non-OK (or non-true) value is returned by the function.
450 // The callable 'func' has the same signature as described above in
451 // ForEachSubpiece*.
452 template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)453 Status ForEachHelper(const Fn& func, const Piece& piece,
454 ShapeIndex* index) const {
455 TF_RETURN_IF_ERROR(func(*index, piece));
456 for (int64 i = 0; i < piece.children_.size(); ++i) {
457 index->push_back(i);
458 TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
459 index->pop_back();
460 }
461 return Status::OK();
462 }
463 template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)464 bool ForEachHelperBool(const Fn& func, const Piece& piece,
465 ShapeIndex* index) const {
466 if (!func(*index, piece)) {
467 return false;
468 }
469 for (int64 i = 0; i < piece.children_.size(); ++i) {
470 index->push_back(i);
471 if (!ForEachHelperBool(func, piece.children_[i], index)) {
472 return false;
473 }
474 index->pop_back();
475 }
476 return true;
477 }
478 template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)479 Status ForEachMutableHelper(const Fn& func, Piece* piece,
480 ShapeIndex* index) {
481 TF_RETURN_IF_ERROR(func(*index, piece));
482 for (int64 i = 0; i < piece->children_.size(); ++i) {
483 index->push_back(i);
484 TF_RETURN_IF_ERROR(
485 ForEachMutableHelper(func, &piece->children_[i], index));
486 index->pop_back();
487 }
488 return Status::OK();
489 }
490
491 // Recursive helper for EqualElements.
492 template <typename NativeT>
493 bool EqualElementsInternal(const Piece& other,
494 std::vector<int64>* multi_index) const;
495
496 // For array-shaped pieces, this is the buffer holding the literal data.
497 char* buffer_ = nullptr;
498
499 // The shape of piece. This points into the shape of the containing Literal
500 // (Literal::shape_).
501 const Shape* subshape_ = nullptr;
502
503 // Children pieces for tuple shaped pieces.
504 std::vector<Piece> children_ = {};
505 }; // class Piece
506
piece(const ShapeIndex & shape_index)507 const Piece& piece(const ShapeIndex& shape_index) const {
508 Piece* piece = &const_cast<Piece&>(root_piece());
509 for (const auto i : shape_index) {
510 DCHECK_GE(i, 0);
511 DCHECK_LT(i, piece->children_size());
512 piece = &piece->child(i);
513 }
514 return *piece;
515 }
516
517 // Returns the piece at the root of the shape.
518 virtual const Piece& root_piece() const = 0;
519
520 // LiteralSlice and Literal must access Pieces of other Literals.
521 friend class MutableLiteralBase;
522 friend class LiteralSlice;
523 friend class BorrowingLiteral;
524
525 private:
526 template <typename NativeT>
527 Literal SliceInternal(const Shape& result_shape,
528 absl::Span<const int64> start_indices) const;
529 };
530
531 // Abstract base class representing a mutable literal in XLA.
532 class MutableLiteralBase : public LiteralBase {
533 public:
534 virtual ~MutableLiteralBase() = 0;
535
536 // Returns a Span view of the array for this literal for the
537 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
538 // given ShapeIndex is not array. See primitive_util.h for the mapping from
539 // XLA type to native type.
540 template <typename NativeT>
541 absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
542 // Unhide const method from parent class.
543 using LiteralBase::data;
544
545 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
546 // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()547 Shape* mutable_shape_do_not_use() { return shape_.get(); }
548
549 // Returns a pointer to the underlying buffer holding the array at the given
550 // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
551 // is not array.
552 void* untyped_data(const ShapeIndex& shape_index = {});
553 // Unhide const method from parent class.
554 using LiteralBase::untyped_data;
555
556 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
557 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
558 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
559 // rooted at 'src_shape_index', but need not be arrays.
560 Status CopyFrom(const LiteralSlice& src_literal,
561 const ShapeIndex& dest_shape_index = {},
562 const ShapeIndex& src_shape_index = {});
563
564 // Copies the values from src_literal, starting at src_base shape indexes,
565 // to this literal, starting at dest_base, where the copy size in each
566 // dimension is specified by copy_size.
567 // The src_literal and this literal must have the same primitive type,
568 // src_base+copy_size must fit the source literal dimensions, as well as
569 // dest_base+copy_size must fit the destination literal dimensions.
570 // Note: if either src_literal or this literal contains dimensions with zero
571 // element, then copy_size must be 0 in these dimensions while the
572 // corresponding base indices being 0.
573 // This literal and 'src_literal' must be arrays.
574 Status CopySliceFrom(const LiteralSlice& src_literal,
575 absl::Span<const int64> src_base,
576 absl::Span<const int64> dest_base,
577 absl::Span<const int64> copy_size);
578
579 // Copies one element from src_literal[src_index] to (*this)[dest_index].
580 Status CopyElementFrom(const LiteralSlice& src_literal,
581 absl::Span<const int64> src_index,
582 absl::Span<const int64> dest_index);
583
584 // Sets an element in the literal at the given index. The multi_index is
585 // CHECKed against the dimension sizes.
586 template <typename NativeT>
587 void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
588 NativeT value);
589 // Overloads of Set for array literals. CHECKs if the literal is not
590 // array-shaped and dense.
591 template <typename NativeT>
592 void Set(absl::Span<const int64> multi_index, NativeT value);
593
594 // As Set(), but truncates `value` to the literal element type before storing.
595 // This literal must be an array.
596 Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
597
598 // As Set(), but truncates `value` to the literal element type before storing.
599 // This literal must be an array.
600 Status SetFromDouble(absl::Span<const int64> multi_index, double value);
601
602 // Populate this literal with the given values. Examples:
603 //
604 // // Populate with floats.
605 // Array2D<float> float_values = ...
606 // literal.PopulateR2FromArray2D(values);
607 //
608 // // Populate with int32s.
609 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
610 //
611 // The shape and element type of this literal must match given values. For
612 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
613 // array of S32.
614 template <typename NativeT>
615 void PopulateR1(absl::Span<const NativeT> values);
616 void PopulateR1(const tensorflow::core::Bitmap& values);
617 template <typename NativeT>
618 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
619 template <typename NativeT>
620 void PopulateFromArray(const Array<NativeT>& values);
621 template <typename NativeT>
622 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
623 template <typename NativeT>
624 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
625 template <typename NativeT>
626 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
627
628 // Populates literal values by calling the generator function for every cell
629 // in this literal object.
630 //
631 // generator must be a callable of the type
632 // NativeT(absl::Span<int64> indexes) or compatible.
633 //
634 // This literal must have a dense layout.
635 template <typename NativeT, typename FnType>
636 Status Populate(const FnType& generator);
637
638 // A parallel version of Populate(). This can be used if the generator is
639 // thread-safe and the values for the shape's different elements are
640 // independent.
641 template <typename NativeT, typename FnType>
642 Status PopulateParallel(const FnType& generator);
643
644 // Fills this literal with the given value.
645 template <typename NativeT>
646 void PopulateWithValue(NativeT value);
647
648 // This operation is the inverse of DecomposeTuple. The given elements are
649 // moved into the tuple elements of a new tuple-shaped Literal which is
650 // returned. Upon return, each of the Literals in 'elements' is set to a nil
651 // shape (empty tuple).
652 static Literal MoveIntoTuple(absl::Span<Literal> elements);
653
654 // Serialize from a proto.
655 static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
656 bool prohibit_empty_literal = true);
657
658 protected:
659 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)660 Piece& piece(const ShapeIndex& shape_index) {
661 return const_cast<Piece&>(LiteralBase::piece(shape_index));
662 }
663
root_piece()664 Piece& root_piece() const override { return *root_piece_; };
665
666 // Internal template helper for the Literal::CopySliceFrom(), matching its
667 // arguments one by one.
668 template <typename NativeT>
669 Status CopySliceFromInternal(const LiteralBase& src_literal,
670 absl::Span<const int64> src_base,
671 absl::Span<const int64> dest_base,
672 absl::Span<const int64> copy_size);
673
674 // Utility structure which is used to create the optimal configuration for
675 // a ShapeUtil::ForEachIndex() scan across two literals.
676 struct StrideConfig {
677 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
678 absl::Span<const int64> dimensions);
679
680 // The dimensions of the stride operation. Essentially every dimension
681 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
682 // steps.
683 absl::Span<const int64> dimensions;
684 DimensionVector base;
685 DimensionVector step;
686 int64 minor_dimension = 0;
687 // The size of the strides for source and destination. One of the two
688 // (the one looping through its most minor dimension) will be 1, while
689 // the other will be the stride size at the dimension matching the other
690 // shape most minor dimension being scanned.
691 int64 dest_stride = 1;
692 int64 source_stride = 1;
693 // The size of the inner loop on the most minor dimension.
694 int64 minor_loop_size = 1;
695 };
696
697 // Literal class always owns the shape. The parent class borrows this shape.
698 std::unique_ptr<Shape> shape_;
699
700 Piece* root_piece_ = nullptr;
701
702 // Implementation details shared between Populate() and PopulateParallel()
703 template <typename NativeT, typename FnType>
704 Status PopulateInternal(const FnType& generator, bool parallel);
705
706 friend class LiteralBase;
707 friend class MutableBorrowingLiteral;
708 };
709 std::ostream& operator<<(std::ostream& out, const Literal& literal);
710
711 // The underlying buffer and shape is always owned by this class.
712 class Literal : public MutableLiteralBase {
713 public:
Literal()714 Literal() : Literal(ShapeUtil::MakeNil()) {}
715
716 // Create a literal of the given shape. The literal is allocated sufficient
717 // memory to hold the shape. Memory is uninitialized.
718 explicit Literal(const Shape& shape);
719 virtual ~Literal();
720
721 // Literals are moveable, but not copyable. To copy a literal use
722 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
723 // of literals which can be expensive.
724 Literal(const Literal& other) = delete;
725 Literal& operator=(const Literal& other) = delete;
726 Literal(Literal&& other);
727 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
728 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
729 // to nullptr.
730 Literal(const Shape& shape, bool allocate_arrays);
731 Literal& operator=(Literal&& other);
732
733 // Similar to CopyFrom, but with move semantics. The subshape of this literal
734 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
735 // (layouts and shapes must match), but need not be arrays. The memory
736 // allocated in this literal for the subshape at dest_shape_index is
737 // deallocated, and the respective buffers are replaced with those in
738 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
739 virtual Status MoveFrom(Literal&& src_literal,
740 const ShapeIndex& dest_shape_index = {});
741
742 // Returns a vector containing the tuple elements of this Literal as separate
743 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
744 // elements are moved into the new Literals; no data is copied. Upon return
745 // this Literal is set to a nil shape (empty tuple)
746 std::vector<Literal> DecomposeTuple();
747
748 private:
749 // Deallocate the buffers held by this literal.
750 void DeallocateBuffers();
751
752 // Recursively sets the subshapes and buffers of all subpieces rooted at
753 // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
754 // the shape.
755 void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
756 };
757
758 // The underlying buffer is not owned by this class and is always owned by
759 // others. The shape is not owned by this class and not mutable.
760 class MutableBorrowingLiteral : public MutableLiteralBase {
761 public:
762 virtual ~MutableBorrowingLiteral();
763
MutableBorrowingLiteral()764 MutableBorrowingLiteral() : MutableLiteralBase() {}
765
766 MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
767 MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
768
769 // Implicit conversion constructors.
770 MutableBorrowingLiteral(MutableLiteralBase* literal);
771 MutableBorrowingLiteral(MutableBorrowingLiteral literal,
772 const ShapeIndex& view_root);
773 MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
774
775 private:
776 // Recursively copies the subtree from the `src_piece` at the given child
777 // index to the `dest_piece`. For buffers only the pointers are copied, but
778 // not the content.
779 void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
780 Piece* dest_piece);
781 };
782
783 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
784 // literal buffers always owned by others.
785 class LiteralSlice : public LiteralBase {
786 public:
LiteralSlice()787 LiteralSlice() : LiteralBase() {}
788
789 // Implicit conversion constructors.
790 LiteralSlice(const LiteralBase& literal);
791 LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
792
793 private:
root_piece()794 const Piece& root_piece() const override { return *root_piece_; };
795
796 const Piece* root_piece_; // Not owned.
797 };
798
799 // A read-only Literal where the underlying buffers are never owned by this
800 // class.
801 class BorrowingLiteral : public LiteralBase {
802 public:
BorrowingLiteral()803 BorrowingLiteral() : LiteralBase() {}
804
805 // 'src_buf_ptr' is not owned by this class and must outlive the
806 // lifetime of this class. It points to an appropriately sized buffer with
807 // data interpretered as indicated by 'shape'.
808 // This constructor is only used for array shapes.
809 BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
810 // Similar as above, except to be used for constructing non-nested tuples.
811 BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
812 const Shape& shape);
813 // TODO(b/79707221): adding constructors for nested tuples as well.
814
815 private:
816 // Recursively builds the subtree for the given piece and sets the subshapes
817 // of the given piece with the given shape.
818 void BuildPieceSubtree(const Shape& shape, Piece* piece);
819
820 // Accessor for the root piece of this literal.
root_piece()821 const Piece& root_piece() const override { return root_piece_; };
822 Piece root_piece_;
823
824 // Shape of this literal. Stored as unique_ptr such that the (default) move
825 // construction of this class would be trivially correct: the pointer to Shape
826 // root_piece_ stores will still point to the correct address.
827 std::unique_ptr<Shape> shape_;
828 };
829
830 template <typename NativeT>
data()831 absl::Span<const NativeT> LiteralBase::Piece::data() const {
832 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
833 DCHECK_EQ(subshape().element_type(),
834 primitive_util::NativeToPrimitiveType<NativeT>())
835 << "Attempting to access "
836 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
837 << " type, but literal element type is "
838 << PrimitiveType_Name(subshape().element_type());
839 return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
840 element_count());
841 }
842
843 template <typename NativeT>
data()844 absl::Span<NativeT> LiteralBase::Piece::data() {
845 DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
846 DCHECK_EQ(subshape().element_type(),
847 primitive_util::NativeToPrimitiveType<NativeT>())
848 << "Attempting to access "
849 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
850 << " type, but literal element type is "
851 << PrimitiveType_Name(subshape().element_type());
852 return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
853 element_count());
854 }
855
856 template <typename NativeT>
Get(absl::Span<const int64> multi_index)857 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
858 CHECK(LayoutUtil::IsDenseArray(subshape()));
859 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
860 subshape(), multi_index)];
861 }
862
863 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)864 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
865 NativeT value) {
866 CHECK(LayoutUtil::IsDenseArray(subshape()));
867 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
868 subshape(), multi_index)] = value;
869 }
870
871 template <typename NativeT>
data(const ShapeIndex & shape_index)872 absl::Span<const NativeT> LiteralBase::data(
873 const ShapeIndex& shape_index) const {
874 return piece(shape_index).data<NativeT>();
875 }
876
877 template <typename NativeT>
data(const ShapeIndex & shape_index)878 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
879 return piece(shape_index).data<NativeT>();
880 }
881
882 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)883 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
884 const ShapeIndex& shape_index) const {
885 return piece(shape_index).Get<NativeT>(multi_index);
886 }
887
888 template <typename NativeT>
Get(absl::Span<const int64> multi_index)889 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
890 return root_piece().Get<NativeT>(multi_index);
891 }
892
893 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)894 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
895 const ShapeIndex& shape_index,
896 NativeT value) {
897 return piece(shape_index).Set<NativeT>(multi_index, value);
898 }
899
900 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)901 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
902 NativeT value) {
903 return root_piece().Set<NativeT>(multi_index, value);
904 }
905
906 template <typename NativeT>
GetFirstElement()907 NativeT LiteralBase::GetFirstElement() const {
908 return data<NativeT>().at(0);
909 }
910
911 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)912 void LiteralBase::EachCell(
913 std::function<void(absl::Span<const int64> indices, NativeT value)>
914 per_cell) const {
915 if (ShapeUtil::IsZeroElementArray(shape())) {
916 return;
917 }
918 std::vector<int64> indices(shape().rank(), 0);
919 do {
920 per_cell(indices, Get<NativeT>(indices));
921 } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
922 }
923
924 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)925 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
926 CHECK(shape().IsArray());
927 CHECK_EQ(shape().rank(), 1);
928 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
929 CHECK_EQ(shape().element_type(),
930 primitive_util::NativeToPrimitiveType<NativeT>());
931 auto data_span = data<NativeT>();
932 std::copy(values.begin(), values.end(), data_span.begin());
933 }
934
935 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)936 void MutableLiteralBase::PopulateR2(
937 std::initializer_list<std::initializer_list<NativeT>> values) {
938 CHECK(shape().IsArray());
939 CHECK_EQ(shape().rank(), 2);
940 CHECK_EQ(shape().element_type(),
941 primitive_util::NativeToPrimitiveType<NativeT>());
942
943 const int64 dim0_size = values.size();
944 const int64 dim1_size = values.begin()->size();
945 CHECK_EQ(dim0_size, shape().dimensions(0));
946 CHECK_EQ(dim1_size, shape().dimensions(1));
947
948 int64 dim0 = 0;
949 for (auto inner_list : values) {
950 int64 dim1 = 0;
951 for (auto value : inner_list) {
952 Set({dim0, dim1}, value);
953 ++dim1;
954 }
955 CHECK_EQ(dim1_size, dim1);
956 ++dim0;
957 }
958 }
959
960 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)961 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
962 CHECK(shape().IsArray());
963 CHECK_EQ(shape().element_type(),
964 primitive_util::NativeToPrimitiveType<NativeT>());
965 CHECK_EQ(shape().rank(), values.num_dimensions());
966 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
967 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
968 }
969 values.Each([this](absl::Span<const int64> indices, NativeT value) {
970 this->Set(indices, value);
971 });
972 }
973
974 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)975 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
976 PopulateFromArray(values);
977 }
978
979 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)980 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
981 PopulateFromArray(values);
982 }
983
984 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)985 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
986 PopulateFromArray(values);
987 }
988
989 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)990 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
991 bool parallel) {
992 const Shape& this_shape = shape();
993 const int64 rank = this_shape.rank();
994 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
995 TF_RET_CHECK(this_shape.element_type() ==
996 primitive_util::NativeToPrimitiveType<NativeT>());
997 absl::Span<NativeT> literal_data = data<NativeT>();
998 if (rank > 0) {
999 StrideConfig stride_config(this_shape, this_shape,
1000 AsInt64Slice(this_shape.dimensions()));
1001 int64 minor_dimension_size =
1002 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1003
1004 auto init_function = [&](absl::Span<const int64> indexes) {
1005 DimensionVector minor_scan_indexes(rank, 0);
1006 const int64 index =
1007 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1008 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1009 for (int64 i = 0; i < minor_dimension_size; ++i) {
1010 minor_scan_indexes[stride_config.minor_dimension] = i;
1011 literal_data.at(index + i) = generator(minor_scan_indexes);
1012 }
1013 };
1014 if (parallel) {
1015 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1016 stride_config.dimensions,
1017 stride_config.step, init_function);
1018 } else {
1019 ShapeUtil::ForEachIndex(
1020 this_shape, stride_config.base, stride_config.dimensions,
1021 stride_config.step,
1022 [&init_function](absl::Span<const int64> indexes) {
1023 init_function(indexes);
1024 return true;
1025 });
1026 }
1027 } else {
1028 // For scalars.
1029 literal_data.at(0) = generator({});
1030 }
1031 return Status::OK();
1032 }
1033 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1034 Status MutableLiteralBase::Populate(const FnType& generator) {
1035 return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1036 }
1037
1038 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1039 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1040 return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1041 }
1042
1043 template <typename NativeT>
PopulateWithValue(NativeT value)1044 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1045 CHECK(shape().IsArray());
1046 CHECK_EQ(shape().element_type(),
1047 primitive_util::NativeToPrimitiveType<NativeT>());
1048 for (NativeT& element : data<NativeT>()) {
1049 element = value;
1050 }
1051 }
1052
1053 template <typename NativeT>
Replicate(int64 times)1054 Literal LiteralBase::Replicate(int64 times) const {
1055 DimensionVector bounds = {times};
1056 bounds.reserve(shape().dimensions_size() + 1);
1057 for (int64 bound : shape().dimensions()) {
1058 bounds.push_back(bound);
1059 }
1060 Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1061 int64 elements = ShapeUtil::ElementsIn(literal.shape());
1062 if (elements == 0) {
1063 return literal;
1064 }
1065
1066 DimensionVector output_indices(bounds.size(), 0);
1067 absl::Span<const int64> input_indices = output_indices;
1068 input_indices.remove_prefix(1);
1069
1070 bool done = false;
1071 while (!done) {
1072 const auto element = Get<NativeT>(input_indices);
1073 literal.Set<NativeT>(output_indices, element);
1074
1075 done = true;
1076 for (int n = 0; n < output_indices.size(); ++n) {
1077 ++output_indices[n];
1078 if (output_indices[n] < bounds[n]) {
1079 done = false;
1080 break;
1081 }
1082 output_indices[n] = 0;
1083 }
1084 }
1085 return literal;
1086 }
1087
1088 } // namespace xla
1089
1090 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1091