• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18 
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/strings/cord.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/iterator_util.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
35 #include "tensorflow/compiler/xla/service/hlo.pb.h"
36 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/name_uniquer.h"
39 #include "tensorflow/compiler/xla/shape_tree.h"
40 #include "tensorflow/compiler/xla/statusor.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 
45 namespace xla {
46 
47 class HloModule;
48 
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function.  It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node).  If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program.  Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation".  For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise.  (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66  public:
67   // Used by instructions_.
68   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
69 
70   // Builder class for HloComputation.
71   class Builder {
72    public:
73     explicit Builder(const std::string& name,
74                      HloInstruction* fusion_instruction = nullptr)
name_(name)75         : name_(name),
76           last_added_instruction_(nullptr),
77           fusion_instruction_(fusion_instruction) {}
78     Builder(Builder&& b) = default;
79     virtual ~Builder() = default;
80 
81     // Build and return an HloComputation. The parameter root_instruction
82     // specifies the already-added instruction to use as the root. If
83     // root_instruction is nullptr then use the last added instruction as the
84     // root.
85     std::unique_ptr<HloComputation> Build(
86         HloInstruction* root_instruction = nullptr);
87 
88     // Add the instruction to be part of this computation.
89     // If the new instruction is derived from another one,
90     // you probably want to do
91     // `original_inst->AddInstruction(new_inst)` instead.
AddInstruction(std::unique_ptr<HloInstruction> instruction)92     virtual HloInstruction* AddInstruction(
93         std::unique_ptr<HloInstruction> instruction) {
94       instructions_.push_back(std::move(instruction));
95       last_added_instruction_ = instructions_.back().get();
96       return last_added_instruction_;
97     }
98 
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)99     Status ForEachInstruction(
100         const std::function<Status(const HloInstruction*)>& func) const {
101       for (const auto& instruction : instructions_) {
102         TF_RETURN_IF_ERROR(func(instruction.get()));
103       }
104       return OkStatus();
105     }
106 
last_added_instruction()107     HloInstruction* last_added_instruction() const {
108       return last_added_instruction_;
109     }
110 
111    private:
112     const std::string name_;
113     HloInstruction* last_added_instruction_;
114     HloInstruction* fusion_instruction_;
115     std::vector<std::unique_ptr<HloInstruction>> instructions_;
116 
117     Builder(const Builder&) = delete;
118     Builder& operator=(const Builder&) = delete;
119   };
120 
121   // Helper class to automatically set the OpMetadata for every instruction
122   // added to a computation.
123   class MetadataBuilder {
124    public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)125     MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
126         : computation_(computation), metadata_(metadata) {}
127 
AddInstruction(std::unique_ptr<HloInstruction> instruction)128     HloInstruction* AddInstruction(
129         std::unique_ptr<HloInstruction> instruction) {
130       instruction->set_metadata(metadata_);
131       return computation_->AddInstruction(std::move(instruction));
132     }
133 
134    private:
135     HloComputation* computation_;
136     OpMetadata metadata_;
137   };
138 
139   ~HloComputation();
140 
141   // Add an instruction to the computation. The computation takes ownership of
142   // the instruction.
143   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
144                                  const std::string& new_name = "");
145 
146   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
147                                  const OpMetadata* metadata);
148 
149   // Replace the old parameter at index param_no with
150   // `instruction`. Updates uses and root instruction. Removes old
151   // instruction from computation. No check is done on the shape.
152   HloInstruction* ReplaceParameter(int64_t param_no,
153                                    std::unique_ptr<HloInstruction> instruction);
154 
155   // Remove the param_no'th parameter from the computation.
156   // Note this is only applicatable to the computation for the fusion
157   // instruction.
158   Status RemoveParameter(int64_t param_no);
159 
160   // Remove unused parameters from the computation.
161   // Note this is only applicatable to the computation for the fusion
162   // instruction.
163   Status RemoveUnusedParametersFromFusedComputation();
164 
165   // Remove unused parameters from the computation. Unlike
166   // RemoveUnusedParametersFromFusedComputation, this function can be used
167   // to remove parameters from non-fusion computations.
168   Status RemoveUnusedParametersFromAnyComputation();
169 
170   // Adds a new parameter instruction to a fusion computation.
171   //
172   // This should be a new parameter. Instruction will be appended to parameters
173   // and inserted to the instruction list.
174   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
175 
176   // Adds a new parameter instruction to the entry computation and update
177   // the parent module config to reflect the change.
178   //
179   // This should be a new parameter. Instruction will be appended to parameters
180   // and inserted to the instruction list.
181   HloInstruction* AddEntryComputationParameter(
182       std::unique_ptr<HloInstruction> instruction);
183 
184   // Replaces an old parameter with a new parameter. Adds the new parameter
185   // instruction to the entry computation.  Updates users instruction.
186   Status ReplaceEntryComputationParameter(
187       int64_t param_no, HloInstruction* old_instruction,
188       std::unique_ptr<HloInstruction> instruction);
189 
190   // Remove an instruction from the computation. The instruction must have no
191   // users. Instruction is deallocated with this call.
192   Status RemoveInstruction(HloInstruction* instruction);
193 
194   // Removes an instruction from the computation. The instruction must have no
195   // users. Instruction is deallocated with this call. The instruction will be
196   // removed even if it is marked as not removable.
197   Status ForceRemoveInstruction(HloInstruction* instruction);
198 
199   // Remove an instruction (including side effecting ones) from the computation
200   // and also transitively any operand that has no side effect and no users post
201   // removing an instruction. The instruction must have no users. Instruction is
202   // deallocated with this call. If given, the cleanup routine is executed on a
203   // removed instruction before its deallocation.
204   Status RemoveInstructionAndUnusedOperands(
205       HloInstruction* instruction,
206       std::function<void(HloInstruction*)> cleanup = nullptr);
207 
208   // Set the root of the computation to the given instruction. The instruction
209   // must have already been added to the computation. In addition it must have
210   // the same shape as the result of the computation for non fusion
211   // computations, except if accept_different_shape is set to true.
212   void set_root_instruction(HloInstruction* new_root_instruction,
213                             bool accept_different_shape = false);
214 
215   // Return the root instruction of the computation. The root instruction is the
216   // instruction which produces the output of the computation.
root_instruction()217   HloInstruction* root_instruction() const { return root_instruction_; }
218 
219   // Returns the number of parameters for this computation.
num_parameters()220   int64_t num_parameters() const { return param_instructions_.size(); }
221 
222   // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64_t param_no)223   HloInstruction* parameter_instruction(int64_t param_no) const {
224     CHECK_GE(param_no, 0);
225     CHECK_LT(param_no, static_cast<int64_t>(param_instructions_.size()))
226         << "Computation " << name() << " has no parameter number " << param_no;
227     return param_instructions_[param_no];
228   }
229 
parameter_instructions()230   const std::vector<HloInstruction*>& parameter_instructions() const {
231     return param_instructions_;
232   }
233 
name()234   const std::string& name() const { return name_; }
235 
236   // Use the given NameUniquer to select a unique name for the computation based
237   // on the computation's existing name.
238   void UniquifyName(NameUniquer* name_uniquer);
239 
240   // Return a string representation of the computation.
241   //
242   // (We express the default options using an overload rather than a default
243   // param because gdb ignores default params, but does resolve overloads.)
ToString()244   std::string ToString() const { return ToString(HloPrintOptions()); }
245   std::string ToString(const HloPrintOptions& options) const;
246 
247   // Overload which accepts an order to emit the instructions in.
248   std::string ToString(
249       const HloPrintOptions& options,
250       absl::Span<const HloInstruction* const> instruction_order) const;
251 
252   // Returns a Cord representation of the computation.
253   //
254   // (We express the default options using an overload rather than a default
255   // param because gdb ignores default params, but does resolve overloads.)
ToCord()256   absl::Cord ToCord() const { return ToCord(HloPrintOptions()); }
257   absl::Cord ToCord(const HloPrintOptions& options) const;
258 
259   // Overload which accepts an order to emit the instructions in.
260   absl::Cord ToCord(
261       const HloPrintOptions& options,
262       absl::Span<const HloInstruction* const> instruction_order) const;
263 
264   // Returns a serialized representation of this computation.
265   HloComputationProto ToProto() const;
266 
267   // Creates a computation from the given proto. Arguments:
268   //
269   //   proto: the proto to convert from.
270   //   computation_map: a map from computation id to HloComputation*. This map
271   //     must contain all computations which the newly constructed computation
272   //     calls.
273   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
274       const HloComputationProto& proto,
275       const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
276       bool prohibit_empty_literal = true);
277 
278   // Generates a hash value of an HLO computation. Hash considers
279   // information on opcode, shape, operands, and typically a root instruction.
280   // This function returns the same hash value for equivalent HLO computations,
281   // with respect to HloComputation::Equal() method.
282   template <typename H>
AbslHashValue(H h,const HloComputation & computation)283   friend H AbslHashValue(H h, const HloComputation& computation) {
284     auto instructions = computation.MakeInstructionPostOrder();
285     for (auto* instruction : instructions) {
286       h = H::combine(std::move(h), *instruction);
287     }
288     return H::combine(std::move(h), instructions.size());
289   }
290 
291   using InstructionSequence = tensorflow::gtl::iterator_range<
292       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
293 
294   using ConstInstructionSequence =
295       tensorflow::gtl::iterator_range<UnwrappingIterator<
296           std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
297 
298   // Gets the instructions in this computation.
299   //
300   // The returned type is a range of HloInstruction*s, so you can iterate over
301   // it using a range-based for loop in the natural way:
302   //
303   //   for (HloInstruction* instr : computation->instructions()) { ... }
304   //
instructions()305   ConstInstructionSequence instructions() const {
306     return {MakeUnwrappingIterator(instructions_.begin()),
307             MakeUnwrappingIterator(instructions_.end())};
308   }
instructions()309   InstructionSequence instructions() {
310     return {MakeUnwrappingIterator(instructions_.begin()),
311             MakeUnwrappingIterator(instructions_.end())};
312   }
313 
314   // Compute and return a post-order of the instructions in the computation. In
315   // this order, definitions of values always appear before their uses.
316   std::vector<HloInstruction*> MakeInstructionPostOrder() const;
317 
instruction_count()318   int64_t instruction_count() const { return instruction_iterators_.size(); }
319 
320   // Creates and returns a list of the embedded computations called by this
321   // computation. This includes all embedded computations called directly or
322   // transitively. The embedded computations are sorted such that if computation
323   // A calls computation B (eg, via a map instruction) then A will appear after
324   // B in the list.
325   std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
326 
327   // Creates a fusion instruction containing the given instructions.
328   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
329   // into a library call. Instructions must be in reverse topological order
330   // (root of the fused expression first). Replaces all uses of the original
331   // root instruction with the fusion instruction. The original instructions are
332   // removed if they have no uses after fusion (this is necessarily true for at
333   // least the root).
334   HloInstruction* CreateFusionInstruction(
335       absl::Span<HloInstruction* const> instructions_to_fuse,
336       HloInstruction::FusionKind fusion_kind);
337 
338   // Creates a call instruction containing the given instructions.  Instructions
339   // must be in reverse topological order (root of the called computation
340   // first). Replaces all uses of the original root instruction with the call
341   // instruction. The original instructions are removed if they have no uses
342   // after creating the call (this is necessarily true for at least the root).
343   HloInstruction* CreateCallInstruction(
344       absl::Span<HloInstruction* const> instructions_to_call);
345 
346   // Creates an async start/done instruction pair where instruction is wrapped
347   // inside an asynchronous computation. The context shapes are appended to the
348   // output tuple of the asynchronous start which is backend specific. Returns
349   // the async done instruction. The new async start instruction is the operand
350   // of the async done instruction so that can be accessed using that. If
351   // present, `async_execution_thread` will be attached to the
352   // async-start/update/done instructions as well as wrapped computations.
353   StatusOr<HloInstruction*> CreateAsyncInstructions(
354       HloInstruction* instruction, absl::Span<const Shape> context_shapes,
355       absl::string_view async_execution_thread =
356           HloInstruction::kMainExecutionThread);
357 
358   // Create a deep copy of the given instruction and return the instruction
359   // producing the copied result. All instructions performing the copy are added
360   // to the computation. For array-shaped values, this method trivially returns
361   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
362   // with a series of kGetTupleElement and kTuple instructions. If
363   // indices_to_copy is non-null then this ShapeTree indicates which elements
364   // (arrays) of the shape to copy. Non-copied elements are passed through
365   // transparently. If copies_added is non-null, then the added kCopy
366   // instructions will be inserted in the respective index in the given
367   // ShapeTree.
368   StatusOr<HloInstruction*> DeepCopyInstruction(
369       HloInstruction* instruction,
370       const ShapeTree<bool>* indices_to_copy = nullptr,
371       ShapeTree<HloInstruction*>* copies_added = nullptr);
372 
373   // As above, but uses a custom function to copy the leaf nodes, which could
374   // create alternative HLOs other than kCopy, or even pass-throughs.
375   StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
376       HloInstruction* instruction,
377       const std::function<
378           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
379                           HloComputation* computation)>& copy_leaf);
380 
381   // Computes and returns the ProgramShape of this computation (shape of
382   // parameters and result with layout).
383   ProgramShape ComputeProgramShape(bool include_ids = true) const;
384 
385   // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)386   bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
387     return EqualInternal(other, is_layout_sensitive,
388                          /*ignore_channel_id_values=*/false,
389                          /*ignore_execution_thread=*/false);
390   }
391 
392   // Same as Equal() but ignores channel ID value mismatches on instructions, as
393   // long as the two instructions both have channel IDs or neither has a channel
394   // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)395   bool EqualIgnoringChannelIdValues(const HloComputation& other,
396                                     bool is_layout_sensitive) const {
397     return EqualInternal(other, is_layout_sensitive,
398                          /*ignore_channel_id_values=*/true,
399                          /*ignore_execution_thread=*/false);
400   }
401 
EqualIgnoringExecutionThread(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values)402   bool EqualIgnoringExecutionThread(const HloComputation& other,
403                                     bool is_layout_sensitive,
404                                     bool ignore_channel_id_values) const {
405     return EqualInternal(other, is_layout_sensitive, ignore_channel_id_values,
406                          /*ignore_execution_thread=*/true);
407   }
408 
409   // Return whether `*this` and `other` are functionally equivalent.
410   bool operator==(const HloComputation& other) const {
411     return Equal(other, true);
412   }
413   bool operator!=(const HloComputation& other) const {
414     return !(*this == other);
415   }
416 
417   // Replaces old instruction with newly created instruction. Removes old
418   // instruction from computation. Updates uses and root instruction.
419   Status ReplaceWithNewInstruction(
420       HloInstruction* old_instruction,
421       std::unique_ptr<HloInstruction> new_instruction);
422 
423   // Replaces an old instruction with a newly created instruction, and adds the
424   // new instruction as an entry computation's parameter. Removes old
425   // instruction from computation. Updates uses and root instruction.
426   Status ReplaceWithNewEntryComputationParameter(
427       HloInstruction* old_instruction,
428       std::unique_ptr<HloInstruction> new_instruction);
429 
430   // Replace old instruction with new instruction.  Updates uses and root
431   // instruction. Removes old instruction from computation. Transitively removes
432   // non-side effecting operands of old instruction that no longer have users,
433   // similar to RemoveInstructionAndUnusedOperands(). Precondition:
434   // old_instruction and new_instruction must have the compatible shapes.
435   // If preserve_sharding is true, the replacement will fail if both new and old
436   // instruction have sharding that is not compatible, and the function will
437   // return false. Otherwise, when the replacement happens, if |new_instruction|
438   // doesn't have any sharding information it will receive the sharding
439   // information of |old_instruction|, and function will return true.
440   StatusOr<bool> ReplaceInstruction(HloInstruction* old_instruction,
441                                     HloInstruction* new_instruction,
442                                     bool preserve_sharding);
443 
444   // Same as above, with preserve_sharding=false. Since this replacement always
445   // happens, it returns just a Status as opposed to StatusOr<bool>
446   Status ReplaceInstruction(HloInstruction* old_instruction,
447                             HloInstruction* new_instruction);
448 
449   // Same as ReplaceInstruction, but the new instruction can have a different
450   // shape.
451   StatusOr<bool> ReplaceInstructionWithDifferentShape(
452       HloInstruction* old_instruction, HloInstruction* new_instruction,
453       bool preserve_sharding);
454   Status ReplaceInstructionWithDifferentShape(HloInstruction* old_instruction,
455                                               HloInstruction* new_instruction);
456 
457   // Set/get the module containing this computation.
set_parent(HloModule * module)458   void set_parent(HloModule* module) { parent_ = module; }
parent()459   const HloModule* parent() const { return parent_; }
parent()460   HloModule* parent() { return parent_; }
461 
462   // Visit every node in the computation in DFS post-order with the given
463   // visitor. This is similar to calling HloInstruction::Accept on the root of
464   // the computation except this method also visits instructions not reachable
465   // via the root. The root instruction of the computation is visited last, and
466   // the visitor's FinishVisit method is called once upon completion (with the
467   // root instruction as the argument).
468   template <typename HloInstructionPtr>
469   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
470 
471   // Same as Accept() above, but the order of operand and control predecessor
472   // visitation is determined by the given operand order; if compare(A, B) ==
473   // true, A is visited before B.
474   Status AcceptWithOperandOrder(
475       DfsHloVisitor* visitor,
476       const HloInstruction::CompareFunction& operand_order) const;
477 
478   // Visit every node in the computation in the given order. 'order' must
479   // be a topological sort of all instructions in the computation.
480   template <typename HloInstructionPtr>
481   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
482                        absl::Span<HloInstruction* const> order) const;
483 
484   // Returns a deep copy of this computation including all instructions.
485   // If the clone context is specified, it will be populated with the cloned
486   // object mappings, and its module() will be used to add new computations
487   // into.
488   std::unique_ptr<HloComputation> Clone(const std::string& suffix = "clone",
489                                         HloCloneContext* context = nullptr);
490 
491   // Like Clone(), but if an instruction is present in replacement_map, we use
492   // the map's value to replace that instruction in the cloned computation.
493   //
494   // If replacements is nullptr, don't perform replacement.
495   // If replacements maps a key to nullptr, we remove that instruction from the
496   // new computation.  If an element of `replacements` references an instruction
497   // that's not already in the computation, it's cloned and added to the new
498   // computation.
499   //
500   // 'extra_parameters' allows to specify additional parameters that should be
501   // added to the computation.
502   //
503   // All relevant instructions are cloned, *including* unique_ptr in the
504   // `replacements` map.
505   std::unique_ptr<HloComputation> CloneWithReplacements(
506       const absl::flat_hash_map<const HloInstruction*,
507                                 std::unique_ptr<HloInstruction>>* replacements,
508       absl::Span<const HloInstruction* const> extra_parameters = {},
509       HloCloneContext* context = nullptr, const std::string& suffix = "clone",
510       const HloInstruction* new_root = nullptr);
511 
512   // Convenience overloads for CloneWithReplacements.  You want to do
513   //
514   //   CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}})  // ERROR
515   //
516   // but that doesn't work because std::initializer_list is not movable.  These
517   // overloads let you do
518   //
519   //   CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)});   // OK
520   //
521   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
522       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
523       HloCloneContext* context = nullptr, const std::string& suffix = "clone");
524   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
525       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
526       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
527       HloCloneContext* context = nullptr, const std::string& suffix = "clone");
528   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
529       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
530       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
531       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
532       HloCloneContext* context = nullptr, const std::string& suffix = "clone");
533 
534   // Returns true if the given instruction can be removed from the computation.
535   // Parameter instructions cannot be removed without violating invariants of
536   // the HLO computation with the exception of fusion computation. A parameter
537   // instruction is removable for a fusion computation.
538   //
539   // Note that IsSafelyRemovable() is a necessary condition to remove an
540   // instruction rather than a sufficient condition. For example, instructions
541   // with side-effect (e.g., Send, Infeed) may be removed from a computation,
542   // but the transformation must guarantee the invariants relevant to the
543   // instructions still hold (e.g., Send and Recv must be removed together to
544   // make each channel complete).
545   bool IsSafelyRemovable(const HloInstruction* instruction);
546 
547   // Returns a map from channel-id to the group of instructions associated with
548   // the channel. These instructions will be considered as a single node for
549   // dependency purposes. Send and RecvDone are in the group, and AllReduces
550   // with the same channel id are in the group.
551   using ChannelDependencyGroup =
552       absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>;
553   ChannelDependencyGroup ComputeChannelDependencies() const;
554 
555   // Returns true if this computation has a side effect. A computation has a
556   // side effect if it contains one or more instructions with a side effect.
557   bool HasSideEffect() const;
558 
559   // Returns if this computation is a fusion computation.
560   // Do not use this method to determine if fusion_instruction_ != nullptr.
561   // Instead, directly do: FusionInstruction() != nullptr
IsFusionComputation()562   bool IsFusionComputation() const { return is_fusion_computation_; }
563 
564   // Returns if this computation is the entry computation of the module.
565   bool IsEntryComputation() const;
566 
567   // Returns the owning fusion instruction, or nullptr if this is not a fusion
568   // computation.
FusionInstruction()569   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)570   void SetFusionInstruction(HloInstruction* fusion_instruction) {
571     CHECK(!IsCustomCallComputation() && !IsAsyncComputation());
572     fusion_instruction_ = fusion_instruction;
573     is_fusion_computation_ |= (fusion_instruction != nullptr);
574   }
575 
576   // Returns if this computation is a custom-call computation.
IsCustomCallComputation()577   bool IsCustomCallComputation() const { return is_custom_call_computation_; }
578 
579   // Returns the owning custom call instruction, or nullptr if this is not a
580   // custom call computation.
CustomCallInstruction()581   HloInstruction* CustomCallInstruction() const {
582     return custom_call_instruction_;
583   }
SetCustomCallInstruction(HloInstruction * custom_call_instruction)584   void SetCustomCallInstruction(HloInstruction* custom_call_instruction) {
585     CHECK(!IsFusionComputation() && !IsAsyncComputation());
586     custom_call_instruction_ = custom_call_instruction;
587     is_custom_call_computation_ |= (custom_call_instruction != nullptr);
588   }
589 
590   // Returns if this computation is an async computation.
IsAsyncComputation()591   bool IsAsyncComputation() const { return !async_instructions_.empty(); }
592 
593   // Returns the owning async instruction. It's empty if this is not an async
594   // computation.
AsyncInstructions()595   const std::vector<HloInstruction*>& AsyncInstructions() const {
596     return async_instructions_;
597   }
598 
AsyncInstructions()599   std::vector<HloInstruction*>& AsyncInstructions() {
600     return async_instructions_;
601   }
602 
AddAsyncInstruction(HloInstruction * async_instruction)603   void AddAsyncInstruction(HloInstruction* async_instruction) {
604     CHECK(async_instruction != nullptr)
605         << "Nullptr shouldn't be added as commputation's async instruction. ";
606     CHECK(!IsFusionComputation() && !IsCustomCallComputation());
607     CHECK(async_instruction->opcode() == HloOpcode::kAsyncStart ||
608           async_instruction->opcode() == HloOpcode::kAsyncUpdate ||
609           async_instruction->opcode() == HloOpcode::kAsyncDone);
610     async_instructions_.push_back(async_instruction);
611   }
612 
RemoveAsyncInstruction(HloInstruction * async_instruction)613   void RemoveAsyncInstruction(HloInstruction* async_instruction) {
614     if (async_instruction == nullptr) {
615       return;
616     }
617     async_instructions_.erase(
618         std::remove(async_instructions_.begin(), async_instructions_.end(),
619                     async_instruction),
620         async_instructions_.end());
621   }
622 
623   // Returns if this computation is invoked by an Hlo instruction.
IsCalledComputation()624   bool IsCalledComputation() const {
625     return IsFusionComputation() || IsCustomCallComputation();
626   }
627 
628   // Clear the unique ID of the computation so that it can be re-assigned, such
629   // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()630   void ClearUniqueIdInternal() { unique_id_ = -1; }
631 
632   // The id of this computation should be unique within the module.
SetUniqueId(int64_t id)633   void SetUniqueId(int64_t id) {
634     CHECK_EQ(unique_id_, -1);
635     CHECK_GE(id, 0);
636     unique_id_ = id;
637   }
638 
639   // Returns the instruction in this computation that has name `name`.  Returns
640   // null if there is no such computation.
641   HloInstruction* GetInstructionWithName(absl::string_view name);
642 
unique_id()643   int64_t unique_id() const { return unique_id_; }
644 
SetExecutionThread(absl::string_view execution_thread)645   void SetExecutionThread(absl::string_view execution_thread) {
646     execution_thread_ = std::string(execution_thread);
647   }
648 
execution_thread()649   absl::string_view execution_thread() const { return execution_thread_; }
650   // Returns true if this computation is annotated on "main" execution thread.
IsMainThread()651   bool IsMainThread() const {
652     return execution_thread_ == HloInstruction::kMainExecutionThread;
653   }
654 
655   // Deallocate instructions that are marked by "RemoveInstruction". The two
656   // stage clean up process is designed such that HloPass can have stable
657   // internal pointers to HloInstructions while we create and remove
658   // HloInstructions in a pass.
Cleanup()659   void Cleanup() { to_be_deleted_.clear(); }
660 
661   // Returns true if a given instruction is marked dead in this computation.
662   bool IsMarkedAsDead(const HloInstruction* inst);
663 
664  private:
665   explicit HloComputation(
666       const std::string& name, int parameter_count,
667       std::vector<std::unique_ptr<HloInstruction>>* instructions,
668       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
669 
670   // Internal helper for adding instructions.
671   HloInstruction* AddInstructionInternal(
672       std::unique_ptr<HloInstruction> instruction);
673 
674   // Internal helper for comparison with different options.
675   bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
676                      bool ignore_channel_id_values,
677                      bool ignore_execution_thread) const;
678 
679   // Appends (fuses) HLOs in instructions_to_append into the called computation
680   // of the caller.
681   void AppendInstructionsIntoCalledComputation(
682       absl::Span<HloInstruction* const> instructions_to_append,
683       HloInstruction* caller);
684 
685   // Internal helper for recursive copying of an instruction. Creates and
686   // returns a deep copy of the given instruction.
687   StatusOr<HloInstruction*> DeepCopyHelper(
688       HloInstruction* instruction, ShapeIndex* index,
689       const std::function<
690           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
691                           HloComputation* computation)>& copy_leaf);
692 
693   // Internal helper to collect unreachable roots.
694   std::vector<HloInstruction*> CollectUnreachableRoots() const;
695 
696   enum VisitState { kVisiting, kVisited };
697   void ComputeInstructionPostOrder(
698       HloInstruction* root,
699       HloComputation::ChannelDependencyGroup& channel_dependencies,
700       absl::flat_hash_map<HloInstruction*, VisitState>& visited,
701       std::vector<HloInstruction*>& post_order) const;
702 
703   Status RemoveUnusedParametersImpl(bool allow_non_fusion);
704 
705   Status RemoveInstructionImpl(HloInstruction* instruction,
706                                bool ignore_safety_check);
707 
708   std::string name_;
709   int64_t unique_id_;
710   HloInstruction* root_instruction_;
711 
712   // If this computation is a fusion computation, this field points to the
713   // corresponding fusion instruction (if it is live). Otherwise, this is null.
714   HloInstruction* fusion_instruction_;
715 
716   // Determines whether this computation is a fusion computation. A fusion
717   // computation ordinarily also has a non-null fusion_instruction_. However, if
718   // a fusion instruction is removed during compilation, the fusion computation
719   // becomes unreachable, and its fusion_instruction_ is set to null. We still
720   // need to regard such computations as fusion computations for HLO scheduling
721   // purposes.
722   bool is_fusion_computation_;
723 
724   // If this computation is a custom-call computation, this field points to the
725   // corresponding custom-call instruction (if it is live). Otherwise, this is
726   // null.
727   HloInstruction* custom_call_instruction_;
728 
729   // Determines whether this computation is a custom-call computation.
730   bool is_custom_call_computation_;
731 
732   // If this computation is an async computation, this field points to the
733   // corresponding async instructions (if live) that call this computation.
734   // Otherwise, this is empty.
735   std::vector<HloInstruction*> async_instructions_;
736 
737   // Execution thread of this computation. By default, it's main thread.
738   std::string execution_thread_ = HloInstruction::kMainExecutionThread;
739 
740   // Module containing this computation.
741   HloModule* parent_ = nullptr;
742 
743   // Store instructions in std::list as they can be added and removed
744   // arbitrarily and we want a stable iteration order. Keep a map from
745   // instruction pointer to location in the list for fast lookup.
746   InstructionList instructions_;
747   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
748       instruction_iterators_;
749 
750   // Removed instructions are moved into to_be_deleted_ first and then
751   // deallocated when Cleanup is called.
752   std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
753 
754   std::vector<HloInstruction*> param_instructions_;
755 
756   HloComputation(const HloComputation&) = delete;
757   HloComputation& operator=(const HloComputation&) = delete;
758 };
759 
760 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)761 Status HloComputation::Accept(
762     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
763   // Visit unreachable roots. Beware that the visitor might delete the currently
764   // visited root, which would invalidate iterators if the unreachable roots
765   // weren't computed ahead of time.
766   for (HloInstruction* root : CollectUnreachableRoots()) {
767     VLOG(3) << "Traversing unreachable root: " << root->ToString();
768     // Call FinishVisit only at the end.
769     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
770   }
771   // Visit the computation root instruction last.
772   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
773 }
774 
775 // Explicit instantiations.
776 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
777 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
778 
779 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)780 Status HloComputation::AcceptOrdered(
781     DfsHloVisitorBase<HloInstructionPtr>* visitor,
782     absl::Span<HloInstruction* const> order) const {
783   VLOG(3) << "Accepting visitor with order.";
784   for (HloInstruction* root : CollectUnreachableRoots()) {
785     TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
786   }
787   TF_RET_CHECK(order.size() == instruction_count());
788   absl::flat_hash_set<const HloInstruction*> visited;
789   for (const HloInstruction* instruction : order) {
790     VLOG(3) << "Visiting ordered: " << instruction->ToString();
791     TF_RET_CHECK(instruction_iterators_.contains(instruction))
792         << "Instruction " << instruction->name() << " is not in computation "
793         << name();
794     TF_RET_CHECK(!visited.contains(instruction))
795         << "Instruction " << instruction->name()
796         << " appears more than once in order";
797     HloInstruction* mutable_instruction =
798         const_cast<HloInstruction*>(instruction);
799     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
800     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
801     visitor->SetVisited(*mutable_instruction);
802     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
803     visited.insert(instruction);
804   }
805   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
806   return OkStatus();
807 }
808 
809 // Explicit instantiations.
810 template Status HloComputation::AcceptOrdered(
811     DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
812 template Status HloComputation::AcceptOrdered(
813     ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
814 
815 }  // namespace xla
816 
817 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
818