• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18 
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class HloModule;
48 
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function.  It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node).  If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program.  Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation".  For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise.  (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66  public:
67   // Builder class for HloComputation.
68   class Builder {
69    public:
70     explicit Builder(const string& name,
71                      HloInstruction* fusion_instruction = nullptr)
name_(name)72         : name_(name),
73           last_added_instruction_(nullptr),
74           fusion_instruction_(fusion_instruction) {}
75     Builder(Builder&& b) = default;
76     virtual ~Builder() = default;
77 
78     // Build and return an HloComputation. The parameter root_instruction
79     // specifies the already-added instruction to use as the root. If
80     // root_instruction is nullptr then use the last added instruction as the
81     // root.
82     std::unique_ptr<HloComputation> Build(
83         HloInstruction* root_instruction = nullptr);
84 
AddInstruction(std::unique_ptr<HloInstruction> instruction)85     virtual HloInstruction* AddInstruction(
86         std::unique_ptr<HloInstruction> instruction) {
87       instructions_.push_back(std::move(instruction));
88       last_added_instruction_ = instructions_.back().get();
89       return last_added_instruction_;
90     }
91 
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)92     Status ForEachInstruction(
93         const std::function<Status(const HloInstruction*)>& func) const {
94       for (const auto& instruction : instructions_) {
95         TF_RETURN_IF_ERROR(func(instruction.get()));
96       }
97       return Status::OK();
98     }
99 
100    private:
101     const string name_;
102     HloInstruction* last_added_instruction_;
103     HloInstruction* fusion_instruction_;
104     std::vector<std::unique_ptr<HloInstruction>> instructions_;
105 
106     TF_DISALLOW_COPY_AND_ASSIGN(Builder);
107   };
108 
109   // Helper class to automatically set the OpMetadata for every instruction
110   // added to a computation.
111   class MetadataBuilder {
112    public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)113     MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
114         : computation_(computation), metadata_(metadata) {}
115 
AddInstruction(std::unique_ptr<HloInstruction> instruction)116     HloInstruction* AddInstruction(
117         std::unique_ptr<HloInstruction> instruction) {
118       instruction->set_metadata(metadata_);
119       return computation_->AddInstruction(std::move(instruction));
120     }
121 
122    private:
123     HloComputation* computation_;
124     OpMetadata metadata_;
125   };
126 
127   ~HloComputation();
128 
129   // Add an instruction to the computation. The computation takes ownership of
130   // the instruction.
131   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
132                                  const std::string& new_name = "");
133 
134   // Remove the param_no'th parameter from the computation.
135   // Note this is only applicatable to the computation for the fusion
136   // instruction.
137   Status RemoveParameter(int64_t param_no);
138 
139   // Remove unused parameters from the computation.
140   // Note this is only applicatable to the computation for the fusion
141   // instruction.
142   Status RemoveUnusedParametersFromFusedComputation();
143 
144   // Remove unused parameters from the computation. Unlike
145   // RemoveUnusedParametersFromFusedComputation, this function can be used
146   // to remove parameters from non-fusion computations.
147   Status RemoveUnusedParametersFromAnyComputation();
148 
149   // Adds a new parameter instruction to a fusion computation.
150   //
151   // This should be a new parameter. Instruction will be appended to parameters
152   // and inserted to the instruction list.
153   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
154 
155   // Adds a new parameter instruction to the entry computation and update
156   // the parent module config to reflect the change.
157   //
158   // This should be a new parameter. Instruction will be appended to parameters
159   // and inserted to the instruction list.
160   HloInstruction* AddEntryComputationParameter(
161       std::unique_ptr<HloInstruction> instruction);
162 
163   // Replaces an old parameter with a new parameter. Adds the new parameter
164   // instruction to the entry computation.
165   Status ReplaceEntryComputationParameter(
166       int64_t param_no, HloInstruction* old_instruction,
167       std::unique_ptr<HloInstruction> instruction);
168 
169   // Remove an instruction from the computation. The instruction must have no
170   // users. Instruction is deallocated with this call.
171   Status RemoveInstruction(HloInstruction* instruction);
172 
173   // Removes an instruction from the computation. The instruction must have no
174   // users. Instruction is deallocated with this call. The instruction will be
175   // removed even if it is marked as not removable.
176   Status ForceRemoveInstruction(HloInstruction* instruction);
177 
178   // Remove an instruction (including side effecting ones) from the computation
179   // and also transitively any operand that has no side effect and no users post
180   // removing an instruction. The instruction must have no users. Instruction is
181   // deallocated with this call. If given, the cleanup routine is executed on a
182   // removed instruction before its deallocation.
183   Status RemoveInstructionAndUnusedOperands(
184       HloInstruction* instruction,
185       std::function<void(HloInstruction*)> cleanup = nullptr);
186 
187   // Set the root of the computation to the given instruction. The instruction
188   // must have already been added to the computation. In addition it must have
189   // the same shape as the result of the computation for non fusion
190   // computations, except if accept_different_shape is set to true.
191   void set_root_instruction(HloInstruction* new_root_instruction,
192                             bool accept_different_shape = false);
193 
194   // Return the root instruction of the computation. The root instruction is the
195   // instruction which produces the output of the computation.
root_instruction()196   HloInstruction* root_instruction() const { return root_instruction_; }
197 
198   // Returns the number of parameters for this computation.
num_parameters()199   int64 num_parameters() const { return param_instructions_.size(); }
200 
201   // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64_t param_no)202   HloInstruction* parameter_instruction(int64_t param_no) const {
203     CHECK_GE(param_no, 0);
204     CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
205         << "Computation " << name() << " has no parameter number " << param_no;
206     return param_instructions_[param_no];
207   }
208 
parameter_instructions()209   const std::vector<HloInstruction*>& parameter_instructions() const {
210     return param_instructions_;
211   }
212 
name()213   const string& name() const { return name_; }
214 
215   // Use the given NameUniquer to select a unique name for the computation based
216   // on the computation's existing name.
217   void UniquifyName(NameUniquer* name_uniquer);
218 
219   // Return a string representation of the computation.
220   //
221   // (We express the default options using an overload rather than a default
222   // param because gdb ignores default params, but does resolve overloads.)
ToString()223   string ToString() const { return ToString(HloPrintOptions()); }
224   string ToString(const HloPrintOptions& options) const;
225 
226   // Overload which accepts an order to emit the instructions in.
227   string ToString(
228       const HloPrintOptions& options,
229       absl::Span<const HloInstruction* const> instruction_order) const;
230 
231   // Returns a serialized representation of this computation.
232   HloComputationProto ToProto() const;
233 
234   // Creates a computation from the given proto. Arguments:
235   //
236   //   proto: the proto to convert from.
237   //   computation_map: a map from computation id to HloComputation*. This map
238   //     must contain all computations which the newly constructed computation
239   //     calls.
240   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
241       const HloComputationProto& proto,
242       const absl::flat_hash_map<int64, HloComputation*>& computation_map,
243       bool prohibit_empty_literal = true);
244 
245   using InstructionSequence = tensorflow::gtl::iterator_range<
246       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
247 
248   using ConstInstructionSequence =
249       tensorflow::gtl::iterator_range<UnwrappingIterator<
250           std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
251 
252   // Gets the instructions in this computation.
253   //
254   // The returned type is a range of HloInstruction*s, so you can iterate over
255   // it using a range-based for loop in the natural way:
256   //
257   //   for (HloInstruction* instr : computation->instructions()) { ... }
258   //
instructions()259   ConstInstructionSequence instructions() const {
260     return {MakeUnwrappingIterator(instructions_.begin()),
261             MakeUnwrappingIterator(instructions_.end())};
262   }
instructions()263   InstructionSequence instructions() {
264     return {MakeUnwrappingIterator(instructions_.begin()),
265             MakeUnwrappingIterator(instructions_.end())};
266   }
267 
268   // Compute and return a post-order of the instructions in the computation. In
269   // this order, definitions of values always appear before their uses.
270   std::vector<HloInstruction*> MakeInstructionPostOrder() const;
271 
instruction_count()272   int64 instruction_count() const { return instruction_iterators_.size(); }
273 
274   // Creates and returns a list of the embedded computations called by this
275   // computation. This includes all embedded computations called directly or
276   // transitively. The embedded computations are sorted such that if computation
277   // A calls computation B (eg, via a map instruction) then A will appear after
278   // B in the list.
279   std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
280 
281   // Creates a fusion instruction containing the given instructions.
282   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
283   // into a library call. Instructions must be in reverse topological order
284   // (root of the fused expression first). Replaces all uses of the original
285   // root instruction with the fusion instruction. The original instructions are
286   // removed if they have no uses after fusion (this is necessarily true for at
287   // least the root).
288   HloInstruction* CreateFusionInstruction(
289       absl::Span<HloInstruction* const> instructions_to_fuse,
290       HloInstruction::FusionKind fusion_kind);
291 
292   // Create a deep copy of the given instruction and return the instruction
293   // producing the copied result. All instructions performing the copy are added
294   // to the computation. For array-shaped values, this method trivially returns
295   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
296   // with a series of kGetTupleElement and kTuple instructions. If
297   // indices_to_copy is non-null then this ShapeTree indicates which elements
298   // (arrays) of the shape to copy. Non-copied elements are passed through
299   // transparently. If copies_added is non-null, then the added kCopy
300   // instructions will be inserted in the respective index in the given
301   // ShapeTree.
302   StatusOr<HloInstruction*> DeepCopyInstruction(
303       HloInstruction* instruction,
304       const ShapeTree<bool>* indices_to_copy = nullptr,
305       ShapeTree<HloInstruction*>* copies_added = nullptr);
306 
307   // As above, but uses a custom function to copy the leaf nodes, which could
308   // create alternative HLOs other than kCopy, or even pass-throughs.
309   StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
310       HloInstruction* instruction,
311       const std::function<
312           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
313                           HloComputation* computation)>& copy_leaf);
314 
315   // Computes and returns the ProgramShape of this computation (shape of
316   // parameters and result with layout).
317   ProgramShape ComputeProgramShape(bool include_ids = true) const;
318 
319   // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)320   bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
321     return EqualInternal(other, is_layout_sensitive,
322                          /*ignore_channel_id_values=*/false);
323   }
324 
325   // Same as Equal() but ignores channel ID value mismatches on instructions, as
326   // long as the two instructions both have channel IDs or neither has a channel
327   // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)328   bool EqualIgnoringChannelIdValues(const HloComputation& other,
329                                     bool is_layout_sensitive) const {
330     return EqualInternal(other, is_layout_sensitive,
331                          /*ignore_channel_id_values=*/true);
332   }
333 
334   // Return whether `*this` and `other` are functionally equivalent.
335   bool operator==(const HloComputation& other) const {
336     return Equal(other, true);
337   }
338 
339   // Replaces old instruction with newly created instruction. Removes old
340   // instruction from computation. Updates uses and root instruction.
341   Status ReplaceWithNewInstruction(
342       HloInstruction* old_instruction,
343       std::unique_ptr<HloInstruction> new_instruction);
344 
345   // Replaces an old instruction with a newly created instruction, and adds the
346   // new instruction as an entry computation's parameter. Removes old
347   // instruction from computation. Updates uses and root instruction.
348   Status ReplaceWithNewEntryComputationParameter(
349       HloInstruction* old_instruction,
350       std::unique_ptr<HloInstruction> new_instruction);
351 
352   // Replace old instruction with new instruction.  Updates uses and root
353   // instruction. Removes old instruction from computation. Precondition:
354   // old_instruction and new_instruction must have the compatible shapes.
355   // If |new_instruction| doesn't have any sharding information it will
356   // receive the sharding information of |old_instruction|.
357   Status ReplaceInstruction(HloInstruction* old_instruction,
358                             HloInstruction* new_instruction);
359 
360   // Set/get the module containing this computation.
set_parent(HloModule * module)361   void set_parent(HloModule* module) { parent_ = module; }
parent()362   const HloModule* parent() const { return parent_; }
parent()363   HloModule* parent() { return parent_; }
364 
365   // Visit every node in the computation in DFS post-order with the given
366   // visitor. This is similar to calling HloInstruction::Accept on the root of
367   // the computation except this method also visits instructions not reachable
368   // via the root. The root instruction of the computation is visited last, and
369   // the visitor's FinishVisit method is called once upon completion (with the
370   // root instruction as the argument).
371   template <typename HloInstructionPtr>
372   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
373 
374   // Same as Accept() above, but the order of operand and control predecessor
375   // visitation is determined by the given operand order; if compare(A, B) ==
376   // true, A is visited before B.
377   Status AcceptWithOperandOrder(
378       DfsHloVisitor* visitor,
379       const HloInstruction::CompareFunction& operand_order) const;
380 
381   // Visit every node in the computation in the given order. 'order' must
382   // be a topological sort of all instructions in the computation.
383   template <typename HloInstructionPtr>
384   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
385                        absl::Span<HloInstruction* const> order) const;
386 
387   // Returns a deep copy of this computation including all instructions.
388   // If the clone context is specified, it will be populated with the cloned
389   // object mappings, and its module() will be used to add new computations
390   // into.
391   std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
392                                         HloCloneContext* context = nullptr);
393 
394   // Like Clone(), but if an instruction is present in replacement_map, we use
395   // the map's value to replace that instruction in the cloned computation.
396   //
397   // If replacements maps a key to nullptr, we remove that instruction from the
398   // new computation.  If an element of `replacements` references an instruction
399   // that's not already in the computation, it's cloned and added to the new
400   // computation.
401   //
402   // 'extra_parameters' allows to specify additional parameters that should be
403   // added to the computation.
404   //
405   // All relevant instructions are cloned, *including* unique_ptr in the
406   // `replacements` map.
407   std::unique_ptr<HloComputation> CloneWithReplacements(
408       absl::flat_hash_map<const HloInstruction*,
409                           std::unique_ptr<HloInstruction>>
410           replacements,
411       absl::Span<const HloInstruction* const> extra_parameters = {},
412       HloCloneContext* context = nullptr, const string& suffix = "clone",
413       const HloInstruction* new_root = nullptr);
414 
415   // Convenience overloads for CloneWithReplacements.  You want to do
416   //
417   //   CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}})  // ERROR
418   //
419   // but that doesn't work because std::initializer_list is not movable.  These
420   // overloads let you do
421   //
422   //   CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)});   // OK
423   //
424   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
425       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
426       HloCloneContext* context = nullptr, const string& suffix = "clone");
427   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
428       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
429       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
430       HloCloneContext* context = nullptr, const string& suffix = "clone");
431   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
432       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
433       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
434       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
435       HloCloneContext* context = nullptr, const string& suffix = "clone");
436 
437   // Returns true if the given instruction can be removed from the computation.
438   // Parameter instructions cannot be removed without violating invariants of
439   // the HLO computation with the exception of fusion computation. A parameter
440   // instruction is removable for a fusion computation.
441   //
442   // Note that IsSafelyRemovable() is a necessary condition to remove an
443   // instruction rather than a sufficient condition. For example, instructions
444   // with side-effect (e.g., Send, Infeed) may be removed from a computation,
445   // but the transformation must guarantee the invariants relevant to the
446   // instructions still hold (e.g., Send and Recv must be removed together to
447   // make each channel complete).
448   bool IsSafelyRemovable(const HloInstruction* instruction);
449 
450   // Returns a map from channel-id to the group of instructions associated with
451   // the channel. These instructions will be considered as a single node for
452   // dependency purposes. Send and RecvDone are in the group, and AllReduces
453   // with the same channel id are in the group.
454   using ChannelDependencyGroup =
455       absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
456   ChannelDependencyGroup ComputeChannelDependencies() const;
457 
458   // Returns true if this computation has a side effect. A computation has a
459   // side effect if it contains one or more instructions with a side effect.
460   bool HasSideEffect() const;
461 
462   // Returns if this computation is a fusion computation.
463   // Do not use this method to determine if fusion_instruction_ != nullptr.
464   // Instead, directly do: FusionInstruction() != nullptr
IsFusionComputation()465   bool IsFusionComputation() const { return is_fusion_computation_; }
466 
467   // Returns if this computation is the entry computation of the module.
468   bool IsEntryComputation() const;
469 
470   // Returns the owning fusion instruction, or nullptr if this is not a fusion
471   // computation.
FusionInstruction()472   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)473   void SetFusionInstruction(HloInstruction* fusion_instruction) {
474     fusion_instruction_ = fusion_instruction;
475     is_fusion_computation_ |= (fusion_instruction != nullptr);
476   }
477 
478   // Returns if this computation is a custom-call computation.
IsCustomCallComputation()479   bool IsCustomCallComputation() const { return is_custom_call_computation_; }
480 
481   // Returns the owning custom call instruction, or nullptr if this is not a
482   // custom call computation.
CustomCallInstruction()483   HloInstruction* CustomCallInstruction() const {
484     return custom_call_instruction_;
485   }
SetCustomCallInstruction(HloInstruction * custom_call_instruction)486   void SetCustomCallInstruction(HloInstruction* custom_call_instruction) {
487     custom_call_instruction_ = custom_call_instruction;
488     is_custom_call_computation_ |= (custom_call_instruction != nullptr);
489   }
490 
491   // Returns if this computation is invoked by an Hlo instruction.
IsCalledComputation()492   bool IsCalledComputation() const {
493     return IsFusionComputation() || IsCustomCallComputation();
494   }
495 
496   // Clear the unique ID of the computation so that it can be re-assigned, such
497   // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()498   void ClearUniqueIdInternal() { unique_id_ = -1; }
499 
500   // The id of this computation should be unique within the module.
SetUniqueId(int64_t id)501   void SetUniqueId(int64_t id) {
502     CHECK_EQ(unique_id_, -1);
503     CHECK_GE(id, 0);
504     unique_id_ = id;
505   }
506 
507   // Returns the instruction in this computation that has name `name`.  Returns
508   // null if there is no such computation.
509   HloInstruction* GetInstructionWithName(absl::string_view name);
510 
unique_id()511   int64 unique_id() const { return unique_id_; }
512 
513   // Deallocate instructions that are marked by "RemoveInstruction". The two
514   // stage clean up process is designed such that HloPass can have stable
515   // internal pointers to HloInstructions while we create and remove
516   // HloInstructions in a pass.
Cleanup()517   void Cleanup() { to_be_deleted_.clear(); }
518 
519   // Returns true if a given instruction is marked dead in this computation.
520   bool IsMarkedAsDead(const HloInstruction* inst);
521 
522  private:
523   explicit HloComputation(
524       const string& name, int parameter_count,
525       std::vector<std::unique_ptr<HloInstruction>>* instructions,
526       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
527 
528   // Internal helper for adding instructions.
529   HloInstruction* AddInstructionInternal(
530       std::unique_ptr<HloInstruction> instruction);
531 
532   // Internal helper for comparison with different options.
533   bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
534                      bool ignore_channel_id_values) const;
535 
536   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
537   //
538   // Pre-condition: fusion_instruction's opcode is kFusion.
539   void FuseInstructionsInto(
540       absl::Span<HloInstruction* const> instructions_to_fuse,
541       HloInstruction* fusion_instruction);
542 
543   // Internal helper for recursive copying of an instruction. Creates and
544   // returns a deep copy of the given instruction.
545   StatusOr<HloInstruction*> DeepCopyHelper(
546       HloInstruction* instruction, ShapeIndex* index,
547       const std::function<
548           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
549                           HloComputation* computation)>& copy_leaf);
550 
551   // Internal helper to collect unreachable roots.
552   std::vector<HloInstruction*> CollectUnreachableRoots() const;
553 
554   enum VisitState { kVisiting, kVisited };
555   void ComputeInstructionPostOrder(
556       const HloComputation::ChannelDependencyGroup& channel_dependency_group,
557       std::vector<HloInstruction*>* post_order, HloInstruction* root,
558       absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
559 
560   Status RemoveUnusedParametersImpl(bool allow_non_fusion);
561 
562   Status RemoveInstructionImpl(HloInstruction* instruction,
563                                bool ignore_safety_check);
564 
565   string name_;
566   int64 unique_id_;
567   HloInstruction* root_instruction_;
568 
569   // If this computation is a fusion computation, this field points to the
570   // corresponding fusion instruction (if it is live). Otherwise, this is null.
571   HloInstruction* fusion_instruction_;
572 
573   // Determines whether this computation is a fusion computation. A fusion
574   // computation ordinarily also has a non-null fusion_instruction_. However, if
575   // a fusion instruction is removed during compilation, the fusion computation
576   // becomes unreachable, and its fusion_instruction_ is set to null. We still
577   // need to regard such computations as fusion computations for HLO scheduling
578   // purposes.
579   bool is_fusion_computation_;
580 
581   // If this computation is a custom-call computation, this field points to the
582   // corresponding custom-call instruction (if it is live). Otherwise, this is
583   // null.
584   HloInstruction* custom_call_instruction_;
585 
586   // Determines whether this computation is a custom-call computation. A
587   bool is_custom_call_computation_;
588 
589   // Module containing this computation.
590   HloModule* parent_ = nullptr;
591 
592   // Store instructions in std::list as they can be added and removed
593   // arbitrarily and we want a stable iteration order. Keep a map from
594   // instruction pointer to location in the list for fast lookup.
595   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
596   InstructionList instructions_;
597   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
598       instruction_iterators_;
599 
600   // Removed instructions are moved into to_be_deleted_ first and then
601   // deallocated when Cleanup is called.
602   std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
603 
604   std::vector<HloInstruction*> param_instructions_;
605 
606   TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
607 };
608 
609 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)610 Status HloComputation::Accept(
611     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
612   // Visit unreachable roots. Beware that the visitor might delete the currently
613   // visited root, which would invalidate iterators if the unreachable roots
614   // weren't computed ahead of time.
615   for (HloInstruction* root : CollectUnreachableRoots()) {
616     VLOG(3) << "Traversing unreachable root: " << root->ToString();
617     // Call FinishVisit only at the end.
618     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
619   }
620   // Visit the computation root instruction last.
621   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
622 }
623 
624 // Explicit instantiations.
625 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
626 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
627 
628 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)629 Status HloComputation::AcceptOrdered(
630     DfsHloVisitorBase<HloInstructionPtr>* visitor,
631     absl::Span<HloInstruction* const> order) const {
632   VLOG(3) << "Accepting visitor with order.";
633   for (HloInstruction* root : CollectUnreachableRoots()) {
634     TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
635   }
636   TF_RET_CHECK(order.size() == instruction_count());
637   absl::flat_hash_set<const HloInstruction*> visited;
638   for (const HloInstruction* instruction : order) {
639     VLOG(3) << "Visiting ordered: " << instruction->ToString();
640     TF_RET_CHECK(instruction_iterators_.contains(instruction))
641         << "Instruction " << instruction->name() << " is not in computation "
642         << name();
643     TF_RET_CHECK(!visited.contains(instruction))
644         << "Instruction " << instruction->name()
645         << " appears more than once in order";
646     HloInstruction* mutable_instruction =
647         const_cast<HloInstruction*>(instruction);
648     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
649     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
650     visitor->SetVisited(*mutable_instruction);
651     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
652     visited.insert(instruction);
653   }
654   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
655   return Status::OK();
656 }
657 
658 // Explicit instantiations.
659 template Status HloComputation::AcceptOrdered(
660     DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
661 template Status HloComputation::AcceptOrdered(
662     ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
663 
664 }  // namespace xla
665 
666 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
667