1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with Literal protobufs.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/types.h"
51
52 namespace xla {
53
54 class LiteralUtil {
55 public:
56 LiteralUtil() = delete;
57
58 // Returns a literal scalar representing the first element.
59 static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
60
61 // Creates a new literal of a given rank. To minimize ambiguity (for users
62 // and the compiler) these CreateR[0-2] methods should explicitly specify the
63 // native type. For example:
64 //
65 // CreateR1<float>({1.0, 42.0});
66 // CreateR2<uint32>({{1, 2}, {3, 4}});
67 //
68 // The variants not ending with WithLayout use the default XLA layout for the
69 // literal's linear representation in memory.
70 template <typename NativeT>
71 static Literal CreateR0(NativeT value);
72 template <typename NativeT>
73 static Literal CreateR1(absl::Span<const NativeT> values);
74 static Literal CreateR1(const tensorflow::core::Bitmap& values);
75 template <typename NativeT>
76 static Literal CreateR2(
77 std::initializer_list<std::initializer_list<NativeT>> values);
78 template <typename NativeT>
79 static Literal CreateR2WithLayout(
80 std::initializer_list<std::initializer_list<NativeT>> values,
81 const Layout& layout);
82 template <typename NativeT>
83 static Literal CreateR3(std::initializer_list<
84 std::initializer_list<std::initializer_list<NativeT>>>
85 values);
86 template <typename NativeT>
87 static Literal CreateR3WithLayout(
88 std::initializer_list<
89 std::initializer_list<std::initializer_list<NativeT>>>
90 values,
91 const Layout& layout);
92 template <typename NativeT>
93 static Literal CreateR4(
94 std::initializer_list<std::initializer_list<
95 std::initializer_list<std::initializer_list<NativeT>>>>
96 values);
97 template <typename NativeT>
98 static Literal CreateR4WithLayout(
99 std::initializer_list<std::initializer_list<
100 std::initializer_list<std::initializer_list<NativeT>>>>
101 values,
102 const Layout& layout);
103
104 // Creates a scalar literal value zero of the given primitive type.
105 static Literal Zero(PrimitiveType primitive_type);
106 // Creates a scalar literal value one of the given primitive type.
107 static Literal One(PrimitiveType primitive_type);
108 // Creates a scalar literal value containing the minimum value of the given
109 // primitive type. For floating-point types, returns -inf.
110 static Literal MinValue(PrimitiveType primitive_type);
111 // Creates a scalar literal value containing the maximum value of the given
112 // primitive type. For floating-point types, returns inf.
113 static Literal MaxValue(PrimitiveType primitive_type);
114 // Creates a scalar literal value containing the NaN value of the given
115 // primitive type. Fail for non-inexact types. For complex types, returns a
116 // nan + nan * j value.
117 static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
118 // Creates a literal of the given shape where each element is `value`.
119 template <typename NativeT>
120 static Literal CreateFullWithDescendingLayout(
121 absl::Span<const int64> dimensions, NativeT value);
122
123 // Creates a new literal from an Array type. The variants not ending with
124 // WithLayout use the default XLA layout for the literal's linear
125 // representation in memory.
126 template <typename NativeT>
127 static Literal CreateFromArray(const Array<NativeT>& values);
128 template <typename NativeT>
129 static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
130 const Layout& layout);
131 template <typename NativeT>
132 static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
133 template <typename NativeT>
134 static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
135 const Layout& layout);
136 template <typename NativeT>
137 static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
138 template <typename NativeT>
139 static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
140 const Layout& layout);
141 template <typename NativeT>
142 static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
143 template <typename NativeT>
144 static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
145 const Layout& layout);
146
147 // Creates a new vector of U8s literal value from a string.
148 static Literal CreateR1U8(absl::string_view value);
149
150 // Creates a linspace-populated literal with the given number of rows and
151 // columns.
152 static Literal CreateR2F32Linspace(float from, float to, int64 rows,
153 int64 cols);
154
155 // Creates a literal that projects the (x, y) dimensions given in values into
156 // the z dimension given by "projection".
157 template <typename NativeT>
158 static Literal CreateR3Projected(
159 std::initializer_list<std::initializer_list<NativeT>> values,
160 int64 projection);
161
162 // Creates a literal that projects the (x, y) dimensions given in values into
163 // the z and p dimensions given.
164 template <typename NativeT>
165 static Literal CreateR4Projected(
166 std::initializer_list<std::initializer_list<NativeT>> values,
167 int64 projection_p, int64 projection_z);
168
169 // Returns an identity matrix (rank 2) with the given row and column count.
170 template <typename NativeT>
171 static Literal MakeIdentityR2(int64 size);
172
173 // Returns a tuple literal composed of given literals. Data is copied from the
174 // given elements into the returned literal.
175 static Literal MakeTuple(absl::Span<const Literal* const> elements);
176
177 static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
178
179 // As above, but intended to be invoked with move semantics; i.e.
180 //
181 // std::vector<Literal> elements = ...;
182 // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
183 //
184 // This would have been declared as an overload, but there is ambiguity
185 // in invocation between the above signature and this one.
186 static Literal MakeTupleOwned(std::vector<Literal> elements);
187
188 // This overload lets you pass a list of Literals to MakeTupleOwned:
189 //
190 // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
191 //
192 // Simply relying on the MakeTupleOwned(std::vector<Literal>)
193 // overload doesn't work because std::initializer_list's elements are always
194 // const.
195 //
196 // The arguments to this function must all be Literal.
197 template <typename... Ts>
MakeTupleOwned(Ts...elements)198 static Literal MakeTupleOwned(Ts... elements) {
199 std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
200 std::vector<Literal> v;
201 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
202 std::make_move_iterator(arr.end()));
203 return MakeTupleOwned(std::move(v));
204 }
205
206 // Create a constant token literal. Token types have no value.
207 static Literal CreateToken();
208
209 // Creates a new Literal object with its values havings the primitive_type
210 // type, and with dimensions defined by the dimensions parameter.
211 // The content of the literal values is the default value of the primitive
212 // type of literal itself (0 for numeric types, and false for predicates).
213 static Literal CreateFromDimensions(PrimitiveType primitive_type,
214 absl::Span<const int64> dimensions);
215
216 // If the given literal's data type is bfloat16, converts it to a float
217 // literal; otherwise, returns a copy of it. If the literal is a tuple,
218 // recursively converts its elements.
219 static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
220
221 // If the given literal's data type is bfloat16, converts it to a double
222 // literal; otherwise, returns a copy of it. If the literal is a tuple,
223 // recursively converts its elements.
224 static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
225
226 // If the given literal's data type is float, converts it to a bfloat16
227 // literal; otherwise, returns a copy of it. If the literal is a tuple,
228 // recursively converts its elements.
229 static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
230
231 // If the given literal's data type is float, converts it to a double
232 // literal; otherwise, returns a copy of it. If the literal is a tuple,
233 // recursively converts its elements.
234 static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
235
236 // If the given literal's data type is double, converts it to a bfloat16
237 // literal; otherwise, returns a copy of it. If the literal is a tuple,
238 // recursively converts its elements.
239 static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
240
241 // If the given literal's data type is double, converts it to a bfloat16
242 // literal; otherwise, returns a copy of it. If the literal is a tuple,
243 // recursively converts its elements.
244 static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
245
246 // Creates a literal with a new shape with the given new dimensions using the
247 // data in the given input literal. For reshaping purposes the (flat) data
248 // buffer of the input literal is assumed to have the given minor_to_major
249 // layout order.
250 static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
251 absl::Span<const int64> minor_to_major,
252 const LiteralSlice& literal);
253
254 // Creates a literal with the supplied shape, and uses the provided value
255 // generator to populate the literal's values.
256 // Returns the new literal object, or an error Status if failed.
257 template <
258 PrimitiveType type,
259 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
260 static StatusOr<Literal> CreateLiteralWithGenerator(
261 const Shape& shape,
262 const std::function<T(absl::Span<const int64>)>& generator);
263
264 // Creates a literal with the supplied shape, and initializes the literal
265 // values using a normal distribution with given mean and stddev standard
266 // deviation, and using the engine as entropy generator.
267 // Returns the new literal object, or an error Status if failed.
268 template <
269 PrimitiveType type, typename E,
270 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
271 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
272 T mean, T stddev);
273
274 // Creates a literal with the supplied shape, and initializes the literal
275 // values using a normal distribution with given mean and stddev standard
276 // deviation.
277 // Returns the new literal object, or an error Status if failed.
278 template <
279 PrimitiveType type,
280 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
281 static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
282 T stddev);
283
284 //
285 // End of factory methods.
286
287 // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
288 // be returned for a 2-dimensional index with dimension 0 index equal to 7,
289 // dimension 1 equal to 8.
290 static string MultiIndexAsString(absl::Span<const int64> multi_index);
291 };
292
293 std::ostream& operator<<(std::ostream& out, const Literal& literal);
294
295 template <typename NativeT>
CreateR0(NativeT value)296 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
297 Literal literal(ShapeUtil::MakeShape(
298 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
299 literal.Set({}, value);
300 return literal;
301 }
302
303 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)304 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
305 Literal literal(
306 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
307 {static_cast<int64>(values.size())}));
308 literal.PopulateR1(values);
309 return literal;
310 }
311
312 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)313 /* static */ Literal LiteralUtil::CreateR2WithLayout(
314 std::initializer_list<std::initializer_list<NativeT>> values,
315 const Layout& layout) {
316 Literal literal(ShapeUtil::MakeShapeWithLayout(
317 primitive_util::NativeToPrimitiveType<NativeT>(),
318 {static_cast<int64>(values.size()),
319 static_cast<int64>(values.begin()->size())},
320 AsInt64Slice(layout.minor_to_major())));
321 literal.PopulateR2(values);
322 return literal;
323 }
324
325 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)326 /* static */ Literal LiteralUtil::CreateR2(
327 std::initializer_list<std::initializer_list<NativeT>> values) {
328 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
329 }
330
331 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)332 /* static */ Literal LiteralUtil::CreateR3WithLayout(
333 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
334 values,
335 const Layout& layout) {
336 const int64 d0 = values.size();
337 const int64 d1 = values.begin()->size();
338 const int64 d2 = values.begin()->begin()->size();
339 Array3D<NativeT> tmp(d0, d1, d2);
340 int64 i0 = 0;
341 for (auto d1_values : values) {
342 int64 i1 = 0;
343 for (auto d2_values : d1_values) {
344 int64 i2 = 0;
345 for (auto value : d2_values) {
346 tmp(i0, i1, i2) = value;
347 ++i2;
348 }
349 ++i1;
350 }
351 ++i0;
352 }
353 return CreateR3FromArray3DWithLayout(tmp, layout);
354 }
355
356 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)357 /* static */ Literal LiteralUtil::CreateR3(
358 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
359 values) {
360 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
361 }
362
363 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)364 /* static */ Literal LiteralUtil::CreateR4WithLayout(
365 std::initializer_list<std::initializer_list<
366 std::initializer_list<std::initializer_list<NativeT>>>>
367 values,
368 const Layout& layout) {
369 const int64 d0 = values.size();
370 const int64 d1 = values.begin()->size();
371 const int64 d2 = values.begin()->begin()->size();
372 const int64 d3 = values.begin()->begin()->begin()->size();
373 Array4D<NativeT> tmp(d0, d1, d2, d3);
374 int64 i0 = 0;
375 for (auto d1_values : values) {
376 int64 i1 = 0;
377 for (auto d2_values : d1_values) {
378 int64 i2 = 0;
379 for (auto d3_values : d2_values) {
380 int64 i3 = 0;
381 for (auto value : d3_values) {
382 tmp(i0, i1, i2, i3) = value;
383 ++i3;
384 }
385 ++i2;
386 }
387 ++i1;
388 }
389 ++i0;
390 }
391 return CreateR4FromArray4DWithLayout(tmp, layout);
392 }
393
394 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)395 /* static */ Literal LiteralUtil::CreateR4(
396 std::initializer_list<std::initializer_list<
397 std::initializer_list<std::initializer_list<NativeT>>>>
398 values) {
399 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
400 }
401
402 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)403 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
404 const Array<NativeT>& values, const Layout& layout) {
405 Literal literal(ShapeUtil::MakeShapeWithLayout(
406 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
407 AsInt64Slice(layout.minor_to_major())));
408 literal.PopulateFromArray(values);
409 return literal;
410 }
411
412 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)413 /* static */ Literal LiteralUtil::CreateFromArray(
414 const Array<NativeT>& values) {
415 return CreateFromArrayWithLayout(
416 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
417 }
418
419 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)420 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
421 const Array2D<NativeT>& values, const Layout& layout) {
422 return CreateFromArrayWithLayout(values, layout);
423 }
424
425 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)426 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
427 const Array2D<NativeT>& values) {
428 return CreateFromArray(values);
429 }
430
431 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)432 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
433 const Array3D<NativeT>& values, const Layout& layout) {
434 return CreateFromArrayWithLayout(values, layout);
435 }
436
437 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)438 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
439 const Array3D<NativeT>& values) {
440 return CreateFromArray(values);
441 }
442
443 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)444 /* static */ Literal LiteralUtil::CreateR3Projected(
445 std::initializer_list<std::initializer_list<NativeT>> values,
446 int64 projection) {
447 int64 dim0_size = projection;
448 int64 dim1_size = values.size();
449 int64 dim2_size = values.begin()->size();
450
451 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
452 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
453 int64 dim1 = 0;
454 for (auto inner_list : values) {
455 int64 dim2 = 0;
456 for (auto value : inner_list) {
457 array(dim0, dim1, dim2) = value;
458 ++dim2;
459 }
460 CHECK_EQ(dim2_size, dim2);
461 ++dim1;
462 }
463 CHECK_EQ(dim1_size, dim1);
464 }
465 return CreateR3FromArray3D(array);
466 }
467
468 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)469 /* static */ Literal LiteralUtil::CreateR4Projected(
470 std::initializer_list<std::initializer_list<NativeT>> values,
471 int64 projection_p, int64 projection_z) {
472 int64 dim0_size = projection_p;
473 int64 dim1_size = projection_z;
474 int64 dim2_size = values.size();
475 int64 dim3_size = values.begin()->size();
476
477 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
478 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
479 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
480 int64 dim2 = 0;
481 for (auto inner_list : values) {
482 int64 dim3 = 0;
483 for (auto value : inner_list) {
484 array(dim0, dim1, dim2, dim3) = value;
485 ++dim3;
486 }
487 CHECK_EQ(dim3_size, dim3);
488 ++dim2;
489 }
490 CHECK_EQ(dim2_size, dim2);
491 }
492 }
493 return CreateR4FromArray4D(array);
494 }
495
496 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)497 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
498 const Array4D<NativeT>& values) {
499 return CreateFromArray(values);
500 }
501
502 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)503 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
504 const Array4D<NativeT>& values, const Layout& layout) {
505 return CreateFromArrayWithLayout(values, layout);
506 }
507
508 // Returns an identity matrix (rank 2) with the given row and column count.
509 template <typename NativeT>
MakeIdentityR2(int64 size)510 /* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
511 Array2D<NativeT> array(size, size, 0);
512 for (int64 i = 0; i < size; ++i) {
513 array(i, i) = 1;
514 }
515 return CreateR2FromArray2D(array);
516 }
517
518 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)519 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
520 absl::Span<const int64> dimensions, NativeT value) {
521 Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
522 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
523 literal.PopulateWithValue(value);
524 return literal;
525 }
526
527 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)528 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
529 const Shape& shape,
530 const std::function<T(absl::Span<const int64>)>& generator) {
531 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
532 TF_RET_CHECK(shape.element_type() == type);
533 Literal literal(shape);
534 TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
535 [&](absl::Span<const int64> indexes) { return generator(indexes); }));
536 return std::move(literal);
537 }
538
539 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)540 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
541 const Shape& shape, E* engine, T mean, T stddev) {
542 using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
543 std::normal_distribution<NativeT> generator(mean, stddev);
544 return CreateLiteralWithGenerator<type, NativeT>(
545 shape,
546 [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
547 }
548
549 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)550 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
551 const Shape& shape, T mean, T stddev) {
552 std::minstd_rand0 engine;
553 return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
554 }
555
556 } // namespace xla
557
558 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
559