• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <tuple>
31 #include <vector>
32 
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/container/inlined_vector.h"
36 #include "absl/memory/memory.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/comparison_util.h"
41 #include "tensorflow/compiler/xla/iterator_util.h"
42 #include "tensorflow/compiler/xla/literal.h"
43 #include "tensorflow/compiler/xla/map_util.h"
44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
45 #include "tensorflow/compiler/xla/service/hlo.pb.h"
46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
50 #include "tensorflow/compiler/xla/service/name_uniquer.h"
51 #include "tensorflow/compiler/xla/shape_tree.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/iterator_range.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/macros.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/platform/types.h"
60 
61 namespace xla {
62 
63 class HloComputation;
64 class HloModule;
65 
66 string PrintName(const string& name, bool print_ids);
67 
68 // A bunch of switches that control how the hlo text should be printed.
69 class HloPrintOptions {
70  public:
71   enum class PrintSubcomputationMode {
72     kOff,                  // Do not print anything about subcomputations.
73     kNameOnly,             // Only print the name of subcomputations.
74     kFullBodies,           // Print the full bodies of subcomputations.
75     kNonSequentialBodies,  // Print the full bodies of subcomputations that are
76                            // not in a sequential context.
77   };
78 
79   // Constructs the default print options: don't print large constants, don't
80   // compact operands, no indentation.
HloPrintOptions()81   HloPrintOptions()
82       : print_large_constants_(false),
83         print_only_essential_constants_(false),
84         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
85         print_metadata_(true),
86         print_backend_config_(true),
87         print_infeed_outfeed_config_(true),
88         compact_operands_(false),
89         include_layout_in_shapes_(true),
90         print_result_shape_(true),
91         print_operand_shape_(true),
92         print_operand_names_(true),
93         print_operand_index_annotation_interval_(5),
94         print_program_shape_(true),
95         print_percent_(true),
96         print_control_dependencies_(true),
97         canonicalize_instruction_names_(false),
98         indent_amount_(0),
99         is_in_nested_computation_(false),
100         print_ids_(true),
101         canonicalize_computations_(false),
102         print_extra_attributes_(true) {}
103 
ShortParsable()104   static HloPrintOptions ShortParsable() {
105     return HloPrintOptions()
106         .set_print_large_constants(true)
107         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
108         .set_print_metadata(false)
109         .set_print_backend_config(false)
110         .set_print_operand_shape(false)
111         .set_print_operand_index_annotation_interval(0)
112         .set_print_program_shape(false)
113         .set_print_percent(false)
114         .set_print_control_dependencies(false);
115   }
116 
117   // Options to produce the canonical string representing an isomorphic
118   // computation graph.
Canonical()119   static HloPrintOptions Canonical() {
120     return HloPrintOptions()
121         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
122         .set_print_metadata(false)
123         .set_print_backend_config(false)
124         .set_compact_operands(false)
125         .set_print_operand_names(false)
126         .set_print_operand_shape(true)
127         .set_print_operand_index_annotation_interval(0)
128         .set_print_program_shape(false)
129         .set_print_percent(false)
130         .set_print_control_dependencies(false)
131         .set_canonicalize_instruction_names(true);
132   }
133 
134   // Options to produce a fingerprint of an HLO.
Fingerprint()135   static HloPrintOptions Fingerprint() {
136     return HloPrintOptions()
137         .set_print_subcomputation_mode(
138             PrintSubcomputationMode::kNonSequentialBodies)
139         .set_print_metadata(false)
140         .set_print_backend_config(false)
141         .set_print_infeed_outfeed_config(false)
142         .set_print_only_essential_constants(true)
143         .set_compact_operands(true)
144         .set_print_operand_names(false)
145         .set_print_operand_shape(true)
146         .set_print_operand_index_annotation_interval(0)
147         .set_print_program_shape(false)
148         .set_print_percent(false)
149         .set_print_control_dependencies(false)
150         .set_canonicalize_instruction_names(true)
151         .set_print_ids(false)
152         .set_canonicalize_computations(true);
153   }
154 
155   // If true, large constants will be printed out.
set_print_large_constants(bool value)156   HloPrintOptions& set_print_large_constants(bool value) {
157     print_large_constants_ = value;
158     return *this;
159   }
160 
161   // If true, only integer, all-zero, are all-one constants will be printed out.
set_print_only_essential_constants(bool value)162   HloPrintOptions& set_print_only_essential_constants(bool value) {
163     print_only_essential_constants_ = value;
164     return *this;
165   }
166 
set_print_subcomputation_mode(PrintSubcomputationMode value)167   HloPrintOptions& set_print_subcomputation_mode(
168       PrintSubcomputationMode value) {
169     print_subcomputation_mode_ = value;
170     return *this;
171   }
172 
173   // If true, metadata will be printed.
set_print_metadata(bool value)174   HloPrintOptions& set_print_metadata(bool value) {
175     print_metadata_ = value;
176     return *this;
177   }
178 
179   // If true, backend_config will be printed.
set_print_backend_config(bool value)180   HloPrintOptions& set_print_backend_config(bool value) {
181     print_backend_config_ = value;
182     return *this;
183   }
184 
185   // If true, infeed_config and outfeed_config will be printed.
set_print_infeed_outfeed_config(bool value)186   HloPrintOptions& set_print_infeed_outfeed_config(bool value) {
187     print_infeed_outfeed_config_ = value;
188     return *this;
189   }
190 
191   // If true, result shapes will be printed.
set_print_result_shape(bool value)192   HloPrintOptions& set_print_result_shape(bool value) {
193     print_result_shape_ = value;
194     return *this;
195   }
196 
197   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)198   HloPrintOptions& set_print_operand_shape(bool value) {
199     print_operand_shape_ = value;
200     return *this;
201   }
202 
203   // If true, operands' shapes will be printed.
set_print_operand_index_annotation_interval(int64_t value)204   HloPrintOptions& set_print_operand_index_annotation_interval(int64_t value) {
205     print_operand_index_annotation_interval_ = value;
206     return *this;
207   }
208 
209   // If true, the operand names will be printed.
set_print_operand_names(bool value)210   HloPrintOptions& set_print_operand_names(bool value) {
211     print_operand_names_ = value;
212     return *this;
213   }
214 
215   // If true, all printed names include unique identifiers.
set_print_ids(bool value)216   HloPrintOptions& set_print_ids(bool value) {
217     print_ids_ = value;
218     return *this;
219   }
220 
221   // If true, the HLO includes its attributes.
set_print_extra_attributes(bool value)222   HloPrintOptions& set_print_extra_attributes(bool value) {
223     print_extra_attributes_ = value;
224     return *this;
225   }
226 
227   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)228   HloPrintOptions& set_print_program_shape(bool value) {
229     print_program_shape_ = value;
230     return *this;
231   }
232 
233   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)234   HloPrintOptions& set_print_percent(bool value) {
235     print_percent_ = value;
236     return *this;
237   }
238 
239   // If true, control dependencies will be printed.
set_print_control_dependencies(bool value)240   HloPrintOptions& set_print_control_dependencies(bool value) {
241     print_control_dependencies_ = value;
242     return *this;
243   }
244 
245   // If true, only a part of operands will be printed out (note that in this
246   // case the text will not be parsable).
set_compact_operands(bool value)247   HloPrintOptions& set_compact_operands(bool value) {
248     compact_operands_ = value;
249     return *this;
250   }
251 
252   // If true, include the layout in any shapes that are printed (instruction
253   // and operands).
set_include_layout_in_shapes(bool value)254   HloPrintOptions& set_include_layout_in_shapes(bool value) {
255     include_layout_in_shapes_ = value;
256     return *this;
257   }
258 
259   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
260   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
set_canonicalize_instruction_names(bool value)261   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
262     canonicalize_instruction_names_ = value;
263     return *this;
264   }
265 
266   // If true, canonicalizes computations, sorting by computations' names.
set_canonicalize_computations(bool value)267   HloPrintOptions& set_canonicalize_computations(bool value) {
268     canonicalize_computations_ = value;
269     return *this;
270   }
271 
272   // The indent of the hlo text block.
set_indent_amount(int value)273   HloPrintOptions& set_indent_amount(int value) {
274     indent_amount_ = value;
275     return *this;
276   }
277 
278   // If true, indicates the instruction being printed is inside a nested
279   // computation.
set_is_in_nested_computation(bool value)280   HloPrintOptions& set_is_in_nested_computation(bool value) {
281     is_in_nested_computation_ = value;
282     return *this;
283   }
284 
285   // Instructions are selected for printing by a predicate function
286   // (`set_print_instructions`). We also print their surrounding instructions
287   // for ease of reading. We print `leading_and_trailing_instructions_number`
288   // instructions before and after the qualified ones inside a computation.
set_leading_and_trailing_instructions_number(int value)289   HloPrintOptions& set_leading_and_trailing_instructions_number(int value) {
290     leading_and_trailing_instructions_number_ = value;
291     return *this;
292   }
293 
294   // A callback which takes an HloInstruction*, its string representation,
295   // the indentation level of the resulting block, and a
296   // bool variable indicating whether the instruction is root or not. The return
297   // value is a string which is used for this instruction during printing.
298   using FormatInstructionFunc =
299       std::function<string(const HloInstruction*, const string&, int, bool)>;
300 
set_format_instruction(FormatInstructionFunc callback)301   HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) {
302     format_instruction_ = callback;
303     return *this;
304   }
305 
306   using HloInstructionPredicate = std::function<bool(const HloInstruction*)>;
307 
308   // A callback which takes an HloInstruction* and returns whether it should be
309   // printed or not.
set_print_instruction(HloInstructionPredicate callback)310   HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) {
311     print_instruction_ = callback;
312     return *this;
313   }
314 
315   using HloComputationPredicate = std::function<bool(const HloComputation*)>;
316 
317   // A callback which takes an HloComputation* and returns whether it should be
318   // printed or not.
set_print_computation(HloComputationPredicate callback)319   HloPrintOptions& set_print_computation(HloComputationPredicate callback) {
320     print_computation_ = callback;
321     return *this;
322   }
323 
print_large_constants()324   bool print_large_constants() const { return print_large_constants_; }
print_only_essential_constants()325   bool print_only_essential_constants() const {
326     return print_only_essential_constants_;
327   }
print_subcomputation_mode()328   PrintSubcomputationMode print_subcomputation_mode() const {
329     return print_subcomputation_mode_;
330   }
print_metadata()331   bool print_metadata() const { return print_metadata_; }
print_backend_config()332   bool print_backend_config() const { return print_backend_config_; }
print_infeed_outfeed_config()333   bool print_infeed_outfeed_config() const {
334     return print_infeed_outfeed_config_;
335   }
compact_operands()336   bool compact_operands() const { return compact_operands_; }
include_layout_in_shapes()337   bool include_layout_in_shapes() const { return include_layout_in_shapes_; }
print_result_shape()338   bool print_result_shape() const { return print_result_shape_; }
print_operand_shape()339   bool print_operand_shape() const { return print_operand_shape_; }
print_operand_names()340   bool print_operand_names() const { return print_operand_names_; }
print_operand_index_annotation_interval()341   int64 print_operand_index_annotation_interval() const {
342     return print_operand_index_annotation_interval_;
343   }
print_ids()344   bool print_ids() const { return print_ids_; }
print_program_shape()345   bool print_program_shape() const { return print_program_shape_; }
print_percent()346   bool print_percent() const { return print_percent_; }
print_control_dependencies()347   bool print_control_dependencies() const {
348     return print_control_dependencies_;
349   }
print_extra_attributes()350   bool print_extra_attributes() const { return print_extra_attributes_; }
canonicalize_instruction_names()351   bool canonicalize_instruction_names() const {
352     return canonicalize_instruction_names_;
353   }
canonicalize_computations()354   bool canonicalize_computations() const { return canonicalize_computations_; }
indent_amount()355   int indent_amount() const { return indent_amount_; }
is_in_nested_computation()356   int is_in_nested_computation() const { return is_in_nested_computation_; }
leading_and_trailing_instructions_number()357   int leading_and_trailing_instructions_number() const {
358     return leading_and_trailing_instructions_number_;
359   }
format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)360   string format_instruction(const HloInstruction* instr,
361                             const string& instr_name, int indent,
362                             bool is_root) const {
363     return format_instruction_(instr, instr_name, indent, is_root);
364   }
print_instruction(const HloInstruction * instr)365   bool print_instruction(const HloInstruction* instr) const {
366     return print_instruction_(instr);
367   }
print_computation(const HloComputation * comp)368   bool print_computation(const HloComputation* comp) const {
369     return print_computation_(comp);
370   }
371 
372  private:
373   bool print_large_constants_;
374   bool print_only_essential_constants_;
375   PrintSubcomputationMode print_subcomputation_mode_;
376   bool print_metadata_;
377   bool print_backend_config_;
378   bool print_infeed_outfeed_config_;
379   bool compact_operands_;
380   bool include_layout_in_shapes_;
381   bool print_result_shape_;
382   bool print_operand_shape_;
383   bool print_operand_names_;
384   // The interval between the /*index=*/ annotated operands. 0 means never print
385   // the annotation, 1 means print annotation for every operand.
386   int64 print_operand_index_annotation_interval_;
387   bool print_program_shape_;
388   bool print_percent_;
389   bool print_control_dependencies_;
390   bool canonicalize_instruction_names_;
391   int indent_amount_;
392   bool is_in_nested_computation_;
393   bool print_ids_;
394   bool canonicalize_computations_;
395   bool print_extra_attributes_;
396   int leading_and_trailing_instructions_number_ = 3;
397   FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr,
398                                                  const string& instr_name,
399                                                  int indent, bool is_root) {
400     return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "",
401                         instr_name);
402   };
403   HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) {
404     return true;
405   };
406   HloComputationPredicate print_computation_ = [](const HloComputation* comp) {
407     return true;
408   };
409 };
410 
411 // For canonical string output, we need to have a canonical way to rename
412 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
413 // where <xxx> is an index starting from 0.
414 class CanonicalNameMap {
415  public:
CanonicalNameMap()416   CanonicalNameMap() : index(0) {}
417 
LookupOrInsert(const string & old_name)418   string LookupOrInsert(const string& old_name) {
419     auto iter = canonical_name_map.find(old_name);
420     if (iter != canonical_name_map.end()) {
421       return iter->second;
422     }
423 
424     string new_name = absl::StrCat("tmp_", index++);
425     canonical_name_map[old_name] = new_name;
426     return new_name;
427   }
Clear()428   void Clear() {
429     canonical_name_map.clear();
430     index = 0;
431   }
432 
433  private:
434   int64 index;
435   absl::flat_hash_map<string, string> canonical_name_map;
436 };
437 
438 // HLO instructions are the atomic unit of the high-level compiler's IR.
439 //
440 // HloInstructions live inside of an HloComputation, which is analogous to a
441 // function in other programming languages.  Nodes have no total order within
442 // their computation.  Instead, they have a partial ordering determined by their
443 // data and control dependencies.
444 //
445 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
446 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
447 // control flow.  For example, the kConditional HLO executes one of two possible
448 // computations, depending on the runtime value of a predicate.
449 //
450 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
451 // values are produced by one HLO and flow into consumers across dependency
452 // edges.
453 class HloInstruction {
454  public:
455   // A fusion node computes the same value a call to its fusion computation
456   // would compute.  However, the choice of fusion kind dictates codegen
457   // strategy for the backend.
458   //
459   // To generate code for a kFusion HloInstruction, most backends do something
460   // like the following:
461   //
462   // 1) Identify the "primary" HloInstruction of the fused computation.
463   // 2) Emit code that does the work of the primary node, creating its inputs
464   //    and transforming its outputs as specified by the fused computation.
465   //
466   // In step (2), the code emitted is usually similar to the code that would be
467   // emitted for an *unfused* version of the primary node, except that
468   //
469   //  - when the primary node reads an element of one of its operands, instead
470   //    of loading the value from memory, it *computes* the value based on the
471   //    contents of the fused computation.
472   //  - when the primary node outputs a value, instead of storing it to memory,
473   //    it forwards the value to its users, which then perform additional
474   //    computations before the value is finally stored to memory at the root of
475   //    the fusion node.
476   //
477   // An HloInstruction's FusionKind helps us find the kFusion instruction's
478   // primary node, and can also affect how we generate code in step (2).
479   //
480   //  - kInput: The primary node is the root of the fused instruction.
481   //
482   //  - kOutput: The primary node is not the root of the fused instruction.
483   //    This fusion kind requires that one operand buffer of the fusion
484   //    instruction be able to alias the output buffer.  This constraint is
485   //    usually enough to let backends find the primary node unambiguously.
486   //
487   //  - kLoop: The primary node is the root of the fused computation, but,
488   //    unlike in input fusion, we prescribe a specific implementation for
489   //    codegen.  Rather than generating code that looks like the code we'd emit
490   //    for an unfused version of the primary/root node, we emit code that
491   //    generates one element of the root at a time.
492   //
493   //  - kCustom: Custom category for backend-specific fusions that don't fit
494   //    into the above patterns.
495   //
496   // Not all backends support all fusion kinds, and given a particular fused
497   // computation, it's not in general safe to change its fusion kind.  Creation
498   // of fusion nodes is always backend-specific.
499   //
500   // For elementwise ops (e.g. kAdd), most backends would emit a
501   // one-element-at-a-time implementation for the unfused version, so loop
502   // fusion and input fusion are probably equivalent if the root node is
503   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
504   // implementation might emit something more sophisticated for an unfused or
505   // input-fusion reduce, but will emit the naive code that reduces one element
506   // at a time for loop fusion with a reduce as the root.
507   //
508   // Another way to think of loop fusion is that it's equivalent to input
509   // fusion, but where the root node is an implicit identity node, whose
510   // unfused implementation is "read one element, write one element".
511   //
512   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
513   // input and loop fusion are basically the same thing: There is no reason for
514   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
515   // the root of a fusion should be lowered.  In addition, this scheme as
516   // written doesn't work for multi-output fusion, where the primary node is
517   // never actually the root (which is a kTuple instruction that gathers the
518   // multiple outputs of the fusion).
519   enum class FusionKind {
520     kLoop,
521     kInput,
522     kOutput,
523     kCustom,
524   };
525 
~HloInstruction()526   virtual ~HloInstruction() { DetachFromOperandsAndUsers(); }
527 
528   // Detaches an instruction from its operands and users. That is, remove the
529   // instruction from each operand's user set and user's operand set.
530   void DetachFromOperandsAndUsers();
531 
532   // Creates an instruction from the given proto. Arguments:
533   //
534   //   proto: the proto to convert from.
535   //   instruction_map: a map from instruction id to HloInstruction*. This map
536   //     must contain all operands of the newly constructed instruction.
537   //   computation_map: a map from computation id to HloComputation*. This map
538   //     must contain all computations which the newly constructed instruction
539   //     calls.
540   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
541       const HloInstructionProto& proto,
542       const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
543       const absl::flat_hash_map<int64, HloComputation*>& computation_map = {},
544       bool prohibit_empty_literal = true);
545 
546   // Creates a parameter-retrieving instruction.
547   static std::unique_ptr<HloInstruction> CreateParameter(
548       int64_t parameter_number, const Shape& shape, const string& name);
549 
550   // Creates a literal constant instruction.
551   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
552 
553   // Creates an Iota instruction.
554   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
555                                                     int64_t iota_dimension);
556 
557   // Creates a get tuple element instruction.
558   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
559       const Shape& shape, HloInstruction* operand, int64_t index);
560 
561   // Creates a trace instruction that logs the input operand in the computation.
562   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
563                                                      HloInstruction* operand);
564 
565   // Creates a random number generation instruction that fills a shape with
566   // random numbers from a given distribution.
567   //
568   // The parameters to the instruction are interpreted as follows:
569   //
570   //  - If `distribution` is RNG_UNIFORM, generates a number in range
571   //    [param0, param1).
572   //
573   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
574   //    with mean `param0` and standard deviation `param1`.
575   static std::unique_ptr<HloInstruction> CreateRng(
576       const Shape& shape, RandomDistribution distribution,
577       absl::Span<HloInstruction* const> parameters);
578 
579   // Creates a stateless random bit generator instruction that fills a shape
580   // with random bits.
581   static std::unique_ptr<HloInstruction> CreateRngBitGenerator(
582       const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm);
583 
584   // Creates an instruction to update the random number generator state to
585   // reflect the new state after `delta` units of 32 random bits are generated
586   // and returns the old state.
587   static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState(
588       const Shape& shape, int64_t delta);
589 
590   // Creates a unary instruction (one operand).
591   // Precondition: opcode must be a legitimate unary operation.
592   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
593                                                      HloOpcode opcode,
594                                                      HloInstruction* operand);
595 
596   // Creates a binary instruction (two operands).
597   // Precondition: opcode must be a legitimate binary operation.
598   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
599                                                       HloOpcode opcode,
600                                                       HloInstruction* lhs,
601                                                       HloInstruction* rhs);
602 
603   // Creates a ternary instruction (three operands).
604   // Precondition: opcode must be a legitimate ternary operation.
605   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
606                                                        HloOpcode opcode,
607                                                        HloInstruction* lhs,
608                                                        HloInstruction* rhs,
609                                                        HloInstruction* ehs);
610 
611   // Creates a variadic instruction (variable number of operands).
612   // Precondition: opcode must be a legitimate variadic operation.
613   static std::unique_ptr<HloInstruction> CreateVariadic(
614       const Shape& shape, HloOpcode opcode,
615       absl::Span<HloInstruction* const> operands);
616 
617   // Creates a map instruction, where the computation (given by the handle) is
618   // applied element-wise to every element in operands (across the operands,
619   // at a given index)
620   static std::unique_ptr<HloInstruction> CreateMap(
621       const Shape& shape, absl::Span<HloInstruction* const> operands,
622       HloComputation* map_computation);
623 
624   // Creates a convolution op, where rhs is the convolutional filter
625   // and window describes how the filter is applied to lhs.
626   static std::unique_ptr<HloInstruction> CreateConvolve(
627       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
628       int64_t feature_group_count, int64_t batch_group_count,
629       const Window& window,
630       const ConvolutionDimensionNumbers& dimension_numbers,
631       const PrecisionConfig& precision_config);
632 
633   // Creates an FFT op, of the type indicated by fft_type.
634   static std::unique_ptr<HloInstruction> CreateFft(
635       const Shape& shape, HloInstruction* operand, FftType fft_type,
636       absl::Span<const int64> fft_length);
637 
638   // Creates a copy-start op, indicating whether this is a cross-program
639   // prefetch or not.
640   static std::unique_ptr<HloInstruction> CreateCopyStart(
641       const Shape& shape, HloInstruction* operand,
642       bool is_cross_program_prefetch = false);
643 
644   // Creates a compare op, performing the comparison specified in direction.
645   static std::unique_ptr<HloInstruction> CreateCompare(
646       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
647       Comparison::Direction direction,
648       absl::optional<Comparison::Type> type = absl::nullopt);
649 
650   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
651       const Shape& shape, HloInstruction* a, HloInstruction* b,
652       const TriangularSolveOptions& options);
653 
654   static std::unique_ptr<HloInstruction> CreateCholesky(
655       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
656 
657   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
658   // dimensions specified in 'dimension_numbers'.
659   static std::unique_ptr<HloInstruction> CreateDot(
660       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
661       const DotDimensionNumbers& dimension_numbers,
662       const PrecisionConfig& precision_config);
663 
664   // Creates a reduce-precision op, where operand is the data to reduce in
665   // precision, and exponent_bits and mantissa_bits describe the precision to
666   // reduce it to.
667   static std::unique_ptr<HloInstruction> CreateReducePrecision(
668       const Shape& shape, HloInstruction* operand, const int exponent_bits,
669       const int mantissa_bits);
670 
671   // Creates an all-gather op, which concats the operands of all participants
672   // along all_gather_dimension. The replica_groups, channel_id, and
673   // use_global_device_ids arguments are identical to those in all-reduce,
674   // except that the order of the group members determines the concatenation
675   // order of inputs from different participants.
676   static std::unique_ptr<HloInstruction> CreateAllGather(
677       const Shape& shape, absl::Span<HloInstruction* const> operands,
678       int64_t all_gather_dimension,
679       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
680       const absl::optional<int64>& channel_id, bool use_global_device_ids);
681 
682   // Creates an all-gather-start op, which concats the operands of all
683   // participants
684   // along all_gather_dimension. The replica_groups, channel_id, and
685   // use_global_device_ids arguments are identical to those in all-reduce,
686   // except that the order of the group members determines the concatenation
687   // order of inputs from different participants. Needs to be used in
688   // conjunction of a AllGatherDone op that synchronizes and returns the result.
689   static std::unique_ptr<HloInstruction> CreateAllGatherStart(
690       const Shape& shape, absl::Span<HloInstruction* const> operands,
691       int64_t all_gather_dimension,
692       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
693       const absl::optional<int64>& channel_id, bool use_global_device_ids);
694 
695   // Creates a cross replica reduction op.
696   //
697   // `reduction_computation`: the reduction function.
698   //
699   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
700   // empty, all replicas belong to one group in the order of 0 - (n-1).
701   // Allreduce will be applied within subgroups.
702   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
703   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
704   //
705   // `channel_id`: for Allreduce nodes from different modules, if
706   // they have the same channel_id, they will be 'Allreduce'd. If
707   // empty, Allreduce will not be applied cross modules.
708   static std::unique_ptr<HloInstruction> CreateAllReduce(
709       const Shape& shape, absl::Span<HloInstruction* const> operands,
710       HloComputation* reduce_computation,
711       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
712       const absl::optional<int64>& channel_id, bool use_global_device_ids);
713 
714   // Creates a reduce-scatter operation which reduces its inputs across the
715   // given replica groups and then scatters the reduced data across the N
716   // participants.
717   static std::unique_ptr<HloInstruction> CreateReduceScatter(
718       const Shape& shape, absl::Span<HloInstruction* const> operands,
719       HloComputation* reduce_computation,
720       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
721       const absl::optional<int64>& channel_id, bool use_global_device_ids,
722       int64_t scatter_dimension);
723 
724   // Creates an asynchronous cross replica reduction op.
725   //
726   // `reduction_computation`: the reduction function.
727   //
728   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
729   // empty, all replicas belong to one group in the order of 0 - (n-1).
730   // Allreduce will be applied within subgroups.
731   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
732   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
733   //
734   // `channel_id`: for Allreduce nodes from different modules, if
735   // they have the same channel_id, they will be 'Allreduce'd. If
736   // empty, Allreduce will not be applied cross modules.
737   static std::unique_ptr<HloInstruction> CreateAllReduceStart(
738       const Shape& shape, absl::Span<HloInstruction* const> operands,
739       HloComputation* reduce_computation,
740       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
741       const absl::optional<int64>& channel_id, bool use_global_device_ids);
742 
743   // An all-to-all op takes N array operands of the same shape and scatters them
744   // to N replicas.  Each replica gathers the results into a tuple.
745   //
746   // For example, suppose we have 3 replicas, with replica i passing inputs
747   // [a_i, b_i, c_i] to its all-to-all op.  Then the resulting tuples are
748   //
749   //   replica 0: (a_0, a_1, a_2)
750   //   replica 1: (b_0, b_1, b_2)
751   //   replica 2: (c_0, c_1, c_2).
752   //
753   // If replica_groups is set, the op is sharded and the replicas are permuted.
754   // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}.
755   // Then each replica passes two operands, say [a_i, b_i], and the result is
756   //
757   //   replica 0: (b_3, b_0)
758   //   replica 1: (a_1, a_2)
759   //   replica 2: (b_1, b_2)
760   //   replica 3: (a_3, a_0).
761   //
762   // All replica groups must have the same number of elements, and the number of
763   // operands must be equal to the size of one replica group.  Each replica must
764   // appear in exactly one group.
765   //
766   // Note that this instruction is different than the all-to-all op in
767   // xla_builder.h.  The version in XlaBuilder takes one input and slices it,
768   // and then concatenates the results into a single array.  This instruction
769   // takes multiple inputs and returns a tuple; it doesn't slice or concatenate.
770   // It is used to implement the higher-level instruction in XlaBuilder.
771   static std::unique_ptr<HloInstruction> CreateAllToAll(
772       const Shape& shape, absl::Span<HloInstruction* const> operands,
773       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
774       const absl::optional<int64>& channel_id,
775       const absl::optional<int64>& split_dimension = absl::nullopt);
776 
777   // Creates a communication instruction that permutes data cross replicas.
778   // Data is sent/received according to the (source_replica_id,
779   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
780   // target_replica_id in any pair, the output on that replica is a tensor
781   // consists of 0(s) in `shape`.
782   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
783       const Shape& shape, HloInstruction* operand,
784       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
785       const absl::optional<int64_t>& channel_id);
786 
787   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
788       const Shape& shape, HloInstruction* input, HloInstruction* output,
789       HloInstruction* input_start_indices, HloInstruction* output_start_indices,
790       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
791       absl::Span<const std::vector<int64_t>> slice_sizes,
792       const absl::optional<int64_t>& channel_id);
793 
794   // Creates a communication instruction that initiates the start of
795   // CollectivePermute.
796   static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
797       const Shape& shape, HloInstruction* operand,
798       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
799       const absl::optional<int64_t>& channel_id);
800 
801   static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
802       const Shape& shape, HloInstruction* input, HloInstruction* output,
803       HloInstruction* input_start_indices, HloInstruction* output_start_indices,
804       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
805       absl::Span<const std::vector<int64_t>> slice_sizes,
806       const absl::optional<int64_t>& channel_id);
807 
808   // Creates an instruction that returns a U32 replica ID.
809   static std::unique_ptr<HloInstruction> CreateReplicaId(
810       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
811 
812   // Creates an instruction that returns a U32 partition ID.
813   static std::unique_ptr<HloInstruction> CreatePartitionId(
814       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
815 
816   // Creates a conversion instruction, where operand is the data to convert and
817   // shape is the target shape for the conversion.
818   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
819                                                        HloInstruction* operand);
820 
821   // Creates a bitcast instruction, where operand is the data to
822   // convert and shape is the target shape for the conversion.
823   static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape,
824                                                        HloInstruction* operand);
825 
826   // Creates a bitcast conversion instruction, where operand is the data to
827   // convert and shape is the target shape for the conversion.
828   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
829       const Shape& shape, HloInstruction* operand);
830 
831   // Creates an infeed instruction, which reads data of the given shape from the
832   // Infeed interface of the device. infeed_shape is the shape of the data
833   // received from the infeed *not* the shape of the infeed instruction which
834   // is a tuple containing the infeed_shape and the TOKEN.
835   static std::unique_ptr<HloInstruction> CreateInfeed(
836       const Shape& infeed_shape, HloInstruction* token_operand,
837       const string& config);
838 
839   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
840   // shape of the data being outfed *not* the shape of the outfeed instruction
841   // which is a TOKEN.
842   static std::unique_ptr<HloInstruction> CreateOutfeed(
843       const Shape& outfeed_shape, HloInstruction* operand,
844       HloInstruction* token_operand, absl::string_view outfeed_config);
845 
846   // Creates an asynchronous send instruction with the given channel id, which
847   // initiates sending the operand data to a unique receive instruction in
848   // another computation that has the same channel id. If is_host_transfer is
849   // true, then this Send operation transfers data to the host.
850   static std::unique_ptr<HloInstruction> CreateSend(
851       HloInstruction* operand, HloInstruction* token, int64_t channel_id,
852       bool is_host_transfer = false);
853 
854   // Blocks until data transfer for the Send instruction (operand) is complete.
855   // The operand must be kSend.
856   static std::unique_ptr<HloInstruction> CreateSendDone(
857       HloInstruction* operand, bool is_host_transfer = false);
858 
859   // Creates an asynchronous receive instruction with the given channel id,
860   // which allocates resources to receive data of the given shape from a unique
861   // send instruction in another computation that has the same channel id.  If
862   // is_host_transfer is true, then this Send operation transfers data from the
863   // host.
864   static std::unique_ptr<HloInstruction> CreateRecv(
865       const Shape& shape, HloInstruction* token, int64_t channel_id,
866       bool is_host_transfer = false);
867 
868   // Blocks until data transfer for the Recv instruction (operand) is complete
869   // and returns the receive buffer. The operand must be kRecv.
870   static std::unique_ptr<HloInstruction> CreateRecvDone(
871       HloInstruction* operand, bool is_host_transfer = false);
872 
873   // Creates a slice instruction, where the operand is sliced by the given
874   // start/limit indices.
875   static std::unique_ptr<HloInstruction> CreateSlice(
876       const Shape& shape, HloInstruction* operand,
877       absl::Span<const int64> start_indices,
878       absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
879 
880   // Creates a slice instruction, where the first operand is sliced by
881   // start indices specified in the second operand, and by size specified in
882   // 'slice_sizes'.
883   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
884       const Shape& shape, HloInstruction* operand,
885       absl::Span<HloInstruction* const> start_indices,
886       absl::Span<const int64> slice_sizes);
887 
888   // Creates a dynamic update slice instruction, which updates a slice
889   // of 'operand' with 'update' and 'start_indices'.
890   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
891       const Shape& shape, HloInstruction* operand, HloInstruction* update,
892       absl::Span<HloInstruction* const> start_indices);
893 
894   // Creates a concatenate instruction, where the operands are concatenated on
895   // the provided dimension.
896   static std::unique_ptr<HloInstruction> CreateConcatenate(
897       const Shape& shape, absl::Span<HloInstruction* const> operands,
898       int64_t dimension);
899 
900   // Creates a reduce instruction, where the computation (given by the handle)
901   // is applied successively to every element in operand. For example, let f be
902   // the function to apply, which takes 2 arguments, an accumulator and the
903   // current value. Let init be an initial value (which is normally chosen to be
904   // the identity element for f, e.g. 0 if f is addition).
905   // Then the reduce HLO will compute:
906   // f(f(init, value0), value1), ...)
907   static std::unique_ptr<HloInstruction> CreateReduce(
908       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
909       absl::Span<const int64> dimensions_to_reduce,
910       HloComputation* reduce_computation);
911 
912   // A more general, multiple-argument version of the above.
913   // The function to apply, f, now takes N arguments:
914   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
915   // init_valueN], and returns an N-tuple. The performed computation is (for
916   // commutative and associative f operators) equivalent to:
917   //
918   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
919   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
920   // ..., inputN.value1)
921   // ...
922   static std::unique_ptr<HloInstruction> CreateReduce(
923       const Shape& shape, absl::Span<HloInstruction* const> operands,
924       absl::Span<HloInstruction* const> init_values,
925       absl::Span<const int64> dimensions_to_reduce,
926       HloComputation* reduce_computation);
927 
928   // Creates a reduce-window instruction, where the computation (given
929   // by the handle) is applied window-wise at each valid window
930   // position in the operand.
931   static std::unique_ptr<HloInstruction> CreateReduceWindow(
932       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
933       const Window& window, HloComputation* reduce_computation);
934 
935   // A more general, multiple-argument version of the above.
936   // The reduce_computation being applied,now takes N arguments:
937   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
938   // valueN], and returns an N-tuple. The operands and init_values now each
939   // contain a span of N input arrays and n initial values.
940   static std::unique_ptr<HloInstruction> CreateReduceWindow(
941       const Shape& shape, absl::Span<HloInstruction* const> operands,
942       absl::Span<HloInstruction* const> init_values, const Window& window,
943       HloComputation* reduce_computation);
944 
945   // Creates a batch-norm-training instruction.
946   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
947       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
948       HloInstruction* offset, float epsilon, int64_t feature_index);
949 
950   // Creates a batch-norm-inference instruction.
951   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
952       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
953       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
954       float epsilon, int64_t feature_index);
955 
956   // Creates a batch-norm-grad instruction.
957   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
958       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
959       HloInstruction* mean, HloInstruction* variance,
960       HloInstruction* grad_output, float epsilon, int64_t feature_index);
961 
962   // Creates a scatter computation that scatters the `source` array to the
963   // selected indices of each window.
964   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
965       const Shape& shape, HloInstruction* operand, HloComputation* select,
966       const Window& window, HloInstruction* source, HloInstruction* init_value,
967       HloComputation* scatter);
968 
969   // Creates a broadcast instruction.
970   static std::unique_ptr<HloInstruction> CreateBroadcast(
971       const Shape& shape, HloInstruction* operand,
972       absl::Span<const int64> broadcast_dimensions);
973 
974   // Creates a sequence of instructions that performs an explicit broadcast of
975   // the operand to the target shape.
976   //
977   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
978   // returned as a unique_ptr for API consistency with other factory methods in
979   // this interface.
980   //
981   // TODO(b/72173833) Ideally HloComputations would always be present, and so
982   // the adder being passed by the caller would not be necessary.
983   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
984       const Shape& output_shape, HloInstruction* operand,
985       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
986           adder);
987 
988   // Creates a pad instruction, where the operand is padded on the edges and
989   // between the elements with the given padding value.
990   static std::unique_ptr<HloInstruction> CreatePad(
991       const Shape& shape, HloInstruction* operand,
992       HloInstruction* padding_value, const PaddingConfig& padding_config);
993 
994   // Creates a reshape instruction, where the operand is flattened row-major
995   // order and then reshaped to the given result shape.
996   static std::unique_ptr<HloInstruction> CreateReshape(
997       const Shape& shape, HloInstruction* operand,
998       int64_t inferred_dimension = -1);
999 
1000   // Creates a dynamic reshape instruction. Similar to reshape but dynamic
1001   // dimensions sizes are provided as additional variadic arguments.
1002   //
1003   // Precondition: dim_sizes.size() == shape.rank()
1004   static std::unique_ptr<HloInstruction> CreateDynamicReshape(
1005       const Shape& shape, HloInstruction* data_operand,
1006       absl::Span<HloInstruction* const> dim_sizes);
1007 
1008   // Creates a transpose instruction which permutes the operand dimensions.
1009   static std::unique_ptr<HloInstruction> CreateTranspose(
1010       const Shape& shape, HloInstruction* operand,
1011       absl::Span<const int64> dimensions);
1012 
1013   // Creates a n-ary sort op with a 'compare' computation which is used for
1014   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
1015   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
1016   // specific index positions which should be compared, and should return a
1017   // PRED. 'is_stable' specifies whether stable sorting is required.
1018   static std::unique_ptr<HloInstruction> CreateSort(
1019       const Shape& shape, int64_t dimension,
1020       absl::Span<HloInstruction* const> operands, HloComputation* compare,
1021       bool is_stable);
1022 
1023   // Creates a while instruction, given a condition computation, a body
1024   // computation, and the initial value for the input of the computations. For
1025   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
1026   // corresponds to the C code below.
1027   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
1028   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
1029                                                      HloComputation* condition,
1030                                                      HloComputation* body,
1031                                                      HloInstruction* init);
1032 
1033   static std::unique_ptr<HloInstruction> CreateConditional(
1034       const Shape& shape, HloInstruction* pred,
1035       HloInstruction* true_computation_arg, HloComputation* true_computation,
1036       HloInstruction* false_computation_arg, HloComputation* false_computation);
1037 
1038   static std::unique_ptr<HloInstruction> CreateConditional(
1039       const Shape& shape, HloInstruction* branch_index,
1040       absl::Span<HloComputation* const> branch_computations,
1041       absl::Span<HloInstruction* const> branch_computation_args);
1042 
1043   static std::unique_ptr<HloInstruction> CreateGather(
1044       const Shape& shape, HloInstruction* operand,
1045       HloInstruction* start_indices,
1046       const GatherDimensionNumbers& gather_dim_numbers,
1047       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
1048 
1049   static std::unique_ptr<HloInstruction> CreateScatter(
1050       const Shape& shape, HloInstruction* operand,
1051       HloInstruction* scatter_indices, HloInstruction* updates,
1052       HloComputation* update_computation,
1053       const ScatterDimensionNumbers& scatter_dim_numbers,
1054       bool indices_are_sorted, bool unique_indices);
1055 
1056   // Creates a kDomain instruction which delimits an HLO domain which have
1057   // the provided user and operand side metadata.
1058   static std::unique_ptr<HloInstruction> CreateDomain(
1059       const Shape& shape, HloInstruction* operand,
1060       std::unique_ptr<DomainMetadata> operand_side_metadata,
1061       std::unique_ptr<DomainMetadata> user_side_metadata);
1062 
1063   // Creates a fusion instruction. A fusion instruction contains one or more
1064   // fused instructions forming an expression with a single root
1065   // "fused_root". Additional instructions can be added to the fusion
1066   // instruction with the method FuseInstruction.
1067   static std::unique_ptr<HloInstruction> CreateFusion(
1068       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
1069 
1070   static std::unique_ptr<HloInstruction> CreateFusion(
1071       const Shape& shape, FusionKind fusion_kind,
1072       absl::Span<HloInstruction* const> operands,
1073       HloComputation* fusion_computation);
1074 
1075   // Creates a call instruction that applies the given computation on the given
1076   // operands. "shape" is the resultant shape.
1077   static std::unique_ptr<HloInstruction> CreateCall(
1078       const Shape& shape, absl::Span<HloInstruction* const> operands,
1079       HloComputation* computation);
1080 
1081   // Creates a custom call instruction that applies the given custom call target
1082   // to the given operands. "opaque" can be an arbitrary string with a
1083   // backend-specific interpretation. "shape" is the resultant shape.
1084   static std::unique_ptr<HloInstruction> CreateCustomCall(
1085       const Shape& shape, absl::Span<HloInstruction* const> operands,
1086       absl::string_view custom_call_target, string opaque = "",
1087       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1088 
1089   // Overload with a to_apply computation.
1090   static std::unique_ptr<HloInstruction> CreateCustomCall(
1091       const Shape& shape, absl::Span<HloInstruction* const> operands,
1092       HloComputation* to_apply, absl::string_view custom_call_target,
1093       string opaque = "",
1094       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1095 
1096   // Overload with multiple computations. The called computations can have
1097   // different function signatures.
1098   static std::unique_ptr<HloInstruction> CreateCustomCall(
1099       const Shape& shape, absl::Span<HloInstruction* const> operands,
1100       absl::Span<HloComputation* const> called_computations,
1101       absl::string_view custom_call_target, string opaque = "",
1102       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1103 
1104   // Overload which constrains the layouts of the operand and result. 'shape'
1105   // and 'operand_shapes_with_layout' must have layouts.
1106   // 'operand_shapes_with_layout' must have a compatible element for each
1107   // operand.
1108   static std::unique_ptr<HloInstruction> CreateCustomCall(
1109       const Shape& shape, absl::Span<HloInstruction* const> operands,
1110       absl::string_view custom_call_target,
1111       absl::Span<const Shape> operand_shapes_with_layout, string opaque = "",
1112       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1113 
1114   // Creates a tuple instruction with the given elements. This is a convenience
1115   // wrapper around CreateVariadic.
1116   static std::unique_ptr<HloInstruction> CreateTuple(
1117       absl::Span<HloInstruction* const> elements);
1118 
1119   // Creates a reverse instruction, which reverses the order of the elements
1120   // in the specified dimensions.
1121   static std::unique_ptr<HloInstruction> CreateReverse(
1122       const Shape& shape, HloInstruction* operand,
1123       absl::Span<const int64> dimensions);
1124 
1125   // Creates a Afterall instruction used for joining or creating new values of
1126   // token type which thread through side-effecting operations. Operands must
1127   // all be tokens, and there must be at least one operand.
1128   static std::unique_ptr<HloInstruction> CreateAfterAll(
1129       absl::Span<HloInstruction* const> operands);
1130 
1131   // Creates an AfterAll instruction which creates a token type out of thin air
1132   // (no operands). This is a separate method from CreateAfterAll to facility
1133   // the removal of operand-less AfterAll instructions.
1134   // TODO(b/110532604): Remove this capability of creating a token from nothing
1135   // when we plumb a primordial token from the entry computation.
1136   static std::unique_ptr<HloInstruction> CreateToken();
1137 
1138   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
1139       const Shape& shape, HloInstruction* operand, int64_t dimension);
1140 
1141   static std::unique_ptr<HloInstruction> CreateSetDimensionSize(
1142       const Shape& shape, HloInstruction* operand, HloInstruction* val,
1143       int64_t dimension);
1144 
1145   static std::unique_ptr<HloInstruction> CreateAddDependency(
1146       HloInstruction* data_operand, HloInstruction* token_operand);
1147 
1148   // Returns the opcode for this instruction.
opcode()1149   HloOpcode opcode() const { return opcode_; }
mutable_opcode()1150   HloOpcode* mutable_opcode() { return &opcode_; }
1151 
1152   // Returns true if this instruction has a side effect, irrespective of whether
1153   // any called computations may contain an instruction with side effects.
1154   bool HasSideEffectNoRecurse() const;
1155 
1156   // Returns true if this instruction has a side effect. An instruction has a
1157   // side effect if it uses certain opcodes or calls a computation with a side
1158   // effect.
1159   bool HasSideEffect() const;
1160 
1161   // Returns the result shape of this instruction.
1162   const Shape& shape() const;
1163 
1164   // Returns the (mutable) result shape of this instruction.
mutable_shape()1165   Shape* mutable_shape() { return &shape_; }
1166 
1167   // Returns the ith operand to this instruction.
1168   const HloInstruction* operand(int64_t i) const;
1169 
1170   // Returns the ith operand to this instruction.
1171   HloInstruction* mutable_operand(int64_t i);
1172 
1173   // Returns the number of operands to this instruction.
operand_count()1174   int64 operand_count() const { return operands_.size(); }
1175 
1176   // Returns the vector of operands of this instruction.
1177   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
operands()1178   const InstructionVector& operands() const { return operands_; }
1179 
1180   // Returns the vector of unique operands, in the same order they are found
1181   // within the operand vector.
1182   InstructionVector unique_operands() const;
1183 
1184   // Returns the index of 'target' in the operands sequence.
1185   // Precondition: target must be an operand (or a fatal error will occur).
1186   int64 operand_index(const HloInstruction* target) const;
1187 
1188   // Returns the number of users of this instruction.
user_count()1189   int64 user_count() const { return users_.size(); }
1190 
1191   // Returns the users of this instruction.
users()1192   const std::vector<HloInstruction*>& users() const { return users_; }
1193 
1194   // Returns the index of the user in the users() vector.
1195   //
1196   // Precondition: `user` is a user of the instruction.
1197   int64 UserId(HloInstruction* user);
1198 
1199   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)1200   bool IsUserOf(const HloInstruction* instruction) const {
1201     return ContainsKey(instruction->user_map_, this);
1202   }
1203 
1204   // Adds a control dependency from this instruction to the given
1205   // instruction. This instruction becomes a control predecessor of
1206   // 'instruction', and 'instruction' becomes a control successor of this
1207   // instruction. Returns an error status if either of the given instructions
1208   // does not belong to the same computation.
1209   //
1210   // This is used to enforce an additional ordering requirement that is not
1211   // captured by normal data dependencies, such as ordering among Send or Recv
1212   // operations to avoid deadlock.
1213   Status AddControlDependencyTo(HloInstruction* instruction);
1214 
1215   // Removes a previously added control dependency from this instruction to
1216   // 'instruction'.
1217   Status RemoveControlDependencyTo(HloInstruction* instruction);
1218 
1219   // Drops all control predecessors and successors from this HLO instruction.
1220   Status DropAllControlDeps();
1221 
1222   // Copies the control predecessors and successors on this HLO instruction to
1223   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
1224   // this HLO are in the same module.
1225   //
1226   // Depending on the use cases we see in practice, in the future we may
1227   // consider folding the logic here into Clone, CloneWithNewOperands and
1228   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
1229   Status CopyAllControlDepsFrom(const HloInstruction* inst);
1230 
1231   // Returns the set of control predecessors (successors) of this
1232   // instruction. Control predecessors (successors) must execute before (after)
1233   // the current instruction.
control_predecessors()1234   const std::vector<HloInstruction*>& control_predecessors() const {
1235     return control_predecessors_;
1236   }
control_successors()1237   const std::vector<HloInstruction*>& control_successors() const {
1238     return control_successors_;
1239   }
1240 
1241   // Returns true if "other" performs the same computation as this instruction.
1242   bool Identical(
1243       const HloInstruction& other,
1244       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1245           eq_operands = std::equal_to<const HloInstruction*>(),
1246       const std::function<bool(const HloComputation*, const HloComputation*)>&
1247           eq_computations = std::equal_to<const HloComputation*>(),
1248       bool layout_sensitive = true) const {
1249     return IdenticalInternal(other, eq_operands, eq_computations,
1250                              layout_sensitive,
1251                              /*ignore_channel_id_values=*/false);
1252   }
1253 
1254   // Same as Identical() but ignores channel ID value mismatches, as long as
1255   // both have channel IDs or neither has a channel ID.
1256   bool IdenticalIgnoringChannelIdValues(
1257       const HloInstruction& other,
1258       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1259           eq_operands = std::equal_to<const HloInstruction*>(),
1260       const std::function<bool(const HloComputation*, const HloComputation*)>&
1261           eq_computations = std::equal_to<const HloComputation*>(),
1262       bool layout_sensitive = true) const {
1263     return IdenticalInternal(other, eq_operands, eq_computations,
1264                              layout_sensitive,
1265                              /*ignore_channel_id_values=*/true);
1266   }
1267 
1268   // Generates a hash value of an HLO instruction. Hash considers
1269   // information on opcode, shape, operands, and typically a root instruction.
1270   // This function returns the same hash value for equivalent HLO instructions,
1271   // with respect to HloInstruction::Identical() method.
1272   //
1273   // Uses hash_operand function to compute hash values of its operands.
1274   // At the very top level, hash_operand should be non-recursive to prevent
1275   // non-termination.
1276   uint64 Hash(
1277       const std::function<uint64(const HloInstruction*)>& hash_operand) const;
1278 
1279   // Calls the above method with non-recursive hash_operand function.
1280   uint64 Hash() const;
1281 
1282   // Returns whether the instruction has a constant operand.
1283   bool HasConstantOperand() const;
1284 
1285   // Replaces the use of this instruction in "user" with "new_producer". Note
1286   // that there might be multiple uses of this instruction in "user"; all will
1287   // be replaced.
1288   //
1289   // If user is a fusion instruction, this function will remove any duplicated
1290   // operands of it which could be created due to this replacement.
1291   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
1292 
1293   // Same as ReplaceUseWith(), but new_producer can have a different shape.
1294   Status ReplaceUseWithDifferentShape(HloInstruction* user,
1295                                       HloInstruction* new_producer);
1296 
1297   // Replaces the specified operand with new_operand. The old and new operands
1298   // must have compatible shapes ignoring floating-point precision.
1299   //
1300   // This function does NOT remove duplicated operands even if this instruction
1301   // is a fusion, so that the existing operand numbers do not change.
1302   Status ReplaceOperandWith(int64_t operand_num, HloInstruction* new_operand);
1303 
1304   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
1305   Status ReplaceOperandWithDifferentShape(int64_t operand_num,
1306                                           HloInstruction* new_operand);
1307 
1308   // Replaces all uses of this instruction with the new producer. If
1309   // new_producer is a user of this instruction then new_producer remains a use
1310   // of this instruction to avoid introducing cycles into the graph.
1311   //
1312   // If this instruction is the root of its computation, sets the computation's
1313   // root to new_producer.
1314   //
1315   // The new producer must have a compatible shape ignoring floating-point
1316   // precision.
1317   //
1318   // If a user is a fusion instruction, this function will remove any duplicated
1319   // operands of it which could be created due to this replacement.
1320   Status ReplaceAllUsesWith(HloInstruction* new_producer);
1321 
1322   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
1323   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
1324 
1325   // Same as ReplaceAllUsesWith, but only replace given set of users.
1326   Status ReplaceUsesWith(absl::Span<HloInstruction* const> users,
1327                          HloInstruction* new_producer);
1328   Status ReplaceAllUsesWithDifferentShape(
1329       absl::Span<HloInstruction* const> users, HloInstruction* new_producer);
1330 
1331   // Performs a postorder DFS visit using this node as the root. If
1332   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
1333   // complete. If ignore_control_predecessors is true, instructions only
1334   // reachable via control dependencies will not be visited, and the postorder
1335   // will not take control dependencies into account. It is as if the control
1336   // dependencies didn't exist in the graph at all.
1337   template <typename HloInstructionPtr>
1338   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
1339                 bool call_finish_visit = true,
1340                 bool ignore_control_predecessors = false);
1341   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
1342                 bool ignore_control_predecessors = false) const {
1343     return const_cast<HloInstruction*>(this)->Accept(
1344         visitor, call_finish_visit, ignore_control_predecessors);
1345   }
1346 
1347   // Same as Accept() above, but the order of operand and control predecessor
1348   // visitation is determined by the given operand order; if compare(A, B) ==
1349   // true, A is visited before B.
1350   using CompareFunction =
1351       std::function<bool(const HloInstruction*, const HloInstruction*)>;
1352   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
1353                                 const CompareFunction& operand_order,
1354                                 bool call_finish_visit = true);
1355 
1356   // Visit this instruction and only this instruction with the given visitor.
1357   template <typename HloInstructionPtr>
1358   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
1359 
1360   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
1361   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
1362   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
1363   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
1364       const;
1365 
LatestNonGteAncestorAndIndex()1366   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
1367     auto rv =
1368         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
1369     return {const_cast<HloInstruction*>(rv.first), rv.second};
1370   }
1371 
1372   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
1373   const HloInstruction* LatestNonGteAncestor() const;
1374 
LatestNonGteAncestor()1375   HloInstruction* LatestNonGteAncestor() {
1376     return const_cast<HloInstruction*>(
1377         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
1378   }
1379 
1380   // Returns true whether this instruction is effectively a bitcast. Currently,
1381   // this means it either is a bitcast, or it is a transpose that is effectively
1382   // a bitcast.
1383   bool IsEffectiveBitcast() const;
1384 
1385   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
1386   // The setter should only be called by HloModule or HloComputation methods.
1387   //
1388   // Precondition: The instruction has a valid to_apply_ field.
1389   HloComputation* to_apply() const;
1390   void set_to_apply(HloComputation* to_apply);
1391 
1392   // Gets/sets the while_condition or while_body HloComputation for While. The
1393   // setters should only be called by HloModule or HloComputation methods.
1394   //
1395   // Precondition: The instruction is a While instruction.
1396   HloComputation* while_condition() const;
1397   HloComputation* while_body() const;
1398   void set_while_condition(HloComputation* while_condition);
1399   void set_while_body(HloComputation* while_body);
1400 
1401   HloInstruction* while_init() const;
1402 
1403   // Gets/sets the true and false HloComputation for Conditional.
1404   //
1405   // Precondition: The instruction is a predicated Conditional instruction.
1406   HloComputation* true_computation() const;
1407   HloComputation* false_computation() const;
1408 
1409   // Gets the branch HloComputations for Conditional.
1410   //
1411   // Precondition: The instruction is a Conditional instruction.
1412   const std::vector<HloComputation*>& branch_computations() const;
1413   int branch_count() const;
1414   HloComputation* branch_computation(int b) const;
1415   // Sets a branch HloComputation for Conditional.
1416   // The setter should only be called by HloModule or HloComputation methods.
1417   //
1418   // Precondition: The instruction is a Conditional instruction.
1419   void set_branch_computation(int b, HloComputation* computation);
1420 
1421   // Returns a string for the signature of this instruction if considered as a
1422   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
1423   string SignatureString() const;
1424 
1425   // Returns a debugging string that represents this instruction.
1426   //
1427   // (We express the default options using an overload rather than a default
1428   // param because gdb ignores default params, but does resolve overloads.)
1429   //
1430   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
1431   // default, backing off on providing full information for very large strings,
1432   // or provide a different name for a ToString-like function that does that.
ToString()1433   string ToString() const { return ToString(HloPrintOptions()); }
1434   string ToString(const HloPrintOptions& options) const;
1435 
1436   // Components of the ToString() representation:
1437 
1438   // Returns a string representation of the operand list.
1439   string OperandsToString(const HloPrintOptions& options) const;
1440 
1441   // Returns string representation of op-specific attributes.
1442   std::vector<string> ExtraAttributesToString(
1443       const HloPrintOptions& options) const;
1444 
1445   // As ToString, but returns a shorter string.
1446   string ToShortString() const;
1447 
1448   // Prints an instruction to a string.
1449   //
1450   // The canonical string representation needs to name operands and instruction
1451   // names in a consistent way. This is implemented through the
1452   // canonical_name_map.
1453   string ToStringWithCanonicalNameMap(
1454       const HloPrintOptions& options,
1455       CanonicalNameMap* canonical_name_map) const;
1456 
1457   // Returns a serialized representation of this instruction.
1458   virtual HloInstructionProto ToProto() const;
1459 
1460   // Returns a category for the HLO. This could be something like "convolution"
1461   // or "elementwise".
1462   virtual string ToCategory() const;
1463 
1464   // Returns a logging instruction, if the output of this instruction is logged.
1465   //
1466   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
1467   HloInstruction* tracing() const;
1468   void set_tracing(HloInstruction* trace_instruction);
1469 
1470   // Returns true if this instruction is fused, ie contained within a fusion
1471   // instruction.
1472   bool IsFused() const;
1473 
1474   bool IsLoopFusion() const;
1475   bool IsInputFusion() const;
1476   bool IsOutputFusion() const;
1477   bool IsCustomFusion() const;
1478 
1479   // Returns true if this instruction can be legally fused into a fusion
1480   // instruction.
1481   bool IsFusible() const;
1482 
1483   bool IsCustomCall(absl::string_view target) const;
1484 
1485   // Returns the sharding applied to this operator.
1486   // REQUIRES: has_sharding() is true.
sharding()1487   const HloSharding& sharding() const {
1488     CHECK(has_sharding());
1489     return *sharding_;
1490   }
sharding_ptr()1491   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
1492 
1493   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)1494   const HloSharding& sharding_or_default(const HloSharding& default_) const {
1495     return sharding_ ? *sharding_ : default_;
1496   }
1497   // Returns the sharding unique device, if any.
sharding_unique_device()1498   absl::optional<int64> sharding_unique_device() const {
1499     if (sharding_ == nullptr) {
1500       return absl::optional<int64>();
1501     }
1502     return sharding_->UniqueDevice();
1503   }
1504   // Sets the sharding of this operator. Should only be called by HloModule or
1505   // HloComputation methods.
set_sharding(const HloSharding & sharding)1506   void set_sharding(const HloSharding& sharding) {
1507     sharding_ = std::make_shared<const HloSharding>(sharding);
1508   }
set_sharding(std::shared_ptr<const HloSharding> sharding)1509   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
1510     sharding_ = std::move(sharding);
1511   }
1512   void set_single_sharding(const HloSharding& sharding);
1513   // Sets a sharding that assigns the current instruction to device.
set_device_sharding(int64_t device)1514   void set_device_sharding(int64_t device) {
1515     set_single_sharding(HloSharding::AssignDevice(device));
1516   }
1517   // Remove any sharding from this operator.
clear_sharding()1518   void clear_sharding() { sharding_ = nullptr; }
1519   // Return true if this operator has a sharding assigned.
has_sharding()1520   bool has_sharding() const { return sharding_ != nullptr; }
1521   // Checks whether the instruction has compatible sharding with the other
1522   // instruction.
has_compatible_sharding(const HloInstruction * other)1523   bool has_compatible_sharding(const HloInstruction* other) const {
1524     if (!has_sharding()) {
1525       return !other->has_sharding();
1526     }
1527     return other->has_sharding() ? sharding() == other->sharding() : false;
1528   }
1529 
1530   // When creating a new instruction which either replaces, or shifts up (kCopy
1531   // insertion case), another instruction, we need to make sure the certain
1532   // properties of the new instruction are copied into the derived one. As of
1533   // today, the metadata and sharding will be propagated to the derived
1534   // instruction.
1535   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
1536 
1537   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1538   // operands. After creation the clone has no uses. "this" (the instruction
1539   // cloned from) is not changed. Suffix is the string to append to the name of
1540   // the instruction to form the name of the cloned instruction.
1541   // Ignores the control predecessors and successors of this HLO instruction.
1542   std::unique_ptr<HloInstruction> Clone(
1543       const string& suffix = "clone", HloCloneContext* context = nullptr) const;
1544 
1545   // Clones the HLO instruction as above but with new shape.
1546   std::unique_ptr<HloInstruction> CloneWithNewShape(
1547       const Shape& shape, const string& suffix = "clone",
1548       HloCloneContext* context = nullptr) const;
1549 
1550   // Clones the HLO instruction as above but with new shape and operands.
1551   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1552       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1553       HloCloneContext* context = nullptr) const;
1554 
1555   // Returns the computations this instruction directly calls (if any).
called_computations()1556   const std::vector<HloComputation*>& called_computations() const {
1557     return called_computations_;
1558   }
1559 
1560   // Replaces all called computations based on a map function. This is needed
1561   // when we clone hlo_computations and want to let the instructions to point
1562   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1563   void ReplaceCalledComputations(
1564       std::function<HloComputation*(HloComputation*)> map_function) {
1565     for (int64_t i = 0; i < called_computations_.size(); ++i) {
1566       called_computations_[i] = map_function(called_computations_[i]);
1567     }
1568   }
1569 
1570   // Clears out the called computations.
1571   //
1572   // This is, in particular, necessary when inlining function bodies into their
1573   // caller. If there were side-effecting operations in the called computations,
1574   // the call itself is considered side-effecting and thus cannot be removed. By
1575   // clearing out the computations, we reflect the fact that all side-effecting
1576   // properties have been reflected in the caller, and make the call HLO
1577   // removable.
ClearCalledComputations()1578   virtual void ClearCalledComputations() { called_computations_.clear(); }
1579 
1580   // Returns true if this instruction performs an elementwise operation on
1581   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1582   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
1583   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
1584   //
1585   // Note on performance: when this instruction is kFusion, this method, in the
1586   // worst case, scans all fused instructions. We could speed this up by
1587   // caching.
1588   bool IsElementwiseOnOperand(int64_t operand_idx) const;
1589 
1590   // Returns true if this instruction is elementwise on all its operands.
1591   bool IsElementwise() const;
1592 
1593   static bool IsOpElementwise(HloOpcode opcode);
1594 
1595   // Returns true if this is a cross module all-reduce instruction.
1596   bool IsCrossModuleAllReduce() const;
1597 
1598   // Returns true if this is a cross-replica all-reduce instruction.
1599   bool IsCrossReplicaAllReduce() const;
1600 
1601   // Returns true if this instruction is binary and elementwise.
1602   bool IsElementwiseBinary() const;
1603 
1604   // Returns whether this instruction may reuse elements of its `i`th operand.
ReusesOperandElements(int64_t i)1605   bool ReusesOperandElements(int64_t i) const {
1606     return OperandElementUse(i) == UseKind::kReuse;
1607   }
1608 
1609   // Returns the indices that the given operand appear in the operand list of
1610   // this instruction. Note that an instruction can use the same operand
1611   // multiple times.
1612   absl::InlinedVector<int64, 4> OperandIndices(
1613       const HloInstruction* operand) const;
1614 
1615   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1616   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1617   // indices of the deleted dimensions and the output indices of the inserted
1618   // dimensions.
1619   //
1620   // Precondition: this op must be a reshape.
1621   std::tuple<bool, std::vector<int64>, std::vector<int64>>
1622   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1623 
1624   // Gets the string identifier for this instruction.
name()1625   const string& name() const { return name_; }
1626 
1627   // Sets the string identifier for this instruction. Name will be sanitized to
1628   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
SetAndSanitizeName(const string & name)1629   void SetAndSanitizeName(const string& name) {
1630     name_ = NameUniquer::GetSanitizedName(name);
1631   }
1632 
1633   // Use the given NameUniquer to select a unique name for the instruction based
1634   // on the instruction's existing name.
1635   void UniquifyName(NameUniquer* name_uniquer);
1636 
1637   // Clear the unique ID of the instruction so that it can be re-assigned, such
1638   // as for the purpose of compacting the instruction unique IDs.
ClearUniqueIdInternal()1639   void ClearUniqueIdInternal() { unique_id_ = -1; }
1640 
1641   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1642   void SetUniqueId(int id) {
1643     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1644     CHECK_GE(id, 0);
1645     unique_id_ = id;
1646   }
1647 
1648   // Return the unique ID assigned to this node via SetUniqueId (or -1
1649   // if no id has been assigned yet).
unique_id()1650   int unique_id() const { return unique_id_; }
1651 
1652   // Returns the backend-specific configuration for how a backend should compile
1653   // this HLO. The meaning of the field is backend specific. Not for use before
1654   // or during general HLO optimization, since HLO optimizations do not preserve
1655   // this field and they cannot interpret it due to its meaning being backend
1656   // specific. Except for CustomCall, where this field is preserved and no
1657   // general HLO optimization needs to interpret it.
1658   //
1659   // ConfigProto should be a protobuf Message type.
1660   template <typename ConfigProto>
backend_config()1661   StatusOr<ConfigProto> backend_config() const {
1662     ConfigProto proto;
1663     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
1664     return std::move(proto);
1665   }
1666   Status set_backend_config(const tensorflow::protobuf::Message& proto);
1667 
set_frontend_attributes(FrontendAttributes frontend_attributes)1668   void set_frontend_attributes(FrontendAttributes frontend_attributes) {
1669     frontend_attributes_ = std::move(frontend_attributes);
1670   }
1671 
add_frontend_attributes(FrontendAttributes frontend_attributes)1672   void add_frontend_attributes(FrontendAttributes frontend_attributes) {
1673     frontend_attributes_.mutable_map()->insert(
1674         frontend_attributes.map().begin(), frontend_attributes.map().end());
1675   }
1676 
frontend_attributes()1677   const FrontendAttributes& frontend_attributes() const {
1678     return frontend_attributes_;
1679   }
1680 
1681   // Getter/setter for raw JSON-encoded backend config.  Prefer the
1682   // functions above that deal in proto Messages where possible.
raw_backend_config_string()1683   const string& raw_backend_config_string() const { return backend_config_; }
set_raw_backend_config_string(string config_str)1684   void set_raw_backend_config_string(string config_str) {
1685     backend_config_ = std::move(config_str);
1686   }
1687 
is_default_config()1688   bool is_default_config() const { return is_default_config_; }
set_default_config()1689   void set_default_config() { is_default_config_ = true; }
1690 
1691   // Returns a string representation of a proto in the format used by
1692   // raw_backend_config_string.
1693   //
1694   // This is morally equivalent to:
1695   //
1696   //   HloInstruction instr;
1697   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
1698   //   return instr.raw_backend_config_string();
1699   //
1700   static StatusOr<string> BackendConfigToRawString(
1701       const tensorflow::protobuf::Message& proto);
1702 
1703   // Returns the information used to tell the implementation information about
1704   // what sort of precision is requested. The meaning of the field is backend
1705   // specific. At the moment, it is only supported for kConvolution and kDot.
1706   // Transformations on one kDot or kConvolution to another will preserve this
1707   // information. Transformations to other HLOs will not preserve this
1708   // information but it is presumed that the alternate lowering is strictly
1709   // superior.
1710   // Precondition: opcode must be kConvolution or kDot.
1711   const PrecisionConfig& precision_config() const;
1712   PrecisionConfig* mutable_precision_config();
1713 
1714   // Sets the debug metadata for this instruction, excluding creation_pass_id,
1715   // which should never be copied anywhere.
set_metadata(const OpMetadata & metadata)1716   void set_metadata(const OpMetadata& metadata) {
1717     int64_t creation_pass_id = metadata_.creation_pass_id();
1718     metadata_ = metadata;
1719     metadata_.set_creation_pass_id(creation_pass_id);
1720   }
1721 
set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes)1722   void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) {
1723     metadata_.set_size_of_generated_code_in_bytes(code_size_in_bytes);
1724   }
set_size_of_memory_working_set_in_bytes(int64_t working_set_size_in_bytes)1725   void set_size_of_memory_working_set_in_bytes(
1726       int64_t working_set_size_in_bytes) {
1727     metadata_.set_size_of_memory_working_set_in_bytes(
1728         working_set_size_in_bytes);
1729   }
set_creation_pass_id(int64_t pass_id)1730   void set_creation_pass_id(int64_t pass_id) {
1731     metadata_.set_creation_pass_id(pass_id);
1732   }
set_metadata_op_name(const std::string & name)1733   void set_metadata_op_name(const std::string& name) {
1734     metadata_.set_op_name(name);
1735   }
set_logical_creation_pass_id(int64_t pass_id)1736   void set_logical_creation_pass_id(int64_t pass_id) {
1737     metadata_.set_logical_creation_pass_id(pass_id);
1738   }
metadata()1739   const OpMetadata& metadata() const { return metadata_; }
1740 
1741   // Set/get the computation containing this instruction. set_parent should only
1742   // be called by HloComputation methods which add/remove instructions to
1743   // computations.
set_parent(HloComputation * computation)1744   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1745   const HloComputation* parent() const { return parent_; }
parent()1746   HloComputation* parent() { return parent_; }
1747 
1748   // Returns the module for this instruction.
1749   HloModule* GetModule() const;
1750 
1751   // Get/Set the number of partitions per outer dimension (in order, starting
1752   // with outer-most dimension first). Currently used by the parallel cpu
1753   // backend to partition HLOs into parallel tasks.
1754   //
1755   // TODO(b/62783254) Replace these methods with a more general way to
1756   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1757   const std::vector<int64>& outer_dimension_partitions() const {
1758     return outer_dimension_partitions_;
1759   }
1760   void set_outer_dimension_partitions(
1761       const std::vector<int64>& outer_dimension_partitions);
1762 
1763   // Old methods kept for smooth subclassing transition BEGIN.
1764   // TODO(b/80131774): Remove this code.
1765 
1766   // Delegates to HloBatchNormInstruction::feature_index.
1767   int64 feature_index() const;
1768 
1769   // Delegates to HloBatchNormInstruction::epsilon.
1770   float epsilon() const;
1771 
1772   // Delegates to HloFftInstruction::fft_type.
1773   FftType fft_type() const;
1774 
1775   // Delegates to HloFftInstruction::fft_length.
1776   const std::vector<int64>& fft_length() const;
1777 
1778   // Delegates to HloChannelInstruction::channel_id.
1779   absl::optional<int64> channel_id() const;
1780   void set_channel_id(const absl::optional<int64>& channel_id);
1781 
1782   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1783   virtual const std::vector<int64>& dimensions() const {
1784     LOG(FATAL) << "Unimplemented method.";
1785   }
dimensions(int64_t index)1786   virtual int64 dimensions(int64_t index) const {
1787     LOG(FATAL) << "Unimplemented method.";
1788   }
mutable_dimensions()1789   virtual std::vector<int64>* mutable_dimensions() {
1790     LOG(FATAL) << "Unimplemented method.";
1791   }
1792 
1793   // Delegates to HloConcatenateInstruction::concatenate_dimension.
1794   int64 concatenate_dimension() const;
1795 
1796   // Delegates to HloGetDimensionSizeInstruction::dimension.
1797   int64 dimension() const;
1798 
1799   // Delegates to HloReshapeInstruction::inferred_dimension.
1800   int64 inferred_dimension() const;
1801 
1802   // Returns whether this instruction does a rank-2 transposition.
1803   bool IsRank2Transpose() const;
1804 
1805   // Delegates to HloSliceInstruction::slice_start.
1806   int64 slice_starts(int64_t dimension) const;
1807   const std::vector<int64>& slice_starts() const;
1808   std::vector<int64>* mutable_slice_starts();
1809 
1810   // Delegates to HloSliceInstruction::slice_limits.
1811   int64 slice_limits(int64_t dimension) const;
1812   const std::vector<int64>& slice_limits() const;
1813   std::vector<int64>* mutable_slice_limits();
1814 
1815   // Delegates to HloSliceInstruction::slice_strides.
1816   int64 slice_strides(int64_t dimension) const;
1817   const std::vector<int64>& slice_strides() const;
1818   std::vector<int64>* mutable_slice_strides();
1819 
1820   // Returns the literal associated with this instruction.
1821   const Literal& literal() const;
1822 
1823   // Returns whether the instruction is a constant.
1824   bool IsConstant() const;
1825 
1826   // Delegate to HloConstantInstruction::RelayoutConstant.
1827   void RelayoutConstant(const Layout& new_layout,
1828                         const ShapeIndex& shape_index = {});
1829 
1830   // Delegates to HloTraceInstruction::TracingTag.
1831   string TracingTag() const;
1832 
1833   // Delegates to HloFusionInstruction::AddFusionOperand.
1834   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1835 
1836   // Delegates to HloFusionInstruction::MergeFusionInstruction.
1837   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
1838 
1839   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
1840   void MergeFusionInstructionIntoMultiOutput(
1841       HloInstruction* instruction_to_merge);
1842 
1843   // Delegates to HloFusionInstruction::FuseInstruction.
1844   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
1845 
1846   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
1847   HloInstruction* FuseInstructionIntoMultiOutput(
1848       HloInstruction* instruction_to_fuse);
1849 
1850   // Delegates to HloFusionInstruction::fused_instruction.
1851   HloComputation* fused_instructions_computation() const;
1852 
1853   // Delegates to HloFusionInstruction::fused_expression_root.
1854   HloInstruction* fused_expression_root() const;
1855 
1856   // Delegates to HloFusionInstruction::fused_instructions.
1857   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1858       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1859   fused_instructions() const;
1860 
1861   const tensorflow::gtl::iterator_range<
1862       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1863   fused_instructions();
1864 
1865   // Delegates to HloFusionInstruction::fused_instruction_count.
1866   int64 fused_instruction_count() const;
1867 
1868   // Delegates to HloFusionInstruction::fused_parameter.
1869   HloInstruction* fused_parameter(int64_t parameter_number) const;
1870 
1871   // Delegates to HloFusionInstruction::fused_parameters.
1872   const std::vector<HloInstruction*>& fused_parameters() const;
1873 
1874   // Returns true if this instruction is a fusion instruction that generates
1875   // multiple outputs.
1876   const bool IsMultiOutputFusion() const;
1877 
1878   // Delegates to HloFusionInstruction::fusion_kind.
1879   FusionKind fusion_kind() const;
1880 
1881   // Delegates to HloFusionInstruction::set_fusion_kind.
1882   void set_fusion_kind(FusionKind kind);
1883 
1884   // Delegates to HloRngInstruction::random_distribution.
1885   RandomDistribution random_distribution() const;
1886 
1887   // Delegates to HloParameterInstruction::parameter_number.
1888   int64 parameter_number() const;
1889 
1890   // Delegates to
1891   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
1892   void set_parameter_replicated_at_leaf_buffers(
1893       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
1894   void set_parameter_replicated_at_leaf_buffers(
1895       const std::vector<bool>& parameter_replicated_at_leaf_buffers);
1896 
1897   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
1898   const absl::optional<std::vector<bool>>&
1899   parameter_replicated_at_leaf_buffers() const;
1900 
1901   // Delegates to HloGetTupleElementInstruction::tuple_index.
1902   int64 tuple_index() const;
1903 
1904   // Delegates to HloGetTupleElementInstruction::set_tuple_index.
1905   void set_tuple_index(int64_t new_tuple_index);
1906 
1907   // Delegates to HloReducePrecisionInstruction::exponent_bits.
1908   int32 exponent_bits() const;
1909 
1910   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
1911   int32 mantissa_bits() const;
1912 
1913   // Delegates to HloInfeedInstruction::infeed_config.
1914   string infeed_config() const;
1915 
1916   // Delegates to HloInfeedInstruction::set_infeed_config.
1917   void set_infeed_config(const string& config);
1918 
1919   // Returns the config for the Outfeed instruction.
1920   const string& outfeed_config() const;
1921 
1922   // Delegates to HloOutfeedInstruction::set_outfeed_config.
1923   void set_outfeed_config(const string& config);
1924 
1925   // Returns the shape for the Outfeed instruction.
1926   const Shape& outfeed_shape() const;
1927 
1928   // Returns the mutable shape for the Outfeed instruction.
1929   Shape* mutable_outfeed_shape();
1930 
1931   // Delegates to HloCollectiveInstruction::replica_groups.
1932   const std::vector<ReplicaGroup>& replica_groups() const;
1933 
1934   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
1935   const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
1936 
1937   // Returns data on the window in a windowed operation such as
1938   // convolution.
window()1939   virtual const Window& window() const {
1940     LOG(FATAL) << "Unimplemented method.";
1941   }
1942 
1943   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)1944   virtual void set_window(const Window& window) {
1945     LOG(FATAL) << "Unimplemented method.";
1946   }
1947 
1948   // Returns the unique_indices field.
unique_indices()1949   virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; }
1950 
1951   // Returns data on the dimension numbers used for a convolution operation,
1952   // which may be a kConvolution instruction or a kCustomCall that implements a
1953   // convolution.
1954   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
1955 
1956   // Sets the convolution dimension numbers on this instruction.  In general you
1957   // shouldn't need to call this; instead, specify the convolution dimension
1958   // numbers when you create the instruction.
1959   void set_convolution_dimension_numbers(
1960       const ConvolutionDimensionNumbers& dnums);
1961 
1962   // The number of feature groups. Must be a divisor of the input feature
1963   // dimension and output feature dimension.
1964   int64 feature_group_count() const;
1965 
1966   void set_feature_group_count(int64_t feature_group_count);
1967 
1968   // The number of batch groups. Must be a divisor of the input batch dimension
1969   int64 batch_group_count() const;
1970 
1971   void set_batch_group_count(int64_t batch_group_count);
1972 
1973   // Delegates to HloSelectAndScatterInstruction::select.
1974   HloComputation* select() const;
1975 
1976   // Delegates to HloSelectAndScatterInstruction::scatter.
1977   HloComputation* scatter() const;
1978 
1979   // Delegates to HloSelectAndScatterInstruction::set_select.
1980   void set_select(HloComputation* computation);
1981 
1982   // Delegates to HloSelectAndScatterInstruction::set_scatter.
1983   void set_scatter(HloComputation* computation);
1984 
1985   // Delegates to HloCustomCallInstruction::custom_call_target.
1986   const string& custom_call_target() const;
1987 
1988   // Delegates to HloPadInstruction::padding_config.
1989   const PaddingConfig& padding_config() const;
1990   PaddingConfig* mutable_padding_config();
1991 
1992   // Delegates to HloConvolutionInstruction::padding_type.
1993   PaddingType padding_type() const;
1994 
1995   // Delegates to HloDynamicSliceInstruction::slice_sizes.
1996   int64 slice_sizes(int64_t dimension) const;
1997 
1998   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
1999   const std::vector<int64>& dynamic_slice_sizes() const;
2000 
2001   // Delegates to HloCollectivePermuteInstruction::dynamic_slice_sizes.
2002   const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const;
2003 
2004   // Delegates to HloGatherInstruction::gather_dimension_numbers.
2005   const GatherDimensionNumbers& gather_dimension_numbers() const;
2006   // Delegates to HloGatherInstruction::gather_slice_sizes.
2007   absl::Span<const int64> gather_slice_sizes() const;
2008 
2009   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
2010   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
2011 
2012   // Delegates to HloDotInstruction::dot_dimension_numbers().
2013   const DotDimensionNumbers& dot_dimension_numbers() const;
2014 
2015   // Delegates to HloDomainInstruction::operand_side_metadata().
2016   const DomainMetadata& operand_side_metadata() const;
2017 
2018   // Delegates to HloDomainInstruction::user_side_metadata().
2019   const DomainMetadata& user_side_metadata() const;
2020 
2021   // Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
2022   bool is_cross_program_prefetch() const;
2023 
2024   // Delegates to HloCompareInstruction::direction().
2025   ComparisonDirection comparison_direction() const;
2026 
2027   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
2028   const TriangularSolveOptions& triangular_solve_options() const;
2029 
2030   // Delegates to HloCholeskyInstruction::cholesky_options().
2031   const CholeskyOptions& cholesky_options() const;
2032 
2033   // Appends operand to the list of operands and adds this instruction as a user
2034   // of the operand.
2035   void AppendOperand(HloInstruction* operand);
2036 
2037   // Old methods kept for smooth subclassing transition END.
2038 
2039  protected:
2040   // Indicates how an instruction uses a value (such as an operand).
2041   //
2042   // Does it (a) not use it, (b) use it, or (c) use it multiple times?
2043   //
2044   // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or
2045   // (ii) use it after having permuted it somehow, e.g. through a reshape.  If
2046   // the use is a permuting use, we set permutation_instr to the instruction
2047   // that did the permuting.
2048   struct UseKind {
2049     enum Kind { kReuse, kUse, kNoUse };
2050 
2051     // Creates a UseKind that represents a use that permutes an instruction's
2052     // elements according to the given instruction.
PermutingUseKind2053     static UseKind Permuting(const HloInstruction* permutation_instr) {
2054       UseKind k(kUse);
2055       k.permutation_instr = permutation_instr;
2056       return k;
2057     }
2058 
UseKindUseKind2059     UseKind(Kind kind)  // NOLINT intentionally nonexplicit
2060         : kind(kind), permutation_instr(nullptr) {}
2061 
2062     bool friend operator==(UseKind a, Kind b) { return a.kind == b; }
2063     bool friend operator==(Kind a, UseKind b) { return b == a; }
2064 
2065     Kind kind;
2066     const HloInstruction* permutation_instr;
2067   };
2068 
2069   // Helper class for computing OperandElementUse for kFusion.
2070   class FusionReusesParamElements;
2071 
2072   // Internal constructor for a given opcode/shape, other fields must be filled
2073   // by factory methods.
2074   HloInstruction(HloOpcode opcode, const Shape& shape);
2075 
RemoveAllOperands()2076   void RemoveAllOperands() { operands_.clear(); }
2077 
RemoveOperandAt(int index)2078   void RemoveOperandAt(int index) {
2079     operands_.erase(operands_.begin() + index);
2080   }
2081 
2082   // Removes a list of operands with the given indices in ascending order.
2083   void RemoveOperandsAtAscendingIndices(
2084       absl::Span<const int> ascending_indices);
2085 
AppendComputation(HloComputation * computation)2086   void AppendComputation(HloComputation* computation) {
2087     called_computations_.push_back(computation);
2088   }
2089 
DetachFrom(HloInstruction * usee)2090   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
2091 
set_called_computation(int index,HloComputation * computation)2092   void set_called_computation(int index, HloComputation* computation) {
2093     called_computations_[index] = computation;
2094   }
2095   // Indices of computations in called_computations_ for instructions which call
2096   // multiple computations.
2097   enum {
2098     // kWhile computations.
2099     kBodyComputationIndex = 0,
2100     kConditionComputationIndex = 1,
2101 
2102     // kSelectAndScatter computations.
2103     kSelectComputationIndex = 0,
2104     kScatterComputationIndex = 1,
2105 
2106     // kConditional computations.
2107     kTrueComputationIndex = 0,
2108     kFalseComputationIndex = 1,
2109   };
2110 
2111  private:
2112   friend class HloComputation;
2113 
2114   bool IdenticalInternal(
2115       const HloInstruction& other,
2116       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2117           eq_operands,
2118       const std::function<bool(const HloComputation*, const HloComputation*)>&
2119           eq_computations,
2120       bool layout_sensitive, bool ignore_channel_id_values) const;
2121 
2122   // Implementation for non-common logic of CloneWithNewOperands.
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2123   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2124       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2125       HloCloneContext* context) const {
2126     // TODO(b/80131774): This should be pure virtual.
2127     LOG(FATAL) << "Unimplemented method.";
2128   }
2129 
2130   // Implementation for non-common logic of ExtraAttributesToString.
ExtraAttributesToStringImpl(const HloPrintOptions & options)2131   virtual std::vector<string> ExtraAttributesToStringImpl(
2132       const HloPrintOptions& options) const {
2133     return {};
2134   }
2135 
2136   // Implementation for IsElementwise if operand_idx is nullopt and for
2137   // IsElementwiseOnOperand if otherwise.
2138   //
2139   // NOTE: For all instructions other than kFusion, being elementwise on one of
2140   // the operands is equivalent to being elementwise on all the operands.
2141   virtual bool IsElementwiseImpl(
2142       const absl::optional<int64>& operand_idx) const;
2143 
2144   // Prints an operand to a string. Accessed by friend class HloInstruction.
2145   virtual string OperandsToStringWithCanonicalNameMap(
2146       const HloPrintOptions& options,
2147       CanonicalNameMap* canonical_name_map) const;
2148 
2149   // See comments on Identical().
2150   virtual bool IdenticalSlowPath(
2151       const HloInstruction& other,
2152       const std::function<bool(const HloComputation*, const HloComputation*)>&
2153           eq_computations) const;
2154 
2155   // Generates a hash value specific to a particular type of an instruction.
2156   // This function typically considers the inner root instruction.
2157   virtual uint64 InnerHash() const;
2158 
2159   // Creates an n-ary elementwise operation.
2160   static std::unique_ptr<HloInstruction> CreateNary(
2161       const Shape& shape, HloOpcode opcode,
2162       absl::Span<HloInstruction* const> operands);
2163 
2164   // Adds a user for this instruction.
2165   void AddUser(HloInstruction* user);
2166 
2167   // Removes a user for this instruction.
2168   void RemoveUser(HloInstruction* user);
2169 
2170   // Returns how this instruction uses elements of its operand at operand_num.
2171   UseKind OperandElementUse(int64_t operand_num) const;
2172 
2173   // Helper for implementing backend_config().  Parses backend_config_ into the
2174   // given proto.
2175   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
2176 
2177   // Mark this instruction as dead. Accessed by friend class HloInstruction.
MarkAsDead()2178   void MarkAsDead() { marked_as_dead_ = true; }
2179 
2180   // Has this instruction been marked as dead? Accessed by friend class
2181   // HloInstruction.
IsMarkedAsDead()2182   bool IsMarkedAsDead() const { return marked_as_dead_; }
2183 
2184   int unique_id_;  // Unique to this HloInstruction within a HloModule
2185 
2186   // Opcode for this instruction.
2187   HloOpcode opcode_;
2188 
2189   // Instruction operands.
2190   InstructionVector operands_;
2191 
2192   // The set of control predecessors of this instruction.
2193   // Note that the order of the instructions in the vector influences the order
2194   // computed in HloComputation::ComputeInstructionPostOrder, which may
2195   // influence the result of the compilation by changing the scheduling. We are
2196   // not sure if it matters.
2197   std::vector<HloInstruction*> control_predecessors_;
2198 
2199   // The users of this instruction. Users are HLOs where this instruction is an
2200   // operand. The vector users_ and the map user_map_ contain identical members.
2201   // The map enables fast membership testing and the vector enables fast, stable
2202   // iteration. The value in the map contains the index of the instruction in
2203   // the vector what enables fast removal.
2204   std::vector<HloInstruction*> users_;
2205   absl::flat_hash_map<const HloInstruction*, int64> user_map_;
2206 
2207   // The set of control successors of this instruction.
2208   std::vector<HloInstruction*> control_successors_;
2209 
2210   // The computation in which this instruction is contained.
2211   HloComputation* parent_ = nullptr;
2212 
2213   // Result shape of this instruction.
2214   Shape shape_;
2215 
2216   // The sharding, if one exists.
2217   // Uses std::shared_ptr to allow reuse of the same sharding object between
2218   // HloInstructions and other components as HloSharding can be very large for
2219   // many element tuples.
2220   std::shared_ptr<const HloSharding> sharding_;
2221 
2222   // Computations called by this instruction.
2223   std::vector<HloComputation*> called_computations_;
2224 
2225   // A trace instruction that consumes this instruction.
2226   //
2227   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
2228   // an operand.
2229   HloInstruction* trace_instruction_ = nullptr;
2230 
2231   // The backend-specific configuration for how a backend should compile this
2232   // HLO. See the documentation on backend_config().
2233   string backend_config_;
2234 
2235   // Attributes passed from the frontend to give hints to the backend about
2236   // how to compile this HLO.
2237   // HLO -> HLO transforms are expected to preserve these attributes on a
2238   // "best effort" basis only.
2239   // For example:
2240   //    x = const(10, frontend_attributes={x}
2241   //    y = const(10, frontend_attributes={y}
2242   //    z = add(x,y), frontend_attributes={y}
2243   // Could be simplified to:
2244   //    z' = const(20), frontend_attributes={?}
2245   FrontendAttributes frontend_attributes_;
2246 
2247   // This field is assigned to true when backend_config_ is assigned to
2248   // a default configuration.
2249   bool is_default_config_ = false;
2250 
2251   // True if this instruction has already been detached from its user and
2252   // operands.
2253   bool cleaned_up_ = false;
2254 
2255   // String identifier for instruction.
2256   string name_;
2257 
2258   // Metadata for debugging.
2259   OpMetadata metadata_;
2260 
2261   // The number of partitions per outer dimension (listed in order from
2262   // outer-most dimension first).
2263   std::vector<int64> outer_dimension_partitions_;
2264 
2265   // Intrusive flag used by HloComputation, whether this instruction has
2266   // been marked as dead.
2267   bool marked_as_dead_;
2268 
2269   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
2270 };
2271 
2272 // Explicit instantiations in hlo_instruction.cc.
2273 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2274 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2275 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2276 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2277 
2278 string ToString(HloInstruction::FusionKind kind);
2279 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2280     const string& kind_name);
2281 
2282 // Custom (de)stringification functions for protos that live inside
2283 // HloInstruction.
2284 string PaddingConfigToString(const PaddingConfig& padding);
2285 string FrontendAttributesToString(
2286     const FrontendAttributes& frontend_attributes);
2287 string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
2288 string RandomDistributionToString(const RandomDistribution& distribution);
2289 string PrecisionToString(const PrecisionConfig::Precision& precision);
2290 string ConvolutionDimensionNumbersToString(
2291     const ConvolutionDimensionNumbers& dnums);
2292 string ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups);
2293 
2294 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name);
2295 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
2296 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
2297 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(absl::string_view name);
2298 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
2299     absl::string_view name);
2300 
2301 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
2302 
2303 // Map classes that guarantee a deterministic iteration order when the key is
2304 // an HloInstruction* or a const HloInstruction*.
2305 // To make the iteration order over the map deterministic, the comparator
2306 // should not be using the pointer values, but rather an intrinsic property of
2307 // the hlo. Exception: null pointer values compare less than non-null.
2308 struct HloPtrComparator {
2309   bool operator()(const HloInstruction* const& lhs,
2310                   const HloInstruction* const& rhs) const;
2311 };
2312 
2313 template <typename ValueT>
2314 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
2315 
2316 template <typename ValueT>
2317 using ConstHloInstructionMap =
2318     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
2319 
2320 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
2321 using ConstHloInstructionSet =
2322     std::set<const HloInstruction*, HloPtrComparator>;
2323 
2324 }  // namespace xla
2325 
2326 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
2327