• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18 
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/name_uniquer.h"
38 #include "tensorflow/compiler/xla/shape_tree.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/macros.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace xla {
47 
48 class HloModule;
49 
50 // Describes a computation at the HLO level.
51 //
52 // You can think of an HloComputation like a function.  It has some inputs
53 // (parameters) and returns exactly one value (the value of its root node).  If
54 // you want to return multiple values, you can return a tuple.
55 //
56 // The instructions inside of a computation do not have an explicit total order.
57 // Instead, they have a partial order determined by their data and control
58 // dependencies.
59 //
60 // An HloModule contains one "entry computation" -- this is like main() in a C
61 // program.  Every other computation inside of a module is attached to one or
62 // more HloInstructions, as a "nested computation".  For example, the kMap
63 // instruction has a nested computation and "applies" it to every element of its
64 // input, elementwise.  (That is, the input [x, y, z] is transformed to [f(x),
65 // f(y), f(z)].)
66 class HloComputation {
67  public:
68   // Builder class for HloComputation.
69   class Builder {
70    public:
71     explicit Builder(const string& name,
72                      HloInstruction* fusion_instruction = nullptr)
name_(name)73         : name_(name),
74           last_added_instruction_(nullptr),
75           fusion_instruction_(fusion_instruction) {}
76 
77     // Build and return an HloComputation. The parameter root_instruction
78     // specifies the already-added instruction to use as the root. If
79     // root_instruction is nullptr then use the last added instruction as the
80     // root.
81     std::unique_ptr<HloComputation> Build(
82         HloInstruction* root_instruction = nullptr);
83 
AddInstruction(std::unique_ptr<HloInstruction> instruction)84     HloInstruction* AddInstruction(
85         std::unique_ptr<HloInstruction> instruction) {
86       instructions_.push_back(std::move(instruction));
87       last_added_instruction_ = instructions_.back().get();
88       return last_added_instruction_;
89     }
90 
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)91     Status ForEachInstruction(
92         const std::function<Status(const HloInstruction*)>& func) const {
93       for (const auto& instruction : instructions_) {
94         TF_RETURN_IF_ERROR(func(instruction.get()));
95       }
96       return Status::OK();
97     }
98 
99    private:
100     const string name_;
101     HloInstruction* last_added_instruction_;
102     HloInstruction* fusion_instruction_;
103     std::vector<std::unique_ptr<HloInstruction>> instructions_;
104   };
105 
106   // Add an instruction to the computation. The computation takes ownership of
107   // the instruction.
108   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
109 
110   // Remove the param_no'th parameter from the computation.
111   // Note this is only applicatable to the computation for the fusion
112   // instruction.
113   Status RemoveParameter(int64 param_no);
114 
115   // Remove unused parameters from the computation.
116   // Note this is only applicatable to the computation for the fusion
117   // instruction.
118   Status RemoveUnusedParameters();
119 
120   // Adds a new parameter instruction to a fusion computation.
121   //
122   // This should be a new parameter. Instruction will be appended to parameters
123   // and inserted to the instruction list.
124   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
125 
126   // Adds a new parameter instruction to the entry computation and update
127   // the parent module config to reflect the change.
128   //
129   // This should be a new parameter. Instruction will be appended to parameters
130   // and inserted to the instruction list.
131   HloInstruction* AddEntryComputationParameter(
132       std::unique_ptr<HloInstruction> instruction);
133 
134   // Remove an instruction from the computation. The instruction must have no
135   // users. Instruction is deallocated with this call.
136   Status RemoveInstruction(HloInstruction* instruction);
137 
138   // Remove an instruction (including side effecting ones) from the computation
139   // and also transitively any operand that has no side effect and no users post
140   // removing an instruction. The instruction must have no users. Instruction is
141   // deallocated with this call.
142   Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
143 
144   // Set the root of the computation to the given instruction. The instruction
145   // must have already been added to the computation. In addition it must have
146   // the same shape as the result of the computation for non fusion
147   // computations, except if accept_different_shape is set to true.
148   void set_root_instruction(HloInstruction* new_root_instruction,
149                             bool accept_different_shape = false);
150 
151   // Return the root instruction of the computation. The root instruction is the
152   // instruction which produces the output of the computation.
root_instruction()153   HloInstruction* root_instruction() const { return root_instruction_; }
154 
155   // Returns the number of parameters for this computation.
num_parameters()156   int64 num_parameters() const { return param_instructions_.size(); }
157 
158   // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64 param_no)159   HloInstruction* parameter_instruction(int64 param_no) const {
160     CHECK_GE(param_no, 0);
161     CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
162         << "Computation " << name() << " has no parameter number " << param_no;
163     return param_instructions_[param_no];
164   }
165 
parameter_instructions()166   const std::vector<HloInstruction*>& parameter_instructions() const {
167     return param_instructions_;
168   }
169 
name()170   const string& name() const { return name_; }
171 
172   // Use the given NameUniquer to select a unique name for the computation based
173   // on the computation's existing name.
174   void UniquifyName(NameUniquer* name_uniquer);
175 
176   // Return a string representation of the computation.
177   //
178   // (We express the default options using an overload rather than a default
179   // param because gdb ignores default params, but does resolve overloads.)
ToString()180   string ToString() const { return ToString(HloPrintOptions()); }
181   string ToString(const HloPrintOptions& options) const;
182 
183   // Overload which accepts an order to emit the instructions in.
184   string ToString(
185       const HloPrintOptions& options,
186       absl::Span<const HloInstruction* const> instruction_order) const;
187 
188   // Returns a serialized representation of this computation.
189   HloComputationProto ToProto() const;
190 
191   // Creates a computation from the given proto. Arguments:
192   //
193   //   proto: the proto to convert from.
194   //   computation_map: a map from computation id to HloComputation*. This map
195   //     must contain all computations which the newly constructed computation
196   //     calls.
197   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
198       const HloComputationProto& proto,
199       const absl::flat_hash_map<int64, HloComputation*>& computation_map);
200 
201   // Gets the instructions in this computation.
202   //
203   // The returned type is a range of HloInstruction*s, so you can iterate over
204   // it using a range-based for loop in the natural way:
205   //
206   //   for (HloInstruction* instr : computation->instructions()) { ... }
207   //
208   tensorflow::gtl::iterator_range<UnwrappingIterator<
209       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
instructions()210   instructions() const {
211     return {MakeUnwrappingIterator(instructions_.begin()),
212             MakeUnwrappingIterator(instructions_.end())};
213   }
214   tensorflow::gtl::iterator_range<
215       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
instructions()216   instructions() {
217     return {MakeUnwrappingIterator(instructions_.begin()),
218             MakeUnwrappingIterator(instructions_.end())};
219   }
220 
221   // Compute and return a post-order of the instructions in the computation. In
222   // this order, definitions of values always appear before their uses.
223   std::vector<HloInstruction*> MakeInstructionPostOrder() const;
224 
instruction_count()225   int64 instruction_count() const { return instruction_iterators_.size(); }
226 
227   // Creates and returns a list of the embedded computations called by this
228   // computation. This includes all embedded computations called directly or
229   // transitively. The embedded computations are sorted such that if computation
230   // A calls computation B (eg, via a map instruction) then A will appear after
231   // B in the list.
232   std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
233 
234   // Creates a fusion instruction containing the given instructions.
235   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
236   // into a library call. Instructions must be in reverse topological order
237   // (root of the fused expression first). Replaces all uses of the original
238   // root instruction with the fusion instruction. The original instructions are
239   // removed if they have no uses after fusion (this is necessarily true for at
240   // least the root).
241   HloInstruction* CreateFusionInstruction(
242       absl::Span<HloInstruction* const> instructions_to_fuse,
243       HloInstruction::FusionKind fusion_kind);
244 
245   // Create a deep copy of the given instruction and return the instruction
246   // producing the copied result. All instructions performing the copy are added
247   // to the computation. For array-shaped values, this method trivially returns
248   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
249   // with a series of kGetTupleElement and kTuple instructions. If
250   // indices_to_copy is non-null then this ShapeTree indicates which elements
251   // (arrays) of the shape to copy. Non-copied elements are passed through
252   // transparently. If copies_added is non-null, then the added kCopy
253   // instructions will be inserted in the respective index in the given
254   // ShapeTree.
255   StatusOr<HloInstruction*> DeepCopyInstruction(
256       HloInstruction* instruction,
257       const ShapeTree<bool>* indices_to_copy = nullptr,
258       ShapeTree<HloInstruction*>* copies_added = nullptr);
259 
260   // As above, but uses a custom function to copy the leaf nodes, which could
261   // create alternative HLOs other than kCopy, or even pass-throughs.
262   StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
263       HloInstruction* instruction,
264       const std::function<
265           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
266                           HloComputation* computation)>& copy_leaf);
267 
268   // Computes and returns the ProgramShape of this computation (shape of
269   // parameters and result with layout).
270   ProgramShape ComputeProgramShape() const;
271 
272   // Return whether `*this` and `other` are functionally equivalent.
273   bool operator==(const HloComputation& other) const;
274 
275   // Replaces old instruction with newly created instruction. Removes old
276   // instruction from computation. Updates uses and root instruction.
277   Status ReplaceWithNewInstruction(
278       HloInstruction* old_instruction,
279       std::unique_ptr<HloInstruction> new_instruction);
280 
281   // Replace old instruction with new instruction.  Updates uses and root
282   // instruction. Removes old instruction from computation. Precondition:
283   // old_instruction and new_instruction must have the compatible shapes.
284   Status ReplaceInstruction(HloInstruction* old_instruction,
285                             HloInstruction* new_instruction);
286 
287   // Set/get the module containing this computation.
set_parent(HloModule * module)288   void set_parent(HloModule* module) { parent_ = module; }
parent()289   const HloModule* parent() const { return parent_; }
parent()290   HloModule* parent() { return parent_; }
291 
292   // Visit every node in the computation in DFS post-order with the given
293   // visitor. This is similar to calling HloInstruction::Accept on the root of
294   // the computation except this method also visits instructions not reachable
295   // via the root. The root instruction of the computation is visited last, and
296   // the visitor's FinishVisit method is called once upon completion (with the
297   // root instruction as the argument).
298   template <typename HloInstructionPtr>
299   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
300 
301   // Same as Accept() above, but the order of operand and control predecessor
302   // visitation is determined by the given operand order; if compare(A, B) ==
303   // true, A is visited before B.
304   Status AcceptWithOperandOrder(
305       DfsHloVisitor* visitor,
306       const HloInstruction::CompareFunction& operand_order) const;
307 
308   // Visit every node in the computation in the given order. 'order' must
309   // be a topological sort of all instructions in the computation.
310   template <typename HloInstructionPtr>
311   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
312                        absl::Span<HloInstruction* const> order) const;
313 
314   // Same as Accept() above, but the visitor is given as a function.
315   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
316   Status Accept(
317       const std::function<Status(const HloInstruction*)>& visitor_func) const;
318 
319   // Returns a deep copy of this computation including all instructions.
320   // If the clone context is specified, it will be populated with the cloned
321   // object mappings, and its module() will be used to add new computations
322   // into.
323   std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
324                                         HloCloneContext* context = nullptr);
325 
326   // Like Clone(), but if an instruction is present in replacement_map, we use
327   // the map's value to replace that instruction in the cloned computation.
328   //
329   // If replacements maps a key to nullptr, we remove that instruction from the
330   // new computation.  If an element of `replacements` references an instruction
331   // that's not already in the computation, it's cloned and added to the new
332   // computation.
333   //
334   // 'extra_parameters' allows to specify additional parameters that should be
335   // added to the computation.
336   //
337   // All relevant instructions are cloned, *including* unique_ptr in the
338   // `replacements` map.
339   std::unique_ptr<HloComputation> CloneWithReplacements(
340       absl::flat_hash_map<const HloInstruction*,
341                           std::unique_ptr<HloInstruction>>
342           replacements,
343       absl::Span<const HloInstruction* const> extra_parameters = {},
344       HloCloneContext* context = nullptr, const string& suffix = "clone");
345 
346   // Convenience overloads for CloneWithReplacements.  You want to do
347   //
348   //   CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}})  // ERROR
349   //
350   // but that doesn't work because std::initializer_list is not movable.  These
351   // overloads let you do
352   //
353   //   CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)});   // OK
354   //
355   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
356       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
357       HloCloneContext* context = nullptr, const string& suffix = "clone");
358   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
359       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
360       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
361       HloCloneContext* context = nullptr, const string& suffix = "clone");
362   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
363       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
364       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
365       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
366       HloCloneContext* context = nullptr, const string& suffix = "clone");
367 
368   // Returns true if the given instruction can be removed from the computation.
369   // Parameter instructions cannot be removed without violating invariants of
370   // the HLO computation with the exception of fusion computation. A parameter
371   // instruction is removable for a fusion computation.
372   //
373   // Note that IsRemovable() is a necessariy condition to remove an instruction
374   // rather than a sufficient condition. For example, instructions with
375   // side-effect (e.g., Send, Infeed) may be removed from a computation, but the
376   // transformation must guarantee the invariants relevant to the instructions
377   // still hold (e.g., Send and Recv must be removed together to make each
378   // channel complete).
379   bool IsRemovable(const HloInstruction* instruction);
380 
381   // Returns a map from channel-id to the group of instructions associated with
382   // the channel. These instructions will be considered as a single node for
383   // dependency purposes. Send and RecvDone are in the group, and AllReduces
384   // with the same channel id are in the group.
385   using ChannelDependencyGroup =
386       absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
387   ChannelDependencyGroup ComputeChannelDependencies() const;
388 
389   // Returns true if this computation has a side effect. A computation has a
390   // side effect if it contains one or more instructions with a side effect.
391   bool HasSideEffect() const;
392 
393   // Returns if this computation is a fusion computation.
IsFusionComputation()394   bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
395 
396   // Returns the owning fusion instruction, or nullptr if this is not a fusion
397   // computation.
FusionInstruction()398   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)399   void SetFusionInstruction(HloInstruction* fusion_instruction) {
400     fusion_instruction_ = fusion_instruction;
401   }
402 
403   // Clear the unique ID of the computation so that it can be re-assigned, such
404   // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()405   void ClearUniqueIdInternal() { unique_id_ = -1; }
406 
407   // The id of this computation should be unique within the module.
SetUniqueId(int64 id)408   void SetUniqueId(int64 id) {
409     CHECK_EQ(unique_id_, -1);
410     CHECK_GE(id, 0);
411     unique_id_ = id;
412   }
413 
414   // Returns the instruction in this computation that has name `name`.  Returns
415   // null if there is no such computation.
416   HloInstruction* GetInstructionWithName(absl::string_view name);
417 
unique_id()418   int64 unique_id() const { return unique_id_; }
419 
420  private:
421   explicit HloComputation(
422       const string& name, int parameter_count,
423       std::vector<std::unique_ptr<HloInstruction>>* instructions,
424       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
425 
426   // Internal helper for adding instructions.
427   HloInstruction* AddInstructionInternal(
428       std::unique_ptr<HloInstruction> instruction);
429 
430   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
431   //
432   // Pre-condition: fusion_instruction's opcode is kFusion.
433   void FuseInstructionsInto(
434       absl::Span<HloInstruction* const> instructions_to_fuse,
435       HloInstruction* fusion_instruction);
436 
437   // Internal helper for recursive copying of an instruction. Creates and
438   // returns a deep copy of the given instruction.
439   StatusOr<HloInstruction*> DeepCopyHelper(
440       HloInstruction* instruction, ShapeIndex* index,
441       const std::function<
442           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
443                           HloComputation* computation)>& copy_leaf);
444 
445   // Internal helper to collect unreachable roots.
446   std::vector<HloInstruction*> CollectUnreachableRoots() const;
447 
448   enum VisitState { kVisiting, kVisited };
449   void ComputeInstructionPostOrder(
450       const HloComputation::ChannelDependencyGroup& channel_dependency_map,
451       std::vector<HloInstruction*>* post_order, HloInstruction* root,
452       absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
453 
454   string name_;
455   int64 unique_id_;
456   HloInstruction* root_instruction_;
457 
458   // If this computation is a fusion computation, this field points to the
459   // corresponding fusion instruction.  Otherwise, this is null.
460   HloInstruction* fusion_instruction_;
461 
462   // Module containing this computation.
463   HloModule* parent_ = nullptr;
464 
465   // Store instructions in std::list as they can be added and removed
466   // arbitrarily and we want a stable iteration order. Keep a map from
467   // instruction pointer to location in the list for fast lookup.
468   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
469   InstructionList instructions_;
470   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
471       instruction_iterators_;
472 
473   std::vector<HloInstruction*> param_instructions_;
474 
475   TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
476 };
477 
478 }  // namespace xla
479 
480 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
481