• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with Literal protobufs.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/types.h"
51 
52 namespace xla {
53 
54 class LiteralUtil {
55  public:
56   LiteralUtil() = delete;
57 
58   // Returns a literal scalar representing the first element.
59   static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
60 
61   // Creates a new literal of a given rank. To minimize ambiguity (for users
62   // and the compiler) these CreateR[0-2] methods should explicitly specify the
63   // native type. For example:
64   //
65   //  CreateR1<float>({1.0, 42.0});
66   //  CreateR2<uint32>({{1, 2}, {3, 4}});
67   //
68   // The variants not ending with WithLayout use the default XLA layout for the
69   // literal's linear representation in memory.
70   template <typename NativeT>
71   static Literal CreateR0(NativeT value);
72   template <typename NativeT>
73   static Literal CreateR1(absl::Span<const NativeT> values);
74   static Literal CreateR1(const tensorflow::core::Bitmap& values);
75   template <typename NativeT>
76   static Literal CreateR2(
77       std::initializer_list<std::initializer_list<NativeT>> values);
78   template <typename NativeT>
79   static Literal CreateR2WithLayout(
80       std::initializer_list<std::initializer_list<NativeT>> values,
81       const Layout& layout);
82   template <typename NativeT>
83   static Literal CreateR3(std::initializer_list<
84                           std::initializer_list<std::initializer_list<NativeT>>>
85                               values);
86   template <typename NativeT>
87   static Literal CreateR3WithLayout(
88       std::initializer_list<
89           std::initializer_list<std::initializer_list<NativeT>>>
90           values,
91       const Layout& layout);
92   template <typename NativeT>
93   static Literal CreateR4(
94       std::initializer_list<std::initializer_list<
95           std::initializer_list<std::initializer_list<NativeT>>>>
96           values);
97   template <typename NativeT>
98   static Literal CreateR4WithLayout(
99       std::initializer_list<std::initializer_list<
100           std::initializer_list<std::initializer_list<NativeT>>>>
101           values,
102       const Layout& layout);
103 
104   // Creates a scalar literal value zero of the given primitive type.
105   static Literal Zero(PrimitiveType primitive_type);
106   // Creates a scalar literal value one of the given primitive type.
107   static Literal One(PrimitiveType primitive_type);
108   // Creates a scalar literal value containing the minimum value of the given
109   // primitive type. For floating-point types, returns -inf.
110   static Literal MinValue(PrimitiveType primitive_type);
111   // Creates a scalar literal value containing the maximum value of the given
112   // primitive type. For floating-point types, returns inf.
113   static Literal MaxValue(PrimitiveType primitive_type);
114   // Creates a scalar literal value containing the NaN value of the given
115   // primitive type. Fail for non-inexact types. For complex types, returns a
116   // nan + nan * j value.
117   static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
118   // Creates a literal of the given shape where each element is `value`.
119   template <typename NativeT>
120   static Literal CreateFullWithDescendingLayout(
121       absl::Span<const int64> dimensions, NativeT value);
122 
123   // Creates a new literal from an Array type. The variants not ending with
124   // WithLayout use the default XLA layout for the literal's linear
125   // representation in memory.
126   template <typename NativeT>
127   static Literal CreateFromArray(const Array<NativeT>& values);
128   template <typename NativeT>
129   static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
130                                            const Layout& layout);
131   template <typename NativeT>
132   static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
133   template <typename NativeT>
134   static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
135                                                const Layout& layout);
136   template <typename NativeT>
137   static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
138   template <typename NativeT>
139   static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
140                                                const Layout& layout);
141   template <typename NativeT>
142   static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
143   template <typename NativeT>
144   static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
145                                                const Layout& layout);
146 
147   // Creates a new vector of U8s literal value from a string.
148   static Literal CreateR1U8(absl::string_view value);
149 
150   // Creates a linspace-populated literal with the given number of rows and
151   // columns.
152   static Literal CreateR2F32Linspace(float from, float to, int64_t rows,
153                                      int64_t cols);
154 
155   // Creates a literal that projects the (x, y) dimensions given in values into
156   // the z dimension given by "projection".
157   template <typename NativeT>
158   static Literal CreateR3Projected(
159       std::initializer_list<std::initializer_list<NativeT>> values,
160       int64_t projection);
161 
162   // Creates a literal that projects the (x, y) dimensions given in values into
163   // the z and p dimensions given.
164   template <typename NativeT>
165   static Literal CreateR4Projected(
166       std::initializer_list<std::initializer_list<NativeT>> values,
167       int64_t projection_p, int64_t projection_z);
168 
169   // Returns an identity matrix (rank 2) with the given row and column count.
170   template <typename NativeT>
171   static Literal MakeIdentityR2(int64_t size);
172 
173   // Returns a tuple literal composed of given literals. Data is copied from the
174   // given elements into the returned literal.
175   static Literal MakeTuple(absl::Span<const Literal* const> elements);
176 
177   static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
178 
179   // As above, but intended to be invoked with move semantics; i.e.
180   //
181   //  std::vector<Literal> elements = ...;
182   //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
183   //
184   // This would have been declared as an overload, but there is ambiguity
185   // in invocation between the above signature and this one.
186   static Literal MakeTupleOwned(std::vector<Literal> elements);
187 
188   // This overload lets you pass a list of Literals to MakeTupleOwned:
189   //
190   //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
191   //
192   // Simply relying on the MakeTupleOwned(std::vector<Literal>)
193   // overload doesn't work because std::initializer_list's elements are always
194   // const.
195   //
196   // The arguments to this function must all be Literal.
197   template <typename... Ts>
MakeTupleOwned(Ts...elements)198   static Literal MakeTupleOwned(Ts... elements) {
199     std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
200     std::vector<Literal> v;
201     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
202              std::make_move_iterator(arr.end()));
203     return MakeTupleOwned(std::move(v));
204   }
205 
206   // Create a constant token literal. Token types have no value.
207   static Literal CreateToken();
208 
209   // Creates a new Literal object with its values havings the primitive_type
210   // type, and with dimensions defined by the dimensions parameter.
211   // The content of the literal values is the default value of the primitive
212   // type of literal itself (0 for numeric types, and false for predicates).
213   static Literal CreateFromDimensions(PrimitiveType primitive_type,
214                                       absl::Span<const int64> dimensions);
215 
216   // If the given literal's data type is bfloat16, converts it to a float
217   // literal; otherwise, returns a copy of it. If the literal is a tuple,
218   // recursively converts its elements.
219   static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
220 
221   // If the given literal's data type is bfloat16, converts it to a double
222   // literal; otherwise, returns a copy of it. If the literal is a tuple,
223   // recursively converts its elements.
224   static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
225 
226   // If the given literal's data type is float, converts it to a bfloat16
227   // literal; otherwise, returns a copy of it. If the literal is a tuple,
228   // recursively converts its elements.
229   static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
230 
231   // If the given literal's data type is float, converts it to a double
232   // literal; otherwise, returns a copy of it. If the literal is a tuple,
233   // recursively converts its elements.
234   static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
235 
236   // If the given literal's data type is double, converts it to a bfloat16
237   // literal; otherwise, returns a copy of it. If the literal is a tuple,
238   // recursively converts its elements.
239   static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
240 
241   // Creates a scalar literal whose value is the maximum value of a given
242   // literal slice.
243   static Literal MaxElement(const LiteralSlice& literal);
244 
245   // If the given literal's data type is double, converts it to a bfloat16
246   // literal; otherwise, returns a copy of it. If the literal is a tuple,
247   // recursively converts its elements.
248   static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
249 
250   // Creates a literal with a new shape with the given new dimensions using the
251   // data in the given input literal. For reshaping purposes the (flat) data
252   // buffer of the input literal is assumed to have the given minor_to_major
253   // layout order.
254   static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
255                               absl::Span<const int64> minor_to_major,
256                               const LiteralSlice& literal);
257 
258   // Creates a literal with the supplied shape, and uses the provided value
259   // generator to populate the literal's values.
260   // Returns the new literal object, or an error Status if failed.
261   template <
262       PrimitiveType type,
263       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
264   static StatusOr<Literal> CreateLiteralWithGenerator(
265       const Shape& shape,
266       const std::function<T(absl::Span<const int64>)>& generator);
267 
268   // Creates a literal with the supplied shape, and initializes the literal
269   // values using a normal distribution with given mean and stddev standard
270   // deviation, and using the engine as entropy generator.
271   // Returns the new literal object, or an error Status if failed.
272   template <
273       PrimitiveType type, typename E,
274       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
275   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
276                                                T mean, T stddev);
277 
278   // Creates a literal with the supplied shape, and initializes the literal
279   // values using a normal distribution with given mean and stddev standard
280   // deviation.
281   // Returns the new literal object, or an error Status if failed.
282   template <
283       PrimitiveType type,
284       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
285   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
286                                                T stddev);
287 
288   //
289   // End of factory methods.
290 
291   // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
292   // be returned for a 2-dimensional index with dimension 0 index equal to 7,
293   // dimension 1 equal to 8.
294   static string MultiIndexAsString(absl::Span<const int64> multi_index);
295 };
296 
297 std::ostream& operator<<(std::ostream& out, const Literal& literal);
298 
299 template <typename NativeT>
CreateR0(NativeT value)300 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
301   Literal literal(ShapeUtil::MakeShape(
302       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
303   literal.Set({}, value);
304   return literal;
305 }
306 
307 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)308 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
309   Literal literal(
310       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
311                            {static_cast<int64>(values.size())}));
312   literal.PopulateR1(values);
313   return literal;
314 }
315 
316 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)317 /* static */ Literal LiteralUtil::CreateR2WithLayout(
318     std::initializer_list<std::initializer_list<NativeT>> values,
319     const Layout& layout) {
320   Literal literal(ShapeUtil::MakeShapeWithLayout(
321       primitive_util::NativeToPrimitiveType<NativeT>(),
322       {static_cast<int64>(values.size()),
323        static_cast<int64>(values.begin()->size())},
324       AsInt64Slice(layout.minor_to_major())));
325   literal.PopulateR2(values);
326   return literal;
327 }
328 
329 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)330 /* static */ Literal LiteralUtil::CreateR2(
331     std::initializer_list<std::initializer_list<NativeT>> values) {
332   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
333 }
334 
335 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)336 /* static */ Literal LiteralUtil::CreateR3WithLayout(
337     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
338         values,
339     const Layout& layout) {
340   const int64_t d0 = values.size();
341   const int64_t d1 = values.begin()->size();
342   const int64_t d2 = values.begin()->begin()->size();
343   Array3D<NativeT> tmp(d0, d1, d2);
344   int64_t i0 = 0;
345   for (auto d1_values : values) {
346     int64_t i1 = 0;
347     for (auto d2_values : d1_values) {
348       int64_t i2 = 0;
349       for (auto value : d2_values) {
350         tmp(i0, i1, i2) = value;
351         ++i2;
352       }
353       ++i1;
354     }
355     ++i0;
356   }
357   return CreateR3FromArray3DWithLayout(tmp, layout);
358 }
359 
360 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)361 /* static */ Literal LiteralUtil::CreateR3(
362     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
363         values) {
364   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
365 }
366 
367 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)368 /* static */ Literal LiteralUtil::CreateR4WithLayout(
369     std::initializer_list<std::initializer_list<
370         std::initializer_list<std::initializer_list<NativeT>>>>
371         values,
372     const Layout& layout) {
373   const int64_t d0 = values.size();
374   const int64_t d1 = values.begin()->size();
375   const int64_t d2 = values.begin()->begin()->size();
376   const int64_t d3 = values.begin()->begin()->begin()->size();
377   Array4D<NativeT> tmp(d0, d1, d2, d3);
378   int64_t i0 = 0;
379   for (auto d1_values : values) {
380     int64_t i1 = 0;
381     for (auto d2_values : d1_values) {
382       int64_t i2 = 0;
383       for (auto d3_values : d2_values) {
384         int64_t i3 = 0;
385         for (auto value : d3_values) {
386           tmp(i0, i1, i2, i3) = value;
387           ++i3;
388         }
389         ++i2;
390       }
391       ++i1;
392     }
393     ++i0;
394   }
395   return CreateR4FromArray4DWithLayout(tmp, layout);
396 }
397 
398 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)399 /* static */ Literal LiteralUtil::CreateR4(
400     std::initializer_list<std::initializer_list<
401         std::initializer_list<std::initializer_list<NativeT>>>>
402         values) {
403   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
404 }
405 
406 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)407 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
408     const Array<NativeT>& values, const Layout& layout) {
409   Literal literal(ShapeUtil::MakeShapeWithLayout(
410       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
411       AsInt64Slice(layout.minor_to_major())));
412   literal.PopulateFromArray(values);
413   return literal;
414 }
415 
416 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)417 /* static */ Literal LiteralUtil::CreateFromArray(
418     const Array<NativeT>& values) {
419   return CreateFromArrayWithLayout(
420       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
421 }
422 
423 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)424 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
425     const Array2D<NativeT>& values, const Layout& layout) {
426   return CreateFromArrayWithLayout(values, layout);
427 }
428 
429 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)430 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
431     const Array2D<NativeT>& values) {
432   return CreateFromArray(values);
433 }
434 
435 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)436 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
437     const Array3D<NativeT>& values, const Layout& layout) {
438   return CreateFromArrayWithLayout(values, layout);
439 }
440 
441 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)442 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
443     const Array3D<NativeT>& values) {
444   return CreateFromArray(values);
445 }
446 
447 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection)448 /* static */ Literal LiteralUtil::CreateR3Projected(
449     std::initializer_list<std::initializer_list<NativeT>> values,
450     int64_t projection) {
451   int64_t dim0_size = projection;
452   int64_t dim1_size = values.size();
453   int64_t dim2_size = values.begin()->size();
454 
455   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
456   for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
457     int64_t dim1 = 0;
458     for (auto inner_list : values) {
459       int64_t dim2 = 0;
460       for (auto value : inner_list) {
461         array(dim0, dim1, dim2) = value;
462         ++dim2;
463       }
464       CHECK_EQ(dim2_size, dim2);
465       ++dim1;
466     }
467     CHECK_EQ(dim1_size, dim1);
468   }
469   return CreateR3FromArray3D(array);
470 }
471 
472 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64_t projection_p,int64_t projection_z)473 /* static */ Literal LiteralUtil::CreateR4Projected(
474     std::initializer_list<std::initializer_list<NativeT>> values,
475     int64_t projection_p, int64_t projection_z) {
476   int64_t dim0_size = projection_p;
477   int64_t dim1_size = projection_z;
478   int64_t dim2_size = values.size();
479   int64_t dim3_size = values.begin()->size();
480 
481   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
482   for (int64_t dim0 = 0; dim0 < dim0_size; ++dim0) {
483     for (int64_t dim1 = 0; dim1 < dim1_size; ++dim1) {
484       int64_t dim2 = 0;
485       for (auto inner_list : values) {
486         int64_t dim3 = 0;
487         for (auto value : inner_list) {
488           array(dim0, dim1, dim2, dim3) = value;
489           ++dim3;
490         }
491         CHECK_EQ(dim3_size, dim3);
492         ++dim2;
493       }
494       CHECK_EQ(dim2_size, dim2);
495     }
496   }
497   return CreateR4FromArray4D(array);
498 }
499 
500 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)501 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
502     const Array4D<NativeT>& values) {
503   return CreateFromArray(values);
504 }
505 
506 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)507 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
508     const Array4D<NativeT>& values, const Layout& layout) {
509   return CreateFromArrayWithLayout(values, layout);
510 }
511 
512 // Returns an identity matrix (rank 2) with the given row and column count.
513 template <typename NativeT>
MakeIdentityR2(int64_t size)514 /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
515   Array2D<NativeT> array(size, size, 0);
516   for (int64_t i = 0; i < size; ++i) {
517     array(i, i) = 1;
518   }
519   return CreateR2FromArray2D(array);
520 }
521 
522 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)523 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
524     absl::Span<const int64> dimensions, NativeT value) {
525   Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
526       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
527   literal.PopulateWithValue(value);
528   return literal;
529 }
530 
531 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)532 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
533     const Shape& shape,
534     const std::function<T(absl::Span<const int64>)>& generator) {
535   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
536   TF_RET_CHECK(shape.element_type() == type);
537   Literal literal(shape);
538   TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
539       [&](absl::Span<const int64> indexes) { return generator(indexes); }));
540   return std::move(literal);
541 }
542 
543 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)544 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
545     const Shape& shape, E* engine, T mean, T stddev) {
546   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
547   std::normal_distribution<NativeT> generator(mean, stddev);
548   return CreateLiteralWithGenerator<type, NativeT>(
549       shape,
550       [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
551 }
552 
553 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)554 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
555     const Shape& shape, T mean, T stddev) {
556   std::minstd_rand0 engine;
557   return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
558 }
559 
560 }  // namespace xla
561 
562 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
563