1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with Literal protobufs.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/types.h"
51
52 namespace xla {
53
54 class LiteralUtil {
55 public:
56 LiteralUtil() = delete;
57
58 // Returns a literal scalar representing the first element.
59 static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
60
61 // Creates a new literal of a given rank. To minimize ambiguity (for users
62 // and the compiler) these CreateR[0-2] methods should explicitly specify the
63 // native type. For example:
64 //
65 // CreateR1<float>({1.0, 42.0});
66 // CreateR2<uint32>({{1, 2}, {3, 4}});
67 //
68 // The variants not ending with WithLayout use the default XLA layout for the
69 // literal's linear representation in memory.
70 template <typename NativeT>
71 static Literal CreateR0(NativeT value);
72 template <typename NativeT>
73 static Literal CreateR1(absl::Span<const NativeT> values);
74 static Literal CreateR1(const tensorflow::core::Bitmap& values);
75 template <typename NativeT>
76 static Literal CreateR2(
77 std::initializer_list<std::initializer_list<NativeT>> values);
78 template <typename NativeT>
79 static Literal CreateR2WithLayout(
80 std::initializer_list<std::initializer_list<NativeT>> values,
81 const Layout& layout);
82 template <typename NativeT>
83 static Literal CreateR3(std::initializer_list<
84 std::initializer_list<std::initializer_list<NativeT>>>
85 values);
86 template <typename NativeT>
87 static Literal CreateR3WithLayout(
88 std::initializer_list<
89 std::initializer_list<std::initializer_list<NativeT>>>
90 values,
91 const Layout& layout);
92 template <typename NativeT>
93 static Literal CreateR4(
94 std::initializer_list<std::initializer_list<
95 std::initializer_list<std::initializer_list<NativeT>>>>
96 values);
97 template <typename NativeT>
98 static Literal CreateR4WithLayout(
99 std::initializer_list<std::initializer_list<
100 std::initializer_list<std::initializer_list<NativeT>>>>
101 values,
102 const Layout& layout);
103
104 // Creates a scalar literal value zero of the given primitive type.
105 static Literal Zero(PrimitiveType primitive_type);
106 // Creates a scalar literal value one of the given primitive type.
107 static Literal One(PrimitiveType primitive_type);
108 // Creates a scalar literal value containing the minimum value of the given
109 // primitive type. For floating-point types, returns -inf.
110 static Literal MinValue(PrimitiveType primitive_type);
111 // Creates a scalar literal value containing the maximum value of the given
112 // primitive type. For floating-point types, returns inf.
113 static Literal MaxValue(PrimitiveType primitive_type);
114 // Creates a scalar literal value containing the NaN value of the given
115 // primitive type. Fail for non-inexact types. For complex types, returns a
116 // nan + nan * j value.
117 static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
118 // Creates a literal of the given shape where each element is `value`.
119 template <typename NativeT>
120 static Literal CreateFullWithDescendingLayout(
121 absl::Span<const int64> dimensions, NativeT value);
122
123 // Creates a new literal from an Array type. The variants not ending with
124 // WithLayout use the default XLA layout for the literal's linear
125 // representation in memory.
126 template <typename NativeT>
127 static Literal CreateFromArray(const Array<NativeT>& values);
128 template <typename NativeT>
129 static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
130 const Layout& layout);
131 template <typename NativeT>
132 static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
133 template <typename NativeT>
134 static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
135 const Layout& layout);
136 template <typename NativeT>
137 static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
138 template <typename NativeT>
139 static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
140 const Layout& layout);
141 template <typename NativeT>
142 static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
143 template <typename NativeT>
144 static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
145 const Layout& layout);
146
147 // Creates a new vector of U8s literal value from a string.
148 static Literal CreateR1U8(absl::string_view value);
149
150 // Creates a linspace-populated literal with the given number of rows and
151 // columns.
152 static Literal CreateR2F32Linspace(float from, float to, int64_t rows,
153 int64_t cols);
154
155 // Creates a literal that projects the (x, y) dimensions given in values into
156 // the z dimension given by "projection".
157 template <typename NativeT>
158 static Literal CreateR3Projected(
159 std::initializer_list<std::initializer_list<NativeT>> values,
160 int64_t projection);
161
162 // Creates a literal that projects the (x, y) dimensions given in values into
163 // the z and p dimensions given.
164 template <typename NativeT>
165 static Literal CreateR4Projected(
166 std::initializer_list<std::initializer_list<NativeT>> values,
167 int64_t projection_p, int64_t projection_z);
168
169 // Returns an identity matrix (rank 2) with the given row and column count.
170 template <typename NativeT>
171 static Literal MakeIdentityR2(int64_t size);
172
173 // Returns a tuple literal composed of given literals. Data is copied from the
174 // given elements into the returned literal.
175 static Literal MakeTuple(absl::Span<const Literal* const> elements);
176
177 static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
178
179 // As above, but intended to be invoked with move semantics; i.e.
180 //
181 // std::vector<Literal> elements = ...;
182 // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
183 //
184 // This would have been declared as an overload, but there is ambiguity
185 // in invocation between the above signature and this one.
186 static Literal MakeTupleOwned(std::vector<Literal> elements);
187
188 // This overload lets you pass a list of Literals to MakeTupleOwned:
189 //
190 // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
191 //
192 // Simply relying on the MakeTupleOwned(std::vector<Literal>)
193 // overload doesn't work because std::initializer_list's elements are always
194 // const.
195 //
196 // The arguments to this function must all be Literal.
197 template <typename... Ts>
MakeTupleOwned(Ts...elements)198 static Literal MakeTupleOwned(Ts... elements) {
199 std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
200 std::vector<Literal> v;
201 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
202 std::make_move_iterator(arr.end()));
203 return MakeTupleOwned(std::move(v));
204 }
205
206 // Create a constant token literal. Token types have no value.
207 static Literal CreateToken();
208
209 // Creates a new Literal object with its values havings the primitive_type
210 // type, and with dimensions defined by the dimensions parameter.
211 // The content of the literal values is the default value of the primitive
212 // type of literal itself (0 for numeric types, and false for predicates).
213 static Literal CreateFromDimensions(PrimitiveType primitive_type,
214 absl::Span<const int64> dimensions);
215
216 // If the given literal's data type is bfloat16, converts it to a float
217 // literal; otherwise, returns a copy of it. If the literal is a tuple,
218 // recursively converts its elements.
219 static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
220
221 // If the given literal's data type is bfloat16, converts it to a double
222 // literal; otherwise, returns a copy of it. If the literal is a tuple,
223 // recursively converts its elements.
224 static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
225
226 // If the given literal's data type is float, converts it to a bfloat16
227 // literal; otherwise, returns a copy of it. If the literal is a tuple,
228 // recursively converts its elements.
229 static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
230
231 // If the given literal's data type is float, converts it to a double
232 // literal; otherwise, returns a copy of it. If the literal is a tuple,
233 // recursively converts its elements.
234 static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
235
236 // If the given literal's data type is double, converts it to a bfloat16
237 // literal; otherwise, returns a copy of it. If the literal is a tuple,
238 // recursively converts its elements.
239 static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
240
241 // Creates a scalar literal whose value is the maximum value of a given
242 // literal slice.
243 static Literal MaxElement(const LiteralSlice& literal);
244
245 // If the given literal's data type is double, converts it to a bfloat16
246 // literal; otherwise, returns a copy of it. If the literal is a tuple,
247 // recursively converts its elements.
248 static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
249
250 // Creates a literal with a new shape with the given new dimensions using the
251 // data in the given input literal. For reshaping purposes the (flat) data
252 // buffer of the input literal is assumed to have the given minor_to_major
253 // layout order.
254 static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
255 absl::Span<const int64> minor_to_major,
256 const LiteralSlice& literal);
257
258 // Creates a literal with the supplied shape, and uses the provided value
259 // generator to populate the literal's values.
260 // Returns the new literal object, or an error Status if failed.
261 template <
262 PrimitiveType type,
263 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
264 static StatusOr<Literal> CreateLiteralWithGenerator(
265 const Shape& shape,
266 const std::function<T(absl::Span<const int64>)>& generator);
267
268 // Creates a literal with the supplied shape, and initializes the literal
269 // values using a normal distribution with given mean and stddev standard
270 // deviation, and using the engine as entropy generator.
271 // Returns the new literal object, or an error Status if failed.
272 template <
273 PrimitiveType type, typename E,
274 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
275 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
276 T mean, T stddev);
277
278 // Creates a literal with the supplied shape, and initializes the literal
279 // values using a normal distribution with given mean and stddev standard
280 // deviation.
281 // Returns the new literal object, or an error Status if failed.
282 template <
283 PrimitiveType type,
284 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
285 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
286 T stddev);
287
288 //
289 // End of factory methods.
290
291 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
292 // be returned for a 2-dimensional index with dimension 0 index equal to 7,
293 // dimension 1 equal to 8.
294 static string MultiIndexAsString(absl::Span<const int64> multi_index);
295 };
296
297 std::ostream& operator<<(std::ostream& out, const Literal& literal);
298
299 template <typename NativeT>
CreateR0(NativeT value)300 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
301 Literal literal(ShapeUtil::MakeShape(
302 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
303 literal.Set({}, value);
304 return literal;
305 }
306
307 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)308 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
309 Literal literal(
310 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
311 {static_cast<int64>(values.size())}));
312 literal.PopulateR1(values);
313 return literal;
314 }
315
316 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)317 /* static */ Literal LiteralUtil::CreateR2WithLayout(
318 std::initializer_list<std::initializer_list<NativeT>> values,
319 const Layout& layout) {
320 Literal literal(ShapeUtil::MakeShapeWithLayout(
321 primitive_util::NativeToPrimitiveType<NativeT>(),
322 {static_cast<int64>(values.size()),
323 static_cast<int64>(values.begin()->size())},
324 AsInt64Slice(layout.minor_to_major())));
325 literal.PopulateR2(values);
326 return literal;
327 }
328
329 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)330 /* static */ Literal LiteralUtil::CreateR2(
331 std::initializer_list<std::initializer_list<NativeT>> values) {
332 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
333 }
334
335 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)336 /* static */ Literal LiteralUtil::CreateR3WithLayout(
337 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
338 values,
339 const Layout& layout) {
340 const int64_t d0 = values.size();
341 const int64_t d1 = values.begin()->size();
342 const int64_t d2 = values.begin()->begin()->size();
343 Array3D<NativeT> tmp(d0, d1, d2);
344 int64_t i0 = 0;
345 for (auto d1_values : values) {
346 int64_t i1 = 0;
347 for (auto d2_values : d1_values) {
348 int64_t i2 = 0;
349 for (auto value : d2_values) {
350 tmp(i0, i1, i2) = value;
351 ++i2;
352 }
353 ++i1;
354 }
355 ++i0;
356 }
357 return CreateR3FromArray3DWithLayout(tmp, layout);
358 }
359
360 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)361 /* static */ Literal LiteralUtil::CreateR3(
362 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
363 values) {
364 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
365 }
366
367 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)368 /* static */ Literal LiteralUtil::CreateR4WithLayout(
369 std::initializer_list<std::initializer_list<
370 std::initializer_list<std::initializer_list<NativeT>>>>
371 values,
372 const Layout& layout) {
373 const int64_t d0 = values.size();
374 const int64_t d1 = values.begin()->size();
375 const int64_t d2 = values.begin()->begin()->size();
376 const int64_t d3 = values.begin()->begin()->begin()->size();
377 Array4D<NativeT> tmp(d0, d1, d2, d3);
378 int64_t i0 = 0;
379 for (auto d1_values : values) {
380 int64_t i1 = 0;
381 for (auto d2_values : d1_values) {
382 int64_t i2 = 0;
383 for (auto d3_values : d2_values) {
384 int64_t i3 = 0;
385 for (auto value : d3_values) {
386 tmp(i0, i1, i2, i3) = value;
387 ++i3;
388 }
389 ++i2;
390 }
391 ++i1;
392 }
393 ++i0;
394 }
395 return CreateR4FromArray4DWithLayout(tmp, layout);
396 }
397
398 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)399 /* static */ Literal LiteralUtil::CreateR4(
400 std::initializer_list<std::initializer_list<
401 std::initializer_list<std::initializer_list<NativeT>>>>
402 values) {
403 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
404 }
405
406 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)407 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
408 const Array<NativeT>& values, const Layout& layout) {
409 Literal literal(ShapeUtil::MakeShapeWithLayout(
410 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
411 AsInt64Slice(layout.minor_to_major())));
412 literal.PopulateFromArray(values);
413 return literal;
414 }
415
416 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)417 /* static */ Literal LiteralUtil::CreateFromArray(
418 const Array<NativeT>& values) {
419 return CreateFromArrayWithLayout(
420 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
421 }
422
423 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)424 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
425 const Array2D<NativeT>& values, const Layout& layout) {
426 return CreateFromArrayWithLayout(values, layout);
427 }
428
429 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)430 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
431 const Array2D<NativeT>& values) {
432 return CreateFromArray(values);
433 }
434
435 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)436 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
437 const Array3D<NativeT>& values, const Layout& layout) {
438 return CreateFromArrayWithLayout(values, layout);
439 }
440
441 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)442 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
443 const Array3D<NativeT>& values) {
444 return CreateFromArray(values);
445 }
446
447 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection)448 /* static */ Literal LiteralUtil::CreateR3Projected(
449 std::initializer_list<std::initializer_list<NativeT>> values,
450 int64_t projection) {
451 int64_t dim0_size = projection;
452 int64_t dim1_size = values.size();
453 int64_t dim2_size = values.begin()->size();
454
455 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
456 for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
457 int64_t dim1 = 0;
458 for (auto inner_list : values) {
459 int64_t dim2 = 0;
460 for (auto value : inner_list) {
461 array(dim0, dim1, dim2) = value;
462 ++dim2;
463 }
464 CHECK_EQ(dim2_size, dim2);
465 ++dim1;
466 }
467 CHECK_EQ(dim1_size, dim1);
468 }
469 return CreateR3FromArray3D(array);
470 }
471
472 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection_p,int64_t projection_z)473 /* static */ Literal LiteralUtil::CreateR4Projected(
474 std::initializer_list<std::initializer_list<NativeT>> values,
475 int64_t projection_p, int64_t projection_z) {
476 int64_t dim0_size = projection_p;
477 int64_t dim1_size = projection_z;
478 int64_t dim2_size = values.size();
479 int64_t dim3_size = values.begin()->size();
480
481 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
482 for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
483 for (int64_t dim1 = 0; dim1 < dim1_size; ++dim1) {
484 int64_t dim2 = 0;
485 for (auto inner_list : values) {
486 int64_t dim3 = 0;
487 for (auto value : inner_list) {
488 array(dim0, dim1, dim2, dim3) = value;
489 ++dim3;
490 }
491 CHECK_EQ(dim3_size, dim3);
492 ++dim2;
493 }
494 CHECK_EQ(dim2_size, dim2);
495 }
496 }
497 return CreateR4FromArray4D(array);
498 }
499
500 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)501 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
502 const Array4D<NativeT>& values) {
503 return CreateFromArray(values);
504 }
505
506 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)507 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
508 const Array4D<NativeT>& values, const Layout& layout) {
509 return CreateFromArrayWithLayout(values, layout);
510 }
511
512 // Returns an identity matrix (rank 2) with the given row and column count.
513 template <typename NativeT>
MakeIdentityR2(int64_t size)514 /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
515 Array2D<NativeT> array(size, size, 0);
516 for (int64_t i = 0; i < size; ++i) {
517 array(i, i) = 1;
518 }
519 return CreateR2FromArray2D(array);
520 }
521
522 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)523 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
524 absl::Span<const int64> dimensions, NativeT value) {
525 Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
526 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
527 literal.PopulateWithValue(value);
528 return literal;
529 }
530
531 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)532 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
533 const Shape& shape,
534 const std::function<T(absl::Span<const int64>)>& generator) {
535 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
536 TF_RET_CHECK(shape.element_type() == type);
537 Literal literal(shape);
538 TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
539 [&](absl::Span<const int64> indexes) { return generator(indexes); }));
540 return std::move(literal);
541 }
542
543 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)544 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
545 const Shape& shape, E* engine, T mean, T stddev) {
546 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
547 std::normal_distribution<NativeT> generator(mean, stddev);
548 return CreateLiteralWithGenerator<type, NativeT>(
549 shape,
550 [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
551 }
552
553 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)554 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
555 const Shape& shape, T mean, T stddev) {
556 std::minstd_rand0 engine;
557 return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
558 }
559
560 } // namespace xla
561
562 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
563