1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class HloModule;
48
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function. It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node). If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program. Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation". For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise. (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66 public:
67 // Builder class for HloComputation.
68 class Builder {
69 public:
70 explicit Builder(const string& name,
71 HloInstruction* fusion_instruction = nullptr)
name_(name)72 : name_(name),
73 last_added_instruction_(nullptr),
74 fusion_instruction_(fusion_instruction) {}
75 Builder(Builder&& b) = default;
76 virtual ~Builder() = default;
77
78 // Build and return an HloComputation. The parameter root_instruction
79 // specifies the already-added instruction to use as the root. If
80 // root_instruction is nullptr then use the last added instruction as the
81 // root.
82 std::unique_ptr<HloComputation> Build(
83 HloInstruction* root_instruction = nullptr);
84
AddInstruction(std::unique_ptr<HloInstruction> instruction)85 virtual HloInstruction* AddInstruction(
86 std::unique_ptr<HloInstruction> instruction) {
87 instructions_.push_back(std::move(instruction));
88 last_added_instruction_ = instructions_.back().get();
89 return last_added_instruction_;
90 }
91
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)92 Status ForEachInstruction(
93 const std::function<Status(const HloInstruction*)>& func) const {
94 for (const auto& instruction : instructions_) {
95 TF_RETURN_IF_ERROR(func(instruction.get()));
96 }
97 return Status::OK();
98 }
99
100 private:
101 const string name_;
102 HloInstruction* last_added_instruction_;
103 HloInstruction* fusion_instruction_;
104 std::vector<std::unique_ptr<HloInstruction>> instructions_;
105
106 TF_DISALLOW_COPY_AND_ASSIGN(Builder);
107 };
108
109 // Helper class to automatically set the OpMetadata for every instruction
110 // added to a computation.
111 class MetadataBuilder {
112 public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)113 MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
114 : computation_(computation), metadata_(metadata) {}
115
AddInstruction(std::unique_ptr<HloInstruction> instruction)116 HloInstruction* AddInstruction(
117 std::unique_ptr<HloInstruction> instruction) {
118 instruction->set_metadata(metadata_);
119 return computation_->AddInstruction(std::move(instruction));
120 }
121
122 private:
123 HloComputation* computation_;
124 OpMetadata metadata_;
125 };
126
127 ~HloComputation();
128
129 // Add an instruction to the computation. The computation takes ownership of
130 // the instruction.
131 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
132 const std::string& new_name = "");
133
134 // Remove the param_no'th parameter from the computation.
135 // Note this is only applicatable to the computation for the fusion
136 // instruction.
137 Status RemoveParameter(int64_t param_no);
138
139 // Remove unused parameters from the computation.
140 // Note this is only applicatable to the computation for the fusion
141 // instruction.
142 Status RemoveUnusedParametersFromFusedComputation();
143
144 // Remove unused parameters from the computation. Unlike
145 // RemoveUnusedParametersFromFusedComputation, this function can be used
146 // to remove parameters from non-fusion computations.
147 Status RemoveUnusedParametersFromAnyComputation();
148
149 // Adds a new parameter instruction to a fusion computation.
150 //
151 // This should be a new parameter. Instruction will be appended to parameters
152 // and inserted to the instruction list.
153 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
154
155 // Adds a new parameter instruction to the entry computation and update
156 // the parent module config to reflect the change.
157 //
158 // This should be a new parameter. Instruction will be appended to parameters
159 // and inserted to the instruction list.
160 HloInstruction* AddEntryComputationParameter(
161 std::unique_ptr<HloInstruction> instruction);
162
163 // Replaces an old parameter with a new parameter. Adds the new parameter
164 // instruction to the entry computation.
165 Status ReplaceEntryComputationParameter(
166 int64_t param_no, HloInstruction* old_instruction,
167 std::unique_ptr<HloInstruction> instruction);
168
169 // Remove an instruction from the computation. The instruction must have no
170 // users. Instruction is deallocated with this call.
171 Status RemoveInstruction(HloInstruction* instruction);
172
173 // Removes an instruction from the computation. The instruction must have no
174 // users. Instruction is deallocated with this call. The instruction will be
175 // removed even if it is marked as not removable.
176 Status ForceRemoveInstruction(HloInstruction* instruction);
177
178 // Remove an instruction (including side effecting ones) from the computation
179 // and also transitively any operand that has no side effect and no users post
180 // removing an instruction. The instruction must have no users. Instruction is
181 // deallocated with this call. If given, the cleanup routine is executed on a
182 // removed instruction before its deallocation.
183 Status RemoveInstructionAndUnusedOperands(
184 HloInstruction* instruction,
185 std::function<void(HloInstruction*)> cleanup = nullptr);
186
187 // Set the root of the computation to the given instruction. The instruction
188 // must have already been added to the computation. In addition it must have
189 // the same shape as the result of the computation for non fusion
190 // computations, except if accept_different_shape is set to true.
191 void set_root_instruction(HloInstruction* new_root_instruction,
192 bool accept_different_shape = false);
193
194 // Return the root instruction of the computation. The root instruction is the
195 // instruction which produces the output of the computation.
root_instruction()196 HloInstruction* root_instruction() const { return root_instruction_; }
197
198 // Returns the number of parameters for this computation.
num_parameters()199 int64 num_parameters() const { return param_instructions_.size(); }
200
201 // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64_t param_no)202 HloInstruction* parameter_instruction(int64_t param_no) const {
203 CHECK_GE(param_no, 0);
204 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
205 << "Computation " << name() << " has no parameter number " << param_no;
206 return param_instructions_[param_no];
207 }
208
parameter_instructions()209 const std::vector<HloInstruction*>& parameter_instructions() const {
210 return param_instructions_;
211 }
212
name()213 const string& name() const { return name_; }
214
215 // Use the given NameUniquer to select a unique name for the computation based
216 // on the computation's existing name.
217 void UniquifyName(NameUniquer* name_uniquer);
218
219 // Return a string representation of the computation.
220 //
221 // (We express the default options using an overload rather than a default
222 // param because gdb ignores default params, but does resolve overloads.)
ToString()223 string ToString() const { return ToString(HloPrintOptions()); }
224 string ToString(const HloPrintOptions& options) const;
225
226 // Overload which accepts an order to emit the instructions in.
227 string ToString(
228 const HloPrintOptions& options,
229 absl::Span<const HloInstruction* const> instruction_order) const;
230
231 // Returns a serialized representation of this computation.
232 HloComputationProto ToProto() const;
233
234 // Creates a computation from the given proto. Arguments:
235 //
236 // proto: the proto to convert from.
237 // computation_map: a map from computation id to HloComputation*. This map
238 // must contain all computations which the newly constructed computation
239 // calls.
240 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
241 const HloComputationProto& proto,
242 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
243 bool prohibit_empty_literal = true);
244
245 using InstructionSequence = tensorflow::gtl::iterator_range<
246 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
247
248 using ConstInstructionSequence =
249 tensorflow::gtl::iterator_range<UnwrappingIterator<
250 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
251
252 // Gets the instructions in this computation.
253 //
254 // The returned type is a range of HloInstruction*s, so you can iterate over
255 // it using a range-based for loop in the natural way:
256 //
257 // for (HloInstruction* instr : computation->instructions()) { ... }
258 //
instructions()259 ConstInstructionSequence instructions() const {
260 return {MakeUnwrappingIterator(instructions_.begin()),
261 MakeUnwrappingIterator(instructions_.end())};
262 }
instructions()263 InstructionSequence instructions() {
264 return {MakeUnwrappingIterator(instructions_.begin()),
265 MakeUnwrappingIterator(instructions_.end())};
266 }
267
268 // Compute and return a post-order of the instructions in the computation. In
269 // this order, definitions of values always appear before their uses.
270 std::vector<HloInstruction*> MakeInstructionPostOrder() const;
271
instruction_count()272 int64 instruction_count() const { return instruction_iterators_.size(); }
273
274 // Creates and returns a list of the embedded computations called by this
275 // computation. This includes all embedded computations called directly or
276 // transitively. The embedded computations are sorted such that if computation
277 // A calls computation B (eg, via a map instruction) then A will appear after
278 // B in the list.
279 std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
280
281 // Creates a fusion instruction containing the given instructions.
282 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
283 // into a library call. Instructions must be in reverse topological order
284 // (root of the fused expression first). Replaces all uses of the original
285 // root instruction with the fusion instruction. The original instructions are
286 // removed if they have no uses after fusion (this is necessarily true for at
287 // least the root).
288 HloInstruction* CreateFusionInstruction(
289 absl::Span<HloInstruction* const> instructions_to_fuse,
290 HloInstruction::FusionKind fusion_kind);
291
292 // Create a deep copy of the given instruction and return the instruction
293 // producing the copied result. All instructions performing the copy are added
294 // to the computation. For array-shaped values, this method trivially returns
295 // a kCopy instruction. For tuple-shaped instructions, the copy is performed
296 // with a series of kGetTupleElement and kTuple instructions. If
297 // indices_to_copy is non-null then this ShapeTree indicates which elements
298 // (arrays) of the shape to copy. Non-copied elements are passed through
299 // transparently. If copies_added is non-null, then the added kCopy
300 // instructions will be inserted in the respective index in the given
301 // ShapeTree.
302 StatusOr<HloInstruction*> DeepCopyInstruction(
303 HloInstruction* instruction,
304 const ShapeTree<bool>* indices_to_copy = nullptr,
305 ShapeTree<HloInstruction*>* copies_added = nullptr);
306
307 // As above, but uses a custom function to copy the leaf nodes, which could
308 // create alternative HLOs other than kCopy, or even pass-throughs.
309 StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
310 HloInstruction* instruction,
311 const std::function<
312 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
313 HloComputation* computation)>& copy_leaf);
314
315 // Computes and returns the ProgramShape of this computation (shape of
316 // parameters and result with layout).
317 ProgramShape ComputeProgramShape(bool include_ids = true) const;
318
319 // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)320 bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
321 return EqualInternal(other, is_layout_sensitive,
322 /*ignore_channel_id_values=*/false);
323 }
324
325 // Same as Equal() but ignores channel ID value mismatches on instructions, as
326 // long as the two instructions both have channel IDs or neither has a channel
327 // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)328 bool EqualIgnoringChannelIdValues(const HloComputation& other,
329 bool is_layout_sensitive) const {
330 return EqualInternal(other, is_layout_sensitive,
331 /*ignore_channel_id_values=*/true);
332 }
333
334 // Return whether `*this` and `other` are functionally equivalent.
335 bool operator==(const HloComputation& other) const {
336 return Equal(other, true);
337 }
338
339 // Replaces old instruction with newly created instruction. Removes old
340 // instruction from computation. Updates uses and root instruction.
341 Status ReplaceWithNewInstruction(
342 HloInstruction* old_instruction,
343 std::unique_ptr<HloInstruction> new_instruction);
344
345 // Replaces an old instruction with a newly created instruction, and adds the
346 // new instruction as an entry computation's parameter. Removes old
347 // instruction from computation. Updates uses and root instruction.
348 Status ReplaceWithNewEntryComputationParameter(
349 HloInstruction* old_instruction,
350 std::unique_ptr<HloInstruction> new_instruction);
351
352 // Replace old instruction with new instruction. Updates uses and root
353 // instruction. Removes old instruction from computation. Precondition:
354 // old_instruction and new_instruction must have the compatible shapes.
355 // If |new_instruction| doesn't have any sharding information it will
356 // receive the sharding information of |old_instruction|.
357 Status ReplaceInstruction(HloInstruction* old_instruction,
358 HloInstruction* new_instruction);
359
360 // Set/get the module containing this computation.
set_parent(HloModule * module)361 void set_parent(HloModule* module) { parent_ = module; }
parent()362 const HloModule* parent() const { return parent_; }
parent()363 HloModule* parent() { return parent_; }
364
365 // Visit every node in the computation in DFS post-order with the given
366 // visitor. This is similar to calling HloInstruction::Accept on the root of
367 // the computation except this method also visits instructions not reachable
368 // via the root. The root instruction of the computation is visited last, and
369 // the visitor's FinishVisit method is called once upon completion (with the
370 // root instruction as the argument).
371 template <typename HloInstructionPtr>
372 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
373
374 // Same as Accept() above, but the order of operand and control predecessor
375 // visitation is determined by the given operand order; if compare(A, B) ==
376 // true, A is visited before B.
377 Status AcceptWithOperandOrder(
378 DfsHloVisitor* visitor,
379 const HloInstruction::CompareFunction& operand_order) const;
380
381 // Visit every node in the computation in the given order. 'order' must
382 // be a topological sort of all instructions in the computation.
383 template <typename HloInstructionPtr>
384 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
385 absl::Span<HloInstruction* const> order) const;
386
387 // Returns a deep copy of this computation including all instructions.
388 // If the clone context is specified, it will be populated with the cloned
389 // object mappings, and its module() will be used to add new computations
390 // into.
391 std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
392 HloCloneContext* context = nullptr);
393
394 // Like Clone(), but if an instruction is present in replacement_map, we use
395 // the map's value to replace that instruction in the cloned computation.
396 //
397 // If replacements maps a key to nullptr, we remove that instruction from the
398 // new computation. If an element of `replacements` references an instruction
399 // that's not already in the computation, it's cloned and added to the new
400 // computation.
401 //
402 // 'extra_parameters' allows to specify additional parameters that should be
403 // added to the computation.
404 //
405 // All relevant instructions are cloned, *including* unique_ptr in the
406 // `replacements` map.
407 std::unique_ptr<HloComputation> CloneWithReplacements(
408 absl::flat_hash_map<const HloInstruction*,
409 std::unique_ptr<HloInstruction>>
410 replacements,
411 absl::Span<const HloInstruction* const> extra_parameters = {},
412 HloCloneContext* context = nullptr, const string& suffix = "clone",
413 const HloInstruction* new_root = nullptr);
414
415 // Convenience overloads for CloneWithReplacements. You want to do
416 //
417 // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR
418 //
419 // but that doesn't work because std::initializer_list is not movable. These
420 // overloads let you do
421 //
422 // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK
423 //
424 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
425 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
426 HloCloneContext* context = nullptr, const string& suffix = "clone");
427 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
428 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
429 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
430 HloCloneContext* context = nullptr, const string& suffix = "clone");
431 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
432 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
433 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
434 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
435 HloCloneContext* context = nullptr, const string& suffix = "clone");
436
437 // Returns true if the given instruction can be removed from the computation.
438 // Parameter instructions cannot be removed without violating invariants of
439 // the HLO computation with the exception of fusion computation. A parameter
440 // instruction is removable for a fusion computation.
441 //
442 // Note that IsSafelyRemovable() is a necessary condition to remove an
443 // instruction rather than a sufficient condition. For example, instructions
444 // with side-effect (e.g., Send, Infeed) may be removed from a computation,
445 // but the transformation must guarantee the invariants relevant to the
446 // instructions still hold (e.g., Send and Recv must be removed together to
447 // make each channel complete).
448 bool IsSafelyRemovable(const HloInstruction* instruction);
449
450 // Returns a map from channel-id to the group of instructions associated with
451 // the channel. These instructions will be considered as a single node for
452 // dependency purposes. Send and RecvDone are in the group, and AllReduces
453 // with the same channel id are in the group.
454 using ChannelDependencyGroup =
455 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
456 ChannelDependencyGroup ComputeChannelDependencies() const;
457
458 // Returns true if this computation has a side effect. A computation has a
459 // side effect if it contains one or more instructions with a side effect.
460 bool HasSideEffect() const;
461
462 // Returns if this computation is a fusion computation.
463 // Do not use this method to determine if fusion_instruction_ != nullptr.
464 // Instead, directly do: FusionInstruction() != nullptr
IsFusionComputation()465 bool IsFusionComputation() const { return is_fusion_computation_; }
466
467 // Returns if this computation is the entry computation of the module.
468 bool IsEntryComputation() const;
469
470 // Returns the owning fusion instruction, or nullptr if this is not a fusion
471 // computation.
FusionInstruction()472 HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)473 void SetFusionInstruction(HloInstruction* fusion_instruction) {
474 fusion_instruction_ = fusion_instruction;
475 is_fusion_computation_ |= (fusion_instruction != nullptr);
476 }
477
478 // Returns if this computation is a custom-call computation.
IsCustomCallComputation()479 bool IsCustomCallComputation() const { return is_custom_call_computation_; }
480
481 // Returns the owning custom call instruction, or nullptr if this is not a
482 // custom call computation.
CustomCallInstruction()483 HloInstruction* CustomCallInstruction() const {
484 return custom_call_instruction_;
485 }
SetCustomCallInstruction(HloInstruction * custom_call_instruction)486 void SetCustomCallInstruction(HloInstruction* custom_call_instruction) {
487 custom_call_instruction_ = custom_call_instruction;
488 is_custom_call_computation_ |= (custom_call_instruction != nullptr);
489 }
490
491 // Returns if this computation is invoked by an Hlo instruction.
IsCalledComputation()492 bool IsCalledComputation() const {
493 return IsFusionComputation() || IsCustomCallComputation();
494 }
495
496 // Clear the unique ID of the computation so that it can be re-assigned, such
497 // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()498 void ClearUniqueIdInternal() { unique_id_ = -1; }
499
500 // The id of this computation should be unique within the module.
SetUniqueId(int64_t id)501 void SetUniqueId(int64_t id) {
502 CHECK_EQ(unique_id_, -1);
503 CHECK_GE(id, 0);
504 unique_id_ = id;
505 }
506
507 // Returns the instruction in this computation that has name `name`. Returns
508 // null if there is no such computation.
509 HloInstruction* GetInstructionWithName(absl::string_view name);
510
unique_id()511 int64 unique_id() const { return unique_id_; }
512
513 // Deallocate instructions that are marked by "RemoveInstruction". The two
514 // stage clean up process is designed such that HloPass can have stable
515 // internal pointers to HloInstructions while we create and remove
516 // HloInstructions in a pass.
Cleanup()517 void Cleanup() { to_be_deleted_.clear(); }
518
519 // Returns true if a given instruction is marked dead in this computation.
520 bool IsMarkedAsDead(const HloInstruction* inst);
521
522 private:
523 explicit HloComputation(
524 const string& name, int parameter_count,
525 std::vector<std::unique_ptr<HloInstruction>>* instructions,
526 HloInstruction* root_instruction, HloInstruction* fusion_instruction);
527
528 // Internal helper for adding instructions.
529 HloInstruction* AddInstructionInternal(
530 std::unique_ptr<HloInstruction> instruction);
531
532 // Internal helper for comparison with different options.
533 bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
534 bool ignore_channel_id_values) const;
535
536 // Fuses HLOs in instructions_to_fuse into fusion_instruction.
537 //
538 // Pre-condition: fusion_instruction's opcode is kFusion.
539 void FuseInstructionsInto(
540 absl::Span<HloInstruction* const> instructions_to_fuse,
541 HloInstruction* fusion_instruction);
542
543 // Internal helper for recursive copying of an instruction. Creates and
544 // returns a deep copy of the given instruction.
545 StatusOr<HloInstruction*> DeepCopyHelper(
546 HloInstruction* instruction, ShapeIndex* index,
547 const std::function<
548 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
549 HloComputation* computation)>& copy_leaf);
550
551 // Internal helper to collect unreachable roots.
552 std::vector<HloInstruction*> CollectUnreachableRoots() const;
553
554 enum VisitState { kVisiting, kVisited };
555 void ComputeInstructionPostOrder(
556 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
557 std::vector<HloInstruction*>* post_order, HloInstruction* root,
558 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
559
560 Status RemoveUnusedParametersImpl(bool allow_non_fusion);
561
562 Status RemoveInstructionImpl(HloInstruction* instruction,
563 bool ignore_safety_check);
564
565 string name_;
566 int64 unique_id_;
567 HloInstruction* root_instruction_;
568
569 // If this computation is a fusion computation, this field points to the
570 // corresponding fusion instruction (if it is live). Otherwise, this is null.
571 HloInstruction* fusion_instruction_;
572
573 // Determines whether this computation is a fusion computation. A fusion
574 // computation ordinarily also has a non-null fusion_instruction_. However, if
575 // a fusion instruction is removed during compilation, the fusion computation
576 // becomes unreachable, and its fusion_instruction_ is set to null. We still
577 // need to regard such computations as fusion computations for HLO scheduling
578 // purposes.
579 bool is_fusion_computation_;
580
581 // If this computation is a custom-call computation, this field points to the
582 // corresponding custom-call instruction (if it is live). Otherwise, this is
583 // null.
584 HloInstruction* custom_call_instruction_;
585
586 // Determines whether this computation is a custom-call computation. A
587 bool is_custom_call_computation_;
588
589 // Module containing this computation.
590 HloModule* parent_ = nullptr;
591
592 // Store instructions in std::list as they can be added and removed
593 // arbitrarily and we want a stable iteration order. Keep a map from
594 // instruction pointer to location in the list for fast lookup.
595 using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
596 InstructionList instructions_;
597 absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
598 instruction_iterators_;
599
600 // Removed instructions are moved into to_be_deleted_ first and then
601 // deallocated when Cleanup is called.
602 std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
603
604 std::vector<HloInstruction*> param_instructions_;
605
606 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
607 };
608
609 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)610 Status HloComputation::Accept(
611 DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
612 // Visit unreachable roots. Beware that the visitor might delete the currently
613 // visited root, which would invalidate iterators if the unreachable roots
614 // weren't computed ahead of time.
615 for (HloInstruction* root : CollectUnreachableRoots()) {
616 VLOG(3) << "Traversing unreachable root: " << root->ToString();
617 // Call FinishVisit only at the end.
618 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
619 }
620 // Visit the computation root instruction last.
621 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
622 }
623
624 // Explicit instantiations.
625 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
626 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
627
628 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)629 Status HloComputation::AcceptOrdered(
630 DfsHloVisitorBase<HloInstructionPtr>* visitor,
631 absl::Span<HloInstruction* const> order) const {
632 VLOG(3) << "Accepting visitor with order.";
633 for (HloInstruction* root : CollectUnreachableRoots()) {
634 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
635 }
636 TF_RET_CHECK(order.size() == instruction_count());
637 absl::flat_hash_set<const HloInstruction*> visited;
638 for (const HloInstruction* instruction : order) {
639 VLOG(3) << "Visiting ordered: " << instruction->ToString();
640 TF_RET_CHECK(instruction_iterators_.contains(instruction))
641 << "Instruction " << instruction->name() << " is not in computation "
642 << name();
643 TF_RET_CHECK(!visited.contains(instruction))
644 << "Instruction " << instruction->name()
645 << " appears more than once in order";
646 HloInstruction* mutable_instruction =
647 const_cast<HloInstruction*>(instruction);
648 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
649 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
650 visitor->SetVisited(*mutable_instruction);
651 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
652 visited.insert(instruction);
653 }
654 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
655 return Status::OK();
656 }
657
658 // Explicit instantiations.
659 template Status HloComputation::AcceptOrdered(
660 DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
661 template Status HloComputation::AcceptOrdered(
662 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
663
664 } // namespace xla
665
666 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
667