1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17
18 #include "absl/base/casts.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/primitive_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
25 #include "tensorflow/compiler/xla/service/transfer_manager.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27
28 namespace xla {
29
30 namespace {
31
32 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)33 void PopulateWithRandomFloatingPointData(Literal* literal,
34 std::minstd_rand0* engine) {
35 std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
36 for (FloatT& value : literal->data<FloatT>()) {
37 value = static_cast<FloatT>(generator(*engine));
38 }
39 }
40
41 template <typename FloatT>
42 void PopulateWithIntNext(Literal* literal);
43
44 template <>
PopulateWithIntNext(Literal * literal)45 void PopulateWithIntNext<half>(Literal* literal) {
46 // Duplicates may be generated if we don't have enough bits.
47 uint16 next_value = 0;
48 for (half& value : literal->data<half>()) {
49 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
50 // the sign bit. We could be less wasteful, but this is best-effort anyway.
51 uint16 exponent_msb = next_value & 0x4000;
52 value.x = (next_value & 0xBFFF) | (exponent_msb << 1);
53 next_value++;
54 }
55 }
56
57 template <>
PopulateWithIntNext(Literal * literal)58 void PopulateWithIntNext<bfloat16>(Literal* literal) {
59 // Duplicates may be generated if we don't have enough bits.
60 // Start at 0x80 rather than 0 to avoid denormals.
61 uint16 next_value = 0x80;
62 for (bfloat16& value : literal->data<bfloat16>()) {
63 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
64 // the sign bit. We could be less wasteful, but this is best-effort anyway.
65 uint16 exponent_msb = next_value & 0x4000;
66 value.value = (next_value & 0xBFFF) | (exponent_msb << 1);
67 next_value++;
68 }
69 }
70
71 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)72 void PopulateWithNextAfter(Literal* literal) {
73 // Duplicates may be generated if the number of elements in the literal
74 // exceeds the number of positive values supported by the type.
75 float next_value = std::numeric_limits<float>::min();
76 for (float& value : literal->data<float>()) {
77 value = next_value;
78 next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
79 }
80 }
81
82 template <typename FloatT,
83 typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
84 std::is_same<half, FloatT>::value,
85 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)86 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
87 PopulateWithIntNext<FloatT>(literal);
88 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
89 *engine);
90 }
91
92 template <typename FloatT,
93 typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
94 !std::is_same<half, FloatT>::value,
95 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)96 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
97 PopulateWithNextAfter<FloatT>(literal);
98 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
99 *engine);
100 }
101
102 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)103 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
104 bool no_duplicates) {
105 CHECK(engine != nullptr);
106 CHECK_EQ(literal->shape().element_type(),
107 primitive_util::NativeToPrimitiveType<FloatT>());
108 if (no_duplicates) {
109 PopulateWithNoDuplicateData<FloatT>(literal, engine);
110 } else {
111 PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
112 }
113 }
114
115 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates)116 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
117 bool no_duplicates) {
118 using InnerFloatT = typename ComplexT::value_type;
119 CHECK(engine != nullptr);
120 CHECK_EQ(result->shape().element_type(),
121 primitive_util::NativeToPrimitiveType<ComplexT>());
122 Shape floating_point_shape = ShapeUtil::ChangeElementType(
123 result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
124 Literal real_lit(floating_point_shape);
125 Literal imaginary_lit(floating_point_shape);
126
127 PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates);
128 PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
129 no_duplicates);
130
131 absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
132 absl::Span<const InnerFloatT> imaginary_data =
133 imaginary_lit.data<InnerFloatT>();
134 absl::Span<ComplexT> result_data = result->data<ComplexT>();
135 for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
136 result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
137 }
138 }
139
140 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)141 void PopulateWithFloatingPointData<half>(Literal* literal,
142 std::minstd_rand0* engine,
143 bool no_duplicates) {
144 CHECK(engine != nullptr);
145 CHECK_EQ(literal->shape().element_type(),
146 primitive_util::NativeToPrimitiveType<half>());
147 if (no_duplicates) {
148 PopulateWithNoDuplicateData<half>(literal, engine);
149 } else {
150 PopulateWithRandomFloatingPointData<half, float>(literal, engine);
151 }
152 }
153
154 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)155 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
156 std::minstd_rand0* engine,
157 bool no_duplicates) {
158 CHECK(engine != nullptr);
159 CHECK_EQ(literal->shape().element_type(),
160 primitive_util::NativeToPrimitiveType<bfloat16>());
161 if (no_duplicates) {
162 PopulateWithNoDuplicateData<bfloat16>(literal, engine);
163 } else {
164 PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
165 }
166 }
167
168 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)169 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
170 bool no_duplicates) {
171 CHECK(engine != nullptr);
172 CHECK_EQ(literal->shape().element_type(),
173 primitive_util::NativeToPrimitiveType<IntT>());
174 if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
175 std::numeric_limits<IntT>::max()) {
176 std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
177 std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
178 *engine);
179 } else {
180 std::uniform_int_distribution<IntT> generator(
181 std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
182 for (IntT& value : literal->data<IntT>()) {
183 value = generator(*engine);
184 }
185 }
186 }
187
188 // Similar to MakeFakeLiteral but takes a random number generator engine to
189 // enable reusing the engine across randomly generated literals. 'no_duplicates'
190 // indicates that there should be no duplicate values in each generated
191 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
192 // are not supported and uniqueness cannot be guaranteed if the number of
193 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates)194 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
195 std::minstd_rand0* engine,
196 bool no_duplicates) {
197 if (shape.IsTuple()) {
198 std::vector<Literal> elements;
199 for (const Shape& element_shape : shape.tuple_shapes()) {
200 TF_ASSIGN_OR_RETURN(
201 Literal element,
202 MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
203 elements.push_back(std::move(element));
204 }
205 return LiteralUtil::MakeTupleOwned(std::move(elements));
206 }
207 if (engine == nullptr) {
208 return Literal::CreateFromShape(shape);
209 }
210 Literal literal(shape);
211 switch (shape.element_type()) {
212 case BF16:
213 PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates);
214 break;
215 case F16:
216 PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates);
217 break;
218 case F32:
219 PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates);
220 break;
221 case F64:
222 PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates);
223 break;
224 case S8:
225 PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
226 break;
227 case U8:
228 PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
229 break;
230 case S16:
231 PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
232 break;
233 case U16:
234 PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
235 break;
236 case S32:
237 PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
238 break;
239 case U32:
240 PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
241 break;
242 case S64:
243 PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
244 break;
245 case U64:
246 PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
247 break;
248 case C64:
249 PopulateWithComplexData<complex64>(&literal, engine, no_duplicates);
250 break;
251 case C128:
252 PopulateWithComplexData<complex128>(&literal, engine, no_duplicates);
253 break;
254 case PRED: {
255 std::uniform_int_distribution<int> generator(0, 1);
256 TF_CHECK_OK(
257 literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
258 return generator(*engine);
259 }));
260 break;
261 }
262 // Token requires no data.
263 case TOKEN:
264 break;
265 default:
266 return Unimplemented("Unsupported type for fake literal generation: %s",
267 ShapeUtil::HumanString(shape));
268 }
269 return std::move(literal);
270 }
271
272 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)273 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
274 std::minstd_rand0* engine,
275 IntT min, IntT max) {
276 CHECK(engine != nullptr);
277 CHECK_EQ(literal->shape().element_type(),
278 primitive_util::NativeToPrimitiveType<IntT>());
279 std::uniform_int_distribution<IntT> generator(min, max);
280 for (IntT& value : literal->data<IntT>()) {
281 value = generator(*engine);
282 }
283 }
284
285 // Same as MakeFakeLiteralInternal but generates random numbers in the given
286 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64 min,int64 max)287 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
288 std::minstd_rand0* engine,
289 int64 min, int64 max) {
290 if (shape.IsTuple()) {
291 std::vector<Literal> elements;
292 for (const Shape& element_shape : shape.tuple_shapes()) {
293 TF_ASSIGN_OR_RETURN(
294 Literal element,
295 MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max));
296 elements.push_back(std::move(element));
297 }
298 return LiteralUtil::MakeTupleOwned(std::move(elements));
299 }
300 if (engine == nullptr) {
301 return Literal::CreateFromShape(shape);
302 }
303 Literal literal(shape);
304 switch (shape.element_type()) {
305 case S8:
306 PopulateWithRandomIntegralDataWithBounds<int8>(
307 &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
308 break;
309 case U8:
310 PopulateWithRandomIntegralDataWithBounds<uint8>(
311 &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
312 break;
313 case S16:
314 PopulateWithRandomIntegralDataWithBounds<int16>(
315 &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
316 break;
317 case U16:
318 PopulateWithRandomIntegralDataWithBounds<uint16>(
319 &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
320 break;
321 case S32:
322 PopulateWithRandomIntegralDataWithBounds<int32>(
323 &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
324 break;
325 case U32:
326 PopulateWithRandomIntegralDataWithBounds<uint32>(
327 &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
328 break;
329 case S64:
330 PopulateWithRandomIntegralDataWithBounds<int64>(
331 &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
332 break;
333 case U64:
334 PopulateWithRandomIntegralDataWithBounds<uint64>(
335 &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
336 break;
337 default:
338 return Unimplemented(
339 "Unsupported type for fake random literal generation with bounds: %s",
340 ShapeUtil::HumanString(shape));
341 }
342 return std::move(literal);
343 }
344
345 enum class ConstantType { kUnknown, kZero, kOne };
346
347 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)348 ConstantType GetInitValue(const HloComputation& computation) {
349 // TODO(b/77635120): Add init values, for min, max, and their arg variants.
350 const HloInstruction* const root = computation.root_instruction();
351 if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
352 root->operand(0)->opcode() != HloOpcode::kParameter ||
353 root->operand(1)->opcode() != HloOpcode::kParameter ||
354 root->operand(0) == root->operand(1)) {
355 return ConstantType::kUnknown;
356 }
357
358 switch (root->opcode()) {
359 case HloOpcode::kAdd:
360 return ConstantType::kZero;
361 case HloOpcode::kMultiply:
362 return ConstantType::kOne;
363 default:
364 return ConstantType::kUnknown;
365 }
366 }
367
368 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
369 // initialization value.
NeedsInitValue(const HloUse & use)370 bool NeedsInitValue(const HloUse& use) {
371 const HloInstruction* const instruction = use.instruction;
372 const HloOpcode opcode = instruction->opcode();
373 const int64 op_num = use.operand_number;
374 return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
375 (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
376 (opcode == HloOpcode::kReduce &&
377 op_num >= instruction->operand_count() / 2));
378 }
379
380 // Generate random values that are constrained to the input_shape minus the
381 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64 index_bound,std::minstd_rand0 * engine)382 Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
383 std::uniform_int_distribution<int32> generator(0, index_bound);
384 return LiteralUtil::CreateR0<int32>(generator(*engine));
385 }
386
387 // Use dataflow analysis on each parameter to see if there are uses that would
388 // be problematic when generating input data. Returns the list of instructions
389 // that correspond to their uses.
390 //
391 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)392 std::vector<HloInstruction*> FindConstrainedUses(
393 const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
394 std::vector<HloInstruction*> constrained_uses;
395 for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) {
396 const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first);
397 for (const HloUse& use : value.uses()) {
398 HloInstruction* instruction = use.instruction;
399 const HloOpcode opcode = instruction->opcode();
400 const int64 op_num = use.operand_number;
401 if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
402 (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
403 constrained_uses.push_back(instruction);
404 } else if ((opcode == HloOpcode::kGather ||
405 opcode == HloOpcode::kScatter) &&
406 op_num == 1) {
407 constrained_uses.push_back(instruction);
408 } else if (opcode == HloOpcode::kFusion) {
409 const HloInstruction* const to_analyze =
410 instruction->fused_parameter(op_num);
411 auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
412 constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
413 fused_uses.end());
414 } else if (NeedsInitValue(use)) {
415 constrained_uses.push_back(instruction);
416 } else if (opcode == HloOpcode::kConvert ||
417 opcode == HloOpcode::kReducePrecision) {
418 auto converted_uses = FindConstrainedUses(dataflow, *instruction);
419 constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
420 converted_uses.end());
421 } else if (opcode == HloOpcode::kSort &&
422 instruction->operand_count() >= 2 && op_num == 0) {
423 // Operand 0 of sort is the array of keys used for key/value
424 // (two-operand) kSort instructions. Since sort stability is not
425 // guaranteed, constrain keys of key-value sort not to have duplicates,
426 // since otherwise the value order may legitimately differ.
427 constrained_uses.push_back(instruction);
428 }
429 }
430 }
431 return constrained_uses;
432 }
433
434 // Given a parameter, generate a random Literal to use as input if there exist
435 // no constrained uses in the dataflow graph. If such constraints exist,
436 // generate a constrained literal (either bounded in the case of indices, or
437 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,std::minstd_rand0 * engine)438 StatusOr<Literal> CreateLiteralForConstrainedUses(
439 const absl::Span<HloInstruction* const> constrained_uses,
440 const HloInstruction& param, std::minstd_rand0* engine) {
441 int64 index_bound = INT64_MAX;
442 bool no_duplicates = false;
443 bool needs_constant = false;
444 ConstantType constant_type = ConstantType::kUnknown;
445 for (HloInstruction* use : constrained_uses) {
446 switch (use->opcode()) {
447 case HloOpcode::kDynamicSlice:
448 case HloOpcode::kDynamicUpdateSlice: {
449 const Shape& indexed_shape = use->operand(0)->shape();
450 const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
451 ? use->shape()
452 : use->operand(1)->shape();
453 const int64 first_index =
454 Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
455 for (int64 operand = first_index; operand < use->operand_count();
456 ++operand) {
457 if (use->operand(operand) == ¶m) {
458 index_bound = std::min(
459 index_bound,
460 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
461 ShapeUtil::GetDimension(slice_shape,
462 operand - first_index));
463 }
464 }
465 break;
466 }
467 case HloOpcode::kGather:
468 case HloOpcode::kScatter: {
469 const Shape& operand_shape = use->operand(0)->shape();
470 if (use->operand(1) == ¶m) {
471 auto index_map =
472 use->opcode() == HloOpcode::kGather
473 ? use->gather_dimension_numbers().start_index_map()
474 : use->scatter_dimension_numbers()
475 .scatter_dims_to_operand_dims();
476 for (const auto dim_in_operand : index_map) {
477 index_bound =
478 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
479 }
480 }
481 break;
482 }
483 case HloOpcode::kReduce:
484 case HloOpcode::kReduceWindow:
485 needs_constant = true;
486 constant_type = GetInitValue(*use->to_apply());
487 break;
488
489 case HloOpcode::kSelectAndScatter:
490 needs_constant = true;
491 constant_type = GetInitValue(*use->scatter());
492 break;
493
494 case HloOpcode::kSort:
495 no_duplicates = true;
496 break;
497
498 default:
499 return Unimplemented(
500 "Constrained operand generation not implemented for %s.",
501 use->ToString());
502 }
503 }
504 int constraint_count = 0;
505 constraint_count += no_duplicates ? 1 : 0;
506 constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
507 constraint_count += needs_constant ? 1 : 0;
508 if (constraint_count > 1) {
509 return Unimplemented("Conflicting operand generation constraints.");
510 }
511 if (index_bound != INT64_MAX) {
512 return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1,
513 index_bound);
514 } else if (needs_constant) {
515 switch (constant_type) {
516 case ConstantType::kZero:
517 return LiteralUtil::Zero(param.shape().element_type());
518 case ConstantType::kOne:
519 return LiteralUtil::One(param.shape().element_type());
520 case ConstantType::kUnknown:
521 // We want the identity element for the computation, but we don't really
522 // know what it is - so any value we generate will be just as wrong.
523 return MakeFakeLiteralInternal(param.shape(), engine,
524 /*no_duplicates=*/false);
525 }
526 } else {
527 return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
528 }
529 }
530
531 // Given a module entry parameter, use the dataflow analysis to see if a
532 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,std::minstd_rand0 * engine)533 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
534 const HloInstruction& param,
535 std::minstd_rand0* engine) {
536 const auto constrained_uses = FindConstrainedUses(dataflow, param);
537 return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
538 }
539
540 } // namespace
541
MakeFakeLiteral(const Shape & shape,bool pseudo_random)542 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
543 auto engine =
544 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
545 return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
546 }
547
MakeFakeArguments(HloModule * const module,bool pseudo_random)548 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
549 bool pseudo_random) {
550 auto engine =
551 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
552 return MakeFakeArguments(module, engine.get());
553 }
554
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine)555 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
556 std::minstd_rand0* engine) {
557 TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
558 const auto params = module->entry_computation()->parameter_instructions();
559 std::vector<Literal> arguments(params.size());
560 for (int i = 0; i < params.size(); ++i) {
561 arguments[i] =
562 MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
563 }
564 return std::move(arguments);
565 }
566
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)567 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
568 bool allow_mixed_precision) {
569 return HloVerifier(/*layout_sensitive=*/layout_sensitive,
570 /*allow_mixed_precision=*/allow_mixed_precision)
571 .Run(module)
572 .status();
573 }
574
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)575 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
576 HloInstruction* lhs,
577 HloInstruction* rhs) {
578 CHECK_EQ(lhs->shape().rank(), 2);
579 CHECK_EQ(rhs->shape().rank(), 2);
580 PrecisionConfig precision_config;
581 precision_config.mutable_operand_precision()->Resize(
582 2, PrecisionConfig::DEFAULT);
583 DotDimensionNumbers dot_dimension_numbers;
584 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
585 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
586 return absl::make_unique<HloDotInstruction>(
587 shape, lhs, rhs, dot_dimension_numbers, precision_config);
588 }
589 } // namespace xla
590