1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17
18 #include <cmath>
19
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29
30 namespace xla {
31
32 namespace {
33
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36 std::minstd_rand0* engine) {
37 std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38 for (FloatT& value : literal->data<FloatT>()) {
39 value = static_cast<FloatT>(generator(*engine));
40 }
41 }
42
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48 std::minstd_rand0* engine) {
49 constexpr float kSpecialValueProbability = 1e-6;
50 constexpr float kSpecialValues[] = {+0.F,
51 -0.F,
52 1.F,
53 -1.F,
54 std::numeric_limits<float>::infinity(),
55 -std::numeric_limits<float>::infinity()};
56 constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57 std::uniform_real_distribution<float> special_value_gen(0, 1);
58
59 // Generates floating points with a log-uniform distribution. This causes the
60 // exponent of the floating point to have a uniform distribution.
61 int min_exp, max_exp;
62 if (std::is_same<FloatT, bfloat16>()) {
63 min_exp = std::numeric_limits<float>::min_exponent;
64 max_exp = std::numeric_limits<float>::max_exponent;
65 } else {
66 min_exp = std::numeric_limits<FloatT>::min_exponent;
67 max_exp = std::numeric_limits<FloatT>::max_exponent;
68 }
69 std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70
71 for (FloatT& value : literal->data<FloatT>()) {
72 // Each special value has a kSpecialValueProbability chance to be generated
73 // instead of sampling using the normal distributions.
74 if (special_value_gen(*engine) <
75 kSpecialValueProbability * kNumSpecialValues) {
76 value =
77 static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78 } else {
79 float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80 value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81 }
82 }
83 }
84
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90 // Duplicates may be generated if we don't have enough bits.
91 uint16 next_value = 0;
92 for (half& value : literal->data<half>()) {
93 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94 // the sign bit. We could be less wasteful, but this is best-effort anyway.
95 uint16 exponent_msb = next_value & 0x4000;
96 value.x = (next_value & 0xBFFF) | (exponent_msb << 1);
97 next_value++;
98 }
99 }
100
101 template <>
PopulateWithIntNext(Literal * literal)102 void PopulateWithIntNext<bfloat16>(Literal* literal) {
103 // Duplicates may be generated if we don't have enough bits.
104 // Start at 0x80 rather than 0 to avoid denormals.
105 uint16 next_value = 0x80;
106 for (bfloat16& value : literal->data<bfloat16>()) {
107 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
108 // the sign bit. We could be less wasteful, but this is best-effort anyway.
109 uint16 exponent_msb = next_value & 0x4000;
110 value.value = (next_value & 0xBFFF) | (exponent_msb << 1);
111 next_value++;
112 }
113 }
114
115 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)116 void PopulateWithNextAfter(Literal* literal) {
117 // Duplicates may be generated if the number of elements in the literal
118 // exceeds the number of positive values supported by the type.
119 float next_value = std::numeric_limits<float>::min();
120 for (float& value : literal->data<float>()) {
121 value = next_value;
122 next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
123 }
124 }
125
126 template <typename FloatT,
127 typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
128 std::is_same<half, FloatT>::value,
129 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)130 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
131 PopulateWithIntNext<FloatT>(literal);
132 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
133 *engine);
134 }
135
136 template <typename FloatT,
137 typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
138 !std::is_same<half, FloatT>::value,
139 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)140 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
141 PopulateWithNextAfter<FloatT>(literal);
142 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
143 *engine);
144 }
145
146 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)147 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
148 bool no_duplicates, bool use_large_range) {
149 CHECK(engine != nullptr);
150 CHECK_EQ(literal->shape().element_type(),
151 primitive_util::NativeToPrimitiveType<FloatT>());
152 if (no_duplicates) {
153 PopulateWithNoDuplicateData<FloatT>(literal, engine);
154 } else if (use_large_range) {
155 PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
156 } else {
157 PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
158 }
159 }
160
161 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)162 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
163 bool no_duplicates, bool use_large_range) {
164 using InnerFloatT = typename ComplexT::value_type;
165 CHECK(engine != nullptr);
166 CHECK_EQ(result->shape().element_type(),
167 primitive_util::NativeToPrimitiveType<ComplexT>());
168 Shape floating_point_shape = ShapeUtil::ChangeElementType(
169 result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
170 Literal real_lit(floating_point_shape);
171 Literal imaginary_lit(floating_point_shape);
172
173 PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
174 use_large_range);
175 PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
176 no_duplicates, use_large_range);
177
178 absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
179 absl::Span<const InnerFloatT> imaginary_data =
180 imaginary_lit.data<InnerFloatT>();
181 absl::Span<ComplexT> result_data = result->data<ComplexT>();
182 for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
183 result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
184 }
185 }
186
187 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)188 void PopulateWithFloatingPointData<half>(Literal* literal,
189 std::minstd_rand0* engine,
190 bool no_duplicates,
191 bool use_large_range) {
192 CHECK(engine != nullptr);
193 CHECK_EQ(literal->shape().element_type(),
194 primitive_util::NativeToPrimitiveType<half>());
195 if (no_duplicates) {
196 PopulateWithNoDuplicateData<half>(literal, engine);
197 } else if (use_large_range) {
198 PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
199 } else {
200 PopulateWithRandomFloatingPointData<half, float>(literal, engine);
201 }
202 }
203
204 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)205 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
206 std::minstd_rand0* engine,
207 bool no_duplicates,
208 bool use_large_range) {
209 CHECK(engine != nullptr);
210 CHECK_EQ(literal->shape().element_type(),
211 primitive_util::NativeToPrimitiveType<bfloat16>());
212 if (no_duplicates) {
213 PopulateWithNoDuplicateData<bfloat16>(literal, engine);
214 } else if (use_large_range) {
215 PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
216 } else {
217 PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
218 }
219 }
220
221 // uniform_int_distribution is not defined for 8-bit integers.
222 // Use 'short' for those types.
223 template <typename IntT>
224 struct RngT {
225 using type = IntT;
226 };
227
228 template <>
229 struct RngT<int8> {
230 using type = int16;
231 };
232
233 template <>
234 struct RngT<uint8> {
235 using type = uint16;
236 };
237
238 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)239 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
240 bool no_duplicates) {
241 CHECK(engine != nullptr);
242 CHECK_EQ(literal->shape().element_type(),
243 primitive_util::NativeToPrimitiveType<IntT>());
244 if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
245 std::numeric_limits<IntT>::max()) {
246 std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
247 std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
248 *engine);
249 } else {
250 std::uniform_int_distribution<typename RngT<IntT>::type> generator(
251 std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
252 for (IntT& value : literal->data<IntT>()) {
253 value = generator(*engine);
254 }
255 }
256 }
257
258 // Similar to MakeFakeLiteral but takes a random number generator engine to
259 // enable reusing the engine across randomly generated literals. 'no_duplicates'
260 // indicates that there should be no duplicate values in each generated
261 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
262 // are not supported and uniqueness cannot be guaranteed if the number of
263 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)264 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
265 std::minstd_rand0* engine,
266 bool no_duplicates,
267 bool use_large_range) {
268 if (shape.IsTuple()) {
269 std::vector<Literal> elements;
270 for (const Shape& element_shape : shape.tuple_shapes()) {
271 TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
272 element_shape, engine,
273 no_duplicates, use_large_range));
274 elements.push_back(std::move(element));
275 }
276 return LiteralUtil::MakeTupleOwned(std::move(elements));
277 }
278 if (engine == nullptr) {
279 return Literal::CreateFromShape(shape);
280 }
281 // Clear tiles/element size in shape's layout before using it for creating
282 // literal.
283 Shape new_shape = shape;
284 new_shape.mutable_layout()->clear_tiles();
285 new_shape.mutable_layout()->set_element_size_in_bits(0);
286 Literal literal(new_shape);
287 switch (shape.element_type()) {
288 case BF16:
289 PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
290 use_large_range);
291 break;
292 case F16:
293 PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
294 use_large_range);
295 break;
296 case F32:
297 PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
298 use_large_range);
299 break;
300 case F64:
301 PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
302 use_large_range);
303 break;
304 case S8:
305 PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
306 break;
307 case U8:
308 PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
309 break;
310 case S16:
311 PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
312 break;
313 case U16:
314 PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
315 break;
316 case S32:
317 PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
318 break;
319 case U32:
320 PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
321 break;
322 case S64:
323 PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
324 break;
325 case U64:
326 PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
327 break;
328 case C64:
329 PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
330 use_large_range);
331 break;
332 case C128:
333 PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
334 use_large_range);
335 break;
336 case PRED: {
337 std::uniform_int_distribution<int> generator(0, 1);
338 TF_CHECK_OK(
339 literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
340 return generator(*engine);
341 }));
342 break;
343 }
344 default:
345 return Unimplemented("Unsupported type for fake literal generation: %s",
346 ShapeUtil::HumanString(shape));
347 }
348 return std::move(literal);
349 }
350
351 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)352 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
353 std::minstd_rand0* engine,
354 IntT min, IntT max) {
355 CHECK(engine != nullptr);
356 CHECK_EQ(literal->shape().element_type(),
357 primitive_util::NativeToPrimitiveType<IntT>());
358 std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
359 for (IntT& value : literal->data<IntT>()) {
360 value = generator(*engine);
361 }
362 }
363
364 // Same as MakeFakeLiteralInternal but generates random numbers in the given
365 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64 min,int64 max,bool is_sorted)366 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
367 std::minstd_rand0* engine,
368 int64 min, int64 max,
369 bool is_sorted) {
370 if (shape.IsTuple()) {
371 std::vector<Literal> elements;
372 for (const Shape& element_shape : shape.tuple_shapes()) {
373 TF_ASSIGN_OR_RETURN(Literal element,
374 MakeFakeLiteralInternalWithBounds(
375 element_shape, engine, min, max, is_sorted));
376 elements.push_back(std::move(element));
377 }
378 return LiteralUtil::MakeTupleOwned(std::move(elements));
379 }
380 if (engine == nullptr) {
381 return Literal::CreateFromShape(shape);
382 }
383 // Clear tiles/element size in shape's layout before using it for creating
384 // literal.
385 Shape new_shape = shape;
386 new_shape.mutable_layout()->clear_tiles();
387 new_shape.mutable_layout()->set_element_size_in_bits(0);
388 Literal literal(new_shape);
389 switch (shape.element_type()) {
390 case S8:
391 PopulateWithRandomIntegralDataWithBounds<int8>(
392 &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
393 if (is_sorted) {
394 std::sort(literal.data<int8>().begin(), literal.data<int8>().end());
395 }
396 break;
397 case U8:
398 PopulateWithRandomIntegralDataWithBounds<uint8>(
399 &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
400 if (is_sorted) {
401 std::sort(literal.data<uint8>().begin(), literal.data<uint8>().end());
402 }
403 break;
404 case S16:
405 PopulateWithRandomIntegralDataWithBounds<int16>(
406 &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
407 if (is_sorted) {
408 std::sort(literal.data<int16>().begin(), literal.data<int16>().end());
409 }
410 break;
411 case U16:
412 PopulateWithRandomIntegralDataWithBounds<uint16>(
413 &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
414 if (is_sorted) {
415 std::sort(literal.data<uint16>().begin(), literal.data<uint16>().end());
416 }
417 break;
418 case S32:
419 PopulateWithRandomIntegralDataWithBounds<int32>(
420 &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
421 if (is_sorted) {
422 std::sort(literal.data<int32>().begin(), literal.data<int32>().end());
423 }
424 break;
425 case U32:
426 PopulateWithRandomIntegralDataWithBounds<uint32>(
427 &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
428 if (is_sorted) {
429 std::sort(literal.data<uint32>().begin(), literal.data<uint32>().end());
430 }
431 break;
432 case S64:
433 PopulateWithRandomIntegralDataWithBounds<int64>(
434 &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
435 if (is_sorted) {
436 std::sort(literal.data<int64>().begin(), literal.data<int64>().end());
437 }
438 break;
439 case U64:
440 PopulateWithRandomIntegralDataWithBounds<uint64>(
441 &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
442 if (is_sorted) {
443 std::sort(literal.data<uint64>().begin(), literal.data<uint64>().end());
444 }
445 break;
446 default:
447 return Unimplemented(
448 "Unsupported type for fake random literal generation with bounds: %s",
449 ShapeUtil::HumanString(shape));
450 }
451 return std::move(literal);
452 }
453
454 enum class ConstantType { kUnknown, kZero, kOne };
455
456 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)457 ConstantType GetInitValue(const HloComputation& computation) {
458 // TODO(b/77635120): Add init values, for min, max, and their arg variants.
459 const HloInstruction* const root = computation.root_instruction();
460 if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
461 root->operand(0)->opcode() != HloOpcode::kParameter ||
462 root->operand(1)->opcode() != HloOpcode::kParameter ||
463 root->operand(0) == root->operand(1)) {
464 return ConstantType::kUnknown;
465 }
466
467 switch (root->opcode()) {
468 case HloOpcode::kAdd:
469 return ConstantType::kZero;
470 case HloOpcode::kMultiply:
471 return ConstantType::kOne;
472 default:
473 return ConstantType::kUnknown;
474 }
475 }
476
477 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
478 // initialization value.
NeedsInitValue(const HloUse & use)479 bool NeedsInitValue(const HloUse& use) {
480 const HloInstruction* const instruction = use.instruction;
481 const HloOpcode opcode = instruction->opcode();
482 const int64 op_num = use.operand_number;
483 return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
484 (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
485 (opcode == HloOpcode::kReduce &&
486 op_num >= instruction->operand_count() / 2));
487 }
488
489 // Generate random values that are constrained to the input_shape minus the
490 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64 index_bound,std::minstd_rand0 * engine)491 Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) {
492 std::uniform_int_distribution<int32> generator(0, index_bound);
493 return LiteralUtil::CreateR0<int32>(generator(*engine));
494 }
495
496 // Use dataflow analysis on each parameter to see if there are uses that would
497 // be problematic when generating input data. Returns the list of instructions
498 // that correspond to their uses.
499 //
500 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)501 std::vector<HloInstruction*> FindConstrainedUses(
502 const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
503 std::vector<HloInstruction*> constrained_uses;
504 for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) {
505 const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first);
506 for (const HloUse& use : value.uses()) {
507 HloInstruction* instruction = use.instruction;
508 const HloOpcode opcode = instruction->opcode();
509 const int64 op_num = use.operand_number;
510 if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
511 (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
512 constrained_uses.push_back(instruction);
513 } else if ((opcode == HloOpcode::kGather ||
514 opcode == HloOpcode::kScatter) &&
515 op_num == 1) {
516 constrained_uses.push_back(instruction);
517 } else if (opcode == HloOpcode::kFusion) {
518 const HloInstruction* const to_analyze =
519 instruction->fused_parameter(op_num);
520 auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
521 constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
522 fused_uses.end());
523 } else if (NeedsInitValue(use)) {
524 constrained_uses.push_back(instruction);
525 } else if (opcode == HloOpcode::kConvert ||
526 opcode == HloOpcode::kReducePrecision) {
527 auto converted_uses = FindConstrainedUses(dataflow, *instruction);
528 constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
529 converted_uses.end());
530 } else if (opcode == HloOpcode::kSort &&
531 instruction->operand_count() >= 2 && op_num == 0) {
532 // Operand 0 of sort is the array of keys used for key/value
533 // (two-operand) kSort instructions. Since sort stability is not
534 // guaranteed, constrain keys of key-value sort not to have duplicates,
535 // since otherwise the value order may legitimately differ.
536 constrained_uses.push_back(instruction);
537 }
538 }
539 }
540 return constrained_uses;
541 }
542
543 // Given a parameter, generate a random Literal to use as input if there exist
544 // no constrained uses in the dataflow graph. If such constraints exist,
545 // generate a constrained literal (either bounded in the case of indices, or
546 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)547 StatusOr<Literal> CreateLiteralForConstrainedUses(
548 const absl::Span<HloInstruction* const> constrained_uses,
549 const HloInstruction& param, const Shape& param_shape,
550 std::minstd_rand0* engine, bool use_large_range) {
551 int64 index_bound = INT64_MAX;
552 bool no_duplicates = false;
553 bool needs_constant = false;
554 bool needs_sorted_indices = false;
555 ConstantType constant_type = ConstantType::kUnknown;
556 for (HloInstruction* use : constrained_uses) {
557 switch (use->opcode()) {
558 case HloOpcode::kDynamicSlice:
559 case HloOpcode::kDynamicUpdateSlice: {
560 const Shape& indexed_shape = use->operand(0)->shape();
561 const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
562 ? use->shape()
563 : use->operand(1)->shape();
564 const int64 first_index =
565 Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
566 for (int64 operand = first_index; operand < use->operand_count();
567 ++operand) {
568 if (use->operand(operand) == ¶m) {
569 index_bound = std::min(
570 index_bound,
571 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
572 ShapeUtil::GetDimension(slice_shape,
573 operand - first_index));
574 }
575 }
576 break;
577 }
578 case HloOpcode::kGather:
579 case HloOpcode::kScatter: {
580 const Shape& operand_shape = use->operand(0)->shape();
581 if (use->operand(1) == ¶m) {
582 auto index_map =
583 use->opcode() == HloOpcode::kGather
584 ? use->gather_dimension_numbers().start_index_map()
585 : use->scatter_dimension_numbers()
586 .scatter_dims_to_operand_dims();
587 for (const auto dim_in_operand : index_map) {
588 index_bound =
589 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
590 }
591 }
592 if (use->opcode() == HloOpcode::kScatter) {
593 needs_sorted_indices |=
594 Cast<const HloScatterInstruction>(use)->indices_are_sorted();
595 } else {
596 needs_sorted_indices |=
597 Cast<const HloGatherInstruction>(use)->indices_are_sorted();
598 }
599 break;
600 }
601 case HloOpcode::kReduce:
602 case HloOpcode::kReduceWindow:
603 needs_constant = true;
604 constant_type = GetInitValue(*use->to_apply());
605 break;
606
607 case HloOpcode::kSelectAndScatter:
608 needs_constant = true;
609 constant_type = GetInitValue(*use->scatter());
610 break;
611
612 case HloOpcode::kSort:
613 no_duplicates = true;
614 break;
615
616 default:
617 return Unimplemented(
618 "Constrained operand generation not implemented for %s.",
619 use->ToString());
620 }
621 }
622 int constraint_count = 0;
623 constraint_count += no_duplicates ? 1 : 0;
624 constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
625 constraint_count += needs_constant ? 1 : 0;
626 if (constraint_count > 1) {
627 return Unimplemented("Conflicting operand generation constraints.");
628 }
629 if (index_bound != INT64_MAX) {
630 return MakeFakeLiteralInternalWithBounds(param_shape, engine, -1,
631 index_bound, needs_sorted_indices);
632 } else if (needs_constant) {
633 switch (constant_type) {
634 case ConstantType::kZero:
635 return LiteralUtil::Zero(param_shape.element_type());
636 case ConstantType::kOne:
637 return LiteralUtil::One(param_shape.element_type());
638 case ConstantType::kUnknown:
639 // We want the identity element for the computation, but we don't really
640 // know what it is - so any value we generate will be just as wrong.
641 return MakeFakeLiteralInternal(param_shape, engine,
642 /*no_duplicates=*/false,
643 use_large_range);
644 }
645 } else {
646 return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
647 use_large_range);
648 }
649 }
650
651 // Given a module entry parameter, use the dataflow analysis to see if a
652 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)653 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
654 const HloInstruction& param,
655 const Shape& param_shape,
656 std::minstd_rand0* engine,
657 bool use_large_range) {
658 const auto constrained_uses = FindConstrainedUses(dataflow, param);
659 return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
660 engine, use_large_range);
661 }
662
663 } // namespace
664
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)665 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
666 bool use_large_range) {
667 auto engine =
668 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
669 return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
670 use_large_range);
671 }
672
MakeFakeArguments(HloModule * const module,bool pseudo_random,bool use_large_range)673 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
674 bool pseudo_random,
675 bool use_large_range) {
676 auto engine =
677 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
678 return MakeFakeArguments(module, engine.get(), use_large_range);
679 }
680
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine,bool use_large_range)681 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
682 std::minstd_rand0* engine,
683 bool use_large_range) {
684 TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
685 const auto params = module->entry_computation()->parameter_instructions();
686 std::vector<Literal> arguments(params.size());
687 for (int i = 0; i < params.size(); ++i) {
688 const HloModuleConfig& module_config = module->config();
689 const Shape& param_shape = (module_config.has_entry_computation_layout() &&
690 module_config.entry_computation_layout()
691 .parameter_layout(i)
692 .shape()
693 .is_static())
694 ? module_config.entry_computation_layout()
695 .parameter_layout(i)
696 .shape()
697 : params[i]->shape();
698
699 arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], param_shape,
700 engine, use_large_range)
701 .ValueOrDie();
702 }
703 return std::move(arguments);
704 }
705
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)706 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
707 bool allow_mixed_precision) {
708 return HloVerifier(/*layout_sensitive=*/layout_sensitive,
709 /*allow_mixed_precision=*/allow_mixed_precision)
710 .Run(module)
711 .status();
712 }
713
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)714 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
715 HloInstruction* lhs,
716 HloInstruction* rhs) {
717 CHECK_LE(lhs->shape().rank(), 2);
718 CHECK_LE(rhs->shape().rank(), 2);
719 PrecisionConfig precision_config;
720 precision_config.mutable_operand_precision()->Resize(
721 2, PrecisionConfig::DEFAULT);
722 DotDimensionNumbers dot_dimension_numbers;
723 dot_dimension_numbers.add_lhs_contracting_dimensions(
724 lhs->shape().rank() > 1 ? 1 : 0);
725 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
726 return absl::make_unique<HloDotInstruction>(
727 shape, lhs, rhs, dot_dimension_numbers, precision_config);
728 }
729 } // namespace xla
730