/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" #include #include #include "absl/base/casts.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" namespace xla { namespace { template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { std::uniform_real_distribution generator(-0.1f, 0.2f); for (FloatT& value : literal->data()) { value = static_cast(generator(*engine)); } } // Populates a floating point literal with random floating points sampled from a // uniform-log distribution spanning approximately the entire range of the // representable floating point. template void PopulateWithRandomFullRangeFloatingPointData(Literal* literal, std::minstd_rand0* engine) { constexpr float kSpecialValueProbability = 1e-6; constexpr float kSpecialValues[] = {+0.F, -0.F, 1.F, -1.F, std::numeric_limits::infinity(), -std::numeric_limits::infinity()}; constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float); std::uniform_real_distribution special_value_gen(0, 1); // Generates floating points with a log-uniform distribution. This causes the // exponent of the floating point to have a uniform distribution. int min_exp, max_exp; if (std::is_same()) { min_exp = std::numeric_limits::min_exponent; max_exp = std::numeric_limits::max_exponent; } else { min_exp = std::numeric_limits::min_exponent; max_exp = std::numeric_limits::max_exponent; } std::uniform_real_distribution generator(min_exp - 1, max_exp - 1); for (FloatT& value : literal->data()) { // Each special value has a kSpecialValueProbability chance to be generated // instead of sampling using the normal distributions. if (special_value_gen(*engine) < kSpecialValueProbability * kNumSpecialValues) { value = static_cast(kSpecialValues[(*engine)() % kNumSpecialValues]); } else { float sign = ((*engine)() % 2 == 0) ? 1 : -1; value = static_cast(pow(2, generator(*engine)) * sign); } } } template void PopulateWithIntNext(Literal* literal); template <> void PopulateWithIntNext(Literal* literal) { // Duplicates may be generated if we don't have enough bits. uint16_t next_value = 0; for (half& value : literal->data()) { // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into // the sign bit. We could be less wasteful, but this is best-effort anyway. uint16_t exponent_msb = next_value & 0x4000; value = Eigen::numext::bit_cast((next_value & 0xBFFF) | (exponent_msb << 1)); next_value++; } } template <> void PopulateWithIntNext(Literal* literal) { // Duplicates may be generated if we don't have enough bits. // Start at 0x80 rather than 0 to avoid denormals. uint16_t next_value = 0x80; for (bfloat16& value : literal->data()) { // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into // the sign bit. We could be less wasteful, but this is best-effort anyway. uint16_t exponent_msb = next_value & 0x4000; value = Eigen::numext::bit_cast((next_value & 0xBFFF) | (exponent_msb << 1)); next_value++; } } template void PopulateWithNextAfter(Literal* literal) { // Duplicates may be generated if the number of elements in the literal // exceeds the number of positive values supported by the type. float next_value = std::numeric_limits::min(); for (float& value : literal->data()) { value = next_value; next_value = std::nextafter(next_value, std::numeric_limits::max()); } } template ::value || std::is_same::value, int>::type = 0> void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { PopulateWithIntNext(literal); std::shuffle(literal->data().begin(), literal->data().end(), *engine); } template ::value && !std::is_same::value, int>::type = 0> void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { PopulateWithNextAfter(literal); std::shuffle(literal->data().begin(), literal->data().end(), *engine); } template void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates, bool use_large_range) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates) { PopulateWithNoDuplicateData(literal, engine); } else if (use_large_range) { PopulateWithRandomFullRangeFloatingPointData(literal, engine); } else { PopulateWithRandomFloatingPointData(literal, engine); } } template void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine, bool no_duplicates, bool use_large_range) { using InnerFloatT = typename ComplexT::value_type; CHECK(engine != nullptr); CHECK_EQ(result->shape().element_type(), primitive_util::NativeToPrimitiveType()); Shape floating_point_shape = ShapeUtil::ChangeElementType( result->shape(), primitive_util::NativeToPrimitiveType()); Literal real_lit(floating_point_shape); Literal imaginary_lit(floating_point_shape); PopulateWithFloatingPointData(&real_lit, engine, no_duplicates, use_large_range); PopulateWithFloatingPointData(&imaginary_lit, engine, no_duplicates, use_large_range); absl::Span real_data = real_lit.data(); absl::Span imaginary_data = imaginary_lit.data(); absl::Span result_data = result->data(); for (int i = 0; i < real_lit.data().size(); i++) { result_data[i] = ComplexT(real_data[i], imaginary_data[i]); } } template <> void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates, bool use_large_range) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates) { PopulateWithNoDuplicateData(literal, engine); } else if (use_large_range) { PopulateWithRandomFullRangeFloatingPointData(literal, engine); } else { PopulateWithRandomFloatingPointData(literal, engine); } } template <> void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates, bool use_large_range) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates) { PopulateWithNoDuplicateData(literal, engine); } else if (use_large_range) { PopulateWithRandomFullRangeFloatingPointData(literal, engine); } else { PopulateWithRandomFloatingPointData(literal, engine); } } // uniform_int_distribution is not defined for 8-bit integers. // Use 'short' for those types. template struct RngT { using type = IntT; }; template <> struct RngT { using type = int16_t; }; template <> struct RngT { using type = uint16_t; }; template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) < std::numeric_limits::max()) { std::iota(literal->data().begin(), literal->data().end(), 0); std::shuffle(literal->data().begin(), literal->data().end(), *engine); } else { std::uniform_int_distribution::type> generator( std::numeric_limits::lowest(), std::numeric_limits::max()); for (IntT& value : literal->data()) { value = generator(*engine); } } } // Similar to MakeFakeLiteral but takes a random number generator engine to // enable reusing the engine across randomly generated literals. 'no_duplicates' // indicates that there should be no duplicate values in each generated // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. StatusOr MakeFakeLiteralInternal(const Shape& shape, std::minstd_rand0* engine, bool no_duplicates, bool use_large_range) { if (shape.IsTuple()) { std::vector elements; const auto& shape_tuple_shapes = shape.tuple_shapes(); elements.reserve(shape_tuple_shapes.size()); for (const Shape& element_shape : shape_tuple_shapes) { TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal( element_shape, engine, no_duplicates, use_large_range)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); } if (engine == nullptr) { return Literal::CreateFromShape(shape); } // Clear tiles/element size in shape's layout before using it for creating // literal. Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); new_shape.mutable_layout()->set_element_size_in_bits(0); Literal literal(new_shape); switch (shape.element_type()) { case BF16: PopulateWithFloatingPointData(&literal, engine, no_duplicates, use_large_range); break; case F16: PopulateWithFloatingPointData(&literal, engine, no_duplicates, use_large_range); break; case F32: PopulateWithFloatingPointData(&literal, engine, no_duplicates, use_large_range); break; case F64: PopulateWithFloatingPointData(&literal, engine, no_duplicates, use_large_range); break; case S8: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case C64: PopulateWithComplexData(&literal, engine, no_duplicates, use_large_range); break; case C128: PopulateWithComplexData(&literal, engine, no_duplicates, use_large_range); break; case PRED: { std::uniform_int_distribution generator(0, 1); TF_CHECK_OK( literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; } default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape)); } return std::move(literal); } template void PopulateWithRandomIntegralDataWithBounds(Literal* literal, std::minstd_rand0* engine, IntT min, IntT max) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution::type> generator(min, max); for (IntT& value : literal->data()) { value = generator(*engine); } } // Same as MakeFakeLiteralInternal but generates random numbers in the given // range [min, max]. Currently this works only for INT types. StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, std::minstd_rand0* engine, int64_t min, int64_t max, bool is_sorted) { if (shape.IsTuple()) { std::vector elements; const auto& shape_tuple_shapes = shape.tuple_shapes(); elements.reserve(shape_tuple_shapes.size()); for (const Shape& element_shape : shape_tuple_shapes) { TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternalWithBounds( element_shape, engine, min, max, is_sorted)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); } if (engine == nullptr) { return Literal::CreateFromShape(shape); } // Clear tiles/element size in shape's layout before using it for creating // literal. Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); new_shape.mutable_layout()->set_element_size_in_bits(0); Literal literal(new_shape); switch (shape.element_type()) { case S8: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case U8: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case S16: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case U16: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case S32: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case U32: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case S64: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; case U64: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); if (is_sorted) { std::sort(literal.data().begin(), literal.data().end()); } break; default: return Unimplemented( "Unsupported type for fake random literal generation with bounds: %s", ShapeUtil::HumanString(shape)); } return std::move(literal); } enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || root->operand(1)->opcode() != HloOpcode::kParameter || root->operand(0) == root->operand(1)) { return ConstantType::kUnknown; } switch (root->opcode()) { case HloOpcode::kAdd: return ConstantType::kZero; case HloOpcode::kMultiply: return ConstantType::kOne; default: return ConstantType::kUnknown; } } // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random // initialization value. bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64_t op_num = use.operand_number; return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || (opcode == HloOpcode::kReduce && op_num >= instruction->operand_count() / 2)); } // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. Literal MakeRandomIndex(int64_t index_bound, std::minstd_rand0* engine) { std::uniform_int_distribution generator(0, index_bound); return LiteralUtil::CreateR0(generator(*engine)); } // Returns true if `dest' is reachable from `src' through data-formatting and // custom call instructions within the same computation. bool ReachableViaDataFormatting(const HloInstruction* src, const HloInstruction* dest) { if (src == dest) { return true; } switch (dest->opcode()) { case HloOpcode::kReshape: case HloOpcode::kTranspose: case HloOpcode::kCopy: case HloOpcode::kSlice: break; case HloOpcode::kCustomCall: if (dest->custom_call_target() == "AssumeGatherIndicesInBound") { break; } return false; default: return false; } for (const auto* operand : dest->operands()) { if (ReachableViaDataFormatting(src, operand)) { return true; } } return false; } // Use dataflow analysis on each parameter to see if there are uses that would // be problematic when generating input data. Returns the list of instructions // that correspond to their uses. // // Should be paired with the CreateLiteralForConstrainedUses() function below. std::vector FindConstrainedUses( const HloDataflowAnalysis& dataflow, const HloInstruction& param) { std::vector constrained_uses; for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); for (const HloUse& use : value.GetUses()) { HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64_t op_num = use.operand_number; if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); } else if ((opcode == HloOpcode::kGather || opcode == HloOpcode::kScatter) && op_num == 1) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = instruction->fused_parameter(op_num); auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), fused_uses.end()); } else if (NeedsInitValue(use)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kConvert || opcode == HloOpcode::kReducePrecision) { auto converted_uses = FindConstrainedUses(dataflow, *instruction); constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value // (two-operand) kSort instructions. Since sort stability is not // guaranteed, constrain keys of key-value sort not to have duplicates, // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } } for (auto* instruction : param.parent()->instructions()) { const HloOpcode opcode = instruction->opcode(); if (opcode == HloOpcode::kGather || opcode == HloOpcode::kScatter) { if (instruction->operand(1) == ¶m) { // Above already covers this case. continue; } if (ReachableViaDataFormatting(¶m, instruction->operand(1))) { constrained_uses.push_back(instruction); } } } return constrained_uses; } // Given a parameter, generate a random Literal to use as input if there exist // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, const Shape& param_shape, std::minstd_rand0* engine, bool use_large_range) { int64_t index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; bool needs_sorted_indices = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: { const Shape& indexed_shape = use->operand(0)->shape(); const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); const int64_t first_index = Cast(use)->first_index_operand_number(); for (int64_t operand = first_index; operand < use->operand_count(); ++operand) { if (use->operand(operand) == ¶m) { index_bound = std::min( index_bound, ShapeUtil::GetDimension(indexed_shape, operand - first_index) - ShapeUtil::GetDimension(slice_shape, operand - first_index)); } } break; } case HloOpcode::kGather: case HloOpcode::kScatter: { const Shape& operand_shape = use->operand(0)->shape(); auto index_map = use->opcode() == HloOpcode::kGather ? use->gather_dimension_numbers().start_index_map() : use->scatter_dimension_numbers() .scatter_dims_to_operand_dims(); for (const auto dim_in_operand : index_map) { index_bound = std::min(index_bound, operand_shape.dimensions(dim_in_operand) - 1); } if (use->opcode() == HloOpcode::kScatter) { needs_sorted_indices |= Cast(use)->indices_are_sorted(); } else { needs_sorted_indices |= Cast(use)->indices_are_sorted(); } break; } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; case HloOpcode::kSort: no_duplicates = true; break; default: return Unimplemented( "Constrained operand generation not implemented for %s.", use->ToString()); } } int constraint_count = 0; constraint_count += no_duplicates ? 1 : 0; constraint_count += (index_bound != INT64_MAX) ? 1 : 0; constraint_count += needs_constant ? 1 : 0; if (constraint_count > 1) { return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { return MakeFakeLiteralInternalWithBounds(param_shape, engine, 0, index_bound, needs_sorted_indices); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param_shape.element_type()); case ConstantType::kOne: return LiteralUtil::One(param_shape.element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. return MakeFakeLiteralInternal(param_shape, engine, /*no_duplicates=*/false, use_large_range); } } else { return MakeFakeLiteralInternal(param_shape, engine, no_duplicates, use_large_range); } } // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, const HloInstruction& param, const Shape& param_shape, std::minstd_rand0* engine, bool use_large_range) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape, engine, use_large_range); } } // namespace StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, bool use_large_range) { auto engine = pseudo_random ? std::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false, use_large_range); } StatusOr> MakeFakeArguments(const HloModule* module, bool pseudo_random, bool use_large_range) { auto engine = pseudo_random ? std::make_unique() : nullptr; return MakeFakeArguments(module, engine.get(), use_large_range); } StatusOr> MakeFakeArguments(const HloModule* module, std::minstd_rand0* engine, bool use_large_range) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { const HloModuleConfig& module_config = module->config(); const Shape& param_shape = (module_config.has_entry_computation_layout() && module_config.entry_computation_layout() .parameter_layout(i) .shape() .is_static()) ? module_config.entry_computation_layout() .parameter_layout(i) .shape() : params[i]->shape(); TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( *dataflow, *params[i], param_shape, engine, use_large_range)); } return std::move(arguments); } Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision) { return HloVerifier(/*layout_sensitive=*/layout_sensitive, /*allow_mixed_precision=*/allow_mixed_precision) .Run(module) .status(); } std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { CHECK_LE(lhs->shape().rank(), 2); CHECK_LE(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.add_lhs_contracting_dimensions( lhs->shape().rank() > 1 ? 1 : 0); dot_dimension_numbers.add_rhs_contracting_dimensions(0); return std::make_unique( shape, lhs, rhs, dot_dimension_numbers, precision_config); } } // namespace xla