1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/strings/cord.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/iterator_util.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
35 #include "tensorflow/compiler/xla/service/hlo.pb.h"
36 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/name_uniquer.h"
39 #include "tensorflow/compiler/xla/shape_tree.h"
40 #include "tensorflow/compiler/xla/statusor.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44
45 namespace xla {
46
47 class HloModule;
48
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function. It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node). If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program. Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation". For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise. (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66 public:
67 // Used by instructions_.
68 using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
69
70 // Builder class for HloComputation.
71 class Builder {
72 public:
73 explicit Builder(const std::string& name,
74 HloInstruction* fusion_instruction = nullptr)
name_(name)75 : name_(name),
76 last_added_instruction_(nullptr),
77 fusion_instruction_(fusion_instruction) {}
78 Builder(Builder&& b) = default;
79 virtual ~Builder() = default;
80
81 // Build and return an HloComputation. The parameter root_instruction
82 // specifies the already-added instruction to use as the root. If
83 // root_instruction is nullptr then use the last added instruction as the
84 // root.
85 std::unique_ptr<HloComputation> Build(
86 HloInstruction* root_instruction = nullptr);
87
88 // Add the instruction to be part of this computation.
89 // If the new instruction is derived from another one,
90 // you probably want to do
91 // `original_inst->AddInstruction(new_inst)` instead.
AddInstruction(std::unique_ptr<HloInstruction> instruction)92 virtual HloInstruction* AddInstruction(
93 std::unique_ptr<HloInstruction> instruction) {
94 instructions_.push_back(std::move(instruction));
95 last_added_instruction_ = instructions_.back().get();
96 return last_added_instruction_;
97 }
98
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)99 Status ForEachInstruction(
100 const std::function<Status(const HloInstruction*)>& func) const {
101 for (const auto& instruction : instructions_) {
102 TF_RETURN_IF_ERROR(func(instruction.get()));
103 }
104 return OkStatus();
105 }
106
last_added_instruction()107 HloInstruction* last_added_instruction() const {
108 return last_added_instruction_;
109 }
110
111 private:
112 const std::string name_;
113 HloInstruction* last_added_instruction_;
114 HloInstruction* fusion_instruction_;
115 std::vector<std::unique_ptr<HloInstruction>> instructions_;
116
117 Builder(const Builder&) = delete;
118 Builder& operator=(const Builder&) = delete;
119 };
120
121 // Helper class to automatically set the OpMetadata for every instruction
122 // added to a computation.
123 class MetadataBuilder {
124 public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)125 MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
126 : computation_(computation), metadata_(metadata) {}
127
AddInstruction(std::unique_ptr<HloInstruction> instruction)128 HloInstruction* AddInstruction(
129 std::unique_ptr<HloInstruction> instruction) {
130 instruction->set_metadata(metadata_);
131 return computation_->AddInstruction(std::move(instruction));
132 }
133
134 private:
135 HloComputation* computation_;
136 OpMetadata metadata_;
137 };
138
139 ~HloComputation();
140
141 // Add an instruction to the computation. The computation takes ownership of
142 // the instruction.
143 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
144 const std::string& new_name = "");
145
146 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
147 const OpMetadata* metadata);
148
149 // Replace the old parameter at index param_no with
150 // `instruction`. Updates uses and root instruction. Removes old
151 // instruction from computation. No check is done on the shape.
152 HloInstruction* ReplaceParameter(int64_t param_no,
153 std::unique_ptr<HloInstruction> instruction);
154
155 // Remove the param_no'th parameter from the computation.
156 // Note this is only applicatable to the computation for the fusion
157 // instruction.
158 Status RemoveParameter(int64_t param_no);
159
160 // Remove unused parameters from the computation.
161 // Note this is only applicatable to the computation for the fusion
162 // instruction.
163 Status RemoveUnusedParametersFromFusedComputation();
164
165 // Remove unused parameters from the computation. Unlike
166 // RemoveUnusedParametersFromFusedComputation, this function can be used
167 // to remove parameters from non-fusion computations.
168 Status RemoveUnusedParametersFromAnyComputation();
169
170 // Adds a new parameter instruction to a fusion computation.
171 //
172 // This should be a new parameter. Instruction will be appended to parameters
173 // and inserted to the instruction list.
174 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
175
176 // Adds a new parameter instruction to the entry computation and update
177 // the parent module config to reflect the change.
178 //
179 // This should be a new parameter. Instruction will be appended to parameters
180 // and inserted to the instruction list.
181 HloInstruction* AddEntryComputationParameter(
182 std::unique_ptr<HloInstruction> instruction);
183
184 // Replaces an old parameter with a new parameter. Adds the new parameter
185 // instruction to the entry computation. Updates users instruction.
186 Status ReplaceEntryComputationParameter(
187 int64_t param_no, HloInstruction* old_instruction,
188 std::unique_ptr<HloInstruction> instruction);
189
190 // Remove an instruction from the computation. The instruction must have no
191 // users. Instruction is deallocated with this call.
192 Status RemoveInstruction(HloInstruction* instruction);
193
194 // Removes an instruction from the computation. The instruction must have no
195 // users. Instruction is deallocated with this call. The instruction will be
196 // removed even if it is marked as not removable.
197 Status ForceRemoveInstruction(HloInstruction* instruction);
198
199 // Remove an instruction (including side effecting ones) from the computation
200 // and also transitively any operand that has no side effect and no users post
201 // removing an instruction. The instruction must have no users. Instruction is
202 // deallocated with this call. If given, the cleanup routine is executed on a
203 // removed instruction before its deallocation.
204 Status RemoveInstructionAndUnusedOperands(
205 HloInstruction* instruction,
206 std::function<void(HloInstruction*)> cleanup = nullptr);
207
208 // Set the root of the computation to the given instruction. The instruction
209 // must have already been added to the computation. In addition it must have
210 // the same shape as the result of the computation for non fusion
211 // computations, except if accept_different_shape is set to true.
212 void set_root_instruction(HloInstruction* new_root_instruction,
213 bool accept_different_shape = false);
214
215 // Return the root instruction of the computation. The root instruction is the
216 // instruction which produces the output of the computation.
root_instruction()217 HloInstruction* root_instruction() const { return root_instruction_; }
218
219 // Returns the number of parameters for this computation.
num_parameters()220 int64_t num_parameters() const { return param_instructions_.size(); }
221
222 // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64_t param_no)223 HloInstruction* parameter_instruction(int64_t param_no) const {
224 CHECK_GE(param_no, 0);
225 CHECK_LT(param_no, static_cast<int64_t>(param_instructions_.size()))
226 << "Computation " << name() << " has no parameter number " << param_no;
227 return param_instructions_[param_no];
228 }
229
parameter_instructions()230 const std::vector<HloInstruction*>& parameter_instructions() const {
231 return param_instructions_;
232 }
233
name()234 const std::string& name() const { return name_; }
235
236 // Use the given NameUniquer to select a unique name for the computation based
237 // on the computation's existing name.
238 void UniquifyName(NameUniquer* name_uniquer);
239
240 // Return a string representation of the computation.
241 //
242 // (We express the default options using an overload rather than a default
243 // param because gdb ignores default params, but does resolve overloads.)
ToString()244 std::string ToString() const { return ToString(HloPrintOptions()); }
245 std::string ToString(const HloPrintOptions& options) const;
246
247 // Overload which accepts an order to emit the instructions in.
248 std::string ToString(
249 const HloPrintOptions& options,
250 absl::Span<const HloInstruction* const> instruction_order) const;
251
252 // Returns a Cord representation of the computation.
253 //
254 // (We express the default options using an overload rather than a default
255 // param because gdb ignores default params, but does resolve overloads.)
ToCord()256 absl::Cord ToCord() const { return ToCord(HloPrintOptions()); }
257 absl::Cord ToCord(const HloPrintOptions& options) const;
258
259 // Overload which accepts an order to emit the instructions in.
260 absl::Cord ToCord(
261 const HloPrintOptions& options,
262 absl::Span<const HloInstruction* const> instruction_order) const;
263
264 // Returns a serialized representation of this computation.
265 HloComputationProto ToProto() const;
266
267 // Creates a computation from the given proto. Arguments:
268 //
269 // proto: the proto to convert from.
270 // computation_map: a map from computation id to HloComputation*. This map
271 // must contain all computations which the newly constructed computation
272 // calls.
273 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
274 const HloComputationProto& proto,
275 const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
276 bool prohibit_empty_literal = true);
277
278 // Generates a hash value of an HLO computation. Hash considers
279 // information on opcode, shape, operands, and typically a root instruction.
280 // This function returns the same hash value for equivalent HLO computations,
281 // with respect to HloComputation::Equal() method.
282 template <typename H>
AbslHashValue(H h,const HloComputation & computation)283 friend H AbslHashValue(H h, const HloComputation& computation) {
284 auto instructions = computation.MakeInstructionPostOrder();
285 for (auto* instruction : instructions) {
286 h = H::combine(std::move(h), *instruction);
287 }
288 return H::combine(std::move(h), instructions.size());
289 }
290
291 using InstructionSequence = tensorflow::gtl::iterator_range<
292 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
293
294 using ConstInstructionSequence =
295 tensorflow::gtl::iterator_range<UnwrappingIterator<
296 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
297
298 // Gets the instructions in this computation.
299 //
300 // The returned type is a range of HloInstruction*s, so you can iterate over
301 // it using a range-based for loop in the natural way:
302 //
303 // for (HloInstruction* instr : computation->instructions()) { ... }
304 //
instructions()305 ConstInstructionSequence instructions() const {
306 return {MakeUnwrappingIterator(instructions_.begin()),
307 MakeUnwrappingIterator(instructions_.end())};
308 }
instructions()309 InstructionSequence instructions() {
310 return {MakeUnwrappingIterator(instructions_.begin()),
311 MakeUnwrappingIterator(instructions_.end())};
312 }
313
314 // Compute and return a post-order of the instructions in the computation. In
315 // this order, definitions of values always appear before their uses.
316 std::vector<HloInstruction*> MakeInstructionPostOrder() const;
317
instruction_count()318 int64_t instruction_count() const { return instruction_iterators_.size(); }
319
320 // Creates and returns a list of the embedded computations called by this
321 // computation. This includes all embedded computations called directly or
322 // transitively. The embedded computations are sorted such that if computation
323 // A calls computation B (eg, via a map instruction) then A will appear after
324 // B in the list.
325 std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
326
327 // Creates a fusion instruction containing the given instructions.
328 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
329 // into a library call. Instructions must be in reverse topological order
330 // (root of the fused expression first). Replaces all uses of the original
331 // root instruction with the fusion instruction. The original instructions are
332 // removed if they have no uses after fusion (this is necessarily true for at
333 // least the root).
334 HloInstruction* CreateFusionInstruction(
335 absl::Span<HloInstruction* const> instructions_to_fuse,
336 HloInstruction::FusionKind fusion_kind);
337
338 // Creates a call instruction containing the given instructions. Instructions
339 // must be in reverse topological order (root of the called computation
340 // first). Replaces all uses of the original root instruction with the call
341 // instruction. The original instructions are removed if they have no uses
342 // after creating the call (this is necessarily true for at least the root).
343 HloInstruction* CreateCallInstruction(
344 absl::Span<HloInstruction* const> instructions_to_call);
345
346 // Creates an async start/done instruction pair where instruction is wrapped
347 // inside an asynchronous computation. The context shapes are appended to the
348 // output tuple of the asynchronous start which is backend specific. Returns
349 // the async done instruction. The new async start instruction is the operand
350 // of the async done instruction so that can be accessed using that. If
351 // present, `async_execution_thread` will be attached to the
352 // async-start/update/done instructions as well as wrapped computations.
353 StatusOr<HloInstruction*> CreateAsyncInstructions(
354 HloInstruction* instruction, absl::Span<const Shape> context_shapes,
355 absl::string_view async_execution_thread =
356 HloInstruction::kMainExecutionThread);
357
358 // Create a deep copy of the given instruction and return the instruction
359 // producing the copied result. All instructions performing the copy are added
360 // to the computation. For array-shaped values, this method trivially returns
361 // a kCopy instruction. For tuple-shaped instructions, the copy is performed
362 // with a series of kGetTupleElement and kTuple instructions. If
363 // indices_to_copy is non-null then this ShapeTree indicates which elements
364 // (arrays) of the shape to copy. Non-copied elements are passed through
365 // transparently. If copies_added is non-null, then the added kCopy
366 // instructions will be inserted in the respective index in the given
367 // ShapeTree.
368 StatusOr<HloInstruction*> DeepCopyInstruction(
369 HloInstruction* instruction,
370 const ShapeTree<bool>* indices_to_copy = nullptr,
371 ShapeTree<HloInstruction*>* copies_added = nullptr);
372
373 // As above, but uses a custom function to copy the leaf nodes, which could
374 // create alternative HLOs other than kCopy, or even pass-throughs.
375 StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
376 HloInstruction* instruction,
377 const std::function<
378 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
379 HloComputation* computation)>& copy_leaf);
380
381 // Computes and returns the ProgramShape of this computation (shape of
382 // parameters and result with layout).
383 ProgramShape ComputeProgramShape(bool include_ids = true) const;
384
385 // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)386 bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
387 return EqualInternal(other, is_layout_sensitive,
388 /*ignore_channel_id_values=*/false,
389 /*ignore_execution_thread=*/false);
390 }
391
392 // Same as Equal() but ignores channel ID value mismatches on instructions, as
393 // long as the two instructions both have channel IDs or neither has a channel
394 // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)395 bool EqualIgnoringChannelIdValues(const HloComputation& other,
396 bool is_layout_sensitive) const {
397 return EqualInternal(other, is_layout_sensitive,
398 /*ignore_channel_id_values=*/true,
399 /*ignore_execution_thread=*/false);
400 }
401
EqualIgnoringExecutionThread(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values)402 bool EqualIgnoringExecutionThread(const HloComputation& other,
403 bool is_layout_sensitive,
404 bool ignore_channel_id_values) const {
405 return EqualInternal(other, is_layout_sensitive, ignore_channel_id_values,
406 /*ignore_execution_thread=*/true);
407 }
408
409 // Return whether `*this` and `other` are functionally equivalent.
410 bool operator==(const HloComputation& other) const {
411 return Equal(other, true);
412 }
413 bool operator!=(const HloComputation& other) const {
414 return !(*this == other);
415 }
416
417 // Replaces old instruction with newly created instruction. Removes old
418 // instruction from computation. Updates uses and root instruction.
419 Status ReplaceWithNewInstruction(
420 HloInstruction* old_instruction,
421 std::unique_ptr<HloInstruction> new_instruction);
422
423 // Replaces an old instruction with a newly created instruction, and adds the
424 // new instruction as an entry computation's parameter. Removes old
425 // instruction from computation. Updates uses and root instruction.
426 Status ReplaceWithNewEntryComputationParameter(
427 HloInstruction* old_instruction,
428 std::unique_ptr<HloInstruction> new_instruction);
429
430 // Replace old instruction with new instruction. Updates uses and root
431 // instruction. Removes old instruction from computation. Transitively removes
432 // non-side effecting operands of old instruction that no longer have users,
433 // similar to RemoveInstructionAndUnusedOperands(). Precondition:
434 // old_instruction and new_instruction must have the compatible shapes.
435 // If preserve_sharding is true, the replacement will fail if both new and old
436 // instruction have sharding that is not compatible, and the function will
437 // return false. Otherwise, when the replacement happens, if |new_instruction|
438 // doesn't have any sharding information it will receive the sharding
439 // information of |old_instruction|, and function will return true.
440 StatusOr<bool> ReplaceInstruction(HloInstruction* old_instruction,
441 HloInstruction* new_instruction,
442 bool preserve_sharding);
443
444 // Same as above, with preserve_sharding=false. Since this replacement always
445 // happens, it returns just a Status as opposed to StatusOr<bool>
446 Status ReplaceInstruction(HloInstruction* old_instruction,
447 HloInstruction* new_instruction);
448
449 // Same as ReplaceInstruction, but the new instruction can have a different
450 // shape.
451 StatusOr<bool> ReplaceInstructionWithDifferentShape(
452 HloInstruction* old_instruction, HloInstruction* new_instruction,
453 bool preserve_sharding);
454 Status ReplaceInstructionWithDifferentShape(HloInstruction* old_instruction,
455 HloInstruction* new_instruction);
456
457 // Set/get the module containing this computation.
set_parent(HloModule * module)458 void set_parent(HloModule* module) { parent_ = module; }
parent()459 const HloModule* parent() const { return parent_; }
parent()460 HloModule* parent() { return parent_; }
461
462 // Visit every node in the computation in DFS post-order with the given
463 // visitor. This is similar to calling HloInstruction::Accept on the root of
464 // the computation except this method also visits instructions not reachable
465 // via the root. The root instruction of the computation is visited last, and
466 // the visitor's FinishVisit method is called once upon completion (with the
467 // root instruction as the argument).
468 template <typename HloInstructionPtr>
469 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
470
471 // Same as Accept() above, but the order of operand and control predecessor
472 // visitation is determined by the given operand order; if compare(A, B) ==
473 // true, A is visited before B.
474 Status AcceptWithOperandOrder(
475 DfsHloVisitor* visitor,
476 const HloInstruction::CompareFunction& operand_order) const;
477
478 // Visit every node in the computation in the given order. 'order' must
479 // be a topological sort of all instructions in the computation.
480 template <typename HloInstructionPtr>
481 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
482 absl::Span<HloInstruction* const> order) const;
483
484 // Returns a deep copy of this computation including all instructions.
485 // If the clone context is specified, it will be populated with the cloned
486 // object mappings, and its module() will be used to add new computations
487 // into.
488 std::unique_ptr<HloComputation> Clone(const std::string& suffix = "clone",
489 HloCloneContext* context = nullptr);
490
491 // Like Clone(), but if an instruction is present in replacement_map, we use
492 // the map's value to replace that instruction in the cloned computation.
493 //
494 // If replacements is nullptr, don't perform replacement.
495 // If replacements maps a key to nullptr, we remove that instruction from the
496 // new computation. If an element of `replacements` references an instruction
497 // that's not already in the computation, it's cloned and added to the new
498 // computation.
499 //
500 // 'extra_parameters' allows to specify additional parameters that should be
501 // added to the computation.
502 //
503 // All relevant instructions are cloned, *including* unique_ptr in the
504 // `replacements` map.
505 std::unique_ptr<HloComputation> CloneWithReplacements(
506 const absl::flat_hash_map<const HloInstruction*,
507 std::unique_ptr<HloInstruction>>* replacements,
508 absl::Span<const HloInstruction* const> extra_parameters = {},
509 HloCloneContext* context = nullptr, const std::string& suffix = "clone",
510 const HloInstruction* new_root = nullptr);
511
512 // Convenience overloads for CloneWithReplacements. You want to do
513 //
514 // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR
515 //
516 // but that doesn't work because std::initializer_list is not movable. These
517 // overloads let you do
518 //
519 // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK
520 //
521 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
522 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
523 HloCloneContext* context = nullptr, const std::string& suffix = "clone");
524 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
525 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
526 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
527 HloCloneContext* context = nullptr, const std::string& suffix = "clone");
528 std::unique_ptr<HloComputation> CloneWithReplacementPairs(
529 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
530 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
531 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
532 HloCloneContext* context = nullptr, const std::string& suffix = "clone");
533
534 // Returns true if the given instruction can be removed from the computation.
535 // Parameter instructions cannot be removed without violating invariants of
536 // the HLO computation with the exception of fusion computation. A parameter
537 // instruction is removable for a fusion computation.
538 //
539 // Note that IsSafelyRemovable() is a necessary condition to remove an
540 // instruction rather than a sufficient condition. For example, instructions
541 // with side-effect (e.g., Send, Infeed) may be removed from a computation,
542 // but the transformation must guarantee the invariants relevant to the
543 // instructions still hold (e.g., Send and Recv must be removed together to
544 // make each channel complete).
545 bool IsSafelyRemovable(const HloInstruction* instruction);
546
547 // Returns a map from channel-id to the group of instructions associated with
548 // the channel. These instructions will be considered as a single node for
549 // dependency purposes. Send and RecvDone are in the group, and AllReduces
550 // with the same channel id are in the group.
551 using ChannelDependencyGroup =
552 absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>;
553 ChannelDependencyGroup ComputeChannelDependencies() const;
554
555 // Returns true if this computation has a side effect. A computation has a
556 // side effect if it contains one or more instructions with a side effect.
557 bool HasSideEffect() const;
558
559 // Returns if this computation is a fusion computation.
560 // Do not use this method to determine if fusion_instruction_ != nullptr.
561 // Instead, directly do: FusionInstruction() != nullptr
IsFusionComputation()562 bool IsFusionComputation() const { return is_fusion_computation_; }
563
564 // Returns if this computation is the entry computation of the module.
565 bool IsEntryComputation() const;
566
567 // Returns the owning fusion instruction, or nullptr if this is not a fusion
568 // computation.
FusionInstruction()569 HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)570 void SetFusionInstruction(HloInstruction* fusion_instruction) {
571 CHECK(!IsCustomCallComputation() && !IsAsyncComputation());
572 fusion_instruction_ = fusion_instruction;
573 is_fusion_computation_ |= (fusion_instruction != nullptr);
574 }
575
576 // Returns if this computation is a custom-call computation.
IsCustomCallComputation()577 bool IsCustomCallComputation() const { return is_custom_call_computation_; }
578
579 // Returns the owning custom call instruction, or nullptr if this is not a
580 // custom call computation.
CustomCallInstruction()581 HloInstruction* CustomCallInstruction() const {
582 return custom_call_instruction_;
583 }
SetCustomCallInstruction(HloInstruction * custom_call_instruction)584 void SetCustomCallInstruction(HloInstruction* custom_call_instruction) {
585 CHECK(!IsFusionComputation() && !IsAsyncComputation());
586 custom_call_instruction_ = custom_call_instruction;
587 is_custom_call_computation_ |= (custom_call_instruction != nullptr);
588 }
589
590 // Returns if this computation is an async computation.
IsAsyncComputation()591 bool IsAsyncComputation() const { return !async_instructions_.empty(); }
592
593 // Returns the owning async instruction. It's empty if this is not an async
594 // computation.
AsyncInstructions()595 const std::vector<HloInstruction*>& AsyncInstructions() const {
596 return async_instructions_;
597 }
598
AsyncInstructions()599 std::vector<HloInstruction*>& AsyncInstructions() {
600 return async_instructions_;
601 }
602
AddAsyncInstruction(HloInstruction * async_instruction)603 void AddAsyncInstruction(HloInstruction* async_instruction) {
604 CHECK(async_instruction != nullptr)
605 << "Nullptr shouldn't be added as commputation's async instruction. ";
606 CHECK(!IsFusionComputation() && !IsCustomCallComputation());
607 CHECK(async_instruction->opcode() == HloOpcode::kAsyncStart ||
608 async_instruction->opcode() == HloOpcode::kAsyncUpdate ||
609 async_instruction->opcode() == HloOpcode::kAsyncDone);
610 async_instructions_.push_back(async_instruction);
611 }
612
RemoveAsyncInstruction(HloInstruction * async_instruction)613 void RemoveAsyncInstruction(HloInstruction* async_instruction) {
614 if (async_instruction == nullptr) {
615 return;
616 }
617 async_instructions_.erase(
618 std::remove(async_instructions_.begin(), async_instructions_.end(),
619 async_instruction),
620 async_instructions_.end());
621 }
622
623 // Returns if this computation is invoked by an Hlo instruction.
IsCalledComputation()624 bool IsCalledComputation() const {
625 return IsFusionComputation() || IsCustomCallComputation();
626 }
627
628 // Clear the unique ID of the computation so that it can be re-assigned, such
629 // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()630 void ClearUniqueIdInternal() { unique_id_ = -1; }
631
632 // The id of this computation should be unique within the module.
SetUniqueId(int64_t id)633 void SetUniqueId(int64_t id) {
634 CHECK_EQ(unique_id_, -1);
635 CHECK_GE(id, 0);
636 unique_id_ = id;
637 }
638
639 // Returns the instruction in this computation that has name `name`. Returns
640 // null if there is no such computation.
641 HloInstruction* GetInstructionWithName(absl::string_view name);
642
unique_id()643 int64_t unique_id() const { return unique_id_; }
644
SetExecutionThread(absl::string_view execution_thread)645 void SetExecutionThread(absl::string_view execution_thread) {
646 execution_thread_ = std::string(execution_thread);
647 }
648
execution_thread()649 absl::string_view execution_thread() const { return execution_thread_; }
650 // Returns true if this computation is annotated on "main" execution thread.
IsMainThread()651 bool IsMainThread() const {
652 return execution_thread_ == HloInstruction::kMainExecutionThread;
653 }
654
655 // Deallocate instructions that are marked by "RemoveInstruction". The two
656 // stage clean up process is designed such that HloPass can have stable
657 // internal pointers to HloInstructions while we create and remove
658 // HloInstructions in a pass.
Cleanup()659 void Cleanup() { to_be_deleted_.clear(); }
660
661 // Returns true if a given instruction is marked dead in this computation.
662 bool IsMarkedAsDead(const HloInstruction* inst);
663
664 private:
665 explicit HloComputation(
666 const std::string& name, int parameter_count,
667 std::vector<std::unique_ptr<HloInstruction>>* instructions,
668 HloInstruction* root_instruction, HloInstruction* fusion_instruction);
669
670 // Internal helper for adding instructions.
671 HloInstruction* AddInstructionInternal(
672 std::unique_ptr<HloInstruction> instruction);
673
674 // Internal helper for comparison with different options.
675 bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
676 bool ignore_channel_id_values,
677 bool ignore_execution_thread) const;
678
679 // Appends (fuses) HLOs in instructions_to_append into the called computation
680 // of the caller.
681 void AppendInstructionsIntoCalledComputation(
682 absl::Span<HloInstruction* const> instructions_to_append,
683 HloInstruction* caller);
684
685 // Internal helper for recursive copying of an instruction. Creates and
686 // returns a deep copy of the given instruction.
687 StatusOr<HloInstruction*> DeepCopyHelper(
688 HloInstruction* instruction, ShapeIndex* index,
689 const std::function<
690 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
691 HloComputation* computation)>& copy_leaf);
692
693 // Internal helper to collect unreachable roots.
694 std::vector<HloInstruction*> CollectUnreachableRoots() const;
695
696 enum VisitState { kVisiting, kVisited };
697 void ComputeInstructionPostOrder(
698 HloInstruction* root,
699 HloComputation::ChannelDependencyGroup& channel_dependencies,
700 absl::flat_hash_map<HloInstruction*, VisitState>& visited,
701 std::vector<HloInstruction*>& post_order) const;
702
703 Status RemoveUnusedParametersImpl(bool allow_non_fusion);
704
705 Status RemoveInstructionImpl(HloInstruction* instruction,
706 bool ignore_safety_check);
707
708 std::string name_;
709 int64_t unique_id_;
710 HloInstruction* root_instruction_;
711
712 // If this computation is a fusion computation, this field points to the
713 // corresponding fusion instruction (if it is live). Otherwise, this is null.
714 HloInstruction* fusion_instruction_;
715
716 // Determines whether this computation is a fusion computation. A fusion
717 // computation ordinarily also has a non-null fusion_instruction_. However, if
718 // a fusion instruction is removed during compilation, the fusion computation
719 // becomes unreachable, and its fusion_instruction_ is set to null. We still
720 // need to regard such computations as fusion computations for HLO scheduling
721 // purposes.
722 bool is_fusion_computation_;
723
724 // If this computation is a custom-call computation, this field points to the
725 // corresponding custom-call instruction (if it is live). Otherwise, this is
726 // null.
727 HloInstruction* custom_call_instruction_;
728
729 // Determines whether this computation is a custom-call computation.
730 bool is_custom_call_computation_;
731
732 // If this computation is an async computation, this field points to the
733 // corresponding async instructions (if live) that call this computation.
734 // Otherwise, this is empty.
735 std::vector<HloInstruction*> async_instructions_;
736
737 // Execution thread of this computation. By default, it's main thread.
738 std::string execution_thread_ = HloInstruction::kMainExecutionThread;
739
740 // Module containing this computation.
741 HloModule* parent_ = nullptr;
742
743 // Store instructions in std::list as they can be added and removed
744 // arbitrarily and we want a stable iteration order. Keep a map from
745 // instruction pointer to location in the list for fast lookup.
746 InstructionList instructions_;
747 absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
748 instruction_iterators_;
749
750 // Removed instructions are moved into to_be_deleted_ first and then
751 // deallocated when Cleanup is called.
752 std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
753
754 std::vector<HloInstruction*> param_instructions_;
755
756 HloComputation(const HloComputation&) = delete;
757 HloComputation& operator=(const HloComputation&) = delete;
758 };
759
760 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)761 Status HloComputation::Accept(
762 DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
763 // Visit unreachable roots. Beware that the visitor might delete the currently
764 // visited root, which would invalidate iterators if the unreachable roots
765 // weren't computed ahead of time.
766 for (HloInstruction* root : CollectUnreachableRoots()) {
767 VLOG(3) << "Traversing unreachable root: " << root->ToString();
768 // Call FinishVisit only at the end.
769 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
770 }
771 // Visit the computation root instruction last.
772 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
773 }
774
775 // Explicit instantiations.
776 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
777 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
778
779 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)780 Status HloComputation::AcceptOrdered(
781 DfsHloVisitorBase<HloInstructionPtr>* visitor,
782 absl::Span<HloInstruction* const> order) const {
783 VLOG(3) << "Accepting visitor with order.";
784 for (HloInstruction* root : CollectUnreachableRoots()) {
785 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
786 }
787 TF_RET_CHECK(order.size() == instruction_count());
788 absl::flat_hash_set<const HloInstruction*> visited;
789 for (const HloInstruction* instruction : order) {
790 VLOG(3) << "Visiting ordered: " << instruction->ToString();
791 TF_RET_CHECK(instruction_iterators_.contains(instruction))
792 << "Instruction " << instruction->name() << " is not in computation "
793 << name();
794 TF_RET_CHECK(!visited.contains(instruction))
795 << "Instruction " << instruction->name()
796 << " appears more than once in order";
797 HloInstruction* mutable_instruction =
798 const_cast<HloInstruction*>(instruction);
799 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
800 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
801 visitor->SetVisited(*mutable_instruction);
802 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
803 visited.insert(instruction);
804 }
805 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
806 return OkStatus();
807 }
808
809 // Explicit instantiations.
810 template Status HloComputation::AcceptOrdered(
811 DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
812 template Status HloComputation::AcceptOrdered(
813 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
814
815 } // namespace xla
816
817 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
818