• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <tuple>
31 #include <vector>
32 
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/container/inlined_vector.h"
36 #include "absl/memory/memory.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/comparison_util.h"
41 #include "tensorflow/compiler/xla/iterator_util.h"
42 #include "tensorflow/compiler/xla/literal.h"
43 #include "tensorflow/compiler/xla/map_util.h"
44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
45 #include "tensorflow/compiler/xla/service/hlo.pb.h"
46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
50 #include "tensorflow/compiler/xla/service/name_uniquer.h"
51 #include "tensorflow/compiler/xla/shape_tree.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/iterator_range.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/macros.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/platform/types.h"
60 
61 namespace xla {
62 
63 class HloComputation;
64 class HloModule;
65 
66 string PrintName(const string& name, bool print_ids);
67 
68 // A bunch of switches that control how the hlo text should be printed.
69 class HloPrintOptions {
70  public:
71   enum class PrintSubcomputationMode {
72     kOff,         // Do not print anything about subcomputations.
73     kNameOnly,    // Only print the name of subcomputations.
74     kFullBodies,  // Print the full bodies of subcomputations.
75   };
76 
77   // Constructs the default print options: don't print large constants, don't
78   // compact operands, no indentation.
HloPrintOptions()79   HloPrintOptions()
80       : print_large_constants_(false),
81         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
82         print_metadata_(true),
83         print_backend_config_(true),
84         compact_operands_(false),
85         include_layout_in_shapes_(true),
86         print_operand_shape_(true),
87         print_operand_names_(true),
88         print_program_shape_(true),
89         print_percent_(true),
90         print_control_dependencies_(true),
91         canonicalize_instruction_names_(false),
92         indent_amount_(0),
93         is_in_nested_computation_(false),
94         print_ids_(true),
95         canonicalize_computations_(false) {}
96 
ShortParsable()97   static HloPrintOptions ShortParsable() {
98     return HloPrintOptions()
99         .set_print_large_constants(true)
100         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
101         .set_print_metadata(false)
102         .set_print_backend_config(false)
103         .set_print_operand_shape(false)
104         .set_print_program_shape(false)
105         .set_print_percent(false)
106         .set_print_control_dependencies(false);
107   }
108 
109   // Options to produce the canonical string representing an isomorphic
110   // computation graph.
Canonical()111   static HloPrintOptions Canonical() {
112     return HloPrintOptions()
113         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
114         .set_print_metadata(false)
115         .set_print_backend_config(false)
116         .set_compact_operands(false)
117         .set_print_operand_names(false)
118         .set_print_operand_shape(true)
119         .set_print_program_shape(false)
120         .set_print_percent(false)
121         .set_print_control_dependencies(false)
122         .set_canonicalize_instruction_names(true);
123   }
124 
125   // Options to produce a fingerprint of an HLO.
Fingerprint()126   static HloPrintOptions Fingerprint() {
127     return HloPrintOptions()
128         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
129         .set_print_metadata(false)
130         .set_print_backend_config(false)
131         .set_compact_operands(true)
132         .set_print_operand_names(false)
133         .set_print_operand_shape(true)
134         .set_print_program_shape(false)
135         .set_print_percent(false)
136         .set_print_control_dependencies(false)
137         .set_canonicalize_instruction_names(true)
138         .set_print_ids(false)
139         .set_canonicalize_computations(true);
140   }
141 
142   // If true, large constants will be printed out.
set_print_large_constants(bool value)143   HloPrintOptions& set_print_large_constants(bool value) {
144     print_large_constants_ = value;
145     return *this;
146   }
147 
set_print_subcomputation_mode(PrintSubcomputationMode value)148   HloPrintOptions& set_print_subcomputation_mode(
149       PrintSubcomputationMode value) {
150     print_subcomputation_mode_ = value;
151     return *this;
152   }
153 
154   // If true, metadata will be printed.
set_print_metadata(bool value)155   HloPrintOptions& set_print_metadata(bool value) {
156     print_metadata_ = value;
157     return *this;
158   }
159 
160   // If true, backend_config will be printed.
set_print_backend_config(bool value)161   HloPrintOptions& set_print_backend_config(bool value) {
162     print_backend_config_ = value;
163     return *this;
164   }
165 
166   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)167   HloPrintOptions& set_print_operand_shape(bool value) {
168     print_operand_shape_ = value;
169     return *this;
170   }
171 
172   // If true, the operand names will be printed.
set_print_operand_names(bool value)173   HloPrintOptions& set_print_operand_names(bool value) {
174     print_operand_names_ = value;
175     return *this;
176   }
177 
178   // If true, all printed names include unique identifiers.
set_print_ids(bool value)179   HloPrintOptions& set_print_ids(bool value) {
180     print_ids_ = value;
181     return *this;
182   }
183 
184   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)185   HloPrintOptions& set_print_program_shape(bool value) {
186     print_program_shape_ = value;
187     return *this;
188   }
189 
190   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)191   HloPrintOptions& set_print_percent(bool value) {
192     print_percent_ = value;
193     return *this;
194   }
195 
196   // If true, control dependencies will be printed.
set_print_control_dependencies(bool value)197   HloPrintOptions& set_print_control_dependencies(bool value) {
198     print_control_dependencies_ = value;
199     return *this;
200   }
201 
202   // If true, only a part of operands will be printed out (note that in this
203   // case the text will not be parsable).
set_compact_operands(bool value)204   HloPrintOptions& set_compact_operands(bool value) {
205     compact_operands_ = value;
206     return *this;
207   }
208 
209   // If true, include the layout in any shapes that are printed (instruction
210   // and operands).
set_include_layout_in_shapes(bool value)211   HloPrintOptions& set_include_layout_in_shapes(bool value) {
212     include_layout_in_shapes_ = value;
213     return *this;
214   }
215 
216   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
217   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
set_canonicalize_instruction_names(bool value)218   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
219     canonicalize_instruction_names_ = value;
220     return *this;
221   }
222 
223   // If true, canonicalizes computations, sorting by computations' names.
set_canonicalize_computations(bool value)224   HloPrintOptions& set_canonicalize_computations(bool value) {
225     canonicalize_computations_ = value;
226     return *this;
227   }
228 
229   // The indent of the hlo text block.
set_indent_amount(int value)230   HloPrintOptions& set_indent_amount(int value) {
231     indent_amount_ = value;
232     return *this;
233   }
234 
235   // If true, indicates the instruction being printed is inside a nested
236   // computation.
set_is_in_nested_computation(bool value)237   HloPrintOptions& set_is_in_nested_computation(bool value) {
238     is_in_nested_computation_ = value;
239     return *this;
240   }
241 
242   // Instructions are selected for printing by a predicate function
243   // (`set_print_instructions`). We also print their surrounding instructions
244   // for ease of reading. We print `leading_and_trailing_instructions_number`
245   // instructions before and after the qualified ones inside a computation.
set_leading_and_trailing_instructions_number(int value)246   HloPrintOptions& set_leading_and_trailing_instructions_number(int value) {
247     leading_and_trailing_instructions_number_ = value;
248     return *this;
249   }
250 
251   // A callback which takes an HloInstruction*, its string representation,
252   // the indentation level of the resulting block, and a
253   // bool variable indicating whether the instruction is root or not. The return
254   // value is a string which is used for this instruction during printing.
255   using FormatInstructionFunc =
256       std::function<string(const HloInstruction*, const string&, int, bool)>;
257 
set_format_instruction(FormatInstructionFunc callback)258   HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) {
259     format_instruction_ = callback;
260     return *this;
261   }
262 
263   using HloInstructionPredicate = std::function<bool(const HloInstruction*)>;
264 
265   // A callback which takes an HloInstruction* and returns whether it should be
266   // printed or not.
set_print_instruction(HloInstructionPredicate callback)267   HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) {
268     print_instruction_ = callback;
269     return *this;
270   }
271 
272   using HloComputationPredicate = std::function<bool(const HloComputation*)>;
273 
274   // A callback which takes an HloComputation* and returns whether it should be
275   // printed or not.
set_print_computation(HloComputationPredicate callback)276   HloPrintOptions& set_print_computation(HloComputationPredicate callback) {
277     print_computation_ = callback;
278     return *this;
279   }
280 
print_large_constants()281   bool print_large_constants() const { return print_large_constants_; }
print_subcomputation_mode()282   PrintSubcomputationMode print_subcomputation_mode() const {
283     return print_subcomputation_mode_;
284   }
print_metadata()285   bool print_metadata() const { return print_metadata_; }
print_backend_config()286   bool print_backend_config() const { return print_backend_config_; }
compact_operands()287   bool compact_operands() const { return compact_operands_; }
include_layout_in_shapes()288   bool include_layout_in_shapes() const { return include_layout_in_shapes_; }
print_operand_shape()289   bool print_operand_shape() const { return print_operand_shape_; }
print_operand_names()290   bool print_operand_names() const { return print_operand_names_; }
print_ids()291   bool print_ids() const { return print_ids_; }
print_program_shape()292   bool print_program_shape() const { return print_program_shape_; }
print_percent()293   bool print_percent() const { return print_percent_; }
print_control_dependencies()294   bool print_control_dependencies() const {
295     return print_control_dependencies_;
296   }
canonicalize_instruction_names()297   bool canonicalize_instruction_names() const {
298     return canonicalize_instruction_names_;
299   }
canonicalize_computations()300   bool canonicalize_computations() const { return canonicalize_computations_; }
indent_amount()301   int indent_amount() const { return indent_amount_; }
is_in_nested_computation()302   int is_in_nested_computation() const { return is_in_nested_computation_; }
leading_and_trailing_instructions_number()303   int leading_and_trailing_instructions_number() const {
304     return leading_and_trailing_instructions_number_;
305   }
format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)306   string format_instruction(const HloInstruction* instr,
307                             const string& instr_name, int indent,
308                             bool is_root) const {
309     return format_instruction_(instr, instr_name, indent, is_root);
310   }
print_instruction(const HloInstruction * instr)311   bool print_instruction(const HloInstruction* instr) const {
312     return print_instruction_(instr);
313   }
print_computation(const HloComputation * comp)314   bool print_computation(const HloComputation* comp) const {
315     return print_computation_(comp);
316   }
317 
318  private:
319   bool print_large_constants_;
320   PrintSubcomputationMode print_subcomputation_mode_;
321   bool print_metadata_;
322   bool print_backend_config_;
323   bool compact_operands_;
324   bool include_layout_in_shapes_;
325   bool print_operand_shape_;
326   bool print_operand_names_;
327   bool print_program_shape_;
328   bool print_percent_;
329   bool print_control_dependencies_;
330   bool canonicalize_instruction_names_;
331   int indent_amount_;
332   bool is_in_nested_computation_;
333   bool print_ids_;
334   bool canonicalize_computations_;
335   int leading_and_trailing_instructions_number_ = 3;
336   FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr,
337                                                  const string& instr_name,
338                                                  int indent, bool is_root) {
339     return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "",
340                         instr_name);
341   };
342   HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) {
343     return true;
344   };
345   HloComputationPredicate print_computation_ = [](const HloComputation* comp) {
346     return true;
347   };
348 };
349 
350 // For canonical string output, we need to have a canonical way to rename
351 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
352 // where <xxx> is an index starting from 0.
353 class CanonicalNameMap {
354  public:
CanonicalNameMap()355   CanonicalNameMap() : index(0) {}
356 
LookupOrInsert(const string & old_name)357   string LookupOrInsert(const string& old_name) {
358     auto iter = canonical_name_map.find(old_name);
359     if (iter != canonical_name_map.end()) {
360       return iter->second;
361     }
362 
363     string new_name = absl::StrCat("tmp_", index++);
364     canonical_name_map[old_name] = new_name;
365     return new_name;
366   }
Clear()367   void Clear() {
368     canonical_name_map.clear();
369     index = 0;
370   }
371 
372  private:
373   int64 index;
374   absl::flat_hash_map<string, string> canonical_name_map;
375 };
376 
377 // HLO instructions are the atomic unit of the high-level compiler's IR.
378 //
379 // HloInstructions live inside of an HloComputation, which is analogous to a
380 // function in other programming languages.  Nodes have no total order within
381 // their computation.  Instead, they have a partial ordering determined by their
382 // data and control dependencies.
383 //
384 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
385 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
386 // control flow.  For example, the kConditional HLO executes one of two possible
387 // computations, depending on the runtime value of a predicate.
388 //
389 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
390 // values are produced by one HLO and flow into consumers across dependency
391 // edges.
392 class HloInstruction {
393  public:
394   // A fusion node computes the same value a call to its fusion computation
395   // would compute.  However, the choice of fusion kind dictates codegen
396   // strategy for the backend.
397   //
398   // To generate code for a kFusion HloInstruction, most backends do something
399   // like the following:
400   //
401   // 1) Identify the "primary" HloInstruction of the fused computation.
402   // 2) Emit code that does the work of the primary node, creating its inputs
403   //    and transforming its outputs as specified by the fused computation.
404   //
405   // In step (2), the code emitted is usually similar to the code that would be
406   // emitted for an *unfused* version of the primary node, except that
407   //
408   //  - when the primary node reads an element of one of its operands, instead
409   //    of loading the value from memory, it *computes* the value based on the
410   //    contents of the fused computation.
411   //  - when the primary node outputs a value, instead of storing it to memory,
412   //    it forwards the value to its users, which then perform additional
413   //    computations before the value is finally stored to memory at the root of
414   //    the fusion node.
415   //
416   // An HloInstruction's FusionKind helps us find the kFusion instruction's
417   // primary node, and can also affect how we generate code in step (2).
418   //
419   //  - kInput: The primary node is the root of the fused instruction.
420   //
421   //  - kOutput: The primary node is not the root of the fused instruction.
422   //    This fusion kind requires that one operand buffer of the fusion
423   //    instruction be able to alias the output buffer.  This constraint is
424   //    usually enough to let backends find the primary node unambiguously.
425   //
426   //  - kLoop: The primary node is the root of the fused computation, but,
427   //    unlike in input fusion, we prescribe a specific implementation for
428   //    codegen.  Rather than generating code that looks like the code we'd emit
429   //    for an unfused version of the primary/root node, we emit code that
430   //    generates one element of the root at a time.
431   //
432   //  - kCustom: Custom category for backend-specific fusions that don't fit
433   //    into the above patterns.
434   //
435   // Not all backends support all fusion kinds, and given a particular fused
436   // computation, it's not in general safe to change its fusion kind.  Creation
437   // of fusion nodes is always backend-specific.
438   //
439   // For elementwise ops (e.g. kAdd), most backends would emit a
440   // one-element-at-a-time implementation for the unfused version, so loop
441   // fusion and input fusion are probably equivalent if the root node is
442   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
443   // implementation might emit something more sophisticated for an unfused or
444   // input-fusion reduce, but will emit the naive code that reduces one element
445   // at a time for loop fusion with a reduce as the root.
446   //
447   // Another way to think of loop fusion is that it's equivalent to input
448   // fusion, but where the root node is an implicit identity node, whose
449   // unfused implementation is "read one element, write one element".
450   //
451   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
452   // input and loop fusion are basically the same thing: There is no reason for
453   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
454   // the root of a fusion should be lowered.  In addition, this scheme as
455   // written doesn't work for multi-output fusion, where the primary node is
456   // never actually the root (which is a kTuple instruction that gathers the
457   // multiple outputs of the fusion).
458   enum class FusionKind {
459     kLoop,
460     kInput,
461     kOutput,
462     kCustom,
463   };
464 
465   virtual ~HloInstruction();
466 
467   // Creates an instruction from the given proto. Arguments:
468   //
469   //   proto: the proto to convert from.
470   //   instruction_map: a map from instruction id to HloInstruction*. This map
471   //     must contain all operands of the newly constructed instruction.
472   //   computation_map: a map from computation id to HloComputation*. This map
473   //     must contain all computations which the newly constructed instruction
474   //     calls.
475   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
476       const HloInstructionProto& proto,
477       const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
478       const absl::flat_hash_map<int64, HloComputation*>& computation_map,
479       bool prohibit_empty_literal = true);
480 
481   // Creates a parameter-retrieving instruction.
482   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
483                                                          const Shape& shape,
484                                                          const string& name);
485 
486   // Creates a literal constant instruction.
487   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
488 
489   // Creates an Iota instruction.
490   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
491                                                     int64 iota_dimension);
492 
493   // Creates a get tuple element instruction.
494   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
495       const Shape& shape, HloInstruction* operand, int64 index);
496 
497   // Creates a trace instruction that logs the input operand in the computation.
498   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
499                                                      HloInstruction* operand);
500 
501   // Creates a random number generation instruction that fills a shape with
502   // random numbers from a given distribution.
503   //
504   // The parameters to the instruction are interpreted as follows:
505   //
506   //  - If `distribution` is RNG_UNIFORM, generates a number in range
507   //    [param0, param1).
508   //
509   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
510   //    with mean `param0` and standard deviation `param1`.
511   static std::unique_ptr<HloInstruction> CreateRng(
512       const Shape& shape, RandomDistribution distribution,
513       absl::Span<HloInstruction* const> parameters);
514 
515   // Creates an instruction to update the random number generator state to
516   // reflect the new state after `delta` units of 32 random bits are generated
517   // and returns the old state.
518   static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState(
519       const Shape& shape, int64 delta);
520 
521   // Creates a unary instruction (one operand).
522   // Precondition: opcode must be a legitimate unary operation.
523   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
524                                                      HloOpcode opcode,
525                                                      HloInstruction* operand);
526 
527   // Creates a binary instruction (two operands).
528   // Precondition: opcode must be a legitimate binary operation.
529   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
530                                                       HloOpcode opcode,
531                                                       HloInstruction* lhs,
532                                                       HloInstruction* rhs);
533 
534   // Creates a ternary instruction (three operands).
535   // Precondition: opcode must be a legitimate ternary operation.
536   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
537                                                        HloOpcode opcode,
538                                                        HloInstruction* lhs,
539                                                        HloInstruction* rhs,
540                                                        HloInstruction* ehs);
541 
542   // Creates a variadic instruction (variable number of operands).
543   // Precondition: opcode must be a legitimate variadic operation.
544   static std::unique_ptr<HloInstruction> CreateVariadic(
545       const Shape& shape, HloOpcode opcode,
546       absl::Span<HloInstruction* const> operands);
547 
548   // Creates a map instruction, where the computation (given by the handle) is
549   // applied element-wise to every element in operands (across the operands,
550   // at a given index)
551   static std::unique_ptr<HloInstruction> CreateMap(
552       const Shape& shape, absl::Span<HloInstruction* const> operands,
553       HloComputation* map_computation);
554 
555   // Creates a convolution op, where rhs is the convolutional filter
556   // and window describes how the filter is applied to lhs.
557   static std::unique_ptr<HloInstruction> CreateConvolve(
558       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
559       int64 feature_group_count, int64 batch_group_count, const Window& window,
560       const ConvolutionDimensionNumbers& dimension_numbers,
561       const PrecisionConfig& precision_config);
562 
563   // Creates an FFT op, of the type indicated by fft_type.
564   static std::unique_ptr<HloInstruction> CreateFft(
565       const Shape& shape, HloInstruction* operand, FftType fft_type,
566       absl::Span<const int64> fft_length);
567 
568   // Creates a compare op, performing the comparison specified in direction.
569   static std::unique_ptr<HloInstruction> CreateCompare(
570       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
571       ComparisonDirection direction);
572 
573   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
574       const Shape& shape, HloInstruction* a, HloInstruction* b,
575       const TriangularSolveOptions& options);
576 
577   static std::unique_ptr<HloInstruction> CreateCholesky(
578       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
579 
580   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
581   // dimensions specified in 'dimension_numbers'.
582   static std::unique_ptr<HloInstruction> CreateDot(
583       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
584       const DotDimensionNumbers& dimension_numbers,
585       const PrecisionConfig& precision_config);
586 
587   // Creates a reduce-precision op, where operand is the data to reduce in
588   // precision, and exponent_bits and mantissa_bits describe the precision to
589   // reduce it to.
590   static std::unique_ptr<HloInstruction> CreateReducePrecision(
591       const Shape& shape, HloInstruction* operand, const int exponent_bits,
592       const int mantissa_bits);
593 
594   // Creates a cross replica reduction op.
595   //
596   // `reduction_computation`: the reduction function.
597   //
598   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
599   // empty, all replicas belong to one group in the order of 0 - (n-1).
600   // Allreduce will be applied within subgroups.
601   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
602   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
603   //
604   // `channel_id`: for Allreduce nodes from different modules, if
605   // they have the same channel_id, they will be 'Allreduce'd. If
606   // empty, Allreduce will not be applied cross modules.
607   static std::unique_ptr<HloInstruction> CreateAllReduce(
608       const Shape& shape, absl::Span<HloInstruction* const> operands,
609       HloComputation* reduce_computation,
610       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
611       const absl::optional<int64>& channel_id);
612 
613   // An all-to-all op takes N array operands of the same shape and scatters them
614   // to N replicas.  Each replica gathers the results into a tuple.
615   //
616   // For example, suppose we have 3 replicas, with replica i passing inputs
617   // [a_i, b_i, c_i] to its all-to-all op.  Then the resulting tuples are
618   //
619   //   replica 0: (a_0, a_1, a_2)
620   //   replica 1: (b_0, b_1, b_2)
621   //   replica 2: (c_0, c_1, c_2).
622   //
623   // If replica_groups is set, the op is sharded and the replicas are permuted.
624   // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}.
625   // Then each replica passes two operands, say [a_i, b_i], and the result is
626   //
627   //   replica 0: (b_3, b_0)
628   //   replica 1: (a_1, a_2)
629   //   replica 2: (b_1, b_2)
630   //   replica 3: (a_3, a_0).
631   //
632   // All replica groups must have the same number of elements, and the number of
633   // operands must be equal to the size of one replica group.  Each replica must
634   // appear in exactly one group.
635   //
636   // Note that this instruction is different than the all-to-all op in
637   // xla_builder.h.  The version in XlaBuilder takes one input and slices it,
638   // and then concatenates the results into a single array.  This instruction
639   // takes multiple inputs and returns a tuple; it doesn't slice or concatenate.
640   // It is used to implement the higher-level instruction in XlaBuilder.
641   static std::unique_ptr<HloInstruction> CreateAllToAll(
642       const Shape& shape, absl::Span<HloInstruction* const> operands,
643       const std::vector<ReplicaGroup>& replica_groups,
644       const absl::optional<int64>& channel_id,
645       const absl::optional<int64>& split_dimension = absl::nullopt);
646 
647   // Creates a communication instructions that permutes data cross replicas.
648   // Data is sent/received according to the (source_replica_id,
649   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
650   // target_replica_id in any pair, the output on that replica is a tensor
651   // consists of 0(s) in `shape`.
652   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
653       const Shape& shape, HloInstruction* operand,
654       const std::vector<std::pair<int64, int64>>& source_target_pairs,
655       const absl::optional<int64>& channel_id);
656 
657   // Creates an instruction that returns a U32 replica ID.
658   static std::unique_ptr<HloInstruction> CreateReplicaId();
659 
660   // Creates an instruction that returns a U32 partition ID.
661   static std::unique_ptr<HloInstruction> CreatePartitionId();
662 
663   // Creates a conversion instruction, where operand is the data to convert and
664   // shape is the target shape for the conversion.
665   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
666                                                        HloInstruction* operand);
667 
668   // Creates a bitcast instruction, where operand is the data to
669   // convert and shape is the target shape for the conversion.
670   static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape,
671                                                        HloInstruction* operand);
672 
673   // Creates a bitcast conversion instruction, where operand is the data to
674   // convert and shape is the target shape for the conversion.
675   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
676       const Shape& shape, HloInstruction* operand);
677 
678   // Creates an infeed instruction, which reads data of the given shape from the
679   // Infeed interface of the device. infeed_shape is the shape of the data
680   // received from the infeed *not* the shape of the infeed instruction which
681   // is a tuple containing the infeed_shape and the TOKEN.
682   static std::unique_ptr<HloInstruction> CreateInfeed(
683       const Shape& infeed_shape, HloInstruction* token_operand,
684       const string& config);
685 
686   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
687   // shape of the data being outfed *not* the shape of the outfeed instruction
688   // which is a TOKEN.
689   static std::unique_ptr<HloInstruction> CreateOutfeed(
690       const Shape& outfeed_shape, HloInstruction* operand,
691       HloInstruction* token_operand, absl::string_view outfeed_config);
692 
693   // Creates an asynchronous send instruction with the given channel id, which
694   // initiates sending the operand data to a unique receive instruction in
695   // another computation that has the same channel id. If is_host_transfer is
696   // true, then this Send operation transfers data to the host.
697   static std::unique_ptr<HloInstruction> CreateSend(
698       HloInstruction* operand, HloInstruction* token, int64 channel_id,
699       bool is_host_transfer = false);
700 
701   // Blocks until data transfer for the Send instruction (operand) is complete.
702   // The operand must be kSend.
703   static std::unique_ptr<HloInstruction> CreateSendDone(
704       HloInstruction* operand, bool is_host_transfer = false);
705 
706   // Creates an asynchronous receive instruction with the given channel id,
707   // which allocates resources to receive data of the given shape from a unique
708   // send instruction in another computation that has the same channel id.  If
709   // is_host_transfer is true, then this Send operation transfers data from the
710   // host.
711   static std::unique_ptr<HloInstruction> CreateRecv(
712       const Shape& shape, HloInstruction* token, int64 channel_id,
713       bool is_host_transfer = false);
714 
715   // Blocks until data transfer for the Recv instruction (operand) is complete
716   // and returns the receive buffer. The operand must be kRecv.
717   static std::unique_ptr<HloInstruction> CreateRecvDone(
718       HloInstruction* operand, bool is_host_transfer = false);
719 
720   // Creates a slice instruction, where the operand is sliced by the given
721   // start/limit indices.
722   static std::unique_ptr<HloInstruction> CreateSlice(
723       const Shape& shape, HloInstruction* operand,
724       absl::Span<const int64> start_indices,
725       absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
726 
727   // Creates a slice instruction, where the first operand is sliced by
728   // start indices specified in the second operand, and by size specified in
729   // 'slice_sizes'.
730   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
731       const Shape& shape, HloInstruction* operand,
732       absl::Span<HloInstruction* const> start_indices,
733       absl::Span<const int64> slice_sizes);
734 
735   // Creates a dynamic update slice instruction, which updates a slice
736   // of 'operand' with 'update' and 'start_indices'.
737   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
738       const Shape& shape, HloInstruction* operand, HloInstruction* update,
739       absl::Span<HloInstruction* const> start_indices);
740 
741   // Creates a concatenate instruction, where the operands are concatenated on
742   // the provided dimension.
743   static std::unique_ptr<HloInstruction> CreateConcatenate(
744       const Shape& shape, absl::Span<HloInstruction* const> operands,
745       int64 dimension);
746 
747   // Creates a reduce instruction, where the computation (given by the handle)
748   // is applied successively to every element in operand. For example, let f be
749   // the function to apply, which takes 2 arguments, an accumulator and the
750   // current value. Let init be an initial value (which is normally chosen to be
751   // the identity element for f, e.g. 0 if f is addition).
752   // Then the reduce HLO will compute:
753   // f(f(init, value0), value1), ...)
754   static std::unique_ptr<HloInstruction> CreateReduce(
755       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
756       absl::Span<const int64> dimensions_to_reduce,
757       HloComputation* reduce_computation);
758 
759   // A more general, multiple-argument version of the above.
760   // The function to apply, f, now takes N arguments:
761   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
762   // init_valueN], and returns an N-tuple. The performed computation is (for
763   // commutative and associative f operators) equivalent to:
764   //
765   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
766   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
767   // ..., inputN.value1)
768   // ...
769   static std::unique_ptr<HloInstruction> CreateReduce(
770       const Shape& shape, absl::Span<HloInstruction* const> operands,
771       absl::Span<HloInstruction* const> init_values,
772       absl::Span<const int64> dimensions_to_reduce,
773       HloComputation* reduce_computation);
774 
775   // Creates a reduce-window instruction, where the computation (given
776   // by the handle) is applied window-wise at each valid window
777   // position in the operand.
778   static std::unique_ptr<HloInstruction> CreateReduceWindow(
779       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
780       const Window& window, HloComputation* reduce_computation);
781 
782   // Creates a batch-norm-training instruction.
783   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
784       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
785       HloInstruction* offset, float epsilon, int64 feature_index);
786 
787   // Creates a batch-norm-inference instruction.
788   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
789       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
790       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
791       float epsilon, int64 feature_index);
792 
793   // Creates a batch-norm-grad instruction.
794   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
795       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
796       HloInstruction* mean, HloInstruction* variance,
797       HloInstruction* grad_output, float epsilon, int64 feature_index);
798 
799   // Creates a scatter computation that scatters the `source` array to the
800   // selected indices of each window.
801   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
802       const Shape& shape, HloInstruction* operand, HloComputation* select,
803       const Window& window, HloInstruction* source, HloInstruction* init_value,
804       HloComputation* scatter);
805 
806   // Creates a broadcast instruction.
807   static std::unique_ptr<HloInstruction> CreateBroadcast(
808       const Shape& shape, HloInstruction* operand,
809       absl::Span<const int64> broadcast_dimensions);
810 
811   // Creates a sequence of instructions that performs an explicit broadcast of
812   // the operand to the target shape.
813   //
814   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
815   // returned as a unique_ptr for API consistency with other factory methods in
816   // this interface.
817   //
818   // TODO(b/72173833) Ideally HloComputations would always be present, and so
819   // the adder being passed by the caller would not be necessary.
820   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
821       const Shape& output_shape, HloInstruction* operand,
822       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
823           adder);
824 
825   // Creates a pad instruction, where the operand is padded on the edges and
826   // between the elements with the given padding value.
827   static std::unique_ptr<HloInstruction> CreatePad(
828       const Shape& shape, HloInstruction* operand,
829       HloInstruction* padding_value, const PaddingConfig& padding_config);
830 
831   // Creates a reshape instruction, where the operand is flattened row-major
832   // order and then reshaped to the given result shape.
833   static std::unique_ptr<HloInstruction> CreateReshape(
834       const Shape& shape, HloInstruction* operand,
835       int64 inferred_dimension = -1);
836 
837   // Creates a transpose instruction which permutes the operand dimensions.
838   static std::unique_ptr<HloInstruction> CreateTranspose(
839       const Shape& shape, HloInstruction* operand,
840       absl::Span<const int64> dimensions);
841 
842   // Creates a n-ary sort op with a 'compare' computation which is used for
843   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
844   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
845   // specific index positions which should be compared, and should return a
846   // PRED. 'is_stable' specifies whether stable sorting is required.
847   static std::unique_ptr<HloInstruction> CreateSort(
848       const Shape& shape, int64 dimension,
849       absl::Span<HloInstruction* const> operands, HloComputation* compare,
850       bool is_stable);
851 
852   // Creates a while instruction, given a condition computation, a body
853   // computation, and the initial value for the input of the computations. For
854   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
855   // corresponds to the C code below.
856   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
857   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
858                                                      HloComputation* condition,
859                                                      HloComputation* body,
860                                                      HloInstruction* init);
861 
862   static std::unique_ptr<HloInstruction> CreateConditional(
863       const Shape& shape, HloInstruction* pred,
864       HloInstruction* true_computation_arg, HloComputation* true_computation,
865       HloInstruction* false_computation_arg, HloComputation* false_computation);
866 
867   static std::unique_ptr<HloInstruction> CreateConditional(
868       const Shape& shape, HloInstruction* branch_index,
869       absl::Span<HloComputation* const> branch_computations,
870       absl::Span<HloInstruction* const> branch_computation_args);
871 
872   static std::unique_ptr<HloInstruction> CreateGather(
873       const Shape& shape, HloInstruction* operand,
874       HloInstruction* start_indices,
875       const GatherDimensionNumbers& gather_dim_numbers,
876       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
877 
878   static std::unique_ptr<HloInstruction> CreateScatter(
879       const Shape& shape, HloInstruction* operand,
880       HloInstruction* scatter_indices, HloInstruction* updates,
881       HloComputation* update_computation,
882       const ScatterDimensionNumbers& scatter_dim_numbers,
883       bool indices_are_sorted, bool unique_indices);
884 
885   // Creates a kDomain instruction which delimits an HLO domain which have
886   // the provided user and operand side metadata.
887   static std::unique_ptr<HloInstruction> CreateDomain(
888       const Shape& shape, HloInstruction* operand,
889       std::unique_ptr<DomainMetadata> operand_side_metadata,
890       std::unique_ptr<DomainMetadata> user_side_metadata);
891 
892   // Creates a fusion instruction. A fusion instruction contains one or more
893   // fused instructions forming an expression with a single root
894   // "fused_root". Additional instructions can be added to the fusion
895   // instruction with the method FuseInstruction.
896   static std::unique_ptr<HloInstruction> CreateFusion(
897       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
898 
899   static std::unique_ptr<HloInstruction> CreateFusion(
900       const Shape& shape, FusionKind fusion_kind,
901       absl::Span<HloInstruction* const> operands,
902       HloComputation* fusion_computation);
903 
904   // Creates a call instruction that applies the given computation on the given
905   // operands. "shape" is the resultant shape.
906   static std::unique_ptr<HloInstruction> CreateCall(
907       const Shape& shape, absl::Span<HloInstruction* const> operands,
908       HloComputation* computation);
909 
910   // Creates a custom call instruction that applies the given custom call target
911   // to the given operands. "opaque" can be an arbitrary string with a
912   // backend-specific interpretation. "shape" is the resultant shape.
913   static std::unique_ptr<HloInstruction> CreateCustomCall(
914       const Shape& shape, absl::Span<HloInstruction* const> operands,
915       absl::string_view custom_call_target, string opaque = "");
916 
917   // Overload which constrains the layouts of the operand and result. 'shape'
918   // and 'operand_shapes_with_layout' must have layouts.
919   // 'operand_shapes_with_layout' must have a compatible element for each
920   // operand.
921   static std::unique_ptr<HloInstruction> CreateCustomCall(
922       const Shape& shape, absl::Span<HloInstruction* const> operands,
923       absl::string_view custom_call_target,
924       absl::Span<const Shape> operand_shapes_with_layout, string opaque = "");
925 
926   // Creates a tuple instruction with the given elements. This is a convenience
927   // wrapper around CreateVariadic.
928   static std::unique_ptr<HloInstruction> CreateTuple(
929       absl::Span<HloInstruction* const> elements);
930 
931   // Creates a reverse instruction, which reverses the order of the elements
932   // in the specified dimensions.
933   static std::unique_ptr<HloInstruction> CreateReverse(
934       const Shape& shape, HloInstruction* operand,
935       absl::Span<const int64> dimensions);
936 
937   // Creates a Afterall instruction used for joining or creating new values of
938   // token type which thread through side-effecting operations. Operands must
939   // all be tokens, and there must be at least one operand.
940   static std::unique_ptr<HloInstruction> CreateAfterAll(
941       absl::Span<HloInstruction* const> operands);
942 
943   // Creates an AfterAll instruction which creates a token type out of thin air
944   // (no operands). This is a separate method from CreateAfterAll to facility
945   // the removal of operand-less AfterAll instructions.
946   // TODO(b/110532604): Remove this capability of creating a token from nothing
947   // when we plumb a primordial token from the entry computation.
948   static std::unique_ptr<HloInstruction> CreateToken();
949 
950   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
951       const Shape& shape, HloInstruction* operand, int64 dimension);
952 
953   static std::unique_ptr<HloInstruction> CreateSetDimensionSize(
954       const Shape& shape, HloInstruction* operand, HloInstruction* val,
955       int64 dimension);
956 
957   static std::unique_ptr<HloInstruction> CreateAddDependency(
958       HloInstruction* data_operand, HloInstruction* token_operand);
959 
960   // Returns the opcode for this instruction.
opcode()961   HloOpcode opcode() const { return opcode_; }
962 
963   // Returns true if this instruction has a side effect, irrespective of whether
964   // any called computations may contain an instruction with side effects.
965   bool HasSideEffectNoRecurse() const;
966 
967   // Returns true if this instruction has a side effect. An instruction has a
968   // side effect if it uses certain opcodes or calls a computation with a side
969   // effect.
970   bool HasSideEffect() const;
971 
972   // Returns the result shape of this instruction.
973   const Shape& shape() const;
974 
975   // Returns the (mutable) result shape of this instruction.
mutable_shape()976   Shape* mutable_shape() { return &shape_; }
977 
978   // Returns the ith operand to this instruction.
979   const HloInstruction* operand(int64 i) const;
980 
981   // Returns the ith operand to this instruction.
982   HloInstruction* mutable_operand(int64 i);
983 
984   // Returns the number of operands to this instruction.
operand_count()985   int64 operand_count() const { return operands_.size(); }
986 
987   // Returns the vector of operands of this instruction.
988   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
operands()989   const InstructionVector& operands() const { return operands_; }
990 
991   // Returns the vector of unique operands, in the same order they are found
992   // within the operand vector.
993   InstructionVector unique_operands() const;
994 
995   // Returns the index of 'target' in the operands sequence.
996   // Precondition: target must be an operand (or a fatal error will occur).
997   int64 operand_index(const HloInstruction* target) const;
998 
999   // Returns the number of users of this instruction.
user_count()1000   int64 user_count() const { return users_.size(); }
1001 
1002   // Returns the users of this instruction.
users()1003   const std::vector<HloInstruction*>& users() const { return users_; }
1004 
1005   // Returns the index of the user in the users() vector.
1006   //
1007   // Precondition: `user` is a user of the instruction.
1008   int64 UserId(HloInstruction* user);
1009 
1010   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)1011   bool IsUserOf(const HloInstruction* instruction) const {
1012     return ContainsKey(instruction->user_map_, this);
1013   }
1014 
1015   // Adds a control dependency from this instruction to the given
1016   // instruction. This instruction becomes a control predecessor of
1017   // 'instruction', and 'instruction' becomes a control successor of this
1018   // instruction. Returns an error status if either of the given instructions
1019   // does not belong to the same computation.
1020   //
1021   // This is used to enforce an additional ordering requirement that is not
1022   // captured by normal data dependencies, such as ordering among Send or Recv
1023   // operations to avoid deadlock.
1024   Status AddControlDependencyTo(HloInstruction* instruction);
1025 
1026   // Removes a previously added control dependency from this instruction to
1027   // 'instruction'.
1028   Status RemoveControlDependencyTo(HloInstruction* instruction);
1029 
1030   // Drops all control predecessors and successors from this HLO instruction.
1031   Status DropAllControlDeps();
1032 
1033   // Copies the control predecessors and successors on this HLO instruction to
1034   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
1035   // this HLO are in the same module.
1036   //
1037   // Depending on the use cases we see in practice, in the future we may
1038   // consider folding the logic here into Clone, CloneWithNewOperands and
1039   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
1040   Status CopyAllControlDepsFrom(const HloInstruction* inst);
1041 
1042   // Returns the set of control predecessors (successors) of this
1043   // instruction. Control predecessors (successors) must execute before (after)
1044   // the current instruction.
control_predecessors()1045   const std::vector<HloInstruction*>& control_predecessors() const {
1046     return control_predecessors_;
1047   }
control_successors()1048   const std::vector<HloInstruction*>& control_successors() const {
1049     return control_successors_;
1050   }
1051 
1052   // Returns true if "other" performs the same computation as this instruction.
1053   bool Identical(
1054       const HloInstruction& other,
1055       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1056           eq_operands = std::equal_to<const HloInstruction*>(),
1057       const std::function<bool(const HloComputation*, const HloComputation*)>&
1058           eq_computations = std::equal_to<const HloComputation*>(),
1059       bool layout_sensitive = true) const {
1060     // An instruction is always identical to itself.
1061     if (this == &other) {
1062       return true;
1063     }
1064 
1065     // Identical instruction must have the same opcode, shape, and identical
1066     // operands.
1067     if (opcode() != other.opcode()) {
1068       return false;
1069     }
1070     if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
1071                            : ShapeUtil::Compatible(shape(), other.shape()))) {
1072       return false;
1073     }
1074     if (operands().size() != other.operands().size()) {
1075       return false;
1076     }
1077 
1078     // Two AllReduces are Identical if they have the same channel_id.
1079     // Their operands don't have to be Identical.
1080     if (!IsCrossModuleAllReduce()) {
1081       // Use an explicit loop rather than ContainerEquals, because copying
1082       // around std::functions may be too expensive in some cases.
1083       for (size_t i = 0; i < operands().size(); ++i) {
1084         if (!eq_operands(operand(i), other.operand(i))) {
1085           return false;
1086         }
1087       }
1088     }
1089 
1090     if (backend_config_ != other.backend_config_) {
1091       return false;
1092     }
1093 
1094     return IdenticalSlowPath(other, eq_computations);
1095   }
1096 
1097   // Generates a hash value of an HLO instruction. Hash considers
1098   // information on opcode, shape, operands, and typically a root instruction.
1099   // This function returns the same hash value for equivalent HLO instructions,
1100   // with respect to HloInstruction::Identical() method.
1101   //
1102   // Uses hash_operand function to compute hash values of its operands.
1103   // At the very top level, hash_operand should be non-recursive to prevent
1104   // non-termination.
1105   uint64 Hash(
1106       const std::function<uint64(const HloInstruction*)>& hash_operand) const;
1107 
1108   // Calls the above method with non-recursive hash_operand function.
1109   uint64 Hash() const;
1110 
1111   // Returns whether the instruction has a constant operand.
1112   bool HasConstantOperand() const;
1113 
1114   // Replaces the use of this instruction in "user" with "new_producer". Note
1115   // that there might be multiple uses of this instruction in "user"; all will
1116   // be replaced.
1117   //
1118   // If user is a fusion instruction, this function will remove any duplicated
1119   // operands of it which could be created due to this replacement.
1120   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
1121 
1122   // Same as ReplaceUseWith(), but new_producer can have a different shape.
1123   Status ReplaceUseWithDifferentShape(HloInstruction* user,
1124                                       HloInstruction* new_producer);
1125 
1126   // Replaces the specified operand with new_operand. The old and new operands
1127   // must have compatible shapes ignoring floating-point precision.
1128   //
1129   // This function does NOT remove duplicated operands even if this instruction
1130   // is a fusion, so that the existing operand numbers do not change.
1131   Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand);
1132 
1133   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
1134   Status ReplaceOperandWithDifferentShape(int64 operand_num,
1135                                           HloInstruction* new_operand);
1136 
1137   // Replaces all uses of this instruction with the new producer. If
1138   // new_producer is a user of this instruction then new_producer remains a use
1139   // of this instruction to avoid introducing cycles into the graph.
1140   //
1141   // If this instruction is the root of its computation, sets the computation's
1142   // root to new_producer.
1143   //
1144   // The new producer must have a compatible shape ignoring floating-point
1145   // precision.
1146   //
1147   // If a user is a fusion instruction, this function will remove any duplicated
1148   // operands of it which could be created due to this replacement.
1149   Status ReplaceAllUsesWith(HloInstruction* new_producer);
1150 
1151   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
1152   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
1153 
1154   // Performs a postorder DFS visit using this node as the root. If
1155   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
1156   // complete. If ignore_control_predecessors is true, instructions only
1157   // reachable via control dependencies will not be visited, and the postorder
1158   // will not take control dependencies into account. It is as if the control
1159   // dependencies didn't exist in the graph at all.
1160   template <typename HloInstructionPtr>
1161   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
1162                 bool call_finish_visit = true,
1163                 bool ignore_control_predecessors = false);
1164   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
1165                 bool ignore_control_predecessors = false) const {
1166     return const_cast<HloInstruction*>(this)->Accept(
1167         visitor, call_finish_visit, ignore_control_predecessors);
1168   }
1169 
1170   // Same as Accept() above, but the order of operand and control predecessor
1171   // visitation is determined by the given operand order; if compare(A, B) ==
1172   // true, A is visited before B.
1173   using CompareFunction =
1174       std::function<bool(const HloInstruction*, const HloInstruction*)>;
1175   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
1176                                 const CompareFunction& operand_order,
1177                                 bool call_finish_visit = true);
1178 
1179   // Visit this instruction and only this instruction with the given visitor.
1180   template <typename HloInstructionPtr>
1181   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
1182 
1183   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
1184   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
1185   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
1186   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
1187       const;
1188 
LatestNonGteAncestorAndIndex()1189   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
1190     auto rv =
1191         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
1192     return {const_cast<HloInstruction*>(rv.first), rv.second};
1193   }
1194 
1195   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
1196   const HloInstruction* LatestNonGteAncestor() const;
1197 
LatestNonGteAncestor()1198   HloInstruction* LatestNonGteAncestor() {
1199     return const_cast<HloInstruction*>(
1200         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
1201   }
1202 
1203   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
1204   // The setter should only be called by HloModule or HloComputation methods.
1205   //
1206   // Precondition: The instruction has a valid to_apply_ field.
1207   HloComputation* to_apply() const;
1208   void set_to_apply(HloComputation* to_apply);
1209 
1210   // Gets/sets the while_condition or while_body HloComputation for While. The
1211   // setters should only be called by HloModule or HloComputation methods.
1212   //
1213   // Precondition: The instruction is a While instruction.
1214   HloComputation* while_condition() const;
1215   HloComputation* while_body() const;
1216   void set_while_condition(HloComputation* while_condition);
1217   void set_while_body(HloComputation* while_body);
1218 
1219   HloInstruction* while_init() const;
1220 
1221   // Gets/sets the true and false HloComputation for Conditional.
1222   //
1223   // Precondition: The instruction is a predicated Conditional instruction.
1224   HloComputation* true_computation() const;
1225   HloComputation* false_computation() const;
1226 
1227   // Gets the branch HloComputations for Conditional.
1228   //
1229   // Precondition: The instruction is a Conditional instruction.
1230   const std::vector<HloComputation*>& branch_computations() const;
1231   int branch_count() const;
1232   HloComputation* branch_computation(int b) const;
1233   // Sets a branch HloComputation for Conditional.
1234   // The setter should only be called by HloModule or HloComputation methods.
1235   //
1236   // Precondition: The instruction is a Conditional instruction.
1237   void set_branch_computation(int b, HloComputation* computation);
1238 
1239   // Returns a string for the signature of this instruction if considered as a
1240   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
1241   string SignatureString() const;
1242 
1243   // Returns a debugging string that represents this instruction.
1244   //
1245   // (We express the default options using an overload rather than a default
1246   // param because gdb ignores default params, but does resolve overloads.)
1247   //
1248   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
1249   // default, backing off on providing full information for very large strings,
1250   // or provide a different name for a ToString-like function that does that.
ToString()1251   string ToString() const { return ToString(HloPrintOptions()); }
1252   string ToString(const HloPrintOptions& options) const;
1253 
1254   // Components of the ToString() representation:
1255 
1256   // Returns a string representation of the operand list.
1257   string OperandsToString(const HloPrintOptions& options) const;
1258 
1259   // Returns string representation of op-specific attributes.
1260   std::vector<string> ExtraAttributesToString(
1261       const HloPrintOptions& options) const;
1262 
1263   // As ToString, but returns a shorter string.
1264   string ToShortString() const;
1265 
1266   // Prints an instruction to a string.
1267   //
1268   // The canonical string representation needs to name operands and instruction
1269   // names in a consistent way. This is implemented through the
1270   // canonical_name_map.
1271   string ToStringWithCanonicalNameMap(
1272       const HloPrintOptions& options,
1273       CanonicalNameMap* canonical_name_map) const;
1274 
1275   // Returns a serialized representation of this instruction.
1276   virtual HloInstructionProto ToProto() const;
1277 
1278   // Returns a category for the HLO. This could be something like "convolution"
1279   // or "elementwise".
1280   virtual string ToCategory() const;
1281 
1282   // Returns a logging instruction, if the output of this instruction is logged.
1283   //
1284   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
1285   HloInstruction* tracing() const;
1286   void set_tracing(HloInstruction* trace_instruction);
1287 
1288   // Returns true if this instruction is fused, ie contained within a fusion
1289   // instruction.
1290   bool IsFused() const;
1291 
1292   bool IsLoopFusion() const;
1293   bool IsInputFusion() const;
1294   bool IsOutputFusion() const;
1295   bool IsCustomFusion() const;
1296 
1297   // Returns true if this instruction can be legally fused into a fusion
1298   // instruction.
1299   bool IsFusible() const;
1300 
1301   // Returns the sharding applied to this operator.
1302   // REQUIRES: has_sharding() is true.
sharding()1303   const HloSharding& sharding() const {
1304     CHECK(has_sharding());
1305     return *sharding_;
1306   }
sharding_ptr()1307   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
1308 
1309   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)1310   const HloSharding& sharding_or_default(const HloSharding& default_) const {
1311     return sharding_ ? *sharding_ : default_;
1312   }
1313   // Returns the sharding unique device, if any.
sharding_unique_device()1314   absl::optional<int64> sharding_unique_device() const {
1315     if (sharding_ == nullptr) {
1316       return absl::optional<int64>();
1317     }
1318     return sharding_->UniqueDevice();
1319   }
1320   // Sets the sharding of this operator. Should only be called by HloModule or
1321   // HloComputation methods.
set_sharding(const HloSharding & sharding)1322   void set_sharding(const HloSharding& sharding) {
1323     sharding_ = std::make_shared<const HloSharding>(sharding);
1324   }
set_sharding(std::shared_ptr<const HloSharding> sharding)1325   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
1326     sharding_ = std::move(sharding);
1327   }
1328   void set_single_sharding(const HloSharding& sharding);
1329   // Sets a sharding that assigns the current instruction to device.
set_device_sharding(int64 device)1330   void set_device_sharding(int64 device) {
1331     set_single_sharding(HloSharding::AssignDevice(device));
1332   }
1333   // Remove any sharding from this operator.
clear_sharding()1334   void clear_sharding() { sharding_ = nullptr; }
1335   // Return true if this operator has a sharding assigned.
has_sharding()1336   bool has_sharding() const { return sharding_ != nullptr; }
1337   // Checks whether the instruction has compatible sharding with the other
1338   // instruction.
has_compatible_sharding(const HloInstruction * other)1339   bool has_compatible_sharding(const HloInstruction* other) const {
1340     if (!has_sharding()) {
1341       return !other->has_sharding();
1342     }
1343     return other->has_sharding() ? sharding() == other->sharding() : false;
1344   }
1345 
1346   // When creating a new instruction which either replaces, or shifts up (kCopy
1347   // insertion case), another instruction, we need to make sure the certain
1348   // properties of the new instruction are copied into the derived one. As of
1349   // today, the metadata and sharding will be propagated to the derived
1350   // instruction.
1351   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
1352 
1353   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1354   // operands. After creation the clone has no uses. "this" (the instruction
1355   // cloned from) is not changed. Suffix is the string to append to the name of
1356   // the instruction to form the name of the cloned instruction.
1357   // Ignores the control predecessors and successors of this HLO instruction.
1358   std::unique_ptr<HloInstruction> Clone(
1359       const string& suffix = "clone", HloCloneContext* context = nullptr) const;
1360 
1361   // Clones the HLO instruction as above but with new shape and operands.
1362   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1363       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1364       HloCloneContext* context = nullptr) const;
1365 
1366   // Returns the computations this instruction directly calls (if any).
called_computations()1367   const std::vector<HloComputation*>& called_computations() const {
1368     return called_computations_;
1369   }
1370 
1371   // Replaces all called computations based on a map function. This is needed
1372   // when we clone hlo_computations and want to let the instructions to point
1373   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1374   void ReplaceCalledComputations(
1375       std::function<HloComputation*(HloComputation*)> map_function) {
1376     for (int64 i = 0; i < called_computations_.size(); ++i) {
1377       called_computations_[i] = map_function(called_computations_[i]);
1378     }
1379   }
1380 
1381   // Clears out the called computations.
1382   //
1383   // This is, in particular, necessary when inlining function bodies into their
1384   // caller. If there were side-effecting operations in the called computations,
1385   // the call itself is considered side-effecting and thus cannot be removed. By
1386   // clearing out the computations, we reflect the fact that all side-effecting
1387   // properties have been reflected in the caller, and make the call HLO
1388   // removable.
ClearCalledComputations()1389   void ClearCalledComputations() { called_computations_.clear(); }
1390 
1391   // Returns true if this instruction performs an elementwise operation on
1392   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1393   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
1394   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
1395   //
1396   // Note on performance: when this instruction is kFusion, this method, in the
1397   // worst case, scans all fused instructions. We could speed this up by
1398   // caching.
1399   bool IsElementwiseOnOperand(int64 operand_idx) const;
1400 
1401   // Returns true if this instruction is elementwise on all its operands.
1402   bool IsElementwise() const;
1403 
1404   // Returns true if this is a cross module all-reduce instruction.
1405   bool IsCrossModuleAllReduce() const;
1406 
1407   // Returns true if this is a cross-replica all-reduce instruction.
1408   bool IsCrossReplicaAllReduce() const;
1409 
1410   // Returns true if this instruction is binary and elementwise.
1411   bool IsElementwiseBinary() const;
1412 
1413   // Returns whether this instruction may reuse elements of its `i`th operand.
ReusesOperandElements(int64 i)1414   bool ReusesOperandElements(int64 i) const {
1415     return OperandElementUse(i) == UseKind::kReuse;
1416   }
1417 
1418   // Returns the indices that the given operand appear in the operand list of
1419   // this instruction. Note that an instruction can use the same operand
1420   // multiple times.
1421   absl::InlinedVector<int64, 4> OperandIndices(
1422       const HloInstruction* operand) const;
1423 
1424   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1425   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1426   // indices of the deleted dimensions and the output indices of the inserted
1427   // dimensions.
1428   //
1429   // Precondition: this op must be a reshape.
1430   std::tuple<bool, std::vector<int64>, std::vector<int64>>
1431   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1432 
1433   // Gets the string identifier for this instruction.
name()1434   const string& name() const { return name_; }
1435 
1436   // Sets the string identifier for this instruction. Name will be sanitized to
1437   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
SetAndSanitizeName(const string & name)1438   void SetAndSanitizeName(const string& name) {
1439     name_ = NameUniquer::GetSanitizedName(name);
1440   }
1441 
1442   // Use the given NameUniquer to select a unique name for the instruction based
1443   // on the instruction's existing name.
1444   void UniquifyName(NameUniquer* name_uniquer);
1445 
1446   // Clear the unique ID of the instruction so that it can be re-assigned, such
1447   // as for the purpose of compacting the instruction unique IDs.
ClearUniqueIdInternal()1448   void ClearUniqueIdInternal() { unique_id_ = -1; }
1449 
1450   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1451   void SetUniqueId(int id) {
1452     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1453     CHECK_GE(id, 0);
1454     unique_id_ = id;
1455   }
1456 
1457   // Return the unique ID assigned to this node via SetUniqueId (or -1
1458   // if no id has been assigned yet).
unique_id()1459   int unique_id() const { return unique_id_; }
1460 
1461   // Returns the backend-specific configuration for how a backend should compile
1462   // this HLO. The meaning of the field is backend specific. Not for use before
1463   // or during general HLO optimization, since HLO optimizations do not preserve
1464   // this field and they cannot interpret it due to its meaning being backend
1465   // specific. Except for CustomCall, where this field is preserved and no
1466   // general HLO optimization needs to interpret it.
1467   //
1468   // ConfigProto should be a protobuf Message type.
1469   template <typename ConfigProto>
backend_config()1470   StatusOr<ConfigProto> backend_config() const {
1471     ConfigProto proto;
1472     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
1473     return std::move(proto);
1474   }
1475   Status set_backend_config(const tensorflow::protobuf::Message& proto);
1476 
set_frontend_attributes(FrontendAttributes frontend_attributes)1477   void set_frontend_attributes(FrontendAttributes frontend_attributes) {
1478     frontend_attributes_ = std::move(frontend_attributes);
1479   }
1480 
frontend_attributes()1481   const FrontendAttributes& frontend_attributes() const {
1482     return frontend_attributes_;
1483   }
1484 
1485   // Getter/setter for raw JSON-encoded backend config.  Prefer the
1486   // functions above that deal in proto Messages where possible.
raw_backend_config_string()1487   const string& raw_backend_config_string() const { return backend_config_; }
set_raw_backend_config_string(string config_str)1488   void set_raw_backend_config_string(string config_str) {
1489     backend_config_ = std::move(config_str);
1490   }
1491 
is_default_config()1492   bool is_default_config() const { return is_default_config_; }
set_default_config()1493   void set_default_config() { is_default_config_ = true; }
1494 
1495   // Returns a string representation of a proto in the format used by
1496   // raw_backend_config_string.
1497   //
1498   // This is morally equivalent to:
1499   //
1500   //   HloInstruction instr;
1501   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
1502   //   return instr.raw_backend_config_string();
1503   //
1504   static StatusOr<string> BackendConfigToRawString(
1505       const tensorflow::protobuf::Message& proto);
1506 
1507   // Returns the information used to tell the implementation information about
1508   // what sort of precision is requested. The meaning of the field is backend
1509   // specific. At the moment, it is only supported for kConvolution and kDot.
1510   // Transformations on one kDot or kConvolution to another will preserve this
1511   // information. Transformations to other HLOs will not preserve this
1512   // information but it is presumed that the alternate lowering is strictly
1513   // superior.
1514   // Precondition: opcode must be kConvolution or kDot.
1515   const PrecisionConfig& precision_config() const;
1516   PrecisionConfig* mutable_precision_config();
1517 
1518   // Sets the debug metadata for this instruction.
set_metadata(const OpMetadata & metadata)1519   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
metadata()1520   const OpMetadata& metadata() const { return metadata_; }
1521 
1522   // Set/get the computation containing this instruction. set_parent should only
1523   // be called by HloComputation methods which add/remove instructions to
1524   // computations.
set_parent(HloComputation * computation)1525   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1526   const HloComputation* parent() const { return parent_; }
parent()1527   HloComputation* parent() { return parent_; }
1528 
1529   // Returns the module for this instruction.
1530   HloModule* GetModule() const;
1531 
1532   // Returns whether we could assign input and output layouts to this
1533   // instruction to make it a bitcast.
1534   bool CouldBeBitcast() const;
1535 
1536   // Get/Set the number of partitions per outer dimension (in order, starting
1537   // with outer-most dimension first). Currently used by the parallel cpu
1538   // backend to partition HLOs into parallel tasks.
1539   //
1540   // TODO(b/62783254) Replace these methods with a more general way to
1541   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1542   const std::vector<int64>& outer_dimension_partitions() const {
1543     return outer_dimension_partitions_;
1544   }
1545   void set_outer_dimension_partitions(
1546       const std::vector<int64>& outer_dimension_partitions);
1547 
1548   // Old methods kept for smooth subclassing transition BEGIN.
1549   // TODO(b/80131774): Remove this code.
1550 
1551   // Delegates to HloBatchNormInstruction::feature_index.
1552   int64 feature_index() const;
1553 
1554   // Delegates to HloBatchNormInstruction::epsilon.
1555   float epsilon() const;
1556 
1557   // Delegates to HloFftInstruction::fft_type.
1558   FftType fft_type() const;
1559 
1560   // Delegates to HloFftInstruction::fft_length.
1561   const std::vector<int64>& fft_length() const;
1562 
1563   // Delegates to HloChannelInstruction::channel_id.
1564   absl::optional<int64> channel_id() const;
1565   void set_channel_id(const absl::optional<int64>& channel_id);
1566 
1567   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1568   virtual const std::vector<int64>& dimensions() const {
1569     LOG(FATAL) << "Unimplemented method.";
1570   }
dimensions(int64 index)1571   virtual int64 dimensions(int64 index) const {
1572     LOG(FATAL) << "Unimplemented method.";
1573   }
1574 
1575   // Delegates to HloConcatenateInstruction::concatenate_dimension.
1576   int64 concatenate_dimension() const;
1577 
1578   // Delegates to HloGetDimensionSizeInstruction::dimension.
1579   int64 dimension() const;
1580 
1581   // Delegates to HloReshapeInstruction::inferred_dimension.
1582   int64 inferred_dimension() const;
1583 
1584   // Returns whether this instruction does a rank-2 transposition.
1585   bool IsRank2Transpose() const;
1586 
1587   // Delegates to HloSliceInstruction::slice_start.
1588   int64 slice_starts(int64 dimension) const;
1589   const std::vector<int64>& slice_starts() const;
1590 
1591   // Delegates to HloSliceInstruction::slice_limits.
1592   int64 slice_limits(int64 dimension) const;
1593   const std::vector<int64>& slice_limits() const;
1594 
1595   // Delegates to HloSliceInstruction::slice_strides.
1596   int64 slice_strides(int64 dimension) const;
1597   const std::vector<int64>& slice_strides() const;
1598 
1599   // Returns the literal associated with this instruction.
1600   const Literal& literal() const;
1601 
1602   // Returns whether the instruction is a constant.
1603   bool IsConstant() const;
1604 
1605   // Delegate to HloConstantInstruction::RelayoutConstant.
1606   void RelayoutConstant(const Layout& new_layout,
1607                         const ShapeIndex& shape_index = {});
1608 
1609   // Delegates to HloTraceInstruction::TracingTag.
1610   string TracingTag() const;
1611 
1612   // Delegates to HloFusionInstruction::AddFusionOperand.
1613   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1614 
1615   // Delegates to HloFusionInstruction::MergeFusionInstruction.
1616   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
1617 
1618   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
1619   void MergeFusionInstructionIntoMultiOutput(
1620       HloInstruction* instruction_to_merge);
1621 
1622   // Delegates to HloFusionInstruction::FuseInstruction.
1623   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
1624 
1625   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
1626   HloInstruction* FuseInstructionIntoMultiOutput(
1627       HloInstruction* instruction_to_fuse);
1628 
1629   // Delegates to HloFusionInstruction::fused_instruction.
1630   HloComputation* fused_instructions_computation() const;
1631 
1632   // Delegates to HloFusionInstruction::fused_expression_root.
1633   HloInstruction* fused_expression_root() const;
1634 
1635   // Delegates to HloFusionInstruction::fused_instructions.
1636   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1637       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1638   fused_instructions() const;
1639 
1640   const tensorflow::gtl::iterator_range<
1641       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1642   fused_instructions();
1643 
1644   // Delegates to HloFusionInstruction::fused_instruction_count.
1645   int64 fused_instruction_count() const;
1646 
1647   // Delegates to HloFusionInstruction::fused_parameter.
1648   HloInstruction* fused_parameter(int64 parameter_number) const;
1649 
1650   // Delegates to HloFusionInstruction::fused_parameters.
1651   const std::vector<HloInstruction*>& fused_parameters() const;
1652 
1653   // Returns true if this instruction is a fusion instruction that generates
1654   // multiple outputs.
1655   const bool IsMultiOutputFusion() const;
1656 
1657   // Delegates to HloFusionInstruction::fusion_kind.
1658   FusionKind fusion_kind() const;
1659 
1660   // Delegates to HloFusionInstruction::set_fusion_kind.
1661   void set_fusion_kind(FusionKind kind);
1662 
1663   // Delegates to HloRngInstruction::random_distribution.
1664   RandomDistribution random_distribution() const;
1665 
1666   // Delegates to HloParameterInstruction::parameter_number.
1667   int64 parameter_number() const;
1668 
1669   // Delegates to
1670   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
1671   void set_parameter_replicated_at_leaf_buffers(
1672       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
1673   void set_parameter_replicated_at_leaf_buffers(
1674       const std::vector<bool>& parameter_replicated_at_leaf_buffers);
1675 
1676   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
1677   const absl::optional<std::vector<bool>>&
1678   parameter_replicated_at_leaf_buffers() const;
1679 
1680   // Delegates to HloGetTupleElementInstruction::tuple_index.
1681   int64 tuple_index() const;
1682 
1683   // Delegates to HloGetTupleElementInstruction::set_tuple_index.
1684   void set_tuple_index(int64 new_tuple_index);
1685 
1686   // Delegates to HloReducePrecisionInstruction::exponent_bits.
1687   int32 exponent_bits() const;
1688 
1689   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
1690   int32 mantissa_bits() const;
1691 
1692   // Delegates to HloInfeedInstruction::infeed_config.
1693   string infeed_config() const;
1694 
1695   // Delegates to HloInfeedInstruction::set_infeed_config.
1696   void set_infeed_config(const string& config);
1697 
1698   // Returns the config for the Outfeed instruction.
1699   const string& outfeed_config() const;
1700 
1701   // Returns the shape for the Outfeed instruction.
1702   const Shape& outfeed_shape() const;
1703 
1704   // Delegates to HloCollectiveInstruction::replica_groups.
1705   const std::vector<ReplicaGroup>& replica_groups() const;
1706 
1707   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
1708   const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
1709 
1710   // Returns data on the window in a windowed operation such as
1711   // convolution.
window()1712   virtual const Window& window() const {
1713     LOG(FATAL) << "Unimplemented method.";
1714   }
1715 
1716   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)1717   virtual void set_window(const Window& window) {
1718     LOG(FATAL) << "Unimplemented method.";
1719   }
1720 
1721   // Returns the unique_indices field.
unique_indices()1722   virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; }
1723 
1724   // Returns data on the dimension numbers used for a convolution operation,
1725   // which may be a kConvolution instruction or a kCustomCall that implements a
1726   // convolution.
1727   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
1728 
1729   // Sets the convolution dimension numbers on this instruction.  In general you
1730   // shouldn't need to call this; instead, specify the convolution dimension
1731   // numbers when you create the instruction.
1732   void set_convolution_dimension_numbers(
1733       const ConvolutionDimensionNumbers& dnums);
1734 
1735   // The number of feature groups. Must be a divisor of the input feature
1736   // dimension and output feature dimension.
1737   int64 feature_group_count() const;
1738 
1739   void set_feature_group_count(int64 feature_group_count);
1740 
1741   // The number of batch groups. Must be a divisor of the input batch dimension
1742   int64 batch_group_count() const;
1743 
1744   void set_batch_group_count(int64 batch_group_count);
1745 
1746   // Delegates to HloSelectAndScatterInstruction::select.
1747   HloComputation* select() const;
1748 
1749   // Delegates to HloSelectAndScatterInstruction::scatter.
1750   HloComputation* scatter() const;
1751 
1752   // Delegates to HloSelectAndScatterInstruction::set_select.
1753   void set_select(HloComputation* computation);
1754 
1755   // Delegates to HloSelectAndScatterInstruction::set_scatter.
1756   void set_scatter(HloComputation* computation);
1757 
1758   // Delegates to HloCustomCallInstruction::custom_call_target.
1759   const string& custom_call_target() const;
1760 
1761   // Delegates to HloPadInstruction::padding_config.
1762   const PaddingConfig& padding_config() const;
1763 
1764   // Delegates to HloDynamicSliceInstruction::slice_sizes.
1765   int64 slice_sizes(int64 dimension) const;
1766 
1767   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
1768   const std::vector<int64>& dynamic_slice_sizes() const;
1769 
1770   // Delegates to HloGatherInstruction::gather_dimension_numbers.
1771   const GatherDimensionNumbers& gather_dimension_numbers() const;
1772   // Delegates to HloGatherInstruction::gather_slice_sizes.
1773   absl::Span<const int64> gather_slice_sizes() const;
1774 
1775   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
1776   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
1777 
1778   // Delegates to HloDotInstruction::dot_dimension_numbers().
1779   const DotDimensionNumbers& dot_dimension_numbers() const;
1780 
1781   // Delegates to HloDomainInstruction::operand_side_metadata().
1782   const DomainMetadata& operand_side_metadata() const;
1783 
1784   // Delegates to HloDomainInstruction::user_side_metadata().
1785   const DomainMetadata& user_side_metadata() const;
1786 
1787   // Delegates to HloCompareInstruction::direction().
1788   ComparisonDirection comparison_direction() const;
1789 
1790   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
1791   const TriangularSolveOptions& triangular_solve_options() const;
1792 
1793   // Delegates to HloCholeskyInstruction::cholesky_options().
1794   const CholeskyOptions& cholesky_options() const;
1795 
1796   // Appends operand to the list of operands and adds this instruction as a user
1797   // of the operand.
1798   void AppendOperand(HloInstruction* operand);
1799 
1800   // Old methods kept for smooth subclassing transition END.
1801 
1802  protected:
1803   // Indicates how an instruction uses a value (such as an operand).
1804   //
1805   // Does it (a) not use it, (b) use it, or (c) use it multiple times?
1806   //
1807   // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or
1808   // (ii) use it after having permuted it somehow, e.g. through a reshape.  If
1809   // the use is a permuting use, we set permutation_instr to the instruction
1810   // that did the permuting.
1811   struct UseKind {
1812     enum Kind { kReuse, kUse, kNoUse };
1813 
1814     // Creates a UseKind that represents a use that permutes an instruction's
1815     // elements according to the given instruction.
PermutingUseKind1816     static UseKind Permuting(const HloInstruction* permutation_instr) {
1817       UseKind k(kUse);
1818       k.permutation_instr = permutation_instr;
1819       return k;
1820     }
1821 
UseKindUseKind1822     UseKind(Kind kind)  // NOLINT intentionally nonexplicit
1823         : kind(kind), permutation_instr(nullptr) {}
1824 
1825     bool friend operator==(UseKind a, Kind b) { return a.kind == b; }
1826     bool friend operator==(Kind a, UseKind b) { return b == a; }
1827 
1828     Kind kind;
1829     const HloInstruction* permutation_instr;
1830   };
1831 
1832   // Helper class for computing OperandElementUse for kFusion.
1833   class FusionReusesParamElements;
1834 
1835   // Internal constructor for a given opcode/shape, other fields must be filled
1836   // by factory methods.
1837   HloInstruction(HloOpcode opcode, const Shape& shape);
1838 
RemoveOperandAt(int index)1839   void RemoveOperandAt(int index) {
1840     operands_.erase(operands_.begin() + index);
1841   }
1842 
1843   // Removes a list of operands with the given indices in ascending order.
1844   void RemoveOperandsAtAscendingIndices(
1845       absl::Span<const int> ascending_indices);
1846 
AppendComputation(HloComputation * computation)1847   void AppendComputation(HloComputation* computation) {
1848     called_computations_.push_back(computation);
1849   }
1850 
DetachFrom(HloInstruction * usee)1851   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
1852 
set_called_computation(int index,HloComputation * computation)1853   void set_called_computation(int index, HloComputation* computation) {
1854     called_computations_[index] = computation;
1855   }
1856   // Indices of computations in called_computations_ for instructions which call
1857   // multiple computations.
1858   enum {
1859     // kWhile computations.
1860     kBodyComputationIndex = 0,
1861     kConditionComputationIndex = 1,
1862 
1863     // kSelectAndScatter computations.
1864     kSelectComputationIndex = 0,
1865     kScatterComputationIndex = 1,
1866 
1867     // kConditional computations.
1868     kTrueComputationIndex = 0,
1869     kFalseComputationIndex = 1,
1870   };
1871 
1872  private:
1873   // Implementation for non-common logic of CloneWithNewOperands.
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)1874   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1875       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1876       HloCloneContext* context) const {
1877     // TODO(b/80131774): This should be pure virtual.
1878     LOG(FATAL) << "Unimplemented method.";
1879   }
1880 
1881   // Implementation for non-common logic of ExtraAttributesToString.
ExtraAttributesToStringImpl(const HloPrintOptions & options)1882   virtual std::vector<string> ExtraAttributesToStringImpl(
1883       const HloPrintOptions& options) const {
1884     return {};
1885   }
1886 
1887   // Implementation for IsElementwise if operand_idx is nullopt and for
1888   // IsElementwiseOnOperand if otherwise.
1889   //
1890   // NOTE: For all instructions other than kFusion, being elementwise on one of
1891   // the operands is equivalent to being elementwise on all the operands.
1892   virtual bool IsElementwiseImpl(
1893       const absl::optional<int64>& operand_idx) const;
1894 
1895   // Prints an operand to a string.
1896   virtual string OperandsToStringWithCanonicalNameMap(
1897       const HloPrintOptions& options,
1898       CanonicalNameMap* canonical_name_map) const;
1899 
1900   // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and
1901   // OperandsToStringWithCanonicalNameMap() functions.
1902   friend class HloComputation;
1903 
1904   // See comments on Identical().
1905   virtual bool IdenticalSlowPath(
1906       const HloInstruction& other,
1907       const std::function<bool(const HloComputation*, const HloComputation*)>&
1908           eq_computations) const;
1909 
1910   // Generates a hash value specific to a particular type of an instruction.
1911   // This function typically considers the inner root instruction.
1912   virtual uint64 InnerHash() const;
1913 
1914   // Creates an n-ary elementwise operation.
1915   static std::unique_ptr<HloInstruction> CreateNary(
1916       const Shape& shape, HloOpcode opcode,
1917       absl::Span<HloInstruction* const> operands);
1918 
1919   // Adds a user for this instruction.
1920   void AddUser(HloInstruction* user);
1921 
1922   // Removes a user for this instruction.
1923   void RemoveUser(HloInstruction* user);
1924 
1925   // Returns how this instruction uses elements of its operand at operand_num.
1926   UseKind OperandElementUse(int64 operand_num) const;
1927 
1928   // Helper for implementing backend_config().  Parses backend_config_ into the
1929   // given proto.
1930   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
1931 
1932   int unique_id_;  // Unique to this HloInstruction within a HloModule
1933 
1934   // Opcode for this instruction.
1935   HloOpcode opcode_;
1936 
1937   // Instruction operands.
1938   InstructionVector operands_;
1939 
1940   // The set of control predecessors of this instruction.
1941   // Note that the order of the instructions in the vector influences the order
1942   // computed in HloComputation::ComputeInstructionPostOrder, which may
1943   // influence the result of the compilation by changing the scheduling. We are
1944   // not sure if it matters.
1945   std::vector<HloInstruction*> control_predecessors_;
1946 
1947   // The users of this instruction. Users are HLOs where this instruction is an
1948   // operand. The vector users_ and the map user_map_ contain identical members.
1949   // The map enables fast membership testing and the vector enables fast, stable
1950   // iteration. The value in the map contains the index of the instruction in
1951   // the vector what enables fast removal.
1952   std::vector<HloInstruction*> users_;
1953   absl::flat_hash_map<const HloInstruction*, int64> user_map_;
1954 
1955   // The set of control successors of this instruction.
1956   std::vector<HloInstruction*> control_successors_;
1957 
1958   // The computation in which this instruction is contained.
1959   HloComputation* parent_ = nullptr;
1960 
1961   // Result shape of this instruction.
1962   Shape shape_;
1963 
1964   // The sharding, if one exists.
1965   // Uses std::shared_ptr to allow reuse of the same sharding object between
1966   // HloInstructions and other components as HloSharding can be very large for
1967   // many element tuples.
1968   std::shared_ptr<const HloSharding> sharding_;
1969 
1970   // Computations called by this instruction.
1971   std::vector<HloComputation*> called_computations_;
1972 
1973   // A trace instruction that consumes this instruction.
1974   //
1975   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1976   // an operand.
1977   HloInstruction* trace_instruction_ = nullptr;
1978 
1979   // The backend-specific configuration for how a backend should compile this
1980   // HLO. See the documentation on backend_config().
1981   string backend_config_;
1982 
1983   // Attributes passed from the frontend to give hints to the backend about
1984   // how to compile this HLO.
1985   // HLO -> HLO transforms are expected to preserve these attributes on a
1986   // "best effort" basis only.
1987   // For example:
1988   //    x = const(10, frontend_attributes={x}
1989   //    y = const(10, frontend_attributes={y}
1990   //    z = add(x,y), frontend_attributes={y}
1991   // Could be simplified to:
1992   //    z' = const(20), frontend_attributes={?}
1993   FrontendAttributes frontend_attributes_;
1994 
1995   // This field is assigned to true when backend_config_ is assigned to
1996   // a default configuration.
1997   bool is_default_config_ = false;
1998 
1999   // String identifier for instruction.
2000   string name_;
2001 
2002   // Metadata for debugging.
2003   OpMetadata metadata_;
2004 
2005   // The number of partitions per outer dimension (listed in order from
2006   // outer-most dimension first).
2007   std::vector<int64> outer_dimension_partitions_;
2008 
2009   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
2010 };
2011 
2012 // Explicit instantiations in hlo_instruction.cc.
2013 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2014 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2015 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2016 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2017 
2018 string ToString(HloInstruction::FusionKind kind);
2019 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2020     const string& kind_name);
2021 
2022 // Custom (de)stringification functions for protos that live inside
2023 // HloInstruction.
2024 string PaddingConfigToString(const PaddingConfig& padding);
2025 string FrontendAttributesToString(
2026     const FrontendAttributes& frontend_attributes);
2027 string OpMetadataToString(const OpMetadata& metadata);
2028 string RandomDistributionToString(const RandomDistribution& distribution);
2029 string PrecisionToString(const PrecisionConfig::Precision& precision);
2030 string ConvolutionDimensionNumbersToString(
2031     const ConvolutionDimensionNumbers& dnums);
2032 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups);
2033 
2034 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
2035 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
2036 
2037 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
2038 
2039 // Map classes that guarantee a deterministic iteration order when the key is
2040 // an HloInstruction* or a const HloInstruction*.
2041 // To make the iteration order over the map deterministic, the comparator
2042 // should not be using the pointer values, but rather an intrinsic property of
2043 // the hlo. Exception: null pointer values compare less than non-null.
2044 struct HloPtrComparator {
2045   bool operator()(const HloInstruction* const& lhs,
2046                   const HloInstruction* const& rhs) const;
2047 };
2048 
2049 template <typename ValueT>
2050 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
2051 
2052 template <typename ValueT>
2053 using ConstHloInstructionMap =
2054     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
2055 
2056 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
2057 using ConstHloInstructionSet =
2058     std::set<const HloInstruction*, HloPtrComparator>;
2059 
2060 }  // namespace xla
2061 
2062 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
2063