• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17 
18 #include <cmath>
19 
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29 
30 namespace xla {
31 
32 namespace {
33 
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36                                          std::minstd_rand0* engine) {
37   std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38   for (FloatT& value : literal->data<FloatT>()) {
39     value = static_cast<FloatT>(generator(*engine));
40   }
41 }
42 
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48                                                   std::minstd_rand0* engine) {
49   constexpr float kSpecialValueProbability = 1e-6;
50   constexpr float kSpecialValues[] = {+0.F,
51                                       -0.F,
52                                       1.F,
53                                       -1.F,
54                                       std::numeric_limits<float>::infinity(),
55                                       -std::numeric_limits<float>::infinity()};
56   constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57   std::uniform_real_distribution<float> special_value_gen(0, 1);
58 
59   // Generates floating points with a log-uniform distribution. This causes the
60   // exponent of the floating point to have a uniform distribution.
61   int min_exp, max_exp;
62   if (std::is_same<FloatT, bfloat16>()) {
63     min_exp = std::numeric_limits<float>::min_exponent;
64     max_exp = std::numeric_limits<float>::max_exponent;
65   } else {
66     min_exp = std::numeric_limits<FloatT>::min_exponent;
67     max_exp = std::numeric_limits<FloatT>::max_exponent;
68   }
69   std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70 
71   for (FloatT& value : literal->data<FloatT>()) {
72     // Each special value has a kSpecialValueProbability chance to be generated
73     // instead of sampling using the normal distributions.
74     if (special_value_gen(*engine) <
75         kSpecialValueProbability * kNumSpecialValues) {
76       value =
77           static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78     } else {
79       float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80       value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81     }
82   }
83 }
84 
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87 
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90   // Duplicates may be generated if we don't have enough bits.
91   uint16 next_value = 0;
92   for (half& value : literal->data<half>()) {
93     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94     // the sign bit. We could be less wasteful, but this is best-effort anyway.
95     uint16 exponent_msb = next_value & 0x4000;
96     value = Eigen::numext::bit_cast<half, uint16>((next_value & 0xBFFF) |
97                                                   (exponent_msb << 1));
98     next_value++;
99   }
100 }
101 
102 template <>
PopulateWithIntNext(Literal * literal)103 void PopulateWithIntNext<bfloat16>(Literal* literal) {
104   // Duplicates may be generated if we don't have enough bits.
105   // Start at 0x80 rather than 0 to avoid denormals.
106   uint16 next_value = 0x80;
107   for (bfloat16& value : literal->data<bfloat16>()) {
108     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
109     // the sign bit. We could be less wasteful, but this is best-effort anyway.
110     uint16 exponent_msb = next_value & 0x4000;
111     value = Eigen::numext::bit_cast<bfloat16, uint16>((next_value & 0xBFFF) |
112                                                       (exponent_msb << 1));
113     next_value++;
114   }
115 }
116 
117 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)118 void PopulateWithNextAfter(Literal* literal) {
119   // Duplicates may be generated if the number of elements in the literal
120   // exceeds the number of positive values supported by the type.
121   float next_value = std::numeric_limits<float>::min();
122   for (float& value : literal->data<float>()) {
123     value = next_value;
124     next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
125   }
126 }
127 
128 template <typename FloatT,
129           typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
130                                       std::is_same<half, FloatT>::value,
131                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)132 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
133   PopulateWithIntNext<FloatT>(literal);
134   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
135                *engine);
136 }
137 
138 template <typename FloatT,
139           typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
140                                       !std::is_same<half, FloatT>::value,
141                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)142 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
143   PopulateWithNextAfter<FloatT>(literal);
144   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
145                *engine);
146 }
147 
148 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)149 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
150                                    bool no_duplicates, bool use_large_range) {
151   CHECK(engine != nullptr);
152   CHECK_EQ(literal->shape().element_type(),
153            primitive_util::NativeToPrimitiveType<FloatT>());
154   if (no_duplicates) {
155     PopulateWithNoDuplicateData<FloatT>(literal, engine);
156   } else if (use_large_range) {
157     PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
158   } else {
159     PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
160   }
161 }
162 
163 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)164 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
165                              bool no_duplicates, bool use_large_range) {
166   using InnerFloatT = typename ComplexT::value_type;
167   CHECK(engine != nullptr);
168   CHECK_EQ(result->shape().element_type(),
169            primitive_util::NativeToPrimitiveType<ComplexT>());
170   Shape floating_point_shape = ShapeUtil::ChangeElementType(
171       result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
172   Literal real_lit(floating_point_shape);
173   Literal imaginary_lit(floating_point_shape);
174 
175   PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
176                                              use_large_range);
177   PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
178                                              no_duplicates, use_large_range);
179 
180   absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
181   absl::Span<const InnerFloatT> imaginary_data =
182       imaginary_lit.data<InnerFloatT>();
183   absl::Span<ComplexT> result_data = result->data<ComplexT>();
184   for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
185     result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
186   }
187 }
188 
189 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)190 void PopulateWithFloatingPointData<half>(Literal* literal,
191                                          std::minstd_rand0* engine,
192                                          bool no_duplicates,
193                                          bool use_large_range) {
194   CHECK(engine != nullptr);
195   CHECK_EQ(literal->shape().element_type(),
196            primitive_util::NativeToPrimitiveType<half>());
197   if (no_duplicates) {
198     PopulateWithNoDuplicateData<half>(literal, engine);
199   } else if (use_large_range) {
200     PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
201   } else {
202     PopulateWithRandomFloatingPointData<half, float>(literal, engine);
203   }
204 }
205 
206 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)207 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
208                                              std::minstd_rand0* engine,
209                                              bool no_duplicates,
210                                              bool use_large_range) {
211   CHECK(engine != nullptr);
212   CHECK_EQ(literal->shape().element_type(),
213            primitive_util::NativeToPrimitiveType<bfloat16>());
214   if (no_duplicates) {
215     PopulateWithNoDuplicateData<bfloat16>(literal, engine);
216   } else if (use_large_range) {
217     PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
218   } else {
219     PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
220   }
221 }
222 
223 // uniform_int_distribution is not defined for 8-bit integers.
224 // Use 'short' for those types.
225 template <typename IntT>
226 struct RngT {
227   using type = IntT;
228 };
229 
230 template <>
231 struct RngT<int8> {
232   using type = int16;
233 };
234 
235 template <>
236 struct RngT<uint8> {
237   using type = uint16;
238 };
239 
240 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)241 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
242                                     bool no_duplicates) {
243   CHECK(engine != nullptr);
244   CHECK_EQ(literal->shape().element_type(),
245            primitive_util::NativeToPrimitiveType<IntT>());
246   if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
247                            std::numeric_limits<IntT>::max()) {
248     std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
249     std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
250                  *engine);
251   } else {
252     std::uniform_int_distribution<typename RngT<IntT>::type> generator(
253         std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
254     for (IntT& value : literal->data<IntT>()) {
255       value = generator(*engine);
256     }
257   }
258 }
259 
260 // Similar to MakeFakeLiteral but takes a random number generator engine to
261 // enable reusing the engine across randomly generated literals. 'no_duplicates'
262 // indicates that there should be no duplicate values in each generated
263 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
264 // are not supported and uniqueness cannot be guaranteed if the number of
265 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)266 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
267                                           std::minstd_rand0* engine,
268                                           bool no_duplicates,
269                                           bool use_large_range) {
270   if (shape.IsTuple()) {
271     std::vector<Literal> elements;
272     for (const Shape& element_shape : shape.tuple_shapes()) {
273       TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
274                                                element_shape, engine,
275                                                no_duplicates, use_large_range));
276       elements.push_back(std::move(element));
277     }
278     return LiteralUtil::MakeTupleOwned(std::move(elements));
279   }
280   if (engine == nullptr) {
281     return Literal::CreateFromShape(shape);
282   }
283   // Clear tiles/element size in shape's layout before using it for creating
284   // literal.
285   Shape new_shape = shape;
286   new_shape.mutable_layout()->clear_tiles();
287   new_shape.mutable_layout()->set_element_size_in_bits(0);
288   Literal literal(new_shape);
289   switch (shape.element_type()) {
290     case BF16:
291       PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
292                                               use_large_range);
293       break;
294     case F16:
295       PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
296                                           use_large_range);
297       break;
298     case F32:
299       PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
300                                            use_large_range);
301       break;
302     case F64:
303       PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
304                                             use_large_range);
305       break;
306     case S8:
307       PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
308       break;
309     case U8:
310       PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
311       break;
312     case S16:
313       PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
314       break;
315     case U16:
316       PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
317       break;
318     case S32:
319       PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
320       break;
321     case U32:
322       PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
323       break;
324     case S64:
325       PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
326       break;
327     case U64:
328       PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
329       break;
330     case C64:
331       PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
332                                          use_large_range);
333       break;
334     case C128:
335       PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
336                                           use_large_range);
337       break;
338     case PRED: {
339       std::uniform_int_distribution<int> generator(0, 1);
340       TF_CHECK_OK(
341           literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
342             return generator(*engine);
343           }));
344       break;
345     }
346     default:
347       return Unimplemented("Unsupported type for fake literal generation: %s",
348                            ShapeUtil::HumanString(shape));
349   }
350   return std::move(literal);
351 }
352 
353 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)354 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
355                                               std::minstd_rand0* engine,
356                                               IntT min, IntT max) {
357   CHECK(engine != nullptr);
358   CHECK_EQ(literal->shape().element_type(),
359            primitive_util::NativeToPrimitiveType<IntT>());
360   std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
361   for (IntT& value : literal->data<IntT>()) {
362     value = generator(*engine);
363   }
364 }
365 
366 // Same as MakeFakeLiteralInternal but generates random numbers in the given
367 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64_t min,int64_t max,bool is_sorted)368 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
369                                                     std::minstd_rand0* engine,
370                                                     int64_t min, int64_t max,
371                                                     bool is_sorted) {
372   if (shape.IsTuple()) {
373     std::vector<Literal> elements;
374     for (const Shape& element_shape : shape.tuple_shapes()) {
375       TF_ASSIGN_OR_RETURN(Literal element,
376                           MakeFakeLiteralInternalWithBounds(
377                               element_shape, engine, min, max, is_sorted));
378       elements.push_back(std::move(element));
379     }
380     return LiteralUtil::MakeTupleOwned(std::move(elements));
381   }
382   if (engine == nullptr) {
383     return Literal::CreateFromShape(shape);
384   }
385   // Clear tiles/element size in shape's layout before using it for creating
386   // literal.
387   Shape new_shape = shape;
388   new_shape.mutable_layout()->clear_tiles();
389   new_shape.mutable_layout()->set_element_size_in_bits(0);
390   Literal literal(new_shape);
391   switch (shape.element_type()) {
392     case S8:
393       PopulateWithRandomIntegralDataWithBounds<int8>(
394           &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
395       if (is_sorted) {
396         std::sort(literal.data<int8>().begin(), literal.data<int8>().end());
397       }
398       break;
399     case U8:
400       PopulateWithRandomIntegralDataWithBounds<uint8>(
401           &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
402       if (is_sorted) {
403         std::sort(literal.data<uint8>().begin(), literal.data<uint8>().end());
404       }
405       break;
406     case S16:
407       PopulateWithRandomIntegralDataWithBounds<int16>(
408           &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
409       if (is_sorted) {
410         std::sort(literal.data<int16>().begin(), literal.data<int16>().end());
411       }
412       break;
413     case U16:
414       PopulateWithRandomIntegralDataWithBounds<uint16>(
415           &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
416       if (is_sorted) {
417         std::sort(literal.data<uint16>().begin(), literal.data<uint16>().end());
418       }
419       break;
420     case S32:
421       PopulateWithRandomIntegralDataWithBounds<int32>(
422           &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
423       if (is_sorted) {
424         std::sort(literal.data<int32>().begin(), literal.data<int32>().end());
425       }
426       break;
427     case U32:
428       PopulateWithRandomIntegralDataWithBounds<uint32>(
429           &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
430       if (is_sorted) {
431         std::sort(literal.data<uint32>().begin(), literal.data<uint32>().end());
432       }
433       break;
434     case S64:
435       PopulateWithRandomIntegralDataWithBounds<int64>(
436           &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
437       if (is_sorted) {
438         std::sort(literal.data<int64>().begin(), literal.data<int64>().end());
439       }
440       break;
441     case U64:
442       PopulateWithRandomIntegralDataWithBounds<uint64>(
443           &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
444       if (is_sorted) {
445         std::sort(literal.data<uint64>().begin(), literal.data<uint64>().end());
446       }
447       break;
448     default:
449       return Unimplemented(
450           "Unsupported type for fake random literal generation with bounds: %s",
451           ShapeUtil::HumanString(shape));
452   }
453   return std::move(literal);
454 }
455 
456 enum class ConstantType { kUnknown, kZero, kOne };
457 
458 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)459 ConstantType GetInitValue(const HloComputation& computation) {
460   // TODO(b/77635120): Add init values, for min, max, and their arg variants.
461   const HloInstruction* const root = computation.root_instruction();
462   if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
463       root->operand(0)->opcode() != HloOpcode::kParameter ||
464       root->operand(1)->opcode() != HloOpcode::kParameter ||
465       root->operand(0) == root->operand(1)) {
466     return ConstantType::kUnknown;
467   }
468 
469   switch (root->opcode()) {
470     case HloOpcode::kAdd:
471       return ConstantType::kZero;
472     case HloOpcode::kMultiply:
473       return ConstantType::kOne;
474     default:
475       return ConstantType::kUnknown;
476   }
477 }
478 
479 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
480 // initialization value.
NeedsInitValue(const HloUse & use)481 bool NeedsInitValue(const HloUse& use) {
482   const HloInstruction* const instruction = use.instruction;
483   const HloOpcode opcode = instruction->opcode();
484   const int64_t op_num = use.operand_number;
485   return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
486           (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
487           (opcode == HloOpcode::kReduce &&
488            op_num >= instruction->operand_count() / 2));
489 }
490 
491 // Generate random values that are constrained to the input_shape minus the
492 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64_t index_bound,std::minstd_rand0 * engine)493 Literal MakeRandomIndex(int64_t index_bound, std::minstd_rand0* engine) {
494   std::uniform_int_distribution<int32> generator(0, index_bound);
495   return LiteralUtil::CreateR0<int32>(generator(*engine));
496 }
497 
498 // Use dataflow analysis on each parameter to see if there are uses that would
499 // be problematic when generating input data.  Returns the list of instructions
500 // that correspond to their uses.
501 //
502 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)503 std::vector<HloInstruction*> FindConstrainedUses(
504     const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
505   std::vector<HloInstruction*> constrained_uses;
506   for (const auto& pair : dataflow.GetInstructionValueSet(&param)) {
507     const HloValue& value = dataflow.GetUniqueValueAt(&param, pair.first);
508     for (const HloUse& use : value.uses()) {
509       HloInstruction* instruction = use.instruction;
510       const HloOpcode opcode = instruction->opcode();
511       const int64_t op_num = use.operand_number;
512       if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
513           (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
514         constrained_uses.push_back(instruction);
515       } else if ((opcode == HloOpcode::kGather ||
516                   opcode == HloOpcode::kScatter) &&
517                  op_num == 1) {
518         constrained_uses.push_back(instruction);
519       } else if (opcode == HloOpcode::kFusion) {
520         const HloInstruction* const to_analyze =
521             instruction->fused_parameter(op_num);
522         auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
523         constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
524                                 fused_uses.end());
525       } else if (NeedsInitValue(use)) {
526         constrained_uses.push_back(instruction);
527       } else if (opcode == HloOpcode::kConvert ||
528                  opcode == HloOpcode::kReducePrecision) {
529         auto converted_uses = FindConstrainedUses(dataflow, *instruction);
530         constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
531                                 converted_uses.end());
532       } else if (opcode == HloOpcode::kSort &&
533                  instruction->operand_count() >= 2 && op_num == 0) {
534         // Operand 0 of sort is the array of keys used for key/value
535         // (two-operand) kSort instructions. Since sort stability is not
536         // guaranteed, constrain keys of key-value sort not to have duplicates,
537         // since otherwise the value order may legitimately differ.
538         constrained_uses.push_back(instruction);
539       }
540     }
541   }
542   return constrained_uses;
543 }
544 
545 // Given a parameter, generate a random Literal to use as input if there exist
546 // no constrained uses in the dataflow graph.  If such constraints exist,
547 // generate a constrained literal (either bounded in the case of indices, or
548 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)549 StatusOr<Literal> CreateLiteralForConstrainedUses(
550     const absl::Span<HloInstruction* const> constrained_uses,
551     const HloInstruction& param, const Shape& param_shape,
552     std::minstd_rand0* engine, bool use_large_range) {
553   int64_t index_bound = INT64_MAX;
554   bool no_duplicates = false;
555   bool needs_constant = false;
556   bool needs_sorted_indices = false;
557   ConstantType constant_type = ConstantType::kUnknown;
558   for (HloInstruction* use : constrained_uses) {
559     switch (use->opcode()) {
560       case HloOpcode::kDynamicSlice:
561       case HloOpcode::kDynamicUpdateSlice: {
562         const Shape& indexed_shape = use->operand(0)->shape();
563         const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
564                                        ? use->shape()
565                                        : use->operand(1)->shape();
566         const int64_t first_index =
567             Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
568         for (int64_t operand = first_index; operand < use->operand_count();
569              ++operand) {
570           if (use->operand(operand) == &param) {
571             index_bound = std::min(
572                 index_bound,
573                 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
574                     ShapeUtil::GetDimension(slice_shape,
575                                             operand - first_index));
576           }
577         }
578         break;
579       }
580       case HloOpcode::kGather:
581       case HloOpcode::kScatter: {
582         const Shape& operand_shape = use->operand(0)->shape();
583         if (use->operand(1) == &param) {
584           auto index_map =
585               use->opcode() == HloOpcode::kGather
586                   ? use->gather_dimension_numbers().start_index_map()
587                   : use->scatter_dimension_numbers()
588                         .scatter_dims_to_operand_dims();
589           for (const auto dim_in_operand : index_map) {
590             index_bound =
591                 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
592           }
593         }
594         if (use->opcode() == HloOpcode::kScatter) {
595           needs_sorted_indices |=
596               Cast<const HloScatterInstruction>(use)->indices_are_sorted();
597         } else {
598           needs_sorted_indices |=
599               Cast<const HloGatherInstruction>(use)->indices_are_sorted();
600         }
601         break;
602       }
603       case HloOpcode::kReduce:
604       case HloOpcode::kReduceWindow:
605         needs_constant = true;
606         constant_type = GetInitValue(*use->to_apply());
607         break;
608 
609       case HloOpcode::kSelectAndScatter:
610         needs_constant = true;
611         constant_type = GetInitValue(*use->scatter());
612         break;
613 
614       case HloOpcode::kSort:
615         no_duplicates = true;
616         break;
617 
618       default:
619         return Unimplemented(
620             "Constrained operand generation not implemented for %s.",
621             use->ToString());
622     }
623   }
624   int constraint_count = 0;
625   constraint_count += no_duplicates ? 1 : 0;
626   constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
627   constraint_count += needs_constant ? 1 : 0;
628   if (constraint_count > 1) {
629     return Unimplemented("Conflicting operand generation constraints.");
630   }
631   if (index_bound != INT64_MAX) {
632     return MakeFakeLiteralInternalWithBounds(param_shape, engine, -1,
633                                              index_bound, needs_sorted_indices);
634   } else if (needs_constant) {
635     switch (constant_type) {
636       case ConstantType::kZero:
637         return LiteralUtil::Zero(param_shape.element_type());
638       case ConstantType::kOne:
639         return LiteralUtil::One(param_shape.element_type());
640       case ConstantType::kUnknown:
641         // We want the identity element for the computation, but we don't really
642         // know what it is - so any value we generate will be just as wrong.
643         return MakeFakeLiteralInternal(param_shape, engine,
644                                        /*no_duplicates=*/false,
645                                        use_large_range);
646     }
647   } else {
648     return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
649                                    use_large_range);
650   }
651 }
652 
653 // Given a module entry parameter, use the dataflow analysis to see if a
654 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)655 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
656                                           const HloInstruction& param,
657                                           const Shape& param_shape,
658                                           std::minstd_rand0* engine,
659                                           bool use_large_range) {
660   const auto constrained_uses = FindConstrainedUses(dataflow, param);
661   return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
662                                          engine, use_large_range);
663 }
664 
665 }  // namespace
666 
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)667 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
668                                   bool use_large_range) {
669   auto engine =
670       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
671   return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
672                                  use_large_range);
673 }
674 
MakeFakeArguments(HloModule * const module,bool pseudo_random,bool use_large_range)675 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
676                                                  bool pseudo_random,
677                                                  bool use_large_range) {
678   auto engine =
679       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
680   return MakeFakeArguments(module, engine.get(), use_large_range);
681 }
682 
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine,bool use_large_range)683 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
684                                                  std::minstd_rand0* engine,
685                                                  bool use_large_range) {
686   TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
687   const auto params = module->entry_computation()->parameter_instructions();
688   std::vector<Literal> arguments(params.size());
689   for (int i = 0; i < params.size(); ++i) {
690     const HloModuleConfig& module_config = module->config();
691     const Shape& param_shape = (module_config.has_entry_computation_layout() &&
692                                 module_config.entry_computation_layout()
693                                     .parameter_layout(i)
694                                     .shape()
695                                     .is_static())
696                                    ? module_config.entry_computation_layout()
697                                          .parameter_layout(i)
698                                          .shape()
699                                    : params[i]->shape();
700 
701     arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], param_shape,
702                                            engine, use_large_range)
703                        .ValueOrDie();
704   }
705   return std::move(arguments);
706 }
707 
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)708 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
709                        bool allow_mixed_precision) {
710   return HloVerifier(/*layout_sensitive=*/layout_sensitive,
711                      /*allow_mixed_precision=*/allow_mixed_precision)
712       .Run(module)
713       .status();
714 }
715 
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)716 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
717                                                       HloInstruction* lhs,
718                                                       HloInstruction* rhs) {
719   CHECK_LE(lhs->shape().rank(), 2);
720   CHECK_LE(rhs->shape().rank(), 2);
721   PrecisionConfig precision_config;
722   precision_config.mutable_operand_precision()->Resize(
723       2, PrecisionConfig::DEFAULT);
724   DotDimensionNumbers dot_dimension_numbers;
725   dot_dimension_numbers.add_lhs_contracting_dimensions(
726       lhs->shape().rank() > 1 ? 1 : 0);
727   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
728   return absl::make_unique<HloDotInstruction>(
729       shape, lhs, rhs, dot_dimension_numbers, precision_config);
730 }
731 }  // namespace xla
732