1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class HloModule;
48
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function. It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node). If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program. Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation". For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise. (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66 public:
67 // Builder class for HloComputation.
68 class Builder {
69 public:
70 explicit Builder(const string& name,
71 HloInstruction* fusion_instruction = nullptr)
name_(name)72 : name_(name),
73 last_added_instruction_(nullptr),
74 fusion_instruction_(fusion_instruction) {}
75
76 // Build and return an HloComputation. The parameter root_instruction
77 // specifies the already-added instruction to use as the root. If
78 // root_instruction is nullptr then use the last added instruction as the
79 // root.
80 std::unique_ptr<HloComputation> Build(
81 HloInstruction* root_instruction = nullptr);
82
AddInstruction(std::unique_ptr<HloInstruction> instruction)83 HloInstruction* AddInstruction(
84 std::unique_ptr<HloInstruction> instruction) {
85 instructions_.push_back(std::move(instruction));
86 last_added_instruction_ = instructions_.back().get();
87 return last_added_instruction_;
88 }
89
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)90 Status ForEachInstruction(
91 const std::function<Status(const HloInstruction*)>& func) const {
92 for (const auto& instruction : instructions_) {
93 TF_RETURN_IF_ERROR(func(instruction.get()));
94 }
95 return Status::OK();
96 }
97
98 private:
99 const string name_;
100 HloInstruction* last_added_instruction_;
101 HloInstruction* fusion_instruction_;
102 std::vector<std::unique_ptr<HloInstruction>> instructions_;
103 };
104
105 // Helper class to automatically set the OpMetadata for every instruction
106 // added to a computation.
107 class MetadataBuilder {
108 public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)109 MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
110 : computation_(computation), metadata_(metadata) {}
111
AddInstruction(std::unique_ptr<HloInstruction> instruction)112 HloInstruction* AddInstruction(
113 std::unique_ptr<HloInstruction> instruction) {
114 instruction->set_metadata(metadata_);
115 return computation_->AddInstruction(std::move(instruction));
116 }
117
118 private:
119 HloComputation* computation_;
120 OpMetadata metadata_;
121 };
122
123 // Add an instruction to the computation. The computation takes ownership of
124 // the instruction.
125 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
126 const std::string& new_name = "");
127
128 // Remove the param_no'th parameter from the computation.
129 // Note this is only applicatable to the computation for the fusion
130 // instruction.
131 Status RemoveParameter(int64 param_no);
132
133 // Remove unused parameters from the computation.
134 // Note this is only applicatable to the computation for the fusion
135 // instruction.
136 Status RemoveUnusedParametersFromFusedComputation();
137
138 // Remove unused parameters from the computation. Unlike
139 // RemoveUnusedParametersFromFusedComputation, this function can be used
140 // to remove parameters from non-fusion computations.
141 Status RemoveUnusedParametersFromAnyComputation();
142
143 // Adds a new parameter instruction to a fusion computation.
144 //
145 // This should be a new parameter. Instruction will be appended to parameters
146 // and inserted to the instruction list.
147 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
148
149 // Adds a new parameter instruction to the entry computation and update
150 // the parent module config to reflect the change.
151 //
152 // This should be a new parameter. Instruction will be appended to parameters
153 // and inserted to the instruction list.
154 HloInstruction* AddEntryComputationParameter(
155 std::unique_ptr<HloInstruction> instruction);
156
157 // Replaces an old parameter with a new parameter. Adds the new parameter
158 // instruction to the entry computation.
159 Status ReplaceEntryComputationParameter(
160 int64 param_no, HloInstruction* old_instruction,
161 std::unique_ptr<HloInstruction> instruction);
162
163 // Remove an instruction from the computation. The instruction must have no
164 // users. Instruction is deallocated with this call.
165 Status RemoveInstruction(HloInstruction* instruction);
166
167 // Removes an instruction from the computation. The instruction must have no
168 // users. Instruction is deallocated with this call. The instruction will be
169 // removed even if it is marked as not removable.
170 Status ForceRemoveInstruction(HloInstruction* instruction);
171
172 // Remove an instruction (including side effecting ones) from the computation
173 // and also transitively any operand that has no side effect and no users post
174 // removing an instruction. The instruction must have no users. Instruction is
175 // deallocated with this call. If given, the cleanup routine is executed on a
176 // removed instruction before its deallocation.
177 Status RemoveInstructionAndUnusedOperands(
178 HloInstruction* instruction,
179 std::function<void(HloInstruction*)> cleanup = nullptr);
180
181 // Set the root of the computation to the given instruction. The instruction
182 // must have already been added to the computation. In addition it must have
183 // the same shape as the result of the computation for non fusion
184 // computations, except if accept_different_shape is set to true.
185 void set_root_instruction(HloInstruction* new_root_instruction,
186 bool accept_different_shape = false);
187
188 // Return the root instruction of the computation. The root instruction is the
189 // instruction which produces the output of the computation.
root_instruction()190 HloInstruction* root_instruction() const { return root_instruction_; }
191
192 // Returns the number of parameters for this computation.
num_parameters()193 int64 num_parameters() const { return param_instructions_.size(); }
194
195 // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64 param_no)196 HloInstruction* parameter_instruction(int64 param_no) const {
197 CHECK_GE(param_no, 0);
198 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
199 << "Computation " << name() << " has no parameter number " << param_no;
200 return param_instructions_[param_no];
201 }
202
parameter_instructions()203 const std::vector<HloInstruction*>& parameter_instructions() const {
204 return param_instructions_;
205 }
206
name()207 const string& name() const { return name_; }
208
209 // Use the given NameUniquer to select a unique name for the computation based
210 // on the computation's existing name.
211 void UniquifyName(NameUniquer* name_uniquer);
212
213 // Return a string representation of the computation.
214 //
215 // (We express the default options using an overload rather than a default
216 // param because gdb ignores default params, but does resolve overloads.)
ToString()217 string ToString() const { return ToString(HloPrintOptions()); }
218 string ToString(const HloPrintOptions& options) const;
219
220 // Overload which accepts an order to emit the instructions in.
221 string ToString(
222 const HloPrintOptions& options,
223 absl::Span<const HloInstruction* const> instruction_order) const;
224
225 // Returns a serialized representation of this computation.
226 HloComputationProto ToProto() const;
227
228 // Creates a computation from the given proto. Arguments:
229 //
230 // proto: the proto to convert from.
231 // computation_map: a map from computation id to HloComputation*. This map
232 // must contain all computations which the newly constructed computation
233 // calls.
234 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
235 const HloComputationProto& proto,
236 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
237 bool prohibit_empty_literal = true);
238
239 using InstructionSequence = tensorflow::gtl::iterator_range<
240 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
241
242 using ConstInstructionSequence =
243 tensorflow::gtl::iterator_range<UnwrappingIterator<
244 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
245
246 // Gets the instructions in this computation.
247 //
248 // The returned type is a range of HloInstruction*s, so you can iterate over
249 // it using a range-based for loop in the natural way:
250 //
251 // for (HloInstruction* instr : computation->instructions()) { ... }
252 //
instructions()253 ConstInstructionSequence instructions() const {
254 return {MakeUnwrappingIterator(instructions_.begin()),
255 MakeUnwrappingIterator(instructions_.end())};
256 }
instructions()257 InstructionSequence instructions() {
258 return {MakeUnwrappingIterator(instructions_.begin()),
259 MakeUnwrappingIterator(instructions_.end())};
260 }
261
262 // Compute and return a post-order of the instructions in the computation. In
263 // this order, definitions of values always appear before their uses.
264 std::vector<HloInstruction*> MakeInstructionPostOrder() const;
265
instruction_count()266 int64 instruction_count() const { return instruction_iterators_.size(); }
267
268 // Creates and returns a list of the embedded computations called by this
269 // computation. This includes all embedded computations called directly or
270 // transitively. The embedded computations are sorted such that if computation
271 // A calls computation B (eg, via a map instruction) then A will appear after
272 // B in the list.
273 std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
274
275 // Creates a fusion instruction containing the given instructions.
276 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
277 // into a library call. Instructions must be in reverse topological order
278 // (root of the fused expression first). Replaces all uses of the original
279 // root instruction with the fusion instruction. The original instructions are
280 // removed if they have no uses after fusion (this is necessarily true for at
281 // least the root).
282 HloInstruction* CreateFusionInstruction(
283 absl::Span<HloInstruction* const> instructions_to_fuse,
284 HloInstruction::FusionKind fusion_kind);
285
286 // Create a deep copy of the given instruction and return the instruction
287 // producing the copied result. All instructions performing the copy are added
288 // to the computation. For array-shaped values, this method trivially returns
289 // a kCopy instruction. For tuple-shaped instructions, the copy is performed
290 // with a series of kGetTupleElement and kTuple instructions. If
291 // indices_to_copy is non-null then this ShapeTree indicates which elements
292 // (arrays) of the shape to copy. Non-copied elements are passed through
293 // transparently. If copies_added is non-null, then the added kCopy
294 // instructions will be inserted in the respective index in the given
295 // ShapeTree.
296 StatusOr<HloInstruction*> DeepCopyInstruction(
297 HloInstruction* instruction,
298 const ShapeTree<bool>* indices_to_copy = nullptr,
299 ShapeTree<HloInstruction*>* copies_added = nullptr);
300
301 // As above, but uses a custom function to copy the leaf nodes, which could
302 // create alternative HLOs other than kCopy, or even pass-throughs.
303 StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
304 HloInstruction* instruction,
305 const std::function<
306 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
307 HloComputation* computation)>& copy_leaf);
308
309 // Computes and returns the ProgramShape of this computation (shape of
310 // parameters and result with layout).
311 ProgramShape ComputeProgramShape(bool include_ids = true) const;
312
313 // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)314 bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
315 return EqualInternal(other, is_layout_sensitive,
316 /*ignore_channel_id_values=*/false);
317 }
318
319 // Same as Equal() but ignores channel ID value mismatches on instructions, as
320 // long as the two instructions both have channel IDs or neither has a channel
321 // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)322 bool EqualIgnoringChannelIdValues(const HloComputation& other,
323 bool is_layout_sensitive) const {
324 return EqualInternal(other, is_layout_sensitive,
325 /*ignore_channel_id_values=*/true);
326 }
327
328 // Return whether `*this` and `other` are functionally equivalent.
329 bool operator==(const HloComputation& other) const {
330 return Equal(other, true);
331 }
332
333 // Replaces old instruction with newly created instruction. Removes old
334 // instruction from computation. Updates uses and root instruction.
335 Status ReplaceWithNewInstruction(
336 HloInstruction* old_instruction,
337 std::unique_ptr<HloInstruction> new_instruction);
338
339 // Replaces an old instruction with a newly created instruction, and adds the
340 // new instruction as an entry computation's parameter. Removes old
341 // instruction from computation. Updates uses and root instruction.
342 Status ReplaceWithNewEntryComputationParameter(
343 HloInstruction* old_instruction,
344 std::unique_ptr<HloInstruction> new_instruction);
345
346 // Replace old instruction with new instruction. Updates uses and root
347 // instruction. Removes old instruction from computation. Precondition:
348 // old_instruction and new_instruction must have the compatible shapes.
349 // If |new_instruction| doesn't have any sharding information it will
350 // receive the sharding information of |old_instruction|.
351 Status ReplaceInstruction(HloInstruction* old_instruction,
352 HloInstruction* new_instruction);
353
354 // Set/get the module containing this computation.
set_parent(HloModule * module)355 void set_parent(HloModule* module) { parent_ = module; }
parent()356 const HloModule* parent() const { return parent_; }
parent()357 HloModule* parent() { return parent_; }
358
359 // Visit every node in the computation in DFS post-order with the given
360 // visitor. This is similar to calling HloInstruction::Accept on the root of
361 // the computation except this method also visits instructions not reachable
362 // via the root. The root instruction of the computation is visited last, and
363 // the visitor's FinishVisit method is called once upon completion (with the
364 // root instruction as the argument).
365 template <typename HloInstructionPtr>
366 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
367
368 // Same as Accept() above, but the order of operand and control predecessor
369 // visitation is determined by the given operand order; if compare(A, B) ==
370 // true, A is visited before B.
371 Status AcceptWithOperandOrder(
372 DfsHloVisitor* visitor,
373 const HloInstruction::CompareFunction& operand_order) const;
374
375 // Visit every node in the computation in the given order. 'order' must
376 // be a topological sort of all instructions in the computation.
377 template <typename HloInstructionPtr>
378 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
379 absl::Span<HloInstruction* const> order) const;
380
381 // Returns a deep copy of this computation including all instructions.
382 // If the clone context is specified, it will be populated with the cloned
383 // object mappings, and its module() will be used to add new computations
384 // into.
385 std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
386 HloCloneContext* context = nullptr);
387
388 // Like Clone(), but if an instruction is present in replacement_map, we use
389 // the map's value to replace that instruction in the cloned computation.
390 //
391 // If replacements maps a key to nullptr, we remove that instruction from the
392 // new computation. If an element of `replacements` references an instruction
393 // that's not already in the computation, it's cloned and added to the new
394 // computation.
395 //
396 // 'extra_parameters' allows to specify additional parameters that should be
397 // added to the computation.
398 //
399 // All relevant instructions are cloned, *including* unique_ptr in the
400 // `replacements` map.
401 std::unique_ptr<HloComputation> CloneWithReplacements(
402 absl::flat_hash_map<const HloInstruction*,
403 std::unique_ptr<HloInstruction>>
404 replacements,
405 absl::Span<const HloInstruction* const> extra_parameters = {},
406 HloCloneContext* context = nullptr, const string& suffix = "clone",
407 const HloInstruction* new_root = nullptr);
408
409 // Convenience overloads for CloneWithReplacements. You want to do
410 //
411 // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR
412 //
413 // but that doesn't work because std::initializer_list is not movable. These
414 // overloads let you do
415 //
416 // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK
417 //
418 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
419 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
420 HloCloneContext* context = nullptr, const string& suffix = "clone");
421 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
422 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
423 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
424 HloCloneContext* context = nullptr, const string& suffix = "clone");
425 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
426 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
427 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
428 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
429 HloCloneContext* context = nullptr, const string& suffix = "clone");
430
431 // Returns true if the given instruction can be removed from the computation.
432 // Parameter instructions cannot be removed without violating invariants of
433 // the HLO computation with the exception of fusion computation. A parameter
434 // instruction is removable for a fusion computation.
435 //
436 // Note that IsSafelyRemovable() is a necessary condition to remove an
437 // instruction rather than a sufficient condition. For example, instructions
438 // with side-effect (e.g., Send, Infeed) may be removed from a computation,
439 // but the transformation must guarantee the invariants relevant to the
440 // instructions still hold (e.g., Send and Recv must be removed together to
441 // make each channel complete).
442 bool IsSafelyRemovable(const HloInstruction* instruction);
443
444 // Returns a map from channel-id to the group of instructions associated with
445 // the channel. These instructions will be considered as a single node for
446 // dependency purposes. Send and RecvDone are in the group, and AllReduces
447 // with the same channel id are in the group.
448 using ChannelDependencyGroup =
449 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
450 ChannelDependencyGroup ComputeChannelDependencies() const;
451
452 // Returns true if this computation has a side effect. A computation has a
453 // side effect if it contains one or more instructions with a side effect.
454 bool HasSideEffect() const;
455
456 // Returns if this computation is a fusion computation.
IsFusionComputation()457 bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
458
459 // Returns if this computation is the entry computation of the module.
460 bool IsEntryComputation() const;
461
462 // Returns the owning fusion instruction, or nullptr if this is not a fusion
463 // computation.
FusionInstruction()464 HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)465 void SetFusionInstruction(HloInstruction* fusion_instruction) {
466 fusion_instruction_ = fusion_instruction;
467 }
468
469 // Clear the unique ID of the computation so that it can be re-assigned, such
470 // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()471 void ClearUniqueIdInternal() { unique_id_ = -1; }
472
473 // The id of this computation should be unique within the module.
SetUniqueId(int64 id)474 void SetUniqueId(int64 id) {
475 CHECK_EQ(unique_id_, -1);
476 CHECK_GE(id, 0);
477 unique_id_ = id;
478 }
479
480 // Returns the instruction in this computation that has name `name`. Returns
481 // null if there is no such computation.
482 HloInstruction* GetInstructionWithName(absl::string_view name);
483
unique_id()484 int64 unique_id() const { return unique_id_; }
485
486 // Deallocate instructions that are marked by "RemoveInstruction". The two
487 // stage clean up process is designed such that HloPass can have stable
488 // internal pointers to HloInstructions while we create and remove
489 // HloInstructions in a pass.
Cleanup()490 void Cleanup() { to_be_deleted_.clear(); }
491
492 // Returns true if a given instruction is marked dead in this computation.
493 bool IsMarkedAsDead(const HloInstruction* inst);
494
495 private:
496 explicit HloComputation(
497 const string& name, int parameter_count,
498 std::vector<std::unique_ptr<HloInstruction>>* instructions,
499 HloInstruction* root_instruction, HloInstruction* fusion_instruction);
500
501 // Internal helper for adding instructions.
502 HloInstruction* AddInstructionInternal(
503 std::unique_ptr<HloInstruction> instruction);
504
505 // Internal helper for comparison with different options.
506 bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
507 bool ignore_channel_id_values) const;
508
509 // Fuses HLOs in instructions_to_fuse into fusion_instruction.
510 //
511 // Pre-condition: fusion_instruction's opcode is kFusion.
512 void FuseInstructionsInto(
513 absl::Span<HloInstruction* const> instructions_to_fuse,
514 HloInstruction* fusion_instruction);
515
516 // Internal helper for recursive copying of an instruction. Creates and
517 // returns a deep copy of the given instruction.
518 StatusOr<HloInstruction*> DeepCopyHelper(
519 HloInstruction* instruction, ShapeIndex* index,
520 const std::function<
521 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
522 HloComputation* computation)>& copy_leaf);
523
524 // Internal helper to collect unreachable roots.
525 std::vector<HloInstruction*> CollectUnreachableRoots() const;
526
527 enum VisitState { kVisiting, kVisited };
528 void ComputeInstructionPostOrder(
529 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
530 std::vector<HloInstruction*>* post_order, HloInstruction* root,
531 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
532
533 Status RemoveUnusedParametersImpl(bool allow_non_fusion);
534
535 Status RemoveInstructionImpl(HloInstruction* instruction,
536 bool ignore_safety_check);
537
538 string name_;
539 int64 unique_id_;
540 HloInstruction* root_instruction_;
541
542 // If this computation is a fusion computation, this field points to the
543 // corresponding fusion instruction. Otherwise, this is null.
544 HloInstruction* fusion_instruction_;
545
546 // Module containing this computation.
547 HloModule* parent_ = nullptr;
548
549 // Store instructions in std::list as they can be added and removed
550 // arbitrarily and we want a stable iteration order. Keep a map from
551 // instruction pointer to location in the list for fast lookup.
552 using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
553 InstructionList instructions_;
554 absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
555 instruction_iterators_;
556
557 // Removed instructions are moved into to_be_deleted_ first and then
558 // deallocated when Cleanup is called.
559 std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
560
561 std::vector<HloInstruction*> param_instructions_;
562
563 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
564 };
565
566 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)567 Status HloComputation::Accept(
568 DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
569 // Visit unreachable roots. Beware that the visitor might delete the currently
570 // visited root, which would invalidate iterators if the unreachable roots
571 // weren't computed ahead of time.
572 for (HloInstruction* root : CollectUnreachableRoots()) {
573 VLOG(3) << "Traversing unreachable root: " << root->ToString();
574 // Call FinishVisit only at the end.
575 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
576 }
577 // Visit the computation root instruction last.
578 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
579 }
580
581 // Explicit instantiations.
582 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
583 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
584
585 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)586 Status HloComputation::AcceptOrdered(
587 DfsHloVisitorBase<HloInstructionPtr>* visitor,
588 absl::Span<HloInstruction* const> order) const {
589 VLOG(3) << "Accepting visitor with order.";
590 for (HloInstruction* root : CollectUnreachableRoots()) {
591 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
592 }
593 TF_RET_CHECK(order.size() == instruction_count());
594 absl::flat_hash_set<const HloInstruction*> visited;
595 for (const HloInstruction* instruction : order) {
596 VLOG(3) << "Visiting ordered: " << instruction->ToString();
597 TF_RET_CHECK(instruction_iterators_.contains(instruction))
598 << "Instruction " << instruction->name() << " is not in computation "
599 << name();
600 TF_RET_CHECK(!visited.contains(instruction))
601 << "Instruction " << instruction->name()
602 << " appears more than once in order";
603 HloInstruction* mutable_instruction =
604 const_cast<HloInstruction*>(instruction);
605 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
606 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
607 visitor->SetVisited(*mutable_instruction);
608 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
609 visited.insert(instruction);
610 }
611 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
612 return Status::OK();
613 }
614
615 // Explicit instantiations.
616 template Status HloComputation::AcceptOrdered(
617 DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
618 template Status HloComputation::AcceptOrdered(
619 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
620
621 } // namespace xla
622
623 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
624