Home
last modified time | relevance | path

Searched refs:hlo (Results 1 – 25 of 131) sorted by relevance

123456

/external/tensorflow/tensorflow/compiler/xla/service/
Ddfs_hlo_visitor.h74 virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
75 virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
77 virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
78 virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
79 virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
80 virtual Status HandleMaximum(HloInstructionPtr hlo) { in HandleMaximum() argument
81 return HandleElementwiseBinary(hlo); in HandleMaximum()
83 virtual Status HandleMinimum(HloInstructionPtr hlo) { in HandleMinimum() argument
84 return HandleElementwiseBinary(hlo); in HandleMinimum()
86 virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
[all …]
Dbfloat16_normalization.cc36 Status DefaultAction(HloInstruction* hlo) override;
48 Status HandleInstruction(HloInstruction* hlo);
52 Status HandleMultipleOutputs(HloInstruction* hlo);
55 Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
60 Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
65 Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
72 HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
80 HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { in InsertConvertAfterOutput() argument
81 bool is_root = computation->root_instruction() == hlo; in InsertConvertAfterOutput()
82 std::vector<HloInstruction*> materialized_users = hlo->users(); in InsertConvertAfterOutput()
[all …]
Dbfloat16_propagation.cc206 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, in AllUsersConsumeBF16() argument
209 const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index); in AllUsersConsumeBF16()
214 auto& value_set = dataflow_->GetValueSet(&hlo, index); in AllUsersConsumeBF16()
311 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, in DetermineInstructionPrecision() argument
318 [this, hlo, &postpone_processing_called_computations] { in DetermineInstructionPrecision()
320 if (hlo->opcode() == HloOpcode::kFusion) { in DetermineInstructionPrecision()
321 DetermineFusionComputationPrecision(hlo); in DetermineInstructionPrecision()
322 } else if (hlo->opcode() == HloOpcode::kWhile) { in DetermineInstructionPrecision()
323 DetermineWhileComputationsPrecision(hlo); in DetermineInstructionPrecision()
326 instructions_visited_in_backward_pass_.insert(hlo); in DetermineInstructionPrecision()
[all …]
Ddynamic_dimension_inference.cc33 Status DefaultAction(HloInstruction* hlo) override;
42 Status HandleParameter(HloInstruction* hlo) override;
44 Status HandleReduce(HloInstruction* hlo) override;
46 Status HandleDot(HloInstruction* hlo) override;
48 Status HandleTuple(HloInstruction* hlo) override;
50 Status HandleTranspose(HloInstruction* hlo) override;
52 Status HandleReshape(HloInstruction* hlo) override;
54 Status HandlePad(HloInstruction* hlo) override;
56 Status HandleBroadcast(HloInstruction* hlo) override;
58 Status HandleGetDimensionSize(HloInstruction* hlo) override;
[all …]
Dbfloat16_conversion_folding.cc35 Status DefaultAction(HloInstruction* hlo) override;
50 Status TryFoldBF16Conversions(HloInstruction* hlo);
55 Status FoldOutputConversions(HloInstruction* hlo);
60 Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
68 HloInstruction* hlo) { in FoldOutputConversions() argument
69 std::vector<HloInstruction*> materialized_users = hlo->users(); in FoldOutputConversions()
70 hlo->mutable_shape()->set_element_type(BF16); in FoldOutputConversions()
73 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); in FoldOutputConversions()
80 HloInstruction* hlo, int64 operand_index) { in FoldOperandConversion() argument
82 auto operand = hlo->mutable_operand(operand_index); in FoldOperandConversion()
[all …]
Dhlo_element_type_converter.cc38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { in ToElementType() argument
39 if (hlo->shape().element_type() != type) { in ToElementType()
40 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); in ToElementType()
41 hlo = hlo->parent()->AddInstruction( in ToElementType()
42 HloInstruction::CreateConvert(shape, hlo)); in ToElementType()
44 CHECK_EQ(hlo->shape().element_type(), type); in ToElementType()
45 return hlo; in ToElementType()
48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { in HasOperandType() argument
49 for (HloInstruction* operand : hlo->operands()) { in HasOperandType()
85 HloInstruction* ConvertTupleElements(HloInstruction* hlo, in ConvertTupleElements() argument
[all …]
Dhlo_module_group_metadata.cc70 const auto visitor = [this](HloInstruction* hlo) -> Status { in Build() argument
74 const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent()); in Build()
79 if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) { in Build()
81 if (IsChannelInstruction(hlo)) { in Build()
82 peers.push_back(PeerComputation(hlo)); in Build()
83 } else if (hlo->IsCrossModuleAllReduce()) { in Build()
84 for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { in Build()
85 if (instr == hlo) { in Build()
104 tracked_instructions_comms_[tracked->instruction()].push_back(hlo); in Build()
106 } else if (IsCompanionInstruction(hlo)) { in Build()
[all …]
Ddfs_hlo_visitor_with_default_test.cc40 Status DefaultAction(HloInstruction* hlo) override { in TEST_F() argument
43 TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2)) in TEST_F()
44 << hlo->ToString(); in TEST_F()
45 TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1)) in TEST_F()
46 << hlo->ToString(); in TEST_F()
50 Status HandleElementwiseBinary(HloInstruction* hlo) override { in TEST_F() argument
52 TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2) in TEST_F()
53 << hlo->ToString(); in TEST_F()
56 Status HandleElementwiseUnary(HloInstruction* hlo) override { in TEST_F() argument
58 TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1) in TEST_F()
[all …]
Dbfloat16_support.cc23 bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, in SupportsBF16Operand() argument
25 switch (hlo.opcode()) { in SupportsBF16Operand()
36 return hlo.operand(0)->shape().element_type() == BF16; in SupportsBF16Operand()
43 bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { in SupportsBF16Output()
44 switch (hlo.opcode()) { in SupportsBF16Output()
54 return hlo.shape().element_type() == BF16; in SupportsBF16Output()
61 bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const { in SupportsMixedPrecisions()
62 switch (hlo.opcode()) { in SupportsMixedPrecisions()
79 const HloInstruction& hlo, int64 operand_index) { in EffectiveOperandPrecisionIsOutputPrecision() argument
80 switch (hlo.opcode()) { in EffectiveOperandPrecisionIsOutputPrecision()
[all …]
Ddfs_hlo_visitor_with_default.h54 Status HandleElementwiseUnary(HloInstructionPtr hlo) override { in HandleElementwiseUnary() argument
55 return DefaultAction(hlo); in HandleElementwiseUnary()
57 Status HandleElementwiseBinary(HloInstructionPtr hlo) override { in HandleElementwiseBinary() argument
58 return DefaultAction(hlo); in HandleElementwiseBinary()
61 Status HandleBatchNormTraining(HloInstructionPtr hlo) override { in HandleBatchNormTraining() argument
62 return DefaultAction(hlo); in HandleBatchNormTraining()
65 Status HandleBatchNormInference(HloInstructionPtr hlo) override { in HandleBatchNormInference() argument
66 return DefaultAction(hlo); in HandleBatchNormInference()
69 Status HandleBatchNormGrad(HloInstructionPtr hlo) override { in HandleBatchNormGrad() argument
70 return DefaultAction(hlo); in HandleBatchNormGrad()
[all …]
Dhlo_execution_profile.cc94 for (const HloInstruction* hlo : computation->instructions()) { in CreateHloProfilePrinterData() local
97 instruction_info->set_long_name(hlo->ToString()); in CreateHloProfilePrinterData()
98 instruction_info->set_short_name(hlo->ToString( in CreateHloProfilePrinterData()
101 instruction_info->set_category(hlo->ToCategory()); in CreateHloProfilePrinterData()
102 instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); in CreateHloProfilePrinterData()
104 cost_analysis.transcendental_count(*hlo)); in CreateHloProfilePrinterData()
105 instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo)); in CreateHloProfilePrinterData()
107 cost_analysis.optimal_seconds(*hlo)); in CreateHloProfilePrinterData()
109 hlo_profile_index_map.GetProfileIndexFor(*hlo)); in CreateHloProfilePrinterData()
133 void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, in SetCyclesTakenBy() argument
[all …]
Dhlo_cost_analysis.h52 Status HandleElementwiseUnary(const HloInstruction* hlo) override;
53 Status HandleElementwiseBinary(const HloInstruction* hlo) override;
58 Status HandleSelect(const HloInstruction* hlo) override;
59 Status HandleTupleSelect(const HloInstruction* hlo) override;
62 Status HandleReducePrecision(const HloInstruction* hlo) override;
74 Status HandleTriangularSolve(const HloInstruction* hlo) override;
75 Status HandleCholesky(const HloInstruction* hlo) override;
77 Status HandleAllToAll(const HloInstruction* hlo) override;
78 Status HandleCollectivePermute(const HloInstruction* hlo) override;
79 Status HandleReplicaId(const HloInstruction* hlo) override;
[all …]
Delemental_ir_emitter.cc1162 const HloInstruction* hlo, llvm::Value* x) { in EmitReducePrecision() argument
1163 if (hlo->operand(0)->shape().element_type() != F32) { in EmitReducePrecision()
1166 return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), in EmitReducePrecision()
1167 /*mantissa_bits=*/hlo->mantissa_bits(), b_); in EmitReducePrecision()
1369 const HloInstruction* hlo, in ConvertValueForDistribution() argument
1373 operand_to_generator.at(hlo->operand(0))(index)); in ConvertValueForDistribution()
1375 operand_to_generator.at(hlo->operand(1))(index)); in ConvertValueForDistribution()
1376 PrimitiveType elem_prim_ty = hlo->shape().element_type(); in ConvertValueForDistribution()
1447 switch (hlo->random_distribution()) { in ConvertValueForDistribution()
1477 RandomDistribution_Name(hlo->random_distribution())); in ConvertValueForDistribution()
[all …]
Dhlo_module_group_util.cc91 for (HloInstruction* hlo : instruction_group) { in GlobalPredecessors()
92 for (HloInstruction* operand : hlo->operands()) { in GlobalPredecessors()
95 for (HloInstruction* control_predecessor : hlo->control_predecessors()) { in GlobalPredecessors()
169 for (HloInstruction* hlo : instruction_group) { in GlobalSuccessors()
170 for (HloInstruction* user : hlo->users()) { in GlobalSuccessors()
173 for (HloInstruction* control_successor : hlo->control_successors()) { in GlobalSuccessors()
251 HloInstruction* hlo = stack.top(); in VisitTopologicalOrder() local
259 if (metadata_.IsCompanionInstruction(hlo)) { in VisitTopologicalOrder()
260 for (HloInstruction* companion : metadata_.Companions(hlo)) { in VisitTopologicalOrder()
263 } else if (hlo->IsCrossModuleAllReduce()) { in VisitTopologicalOrder()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dstream_assignment.cc28 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { in HasStreamAssigned()
29 return hlo_to_stream_number_.contains(&hlo); in HasStreamAssigned()
32 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { in StreamNumberForHlo()
33 return FindOrDie(hlo_to_stream_number_, &hlo); in StreamNumberForHlo()
36 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, in AssignStreamToHlo() argument
42 InsertOrDie(&hlo_to_stream_number_, hlo, stream_num); in AssignStreamToHlo()
43 VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString(); in AssignStreamToHlo()
66 const HloInstruction& hlo, const StreamAssignment& stream_assignment, in ComputeStreamToAssign() argument
69 if (hlo.opcode() == HloOpcode::kParameter || in ComputeStreamToAssign()
70 hlo.opcode() == HloOpcode::kConstant) { in ComputeStreamToAssign()
[all …]
Dgpu_hlo_schedule.cc86 for (const HloInstruction* hlo : thunk_launch_order) { in GpuHloOrdering() local
87 predecessor_map->SetReachable(hlo, hlo); in GpuHloOrdering()
88 if (stream_assignment.HasStreamAssigned(*hlo)) { in GpuHloOrdering()
92 immediate_preds.insert(immediate_preds.end(), hlo->operands().begin(), in GpuHloOrdering()
93 hlo->operands().end()); in GpuHloOrdering()
95 hlo->control_predecessors().begin(), in GpuHloOrdering()
96 hlo->control_predecessors().end()); in GpuHloOrdering()
100 const int stream_no = stream_assignment.StreamNumberForHlo(*hlo); in GpuHloOrdering()
104 predecessor_map->FastSetReachabilityToUnion(immediate_preds, hlo); in GpuHloOrdering()
105 last_instruction_per_stream[stream_no] = hlo; in GpuHloOrdering()
[all …]
Dir_emission_utils.cc86 bool ImplementedAsGemm(const HloInstruction& hlo) { in ImplementedAsGemm() argument
88 if (hlo.opcode() == HloOpcode::kDot) { in ImplementedAsGemm()
89 return DotImplementedAsGemm(hlo); in ImplementedAsGemm()
92 if (hlo.opcode() == HloOpcode::kFusion && in ImplementedAsGemm()
93 hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && in ImplementedAsGemm()
94 (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply || in ImplementedAsGemm()
95 hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) { in ImplementedAsGemm()
97 const HloInstruction* dot = hlo.fused_expression_root()->operand(0); in ImplementedAsGemm()
99 dot = hlo.fused_expression_root()->operand(1); in ImplementedAsGemm()
116 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { in IsCustomCallToDnnBatchNorm() argument
[all …]
Dcudnn_batchnorm_thunk.cc83 const BufferAllocation::Slice& output, const HloInstruction* hlo) in CudnnBatchNormForwardInferenceThunk() argument
84 : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo), in CudnnBatchNormForwardInferenceThunk()
93 CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); in CudnnBatchNormForwardInferenceThunk()
94 CHECK_EQ(hlo->custom_call_target(), in CudnnBatchNormForwardInferenceThunk()
97 LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); in CudnnBatchNormForwardInferenceThunk()
98 CHECK_EQ(hlo->shape().element_type(), F32) << "Not yet implemented"; in CudnnBatchNormForwardInferenceThunk()
142 const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) in CudnnBatchNormForwardTrainingThunk() argument
143 : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo), in CudnnBatchNormForwardTrainingThunk()
153 CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); in CudnnBatchNormForwardTrainingThunk()
154 CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); in CudnnBatchNormForwardTrainingThunk()
[all …]
Dhlo_to_ir_bindings.cc153 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, in GetTypedIrValue() argument
157 ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); in GetTypedIrValue()
168 ir_value->setName(llvm_ir::IrName(&hlo, "raw")); in GetTypedIrValue()
171 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed")); in GetTypedIrValue()
176 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, in BindHloToIrValue() argument
179 VLOG(2) << "Binding " << hlo.ToString(); in BindHloToIrValue()
181 const Shape& hlo_shape = hlo.shape(); in BindHloToIrValue()
182 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value); in BindHloToIrValue()
184 if (!BoundToIrValue(hlo)) { in BindHloToIrValue()
186 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr)); in BindHloToIrValue()
[all …]
Dir_emitter_nested.cc102 for (const auto* hlo : nested_computation.instructions()) { in EmitBasePointersForNestedComputation() local
103 if (hlo->opcode() != HloOpcode::kParameter && in EmitBasePointersForNestedComputation()
104 hlo != nested_computation.root_instruction()) { in EmitBasePointersForNestedComputation()
105 non_io_hlos.push_back(hlo); in EmitBasePointersForNestedComputation()
117 const HloInstruction& hlo, in EmitTargetElementLoop() argument
121 if (hlo.IsMultiOutputFusion()) { in EmitTargetElementLoop()
123 ConstructIrArrayForOutputs(hlo); in EmitTargetElementLoop()
126 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_); in EmitTargetElementLoop()
129 return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) in EmitTargetElementLoop()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Delemental_ir_emitter.cc108 const HloInstruction* hlo, in MakeElementGenerator() argument
110 switch (hlo->opcode()) { in MakeElementGenerator()
112 return [this, hlo, &operand_to_generator]( in MakeElementGenerator()
115 for (int i = 0; i < hlo->operand_count(); i++) { in MakeElementGenerator()
117 operand_to_generator.at(hlo->operand(i))(index)); in MakeElementGenerator()
120 return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo), in MakeElementGenerator()
121 operands, llvm_ir::IrName(hlo)); in MakeElementGenerator()
124 return [this, hlo, &operand_to_generator](const IrArray::Index& index) { in MakeElementGenerator()
126 Cast<HloReduceWindowInstruction>(hlo), in MakeElementGenerator()
127 operand_to_generator.at(hlo->operand(0)), index); in MakeElementGenerator()
[all …]
Dcpu_instruction_fusion.cc30 bool CanBeLoopFused(const HloInstruction& hlo) { in CanBeLoopFused() argument
33 return hlo.IsElementwise() || // in CanBeLoopFused()
34 hlo.opcode() == HloOpcode::kBroadcast || in CanBeLoopFused()
35 hlo.opcode() == HloOpcode::kConcatenate || in CanBeLoopFused()
36 hlo.opcode() == HloOpcode::kDynamicSlice || in CanBeLoopFused()
37 hlo.opcode() == HloOpcode::kDynamicUpdateSlice || in CanBeLoopFused()
38 hlo.opcode() == HloOpcode::kGather || in CanBeLoopFused()
39 hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || in CanBeLoopFused()
40 hlo.opcode() == HloOpcode::kReshape || in CanBeLoopFused()
41 hlo.opcode() == HloOpcode::kReverse || in CanBeLoopFused()
[all …]
Dconv_canonicalization.cc33 for (HloInstruction* hlo : in Run()
35 if (hlo->opcode() == HloOpcode::kConvolution && in Run()
36 !PotentiallyImplementedAsEigenConvolution(*hlo, in Run()
39 hlo->convolution_dimension_numbers(); in Run()
57 HloInstruction* input = hlo->mutable_operand(0); in Run()
78 HloInstruction* kernel = hlo->mutable_operand(1); in Run()
105 new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); in Run()
109 hlo->shape().dimensions(dnums.output_spatial_dimensions(i)); in Run()
112 new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); in Run()
114 ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); in Run()
[all …]
/external/tensorflow/tensorflow/compiler/xla/tools/
Dhlo_extractor.cc59 Status DefaultAction(const HloInstruction* hlo) override { in DefaultAction() argument
62 if (boundary_ != nullptr && boundary_->count(hlo) > 0) { in DefaultAction()
64 parameter_number_, hlo->shape(), hlo->name()); in DefaultAction()
66 clone_context_.MapInstruction(hlo, new_parameter.get()); in DefaultAction()
71 for (auto operand : hlo->operands()) { in DefaultAction()
75 hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_); in DefaultAction()
117 auto hlo = worklist.front(); in ComputeBoundary() local
119 int64 hops = visited[hlo]; in ComputeBoundary()
121 boundary->insert(hlo); in ComputeBoundary()
124 for (const HloInstruction* operand : hlo->operands()) { in ComputeBoundary()
/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/
Dalias_analysis.cc33 void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, in AddAliasingInformationToIrArray() argument
37 if (hlo.opcode() == HloOpcode::kParameter && in AddAliasingInformationToIrArray()
38 hlo.parent() == hlo.parent()->parent()->entry_computation()) { in AddAliasingInformationToIrArray()
44 assignment_.GetAllSlices(&hlo, index); in AddAliasingInformationToIrArray()
69 assignment_, hlo); in AddAliasingInformationToIrArray()
81 if (hlo.opcode() == HloOpcode::kParameter) { in AddAliasingInformationToIrArray()
84 if (absl::c_linear_search(parameter_instructions, &hlo)) { in AddAliasingInformationToIrArray()
125 const BufferAssignment& assignment, const HloInstruction& hlo) { in GetNoaliasMetadataForBuffer() argument
153 for (HloInstruction* user : hlo.users()) { in GetNoaliasMetadataForBuffer()
160 add_buffers_to_worklist(&hlo); in GetNoaliasMetadataForBuffer()
[all …]

123456