1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with Literal protobufs.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/sparse_index_array.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/bitmap.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 #include "tensorflow/core/platform/types.h"
52
53 namespace xla {
54
55 class LiteralUtil {
56 public:
57 LiteralUtil() = delete;
58
59 // Returns a literal scalar representing the first element.
60 static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
61
62 // Creates a new literal of a given rank. To minimize ambiguity (for users
63 // and the compiler) these CreateR[0-2] methods should explicitly specify the
64 // native type. For example:
65 //
66 // CreateR1<float>({1.0, 42.0});
67 // CreateR2<uint32>({{1, 2}, {3, 4}});
68 //
69 // The variants not ending with WithLayout use the default XLA layout for the
70 // literal's linear representation in memory.
71 template <typename NativeT>
72 static Literal CreateR0(NativeT value);
73 template <typename NativeT>
74 static Literal CreateR1(absl::Span<const NativeT> values);
75 static Literal CreateR1(const tensorflow::core::Bitmap& values);
76 template <typename NativeT>
77 static Literal CreateR2(
78 std::initializer_list<std::initializer_list<NativeT>> values);
79 template <typename NativeT>
80 static Literal CreateR2WithLayout(
81 std::initializer_list<std::initializer_list<NativeT>> values,
82 const Layout& layout);
83 template <typename NativeT>
84 static Literal CreateR3(std::initializer_list<
85 std::initializer_list<std::initializer_list<NativeT>>>
86 values);
87 template <typename NativeT>
88 static Literal CreateR3WithLayout(
89 std::initializer_list<
90 std::initializer_list<std::initializer_list<NativeT>>>
91 values,
92 const Layout& layout);
93 template <typename NativeT>
94 static Literal CreateR4(
95 std::initializer_list<std::initializer_list<
96 std::initializer_list<std::initializer_list<NativeT>>>>
97 values);
98 template <typename NativeT>
99 static Literal CreateR4WithLayout(
100 std::initializer_list<std::initializer_list<
101 std::initializer_list<std::initializer_list<NativeT>>>>
102 values,
103 const Layout& layout);
104
105 // Creates a literal with a sparse layout and the given indices and values.
106 // The shape is initialized from the given dimensions. The minor dimension of
107 // the indices array must equal the rank of the shape (i.e. size of the
108 // dimensions array). The major dimension of the indices array must equal the
109 // number of elements in the values array. The maximum number of elements in
110 // the array is taken from the max_indices() value of the index array.
111 //
112 // XLA assumes that sparse literals are in sorted order for all operations. If
113 // the `sort` argument is true, then the indices and values will be sorted
114 // while copying them into the literal. If you have ensured that the indices
115 // and values are already sorted, then you may set the `sort` argument to
116 // false to skip the sorting step.
117 //
118 // For example:
119 //
120 // CreateSparse(
121 // {12, 12, 12},
122 // SparseIndexArray(10, 3,
123 // Array2D{
124 // {0, 1, 2},
125 // {3, 4, 5},
126 // {6, 7, 8},
127 // {9, 10, 11},
128 // }),
129 // {1.0, 2.0 3.0, 4.0})
130 //
131 // This creates an array with shape F64[12,12,12]sparse{10}, that has the
132 // following non-zero values:
133 //
134 // [0, 1, 2]: 1.0
135 // [3, 4, 5]: 2.0
136 // [6, 7, 8]: 3.0
137 // [9, 10, 11]: 4.0
138 //
139 template <typename NativeT>
140 static Literal CreateSparse(absl::Span<const int64> dimensions,
141 SparseIndexArray indices,
142 absl::Span<const NativeT> values,
143 bool sort = true);
144
145 // Creates a scalar literal value zero of the given primitive type.
146 static Literal Zero(PrimitiveType primitive_type);
147 // Creates a scalar literal value one of the given primitive type.
148 static Literal One(PrimitiveType primitive_type);
149 // Creates a scalar literal value containing the minimum value of the given
150 // primitive type. For floating-point types, returns -inf.
151 static Literal MinValue(PrimitiveType primitive_type);
152 // Creates a scalar literal value containing the maximum value of the given
153 // primitive type. For floating-point types, returns inf.
154 static Literal MaxValue(PrimitiveType primitive_type);
155 // Creates a literal of the given shape where each element is `value`.
156 template <typename NativeT>
157 static Literal CreateFullWithDescendingLayout(
158 absl::Span<const int64> dimensions, NativeT value);
159
160 // Creates a new literal from an Array type. The variants not ending with
161 // WithLayout use the default XLA layout for the literal's linear
162 // representation in memory.
163 template <typename NativeT>
164 static Literal CreateFromArray(const Array<NativeT>& values);
165 template <typename NativeT>
166 static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
167 const Layout& layout);
168 template <typename NativeT>
169 static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
170 template <typename NativeT>
171 static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
172 const Layout& layout);
173 template <typename NativeT>
174 static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
175 template <typename NativeT>
176 static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
177 const Layout& layout);
178 template <typename NativeT>
179 static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
180 template <typename NativeT>
181 static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
182 const Layout& layout);
183
184 // Creates a new vector of U8s literal value from a string.
185 static Literal CreateR1U8(absl::string_view value);
186
187 // Creates a linspace-populated literal with the given number of rows and
188 // columns.
189 static Literal CreateR2F32Linspace(float from, float to, int64 rows,
190 int64 cols);
191
192 // Creates a literal that projects the (x, y) dimensions given in values into
193 // the z dimension given by "projection".
194 template <typename NativeT>
195 static Literal CreateR3Projected(
196 std::initializer_list<std::initializer_list<NativeT>> values,
197 int64 projection);
198
199 // Creates a literal that projects the (x, y) dimensions given in values into
200 // the z and p dimensions given.
201 template <typename NativeT>
202 static Literal CreateR4Projected(
203 std::initializer_list<std::initializer_list<NativeT>> values,
204 int64 projection_p, int64 projection_z);
205
206 // Returns an identity matrix (rank 2) with the given row and column count.
207 template <typename NativeT>
208 static Literal MakeIdentityR2(int64 size);
209
210 // Returns a tuple literal composed of given literals. Data is copied from the
211 // given elements into the returned literal.
212 static Literal MakeTuple(absl::Span<const Literal* const> elements);
213
214 static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
215
216 // As above, but intended to be invoked with move semantics; i.e.
217 //
218 // std::vector<Literal> elements = ...;
219 // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
220 //
221 // This would have been declared as an overload, but there is ambiguity
222 // in invocation between the above signature and this one.
223 static Literal MakeTupleOwned(std::vector<Literal> elements);
224
225 // This overload lets you pass a braced list of Literals to
226 // MakeTupleOwned:
227 //
228 // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
229 //
230 // Simply relying on the MakeTupleOwned(std::vector<Literal>)
231 // overload doesn't work because std::initializer_list's elements are always
232 // const.
233 //
234 // The arguments to this function must all be Literal.
235 template <typename... Ts>
MakeTupleOwned(Ts...elements)236 static Literal MakeTupleOwned(Ts... elements) {
237 std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
238 std::vector<Literal> v;
239 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
240 std::make_move_iterator(arr.end()));
241 return MakeTupleOwned(std::move(v));
242 }
243
244 // Create a constant token literal. Token types have no value.
245 static Literal CreateToken();
246
247 // Creates a new Literal object with its values havings the primitive_type
248 // type, and with dimensions defined by the dimensions parameter.
249 // The content of the literal values is the default value of the primitive
250 // type of literal itself (0 for numeric types, and false for predicates).
251 static Literal CreateFromDimensions(PrimitiveType primitive_type,
252 absl::Span<const int64> dimensions);
253
254 // If the given literal's data type is bfloat16, converts it to a float
255 // literal; otherwise, returns a copy of it. If the literal is a tuple,
256 // recursively converts its elements.
257 static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
258
259 // If the given literal's data type is float, converts it to a bfloat16
260 // literal; otherwise, returns a copy of it. If the literal is a tuple,
261 // recursively converts its elements.
262 static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
263
264 // Creates a literal with a new shape with the given new dimensions using the
265 // data in the given input literal. For reshaping purposes the (flat) data
266 // buffer of the input literal is assumed to have the given minor_to_major
267 // layout order.
268 static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
269 absl::Span<const int64> minor_to_major,
270 const LiteralSlice& literal);
271
272 // Creates a literal with the supplied shape, and uses the provided value
273 // generator to populate the literal's values.
274 // Returns the new literal object, or an error Status if failed.
275 template <
276 PrimitiveType type,
277 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
278 static StatusOr<Literal> CreateRandomLiteral(
279 const Shape& shape,
280 const std::function<T(absl::Span<const int64>)>& generator);
281
282 // Creates a literal with the supplied shape, and initializes the literal
283 // values using a normal distribution with given mean and stddev standard
284 // deviation, and using the engine as entropy generator.
285 // Returns the new literal object, or an error Status if failed.
286 template <
287 PrimitiveType type, typename E,
288 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
289 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
290 T mean, T stddev);
291
292 // Creates a literal with the supplied shape, and initializes the literal
293 // values using a normal distribution with given mean and stddev standard
294 // deviation.
295 // Returns the new literal object, or an error Status if failed.
296 template <
297 PrimitiveType type,
298 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
299 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
300 T stddev);
301
302 //
303 // End of factory methods.
304
305 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
306 // be returned for a 2-dimensional index with dimension 0 index equal to 7,
307 // dimension 1 equal to 8.
308 static string MultiIndexAsString(absl::Span<const int64> multi_index);
309 };
310
311 std::ostream& operator<<(std::ostream& out, const Literal& literal);
312
313 template <typename NativeT>
CreateR0(NativeT value)314 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
315 Literal literal(ShapeUtil::MakeShape(
316 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
317 literal.Set({}, value);
318 return literal;
319 }
320
321 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)322 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
323 Literal literal(
324 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
325 {static_cast<int64>(values.size())}));
326 literal.PopulateR1(values);
327 return literal;
328 }
329
330 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)331 /* static */ Literal LiteralUtil::CreateR2WithLayout(
332 std::initializer_list<std::initializer_list<NativeT>> values,
333 const Layout& layout) {
334 Literal literal(ShapeUtil::MakeShapeWithLayout(
335 primitive_util::NativeToPrimitiveType<NativeT>(),
336 {static_cast<int64>(values.size()),
337 static_cast<int64>(values.begin()->size())},
338 AsInt64Slice(layout.minor_to_major())));
339 literal.PopulateR2(values);
340 return literal;
341 }
342
343 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)344 /* static */ Literal LiteralUtil::CreateR2(
345 std::initializer_list<std::initializer_list<NativeT>> values) {
346 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
347 }
348
349 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)350 /* static */ Literal LiteralUtil::CreateR3WithLayout(
351 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
352 values,
353 const Layout& layout) {
354 const int64 d0 = values.size();
355 const int64 d1 = values.begin()->size();
356 const int64 d2 = values.begin()->begin()->size();
357 Array3D<NativeT> tmp(d0, d1, d2);
358 int64 i0 = 0;
359 for (auto d1_values : values) {
360 int64 i1 = 0;
361 for (auto d2_values : d1_values) {
362 int64 i2 = 0;
363 for (auto value : d2_values) {
364 tmp(i0, i1, i2) = value;
365 ++i2;
366 }
367 ++i1;
368 }
369 ++i0;
370 }
371 return CreateR3FromArray3DWithLayout(tmp, layout);
372 }
373
374 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)375 /* static */ Literal LiteralUtil::CreateR3(
376 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
377 values) {
378 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
379 }
380
381 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)382 /* static */ Literal LiteralUtil::CreateR4WithLayout(
383 std::initializer_list<std::initializer_list<
384 std::initializer_list<std::initializer_list<NativeT>>>>
385 values,
386 const Layout& layout) {
387 const int64 d0 = values.size();
388 const int64 d1 = values.begin()->size();
389 const int64 d2 = values.begin()->begin()->size();
390 const int64 d3 = values.begin()->begin()->begin()->size();
391 Array4D<NativeT> tmp(d0, d1, d2, d3);
392 int64 i0 = 0;
393 for (auto d1_values : values) {
394 int64 i1 = 0;
395 for (auto d2_values : d1_values) {
396 int64 i2 = 0;
397 for (auto d3_values : d2_values) {
398 int64 i3 = 0;
399 for (auto value : d3_values) {
400 tmp(i0, i1, i2, i3) = value;
401 ++i3;
402 }
403 ++i2;
404 }
405 ++i1;
406 }
407 ++i0;
408 }
409 return CreateR4FromArray4DWithLayout(tmp, layout);
410 }
411
412 template <typename NativeT>
CreateSparse(absl::Span<const int64> dimensions,SparseIndexArray indices,absl::Span<const NativeT> values,bool sort)413 /* static */ Literal LiteralUtil::CreateSparse(
414 absl::Span<const int64> dimensions, SparseIndexArray indices,
415 absl::Span<const NativeT> values, bool sort) {
416 int64 num_elements = values.size();
417 int64 rank = dimensions.size();
418 CHECK_EQ(num_elements, indices.index_count());
419 CHECK_EQ(rank, indices.rank());
420 Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
421 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
422 indices.max_indices()));
423 literal.PopulateSparse(indices, values, sort);
424 return literal;
425 }
426
427 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)428 /* static */ Literal LiteralUtil::CreateR4(
429 std::initializer_list<std::initializer_list<
430 std::initializer_list<std::initializer_list<NativeT>>>>
431 values) {
432 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
433 }
434
435 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)436 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
437 const Array<NativeT>& values, const Layout& layout) {
438 Literal literal(ShapeUtil::MakeShapeWithLayout(
439 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
440 AsInt64Slice(layout.minor_to_major())));
441 literal.PopulateFromArray(values);
442 return literal;
443 }
444
445 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)446 /* static */ Literal LiteralUtil::CreateFromArray(
447 const Array<NativeT>& values) {
448 return CreateFromArrayWithLayout(
449 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
450 }
451
452 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)453 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
454 const Array2D<NativeT>& values, const Layout& layout) {
455 return CreateFromArrayWithLayout(values, layout);
456 }
457
458 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)459 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
460 const Array2D<NativeT>& values) {
461 return CreateFromArray(values);
462 }
463
464 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)465 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
466 const Array3D<NativeT>& values, const Layout& layout) {
467 return CreateFromArrayWithLayout(values, layout);
468 }
469
470 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)471 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
472 const Array3D<NativeT>& values) {
473 return CreateFromArray(values);
474 }
475
476 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)477 /* static */ Literal LiteralUtil::CreateR3Projected(
478 std::initializer_list<std::initializer_list<NativeT>> values,
479 int64 projection) {
480 int64 dim0_size = projection;
481 int64 dim1_size = values.size();
482 int64 dim2_size = values.begin()->size();
483
484 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
485 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
486 int64 dim1 = 0;
487 for (auto inner_list : values) {
488 int64 dim2 = 0;
489 for (auto value : inner_list) {
490 array(dim0, dim1, dim2) = value;
491 ++dim2;
492 }
493 CHECK_EQ(dim2_size, dim2);
494 ++dim1;
495 }
496 CHECK_EQ(dim1_size, dim1);
497 }
498 return CreateR3FromArray3D(array);
499 }
500
501 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)502 /* static */ Literal LiteralUtil::CreateR4Projected(
503 std::initializer_list<std::initializer_list<NativeT>> values,
504 int64 projection_p, int64 projection_z) {
505 int64 dim0_size = projection_p;
506 int64 dim1_size = projection_z;
507 int64 dim2_size = values.size();
508 int64 dim3_size = values.begin()->size();
509
510 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
511 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
512 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
513 int64 dim2 = 0;
514 for (auto inner_list : values) {
515 int64 dim3 = 0;
516 for (auto value : inner_list) {
517 array(dim0, dim1, dim2, dim3) = value;
518 ++dim3;
519 }
520 CHECK_EQ(dim3_size, dim3);
521 ++dim2;
522 }
523 CHECK_EQ(dim2_size, dim2);
524 }
525 }
526 return CreateR4FromArray4D(array);
527 }
528
529 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)530 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
531 const Array4D<NativeT>& values) {
532 return CreateFromArray(values);
533 }
534
535 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)536 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
537 const Array4D<NativeT>& values, const Layout& layout) {
538 return CreateFromArrayWithLayout(values, layout);
539 }
540
541 // Returns an identity matrix (rank 2) with the given row and column count.
542 template <typename NativeT>
MakeIdentityR2(int64 size)543 /* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
544 Array2D<NativeT> array(size, size, 0);
545 for (int64 i = 0; i < size; ++i) {
546 array(i, i) = 1;
547 }
548 return CreateR2FromArray2D(array);
549 }
550
551 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)552 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
553 absl::Span<const int64> dimensions, NativeT value) {
554 Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
555 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
556 literal.PopulateWithValue(value);
557 return literal;
558 }
559
560 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)561 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
562 const Shape& shape,
563 const std::function<T(absl::Span<const int64>)>& generator) {
564 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
565 TF_RET_CHECK(shape.element_type() == type);
566 Literal literal(shape);
567 TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
568 [&](absl::Span<const int64> indexes) { return generator(indexes); }));
569 return std::move(literal);
570 }
571
572 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)573 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
574 const Shape& shape, E* engine, T mean, T stddev) {
575 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
576 std::normal_distribution<NativeT> generator(mean, stddev);
577 return CreateRandomLiteral<type, NativeT>(
578 shape,
579 [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
580 }
581
582 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)583 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
584 const Shape& shape, T mean, T stddev) {
585 std::minstd_rand0 engine;
586 return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
587 }
588
589 } // namespace xla
590
591 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
592