• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18 
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class HloModule;
48 
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function.  It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node).  If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program.  Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation".  For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise.  (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66  public:
67   // Builder class for HloComputation.
68   class Builder {
69    public:
70     explicit Builder(const string& name,
71                      HloInstruction* fusion_instruction = nullptr)
name_(name)72         : name_(name),
73           last_added_instruction_(nullptr),
74           fusion_instruction_(fusion_instruction) {}
75 
76     // Build and return an HloComputation. The parameter root_instruction
77     // specifies the already-added instruction to use as the root. If
78     // root_instruction is nullptr then use the last added instruction as the
79     // root.
80     std::unique_ptr<HloComputation> Build(
81         HloInstruction* root_instruction = nullptr);
82 
AddInstruction(std::unique_ptr<HloInstruction> instruction)83     HloInstruction* AddInstruction(
84         std::unique_ptr<HloInstruction> instruction) {
85       instructions_.push_back(std::move(instruction));
86       last_added_instruction_ = instructions_.back().get();
87       return last_added_instruction_;
88     }
89 
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)90     Status ForEachInstruction(
91         const std::function<Status(const HloInstruction*)>& func) const {
92       for (const auto& instruction : instructions_) {
93         TF_RETURN_IF_ERROR(func(instruction.get()));
94       }
95       return Status::OK();
96     }
97 
98    private:
99     const string name_;
100     HloInstruction* last_added_instruction_;
101     HloInstruction* fusion_instruction_;
102     std::vector<std::unique_ptr<HloInstruction>> instructions_;
103   };
104 
105   // Helper class to automatically set the OpMetadata for every instruction
106   // added to a computation.
107   class MetadataBuilder {
108    public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)109     MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
110         : computation_(computation), metadata_(metadata) {}
111 
AddInstruction(std::unique_ptr<HloInstruction> instruction)112     HloInstruction* AddInstruction(
113         std::unique_ptr<HloInstruction> instruction) {
114       instruction->set_metadata(metadata_);
115       return computation_->AddInstruction(std::move(instruction));
116     }
117 
118    private:
119     HloComputation* computation_;
120     OpMetadata metadata_;
121   };
122 
123   // Add an instruction to the computation. The computation takes ownership of
124   // the instruction.
125   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
126 
127   // Remove the param_no'th parameter from the computation.
128   // Note this is only applicatable to the computation for the fusion
129   // instruction.
130   Status RemoveParameter(int64 param_no);
131 
132   // Remove unused parameters from the computation.
133   // Note this is only applicatable to the computation for the fusion
134   // instruction.
135   Status RemoveUnusedParametersFromFusedComputation();
136 
137   // Remove unused parameters from the computation. Unlike
138   // RemoveUnusedParametersFromFusedComputation, this function can be used
139   // to remove parameters from non-fusion computations.
140   Status RemoveUnusedParametersFromAnyComputation();
141 
142   // Adds a new parameter instruction to a fusion computation.
143   //
144   // This should be a new parameter. Instruction will be appended to parameters
145   // and inserted to the instruction list.
146   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
147 
148   // Adds a new parameter instruction to the entry computation and update
149   // the parent module config to reflect the change.
150   //
151   // This should be a new parameter. Instruction will be appended to parameters
152   // and inserted to the instruction list.
153   HloInstruction* AddEntryComputationParameter(
154       std::unique_ptr<HloInstruction> instruction);
155 
156   // Replaces an old parameter with a new parameter. Adds the new parameter
157   // instruction to the entry computation.
158   Status ReplaceEntryComputationParameter(
159       int64 param_no, HloInstruction* old_instruction,
160       std::unique_ptr<HloInstruction> instruction);
161 
162   // Remove an instruction from the computation. The instruction must have no
163   // users. Instruction is deallocated with this call.
164   Status RemoveInstruction(HloInstruction* instruction);
165 
166   // Removes an instruction from the computation. The instruction must have no
167   // users. Instruction is deallocated with this call. The instruction will be
168   // removed even if it is marked as not removable.
169   Status ForceRemoveInstruction(HloInstruction* instruction);
170 
171   // Remove an instruction (including side effecting ones) from the computation
172   // and also transitively any operand that has no side effect and no users post
173   // removing an instruction. The instruction must have no users. Instruction is
174   // deallocated with this call. If given, the cleanup routine is executed on a
175   // removed instruction before its deallocation.
176   Status RemoveInstructionAndUnusedOperands(
177       HloInstruction* instruction,
178       std::function<void(HloInstruction*)> cleanup = nullptr);
179 
180   // Set the root of the computation to the given instruction. The instruction
181   // must have already been added to the computation. In addition it must have
182   // the same shape as the result of the computation for non fusion
183   // computations, except if accept_different_shape is set to true.
184   void set_root_instruction(HloInstruction* new_root_instruction,
185                             bool accept_different_shape = false);
186 
187   // Return the root instruction of the computation. The root instruction is the
188   // instruction which produces the output of the computation.
root_instruction()189   HloInstruction* root_instruction() const { return root_instruction_; }
190 
191   // Returns the number of parameters for this computation.
num_parameters()192   int64 num_parameters() const { return param_instructions_.size(); }
193 
194   // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64 param_no)195   HloInstruction* parameter_instruction(int64 param_no) const {
196     CHECK_GE(param_no, 0);
197     CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
198         << "Computation " << name() << " has no parameter number " << param_no;
199     return param_instructions_[param_no];
200   }
201 
parameter_instructions()202   const std::vector<HloInstruction*>& parameter_instructions() const {
203     return param_instructions_;
204   }
205 
name()206   const string& name() const { return name_; }
207 
208   // Use the given NameUniquer to select a unique name for the computation based
209   // on the computation's existing name.
210   void UniquifyName(NameUniquer* name_uniquer);
211 
212   // Return a string representation of the computation.
213   //
214   // (We express the default options using an overload rather than a default
215   // param because gdb ignores default params, but does resolve overloads.)
ToString()216   string ToString() const { return ToString(HloPrintOptions()); }
217   string ToString(const HloPrintOptions& options) const;
218 
219   // Overload which accepts an order to emit the instructions in.
220   string ToString(
221       const HloPrintOptions& options,
222       absl::Span<const HloInstruction* const> instruction_order) const;
223 
224   // Returns a serialized representation of this computation.
225   HloComputationProto ToProto() const;
226 
227   // Creates a computation from the given proto. Arguments:
228   //
229   //   proto: the proto to convert from.
230   //   computation_map: a map from computation id to HloComputation*. This map
231   //     must contain all computations which the newly constructed computation
232   //     calls.
233   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
234       const HloComputationProto& proto,
235       const absl::flat_hash_map<int64, HloComputation*>& computation_map,
236       bool prohibit_empty_literal = true);
237 
238   using InstructionSequence = tensorflow::gtl::iterator_range<
239       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
240 
241   using ConstInstructionSequence =
242       tensorflow::gtl::iterator_range<UnwrappingIterator<
243           std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
244 
245   // Gets the instructions in this computation.
246   //
247   // The returned type is a range of HloInstruction*s, so you can iterate over
248   // it using a range-based for loop in the natural way:
249   //
250   //   for (HloInstruction* instr : computation->instructions()) { ... }
251   //
instructions()252   ConstInstructionSequence instructions() const {
253     return {MakeUnwrappingIterator(instructions_.begin()),
254             MakeUnwrappingIterator(instructions_.end())};
255   }
instructions()256   InstructionSequence instructions() {
257     return {MakeUnwrappingIterator(instructions_.begin()),
258             MakeUnwrappingIterator(instructions_.end())};
259   }
260 
261   // Compute and return a post-order of the instructions in the computation. In
262   // this order, definitions of values always appear before their uses.
263   std::vector<HloInstruction*> MakeInstructionPostOrder() const;
264 
instruction_count()265   int64 instruction_count() const { return instruction_iterators_.size(); }
266 
267   // Creates and returns a list of the embedded computations called by this
268   // computation. This includes all embedded computations called directly or
269   // transitively. The embedded computations are sorted such that if computation
270   // A calls computation B (eg, via a map instruction) then A will appear after
271   // B in the list.
272   std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
273 
274   // Creates a fusion instruction containing the given instructions.
275   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
276   // into a library call. Instructions must be in reverse topological order
277   // (root of the fused expression first). Replaces all uses of the original
278   // root instruction with the fusion instruction. The original instructions are
279   // removed if they have no uses after fusion (this is necessarily true for at
280   // least the root).
281   HloInstruction* CreateFusionInstruction(
282       absl::Span<HloInstruction* const> instructions_to_fuse,
283       HloInstruction::FusionKind fusion_kind);
284 
285   // Create a deep copy of the given instruction and return the instruction
286   // producing the copied result. All instructions performing the copy are added
287   // to the computation. For array-shaped values, this method trivially returns
288   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
289   // with a series of kGetTupleElement and kTuple instructions. If
290   // indices_to_copy is non-null then this ShapeTree indicates which elements
291   // (arrays) of the shape to copy. Non-copied elements are passed through
292   // transparently. If copies_added is non-null, then the added kCopy
293   // instructions will be inserted in the respective index in the given
294   // ShapeTree.
295   StatusOr<HloInstruction*> DeepCopyInstruction(
296       HloInstruction* instruction,
297       const ShapeTree<bool>* indices_to_copy = nullptr,
298       ShapeTree<HloInstruction*>* copies_added = nullptr);
299 
300   // As above, but uses a custom function to copy the leaf nodes, which could
301   // create alternative HLOs other than kCopy, or even pass-throughs.
302   StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
303       HloInstruction* instruction,
304       const std::function<
305           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
306                           HloComputation* computation)>& copy_leaf);
307 
308   // Computes and returns the ProgramShape of this computation (shape of
309   // parameters and result with layout).
310   ProgramShape ComputeProgramShape(bool include_ids = true) const;
311 
312   // Return whether `*this` and `other` are functionally equivalent.
313   bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
314 
315   // Return whether `*this` and `other` are functionally equivalent.
316   bool operator==(const HloComputation& other) const {
317     return Equal(other, true);
318   }
319 
320   // Replaces old instruction with newly created instruction. Removes old
321   // instruction from computation. Updates uses and root instruction.
322   Status ReplaceWithNewInstruction(
323       HloInstruction* old_instruction,
324       std::unique_ptr<HloInstruction> new_instruction);
325 
326   // Replaces an old instruction with a newly created instruction, and adds the
327   // new instruction as an entry computation's parameter. Removes old
328   // instruction from computation. Updates uses and root instruction.
329   Status ReplaceWithNewEntryComputationParameter(
330       HloInstruction* old_instruction,
331       std::unique_ptr<HloInstruction> new_instruction);
332 
333   // Replace old instruction with new instruction.  Updates uses and root
334   // instruction. Removes old instruction from computation. Precondition:
335   // old_instruction and new_instruction must have the compatible shapes.
336   // If |new_instruction| doesn't have any sharding information it will
337   // receive the sharding information of |old_instruction|.
338   Status ReplaceInstruction(HloInstruction* old_instruction,
339                             HloInstruction* new_instruction);
340 
341   // Set/get the module containing this computation.
set_parent(HloModule * module)342   void set_parent(HloModule* module) { parent_ = module; }
parent()343   const HloModule* parent() const { return parent_; }
parent()344   HloModule* parent() { return parent_; }
345 
346   // Visit every node in the computation in DFS post-order with the given
347   // visitor. This is similar to calling HloInstruction::Accept on the root of
348   // the computation except this method also visits instructions not reachable
349   // via the root. The root instruction of the computation is visited last, and
350   // the visitor's FinishVisit method is called once upon completion (with the
351   // root instruction as the argument).
352   template <typename HloInstructionPtr>
353   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
354 
355   // Same as Accept() above, but the order of operand and control predecessor
356   // visitation is determined by the given operand order; if compare(A, B) ==
357   // true, A is visited before B.
358   Status AcceptWithOperandOrder(
359       DfsHloVisitor* visitor,
360       const HloInstruction::CompareFunction& operand_order) const;
361 
362   // Visit every node in the computation in the given order. 'order' must
363   // be a topological sort of all instructions in the computation.
364   template <typename HloInstructionPtr>
365   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
366                        absl::Span<HloInstruction* const> order) const;
367 
368   // Returns a deep copy of this computation including all instructions.
369   // If the clone context is specified, it will be populated with the cloned
370   // object mappings, and its module() will be used to add new computations
371   // into.
372   std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
373                                         HloCloneContext* context = nullptr);
374 
375   // Like Clone(), but if an instruction is present in replacement_map, we use
376   // the map's value to replace that instruction in the cloned computation.
377   //
378   // If replacements maps a key to nullptr, we remove that instruction from the
379   // new computation.  If an element of `replacements` references an instruction
380   // that's not already in the computation, it's cloned and added to the new
381   // computation.
382   //
383   // 'extra_parameters' allows to specify additional parameters that should be
384   // added to the computation.
385   //
386   // All relevant instructions are cloned, *including* unique_ptr in the
387   // `replacements` map.
388   std::unique_ptr<HloComputation> CloneWithReplacements(
389       absl::flat_hash_map<const HloInstruction*,
390                           std::unique_ptr<HloInstruction>>
391           replacements,
392       absl::Span<const HloInstruction* const> extra_parameters = {},
393       HloCloneContext* context = nullptr, const string& suffix = "clone");
394 
395   // Convenience overloads for CloneWithReplacements.  You want to do
396   //
397   //   CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}})  // ERROR
398   //
399   // but that doesn't work because std::initializer_list is not movable.  These
400   // overloads let you do
401   //
402   //   CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)});   // OK
403   //
404   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
405       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
406       HloCloneContext* context = nullptr, const string& suffix = "clone");
407   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
408       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
409       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
410       HloCloneContext* context = nullptr, const string& suffix = "clone");
411   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
412       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
413       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
414       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
415       HloCloneContext* context = nullptr, const string& suffix = "clone");
416 
417   // Returns true if the given instruction can be removed from the computation.
418   // Parameter instructions cannot be removed without violating invariants of
419   // the HLO computation with the exception of fusion computation. A parameter
420   // instruction is removable for a fusion computation.
421   //
422   // Note that IsSafelyRemovable() is a necessary condition to remove an
423   // instruction rather than a sufficient condition. For example, instructions
424   // with side-effect (e.g., Send, Infeed) may be removed from a computation,
425   // but the transformation must guarantee the invariants relevant to the
426   // instructions still hold (e.g., Send and Recv must be removed together to
427   // make each channel complete).
428   bool IsSafelyRemovable(const HloInstruction* instruction);
429 
430   // Returns a map from channel-id to the group of instructions associated with
431   // the channel. These instructions will be considered as a single node for
432   // dependency purposes. Send and RecvDone are in the group, and AllReduces
433   // with the same channel id are in the group.
434   using ChannelDependencyGroup =
435       absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
436   ChannelDependencyGroup ComputeChannelDependencies() const;
437 
438   // Returns true if this computation has a side effect. A computation has a
439   // side effect if it contains one or more instructions with a side effect.
440   bool HasSideEffect() const;
441 
442   // Returns if this computation is a fusion computation.
IsFusionComputation()443   bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
444 
445   // Returns if this computation is the entry computation of the module.
446   bool IsEntryComputation() const;
447 
448   // Returns the owning fusion instruction, or nullptr if this is not a fusion
449   // computation.
FusionInstruction()450   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)451   void SetFusionInstruction(HloInstruction* fusion_instruction) {
452     fusion_instruction_ = fusion_instruction;
453   }
454 
455   // Clear the unique ID of the computation so that it can be re-assigned, such
456   // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()457   void ClearUniqueIdInternal() { unique_id_ = -1; }
458 
459   // The id of this computation should be unique within the module.
SetUniqueId(int64 id)460   void SetUniqueId(int64 id) {
461     CHECK_EQ(unique_id_, -1);
462     CHECK_GE(id, 0);
463     unique_id_ = id;
464   }
465 
466   // Returns the instruction in this computation that has name `name`.  Returns
467   // null if there is no such computation.
468   HloInstruction* GetInstructionWithName(absl::string_view name);
469 
unique_id()470   int64 unique_id() const { return unique_id_; }
471 
472  private:
473   explicit HloComputation(
474       const string& name, int parameter_count,
475       std::vector<std::unique_ptr<HloInstruction>>* instructions,
476       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
477 
478   // Internal helper for adding instructions.
479   HloInstruction* AddInstructionInternal(
480       std::unique_ptr<HloInstruction> instruction);
481 
482   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
483   //
484   // Pre-condition: fusion_instruction's opcode is kFusion.
485   void FuseInstructionsInto(
486       absl::Span<HloInstruction* const> instructions_to_fuse,
487       HloInstruction* fusion_instruction);
488 
489   // Internal helper for recursive copying of an instruction. Creates and
490   // returns a deep copy of the given instruction.
491   StatusOr<HloInstruction*> DeepCopyHelper(
492       HloInstruction* instruction, ShapeIndex* index,
493       const std::function<
494           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
495                           HloComputation* computation)>& copy_leaf);
496 
497   // Internal helper to collect unreachable roots.
498   std::vector<HloInstruction*> CollectUnreachableRoots() const;
499 
500   enum VisitState { kVisiting, kVisited };
501   void ComputeInstructionPostOrder(
502       const HloComputation::ChannelDependencyGroup& channel_dependency_map,
503       std::vector<HloInstruction*>* post_order, HloInstruction* root,
504       absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
505 
506   Status RemoveUnusedParametersImpl(bool allow_non_fusion);
507 
508   Status RemoveInstructionImpl(HloInstruction* instruction,
509                                bool ignore_safety_check);
510 
511   string name_;
512   int64 unique_id_;
513   HloInstruction* root_instruction_;
514 
515   // If this computation is a fusion computation, this field points to the
516   // corresponding fusion instruction.  Otherwise, this is null.
517   HloInstruction* fusion_instruction_;
518 
519   // Module containing this computation.
520   HloModule* parent_ = nullptr;
521 
522   // Store instructions in std::list as they can be added and removed
523   // arbitrarily and we want a stable iteration order. Keep a map from
524   // instruction pointer to location in the list for fast lookup.
525   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
526   InstructionList instructions_;
527   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
528       instruction_iterators_;
529 
530   std::vector<HloInstruction*> param_instructions_;
531 
532   TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
533 };
534 
535 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)536 Status HloComputation::Accept(
537     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
538   // Visit unreachable roots. Beware that the visitor might delete the currently
539   // visited root, which would invalidate iterators if the unreachable roots
540   // weren't computed ahead of time.
541   for (HloInstruction* root : CollectUnreachableRoots()) {
542     VLOG(3) << "Traversing unreachable root: " << root->ToString();
543     // Call FinishVisit only at the end.
544     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
545   }
546   // Visit the computation root instruction last.
547   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
548 }
549 
550 // Explicit instantiations.
551 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
552 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
553 
554 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)555 Status HloComputation::AcceptOrdered(
556     DfsHloVisitorBase<HloInstructionPtr>* visitor,
557     absl::Span<HloInstruction* const> order) const {
558   VLOG(3) << "Accepting visitor with order.";
559   for (HloInstruction* root : CollectUnreachableRoots()) {
560     TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
561   }
562   TF_RET_CHECK(order.size() == instruction_count());
563   absl::flat_hash_set<const HloInstruction*> visited;
564   for (const HloInstruction* instruction : order) {
565     VLOG(3) << "Visiting ordered: " << instruction->ToString();
566     TF_RET_CHECK(instruction_iterators_.contains(instruction))
567         << "Instruction " << instruction->name() << " is not in computation "
568         << name();
569     TF_RET_CHECK(!visited.contains(instruction))
570         << "Instruction " << instruction->name()
571         << " appears more than once in order";
572     HloInstruction* mutable_instruction =
573         const_cast<HloInstruction*>(instruction);
574     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
575     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
576     visitor->SetVisited(*mutable_instruction);
577     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
578     visited.insert(instruction);
579   }
580   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
581   return Status::OK();
582 }
583 
584 // Explicit instantiations.
585 template Status HloComputation::AcceptOrdered(
586     DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
587 template Status HloComputation::AcceptOrdered(
588     ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
589 
590 }  // namespace xla
591 
592 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
593