• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "absl/base/casts.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/primitive_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
25 #include "tensorflow/compiler/xla/service/transfer_manager.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 
28 namespace xla {
29 
30 namespace {
31 
32 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)33 void PopulateWithRandomFloatingPointData(Literal* literal,
34                                          std::minstd_rand0* engine) {
35   std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
36   for (FloatT& value : literal->data<FloatT>()) {
37     value = static_cast<FloatT>(generator(*engine));
38   }
39 }
40 
41 template <typename FloatT>
42 void PopulateWithIntNext(Literal* literal);
43 
44 template <>
PopulateWithIntNext(Literal * literal)45 void PopulateWithIntNext<half>(Literal* literal) {
46   // Duplicates may be generated if we don't have enough bits.
47   uint16 next_value = 0;
48   for (half& value : literal->data<half>()) {
49     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
50     // the sign bit. We could be less wasteful, but this is best-effort anyway.
51     uint16 exponent_msb = next_value & 0x4000;
52     value.x = (next_value & 0xBFFF) | (exponent_msb << 1);
53     next_value++;
54   }
55 }
56 
57 template <>
PopulateWithIntNext(Literal * literal)58 void PopulateWithIntNext<bfloat16>(Literal* literal) {
59   // Duplicates may be generated if we don't have enough bits.
60   // Start at 0x80 rather than 0 to avoid denormals.
61   uint16 next_value = 0x80;
62   for (bfloat16& value : literal->data<bfloat16>()) {
63     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
64     // the sign bit. We could be less wasteful, but this is best-effort anyway.
65     uint16 exponent_msb = next_value & 0x4000;
66     value.value = (next_value & 0xBFFF) | (exponent_msb << 1);
67     next_value++;
68   }
69 }
70 
71 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)72 void PopulateWithNextAfter(Literal* literal) {
73   // Duplicates may be generated if the number of elements in the literal
74   // exceeds the number of positive values supported by the type.
75   float next_value = std::numeric_limits<float>::min();
76   for (float& value : literal->data<float>()) {
77     value = next_value;
78     next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
79   }
80 }
81 
82 template <typename FloatT,
83           typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
84                                       std::is_same<half, FloatT>::value,
85                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)86 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
87   PopulateWithIntNext<FloatT>(literal);
88   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
89                *engine);
90 }
91 
92 template <typename FloatT,
93           typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
94                                       !std::is_same<half, FloatT>::value,
95                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)96 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
97   PopulateWithNextAfter<FloatT>(literal);
98   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
99                *engine);
100 }
101 
102 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)103 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
104                                    bool no_duplicates) {
105   CHECK(engine != nullptr);
106   CHECK_EQ(literal->shape().element_type(),
107            primitive_util::NativeToPrimitiveType<FloatT>());
108   if (no_duplicates) {
109     PopulateWithNoDuplicateData<FloatT>(literal, engine);
110   } else {
111     PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
112   }
113 }
114 
115 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates)116 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
117                              bool no_duplicates) {
118   using InnerFloatT = typename ComplexT::value_type;
119   CHECK(engine != nullptr);
120   CHECK_EQ(result->shape().element_type(),
121            primitive_util::NativeToPrimitiveType<ComplexT>());
122   Shape floating_point_shape = ShapeUtil::ChangeElementType(
123       result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
124   Literal real_lit(floating_point_shape);
125   Literal imaginary_lit(floating_point_shape);
126 
127   PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates);
128   PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
129                                              no_duplicates);
130 
131   absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
132   absl::Span<const InnerFloatT> imaginary_data =
133       imaginary_lit.data<InnerFloatT>();
134   absl::Span<ComplexT> result_data = result->data<ComplexT>();
135   for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
136     result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
137   }
138 }
139 
140 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)141 void PopulateWithFloatingPointData<half>(Literal* literal,
142                                          std::minstd_rand0* engine,
143                                          bool no_duplicates) {
144   CHECK(engine != nullptr);
145   CHECK_EQ(literal->shape().element_type(),
146            primitive_util::NativeToPrimitiveType<half>());
147   if (no_duplicates) {
148     PopulateWithNoDuplicateData<half>(literal, engine);
149   } else {
150     PopulateWithRandomFloatingPointData<half, float>(literal, engine);
151   }
152 }
153 
154 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)155 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
156                                              std::minstd_rand0* engine,
157                                              bool no_duplicates) {
158   CHECK(engine != nullptr);
159   CHECK_EQ(literal->shape().element_type(),
160            primitive_util::NativeToPrimitiveType<bfloat16>());
161   if (no_duplicates) {
162     PopulateWithNoDuplicateData<bfloat16>(literal, engine);
163   } else {
164     PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
165   }
166 }
167 
168 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)169 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
170                                     bool no_duplicates) {
171   CHECK(engine != nullptr);
172   CHECK_EQ(literal->shape().element_type(),
173            primitive_util::NativeToPrimitiveType<IntT>());
174   if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
175                            std::numeric_limits<IntT>::max()) {
176     std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
177     std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
178                  *engine);
179   } else {
180     std::uniform_int_distribution<IntT> generator(
181         std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
182     for (IntT& value : literal->data<IntT>()) {
183       value = generator(*engine);
184     }
185   }
186 }
187 
188 // Similar to MakeFakeLiteral but takes a random number generator engine to
189 // enable reusing the engine across randomly generated literals. 'no_duplicates'
190 // indicates that there should be no duplicate values in each generated
191 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
192 // are not supported and uniqueness cannot be guaranteed if the number of
193 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates)194 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
195                                           std::minstd_rand0* engine,
196                                           bool no_duplicates) {
197   if (shape.IsTuple()) {
198     std::vector<Literal> elements;
199     for (const Shape& element_shape : shape.tuple_shapes()) {
200       TF_ASSIGN_OR_RETURN(
201           Literal element,
202           MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
203       elements.push_back(std::move(element));
204     }
205     return LiteralUtil::MakeTupleOwned(std::move(elements));
206   }
207   if (engine == nullptr) {
208     return Literal::CreateFromShape(shape);
209   }
210   Literal literal(shape);
211   switch (shape.element_type()) {
212     case BF16:
213       PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates);
214       break;
215     case F16:
216       PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates);
217       break;
218     case F32:
219       PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates);
220       break;
221     case F64:
222       PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates);
223       break;
224     case S8:
225       PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
226       break;
227     case U8:
228       PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
229       break;
230     case S16:
231       PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
232       break;
233     case U16:
234       PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
235       break;
236     case S32:
237       PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
238       break;
239     case U32:
240       PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
241       break;
242     case S64:
243       PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
244       break;
245     case U64:
246       PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
247       break;
248     case C64:
249       PopulateWithComplexData<complex64>(&literal, engine, no_duplicates);
250       break;
251     case C128:
252       PopulateWithComplexData<complex128>(&literal, engine, no_duplicates);
253       break;
254     case PRED: {
255       std::uniform_int_distribution<int> generator(0, 1);
256       TF_CHECK_OK(
257           literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
258             return generator(*engine);
259           }));
260       break;
261     }
262     // Token requires no data.
263     case TOKEN:
264       break;
265     default:
266       return Unimplemented("Unsupported type for fake literal generation: %s",
267                            ShapeUtil::HumanString(shape));
268   }
269   return std::move(literal);
270 }
271 
272 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)273 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
274                                               std::minstd_rand0* engine,
275                                               IntT min, IntT max) {
276   CHECK(engine != nullptr);
277   CHECK_EQ(literal->shape().element_type(),
278            primitive_util::NativeToPrimitiveType<IntT>());
279   std::uniform_int_distribution<IntT> generator(min, max);
280   for (IntT& value : literal->data<IntT>()) {
281     value = generator(*engine);
282   }
283 }
284 
285 // Same as MakeFakeLiteralInternal but generates random numbers in the given
286 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64 min,int64 max)287 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
288                                                     std::minstd_rand0* engine,
289                                                     int64 min, int64 max) {
290   if (shape.IsTuple()) {
291     std::vector<Literal> elements;
292     for (const Shape& element_shape : shape.tuple_shapes()) {
293       TF_ASSIGN_OR_RETURN(
294           Literal element,
295           MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max));
296       elements.push_back(std::move(element));
297     }
298     return LiteralUtil::MakeTupleOwned(std::move(elements));
299   }
300   if (engine == nullptr) {
301     return Literal::CreateFromShape(shape);
302   }
303   Literal literal(shape);
304   switch (shape.element_type()) {
305     case S8:
306       PopulateWithRandomIntegralDataWithBounds<int8>(
307           &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
308       break;
309     case U8:
310       PopulateWithRandomIntegralDataWithBounds<uint8>(
311           &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
312       break;
313     case S16:
314       PopulateWithRandomIntegralDataWithBounds<int16>(
315           &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
316       break;
317     case U16:
318       PopulateWithRandomIntegralDataWithBounds<uint16>(
319           &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
320       break;
321     case S32:
322       PopulateWithRandomIntegralDataWithBounds<int32>(
323           &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
324       break;
325     case U32:
326       PopulateWithRandomIntegralDataWithBounds<uint32>(
327           &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
328       break;
329     case S64:
330       PopulateWithRandomIntegralDataWithBounds<int64>(
331           &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
332       break;
333     case U64:
334       PopulateWithRandomIntegralDataWithBounds<uint64>(
335           &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
336       break;
337     default:
338       return Unimplemented(
339           "Unsupported type for fake random literal generation with bounds: %s",
340           ShapeUtil::HumanString(shape));
341   }
342   return std::move(literal);
343 }
344 
345 enum class ConstantType { kUnknown, kZero, kOne };
346 
347 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)348 ConstantType GetInitValue(const HloComputation& computation) {
349   // TODO(b/77635120): Add init values, for min, max, and their arg variants.
350   const HloInstruction* const root = computation.root_instruction();
351   if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
352       root->operand(0)->opcode() != HloOpcode::kParameter ||
353       root->operand(1)->opcode() != HloOpcode::kParameter ||
354       root->operand(0) == root->operand(1)) {
355     return ConstantType::kUnknown;
356   }
357 
358   switch (root->opcode()) {
359     case HloOpcode::kAdd:
360       return ConstantType::kZero;
361     case HloOpcode::kMultiply:
362       return ConstantType::kOne;
363     default:
364       return ConstantType::kUnknown;
365   }
366 }
367 
368 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
369 // initialization value.
NeedsInitValue(const HloUse & use)370 bool NeedsInitValue(const HloUse& use) {
371   const HloInstruction* const instruction = use.instruction;
372   const HloOpcode opcode = instruction->opcode();
373   const int64 op_num = use.operand_number;
374   return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
375           (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
376           (opcode == HloOpcode::kReduce &&
377            op_num >= instruction->operand_count() / 2));
378 }
379 
380 // Generate random values that are constrained to the input_shape minus the
381 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64 index_bound,std::minstd_rand0 * engine)382 Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
383   std::uniform_int_distribution<int32> generator(0, index_bound);
384   return LiteralUtil::CreateR0<int32>(generator(*engine));
385 }
386 
387 // Use dataflow analysis on each parameter to see if there are uses that would
388 // be problematic when generating input data.  Returns the list of instructions
389 // that correspond to their uses.
390 //
391 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)392 std::vector<HloInstruction*> FindConstrainedUses(
393     const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
394   std::vector<HloInstruction*> constrained_uses;
395   for (const auto& pair : dataflow.GetInstructionValueSet(&param)) {
396     const HloValue& value = dataflow.GetUniqueValueAt(&param, pair.first);
397     for (const HloUse& use : value.uses()) {
398       HloInstruction* instruction = use.instruction;
399       const HloOpcode opcode = instruction->opcode();
400       const int64 op_num = use.operand_number;
401       if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
402           (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
403         constrained_uses.push_back(instruction);
404       } else if ((opcode == HloOpcode::kGather ||
405                   opcode == HloOpcode::kScatter) &&
406                  op_num == 1) {
407         constrained_uses.push_back(instruction);
408       } else if (opcode == HloOpcode::kFusion) {
409         const HloInstruction* const to_analyze =
410             instruction->fused_parameter(op_num);
411         auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
412         constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
413                                 fused_uses.end());
414       } else if (NeedsInitValue(use)) {
415         constrained_uses.push_back(instruction);
416       } else if (opcode == HloOpcode::kConvert ||
417                  opcode == HloOpcode::kReducePrecision) {
418         auto converted_uses = FindConstrainedUses(dataflow, *instruction);
419         constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
420                                 converted_uses.end());
421       } else if (opcode == HloOpcode::kSort &&
422                  instruction->operand_count() >= 2 && op_num == 0) {
423         // Operand 0 of sort is the array of keys used for key/value
424         // (two-operand) kSort instructions. Since sort stability is not
425         // guaranteed, constrain keys of key-value sort not to have duplicates,
426         // since otherwise the value order may legitimately differ.
427         constrained_uses.push_back(instruction);
428       }
429     }
430   }
431   return constrained_uses;
432 }
433 
434 // Given a parameter, generate a random Literal to use as input if there exist
435 // no constrained uses in the dataflow graph.  If such constraints exist,
436 // generate a constrained literal (either bounded in the case of indices, or
437 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,std::minstd_rand0 * engine)438 StatusOr<Literal> CreateLiteralForConstrainedUses(
439     const absl::Span<HloInstruction* const> constrained_uses,
440     const HloInstruction& param, std::minstd_rand0* engine) {
441   int64 index_bound = INT64_MAX;
442   bool no_duplicates = false;
443   bool needs_constant = false;
444   ConstantType constant_type = ConstantType::kUnknown;
445   for (HloInstruction* use : constrained_uses) {
446     switch (use->opcode()) {
447       case HloOpcode::kDynamicSlice:
448       case HloOpcode::kDynamicUpdateSlice: {
449         const Shape& indexed_shape = use->operand(0)->shape();
450         const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
451                                        ? use->shape()
452                                        : use->operand(1)->shape();
453         const int64 first_index =
454             Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
455         for (int64 operand = first_index; operand < use->operand_count();
456              ++operand) {
457           if (use->operand(operand) == &param) {
458             index_bound = std::min(
459                 index_bound,
460                 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
461                     ShapeUtil::GetDimension(slice_shape,
462                                             operand - first_index));
463           }
464         }
465         break;
466       }
467       case HloOpcode::kGather:
468       case HloOpcode::kScatter: {
469         const Shape& operand_shape = use->operand(0)->shape();
470         if (use->operand(1) == &param) {
471           auto index_map =
472               use->opcode() == HloOpcode::kGather
473                   ? use->gather_dimension_numbers().start_index_map()
474                   : use->scatter_dimension_numbers()
475                         .scatter_dims_to_operand_dims();
476           for (const auto dim_in_operand : index_map) {
477             index_bound =
478                 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
479           }
480         }
481         break;
482       }
483       case HloOpcode::kReduce:
484       case HloOpcode::kReduceWindow:
485         needs_constant = true;
486         constant_type = GetInitValue(*use->to_apply());
487         break;
488 
489       case HloOpcode::kSelectAndScatter:
490         needs_constant = true;
491         constant_type = GetInitValue(*use->scatter());
492         break;
493 
494       case HloOpcode::kSort:
495         no_duplicates = true;
496         break;
497 
498       default:
499         return Unimplemented(
500             "Constrained operand generation not implemented for %s.",
501             use->ToString());
502     }
503   }
504   int constraint_count = 0;
505   constraint_count += no_duplicates ? 1 : 0;
506   constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
507   constraint_count += needs_constant ? 1 : 0;
508   if (constraint_count > 1) {
509     return Unimplemented("Conflicting operand generation constraints.");
510   }
511   if (index_bound != INT64_MAX) {
512     return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1,
513                                              index_bound);
514   } else if (needs_constant) {
515     switch (constant_type) {
516       case ConstantType::kZero:
517         return LiteralUtil::Zero(param.shape().element_type());
518       case ConstantType::kOne:
519         return LiteralUtil::One(param.shape().element_type());
520       case ConstantType::kUnknown:
521         // We want the identity element for the computation, but we don't really
522         // know what it is - so any value we generate will be just as wrong.
523         return MakeFakeLiteralInternal(param.shape(), engine,
524                                        /*no_duplicates=*/false);
525     }
526   } else {
527     return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
528   }
529 }
530 
531 // Given a module entry parameter, use the dataflow analysis to see if a
532 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,std::minstd_rand0 * engine)533 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
534                                           const HloInstruction& param,
535                                           std::minstd_rand0* engine) {
536   const auto constrained_uses = FindConstrainedUses(dataflow, param);
537   return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
538 }
539 
540 }  // namespace
541 
MakeFakeLiteral(const Shape & shape,bool pseudo_random)542 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
543   auto engine =
544       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
545   return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
546 }
547 
MakeFakeArguments(HloModule * const module,bool pseudo_random)548 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
549                                                  bool pseudo_random) {
550   auto engine =
551       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
552   return MakeFakeArguments(module, engine.get());
553 }
554 
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine)555 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
556                                                  std::minstd_rand0* engine) {
557   TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
558   const auto params = module->entry_computation()->parameter_instructions();
559   std::vector<Literal> arguments(params.size());
560   for (int i = 0; i < params.size(); ++i) {
561     arguments[i] =
562         MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
563   }
564   return std::move(arguments);
565 }
566 
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)567 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
568                        bool allow_mixed_precision) {
569   return HloVerifier(/*layout_sensitive=*/layout_sensitive,
570                      /*allow_mixed_precision=*/allow_mixed_precision)
571       .Run(module)
572       .status();
573 }
574 
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)575 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
576                                                       HloInstruction* lhs,
577                                                       HloInstruction* rhs) {
578   CHECK_EQ(lhs->shape().rank(), 2);
579   CHECK_EQ(rhs->shape().rank(), 2);
580   PrecisionConfig precision_config;
581   precision_config.mutable_operand_precision()->Resize(
582       2, PrecisionConfig::DEFAULT);
583   DotDimensionNumbers dot_dimension_numbers;
584   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
585   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
586   return absl::make_unique<HloDotInstruction>(
587       shape, lhs, rhs, dot_dimension_numbers, precision_config);
588 }
589 }  // namespace xla
590