• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17 
18 #include <cmath>
19 
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29 
30 namespace xla {
31 
32 namespace {
33 
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36                                          std::minstd_rand0* engine) {
37   std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38   for (FloatT& value : literal->data<FloatT>()) {
39     value = static_cast<FloatT>(generator(*engine));
40   }
41 }
42 
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48                                                   std::minstd_rand0* engine) {
49   constexpr float kSpecialValueProbability = 1e-6;
50   constexpr float kSpecialValues[] = {+0.F,
51                                       -0.F,
52                                       1.F,
53                                       -1.F,
54                                       std::numeric_limits<float>::infinity(),
55                                       -std::numeric_limits<float>::infinity()};
56   constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57   std::uniform_real_distribution<float> special_value_gen(0, 1);
58 
59   // Generates floating points with a log-uniform distribution. This causes the
60   // exponent of the floating point to have a uniform distribution.
61   int min_exp, max_exp;
62   if (std::is_same<FloatT, bfloat16>()) {
63     min_exp = std::numeric_limits<float>::min_exponent;
64     max_exp = std::numeric_limits<float>::max_exponent;
65   } else {
66     min_exp = std::numeric_limits<FloatT>::min_exponent;
67     max_exp = std::numeric_limits<FloatT>::max_exponent;
68   }
69   std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70 
71   for (FloatT& value : literal->data<FloatT>()) {
72     // Each special value has a kSpecialValueProbability chance to be generated
73     // instead of sampling using the normal distributions.
74     if (special_value_gen(*engine) <
75         kSpecialValueProbability * kNumSpecialValues) {
76       value =
77           static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78     } else {
79       float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80       value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81     }
82   }
83 }
84 
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87 
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90   // Duplicates may be generated if we don't have enough bits.
91   uint16 next_value = 0;
92   for (half& value : literal->data<half>()) {
93     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94     // the sign bit. We could be less wasteful, but this is best-effort anyway.
95     uint16 exponent_msb = next_value & 0x4000;
96     value.x = (next_value & 0xBFFF) | (exponent_msb << 1);
97     next_value++;
98   }
99 }
100 
101 template <>
PopulateWithIntNext(Literal * literal)102 void PopulateWithIntNext<bfloat16>(Literal* literal) {
103   // Duplicates may be generated if we don't have enough bits.
104   // Start at 0x80 rather than 0 to avoid denormals.
105   uint16 next_value = 0x80;
106   for (bfloat16& value : literal->data<bfloat16>()) {
107     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
108     // the sign bit. We could be less wasteful, but this is best-effort anyway.
109     uint16 exponent_msb = next_value & 0x4000;
110     value.value = (next_value & 0xBFFF) | (exponent_msb << 1);
111     next_value++;
112   }
113 }
114 
115 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)116 void PopulateWithNextAfter(Literal* literal) {
117   // Duplicates may be generated if the number of elements in the literal
118   // exceeds the number of positive values supported by the type.
119   float next_value = std::numeric_limits<float>::min();
120   for (float& value : literal->data<float>()) {
121     value = next_value;
122     next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
123   }
124 }
125 
126 template <typename FloatT,
127           typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
128                                       std::is_same<half, FloatT>::value,
129                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)130 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
131   PopulateWithIntNext<FloatT>(literal);
132   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
133                *engine);
134 }
135 
136 template <typename FloatT,
137           typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
138                                       !std::is_same<half, FloatT>::value,
139                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)140 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
141   PopulateWithNextAfter<FloatT>(literal);
142   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
143                *engine);
144 }
145 
146 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)147 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
148                                    bool no_duplicates, bool use_large_range) {
149   CHECK(engine != nullptr);
150   CHECK_EQ(literal->shape().element_type(),
151            primitive_util::NativeToPrimitiveType<FloatT>());
152   if (no_duplicates) {
153     PopulateWithNoDuplicateData<FloatT>(literal, engine);
154   } else if (use_large_range) {
155     PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
156   } else {
157     PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
158   }
159 }
160 
161 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)162 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
163                              bool no_duplicates, bool use_large_range) {
164   using InnerFloatT = typename ComplexT::value_type;
165   CHECK(engine != nullptr);
166   CHECK_EQ(result->shape().element_type(),
167            primitive_util::NativeToPrimitiveType<ComplexT>());
168   Shape floating_point_shape = ShapeUtil::ChangeElementType(
169       result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
170   Literal real_lit(floating_point_shape);
171   Literal imaginary_lit(floating_point_shape);
172 
173   PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
174                                              use_large_range);
175   PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
176                                              no_duplicates, use_large_range);
177 
178   absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
179   absl::Span<const InnerFloatT> imaginary_data =
180       imaginary_lit.data<InnerFloatT>();
181   absl::Span<ComplexT> result_data = result->data<ComplexT>();
182   for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
183     result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
184   }
185 }
186 
187 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)188 void PopulateWithFloatingPointData<half>(Literal* literal,
189                                          std::minstd_rand0* engine,
190                                          bool no_duplicates,
191                                          bool use_large_range) {
192   CHECK(engine != nullptr);
193   CHECK_EQ(literal->shape().element_type(),
194            primitive_util::NativeToPrimitiveType<half>());
195   if (no_duplicates) {
196     PopulateWithNoDuplicateData<half>(literal, engine);
197   } else if (use_large_range) {
198     PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
199   } else {
200     PopulateWithRandomFloatingPointData<half, float>(literal, engine);
201   }
202 }
203 
204 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)205 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
206                                              std::minstd_rand0* engine,
207                                              bool no_duplicates,
208                                              bool use_large_range) {
209   CHECK(engine != nullptr);
210   CHECK_EQ(literal->shape().element_type(),
211            primitive_util::NativeToPrimitiveType<bfloat16>());
212   if (no_duplicates) {
213     PopulateWithNoDuplicateData<bfloat16>(literal, engine);
214   } else if (use_large_range) {
215     PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
216   } else {
217     PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
218   }
219 }
220 
221 // uniform_int_distribution is not defined for 8-bit integers.
222 // Use 'short' for those types.
223 template <typename IntT>
224 struct RngT {
225   using type = IntT;
226 };
227 
228 template <>
229 struct RngT<int8> {
230   using type = int16;
231 };
232 
233 template <>
234 struct RngT<uint8> {
235   using type = uint16;
236 };
237 
238 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)239 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
240                                     bool no_duplicates) {
241   CHECK(engine != nullptr);
242   CHECK_EQ(literal->shape().element_type(),
243            primitive_util::NativeToPrimitiveType<IntT>());
244   if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
245                            std::numeric_limits<IntT>::max()) {
246     std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
247     std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
248                  *engine);
249   } else {
250     std::uniform_int_distribution<typename RngT<IntT>::type> generator(
251         std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
252     for (IntT& value : literal->data<IntT>()) {
253       value = generator(*engine);
254     }
255   }
256 }
257 
258 // Similar to MakeFakeLiteral but takes a random number generator engine to
259 // enable reusing the engine across randomly generated literals. 'no_duplicates'
260 // indicates that there should be no duplicate values in each generated
261 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
262 // are not supported and uniqueness cannot be guaranteed if the number of
263 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)264 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
265                                           std::minstd_rand0* engine,
266                                           bool no_duplicates,
267                                           bool use_large_range) {
268   if (shape.IsTuple()) {
269     std::vector<Literal> elements;
270     for (const Shape& element_shape : shape.tuple_shapes()) {
271       TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
272                                                element_shape, engine,
273                                                no_duplicates, use_large_range));
274       elements.push_back(std::move(element));
275     }
276     return LiteralUtil::MakeTupleOwned(std::move(elements));
277   }
278   if (engine == nullptr) {
279     return Literal::CreateFromShape(shape);
280   }
281   // Clear tiles/element size in shape's layout before using it for creating
282   // literal.
283   Shape new_shape = shape;
284   new_shape.mutable_layout()->clear_tiles();
285   new_shape.mutable_layout()->set_element_size_in_bits(0);
286   Literal literal(new_shape);
287   switch (shape.element_type()) {
288     case BF16:
289       PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
290                                               use_large_range);
291       break;
292     case F16:
293       PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
294                                           use_large_range);
295       break;
296     case F32:
297       PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
298                                            use_large_range);
299       break;
300     case F64:
301       PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
302                                             use_large_range);
303       break;
304     case S8:
305       PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
306       break;
307     case U8:
308       PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
309       break;
310     case S16:
311       PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
312       break;
313     case U16:
314       PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
315       break;
316     case S32:
317       PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
318       break;
319     case U32:
320       PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
321       break;
322     case S64:
323       PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
324       break;
325     case U64:
326       PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
327       break;
328     case C64:
329       PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
330                                          use_large_range);
331       break;
332     case C128:
333       PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
334                                           use_large_range);
335       break;
336     case PRED: {
337       std::uniform_int_distribution<int> generator(0, 1);
338       TF_CHECK_OK(
339           literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
340             return generator(*engine);
341           }));
342       break;
343     }
344     default:
345       return Unimplemented("Unsupported type for fake literal generation: %s",
346                            ShapeUtil::HumanString(shape));
347   }
348   return std::move(literal);
349 }
350 
351 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)352 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
353                                               std::minstd_rand0* engine,
354                                               IntT min, IntT max) {
355   CHECK(engine != nullptr);
356   CHECK_EQ(literal->shape().element_type(),
357            primitive_util::NativeToPrimitiveType<IntT>());
358   std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
359   for (IntT& value : literal->data<IntT>()) {
360     value = generator(*engine);
361   }
362 }
363 
364 // Same as MakeFakeLiteralInternal but generates random numbers in the given
365 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64 min,int64 max,bool is_sorted)366 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
367                                                     std::minstd_rand0* engine,
368                                                     int64 min, int64 max,
369                                                     bool is_sorted) {
370   if (shape.IsTuple()) {
371     std::vector<Literal> elements;
372     for (const Shape& element_shape : shape.tuple_shapes()) {
373       TF_ASSIGN_OR_RETURN(Literal element,
374                           MakeFakeLiteralInternalWithBounds(
375                               element_shape, engine, min, max, is_sorted));
376       elements.push_back(std::move(element));
377     }
378     return LiteralUtil::MakeTupleOwned(std::move(elements));
379   }
380   if (engine == nullptr) {
381     return Literal::CreateFromShape(shape);
382   }
383   // Clear tiles/element size in shape's layout before using it for creating
384   // literal.
385   Shape new_shape = shape;
386   new_shape.mutable_layout()->clear_tiles();
387   new_shape.mutable_layout()->set_element_size_in_bits(0);
388   Literal literal(new_shape);
389   switch (shape.element_type()) {
390     case S8:
391       PopulateWithRandomIntegralDataWithBounds<int8>(
392           &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
393       if (is_sorted) {
394         std::sort(literal.data<int8>().begin(), literal.data<int8>().end());
395       }
396       break;
397     case U8:
398       PopulateWithRandomIntegralDataWithBounds<uint8>(
399           &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
400       if (is_sorted) {
401         std::sort(literal.data<uint8>().begin(), literal.data<uint8>().end());
402       }
403       break;
404     case S16:
405       PopulateWithRandomIntegralDataWithBounds<int16>(
406           &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
407       if (is_sorted) {
408         std::sort(literal.data<int16>().begin(), literal.data<int16>().end());
409       }
410       break;
411     case U16:
412       PopulateWithRandomIntegralDataWithBounds<uint16>(
413           &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
414       if (is_sorted) {
415         std::sort(literal.data<uint16>().begin(), literal.data<uint16>().end());
416       }
417       break;
418     case S32:
419       PopulateWithRandomIntegralDataWithBounds<int32>(
420           &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
421       if (is_sorted) {
422         std::sort(literal.data<int32>().begin(), literal.data<int32>().end());
423       }
424       break;
425     case U32:
426       PopulateWithRandomIntegralDataWithBounds<uint32>(
427           &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
428       if (is_sorted) {
429         std::sort(literal.data<uint32>().begin(), literal.data<uint32>().end());
430       }
431       break;
432     case S64:
433       PopulateWithRandomIntegralDataWithBounds<int64>(
434           &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
435       if (is_sorted) {
436         std::sort(literal.data<int64>().begin(), literal.data<int64>().end());
437       }
438       break;
439     case U64:
440       PopulateWithRandomIntegralDataWithBounds<uint64>(
441           &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
442       if (is_sorted) {
443         std::sort(literal.data<uint64>().begin(), literal.data<uint64>().end());
444       }
445       break;
446     default:
447       return Unimplemented(
448           "Unsupported type for fake random literal generation with bounds: %s",
449           ShapeUtil::HumanString(shape));
450   }
451   return std::move(literal);
452 }
453 
454 enum class ConstantType { kUnknown, kZero, kOne };
455 
456 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)457 ConstantType GetInitValue(const HloComputation& computation) {
458   // TODO(b/77635120): Add init values, for min, max, and their arg variants.
459   const HloInstruction* const root = computation.root_instruction();
460   if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
461       root->operand(0)->opcode() != HloOpcode::kParameter ||
462       root->operand(1)->opcode() != HloOpcode::kParameter ||
463       root->operand(0) == root->operand(1)) {
464     return ConstantType::kUnknown;
465   }
466 
467   switch (root->opcode()) {
468     case HloOpcode::kAdd:
469       return ConstantType::kZero;
470     case HloOpcode::kMultiply:
471       return ConstantType::kOne;
472     default:
473       return ConstantType::kUnknown;
474   }
475 }
476 
477 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
478 // initialization value.
NeedsInitValue(const HloUse & use)479 bool NeedsInitValue(const HloUse& use) {
480   const HloInstruction* const instruction = use.instruction;
481   const HloOpcode opcode = instruction->opcode();
482   const int64 op_num = use.operand_number;
483   return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
484           (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
485           (opcode == HloOpcode::kReduce &&
486            op_num >= instruction->operand_count() / 2));
487 }
488 
489 // Generate random values that are constrained to the input_shape minus the
490 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64 index_bound,std::minstd_rand0 * engine)491 Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
492   std::uniform_int_distribution<int32> generator(0, index_bound);
493   return LiteralUtil::CreateR0<int32>(generator(*engine));
494 }
495 
496 // Use dataflow analysis on each parameter to see if there are uses that would
497 // be problematic when generating input data.  Returns the list of instructions
498 // that correspond to their uses.
499 //
500 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)501 std::vector<HloInstruction*> FindConstrainedUses(
502     const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
503   std::vector<HloInstruction*> constrained_uses;
504   for (const auto& pair : dataflow.GetInstructionValueSet(&param)) {
505     const HloValue& value = dataflow.GetUniqueValueAt(&param, pair.first);
506     for (const HloUse& use : value.uses()) {
507       HloInstruction* instruction = use.instruction;
508       const HloOpcode opcode = instruction->opcode();
509       const int64 op_num = use.operand_number;
510       if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
511           (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
512         constrained_uses.push_back(instruction);
513       } else if ((opcode == HloOpcode::kGather ||
514                   opcode == HloOpcode::kScatter) &&
515                  op_num == 1) {
516         constrained_uses.push_back(instruction);
517       } else if (opcode == HloOpcode::kFusion) {
518         const HloInstruction* const to_analyze =
519             instruction->fused_parameter(op_num);
520         auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
521         constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
522                                 fused_uses.end());
523       } else if (NeedsInitValue(use)) {
524         constrained_uses.push_back(instruction);
525       } else if (opcode == HloOpcode::kConvert ||
526                  opcode == HloOpcode::kReducePrecision) {
527         auto converted_uses = FindConstrainedUses(dataflow, *instruction);
528         constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
529                                 converted_uses.end());
530       } else if (opcode == HloOpcode::kSort &&
531                  instruction->operand_count() >= 2 && op_num == 0) {
532         // Operand 0 of sort is the array of keys used for key/value
533         // (two-operand) kSort instructions. Since sort stability is not
534         // guaranteed, constrain keys of key-value sort not to have duplicates,
535         // since otherwise the value order may legitimately differ.
536         constrained_uses.push_back(instruction);
537       }
538     }
539   }
540   return constrained_uses;
541 }
542 
543 // Given a parameter, generate a random Literal to use as input if there exist
544 // no constrained uses in the dataflow graph.  If such constraints exist,
545 // generate a constrained literal (either bounded in the case of indices, or
546 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)547 StatusOr<Literal> CreateLiteralForConstrainedUses(
548     const absl::Span<HloInstruction* const> constrained_uses,
549     const HloInstruction& param, const Shape& param_shape,
550     std::minstd_rand0* engine, bool use_large_range) {
551   int64 index_bound = INT64_MAX;
552   bool no_duplicates = false;
553   bool needs_constant = false;
554   bool needs_sorted_indices = false;
555   ConstantType constant_type = ConstantType::kUnknown;
556   for (HloInstruction* use : constrained_uses) {
557     switch (use->opcode()) {
558       case HloOpcode::kDynamicSlice:
559       case HloOpcode::kDynamicUpdateSlice: {
560         const Shape& indexed_shape = use->operand(0)->shape();
561         const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
562                                        ? use->shape()
563                                        : use->operand(1)->shape();
564         const int64 first_index =
565             Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
566         for (int64 operand = first_index; operand < use->operand_count();
567              ++operand) {
568           if (use->operand(operand) == &param) {
569             index_bound = std::min(
570                 index_bound,
571                 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
572                     ShapeUtil::GetDimension(slice_shape,
573                                             operand - first_index));
574           }
575         }
576         break;
577       }
578       case HloOpcode::kGather:
579       case HloOpcode::kScatter: {
580         const Shape& operand_shape = use->operand(0)->shape();
581         if (use->operand(1) == &param) {
582           auto index_map =
583               use->opcode() == HloOpcode::kGather
584                   ? use->gather_dimension_numbers().start_index_map()
585                   : use->scatter_dimension_numbers()
586                         .scatter_dims_to_operand_dims();
587           for (const auto dim_in_operand : index_map) {
588             index_bound =
589                 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
590           }
591         }
592         if (use->opcode() == HloOpcode::kScatter) {
593           needs_sorted_indices |=
594               Cast<const HloScatterInstruction>(use)->indices_are_sorted();
595         } else {
596           needs_sorted_indices |=
597               Cast<const HloGatherInstruction>(use)->indices_are_sorted();
598         }
599         break;
600       }
601       case HloOpcode::kReduce:
602       case HloOpcode::kReduceWindow:
603         needs_constant = true;
604         constant_type = GetInitValue(*use->to_apply());
605         break;
606 
607       case HloOpcode::kSelectAndScatter:
608         needs_constant = true;
609         constant_type = GetInitValue(*use->scatter());
610         break;
611 
612       case HloOpcode::kSort:
613         no_duplicates = true;
614         break;
615 
616       default:
617         return Unimplemented(
618             "Constrained operand generation not implemented for %s.",
619             use->ToString());
620     }
621   }
622   int constraint_count = 0;
623   constraint_count += no_duplicates ? 1 : 0;
624   constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
625   constraint_count += needs_constant ? 1 : 0;
626   if (constraint_count > 1) {
627     return Unimplemented("Conflicting operand generation constraints.");
628   }
629   if (index_bound != INT64_MAX) {
630     return MakeFakeLiteralInternalWithBounds(param_shape, engine, -1,
631                                              index_bound, needs_sorted_indices);
632   } else if (needs_constant) {
633     switch (constant_type) {
634       case ConstantType::kZero:
635         return LiteralUtil::Zero(param_shape.element_type());
636       case ConstantType::kOne:
637         return LiteralUtil::One(param_shape.element_type());
638       case ConstantType::kUnknown:
639         // We want the identity element for the computation, but we don't really
640         // know what it is - so any value we generate will be just as wrong.
641         return MakeFakeLiteralInternal(param_shape, engine,
642                                        /*no_duplicates=*/false,
643                                        use_large_range);
644     }
645   } else {
646     return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
647                                    use_large_range);
648   }
649 }
650 
651 // Given a module entry parameter, use the dataflow analysis to see if a
652 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)653 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
654                                           const HloInstruction& param,
655                                           const Shape& param_shape,
656                                           std::minstd_rand0* engine,
657                                           bool use_large_range) {
658   const auto constrained_uses = FindConstrainedUses(dataflow, param);
659   return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
660                                          engine, use_large_range);
661 }
662 
663 }  // namespace
664 
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)665 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
666                                   bool use_large_range) {
667   auto engine =
668       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
669   return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
670                                  use_large_range);
671 }
672 
MakeFakeArguments(HloModule * const module,bool pseudo_random,bool use_large_range)673 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
674                                                  bool pseudo_random,
675                                                  bool use_large_range) {
676   auto engine =
677       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
678   return MakeFakeArguments(module, engine.get(), use_large_range);
679 }
680 
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine,bool use_large_range)681 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
682                                                  std::minstd_rand0* engine,
683                                                  bool use_large_range) {
684   TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
685   const auto params = module->entry_computation()->parameter_instructions();
686   std::vector<Literal> arguments(params.size());
687   for (int i = 0; i < params.size(); ++i) {
688     const HloModuleConfig& module_config = module->config();
689     const Shape& param_shape = (module_config.has_entry_computation_layout() &&
690                                 module_config.entry_computation_layout()
691                                     .parameter_layout(i)
692                                     .shape()
693                                     .is_static())
694                                    ? module_config.entry_computation_layout()
695                                          .parameter_layout(i)
696                                          .shape()
697                                    : params[i]->shape();
698 
699     arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], param_shape,
700                                            engine, use_large_range)
701                        .ValueOrDie();
702   }
703   return std::move(arguments);
704 }
705 
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)706 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
707                        bool allow_mixed_precision) {
708   return HloVerifier(/*layout_sensitive=*/layout_sensitive,
709                      /*allow_mixed_precision=*/allow_mixed_precision)
710       .Run(module)
711       .status();
712 }
713 
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)714 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
715                                                       HloInstruction* lhs,
716                                                       HloInstruction* rhs) {
717   CHECK_LE(lhs->shape().rank(), 2);
718   CHECK_LE(rhs->shape().rank(), 2);
719   PrecisionConfig precision_config;
720   precision_config.mutable_operand_precision()->Resize(
721       2, PrecisionConfig::DEFAULT);
722   DotDimensionNumbers dot_dimension_numbers;
723   dot_dimension_numbers.add_lhs_contracting_dimensions(
724       lhs->shape().rank() > 1 ? 1 : 0);
725   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
726   return absl::make_unique<HloDotInstruction>(
727       shape, lhs, rhs, dot_dimension_numbers, precision_config);
728 }
729 }  // namespace xla
730