1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class HloModule;
48
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function. It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node). If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program. Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation". For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise. (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66 public:
67 // Builder class for HloComputation.
68 class Builder {
69 public:
70 explicit Builder(const string& name,
71 HloInstruction* fusion_instruction = nullptr)
name_(name)72 : name_(name),
73 last_added_instruction_(nullptr),
74 fusion_instruction_(fusion_instruction) {}
75
76 // Build and return an HloComputation. The parameter root_instruction
77 // specifies the already-added instruction to use as the root. If
78 // root_instruction is nullptr then use the last added instruction as the
79 // root.
80 std::unique_ptr<HloComputation> Build(
81 HloInstruction* root_instruction = nullptr);
82
AddInstruction(std::unique_ptr<HloInstruction> instruction)83 HloInstruction* AddInstruction(
84 std::unique_ptr<HloInstruction> instruction) {
85 instructions_.push_back(std::move(instruction));
86 last_added_instruction_ = instructions_.back().get();
87 return last_added_instruction_;
88 }
89
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)90 Status ForEachInstruction(
91 const std::function<Status(const HloInstruction*)>& func) const {
92 for (const auto& instruction : instructions_) {
93 TF_RETURN_IF_ERROR(func(instruction.get()));
94 }
95 return Status::OK();
96 }
97
98 private:
99 const string name_;
100 HloInstruction* last_added_instruction_;
101 HloInstruction* fusion_instruction_;
102 std::vector<std::unique_ptr<HloInstruction>> instructions_;
103 };
104
105 // Helper class to automatically set the OpMetadata for every instruction
106 // added to a computation.
107 class MetadataBuilder {
108 public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)109 MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
110 : computation_(computation), metadata_(metadata) {}
111
AddInstruction(std::unique_ptr<HloInstruction> instruction)112 HloInstruction* AddInstruction(
113 std::unique_ptr<HloInstruction> instruction) {
114 instruction->set_metadata(metadata_);
115 return computation_->AddInstruction(std::move(instruction));
116 }
117
118 private:
119 HloComputation* computation_;
120 OpMetadata metadata_;
121 };
122
123 // Add an instruction to the computation. The computation takes ownership of
124 // the instruction.
125 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
126
127 // Remove the param_no'th parameter from the computation.
128 // Note this is only applicatable to the computation for the fusion
129 // instruction.
130 Status RemoveParameter(int64 param_no);
131
132 // Remove unused parameters from the computation.
133 // Note this is only applicatable to the computation for the fusion
134 // instruction.
135 Status RemoveUnusedParametersFromFusedComputation();
136
137 // Remove unused parameters from the computation. Unlike
138 // RemoveUnusedParametersFromFusedComputation, this function can be used
139 // to remove parameters from non-fusion computations.
140 Status RemoveUnusedParametersFromAnyComputation();
141
142 // Adds a new parameter instruction to a fusion computation.
143 //
144 // This should be a new parameter. Instruction will be appended to parameters
145 // and inserted to the instruction list.
146 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
147
148 // Adds a new parameter instruction to the entry computation and update
149 // the parent module config to reflect the change.
150 //
151 // This should be a new parameter. Instruction will be appended to parameters
152 // and inserted to the instruction list.
153 HloInstruction* AddEntryComputationParameter(
154 std::unique_ptr<HloInstruction> instruction);
155
156 // Replaces an old parameter with a new parameter. Adds the new parameter
157 // instruction to the entry computation.
158 Status ReplaceEntryComputationParameter(
159 int64 param_no, HloInstruction* old_instruction,
160 std::unique_ptr<HloInstruction> instruction);
161
162 // Remove an instruction from the computation. The instruction must have no
163 // users. Instruction is deallocated with this call.
164 Status RemoveInstruction(HloInstruction* instruction);
165
166 // Removes an instruction from the computation. The instruction must have no
167 // users. Instruction is deallocated with this call. The instruction will be
168 // removed even if it is marked as not removable.
169 Status ForceRemoveInstruction(HloInstruction* instruction);
170
171 // Remove an instruction (including side effecting ones) from the computation
172 // and also transitively any operand that has no side effect and no users post
173 // removing an instruction. The instruction must have no users. Instruction is
174 // deallocated with this call. If given, the cleanup routine is executed on a
175 // removed instruction before its deallocation.
176 Status RemoveInstructionAndUnusedOperands(
177 HloInstruction* instruction,
178 std::function<void(HloInstruction*)> cleanup = nullptr);
179
180 // Set the root of the computation to the given instruction. The instruction
181 // must have already been added to the computation. In addition it must have
182 // the same shape as the result of the computation for non fusion
183 // computations, except if accept_different_shape is set to true.
184 void set_root_instruction(HloInstruction* new_root_instruction,
185 bool accept_different_shape = false);
186
187 // Return the root instruction of the computation. The root instruction is the
188 // instruction which produces the output of the computation.
root_instruction()189 HloInstruction* root_instruction() const { return root_instruction_; }
190
191 // Returns the number of parameters for this computation.
num_parameters()192 int64 num_parameters() const { return param_instructions_.size(); }
193
194 // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64 param_no)195 HloInstruction* parameter_instruction(int64 param_no) const {
196 CHECK_GE(param_no, 0);
197 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
198 << "Computation " << name() << " has no parameter number " << param_no;
199 return param_instructions_[param_no];
200 }
201
parameter_instructions()202 const std::vector<HloInstruction*>& parameter_instructions() const {
203 return param_instructions_;
204 }
205
name()206 const string& name() const { return name_; }
207
208 // Use the given NameUniquer to select a unique name for the computation based
209 // on the computation's existing name.
210 void UniquifyName(NameUniquer* name_uniquer);
211
212 // Return a string representation of the computation.
213 //
214 // (We express the default options using an overload rather than a default
215 // param because gdb ignores default params, but does resolve overloads.)
ToString()216 string ToString() const { return ToString(HloPrintOptions()); }
217 string ToString(const HloPrintOptions& options) const;
218
219 // Overload which accepts an order to emit the instructions in.
220 string ToString(
221 const HloPrintOptions& options,
222 absl::Span<const HloInstruction* const> instruction_order) const;
223
224 // Returns a serialized representation of this computation.
225 HloComputationProto ToProto() const;
226
227 // Creates a computation from the given proto. Arguments:
228 //
229 // proto: the proto to convert from.
230 // computation_map: a map from computation id to HloComputation*. This map
231 // must contain all computations which the newly constructed computation
232 // calls.
233 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
234 const HloComputationProto& proto,
235 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
236 bool prohibit_empty_literal = true);
237
238 using InstructionSequence = tensorflow::gtl::iterator_range<
239 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
240
241 using ConstInstructionSequence =
242 tensorflow::gtl::iterator_range<UnwrappingIterator<
243 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
244
245 // Gets the instructions in this computation.
246 //
247 // The returned type is a range of HloInstruction*s, so you can iterate over
248 // it using a range-based for loop in the natural way:
249 //
250 // for (HloInstruction* instr : computation->instructions()) { ... }
251 //
instructions()252 ConstInstructionSequence instructions() const {
253 return {MakeUnwrappingIterator(instructions_.begin()),
254 MakeUnwrappingIterator(instructions_.end())};
255 }
instructions()256 InstructionSequence instructions() {
257 return {MakeUnwrappingIterator(instructions_.begin()),
258 MakeUnwrappingIterator(instructions_.end())};
259 }
260
261 // Compute and return a post-order of the instructions in the computation. In
262 // this order, definitions of values always appear before their uses.
263 std::vector<HloInstruction*> MakeInstructionPostOrder() const;
264
instruction_count()265 int64 instruction_count() const { return instruction_iterators_.size(); }
266
267 // Creates and returns a list of the embedded computations called by this
268 // computation. This includes all embedded computations called directly or
269 // transitively. The embedded computations are sorted such that if computation
270 // A calls computation B (eg, via a map instruction) then A will appear after
271 // B in the list.
272 std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
273
274 // Creates a fusion instruction containing the given instructions.
275 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
276 // into a library call. Instructions must be in reverse topological order
277 // (root of the fused expression first). Replaces all uses of the original
278 // root instruction with the fusion instruction. The original instructions are
279 // removed if they have no uses after fusion (this is necessarily true for at
280 // least the root).
281 HloInstruction* CreateFusionInstruction(
282 absl::Span<HloInstruction* const> instructions_to_fuse,
283 HloInstruction::FusionKind fusion_kind);
284
285 // Create a deep copy of the given instruction and return the instruction
286 // producing the copied result. All instructions performing the copy are added
287 // to the computation. For array-shaped values, this method trivially returns
288 // a kCopy instruction. For tuple-shaped instructions, the copy is performed
289 // with a series of kGetTupleElement and kTuple instructions. If
290 // indices_to_copy is non-null then this ShapeTree indicates which elements
291 // (arrays) of the shape to copy. Non-copied elements are passed through
292 // transparently. If copies_added is non-null, then the added kCopy
293 // instructions will be inserted in the respective index in the given
294 // ShapeTree.
295 StatusOr<HloInstruction*> DeepCopyInstruction(
296 HloInstruction* instruction,
297 const ShapeTree<bool>* indices_to_copy = nullptr,
298 ShapeTree<HloInstruction*>* copies_added = nullptr);
299
300 // As above, but uses a custom function to copy the leaf nodes, which could
301 // create alternative HLOs other than kCopy, or even pass-throughs.
302 StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
303 HloInstruction* instruction,
304 const std::function<
305 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
306 HloComputation* computation)>& copy_leaf);
307
308 // Computes and returns the ProgramShape of this computation (shape of
309 // parameters and result with layout).
310 ProgramShape ComputeProgramShape(bool include_ids = true) const;
311
312 // Return whether `*this` and `other` are functionally equivalent.
313 bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
314
315 // Return whether `*this` and `other` are functionally equivalent.
316 bool operator==(const HloComputation& other) const {
317 return Equal(other, true);
318 }
319
320 // Replaces old instruction with newly created instruction. Removes old
321 // instruction from computation. Updates uses and root instruction.
322 Status ReplaceWithNewInstruction(
323 HloInstruction* old_instruction,
324 std::unique_ptr<HloInstruction> new_instruction);
325
326 // Replaces an old instruction with a newly created instruction, and adds the
327 // new instruction as an entry computation's parameter. Removes old
328 // instruction from computation. Updates uses and root instruction.
329 Status ReplaceWithNewEntryComputationParameter(
330 HloInstruction* old_instruction,
331 std::unique_ptr<HloInstruction> new_instruction);
332
333 // Replace old instruction with new instruction. Updates uses and root
334 // instruction. Removes old instruction from computation. Precondition:
335 // old_instruction and new_instruction must have the compatible shapes.
336 // If |new_instruction| doesn't have any sharding information it will
337 // receive the sharding information of |old_instruction|.
338 Status ReplaceInstruction(HloInstruction* old_instruction,
339 HloInstruction* new_instruction);
340
341 // Set/get the module containing this computation.
set_parent(HloModule * module)342 void set_parent(HloModule* module) { parent_ = module; }
parent()343 const HloModule* parent() const { return parent_; }
parent()344 HloModule* parent() { return parent_; }
345
346 // Visit every node in the computation in DFS post-order with the given
347 // visitor. This is similar to calling HloInstruction::Accept on the root of
348 // the computation except this method also visits instructions not reachable
349 // via the root. The root instruction of the computation is visited last, and
350 // the visitor's FinishVisit method is called once upon completion (with the
351 // root instruction as the argument).
352 template <typename HloInstructionPtr>
353 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
354
355 // Same as Accept() above, but the order of operand and control predecessor
356 // visitation is determined by the given operand order; if compare(A, B) ==
357 // true, A is visited before B.
358 Status AcceptWithOperandOrder(
359 DfsHloVisitor* visitor,
360 const HloInstruction::CompareFunction& operand_order) const;
361
362 // Visit every node in the computation in the given order. 'order' must
363 // be a topological sort of all instructions in the computation.
364 template <typename HloInstructionPtr>
365 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
366 absl::Span<HloInstruction* const> order) const;
367
368 // Returns a deep copy of this computation including all instructions.
369 // If the clone context is specified, it will be populated with the cloned
370 // object mappings, and its module() will be used to add new computations
371 // into.
372 std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
373 HloCloneContext* context = nullptr);
374
375 // Like Clone(), but if an instruction is present in replacement_map, we use
376 // the map's value to replace that instruction in the cloned computation.
377 //
378 // If replacements maps a key to nullptr, we remove that instruction from the
379 // new computation. If an element of `replacements` references an instruction
380 // that's not already in the computation, it's cloned and added to the new
381 // computation.
382 //
383 // 'extra_parameters' allows to specify additional parameters that should be
384 // added to the computation.
385 //
386 // All relevant instructions are cloned, *including* unique_ptr in the
387 // `replacements` map.
388 std::unique_ptr<HloComputation> CloneWithReplacements(
389 absl::flat_hash_map<const HloInstruction*,
390 std::unique_ptr<HloInstruction>>
391 replacements,
392 absl::Span<const HloInstruction* const> extra_parameters = {},
393 HloCloneContext* context = nullptr, const string& suffix = "clone");
394
395 // Convenience overloads for CloneWithReplacements. You want to do
396 //
397 // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR
398 //
399 // but that doesn't work because std::initializer_list is not movable. These
400 // overloads let you do
401 //
402 // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK
403 //
404 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
405 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
406 HloCloneContext* context = nullptr, const string& suffix = "clone");
407 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
408 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
409 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
410 HloCloneContext* context = nullptr, const string& suffix = "clone");
411 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
412 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
413 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
414 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
415 HloCloneContext* context = nullptr, const string& suffix = "clone");
416
417 // Returns true if the given instruction can be removed from the computation.
418 // Parameter instructions cannot be removed without violating invariants of
419 // the HLO computation with the exception of fusion computation. A parameter
420 // instruction is removable for a fusion computation.
421 //
422 // Note that IsSafelyRemovable() is a necessary condition to remove an
423 // instruction rather than a sufficient condition. For example, instructions
424 // with side-effect (e.g., Send, Infeed) may be removed from a computation,
425 // but the transformation must guarantee the invariants relevant to the
426 // instructions still hold (e.g., Send and Recv must be removed together to
427 // make each channel complete).
428 bool IsSafelyRemovable(const HloInstruction* instruction);
429
430 // Returns a map from channel-id to the group of instructions associated with
431 // the channel. These instructions will be considered as a single node for
432 // dependency purposes. Send and RecvDone are in the group, and AllReduces
433 // with the same channel id are in the group.
434 using ChannelDependencyGroup =
435 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
436 ChannelDependencyGroup ComputeChannelDependencies() const;
437
438 // Returns true if this computation has a side effect. A computation has a
439 // side effect if it contains one or more instructions with a side effect.
440 bool HasSideEffect() const;
441
442 // Returns if this computation is a fusion computation.
IsFusionComputation()443 bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
444
445 // Returns if this computation is the entry computation of the module.
446 bool IsEntryComputation() const;
447
448 // Returns the owning fusion instruction, or nullptr if this is not a fusion
449 // computation.
FusionInstruction()450 HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)451 void SetFusionInstruction(HloInstruction* fusion_instruction) {
452 fusion_instruction_ = fusion_instruction;
453 }
454
455 // Clear the unique ID of the computation so that it can be re-assigned, such
456 // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()457 void ClearUniqueIdInternal() { unique_id_ = -1; }
458
459 // The id of this computation should be unique within the module.
SetUniqueId(int64 id)460 void SetUniqueId(int64 id) {
461 CHECK_EQ(unique_id_, -1);
462 CHECK_GE(id, 0);
463 unique_id_ = id;
464 }
465
466 // Returns the instruction in this computation that has name `name`. Returns
467 // null if there is no such computation.
468 HloInstruction* GetInstructionWithName(absl::string_view name);
469
unique_id()470 int64 unique_id() const { return unique_id_; }
471
472 private:
473 explicit HloComputation(
474 const string& name, int parameter_count,
475 std::vector<std::unique_ptr<HloInstruction>>* instructions,
476 HloInstruction* root_instruction, HloInstruction* fusion_instruction);
477
478 // Internal helper for adding instructions.
479 HloInstruction* AddInstructionInternal(
480 std::unique_ptr<HloInstruction> instruction);
481
482 // Fuses HLOs in instructions_to_fuse into fusion_instruction.
483 //
484 // Pre-condition: fusion_instruction's opcode is kFusion.
485 void FuseInstructionsInto(
486 absl::Span<HloInstruction* const> instructions_to_fuse,
487 HloInstruction* fusion_instruction);
488
489 // Internal helper for recursive copying of an instruction. Creates and
490 // returns a deep copy of the given instruction.
491 StatusOr<HloInstruction*> DeepCopyHelper(
492 HloInstruction* instruction, ShapeIndex* index,
493 const std::function<
494 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
495 HloComputation* computation)>& copy_leaf);
496
497 // Internal helper to collect unreachable roots.
498 std::vector<HloInstruction*> CollectUnreachableRoots() const;
499
500 enum VisitState { kVisiting, kVisited };
501 void ComputeInstructionPostOrder(
502 const HloComputation::ChannelDependencyGroup& channel_dependency_map,
503 std::vector<HloInstruction*>* post_order, HloInstruction* root,
504 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
505
506 Status RemoveUnusedParametersImpl(bool allow_non_fusion);
507
508 Status RemoveInstructionImpl(HloInstruction* instruction,
509 bool ignore_safety_check);
510
511 string name_;
512 int64 unique_id_;
513 HloInstruction* root_instruction_;
514
515 // If this computation is a fusion computation, this field points to the
516 // corresponding fusion instruction. Otherwise, this is null.
517 HloInstruction* fusion_instruction_;
518
519 // Module containing this computation.
520 HloModule* parent_ = nullptr;
521
522 // Store instructions in std::list as they can be added and removed
523 // arbitrarily and we want a stable iteration order. Keep a map from
524 // instruction pointer to location in the list for fast lookup.
525 using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
526 InstructionList instructions_;
527 absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
528 instruction_iterators_;
529
530 std::vector<HloInstruction*> param_instructions_;
531
532 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
533 };
534
535 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)536 Status HloComputation::Accept(
537 DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
538 // Visit unreachable roots. Beware that the visitor might delete the currently
539 // visited root, which would invalidate iterators if the unreachable roots
540 // weren't computed ahead of time.
541 for (HloInstruction* root : CollectUnreachableRoots()) {
542 VLOG(3) << "Traversing unreachable root: " << root->ToString();
543 // Call FinishVisit only at the end.
544 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
545 }
546 // Visit the computation root instruction last.
547 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
548 }
549
550 // Explicit instantiations.
551 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
552 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
553
554 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)555 Status HloComputation::AcceptOrdered(
556 DfsHloVisitorBase<HloInstructionPtr>* visitor,
557 absl::Span<HloInstruction* const> order) const {
558 VLOG(3) << "Accepting visitor with order.";
559 for (HloInstruction* root : CollectUnreachableRoots()) {
560 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
561 }
562 TF_RET_CHECK(order.size() == instruction_count());
563 absl::flat_hash_set<const HloInstruction*> visited;
564 for (const HloInstruction* instruction : order) {
565 VLOG(3) << "Visiting ordered: " << instruction->ToString();
566 TF_RET_CHECK(instruction_iterators_.contains(instruction))
567 << "Instruction " << instruction->name() << " is not in computation "
568 << name();
569 TF_RET_CHECK(!visited.contains(instruction))
570 << "Instruction " << instruction->name()
571 << " appears more than once in order";
572 HloInstruction* mutable_instruction =
573 const_cast<HloInstruction*>(instruction);
574 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
575 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
576 visitor->SetVisited(*mutable_instruction);
577 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
578 visited.insert(instruction);
579 }
580 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
581 return Status::OK();
582 }
583
584 // Explicit instantiations.
585 template Status HloComputation::AcceptOrdered(
586 DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
587 template Status HloComputation::AcceptOrdered(
588 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
589
590 } // namespace xla
591
592 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
593