1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17
18 #include <cmath>
19
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29
30 namespace xla {
31
32 namespace {
33
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36 std::minstd_rand0* engine) {
37 std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38 for (FloatT& value : literal->data<FloatT>()) {
39 value = static_cast<FloatT>(generator(*engine));
40 }
41 }
42
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48 std::minstd_rand0* engine) {
49 constexpr float kSpecialValueProbability = 1e-6;
50 constexpr float kSpecialValues[] = {+0.F,
51 -0.F,
52 1.F,
53 -1.F,
54 std::numeric_limits<float>::infinity(),
55 -std::numeric_limits<float>::infinity()};
56 constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57 std::uniform_real_distribution<float> special_value_gen(0, 1);
58
59 // Generates floating points with a log-uniform distribution. This causes the
60 // exponent of the floating point to have a uniform distribution.
61 int min_exp, max_exp;
62 if (std::is_same<FloatT, bfloat16>()) {
63 min_exp = std::numeric_limits<float>::min_exponent;
64 max_exp = std::numeric_limits<float>::max_exponent;
65 } else {
66 min_exp = std::numeric_limits<FloatT>::min_exponent;
67 max_exp = std::numeric_limits<FloatT>::max_exponent;
68 }
69 std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70
71 for (FloatT& value : literal->data<FloatT>()) {
72 // Each special value has a kSpecialValueProbability chance to be generated
73 // instead of sampling using the normal distributions.
74 if (special_value_gen(*engine) <
75 kSpecialValueProbability * kNumSpecialValues) {
76 value =
77 static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78 } else {
79 float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80 value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81 }
82 }
83 }
84
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90 // Duplicates may be generated if we don't have enough bits.
91 uint16 next_value = 0;
92 for (half& value : literal->data<half>()) {
93 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94 // the sign bit. We could be less wasteful, but this is best-effort anyway.
95 uint16 exponent_msb = next_value & 0x4000;
96 value = Eigen::numext::bit_cast<half, uint16>((next_value & 0xBFFF) |
97 (exponent_msb << 1));
98 next_value++;
99 }
100 }
101
102 template <>
PopulateWithIntNext(Literal * literal)103 void PopulateWithIntNext<bfloat16>(Literal* literal) {
104 // Duplicates may be generated if we don't have enough bits.
105 // Start at 0x80 rather than 0 to avoid denormals.
106 uint16 next_value = 0x80;
107 for (bfloat16& value : literal->data<bfloat16>()) {
108 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
109 // the sign bit. We could be less wasteful, but this is best-effort anyway.
110 uint16 exponent_msb = next_value & 0x4000;
111 value = Eigen::numext::bit_cast<bfloat16, uint16>((next_value & 0xBFFF) |
112 (exponent_msb << 1));
113 next_value++;
114 }
115 }
116
117 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)118 void PopulateWithNextAfter(Literal* literal) {
119 // Duplicates may be generated if the number of elements in the literal
120 // exceeds the number of positive values supported by the type.
121 float next_value = std::numeric_limits<float>::min();
122 for (float& value : literal->data<float>()) {
123 value = next_value;
124 next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
125 }
126 }
127
128 template <typename FloatT,
129 typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
130 std::is_same<half, FloatT>::value,
131 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)132 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
133 PopulateWithIntNext<FloatT>(literal);
134 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
135 *engine);
136 }
137
138 template <typename FloatT,
139 typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
140 !std::is_same<half, FloatT>::value,
141 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)142 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
143 PopulateWithNextAfter<FloatT>(literal);
144 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
145 *engine);
146 }
147
148 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)149 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
150 bool no_duplicates, bool use_large_range) {
151 CHECK(engine != nullptr);
152 CHECK_EQ(literal->shape().element_type(),
153 primitive_util::NativeToPrimitiveType<FloatT>());
154 if (no_duplicates) {
155 PopulateWithNoDuplicateData<FloatT>(literal, engine);
156 } else if (use_large_range) {
157 PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
158 } else {
159 PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
160 }
161 }
162
163 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)164 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
165 bool no_duplicates, bool use_large_range) {
166 using InnerFloatT = typename ComplexT::value_type;
167 CHECK(engine != nullptr);
168 CHECK_EQ(result->shape().element_type(),
169 primitive_util::NativeToPrimitiveType<ComplexT>());
170 Shape floating_point_shape = ShapeUtil::ChangeElementType(
171 result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
172 Literal real_lit(floating_point_shape);
173 Literal imaginary_lit(floating_point_shape);
174
175 PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
176 use_large_range);
177 PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
178 no_duplicates, use_large_range);
179
180 absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
181 absl::Span<const InnerFloatT> imaginary_data =
182 imaginary_lit.data<InnerFloatT>();
183 absl::Span<ComplexT> result_data = result->data<ComplexT>();
184 for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
185 result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
186 }
187 }
188
189 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)190 void PopulateWithFloatingPointData<half>(Literal* literal,
191 std::minstd_rand0* engine,
192 bool no_duplicates,
193 bool use_large_range) {
194 CHECK(engine != nullptr);
195 CHECK_EQ(literal->shape().element_type(),
196 primitive_util::NativeToPrimitiveType<half>());
197 if (no_duplicates) {
198 PopulateWithNoDuplicateData<half>(literal, engine);
199 } else if (use_large_range) {
200 PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
201 } else {
202 PopulateWithRandomFloatingPointData<half, float>(literal, engine);
203 }
204 }
205
206 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)207 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
208 std::minstd_rand0* engine,
209 bool no_duplicates,
210 bool use_large_range) {
211 CHECK(engine != nullptr);
212 CHECK_EQ(literal->shape().element_type(),
213 primitive_util::NativeToPrimitiveType<bfloat16>());
214 if (no_duplicates) {
215 PopulateWithNoDuplicateData<bfloat16>(literal, engine);
216 } else if (use_large_range) {
217 PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
218 } else {
219 PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
220 }
221 }
222
223 // uniform_int_distribution is not defined for 8-bit integers.
224 // Use 'short' for those types.
225 template <typename IntT>
226 struct RngT {
227 using type = IntT;
228 };
229
230 template <>
231 struct RngT<int8> {
232 using type = int16;
233 };
234
235 template <>
236 struct RngT<uint8> {
237 using type = uint16;
238 };
239
240 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)241 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
242 bool no_duplicates) {
243 CHECK(engine != nullptr);
244 CHECK_EQ(literal->shape().element_type(),
245 primitive_util::NativeToPrimitiveType<IntT>());
246 if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
247 std::numeric_limits<IntT>::max()) {
248 std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
249 std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
250 *engine);
251 } else {
252 std::uniform_int_distribution<typename RngT<IntT>::type> generator(
253 std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
254 for (IntT& value : literal->data<IntT>()) {
255 value = generator(*engine);
256 }
257 }
258 }
259
260 // Similar to MakeFakeLiteral but takes a random number generator engine to
261 // enable reusing the engine across randomly generated literals. 'no_duplicates'
262 // indicates that there should be no duplicate values in each generated
263 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
264 // are not supported and uniqueness cannot be guaranteed if the number of
265 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)266 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
267 std::minstd_rand0* engine,
268 bool no_duplicates,
269 bool use_large_range) {
270 if (shape.IsTuple()) {
271 std::vector<Literal> elements;
272 for (const Shape& element_shape : shape.tuple_shapes()) {
273 TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
274 element_shape, engine,
275 no_duplicates, use_large_range));
276 elements.push_back(std::move(element));
277 }
278 return LiteralUtil::MakeTupleOwned(std::move(elements));
279 }
280 if (engine == nullptr) {
281 return Literal::CreateFromShape(shape);
282 }
283 // Clear tiles/element size in shape's layout before using it for creating
284 // literal.
285 Shape new_shape = shape;
286 new_shape.mutable_layout()->clear_tiles();
287 new_shape.mutable_layout()->set_element_size_in_bits(0);
288 Literal literal(new_shape);
289 switch (shape.element_type()) {
290 case BF16:
291 PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
292 use_large_range);
293 break;
294 case F16:
295 PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
296 use_large_range);
297 break;
298 case F32:
299 PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
300 use_large_range);
301 break;
302 case F64:
303 PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
304 use_large_range);
305 break;
306 case S8:
307 PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
308 break;
309 case U8:
310 PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
311 break;
312 case S16:
313 PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
314 break;
315 case U16:
316 PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
317 break;
318 case S32:
319 PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
320 break;
321 case U32:
322 PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
323 break;
324 case S64:
325 PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
326 break;
327 case U64:
328 PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
329 break;
330 case C64:
331 PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
332 use_large_range);
333 break;
334 case C128:
335 PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
336 use_large_range);
337 break;
338 case PRED: {
339 std::uniform_int_distribution<int> generator(0, 1);
340 TF_CHECK_OK(
341 literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
342 return generator(*engine);
343 }));
344 break;
345 }
346 default:
347 return Unimplemented("Unsupported type for fake literal generation: %s",
348 ShapeUtil::HumanString(shape));
349 }
350 return std::move(literal);
351 }
352
353 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)354 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
355 std::minstd_rand0* engine,
356 IntT min, IntT max) {
357 CHECK(engine != nullptr);
358 CHECK_EQ(literal->shape().element_type(),
359 primitive_util::NativeToPrimitiveType<IntT>());
360 std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
361 for (IntT& value : literal->data<IntT>()) {
362 value = generator(*engine);
363 }
364 }
365
366 // Same as MakeFakeLiteralInternal but generates random numbers in the given
367 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64_t min,int64_t max,bool is_sorted)368 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
369 std::minstd_rand0* engine,
370 int64_t min, int64_t max,
371 bool is_sorted) {
372 if (shape.IsTuple()) {
373 std::vector<Literal> elements;
374 for (const Shape& element_shape : shape.tuple_shapes()) {
375 TF_ASSIGN_OR_RETURN(Literal element,
376 MakeFakeLiteralInternalWithBounds(
377 element_shape, engine, min, max, is_sorted));
378 elements.push_back(std::move(element));
379 }
380 return LiteralUtil::MakeTupleOwned(std::move(elements));
381 }
382 if (engine == nullptr) {
383 return Literal::CreateFromShape(shape);
384 }
385 // Clear tiles/element size in shape's layout before using it for creating
386 // literal.
387 Shape new_shape = shape;
388 new_shape.mutable_layout()->clear_tiles();
389 new_shape.mutable_layout()->set_element_size_in_bits(0);
390 Literal literal(new_shape);
391 switch (shape.element_type()) {
392 case S8:
393 PopulateWithRandomIntegralDataWithBounds<int8>(
394 &literal, engine, static_cast<int8>(min), static_cast<int8>(max));
395 if (is_sorted) {
396 std::sort(literal.data<int8>().begin(), literal.data<int8>().end());
397 }
398 break;
399 case U8:
400 PopulateWithRandomIntegralDataWithBounds<uint8>(
401 &literal, engine, static_cast<uint8>(min), static_cast<uint8>(max));
402 if (is_sorted) {
403 std::sort(literal.data<uint8>().begin(), literal.data<uint8>().end());
404 }
405 break;
406 case S16:
407 PopulateWithRandomIntegralDataWithBounds<int16>(
408 &literal, engine, static_cast<int16>(min), static_cast<int16>(max));
409 if (is_sorted) {
410 std::sort(literal.data<int16>().begin(), literal.data<int16>().end());
411 }
412 break;
413 case U16:
414 PopulateWithRandomIntegralDataWithBounds<uint16>(
415 &literal, engine, static_cast<uint16>(min), static_cast<uint16>(max));
416 if (is_sorted) {
417 std::sort(literal.data<uint16>().begin(), literal.data<uint16>().end());
418 }
419 break;
420 case S32:
421 PopulateWithRandomIntegralDataWithBounds<int32>(
422 &literal, engine, static_cast<int32>(min), static_cast<int32>(max));
423 if (is_sorted) {
424 std::sort(literal.data<int32>().begin(), literal.data<int32>().end());
425 }
426 break;
427 case U32:
428 PopulateWithRandomIntegralDataWithBounds<uint32>(
429 &literal, engine, static_cast<uint32>(min), static_cast<uint32>(max));
430 if (is_sorted) {
431 std::sort(literal.data<uint32>().begin(), literal.data<uint32>().end());
432 }
433 break;
434 case S64:
435 PopulateWithRandomIntegralDataWithBounds<int64>(
436 &literal, engine, static_cast<int64>(min), static_cast<int64>(max));
437 if (is_sorted) {
438 std::sort(literal.data<int64>().begin(), literal.data<int64>().end());
439 }
440 break;
441 case U64:
442 PopulateWithRandomIntegralDataWithBounds<uint64>(
443 &literal, engine, static_cast<uint64>(min), static_cast<uint64>(max));
444 if (is_sorted) {
445 std::sort(literal.data<uint64>().begin(), literal.data<uint64>().end());
446 }
447 break;
448 default:
449 return Unimplemented(
450 "Unsupported type for fake random literal generation with bounds: %s",
451 ShapeUtil::HumanString(shape));
452 }
453 return std::move(literal);
454 }
455
456 enum class ConstantType { kUnknown, kZero, kOne };
457
458 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)459 ConstantType GetInitValue(const HloComputation& computation) {
460 // TODO(b/77635120): Add init values, for min, max, and their arg variants.
461 const HloInstruction* const root = computation.root_instruction();
462 if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
463 root->operand(0)->opcode() != HloOpcode::kParameter ||
464 root->operand(1)->opcode() != HloOpcode::kParameter ||
465 root->operand(0) == root->operand(1)) {
466 return ConstantType::kUnknown;
467 }
468
469 switch (root->opcode()) {
470 case HloOpcode::kAdd:
471 return ConstantType::kZero;
472 case HloOpcode::kMultiply:
473 return ConstantType::kOne;
474 default:
475 return ConstantType::kUnknown;
476 }
477 }
478
479 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
480 // initialization value.
NeedsInitValue(const HloUse & use)481 bool NeedsInitValue(const HloUse& use) {
482 const HloInstruction* const instruction = use.instruction;
483 const HloOpcode opcode = instruction->opcode();
484 const int64_t op_num = use.operand_number;
485 return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
486 (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
487 (opcode == HloOpcode::kReduce &&
488 op_num >= instruction->operand_count() / 2));
489 }
490
491 // Generate random values that are constrained to the input_shape minus the
492 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64_t index_bound,std::minstd_rand0 * engine)493 Literal MakeRandomIndex(int64_t index_bound, std::minstd_rand0* engine) {
494 std::uniform_int_distribution<int32> generator(0, index_bound);
495 return LiteralUtil::CreateR0<int32>(generator(*engine));
496 }
497
498 // Use dataflow analysis on each parameter to see if there are uses that would
499 // be problematic when generating input data. Returns the list of instructions
500 // that correspond to their uses.
501 //
502 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)503 std::vector<HloInstruction*> FindConstrainedUses(
504 const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
505 std::vector<HloInstruction*> constrained_uses;
506 for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) {
507 const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first);
508 for (const HloUse& use : value.uses()) {
509 HloInstruction* instruction = use.instruction;
510 const HloOpcode opcode = instruction->opcode();
511 const int64_t op_num = use.operand_number;
512 if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
513 (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
514 constrained_uses.push_back(instruction);
515 } else if ((opcode == HloOpcode::kGather ||
516 opcode == HloOpcode::kScatter) &&
517 op_num == 1) {
518 constrained_uses.push_back(instruction);
519 } else if (opcode == HloOpcode::kFusion) {
520 const HloInstruction* const to_analyze =
521 instruction->fused_parameter(op_num);
522 auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
523 constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
524 fused_uses.end());
525 } else if (NeedsInitValue(use)) {
526 constrained_uses.push_back(instruction);
527 } else if (opcode == HloOpcode::kConvert ||
528 opcode == HloOpcode::kReducePrecision) {
529 auto converted_uses = FindConstrainedUses(dataflow, *instruction);
530 constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
531 converted_uses.end());
532 } else if (opcode == HloOpcode::kSort &&
533 instruction->operand_count() >= 2 && op_num == 0) {
534 // Operand 0 of sort is the array of keys used for key/value
535 // (two-operand) kSort instructions. Since sort stability is not
536 // guaranteed, constrain keys of key-value sort not to have duplicates,
537 // since otherwise the value order may legitimately differ.
538 constrained_uses.push_back(instruction);
539 }
540 }
541 }
542 return constrained_uses;
543 }
544
545 // Given a parameter, generate a random Literal to use as input if there exist
546 // no constrained uses in the dataflow graph. If such constraints exist,
547 // generate a constrained literal (either bounded in the case of indices, or
548 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)549 StatusOr<Literal> CreateLiteralForConstrainedUses(
550 const absl::Span<HloInstruction* const> constrained_uses,
551 const HloInstruction& param, const Shape& param_shape,
552 std::minstd_rand0* engine, bool use_large_range) {
553 int64_t index_bound = INT64_MAX;
554 bool no_duplicates = false;
555 bool needs_constant = false;
556 bool needs_sorted_indices = false;
557 ConstantType constant_type = ConstantType::kUnknown;
558 for (HloInstruction* use : constrained_uses) {
559 switch (use->opcode()) {
560 case HloOpcode::kDynamicSlice:
561 case HloOpcode::kDynamicUpdateSlice: {
562 const Shape& indexed_shape = use->operand(0)->shape();
563 const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
564 ? use->shape()
565 : use->operand(1)->shape();
566 const int64_t first_index =
567 Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
568 for (int64_t operand = first_index; operand < use->operand_count();
569 ++operand) {
570 if (use->operand(operand) == ¶m) {
571 index_bound = std::min(
572 index_bound,
573 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
574 ShapeUtil::GetDimension(slice_shape,
575 operand - first_index));
576 }
577 }
578 break;
579 }
580 case HloOpcode::kGather:
581 case HloOpcode::kScatter: {
582 const Shape& operand_shape = use->operand(0)->shape();
583 if (use->operand(1) == ¶m) {
584 auto index_map =
585 use->opcode() == HloOpcode::kGather
586 ? use->gather_dimension_numbers().start_index_map()
587 : use->scatter_dimension_numbers()
588 .scatter_dims_to_operand_dims();
589 for (const auto dim_in_operand : index_map) {
590 index_bound =
591 std::min(index_bound, operand_shape.dimensions(dim_in_operand));
592 }
593 }
594 if (use->opcode() == HloOpcode::kScatter) {
595 needs_sorted_indices |=
596 Cast<const HloScatterInstruction>(use)->indices_are_sorted();
597 } else {
598 needs_sorted_indices |=
599 Cast<const HloGatherInstruction>(use)->indices_are_sorted();
600 }
601 break;
602 }
603 case HloOpcode::kReduce:
604 case HloOpcode::kReduceWindow:
605 needs_constant = true;
606 constant_type = GetInitValue(*use->to_apply());
607 break;
608
609 case HloOpcode::kSelectAndScatter:
610 needs_constant = true;
611 constant_type = GetInitValue(*use->scatter());
612 break;
613
614 case HloOpcode::kSort:
615 no_duplicates = true;
616 break;
617
618 default:
619 return Unimplemented(
620 "Constrained operand generation not implemented for %s.",
621 use->ToString());
622 }
623 }
624 int constraint_count = 0;
625 constraint_count += no_duplicates ? 1 : 0;
626 constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
627 constraint_count += needs_constant ? 1 : 0;
628 if (constraint_count > 1) {
629 return Unimplemented("Conflicting operand generation constraints.");
630 }
631 if (index_bound != INT64_MAX) {
632 return MakeFakeLiteralInternalWithBounds(param_shape, engine, -1,
633 index_bound, needs_sorted_indices);
634 } else if (needs_constant) {
635 switch (constant_type) {
636 case ConstantType::kZero:
637 return LiteralUtil::Zero(param_shape.element_type());
638 case ConstantType::kOne:
639 return LiteralUtil::One(param_shape.element_type());
640 case ConstantType::kUnknown:
641 // We want the identity element for the computation, but we don't really
642 // know what it is - so any value we generate will be just as wrong.
643 return MakeFakeLiteralInternal(param_shape, engine,
644 /*no_duplicates=*/false,
645 use_large_range);
646 }
647 } else {
648 return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
649 use_large_range);
650 }
651 }
652
653 // Given a module entry parameter, use the dataflow analysis to see if a
654 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)655 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
656 const HloInstruction& param,
657 const Shape& param_shape,
658 std::minstd_rand0* engine,
659 bool use_large_range) {
660 const auto constrained_uses = FindConstrainedUses(dataflow, param);
661 return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
662 engine, use_large_range);
663 }
664
665 } // namespace
666
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)667 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
668 bool use_large_range) {
669 auto engine =
670 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
671 return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
672 use_large_range);
673 }
674
MakeFakeArguments(HloModule * const module,bool pseudo_random,bool use_large_range)675 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
676 bool pseudo_random,
677 bool use_large_range) {
678 auto engine =
679 pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
680 return MakeFakeArguments(module, engine.get(), use_large_range);
681 }
682
MakeFakeArguments(HloModule * const module,std::minstd_rand0 * engine,bool use_large_range)683 StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
684 std::minstd_rand0* engine,
685 bool use_large_range) {
686 TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
687 const auto params = module->entry_computation()->parameter_instructions();
688 std::vector<Literal> arguments(params.size());
689 for (int i = 0; i < params.size(); ++i) {
690 const HloModuleConfig& module_config = module->config();
691 const Shape& param_shape = (module_config.has_entry_computation_layout() &&
692 module_config.entry_computation_layout()
693 .parameter_layout(i)
694 .shape()
695 .is_static())
696 ? module_config.entry_computation_layout()
697 .parameter_layout(i)
698 .shape()
699 : params[i]->shape();
700
701 arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], param_shape,
702 engine, use_large_range)
703 .ValueOrDie();
704 }
705 return std::move(arguments);
706 }
707
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)708 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
709 bool allow_mixed_precision) {
710 return HloVerifier(/*layout_sensitive=*/layout_sensitive,
711 /*allow_mixed_precision=*/allow_mixed_precision)
712 .Run(module)
713 .status();
714 }
715
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)716 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
717 HloInstruction* lhs,
718 HloInstruction* rhs) {
719 CHECK_LE(lhs->shape().rank(), 2);
720 CHECK_LE(rhs->shape().rank(), 2);
721 PrecisionConfig precision_config;
722 precision_config.mutable_operand_precision()->Resize(
723 2, PrecisionConfig::DEFAULT);
724 DotDimensionNumbers dot_dimension_numbers;
725 dot_dimension_numbers.add_lhs_contracting_dimensions(
726 lhs->shape().rank() > 1 ? 1 : 0);
727 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
728 return absl::make_unique<HloDotInstruction>(
729 shape, lhs, rhs, dot_dimension_numbers, precision_config);
730 }
731 } // namespace xla
732