Home
last modified time | relevance | path

Searched refs:HloComputation (Results 1 – 25 of 314) sorted by relevance

12345678910>>...13

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_module.h73 HloComputation* AddEntryComputation(
74 std::unique_ptr<HloComputation> computation);
79 HloComputation* AddEntryComputationWithLayouts(
80 std::unique_ptr<HloComputation> computation);
85 void ReplaceEntryComputation(HloComputation* entry_computation);
88 HloComputation* AddEmbeddedComputation(
89 std::unique_ptr<HloComputation> computation);
92 Status RemoveEmbeddedComputation(HloComputation* to_remove);
105 const std::unordered_map<HloComputation*, HloComputation*>& replacements);
118 HloComputation* DeepCloneComputation(HloComputation* computation,
[all …]
Dcall_graph.h62 const std::vector<HloComputation*>& called_computations, in CallSite()
72 const std::vector<HloComputation*>& called_computations() const { in called_computations()
86 const std::vector<HloComputation*> called_computations_;
95 CallGraphNode(HloComputation* computation);
98 HloComputation* computation() const { return computation_; } in computation()
111 const std::vector<HloComputation*>& callees() const { return callees_; } in callees()
119 const std::vector<HloComputation*>& callers() const { return callers_; } in callers()
151 HloComputation* computation_;
155 std::vector<HloComputation*> callees_;
156 absl::flat_hash_set<HloComputation*> callee_set_;
[all …]
Dhlo_module.cc59 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) { in ReplaceEntryComputation()
67 HloComputation* HloModule::AddComputationInternal( in AddComputationInternal()
68 std::unique_ptr<HloComputation> computation, bool is_entry, in AddComputationInternal()
122 HloComputation* HloModule::AddEntryComputation( in AddEntryComputation()
123 std::unique_ptr<HloComputation> computation) { in AddEntryComputation()
129 HloComputation* HloModule::AddEntryComputationWithLayouts( in AddEntryComputationWithLayouts()
130 std::unique_ptr<HloComputation> computation) { in AddEntryComputationWithLayouts()
136 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { in RemoveEmbeddedComputation()
142 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) { in RemoveEmbeddedComputation()
151 HloComputation* HloModule::AddEmbeddedComputation( in AddEmbeddedComputation()
[all …]
Dflatten_call_graph_test.cc36 std::unique_ptr<HloComputation> MakeScalarComputation() { in MakeScalarComputation()
37 HloComputation::Builder builder(TestName() + ".ScalarComputation"); in MakeScalarComputation()
47 std::unique_ptr<HloComputation> MakeMappingComputation( in MakeMappingComputation()
48 HloComputation* map_computation, int64 callsites) { in MakeMappingComputation()
49 HloComputation::Builder builder(TestName() + ".MappingComputation"); in MakeMappingComputation()
62 std::unique_ptr<HloComputation> MakeCallingComputation( in MakeCallingComputation()
63 HloComputation* callee_computation, int64 callsites, in MakeCallingComputation()
65 HloComputation::Builder builder(TestName() + suffix); in MakeCallingComputation()
78 std::unique_ptr<HloComputation> MakeConditionComputation() { in MakeConditionComputation()
79 HloComputation::Builder builder(TestName() + ".ConditionComputation"); in MakeConditionComputation()
[all …]
Dhlo_clone_context.h27 class HloComputation; variable
49 void MapComputation(const HloComputation* old_computation, in MapComputation()
50 HloComputation* new_computation) { in MapComputation()
62 HloComputation* FindComputation(const HloComputation* old_computation) const { in FindComputation()
72 HloComputation* GetComputation(const HloComputation* old_computation) const { in GetComputation()
81 const absl::flat_hash_map<const HloComputation*, HloComputation*>&
90 absl::flat_hash_map<const HloComputation*, HloComputation*> computations_;
Dhlo_computation.cc51 std::unique_ptr<HloComputation> HloComputation::Builder::Build( in Build()
63 return absl::WrapUnique(new HloComputation( in Build()
67 HloComputation::HloComputation( in HloComputation() function in xla::HloComputation
95 HloInstruction* HloComputation::AddInstruction( in AddInstruction()
106 HloInstruction* HloComputation::AddInstructionInternal( in AddInstructionInternal()
119 HloInstruction* HloComputation::AddParameter( in AddParameter()
130 HloInstruction* HloComputation::AddEntryComputationParameter( in AddEntryComputationParameter()
148 Status HloComputation::ReplaceEntryComputationParameter( in ReplaceEntryComputationParameter()
168 Status HloComputation::RemoveParameter(int64 param_no) { in RemoveParameter()
192 Status HloComputation::RemoveUnusedParametersFromFusedComputation() { in RemoveUnusedParametersFromFusedComputation()
[all …]
Dcall_graph_test.cc37 std::unique_ptr<HloComputation> MakeScalarComputation( in MakeScalarComputation()
39 HloComputation::Builder builder(TestName() + ".ScalarComputation"); in MakeScalarComputation()
49 std::unique_ptr<HloComputation> MakeMappingComputation( in MakeMappingComputation()
50 HloComputation* map_computation, int64 callsites) { in MakeMappingComputation()
51 HloComputation::Builder builder(TestName() + ".MappingComputation"); in MakeMappingComputation()
64 std::unique_ptr<HloComputation> MakeCallingComputation( in MakeCallingComputation()
65 HloComputation* callee_computation, int64 callsites, in MakeCallingComputation()
67 HloComputation::Builder builder(TestName() + suffix); in MakeCallingComputation()
80 std::unique_ptr<HloComputation> MakeConditionComputation() { in MakeConditionComputation()
81 HloComputation::Builder builder(TestName() + ".ConditionComputation"); in MakeConditionComputation()
[all …]
Dhlo_instructions.h52 const std::function<bool(const HloComputation*, const HloComputation*)>&
121 const std::function<bool(const HloComputation*, const HloComputation*)>&
149 const std::function<bool(const HloComputation*, const HloComputation*)>&
173 const std::function<bool(const HloComputation*, const HloComputation*)>&
199 const std::function<bool(const HloComputation*, const HloComputation*)>&
224 const std::function<bool(const HloComputation*, const HloComputation*)>&
252 const std::function<bool(const HloComputation*, const HloComputation*)>& in IdenticalSlowPathIgnoringChannelIdValues() argument
270 const std::function<bool(const HloComputation*, const HloComputation*)>&
293 const std::function<bool(const HloComputation*, const HloComputation*)>&
381 const std::function<bool(const HloComputation*, const HloComputation*)>&
[all …]
Dhlo_computation.h65 class HloComputation {
80 std::unique_ptr<HloComputation> Build(
109 MetadataBuilder(HloComputation* computation, const OpMetadata& metadata) in MetadataBuilder()
119 HloComputation* computation_;
234 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
236 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
273 std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
307 HloComputation* computation)>& copy_leaf);
314 bool Equal(const HloComputation& other, bool is_layout_sensitive) const { in Equal()
322 bool EqualIgnoringChannelIdValues(const HloComputation& other, in EqualIgnoringChannelIdValues()
[all …]
Dcall_inliner_test.cc48 HloComputation::Builder inner(TestName() + ".inner"); in TEST_F()
55 HloComputation* inner_computation = in TEST_F()
59 HloComputation::Builder outer(TestName() + ".outer"); in TEST_F()
87 HloComputation::Builder just_false(TestName() + ".false"); in TEST_F()
90 HloComputation* false_computation = in TEST_F()
93 HloComputation::Builder call_false_builder(TestName() + ".call_false"); in TEST_F()
98 HloComputation* call_false = in TEST_F()
101 HloComputation::Builder outer(TestName() + ".outer"); in TEST_F()
125 HloComputation::Builder just_false(TestName() + ".false"); in TEST_F()
131 HloComputation* false_computation = in TEST_F()
[all …]
Dhlo_memory_scheduler.h44 HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
46 const absl::flat_hash_map<const HloComputation*, int64>&,
64 HloComputation* computation,
68 const absl::flat_hash_map<const HloComputation*, int64>&
74 HloComputation* computation,
78 const absl::flat_hash_map<const HloComputation*, int64>&
84 HloComputation* computation,
88 const absl::flat_hash_map<const HloComputation*, int64>&
97 HloComputation* computation,
101 const absl::flat_hash_map<const HloComputation*, int64>&
[all …]
Dhlo_module_test.cc44 std::unique_ptr<HloComputation> CreateConstantComputation() { in CreateConstantComputation()
45 auto builder = HloComputation::Builder("Constant"); in CreateConstantComputation()
52 std::unique_ptr<HloComputation> CreateCallComputation( in CreateCallComputation()
53 absl::Span<HloComputation* const> computations) { in CreateCallComputation()
54 auto builder = HloComputation::Builder("Call"); in CreateCallComputation()
118 HloComputation* fused_computation; in TEST_F()
120 auto b = HloComputation::Builder("Fused"); in TEST_F()
129 auto b = HloComputation::Builder("Entry"); in TEST_F()
179 auto builder = HloComputation::Builder("Constant"); in TEST_F()
283 HloComputation* entry = module->entry_computation(); in TEST_F()
[all …]
Dwhile_loop_invariant_code_motion_test.cc33 HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
37 static void FindOnlyWhileInstruction(HloComputation* computation, in FindOnlyWhileInstruction()
50 HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( in MakeAlwaysTrueComputation()
52 HloComputation::Builder builder(TestName() + ".always_true"); in MakeAlwaysTrueComputation()
66 HloComputation* while_body = [&]() { in TEST_F()
67 HloComputation::Builder builder(TestName() + ".while_body"); in TEST_F()
83 HloComputation::Builder builder(TestName()); in TEST_F()
89 HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); in TEST_F()
108 HloComputation* while_body = [&]() { in TEST_F()
109 HloComputation::Builder builder(TestName() + ".while_body"); in TEST_F()
[all …]
Dcall_graph.cc81 [](string* out, const HloComputation* computation) { in ToString()
86 CallGraphNode::CallGraphNode(HloComputation* computation) in CallGraphNode()
102 HloComputation* caller = caller_callsite.instruction()->parent(); in AddCallerCallSite()
132 const HloComputation* computation) const { in GetNode()
138 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { in GetNode()
145 const HloComputation* a, const HloComputation* b, in DominatesHelper()
146 absl::flat_hash_set<const HloComputation*>* visited) const { in DominatesHelper()
161 for (const HloComputation* b_caller : b_node.callers()) { in DominatesHelper()
169 bool CallGraph::Dominates(const HloComputation* a, in Dominates()
170 const HloComputation* b) const { in Dominates()
[all …]
Dhlo_schedule.cc35 absl::flat_hash_map<int64, const HloComputation*> id_to_computation; in CreateFromProto()
36 for (const HloComputation* computation : module->computations()) { in CreateFromProto()
47 const HloComputation* computation = comp_it->second; in CreateFromProto()
84 void HloSchedule::set_sequence(const HloComputation* computation, in set_sequence()
89 void HloSchedule::set_sequence(const HloComputation* computation, in set_sequence()
96 const HloComputation* computation) { in GetOrCreateSequence()
108 const HloComputation* computation) const { in sequence()
113 const HloComputation* computation) { in UpdateComputationSchedule()
204 std::vector<HloComputation*> nonfusion_computations = in Update()
206 for (const HloComputation* computation : nonfusion_computations) { in Update()
[all …]
Dwhile_util.cc30 static StatusOr<HloComputation*> WidenWhileCondition( in WidenWhileCondition()
31 HloComputation* narrow_condition, const Shape& wide_shape) { in WidenWhileCondition()
35 HloComputation* wide_while_cond = [&]() { in WidenWhileCondition()
36 HloComputation::Builder builder(StrCat("wide.", narrow_condition->name())); in WidenWhileCondition()
62 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
63 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { in WidenWhileBody()
66 HloComputation* wide_while_body = [&]() { in WidenWhileBody()
67 HloComputation::Builder builder(StrCat("wide.", narrow_body->name())); in WidenWhileBody()
109 HloComputation * new_while_condition, in MakeInstructionsLiveIn()
112 HloComputation* new_while_body; in MakeInstructionsLiveIn()
[all …]
Dhlo_instruction.h63 class HloComputation; variable
286 using HloComputationPredicate = std::function<bool(const HloComputation*)>;
330 bool print_computation(const HloComputation* comp) const { in print_computation()
363 HloComputationPredicate print_computation_ = [](const HloComputation* comp) {
500 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
580 HloComputation* map_computation);
653 HloComputation* reduce_computation,
810 HloComputation* reduce_computation);
826 HloComputation* reduce_computation);
833 const Window& window, HloComputation* reduce_computation);
[all …]
Dhlo_rematerialization_test.cc66 HloComputation* computation = in TEST_F()
106 HloComputation* computation = in TEST_F()
127 HloComputation* computation = in TEST_F()
150 auto cond_builder = HloComputation::Builder(TestName() + ".cond"); in TEST_F()
155 HloComputation* while_cond = in TEST_F()
158 HloComputation* body_computation = module->AddEmbeddedComputation( in TEST_F()
160 HloComputation* entry_computation = in TEST_F()
186 auto cond_builder = HloComputation::Builder(TestName() + ".cond"); in TEST_F()
191 HloComputation* while_cond = in TEST_F()
194 HloComputation* body_computation = module->AddEmbeddedComputation( in TEST_F()
[all …]
Dhlo_alias_analysis_test.cc119 auto builder = HloComputation::Builder(TestName()); in TEST_F()
148 auto builder = HloComputation::Builder(TestName()); in TEST_F()
197 auto builder = HloComputation::Builder(TestName()); in TEST_F()
225 auto builder = HloComputation::Builder(TestName()); in TEST_F()
274 auto builder = HloComputation::Builder(TestName()); in TEST_F()
333 auto body_builder = HloComputation::Builder("body"); in TEST_F()
344 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); in TEST_F()
347 auto cond_builder = HloComputation::Builder("condition"); in TEST_F()
352 HloComputation* condition = in TEST_F()
355 auto builder = HloComputation::Builder(TestName()); in TEST_F()
[all …]
Dhlo_ordering_test.cc58 auto builder_c = HloComputation::Builder("C"); in TEST_F()
61 HloComputation* computation_c = in TEST_F()
64 auto builder_b = HloComputation::Builder("B"); in TEST_F()
69 HloComputation* computation_b = in TEST_F()
72 auto builder_a = HloComputation::Builder("A"); in TEST_F()
75 HloComputation* computation_a = in TEST_F()
78 auto builder = HloComputation::Builder(TestName()); in TEST_F()
131 auto body_builder = HloComputation::Builder("body"); in TEST_F()
136 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); in TEST_F()
138 auto cond_builder = HloComputation::Builder("condition"); in TEST_F()
[all …]
Dhlo_creation_utils_test.cc33 HloComputation** entry_computation) { in CreateModuleWithProgramShape()
48 HloComputation** entry_computation, PrimitiveType primitive_type_output) { in CreateModuleWithProgramShape()
63 HloComputation* entry_computation; in TEST_F()
82 HloComputation* entry_computation; in TEST_F()
105 HloComputation* entry_computation; in TEST_F()
124 HloComputation* entry_computation; in TEST_F()
143 HloComputation* entry_computation; in TEST_F()
162 HloComputation* entry_computation; in TEST_F()
183 HloComputation* entry_computation; in TEST_F()
203 HloComputation* entry_computation; in TEST_F()
[all …]
Dall_reduce_combiner_test.cc46 for (HloComputation* computation : module.computations()) { in AllReduceCount()
63 std::vector<int64> sizes_in_kib, std::vector<HloComputation*> reductions, in MakeCrossReplicaReductions()
64 std::vector<HloInstruction*>* inputs, HloComputation::Builder* b) { in MakeCrossReplicaReductions()
69 HloComputation* reduction = reductions[i]; in MakeCrossReplicaReductions()
86 HloComputation* MakeReduction(const HloOpcode type, HloModule* module) { in MakeReduction()
87 HloComputation::Builder sum_builder(HloOpcodeString(type)); in MakeReduction()
94 HloComputation* reduction = in MakeReduction()
116 HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get()); in TEST_F()
118 HloComputation::Builder b(TestName()); in TEST_F()
160 HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get()); in TEST_F()
[all …]
Dhlo_schedule.h118 const HloComputation* computation) const;
123 const HloComputation* computation);
126 void set_sequence(const HloComputation* computation,
128 void set_sequence(const HloComputation* computation,
138 bool is_computation_scheduled(const HloComputation* computation) const { in is_computation_scheduled()
143 void remove_computation(const HloComputation* computation) { in remove_computation()
150 void remove_instruction(const HloComputation* computation, in remove_instruction()
157 void replace_instruction(const HloComputation* computation, in replace_instruction()
192 Status UpdateComputationSchedule(const HloComputation* computation);
Dhlo_rematerialization_test_utils.h57 std::unique_ptr<HloComputation> MakeRematerializableComputation(
59 auto builder = HloComputation::Builder(TestName() + suffix);
101 std::unique_ptr<HloComputation> MakeRematerializableWhileComputation(
102 HloComputation* while_cond, HloComputation* while_body,
104 auto builder = HloComputation::Builder(TestName() + suffix);
129 std::unique_ptr<HloComputation> MakeConditionComputation() { in MakeConditionComputation()
130 auto builder = HloComputation::Builder(TestName() + ".cond"); in MakeConditionComputation()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dhlo_function_importer.h39 class HloComputation; variable
48 static Status ImportAsFunc(const xla::HloComputation& computation,
50 std::unordered_map<const xla::HloComputation*,
55 static Status ImportAsRegion(const xla::HloComputation& computation,
61 const xla::HloComputation& computation,
81 std::unordered_map<const xla::HloComputation*, in HloFunctionImporter() argument
94 StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
97 tensorflow::Status ImportAsRegion(const HloComputation& computation,
102 tensorflow::Status ImportInstructions(const HloComputation& computation,
105 const xla::HloComputation& computation,
[all …]

12345678910>>...13