• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <optional>
29 #include <ostream>
30 #include <set>
31 #include <string>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35 
36 #include "absl/container/flat_hash_map.h"
37 #include "absl/container/flat_hash_set.h"
38 #include "absl/container/inlined_vector.h"
39 #include "absl/hash/hash.h"
40 #include "absl/strings/str_cat.h"
41 #include "absl/strings/string_view.h"
42 #include "absl/types/span.h"
43 #include "tensorflow/compiler/xla/comparison_util.h"
44 #include "tensorflow/compiler/xla/iterator_util.h"
45 #include "tensorflow/compiler/xla/literal.h"
46 #include "tensorflow/compiler/xla/map_util.h"
47 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
48 #include "tensorflow/compiler/xla/service/hlo.pb.h"
49 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
50 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
51 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
52 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
53 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
54 #include "tensorflow/compiler/xla/service/name_uniquer.h"
55 #include "tensorflow/compiler/xla/shape_tree.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/xla_data.pb.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/lib/gtl/iterator_range.h"
60 #include "tensorflow/core/platform/logging.h"
61 #include "tensorflow/core/platform/protobuf.h"
62 
63 namespace xla {
64 
65 class HloComputation;
66 class HloModule;
67 
68 std::string PrintName(const std::string& name, bool print_ids);
69 
70 // A bunch of switches that control how the hlo text should be printed.
71 class HloPrintOptions {
72  public:
73   enum class PrintSubcomputationMode {
74     kOff,                  // Do not print anything about subcomputations.
75     kNameOnly,             // Only print the name of subcomputations.
76     kFullBodies,           // Print the full bodies of subcomputations.
77     kNonSequentialBodies,  // Print the full bodies of subcomputations that are
78                            // not in a sequential context.
79   };
80 
81   // Constructs the default print options: don't print large constants, don't
82   // compact operands, no indentation.
HloPrintOptions()83   HloPrintOptions()
84       : print_large_constants_(false),
85         print_only_essential_constants_(false),
86         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
87         print_metadata_(true),
88         print_backend_config_(true),
89         print_infeed_outfeed_config_(true),
90         compact_operands_(false),
91         include_layout_in_shapes_(true),
92         print_result_shape_(true),
93         print_operand_shape_(true),
94         print_operand_names_(true),
95         print_operand_index_annotation_interval_(5),
96         print_program_shape_(true),
97         print_percent_(true),
98         print_control_dependencies_(true),
99         canonicalize_instruction_names_(false),
100         indent_amount_(0),
101         is_in_nested_computation_(false),
102         print_ids_(true),
103         canonicalize_computations_(false),
104         print_extra_attributes_(true),
105         syntax_sugar_async_ops_(true) {}
106 
ShortParsable()107   static HloPrintOptions ShortParsable() {
108     return HloPrintOptions()
109         .set_print_large_constants(true)
110         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
111         .set_print_metadata(false)
112         .set_print_backend_config(false)
113         .set_print_operand_shape(false)
114         .set_print_operand_index_annotation_interval(0)
115         .set_print_program_shape(false)
116         .set_print_percent(false)
117         .set_print_control_dependencies(false);
118   }
119 
120   // Options to produce the canonical string representing an isomorphic
121   // computation graph.
Canonical()122   static HloPrintOptions Canonical() {
123     return HloPrintOptions()
124         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
125         .set_print_metadata(false)
126         .set_print_backend_config(false)
127         // Compact option won't print name of operand_k, where k > a threshold.
128         // Canonical (and Fingerprint) must include all operands.
129         .set_compact_operands(false)
130         .set_print_operand_names(false)
131         .set_print_operand_shape(true)
132         // No index annotations as they are only for ease of human inspection.
133         .set_print_operand_index_annotation_interval(0)
134         .set_print_program_shape(false)
135         .set_print_percent(false)
136         .set_print_control_dependencies(false)
137         .set_canonicalize_instruction_names(true);
138   }
139 
140   // Options to produce a fingerprint of an HLO instruction.
141   // Based on Canonical() with some important changes commented below.
Fingerprint()142   static HloPrintOptions Fingerprint() {
143     return Canonical()
144         // Exclude because they do not affect HLO optimizations.
145         .set_print_infeed_outfeed_config(false)
146         // Exclude floating point constant literals that are not all zeros, all
147         // ones, or integers because they may be randomly initialized weights,
148         // which may be changed across different runs.
149         .set_print_only_essential_constants(true)
150         // Remove "id" in "name.id" (after period) because it can be
151         // non-deterministic. This mainly affects computations' names because
152         // canonicalized instructions' names are in "tmp_id" format.
153         .set_print_ids(false)
154         // Sort computations.
155         .set_canonicalize_computations(true);
156   }
157 
158   // Options to produce a fingerprint of an HLO module and computation.
159   // Shorter (and therefore faster) than Fingerprint().
ModuleFingerprint()160   static HloPrintOptions ModuleFingerprint() {
161     return Fingerprint()
162         // Operand shapes can be inferred from output shapes and canonicalized
163         // names when we have an entire computation.
164         .set_print_operand_shape(false);
165   }
166 
167   // If true, large constants will be printed out.
set_print_large_constants(bool value)168   HloPrintOptions& set_print_large_constants(bool value) {
169     print_large_constants_ = value;
170     return *this;
171   }
172 
173   // If true, only integer, all-zero, are all-one constants will be printed out.
set_print_only_essential_constants(bool value)174   HloPrintOptions& set_print_only_essential_constants(bool value) {
175     print_only_essential_constants_ = value;
176     return *this;
177   }
178 
set_print_subcomputation_mode(PrintSubcomputationMode value)179   HloPrintOptions& set_print_subcomputation_mode(
180       PrintSubcomputationMode value) {
181     print_subcomputation_mode_ = value;
182     return *this;
183   }
184 
185   // If true, metadata will be printed.
set_print_metadata(bool value)186   HloPrintOptions& set_print_metadata(bool value) {
187     print_metadata_ = value;
188     return *this;
189   }
190 
191   // If true, backend_config will be printed.
set_print_backend_config(bool value)192   HloPrintOptions& set_print_backend_config(bool value) {
193     print_backend_config_ = value;
194     return *this;
195   }
196 
197   // If true, infeed_config and outfeed_config will be printed.
set_print_infeed_outfeed_config(bool value)198   HloPrintOptions& set_print_infeed_outfeed_config(bool value) {
199     print_infeed_outfeed_config_ = value;
200     return *this;
201   }
202 
203   // If true, result shapes will be printed.
set_print_result_shape(bool value)204   HloPrintOptions& set_print_result_shape(bool value) {
205     print_result_shape_ = value;
206     return *this;
207   }
208 
209   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)210   HloPrintOptions& set_print_operand_shape(bool value) {
211     print_operand_shape_ = value;
212     return *this;
213   }
214 
215   // If true, operands' shapes will be printed.
set_print_operand_index_annotation_interval(int64_t value)216   HloPrintOptions& set_print_operand_index_annotation_interval(int64_t value) {
217     print_operand_index_annotation_interval_ = value;
218     return *this;
219   }
220 
221   // If true, the operand names will be printed.
set_print_operand_names(bool value)222   HloPrintOptions& set_print_operand_names(bool value) {
223     print_operand_names_ = value;
224     return *this;
225   }
226 
227   // If true, all printed names include unique identifiers.
set_print_ids(bool value)228   HloPrintOptions& set_print_ids(bool value) {
229     print_ids_ = value;
230     return *this;
231   }
232 
233   // If true, the HLO includes its attributes.
set_print_extra_attributes(bool value)234   HloPrintOptions& set_print_extra_attributes(bool value) {
235     print_extra_attributes_ = value;
236     return *this;
237   }
238 
239   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)240   HloPrintOptions& set_print_program_shape(bool value) {
241     print_program_shape_ = value;
242     return *this;
243   }
244 
245   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)246   HloPrintOptions& set_print_percent(bool value) {
247     print_percent_ = value;
248     return *this;
249   }
250 
251   // If true, control dependencies will be printed.
set_print_control_dependencies(bool value)252   HloPrintOptions& set_print_control_dependencies(bool value) {
253     print_control_dependencies_ = value;
254     return *this;
255   }
256 
257   // If true, uses the async operation syntax sugar to print async-start,
258   // async-update, and async-done HLOs. If the syntax sugar is enabled, the
259   // computations called by these instructions will not be printed and instead
260   // the root of the called computation will be printed instead of these
261   // instructions and -start, -update, and -done suffixes will be appended to
262   // the opcode of the async operation. For example, for an HLO module where the
263   // syntax sugar is off:
264   //
265   // HloModule Module
266   //
267   // %AsyncOp (p0.1: f32[10]) -> f32[20] {
268   //   %p0.1 = f32[10]{0} parameter(0)
269   //   ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1),
270   //                                  custom_call_target="foo"
271   // }
272   //
273   // ENTRY %Entry (p0: f32[10]) -> f32[20] {
274   //   %p0 = f32[10]{0} parameter(0)
275   //   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(%p0),
276   //                                                    calls=%AsyncOp
277   //   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(
278   //                                                         %async-start),
279   //                                                     calls=%AsyncOp
280   //   ROOT %async-done = f32[20]{0} async-done(%async-update), calls=%AsyncOp
281   // }
282   //
283   // will be printed as following if the syntax sugar is enabled:
284   //
285   // HloModule Module
286   //
287   // ENTRY %Entry (p0: f32[10]) -> f32[20] {
288   //   %p0 = f32[10]{0} parameter(0)
289   //   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(%p0),
290   //                                                    custom_call_target="foo"
291   //   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(
292   //                                                         %async-start),
293   //                                                    custom_call_target="foo"
294   //   ROOT %async-done = f32[20]{0} custom-call-done(%async-update),
295   //                                                    custom_call_target="foo"
296   // }
set_syntax_sugar_async_ops(bool value)297   HloPrintOptions& set_syntax_sugar_async_ops(bool value) {
298     syntax_sugar_async_ops_ = value;
299     return *this;
300   }
301 
302   // If true, only a part of operands will be printed out (note that in this
303   // case the text will not be parsable).
set_compact_operands(bool value)304   HloPrintOptions& set_compact_operands(bool value) {
305     compact_operands_ = value;
306     return *this;
307   }
308 
309   // If true, include the layout in any shapes that are printed (instruction
310   // and operands).
set_include_layout_in_shapes(bool value)311   HloPrintOptions& set_include_layout_in_shapes(bool value) {
312     include_layout_in_shapes_ = value;
313     return *this;
314   }
315 
316   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
317   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
set_canonicalize_instruction_names(bool value)318   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
319     canonicalize_instruction_names_ = value;
320     return *this;
321   }
322 
323   // If true, canonicalizes computations, sorting by computations' names.
set_canonicalize_computations(bool value)324   HloPrintOptions& set_canonicalize_computations(bool value) {
325     canonicalize_computations_ = value;
326     return *this;
327   }
328 
329   // The indent of the hlo text block.
set_indent_amount(int value)330   HloPrintOptions& set_indent_amount(int value) {
331     indent_amount_ = value;
332     return *this;
333   }
334 
335   // If true, indicates the instruction being printed is inside a nested
336   // computation.
set_is_in_nested_computation(bool value)337   HloPrintOptions& set_is_in_nested_computation(bool value) {
338     is_in_nested_computation_ = value;
339     return *this;
340   }
341 
print_large_constants()342   bool print_large_constants() const { return print_large_constants_; }
print_only_essential_constants()343   bool print_only_essential_constants() const {
344     return print_only_essential_constants_;
345   }
print_subcomputation_mode()346   PrintSubcomputationMode print_subcomputation_mode() const {
347     return print_subcomputation_mode_;
348   }
print_metadata()349   bool print_metadata() const { return print_metadata_; }
print_backend_config()350   bool print_backend_config() const { return print_backend_config_; }
print_infeed_outfeed_config()351   bool print_infeed_outfeed_config() const {
352     return print_infeed_outfeed_config_;
353   }
compact_operands()354   bool compact_operands() const { return compact_operands_; }
include_layout_in_shapes()355   bool include_layout_in_shapes() const { return include_layout_in_shapes_; }
print_result_shape()356   bool print_result_shape() const { return print_result_shape_; }
print_operand_shape()357   bool print_operand_shape() const { return print_operand_shape_; }
print_operand_names()358   bool print_operand_names() const { return print_operand_names_; }
print_operand_index_annotation_interval()359   int64_t print_operand_index_annotation_interval() const {
360     return print_operand_index_annotation_interval_;
361   }
print_ids()362   bool print_ids() const { return print_ids_; }
print_program_shape()363   bool print_program_shape() const { return print_program_shape_; }
print_percent()364   bool print_percent() const { return print_percent_; }
print_control_dependencies()365   bool print_control_dependencies() const {
366     return print_control_dependencies_;
367   }
print_extra_attributes()368   bool print_extra_attributes() const { return print_extra_attributes_; }
syntax_sugar_async_ops()369   bool syntax_sugar_async_ops() const { return syntax_sugar_async_ops_; }
canonicalize_instruction_names()370   bool canonicalize_instruction_names() const {
371     return canonicalize_instruction_names_;
372   }
canonicalize_computations()373   bool canonicalize_computations() const { return canonicalize_computations_; }
indent_amount()374   int indent_amount() const { return indent_amount_; }
is_in_nested_computation()375   int is_in_nested_computation() const { return is_in_nested_computation_; }
376 
377  private:
378   bool print_large_constants_;
379   bool print_only_essential_constants_;
380   PrintSubcomputationMode print_subcomputation_mode_;
381   bool print_metadata_;
382   bool print_backend_config_;
383   bool print_infeed_outfeed_config_;
384   bool compact_operands_;
385   bool include_layout_in_shapes_;
386   bool print_result_shape_;
387   bool print_operand_shape_;
388   bool print_operand_names_;
389   // The interval between the /*index=*/ annotated operands. 0 means never print
390   // the annotation, 1 means print annotation for every operand.
391   int64_t print_operand_index_annotation_interval_;
392   bool print_program_shape_;
393   bool print_percent_;
394   bool print_control_dependencies_;
395   bool canonicalize_instruction_names_;
396   int indent_amount_;
397   bool is_in_nested_computation_;
398   bool print_ids_;
399   bool canonicalize_computations_;
400   bool print_extra_attributes_;
401   bool syntax_sugar_async_ops_;
402 };
403 
404 // For canonical string output, we need to have a canonical way to rename
405 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
406 // where <xxx> is an index starting from 0.
407 class CanonicalNameMap {
408  public:
LookupOrInsert(const std::string & name)409   const std::string& LookupOrInsert(const std::string& name) {
410     std::string& canonical_name = canonical_name_map_[name];
411     if (canonical_name.empty()) {
412       absl::StrAppend(&canonical_name, "tmp_", canonical_name_map_.size() - 1);
413     }
414     return canonical_name;
415   }
416 
417  private:
418   absl::flat_hash_map<std::string, std::string> canonical_name_map_;
419 };
420 
421 // HLO instructions are the atomic unit of the high-level compiler's IR.
422 //
423 // HloInstructions live inside of an HloComputation, which is analogous to a
424 // function in other programming languages.  Nodes have no total order within
425 // their computation.  Instead, they have a partial ordering determined by their
426 // data and control dependencies.
427 //
428 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
429 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
430 // control flow.  For example, the kConditional HLO executes one of two possible
431 // computations, depending on the runtime value of a predicate.
432 //
433 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
434 // values are produced by one HLO and flow into consumers across dependency
435 // edges.
436 class HloInstruction {
437  public:
438   // A fusion node computes the same value a call to its fusion computation
439   // would compute.  However, the choice of fusion kind dictates codegen
440   // strategy for the backend.
441   //
442   // To generate code for a kFusion HloInstruction, most backends do something
443   // like the following:
444   //
445   // 1) Identify the "primary" HloInstruction of the fused computation.
446   // 2) Emit code that does the work of the primary node, creating its inputs
447   //    and transforming its outputs as specified by the fused computation.
448   //
449   // In step (2), the code emitted is usually similar to the code that would be
450   // emitted for an *unfused* version of the primary node, except that
451   //
452   //  - when the primary node reads an element of one of its operands, instead
453   //    of loading the value from memory, it *computes* the value based on the
454   //    contents of the fused computation.
455   //  - when the primary node outputs a value, instead of storing it to memory,
456   //    it forwards the value to its users, which then perform additional
457   //    computations before the value is finally stored to memory at the root of
458   //    the fusion node.
459   //
460   // An HloInstruction's FusionKind helps us find the kFusion instruction's
461   // primary node, and can also affect how we generate code in step (2).
462   //
463   //  - kInput: The primary node is the root of the fused instruction.
464   //
465   //  - kOutput: The primary node is not the root of the fused instruction.
466   //    This fusion kind requires that one operand buffer of the fusion
467   //    instruction be able to alias the output buffer.  This constraint is
468   //    usually enough to let backends find the primary node unambiguously.
469   //
470   //  - kLoop: The primary node is the root of the fused computation, but,
471   //    unlike in input fusion, we prescribe a specific implementation for
472   //    codegen.  Rather than generating code that looks like the code we'd emit
473   //    for an unfused version of the primary/root node, we emit code that
474   //    generates one element of the root at a time.
475   //
476   //  - kCustom: Custom category for backend-specific fusions that don't fit
477   //    into the above patterns.
478   //
479   // Not all backends support all fusion kinds, and given a particular fused
480   // computation, it's not in general safe to change its fusion kind.  Creation
481   // of fusion nodes is always backend-specific.
482   //
483   // For elementwise ops (e.g. kAdd), most backends would emit a
484   // one-element-at-a-time implementation for the unfused version, so loop
485   // fusion and input fusion are probably equivalent if the root node is
486   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
487   // implementation might emit something more sophisticated for an unfused or
488   // input-fusion reduce, but will emit the naive code that reduces one element
489   // at a time for loop fusion with a reduce as the root.
490   //
491   // Another way to think of loop fusion is that it's equivalent to input
492   // fusion, but where the root node is an implicit identity node, whose
493   // unfused implementation is "read one element, write one element".
494   //
495   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
496   // input and loop fusion are basically the same thing: There is no reason for
497   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
498   // the root of a fusion should be lowered.  In addition, this scheme as
499   // written doesn't work for multi-output fusion, where the primary node is
500   // never actually the root (which is a kTuple instruction that gathers the
501   // multiple outputs of the fusion).
502   enum class FusionKind {
503     kLoop,
504     kInput,
505     kOutput,
506     kCustom,
507   };
508 
509   inline static constexpr char kMainExecutionThread[] = "main";
510 
~HloInstruction()511   virtual ~HloInstruction() { DetachFromOperandsAndUsers(); }
512 
513   // Detaches an instruction from its operands and users. That is, remove the
514   // instruction from each operand's user set and user's operand set.
515   void DetachFromOperandsAndUsers();
516 
517   // Adds a derived instruciton to the parent compuation of this instruction.
518   // Also update setup the new instruction as a derived instruction.
519   HloInstruction* AddInstruction(
520       std::unique_ptr<HloInstruction> derived_instruction);
521 
522   // Creates an instruction from the given proto. Arguments:
523   //
524   //   proto: the proto to convert from.
525   //   instruction_map: a map from instruction id to HloInstruction*. This map
526   //     must contain all operands of the newly constructed instruction.
527   //   computation_map: a map from computation id to HloComputation*. This map
528   //     must contain all computations which the newly constructed instruction
529   //     calls.
530   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
531       const HloInstructionProto& proto,
532       const absl::flat_hash_map<int64_t, HloInstruction*>& instruction_map,
533       const absl::flat_hash_map<int64_t, HloComputation*>& computation_map = {},
534       bool prohibit_empty_literal = true);
535 
536   // Creates a parameter-retrieving instruction.
537   static std::unique_ptr<HloInstruction> CreateParameter(
538       int64_t parameter_number, const Shape& shape, const std::string& name);
539 
540   // Creates a literal constant instruction.
541   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
542 
543   // Creates an Iota instruction.
544   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
545                                                     int64_t iota_dimension);
546 
547   // Creates a get tuple element instruction.
548   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
549       const Shape& shape, HloInstruction* operand, int64_t index);
550 
551   // Creates a get tuple element instruction.
552   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
553       HloInstruction* operand, int64_t index);
554 
555   // Creates a random number generation instruction that fills a shape with
556   // random numbers from a given distribution.
557   //
558   // The parameters to the instruction are interpreted as follows:
559   //
560   //  - If `distribution` is RNG_UNIFORM, generates a number in range
561   //    [param0, param1).
562   //
563   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
564   //    with mean `param0` and standard deviation `param1`.
565   static std::unique_ptr<HloInstruction> CreateRng(
566       const Shape& shape, RandomDistribution distribution,
567       absl::Span<HloInstruction* const> parameters);
568 
569   // Creates a stateless random bit generator instruction that fills a shape
570   // with random bits.
571   static std::unique_ptr<HloInstruction> CreateRngBitGenerator(
572       const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm);
573 
574   // Creates an instruction to update the random number generator state to
575   // reflect the new state after `delta` units of 32 random bits are generated
576   // and returns the old state.
577   static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState(
578       const Shape& shape, int64_t delta);
579 
580   // Creates a unary instruction (one operand).
581   // Precondition: opcode must be a legitimate unary operation.
582   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
583                                                      HloOpcode opcode,
584                                                      HloInstruction* operand);
585 
586   // Creates a binary instruction (two operands).
587   // Precondition: opcode must be a legitimate binary operation.
588   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
589                                                       HloOpcode opcode,
590                                                       HloInstruction* lhs,
591                                                       HloInstruction* rhs);
592 
593   // Creates a ternary instruction (three operands).
594   // Precondition: opcode must be a legitimate ternary operation.
595   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
596                                                        HloOpcode opcode,
597                                                        HloInstruction* lhs,
598                                                        HloInstruction* rhs,
599                                                        HloInstruction* ehs);
600 
601   // Creates a variadic instruction (variable number of operands).
602   // Precondition: opcode must be a legitimate variadic operation.
603   static std::unique_ptr<HloInstruction> CreateVariadic(
604       const Shape& shape, HloOpcode opcode,
605       absl::Span<HloInstruction* const> operands);
606 
607   // Creates a map instruction, where the computation (given by the handle) is
608   // applied element-wise to every element in operands (across the operands,
609   // at a given index)
610   static std::unique_ptr<HloInstruction> CreateMap(
611       const Shape& shape, absl::Span<HloInstruction* const> operands,
612       HloComputation* map_computation);
613 
614   // Creates a convolution op, where rhs is the convolutional filter
615   // and window describes how the filter is applied to lhs.
616   static std::unique_ptr<HloInstruction> CreateConvolve(
617       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
618       int64_t feature_group_count, int64_t batch_group_count,
619       const Window& window,
620       const ConvolutionDimensionNumbers& dimension_numbers,
621       const PrecisionConfig& precision_config);
622 
623   // Creates an FFT op, of the type indicated by fft_type.
624   static std::unique_ptr<HloInstruction> CreateFft(
625       const Shape& shape, HloInstruction* operand, FftType fft_type,
626       absl::Span<const int64_t> fft_length);
627 
628   // Creates an asynchronous start, update, and done op.
629   static std::unique_ptr<HloInstruction> CreateAsyncStart(
630       const Shape& shape, absl::Span<HloInstruction* const> operands,
631       HloComputation* async_computation,
632       std::optional<int64_t> async_group_id = std::nullopt,
633       absl::string_view async_execution_thread = kMainExecutionThread);
634   static std::unique_ptr<HloInstruction> CreateAsyncUpdate(
635       const Shape& shape, HloInstruction* operand,
636       HloComputation* async_computation,
637       std::optional<int64_t> async_group_id = std::nullopt,
638       absl::string_view async_execution_thread = kMainExecutionThread);
639   static std::unique_ptr<HloInstruction> CreateAsyncDone(
640       const Shape& shape, HloInstruction* operand,
641       HloComputation* async_computation,
642       std::optional<int64_t> async_group_id = std::nullopt,
643       absl::string_view async_execution_thread = kMainExecutionThread);
644 
645   // Creates a copy-start op, indicating whether this is a cross-program
646   // prefetch or not.
647   static std::unique_ptr<HloInstruction> CreateCopyStart(
648       const Shape& shape, HloInstruction* operand,
649       bool is_cross_program_prefetch = false);
650 
651   // Creates a compare op, performing the comparison specified in direction.
652   static std::unique_ptr<HloInstruction> CreateCompare(
653       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
654       Comparison::Direction direction,
655       std::optional<Comparison::Type> type = std::nullopt);
656 
657   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
658       const Shape& shape, HloInstruction* a, HloInstruction* b,
659       const TriangularSolveOptions& options);
660 
661   static std::unique_ptr<HloInstruction> CreateCholesky(
662       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
663 
664   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
665   // dimensions specified in 'dimension_numbers'.
666   static std::unique_ptr<HloInstruction> CreateDot(
667       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
668       const DotDimensionNumbers& dimension_numbers,
669       const PrecisionConfig& precision_config);
670 
671   // Creates a reduce-precision op, where operand is the data to reduce in
672   // precision, and exponent_bits and mantissa_bits describe the precision to
673   // reduce it to.
674   static std::unique_ptr<HloInstruction> CreateReducePrecision(
675       const Shape& shape, HloInstruction* operand, const int exponent_bits,
676       const int mantissa_bits);
677 
678   // Creates an all-gather op, which concats the operands of all participants
679   // along all_gather_dimension. The replica_groups, channel_id, and
680   // use_global_device_ids arguments are identical to those in all-reduce,
681   // except that the order of the group members determines the concatenation
682   // order of inputs from different participants.
683   static std::unique_ptr<HloInstruction> CreateAllGather(
684       const Shape& shape, absl::Span<HloInstruction* const> operands,
685       int64_t all_gather_dimension,
686       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
687       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
688 
689   // Creates an all-gather-start op, which concats the operands of all
690   // participants
691   // along all_gather_dimension. The replica_groups, channel_id, and
692   // use_global_device_ids arguments are identical to those in all-reduce,
693   // except that the order of the group members determines the concatenation
694   // order of inputs from different participants. Needs to be used in
695   // conjunction of a AllGatherDone op that synchronizes and returns the result.
696   static std::unique_ptr<HloInstruction> CreateAllGatherStart(
697       const Shape& shape, absl::Span<HloInstruction* const> operands,
698       int64_t all_gather_dimension,
699       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
700       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
701 
702   // Creates a cross replica reduction op.
703   //
704   // `reduction_computation`: the reduction function.
705   //
706   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
707   // empty, all replicas belong to one group in the order of 0 - (n-1).
708   // Allreduce will be applied within subgroups.
709   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
710   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
711   //
712   // `channel_id`: for Allreduce nodes from different modules, if
713   // they have the same channel_id, they will be 'Allreduce'd. If
714   // empty, Allreduce will not be applied cross modules.
715   static std::unique_ptr<HloInstruction> CreateAllReduce(
716       const Shape& shape, absl::Span<HloInstruction* const> operands,
717       HloComputation* reduce_computation,
718       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
719       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
720 
721   // Creates a reduce-scatter operation which reduces its inputs across the
722   // given replica groups and then scatters the reduced data across the N
723   // participants.
724   static std::unique_ptr<HloInstruction> CreateReduceScatter(
725       const Shape& shape, absl::Span<HloInstruction* const> operands,
726       HloComputation* reduce_computation,
727       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
728       const std::optional<int64_t>& channel_id, bool use_global_device_ids,
729       int64_t scatter_dimension);
730 
731   // Creates an asynchronous cross replica reduction op.
732   //
733   // `reduction_computation`: the reduction function.
734   //
735   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
736   // empty, all replicas belong to one group in the order of 0 - (n-1).
737   // Allreduce will be applied within subgroups.
738   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
739   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
740   //
741   // `channel_id`: for Allreduce nodes from different modules, if
742   // they have the same channel_id, they will be 'Allreduce'd. If
743   // empty, Allreduce will not be applied cross modules.
744   static std::unique_ptr<HloInstruction> CreateAllReduceStart(
745       const Shape& shape, absl::Span<HloInstruction* const> operands,
746       HloComputation* reduce_computation,
747       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
748       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
749 
750   // An all-to-all op takes N array operands of the same shape and scatters them
751   // to N replicas.  Each replica gathers the results into a tuple.
752   //
753   // For example, suppose we have 3 replicas, with replica i passing inputs
754   // [a_i, b_i, c_i] to its all-to-all op.  Then the resulting tuples are
755   //
756   //   replica 0: (a_0, a_1, a_2)
757   //   replica 1: (b_0, b_1, b_2)
758   //   replica 2: (c_0, c_1, c_2).
759   //
760   // If replica_groups is set, the op is sharded and the replicas are permuted.
761   // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}.
762   // Then each replica passes two operands, say [a_i, b_i], and the result is
763   //
764   //   replica 0: (b_3, b_0)
765   //   replica 1: (a_1, a_2)
766   //   replica 2: (b_1, b_2)
767   //   replica 3: (a_3, a_0).
768   //
769   // All replica groups must have the same number of elements, and the number of
770   // operands must be equal to the size of one replica group.  Each replica must
771   // appear in exactly one group.
772   //
773   // Note that in addition to supporting this instruction, XlaBuilder also
774   // supports a higher-level instruction which takes one input and slices it,
775   // performs AllToAll and then concatenates the results into a single array.
776   static std::unique_ptr<HloInstruction> CreateAllToAll(
777       const Shape& shape, absl::Span<HloInstruction* const> operands,
778       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
779       const std::optional<int64_t>& channel_id,
780       const std::optional<int64_t>& split_dimension = std::nullopt);
781 
782   // Creates a communication instruction that permutes data cross replicas.
783   // Data is sent/received according to the (source_replica_id,
784   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
785   // target_replica_id in any pair, the output on that replica is a tensor
786   // consists of 0(s) in `shape`.
787   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
788       const Shape& shape, HloInstruction* operand,
789       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
790       const std::optional<int64_t>& channel_id);
791 
792   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
793       const Shape& shape, HloInstruction* input, HloInstruction* output,
794       HloInstruction* input_start_indices, HloInstruction* output_start_indices,
795       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
796       absl::Span<const std::vector<int64_t>> slice_sizes,
797       const std::optional<int64_t>& channel_id);
798 
799   // Creates a communication instruction that initiates the start of
800   // CollectivePermute.
801   static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
802       const Shape& shape, HloInstruction* operand,
803       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
804       const std::optional<int64_t>& channel_id);
805 
806   static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
807       const Shape& shape, HloInstruction* input, HloInstruction* output,
808       HloInstruction* input_start_indices, HloInstruction* output_start_indices,
809       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
810       absl::Span<const std::vector<int64_t>> slice_sizes,
811       const std::optional<int64_t>& channel_id);
812 
813   // Creates an instruction that returns a U32 replica ID.
814   static std::unique_ptr<HloInstruction> CreateReplicaId(
815       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
816 
817   // Creates an instruction that returns a U32 partition ID.
818   static std::unique_ptr<HloInstruction> CreatePartitionId(
819       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
820 
821   // Creates a conversion instruction, where operand is the data to convert and
822   // shape is the target shape for the conversion.
823   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
824                                                        HloInstruction* operand);
825 
826   // Creates a bitcast instruction, where operand is the data to
827   // convert and shape is the target shape for the conversion.
828   static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape,
829                                                        HloInstruction* operand);
830 
831   // Creates a bitcast conversion instruction, where operand is the data to
832   // convert and shape is the target shape for the conversion.
833   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
834       const Shape& shape, HloInstruction* operand);
835 
836   // Creates an infeed instruction, which reads data of the given shape from the
837   // Infeed interface of the device. infeed_shape is the shape of the data
838   // received from the infeed *not* the shape of the infeed instruction which
839   // is a tuple containing the infeed_shape and the TOKEN.
840   static std::unique_ptr<HloInstruction> CreateInfeed(
841       const Shape& infeed_shape, HloInstruction* token_operand,
842       const std::string& config);
843 
844   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
845   // shape of the data being outfed *not* the shape of the outfeed instruction
846   // which is a TOKEN.
847   static std::unique_ptr<HloInstruction> CreateOutfeed(
848       const Shape& outfeed_shape, HloInstruction* operand,
849       HloInstruction* token_operand, absl::string_view outfeed_config);
850 
851   // Creates an asynchronous send instruction with the given channel id, which
852   // initiates sending the operand data to a unique receive instruction in
853   // another computation that has the same channel id. If is_host_transfer is
854   // true, then this Send operation transfers data to the host.
855   static std::unique_ptr<HloInstruction> CreateSend(
856       HloInstruction* operand, HloInstruction* token, int64_t channel_id,
857       bool is_host_transfer = false);
858 
859   // Blocks until data transfer for the Send instruction (operand) is complete.
860   // The operand must be kSend.
861   static std::unique_ptr<HloInstruction> CreateSendDone(
862       HloInstruction* operand, bool is_host_transfer = false);
863 
864   // Creates an asynchronous receive instruction with the given channel id,
865   // which allocates resources to receive data of the given shape from a unique
866   // send instruction in another computation that has the same channel id.  If
867   // is_host_transfer is true, then this Recv operation transfers data from the
868   // host.
869   static std::unique_ptr<HloInstruction> CreateRecv(
870       const Shape& shape, HloInstruction* token, int64_t channel_id,
871       bool is_host_transfer = false);
872 
873   // Blocks until data transfer for the Recv instruction (operand) is complete
874   // and returns the receive buffer. The operand must be kRecv.
875   static std::unique_ptr<HloInstruction> CreateRecvDone(
876       HloInstruction* operand, bool is_host_transfer = false);
877 
878   // Creates a slice instruction, where the operand is sliced by the given
879   // start/limit indices.
880   static std::unique_ptr<HloInstruction> CreateSlice(
881       const Shape& shape, HloInstruction* operand,
882       absl::Span<const int64_t> start_indices,
883       absl::Span<const int64_t> limit_indices,
884       absl::Span<const int64_t> strides);
885 
886   // Creates a slice instruction, where the first operand is sliced by
887   // start indices specified in the second operand, and by size specified in
888   // 'slice_sizes'.
889   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
890       const Shape& shape, HloInstruction* operand,
891       absl::Span<HloInstruction* const> start_indices,
892       absl::Span<const int64_t> slice_sizes);
893 
894   // Creates a dynamic update slice instruction, which updates a slice
895   // of 'operand' with 'update' and 'start_indices'.
896   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
897       const Shape& shape, HloInstruction* operand, HloInstruction* update,
898       absl::Span<HloInstruction* const> start_indices);
899 
900   // Creates a concatenate instruction, where the operands are concatenated on
901   // the provided dimension.
902   static std::unique_ptr<HloInstruction> CreateConcatenate(
903       const Shape& shape, absl::Span<HloInstruction* const> operands,
904       int64_t dimension);
905 
906   // Creates a reduce instruction, where the computation (given by the handle)
907   // is applied successively to every element in operand. For example, let f be
908   // the function to apply, which takes 2 arguments, an accumulator and the
909   // current value. Let init be an initial value (which is normally chosen to be
910   // the identity element for f, e.g. 0 if f is addition).
911   // Then the reduce HLO will compute:
912   // f(f(init, value0), value1), ...)
913   static std::unique_ptr<HloInstruction> CreateReduce(
914       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
915       absl::Span<const int64_t> dimensions_to_reduce,
916       HloComputation* reduce_computation);
917 
918   // A more general, multiple-argument version of the above.
919   // The function to apply, f, now takes N arguments:
920   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
921   // init_valueN], and returns an N-tuple. The performed computation is (for
922   // commutative and associative f operators) equivalent to:
923   //
924   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
925   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
926   // ..., inputN.value1)
927   // ...
928   static std::unique_ptr<HloInstruction> CreateReduce(
929       const Shape& shape, absl::Span<HloInstruction* const> operands,
930       absl::Span<HloInstruction* const> init_values,
931       absl::Span<const int64_t> dimensions_to_reduce,
932       HloComputation* reduce_computation);
933 
934   // Helper version where the operands are given by a single instruction which
935   // either is a tuple of size `init_values`, or a single input, in which case
936   // size of `init_values` is one.
937   static std::unique_ptr<HloInstruction> CreateReduce(
938       const Shape& shape, HloInstruction* tuple_of_instructions,
939       absl::Span<HloInstruction* const> init_values,
940       absl::Span<const int64_t> dimensions_to_reduce,
941       HloComputation* reduce_computation);
942 
943   // Creates a reduce-window instruction, where the computation (given
944   // by the handle) is applied window-wise at each valid window
945   // position in the operand.
946   static std::unique_ptr<HloInstruction> CreateReduceWindow(
947       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
948       const Window& window, HloComputation* reduce_computation);
949 
950   // A more general, multiple-argument version of the above.
951   // The reduce_computation being applied,now takes N arguments:
952   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
953   // valueN], and returns an N-tuple. The operands and init_values now each
954   // contain a span of N input arrays and n initial values.
955   static std::unique_ptr<HloInstruction> CreateReduceWindow(
956       const Shape& shape, absl::Span<HloInstruction* const> operands,
957       absl::Span<HloInstruction* const> init_values, const Window& window,
958       HloComputation* reduce_computation);
959 
960   // Creates a batch-norm-training instruction.
961   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
962       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
963       HloInstruction* offset, float epsilon, int64_t feature_index);
964 
965   // Creates a batch-norm-inference instruction.
966   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
967       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
968       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
969       float epsilon, int64_t feature_index);
970 
971   // Creates a batch-norm-grad instruction.
972   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
973       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
974       HloInstruction* mean, HloInstruction* variance,
975       HloInstruction* grad_output, float epsilon, int64_t feature_index);
976 
977   // Creates a scatter computation that scatters the `source` array to the
978   // selected indices of each window.
979   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
980       const Shape& shape, HloInstruction* operand, HloComputation* select,
981       const Window& window, HloInstruction* source, HloInstruction* init_value,
982       HloComputation* scatter);
983 
984   // Creates a broadcast instruction.
985   static std::unique_ptr<HloInstruction> CreateBroadcast(
986       const Shape& shape, HloInstruction* operand,
987       absl::Span<const int64_t> broadcast_dimensions);
988 
989   // Creates a sequence of instructions that performs an explicit broadcast of
990   // the operand to the target shape.
991   //
992   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
993   // returned as a unique_ptr for API consistency with other factory methods in
994   // this interface.
995   //
996   // TODO(b/72173833) Ideally HloComputations would always be present, and so
997   // the adder being passed by the caller would not be necessary.
998   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
999       const Shape& output_shape, HloInstruction* operand,
1000       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1001           adder);
1002 
1003   // Creates a pad instruction, where the operand is padded on the edges and
1004   // between the elements with the given padding value.
1005   static std::unique_ptr<HloInstruction> CreatePad(
1006       const Shape& shape, HloInstruction* operand,
1007       HloInstruction* padding_value, const PaddingConfig& padding_config);
1008 
1009   // Creates a reshape instruction, where the operand is flattened row-major
1010   // order and then reshaped to the given result shape.
1011   static std::unique_ptr<HloInstruction> CreateReshape(
1012       const Shape& shape, HloInstruction* operand,
1013       int64_t inferred_dimension = -1);
1014 
1015   // Creates a dynamic reshape instruction. Similar to reshape but dynamic
1016   // dimensions sizes are provided as additional variadic arguments.
1017   //
1018   // Precondition: dim_sizes.size() == shape.rank()
1019   static std::unique_ptr<HloInstruction> CreateDynamicReshape(
1020       const Shape& shape, HloInstruction* data_operand,
1021       absl::Span<HloInstruction* const> dim_sizes);
1022 
1023   // Creates a transpose instruction which permutes the operand dimensions.
1024   static std::unique_ptr<HloInstruction> CreateTranspose(
1025       const Shape& shape, HloInstruction* operand,
1026       absl::Span<const int64_t> dimensions);
1027 
1028   // Creates a n-ary sort op with a 'compare' computation which is used for
1029   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
1030   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
1031   // specific index positions which should be compared, and should return a
1032   // PRED. 'is_stable' specifies whether stable sorting is required.
1033   static std::unique_ptr<HloInstruction> CreateSort(
1034       const Shape& shape, int64_t dimension,
1035       absl::Span<HloInstruction* const> operands, HloComputation* compare,
1036       bool is_stable);
1037 
1038   // Creates a while instruction, given a condition computation, a body
1039   // computation, and the initial value for the input of the computations. For
1040   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
1041   // corresponds to the C code below.
1042   // int32_t i = 1; int32_t result = while(i < 1000) { i = i * 2 }
1043   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
1044                                                      HloComputation* condition,
1045                                                      HloComputation* body,
1046                                                      HloInstruction* init);
1047 
1048   static std::unique_ptr<HloInstruction> CreateConditional(
1049       const Shape& shape, HloInstruction* pred,
1050       HloInstruction* true_computation_arg, HloComputation* true_computation,
1051       HloInstruction* false_computation_arg, HloComputation* false_computation);
1052 
1053   static std::unique_ptr<HloInstruction> CreateConditional(
1054       const Shape& shape, HloInstruction* branch_index,
1055       absl::Span<HloComputation* const> branch_computations,
1056       absl::Span<HloInstruction* const> branch_computation_args);
1057 
1058   static std::unique_ptr<HloInstruction> CreateGather(
1059       const Shape& shape, HloInstruction* operand,
1060       HloInstruction* start_indices,
1061       const GatherDimensionNumbers& gather_dim_numbers,
1062       absl::Span<const int64_t> slice_sizes, bool indices_are_sorted);
1063 
1064   static std::unique_ptr<HloInstruction> CreateScatter(
1065       const Shape& shape, absl::Span<HloInstruction* const> operands,
1066       HloInstruction* scatter_indices,
1067       absl::Span<HloInstruction* const> updates,
1068       HloComputation* update_computation,
1069       const ScatterDimensionNumbers& scatter_dim_numbers,
1070       bool indices_are_sorted, bool unique_indices);
1071 
1072   static std::unique_ptr<HloInstruction> CreateScatter(
1073       const Shape& shape, HloInstruction* operand,
1074       HloInstruction* scatter_indices, HloInstruction* updates,
1075       HloComputation* update_computation,
1076       const ScatterDimensionNumbers& scatter_dim_numbers,
1077       bool indices_are_sorted, bool unique_indices);
1078 
1079   // Creates a kDomain instruction which delimits an HLO domain which have
1080   // the provided user and operand side metadata.
1081   static std::unique_ptr<HloInstruction> CreateDomain(
1082       const Shape& shape, HloInstruction* operand,
1083       std::unique_ptr<DomainMetadata> operand_side_metadata,
1084       std::unique_ptr<DomainMetadata> user_side_metadata);
1085 
1086   // Creates a fusion instruction. A fusion instruction contains one or more
1087   // fused instructions forming an expression with a single root
1088   // "fused_root". Additional instructions can be added to the fusion
1089   // instruction with the method FuseInstruction.
1090   static std::unique_ptr<HloInstruction> CreateFusion(
1091       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
1092 
1093   static std::unique_ptr<HloInstruction> CreateFusion(
1094       const Shape& shape, FusionKind fusion_kind,
1095       absl::Span<HloInstruction* const> operands,
1096       HloComputation* fusion_computation, absl::string_view prefix = "");
1097 
1098   // Creates a call instruction that applies the given computation on the given
1099   // operands. "shape" is the resultant shape.
1100   static std::unique_ptr<HloInstruction> CreateCall(
1101       const Shape& shape, HloInstruction* called_computation_root);
1102 
1103   static std::unique_ptr<HloInstruction> CreateCall(
1104       const Shape& shape, absl::Span<HloInstruction* const> operands,
1105       HloComputation* computation);
1106 
1107   // Creates a custom call instruction that applies the given custom call target
1108   // to the given operands. "opaque" can be an arbitrary string with a
1109   // backend-specific interpretation. "shape" is the resultant shape.
1110   static std::unique_ptr<HloInstruction> CreateCustomCall(
1111       const Shape& shape, absl::Span<HloInstruction* const> operands,
1112       absl::string_view custom_call_target, std::string opaque = "",
1113       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1114 
1115   // Overload with a to_apply computation.
1116   static std::unique_ptr<HloInstruction> CreateCustomCall(
1117       const Shape& shape, absl::Span<HloInstruction* const> operands,
1118       HloComputation* to_apply, absl::string_view custom_call_target,
1119       std::string opaque = "",
1120       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1121 
1122   // Overload with multiple computations. The called computations can have
1123   // different function signatures.
1124   static std::unique_ptr<HloInstruction> CreateCustomCall(
1125       const Shape& shape, absl::Span<HloInstruction* const> operands,
1126       absl::Span<HloComputation* const> called_computations,
1127       absl::string_view custom_call_target, std::string opaque = "",
1128       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1129 
1130   // Overload which constrains the layouts of the operand and result. 'shape'
1131   // and 'operand_shapes_with_layout' must have layouts.
1132   // 'operand_shapes_with_layout' must have a compatible element for each
1133   // operand.
1134   static std::unique_ptr<HloInstruction> CreateCustomCall(
1135       const Shape& shape, absl::Span<HloInstruction* const> operands,
1136       absl::string_view custom_call_target,
1137       absl::Span<const Shape> operand_shapes_with_layout,
1138       std::string opaque = "",
1139       CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
1140 
1141   // Creates a tuple instruction with the given elements. This is a convenience
1142   // wrapper around CreateVariadic.
1143   static std::unique_ptr<HloInstruction> CreateTuple(
1144       absl::Span<HloInstruction* const> elements);
1145 
1146   // Creates a reverse instruction, which reverses the order of the elements
1147   // in the specified dimensions.
1148   static std::unique_ptr<HloInstruction> CreateReverse(
1149       const Shape& shape, HloInstruction* operand,
1150       absl::Span<const int64_t> dimensions);
1151 
1152   // Creates a Afterall instruction used for joining or creating new values of
1153   // token type which thread through side-effecting operations. Operands must
1154   // all be tokens, and there must be at least one operand.
1155   static std::unique_ptr<HloInstruction> CreateAfterAll(
1156       absl::Span<HloInstruction* const> operands);
1157 
1158   // Creates an AfterAll instruction which creates a token type out of thin air
1159   // (no operands). This is a separate method from CreateAfterAll to facility
1160   // the removal of operand-less AfterAll instructions.
1161   // TODO(b/110532604): Remove this capability of creating a token from nothing
1162   // when we plumb a primordial token from the entry computation.
1163   static std::unique_ptr<HloInstruction> CreateToken();
1164 
1165   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
1166       const Shape& shape, HloInstruction* operand, int64_t dimension);
1167 
1168   static std::unique_ptr<HloInstruction> CreateSetDimensionSize(
1169       const Shape& shape, HloInstruction* operand, HloInstruction* val,
1170       int64_t dimension);
1171 
1172   static std::unique_ptr<HloInstruction> CreateAddDependency(
1173       HloInstruction* data_operand, HloInstruction* token_operand);
1174 
1175   // Returns the opcode for this instruction.
opcode()1176   HloOpcode opcode() const { return opcode_; }
mutable_opcode()1177   HloOpcode* mutable_opcode() { return &opcode_; }
1178 
1179   // Returns whether this instruction is the root of its parent computation.
1180   bool IsRoot() const;
1181 
1182   // Does this instruction have no users.
IsDead()1183   bool IsDead() const { return users_.empty() && !IsRoot(); }
1184 
1185   // Returns true if this instruction has a side effect, irrespective of whether
1186   // any called computations may contain an instruction with side effects.
1187   bool HasSideEffectNoRecurse() const;
1188 
1189   // Returns true if this instruction has a side effect. An instruction has a
1190   // side effect if it uses certain opcodes or calls a computation with a side
1191   // effect.
1192   bool HasSideEffect() const;
1193 
1194   // Returns the result shape of this instruction.
1195   const Shape& shape() const;
1196 
1197   // Returns the (mutable) result shape of this instruction.
mutable_shape()1198   Shape* mutable_shape() { return &shape_; }
1199 
1200   // Returns the ith operand to this instruction.
1201   const HloInstruction* operand(int64_t i) const;
1202 
1203   // Returns the ith operand to this instruction.
1204   HloInstruction* mutable_operand(int64_t i);
1205 
1206   // Returns the number of operands to this instruction.
operand_count()1207   int64_t operand_count() const { return operands_.size(); }
1208 
1209   // Returns the vector of operands of this instruction.
1210   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
operands()1211   const InstructionVector& operands() const { return operands_; }
mutable_operands()1212   InstructionVector mutable_operands() { return operands_; }
1213 
1214   // Returns the vector of unique operands, in the same order they are found
1215   // within the operand vector.
1216   InstructionVector unique_operands() const;
1217 
1218   // Returns the index of 'target' in the operands sequence.
1219   // Precondition: target must be an operand (or a fatal error will occur).
1220   int64_t operand_index(const HloInstruction* target) const;
1221 
1222   // Returns the number of users of this instruction.
user_count()1223   int64_t user_count() const { return users_.size(); }
1224 
1225   // Returns the users of this instruction.
users()1226   const std::vector<HloInstruction*>& users() const { return users_; }
1227 
1228   // Returns the index of the user in the users() vector.
1229   //
1230   // Precondition: `user` is a user of the instruction.
1231   int64_t UserId(HloInstruction* user);
1232 
1233   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)1234   bool IsUserOf(const HloInstruction* instruction) const {
1235     return ContainsKey(instruction->user_map_, this);
1236   }
1237 
1238   // Adds a control dependency from this instruction to the given
1239   // instruction. This instruction becomes a control predecessor of
1240   // 'instruction', and 'instruction' becomes a control successor of this
1241   // instruction. Returns an error status if either of the given instructions
1242   // does not belong to the same computation.
1243   //
1244   // This is used to enforce an additional ordering requirement that is not
1245   // captured by normal data dependencies, such as ordering among Send or Recv
1246   // operations to avoid deadlock.
1247   Status AddControlDependencyTo(HloInstruction* instruction);
1248 
1249   // Removes a previously added control dependency from this instruction to
1250   // 'instruction'.
1251   Status RemoveControlDependencyTo(HloInstruction* instruction);
1252 
1253   // Drops all control predecessors and successors from this HLO instruction.
1254   Status DropAllControlDeps();
1255 
1256   // Copies the control predecessors and successors on this HLO instruction to
1257   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
1258   // this HLO are in the same module.
1259   //
1260   // Depending on the use cases we see in practice, in the future we may
1261   // consider folding the logic here into Clone, CloneWithNewOperands and
1262   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
1263   Status CopyAllControlDepsFrom(const HloInstruction* inst);
1264 
1265   // Returns the set of control predecessors (successors) of this
1266   // instruction. Control predecessors (successors) must execute before (after)
1267   // the current instruction.
control_predecessors()1268   const std::vector<HloInstruction*>& control_predecessors() const {
1269     return control_predecessors_;
1270   }
control_successors()1271   const std::vector<HloInstruction*>& control_successors() const {
1272     return control_successors_;
1273   }
1274 
1275   // Returns true if "other" performs the same computation as this instruction.
1276   bool Identical(
1277       const HloInstruction& other,
1278       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1279           eq_operands = std::equal_to<const HloInstruction*>(),
1280       const std::function<bool(const HloComputation*, const HloComputation*)>&
1281           eq_computations = std::equal_to<const HloComputation*>(),
1282       bool layout_sensitive = true) const {
1283     return IdenticalInternal(other, eq_operands, eq_computations,
1284                              layout_sensitive,
1285                              /*ignore_channel_id_values=*/false,
1286                              /*ignore_commutative_operand_order=*/false);
1287   }
1288 
1289   // Same as Identical() but ignores the order of commutative operands (e.g.
1290   // considers add(a,b) equal to add(b,a)).
1291   bool IdenticalIgnoringCommutativeOperandOrder(
1292       const HloInstruction& other,
1293       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1294           eq_operands = std::equal_to<const HloInstruction*>(),
1295       const std::function<bool(const HloComputation*, const HloComputation*)>&
1296           eq_computations = std::equal_to<const HloComputation*>(),
1297       bool layout_sensitive = true) const {
1298     return IdenticalInternal(other, eq_operands, eq_computations,
1299                              layout_sensitive,
1300                              /*ignore_channel_id_values=*/false,
1301                              /*ignore_commutative_operand_order=*/true);
1302   }
1303 
1304   // Same as Identical() but ignores channel ID value mismatches, as long as
1305   // both have channel IDs or neither has a channel ID.
1306   bool IdenticalIgnoringChannelIdValues(
1307       const HloInstruction& other,
1308       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1309           eq_operands = std::equal_to<const HloInstruction*>(),
1310       const std::function<bool(const HloComputation*, const HloComputation*)>&
1311           eq_computations = std::equal_to<const HloComputation*>(),
1312       bool layout_sensitive = true) const {
1313     return IdenticalInternal(other, eq_operands, eq_computations,
1314                              layout_sensitive,
1315                              /*ignore_channel_id_values=*/true,
1316                              /*ignore_commutative_operand_order=*/true);
1317   }
1318 
1319   // Generates a hash value of an HLO instruction. Hash considers
1320   // information on opcode, shape, operands, and typically a root instruction.
1321   // This function returns the same hash value for equivalent HLO instructions,
1322   // with respect to HloInstruction::Identical() method.
1323   // TODO(majnemer): Make the comment here more crisp & accurate.
1324   template <typename H>
AbslHashValue(H h,const HloInstruction & hlo)1325   friend H AbslHashValue(H h, const HloInstruction& hlo) {
1326     h = H::combine(std::move(h), hlo.opcode(), hlo.shape());
1327 
1328     if (!hlo.IsCrossModuleAllReduce()) {
1329       for (size_t i = 0; i < hlo.operands().size(); ++i) {
1330         h = H::combine(std::move(h), hlo.operand(i)->shape());
1331       }
1332       h = H::combine(std::move(h), hlo.operand_count());
1333     }
1334 
1335     if (hlo.opcode() == HloOpcode::kFusion) {
1336       h = H::combine(std::move(h), *hlo.fused_expression_root(),
1337                      hlo.fusion_kind(), hlo.fused_instruction_count(),
1338                      hlo.fused_parameters().size());
1339     }
1340     return h;
1341   }
1342 
1343   // Returns whether the instruction has a constant operand.
1344   bool HasConstantOperand() const;
1345 
1346   // Replaces the use of this instruction in "user" with "new_producer". Note
1347   // that there might be multiple uses of this instruction in "user"; all will
1348   // be replaced.
1349   //
1350   // If user is a fusion instruction, this function will remove any duplicated
1351   // operands of it which could be created due to this replacement.
1352   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
1353 
1354   // Same as ReplaceUseWith(), but new_producer can have a different shape.
1355   Status ReplaceUseWithDifferentShape(HloInstruction* user,
1356                                       HloInstruction* new_producer);
1357 
1358   // Same as ReplaceUseWith but only replaces the use at the given operand
1359   // number.
1360   Status ReplaceUseWith(HloInstruction* user, int operand_number,
1361                         HloInstruction* new_producer);
1362   Status ReplaceUseWithDifferentShape(HloInstruction* user, int operand_number,
1363                                       HloInstruction* new_producer);
1364 
1365   // Replaces the specified operand with new_operand. The old and new operands
1366   // must have compatible shapes ignoring floating-point precision.
1367   //
1368   // This function does NOT remove duplicated operands even if this instruction
1369   // is a fusion, so that the existing operand numbers do not change.
1370   Status ReplaceOperandWith(int64_t operand_num, HloInstruction* new_operand);
1371 
1372   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
1373   Status ReplaceOperandWithDifferentShape(int64_t operand_num,
1374                                           HloInstruction* new_operand);
1375 
1376   // Replaces all uses of this instruction with the new producer. If
1377   // new_producer is a user of this instruction then new_producer remains a use
1378   // of this instruction to avoid introducing cycles into the graph.
1379   //
1380   // If this instruction is the root of its computation, sets the computation's
1381   // root to new_producer.
1382   //
1383   // The new producer must have a compatible shape ignoring floating-point
1384   // precision.
1385   //
1386   // If a user is a fusion instruction, this function will remove any duplicated
1387   // operands of it which could be created due to this replacement.
1388   Status ReplaceAllUsesWith(HloInstruction* new_producer);
1389 
1390   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
1391   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
1392 
1393   // Same as ReplaceAllUsesWith, but only replace given set of users.
1394   Status ReplaceUsesWith(absl::Span<HloInstruction* const> users,
1395                          HloInstruction* new_producer);
1396   Status ReplaceAllUsesWithDifferentShape(
1397       absl::Span<HloInstruction* const> users, HloInstruction* new_producer);
1398 
1399   // Performs a postorder DFS visit using this node as the root. If
1400   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
1401   // complete. If ignore_control_predecessors is true, instructions only
1402   // reachable via control dependencies will not be visited, and the postorder
1403   // will not take control dependencies into account. It is as if the control
1404   // dependencies didn't exist in the graph at all.
1405   template <typename HloInstructionPtr>
1406   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
1407                 bool call_finish_visit = true,
1408                 bool ignore_control_predecessors = false);
1409   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
1410                 bool ignore_control_predecessors = false) const {
1411     return const_cast<HloInstruction*>(this)->Accept(
1412         visitor, call_finish_visit, ignore_control_predecessors);
1413   }
1414 
1415   // Same as Accept() above, but the order of operand and control predecessor
1416   // visitation is determined by the given operand order; if compare(A, B) ==
1417   // true, A is visited before B.
1418   using CompareFunction =
1419       std::function<bool(const HloInstruction*, const HloInstruction*)>;
1420   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
1421                                 const CompareFunction& operand_order,
1422                                 bool call_finish_visit = true);
1423 
1424   // Visit this instruction and only this instruction with the given visitor.
1425   template <typename HloInstructionPtr>
1426   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
1427 
1428   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
1429   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
1430   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
1431   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
1432       const;
1433 
LatestNonGteAncestorAndIndex()1434   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
1435     auto rv =
1436         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
1437     return {const_cast<HloInstruction*>(rv.first), rv.second};
1438   }
1439 
1440   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
1441   const HloInstruction* LatestNonGteAncestor() const;
1442 
LatestNonGteAncestor()1443   HloInstruction* LatestNonGteAncestor() {
1444     return const_cast<HloInstruction*>(
1445         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
1446   }
1447 
1448   // Returns true whether this instruction is effectively a bitcast. Currently,
1449   // this means it either is a bitcast, or it is a transpose that is effectively
1450   // a bitcast.
1451   bool IsEffectiveBitcast() const;
1452 
1453   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
1454   // The setter should only be called by HloModule or HloComputation methods.
1455   //
1456   // Precondition: The instruction has a valid to_apply_ field.
1457   HloComputation* to_apply() const;
1458   void set_to_apply(HloComputation* computation);
1459 
1460   // Gets/sets the while_condition or while_body HloComputation for While. The
1461   // setters should only be called by HloModule or HloComputation methods.
1462   //
1463   // Precondition: The instruction is a While instruction.
1464   HloComputation* while_condition() const;
1465   HloComputation* while_body() const;
1466   void set_while_condition(HloComputation* computation);
1467   void set_while_body(HloComputation* computation);
1468 
1469   HloInstruction* while_init() const;
1470 
1471   // Gets/sets the true and false HloComputation for Conditional.
1472   //
1473   // Precondition: The instruction is a predicated Conditional instruction.
1474   HloComputation* true_computation() const;
1475   HloComputation* false_computation() const;
1476 
1477   // Gets the branch HloComputations for Conditional.
1478   //
1479   // Precondition: The instruction is a Conditional instruction.
1480   const std::vector<HloComputation*>& branch_computations() const;
1481   int branch_count() const;
1482   HloComputation* branch_computation(int b) const;
1483   // Sets a branch HloComputation for Conditional.
1484   // The setter should only be called by HloModule or HloComputation methods.
1485   //
1486   // Precondition: The instruction is a Conditional instruction.
1487   void set_branch_computation(int b, HloComputation* computation);
1488 
1489   // Returns a string for the signature of this instruction if considered as a
1490   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
1491   std::string SignatureString() const;
1492 
1493   // Returns a debugging string that represents this instruction.
1494   //
1495   // (We express the default options using an overload rather than a default
1496   // param because gdb ignores default params, but does resolve overloads.)
1497   //
1498   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
1499   // default, backing off on providing full information for very large strings,
1500   // or provide a different name for a ToString-like function that does that.
ToString()1501   std::string ToString() const { return ToString(HloPrintOptions()); }
1502   std::string ToString(const HloPrintOptions& options) const;
1503 
1504   // Components of the ToString() representation:
1505 
1506   // Returns a string representation of the operand list.
1507   std::string OperandsToString(const HloPrintOptions& options) const;
1508 
1509   // Returns string representation of op-specific attributes.
1510   std::vector<std::string> ExtraAttributesToString(
1511       const HloPrintOptions& options) const;
1512 
1513   // As ToString, but returns a shorter string.
1514   std::string ToShortString() const;
1515 
1516   // Prints an instruction to a string.
1517   //
1518   // The canonical string representation needs to name operands and instruction
1519   // names in a consistent way. This is implemented through the
1520   // canonical_name_map.
1521   std::string ToStringWithCanonicalNameMap(
1522       const HloPrintOptions& options,
1523       CanonicalNameMap* canonical_name_map) const;
1524 
1525   // Returns a serialized representation of this instruction.
1526   virtual HloInstructionProto ToProto() const;
1527 
1528   // Returns a category for the HLO. This could be something like "convolution"
1529   // or "elementwise".
1530   virtual std::string ToCategory() const;
1531 
1532   // Returns true if this instruction is fused, ie contained within a fusion
1533   // instruction.
1534   bool IsFused() const;
1535 
1536   bool IsLoopFusion() const;
1537   bool IsInputFusion() const;
1538   bool IsOutputFusion() const;
1539   bool IsCustomFusion() const;
1540 
1541   // Returns true if this instruction can be legally fused into a fusion
1542   // instruction.
1543   bool IsFusible() const;
1544 
1545   bool IsCustomCall(absl::string_view target) const;
1546 
1547   // Returns the sharding applied to this operator.
1548   // REQUIRES: has_sharding() is true.
sharding()1549   const HloSharding& sharding() const {
1550     CHECK(has_sharding());
1551     return *sharding_;
1552   }
sharding_ptr()1553   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
1554 
1555   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)1556   const HloSharding& sharding_or_default(const HloSharding& default_) const {
1557     return sharding_ ? *sharding_ : default_;
1558   }
1559   // Returns the sharding unique device, if any.
sharding_unique_device()1560   std::optional<int64_t> sharding_unique_device() const {
1561     if (sharding_ == nullptr) {
1562       return std::optional<int64_t>();
1563     }
1564     return sharding_->UniqueDevice();
1565   }
1566   // Sets the sharding of this operator. Should only be called by HloModule or
1567   // HloComputation methods.
set_sharding(const HloSharding & sharding)1568   void set_sharding(const HloSharding& sharding) {
1569     sharding_ = std::make_shared<const HloSharding>(sharding);
1570   }
set_sharding(std::shared_ptr<const HloSharding> sharding)1571   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
1572     sharding_ = std::move(sharding);
1573   }
1574   void set_single_sharding(const HloSharding& sharding);
1575   // Sets a sharding that assigns the current instruction to device.
set_device_sharding(int64_t device)1576   void set_device_sharding(int64_t device) {
1577     set_single_sharding(HloSharding::AssignDevice(device));
1578   }
1579   // Remove any sharding from this operator.
clear_sharding()1580   void clear_sharding() { sharding_ = nullptr; }
1581   // Return true if this operator has a sharding assigned.
has_sharding()1582   bool has_sharding() const { return sharding_ != nullptr; }
1583   // Checks whether the instruction has compatible sharding with the other
1584   // instruction.
has_compatible_sharding(const HloInstruction * other)1585   bool has_compatible_sharding(const HloInstruction* other) const {
1586     if (!has_sharding()) {
1587       return !other->has_sharding();
1588     }
1589     return other->has_sharding() ? sharding() == other->sharding() : false;
1590   }
1591 
1592   // When creating a new instruction which either replaces, or shifts up (kCopy
1593   // insertion case), another instruction, we need to make sure the certain
1594   // properties of the new instruction are copied into the derived one. As of
1595   // today, the metadata and sharding will be propagated to the derived
1596   // instruction.
1597   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
1598 
1599   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1600   // operands. After creation the clone has no uses. "this" (the instruction
1601   // cloned from) is not changed. Suffix is the string to append to the name of
1602   // the instruction to form the name of the cloned instruction.
1603   // Ignores the control predecessors and successors of this HLO instruction.
1604   std::unique_ptr<HloInstruction> Clone(
1605       const std::string& suffix = "clone",
1606       HloCloneContext* context = nullptr) const;
1607 
1608   // Clones the HLO instruction as above but with new shape.
1609   std::unique_ptr<HloInstruction> CloneWithNewShape(
1610       const Shape& shape, const std::string& suffix = "clone",
1611       HloCloneContext* context = nullptr) const;
1612 
1613   // Clones the HLO instruction as above but with new shape and operands.
1614   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1615       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1616       HloCloneContext* context = nullptr) const;
1617 
1618   // Returns the computations this instruction directly calls (if any).
called_computations()1619   const std::vector<HloComputation*>& called_computations() const {
1620     return called_computations_;
1621   }
1622 
1623   // Replaces all called computations based on a map function. This is needed
1624   // when we clone hlo_computations and want to let the instructions to point
1625   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1626   void ReplaceCalledComputations(
1627       std::function<HloComputation*(HloComputation*)> map_function) {
1628     for (int64_t i = 0; i < called_computations_.size(); ++i) {
1629       called_computations_[i] = map_function(called_computations_[i]);
1630     }
1631   }
1632 
1633   // Clears out the called computations.
1634   //
1635   // This is, in particular, necessary when inlining function bodies into their
1636   // caller. If there were side-effecting operations in the called computations,
1637   // the call itself is considered side-effecting and thus cannot be removed. By
1638   // clearing out the computations, we reflect the fact that all side-effecting
1639   // properties have been reflected in the caller, and make the call HLO
1640   // removable.
ClearCalledComputations()1641   virtual void ClearCalledComputations() { called_computations_.clear(); }
1642 
1643   // Returns true if this instruction performs an elementwise operation on
1644   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1645   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
1646   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
1647   //
1648   // Note on performance: when this instruction is kFusion, this method, in the
1649   // worst case, scans all fused instructions. We could speed this up by
1650   // caching.
1651   bool IsElementwiseOnOperand(int64_t operand_idx) const;
1652 
1653   // Returns true if this instruction is elementwise on all its operands.
1654   bool IsElementwise() const;
1655 
1656   static bool IsOpElementwise(HloOpcode opcode);
1657 
1658   // Returns true if this is a cross module all-reduce instruction.
1659   bool IsCrossModuleAllReduce() const;
1660 
1661   // Returns true if this is a cross-replica all-reduce instruction.
1662   bool IsCrossReplicaAllReduce() const;
1663 
1664   // Returns true if this instruction is binary and elementwise.
1665   bool IsElementwiseBinary() const;
1666 
1667   // Returns whether this instruction may reuse elements of its `i`th operand.
1668   bool ReusesOperandElements(int64_t i) const;
1669 
1670   // Returns the indices that the given operand appear in the operand list of
1671   // this instruction. Note that an instruction can use the same operand
1672   // multiple times.
1673   absl::InlinedVector<int64_t, 4> OperandIndices(
1674       const HloInstruction* operand) const;
1675 
1676   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1677   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1678   // indices of the deleted dimensions and the output indices of the inserted
1679   // dimensions.
1680   //
1681   // Precondition: this op must be a reshape.
1682   std::optional<ShapeUtil::ShapeEqualityDescriptor>
1683   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1684 
1685   // Gets the string identifier for this instruction.
name()1686   const std::string& name() const { return name_; }
1687 
1688   // Sets the string identifier for this instruction. Name will be sanitized to
1689   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
1690   //
1691   // See also HloModule::SetAndUniquifyInstrName(), which does this plus
1692   // UniqufyName().
SetAndSanitizeName(absl::string_view name)1693   void SetAndSanitizeName(absl::string_view name) {
1694     name_ = NameUniquer::GetSanitizedName(name);
1695   }
1696 
1697   // Use the given NameUniquer to select a unique name for the instruction based
1698   // on the instruction's existing name.
1699   //
1700   // See also HloModule::SetAndUniquifyInstrName(), which does this plus
1701   // SetAndSanitizeName().
1702   void UniquifyName(NameUniquer* name_uniquer);
1703 
1704   // Clear the unique ID of the instruction so that it can be re-assigned, such
1705   // as for the purpose of compacting the instruction unique IDs.
ClearUniqueIdInternal()1706   void ClearUniqueIdInternal() { unique_id_ = -1; }
1707 
1708   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1709   void SetUniqueId(int id) {
1710     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1711     CHECK_GE(id, 0);
1712     unique_id_ = id;
1713   }
1714 
1715   // Return the unique ID assigned to this node via SetUniqueId (or -1
1716   // if no id has been assigned yet).
unique_id()1717   int unique_id() const { return unique_id_; }
1718 
1719   template <typename T>
1720   using EnableIfProto = typename std::enable_if_t<
1721       std::is_base_of<tensorflow::protobuf::Message, T>::value>;
1722 
1723   // Returns the backend-specific configuration for how a backend should compile
1724   // this HLO. The meaning of the field is backend specific. Not for use before
1725   // or during general HLO optimization, since HLO optimizations do not preserve
1726   // this field and they cannot interpret it due to its meaning being backend
1727   // specific. Except for CustomCall, where this field is preserved and no
1728   // general HLO optimization needs to interpret it.
1729   //
1730   // ConfigProto should be a protobuf Message type.
1731   template <typename ConfigProto, EnableIfProto<ConfigProto>* = nullptr>
backend_config()1732   StatusOr<ConfigProto> backend_config() const {
1733     ConfigProto proto;
1734     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
1735     return std::move(proto);
1736   }
1737 
set_backend_config(const tensorflow::protobuf::Message & proto)1738   Status set_backend_config(const tensorflow::protobuf::Message& proto) {
1739     backend_config_ = proto;
1740     return OkStatus();
1741   }
1742 
has_backend_config()1743   bool has_backend_config() const { return !backend_config_.empty(); }
1744 
clear_backend_config()1745   void clear_backend_config() { backend_config_.clear(); }
1746 
CopyBackendConfigFrom(const HloInstruction * other)1747   void CopyBackendConfigFrom(const HloInstruction* other) {
1748     backend_config_ = other->backend_config_.Clone();
1749   }
1750 
set_frontend_attributes(FrontendAttributes frontend_attributes)1751   void set_frontend_attributes(FrontendAttributes frontend_attributes) {
1752     frontend_attributes_ = std::move(frontend_attributes);
1753   }
1754 
add_frontend_attributes(FrontendAttributes frontend_attributes)1755   void add_frontend_attributes(FrontendAttributes frontend_attributes) {
1756     frontend_attributes_.mutable_map()->insert(
1757         frontend_attributes.map().begin(), frontend_attributes.map().end());
1758   }
1759 
frontend_attributes()1760   const FrontendAttributes& frontend_attributes() const {
1761     return frontend_attributes_;
1762   }
1763 
1764   // Getter/setter for raw JSON-encoded backend config.  Prefer the
1765   // functions above that deal in proto Messages where possible.
raw_backend_config_string()1766   const std::string& raw_backend_config_string() const {
1767     return backend_config_.GetRawString();
1768   }
set_raw_backend_config_string(std::string config_str)1769   void set_raw_backend_config_string(std::string config_str) {
1770     backend_config_ = std::move(config_str);
1771   }
1772 
is_default_config()1773   bool is_default_config() const { return is_default_config_; }
set_default_config()1774   void set_default_config() { is_default_config_ = true; }
1775 
1776   // Returns a string representation of a proto in the format used by
1777   // raw_backend_config_string.
1778   //
1779   // This is morally equivalent to:
1780   //
1781   //   HloInstruction instr;
1782   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
1783   //   return instr.raw_backend_config_string();
1784   //
1785   static StatusOr<std::string> BackendConfigToRawString(
1786       const tensorflow::protobuf::Message& proto);
1787 
1788   // Returns the information used to tell the implementation information about
1789   // what sort of precision is requested. The meaning of the field is backend
1790   // specific. At the moment, it is only supported for kConvolution and kDot.
1791   // Transformations on one kDot or kConvolution to another will preserve this
1792   // information. Transformations to other HLOs will not preserve this
1793   // information but it is presumed that the alternate lowering is strictly
1794   // superior.
1795   // Precondition: opcode must be kConvolution or kDot.
1796   const PrecisionConfig& precision_config() const;
1797   PrecisionConfig* mutable_precision_config();
1798 
1799   // Sets the debug metadata for this instruction, excluding creation_pass_id,
1800   // which should never be copied anywhere.
set_metadata(const OpMetadata & metadata)1801   void set_metadata(const OpMetadata& metadata) {
1802     int64_t creation_pass_id = metadata_.creation_pass_id();
1803     metadata_ = metadata;
1804     metadata_.set_creation_pass_id(creation_pass_id);
1805   }
1806 
set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes)1807   void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) {
1808     metadata_.set_size_of_generated_code_in_bytes(code_size_in_bytes);
1809   }
set_size_of_memory_working_set_in_bytes(int64_t working_set_size_in_bytes)1810   void set_size_of_memory_working_set_in_bytes(
1811       int64_t working_set_size_in_bytes) {
1812     metadata_.set_size_of_memory_working_set_in_bytes(
1813         working_set_size_in_bytes);
1814   }
set_creation_pass_id(int64_t pass_id)1815   void set_creation_pass_id(int64_t pass_id) {
1816     metadata_.set_creation_pass_id(pass_id);
1817   }
set_metadata_op_name(const std::string & name)1818   void set_metadata_op_name(const std::string& name) {
1819     metadata_.set_op_name(name);
1820   }
set_logical_creation_pass_id(int64_t pass_id)1821   void set_logical_creation_pass_id(int64_t pass_id) {
1822     metadata_.set_logical_creation_pass_id(pass_id);
1823   }
set_metadata_replaced_op(absl::string_view replaced_op)1824   void set_metadata_replaced_op(absl::string_view replaced_op) {
1825     metadata_.set_replaced_op(std::string(replaced_op));
1826   }
metadata()1827   const OpMetadata& metadata() const { return metadata_; }
1828 
1829   // Set/get the computation containing this instruction. set_parent should only
1830   // be called by HloComputation methods which add/remove instructions to
1831   // computations.
set_parent(HloComputation * computation)1832   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1833   const HloComputation* parent() const { return parent_; }
parent()1834   HloComputation* parent() { return parent_; }
1835 
1836   // Returns the module for this instruction.
1837   HloModule* GetModule() const;
1838 
1839   // Get/Set the number of partitions per outer dimension (in order, starting
1840   // with outer-most dimension first). Currently used by the parallel cpu
1841   // backend to partition HLOs into parallel tasks.
1842   //
1843   // TODO(b/62783254) Replace these methods with a more general way to
1844   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1845   const std::vector<int64_t>& outer_dimension_partitions() const {
1846     return outer_dimension_partitions_;
1847   }
1848   void set_outer_dimension_partitions(
1849       const std::vector<int64_t>& outer_dimension_partitions);
1850 
1851   // A method that sorts users_, control_predecessors_, and control_successors_
1852   // according to the orders used in sorted_instruction. The sorting is used
1853   // during cloning, to make clone behavior match uncloned behavior.
1854   void SortInstructionUsersAndControlLists(
1855       const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
1856       const HloInstruction& sorted_instruction);
1857 
1858   // Old methods kept for smooth subclassing transition BEGIN.
1859   // TODO(b/80131774): Remove this code.
1860 
1861   // Delegates to HloBatchNormInstruction::feature_index.
1862   int64_t feature_index() const;
1863 
1864   // Delegates to HloBatchNormInstruction::epsilon.
1865   float epsilon() const;
1866 
1867   // Delegates to HloFftInstruction::fft_type.
1868   FftType fft_type() const;
1869 
1870   // Delegates to HloFftInstruction::fft_length.
1871   const std::vector<int64_t>& fft_length() const;
1872 
1873   // Delegates to HloChannelInstruction::channel_id.
1874   std::optional<int64_t> channel_id() const;
1875   void set_channel_id(const std::optional<int64_t>& channel_id);
1876 
1877   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1878   virtual absl::Span<const int64_t> dimensions() const {
1879     LOG(FATAL) << "Unimplemented method.";
1880   }
1881 
dimensions(int64_t index)1882   int64_t dimensions(int64_t index) const { return dimensions()[index]; }
1883 
mutable_dimensions()1884   virtual std::vector<int64_t>* mutable_dimensions() {
1885     LOG(FATAL) << "Unimplemented method.";
1886   }
1887 
1888   // Delegates to HloConcatenateInstruction::concatenate_dimension.
1889   virtual int64_t concatenate_dimension() const;
1890 
1891   // Delegates to HloGetDimensionSizeInstruction::dimension.
1892   int64_t dimension() const;
1893 
1894   // Delegates to HloReshapeInstruction::inferred_dimension.
1895   int64_t inferred_dimension() const;
1896 
1897   // Returns whether this instruction does a rank-2 transposition.
1898   bool IsRank2Transpose() const;
1899 
1900   // Delegates to HloSliceInstruction::slice_start.
1901   int64_t slice_starts(int64_t dimension) const;
1902   const std::vector<int64_t>& slice_starts() const;
1903   std::vector<int64_t>* mutable_slice_starts();
1904 
1905   // Delegates to HloSliceInstruction::slice_limits.
1906   int64_t slice_limits(int64_t dimension) const;
1907   const std::vector<int64_t>& slice_limits() const;
1908   std::vector<int64_t>* mutable_slice_limits();
1909 
1910   // Delegates to HloSliceInstruction::slice_strides.
1911   int64_t slice_strides(int64_t dimension) const;
1912   const std::vector<int64_t>& slice_strides() const;
1913   std::vector<int64_t>* mutable_slice_strides();
1914 
1915   // Returns the literal associated with this instruction.
1916   const Literal& literal() const;
1917 
1918   // Returns whether the instruction is a constant.
1919   bool IsConstant() const;
1920 
1921   // Delegate to HloConstantInstruction::RelayoutConstant.
1922   void RelayoutConstant(const Layout& new_layout,
1923                         const ShapeIndex& shape_index = {});
1924 
1925   // Delegates to
1926   // HloCallableInstruction::AppendInstructionIntoCalledComputation.
1927   HloInstruction* AppendInstructionIntoCalledComputation(
1928       HloInstruction* instruction_to_append, bool add_output = false);
1929 
1930   // Delegates to HloFusionInstruction::AddFusionOperand.
1931   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1932 
1933   // Delegates to HloFusionInstruction::MergeFusionInstruction.
1934   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
1935 
1936   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
1937   void MergeFusionInstructionIntoMultiOutput(
1938       HloInstruction* instruction_to_merge);
1939 
1940   // Delegates to HloFusionInstruction::FuseInstruction.
1941   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
1942 
1943   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
1944   HloInstruction* FuseInstructionIntoMultiOutput(
1945       HloInstruction* instruction_to_fuse);
1946 
1947   // Delegates to HloFusionInstruction::fused_instruction.
1948   HloComputation* fused_instructions_computation() const;
1949 
1950   // Delegates to HloFusionInstruction::fused_expression_root.
1951   HloInstruction* fused_expression_root() const;
1952 
1953   // Delegates to HloFusionInstruction::fused_instructions.
1954   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1955       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1956   fused_instructions() const;
1957 
1958   const tensorflow::gtl::iterator_range<
1959       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1960   fused_instructions();
1961 
1962   // Delegates to HloFusionInstruction::fused_instruction_count.
1963   int64_t fused_instruction_count() const;
1964 
1965   // Delegates to HloFusionInstruction::fused_parameter.
1966   HloInstruction* fused_parameter(int64_t parameter_number) const;
1967 
1968   // Delegates to HloFusionInstruction::fused_parameters.
1969   const std::vector<HloInstruction*>& fused_parameters() const;
1970 
1971   // Returns true if this instruction is a fusion instruction that generates
1972   // multiple outputs.
1973   const bool IsMultiOutputFusion() const;
1974 
1975   // Delegates to HloFusionInstruction::fusion_kind.
1976   FusionKind fusion_kind() const;
1977 
1978   // Delegates to HloFusionInstruction::set_fusion_kind.
1979   void set_fusion_kind(FusionKind kind);
1980 
1981   // Delegates to HloRngInstruction::random_distribution.
1982   RandomDistribution random_distribution() const;
1983 
1984   // Delegates to HloParameterInstruction::parameter_number.
1985   int64_t parameter_number() const;
1986 
1987   // Delegates to
1988   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
1989   void set_parameter_replicated_at_leaf_buffers(
1990       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
1991   void set_parameter_replicated_at_leaf_buffers(
1992       const std::vector<bool>& parameter_replicated_at_leaf_buffers);
1993 
1994   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
1995   const std::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()
1996       const;
1997 
1998   // Delegates to HloGetTupleElementInstruction::tuple_index.
1999   int64_t tuple_index() const;
2000 
2001   // Delegates to HloGetTupleElementInstruction::set_tuple_index.
2002   void set_tuple_index(int64_t new_tuple_index);
2003 
2004   // Delegates to HloReducePrecisionInstruction::exponent_bits.
2005   int32_t exponent_bits() const;
2006 
2007   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
2008   int32_t mantissa_bits() const;
2009 
2010   // Delegates to HloInfeedInstruction::infeed_config.
2011   std::string infeed_config() const;
2012 
2013   // Delegates to HloInfeedInstruction::set_infeed_config.
2014   void set_infeed_config(const std::string& config);
2015 
2016   // Returns the config for the Outfeed instruction.
2017   const std::string& outfeed_config() const;
2018 
2019   // Delegates to HloOutfeedInstruction::set_outfeed_config.
2020   void set_outfeed_config(const std::string& config);
2021 
2022   // Returns the shape for the Outfeed instruction.
2023   const Shape& outfeed_shape() const;
2024 
2025   // Returns the mutable shape for the Outfeed instruction.
2026   Shape* mutable_outfeed_shape();
2027 
2028   // Delegates to HloCollectiveInstruction::replica_groups.
2029   const std::vector<ReplicaGroup>& replica_groups() const;
2030 
2031   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
2032   const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const;
2033 
2034   // Returns data on the window in a windowed operation such as
2035   // convolution.
window()2036   virtual const Window& window() const {
2037     LOG(FATAL) << "Unimplemented method.";
2038   }
2039 
2040   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)2041   virtual void set_window(const Window& window) {
2042     LOG(FATAL) << "Unimplemented method.";
2043   }
2044 
2045   // Returns the unique_indices field.
unique_indices()2046   virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; }
2047 
2048   // Returns data on the dimension numbers used for a convolution operation,
2049   // which may be a kConvolution instruction or a kCustomCall that implements a
2050   // convolution.
2051   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
2052 
2053   // Sets the convolution dimension numbers on this instruction.  In general you
2054   // shouldn't need to call this; instead, specify the convolution dimension
2055   // numbers when you create the instruction.
2056   void set_convolution_dimension_numbers(
2057       const ConvolutionDimensionNumbers& dnums);
2058 
2059   // The number of feature groups. Must be a divisor of the input feature
2060   // dimension and output feature dimension.
2061   int64_t feature_group_count() const;
2062 
2063   void set_feature_group_count(int64_t feature_group_count);
2064 
2065   // The number of batch groups. Must be a divisor of the input batch dimension
2066   int64_t batch_group_count() const;
2067 
2068   void set_batch_group_count(int64_t batch_group_count);
2069 
2070   // Delegates to HloSelectAndScatterInstruction::select.
2071   HloComputation* select() const;
2072 
2073   // Delegates to HloSelectAndScatterInstruction::scatter.
2074   HloComputation* scatter() const;
2075 
2076   // Delegates to HloSelectAndScatterInstruction::set_select.
2077   void set_select(HloComputation* computation);
2078 
2079   // Delegates to HloSelectAndScatterInstruction::set_scatter.
2080   void set_scatter(HloComputation* computation);
2081 
2082   // Delegates to HloCustomCallInstruction::custom_call_target.
2083   const std::string& custom_call_target() const;
2084   void set_custom_call_target(absl::string_view target);
2085 
2086   // Delegates to HloPadInstruction::padding_config.
2087   const PaddingConfig& padding_config() const;
2088   PaddingConfig* mutable_padding_config();
2089 
2090   // Delegates to HloConvolutionInstruction::padding_type.
2091   PaddingType padding_type() const;
2092 
2093   // Delegates to HloDynamicSliceInstruction::slice_sizes.
2094   int64_t slice_sizes(int64_t dimension) const;
2095 
2096   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
2097   const std::vector<int64_t>& dynamic_slice_sizes() const;
2098 
2099   // Delegates to HloCollectivePermuteInstruction::dynamic_slice_sizes.
2100   const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const;
2101 
2102   // Delegates to HloGatherInstruction::gather_dimension_numbers.
2103   const GatherDimensionNumbers& gather_dimension_numbers() const;
2104   // Delegates to HloGatherInstruction::gather_slice_sizes.
2105   absl::Span<const int64_t> gather_slice_sizes() const;
2106 
2107   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
2108   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
2109 
2110   // Delegates to HloDotInstruction::dot_dimension_numbers().
2111   const DotDimensionNumbers& dot_dimension_numbers() const;
2112 
2113   // Delegates to HloDomainInstruction::operand_side_metadata().
2114   const DomainMetadata& operand_side_metadata() const;
2115 
2116   // Delegates to HloDomainInstruction::user_side_metadata().
2117   const DomainMetadata& user_side_metadata() const;
2118 
2119   // Returns true if the instruction is an async-start, async-update, or
2120   // async-done.
2121   bool IsAsynchronous() const;
2122 
2123   // Returns the computation that will executed asynchronously.
2124   HloComputation* async_wrapped_computation() const;
2125 
2126   // Delagates to HloAsyncInstruction::async_wrapped_instruction().
2127   HloInstruction* async_wrapped_instruction() const;
2128 
2129   // Delagates to HloAsyncInstruction::async_wrapped_opcode().
2130   HloOpcode async_wrapped_opcode() const;
2131 
2132   // Delegates to HloAsyncInstruction::async_group_id().
2133   std::optional<int64_t> async_group_id() const;
2134 
2135   // Delegates to HloAsyncInstruction::set_async_group_id().
2136   void set_async_group_id(std::optional<int64_t> async_group_id);
2137 
2138   // Delegates to HloAsyncInstruction::async_execution_thread().
2139   absl::string_view async_execution_thread() const;
2140 
2141   // Delegates to HloAsyncInstruction::set_async_execution_thread().
2142   void set_async_execution_thread(absl::string_view async_execution_thread);
2143 
2144   // Delegates to
2145   // HloCallableInstruction::RecursivelySetComputationsThreadName().
2146   void set_called_computations_execution_thread(
2147       absl::string_view async_execution_thread,
2148       bool skip_async_execution_thread_overwrite);
2149 
2150   // Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
2151   bool is_cross_program_prefetch() const;
2152 
2153   // Delegates to HloCompareInstruction::direction().
2154   ComparisonDirection comparison_direction() const;
2155   // Delegates to HloCompareInstruction::order().
2156   ComparisonOrder comparison_order() const;
2157 
2158   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
2159   const TriangularSolveOptions& triangular_solve_options() const;
2160 
2161   // Delegates to HloCholeskyInstruction::cholesky_options().
2162   const CholeskyOptions& cholesky_options() const;
2163 
2164   // Appends operand to the list of operands and adds this instruction as a user
2165   // of the operand.
2166   void AppendOperand(HloInstruction* operand);
2167 
2168   // Old methods kept for smooth subclassing transition END.
2169 
2170  protected:
2171   // Internal constructor for a given opcode/shape, other fields must be filled
2172   // by factory methods.
2173   HloInstruction(HloOpcode opcode, const Shape& shape);
2174 
RemoveAllOperands()2175   void RemoveAllOperands() { operands_.clear(); }
2176 
RemoveOperandAt(int index)2177   void RemoveOperandAt(int index) {
2178     operands_.erase(operands_.begin() + index);
2179   }
2180 
2181   // Removes a list of operands with the given indices in ascending order.
2182   void RemoveOperandsAtAscendingIndices(
2183       absl::Span<const int> ascending_indices);
2184 
AppendComputation(HloComputation * computation)2185   void AppendComputation(HloComputation* computation) {
2186     called_computations_.push_back(computation);
2187   }
2188 
DetachFrom(HloInstruction * usee)2189   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
2190 
set_called_computation(int index,HloComputation * computation)2191   void set_called_computation(int index, HloComputation* computation) {
2192     called_computations_[index] = computation;
2193   }
2194   // Indices of computations in called_computations_ for instructions which call
2195   // multiple computations.
2196   enum {
2197     // kWhile computations.
2198     kBodyComputationIndex = 0,
2199     kConditionComputationIndex = 1,
2200 
2201     // kSelectAndScatter computations.
2202     kSelectComputationIndex = 0,
2203     kScatterComputationIndex = 1,
2204 
2205     // kConditional computations.
2206     kTrueComputationIndex = 0,
2207     kFalseComputationIndex = 1,
2208   };
2209 
2210  private:
2211   friend class HloComputation;
2212   // Wrapper class of string format and protobuf format of BackendConfig.
2213   class BackendConfigRep {
2214    public:
GetProtoPtr()2215     const tensorflow::protobuf::Message* GetProtoPtr() const {
2216       return proto_.get();
2217     }
2218 
2219     const std::string& GetRawString() const;
2220 
2221     BackendConfigRep Clone() const;
2222 
2223     bool operator==(const BackendConfigRep& other) const;
2224     bool operator!=(const BackendConfigRep& other) const {
2225       return !(*this == other);
2226     }
2227 
empty()2228     bool empty() const { return proto_ == nullptr && raw_string_.empty(); }
2229 
clear()2230     void clear() {
2231       proto_.reset();
2232       raw_string_.clear();
2233     }
2234 
2235     BackendConfigRep& operator=(std::string raw_string);
2236     BackendConfigRep& operator=(const tensorflow::protobuf::Message& proto);
2237     void SetProto(const tensorflow::protobuf::Message& proto);
2238 
2239    private:
2240     std::unique_ptr<tensorflow::protobuf::Message> proto_;
2241     // If proto_ is not null, raw_string_ is a lazy cache of its string format.
2242     mutable std::string raw_string_;
2243   };
2244 
2245   bool IdenticalInternal(
2246       const HloInstruction& other,
2247       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2248           eq_operands,
2249       const std::function<bool(const HloComputation*, const HloComputation*)>&
2250           eq_computations,
2251       bool layout_sensitive, bool ignore_channel_id_values,
2252       bool ignore_commutative_operand_order) const;
2253 
2254   // Implementation for non-common logic of CloneWithNewOperands.
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2255   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2256       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2257       HloCloneContext* context) const {
2258     // TODO(b/80131774): This should be pure virtual.
2259     LOG(FATAL) << "Unimplemented method.";
2260   }
2261 
2262   // Implementation for non-common logic of ExtraAttributesToString.
ExtraAttributesToStringImpl(const HloPrintOptions & options)2263   virtual std::vector<std::string> ExtraAttributesToStringImpl(
2264       const HloPrintOptions& options) const {
2265     return {};
2266   }
2267 
2268   // Implementation for IsElementwise if operand_idx is nullopt and for
2269   // IsElementwiseOnOperand if otherwise.
2270   //
2271   // NOTE: For all instructions other than kFusion, being elementwise on one of
2272   // the operands is equivalent to being elementwise on all the operands.
2273   virtual bool IsElementwiseImpl(
2274       const std::optional<int64_t>& operand_idx) const;
2275 
2276   // Prints an operand to a string. Accessed by friend class HloInstruction.
2277   virtual std::string OperandsToStringWithCanonicalNameMap(
2278       const HloPrintOptions& options,
2279       CanonicalNameMap* canonical_name_map) const;
2280 
2281   // See comments on Identical().
2282   virtual bool IdenticalSlowPath(
2283       const HloInstruction& other,
2284       const std::function<bool(const HloComputation*, const HloComputation*)>&
2285           eq_computations) const;
2286 
2287   // Creates an n-ary elementwise operation.
2288   static std::unique_ptr<HloInstruction> CreateNary(
2289       const Shape& shape, HloOpcode opcode,
2290       absl::Span<HloInstruction* const> operands);
2291 
2292   // Adds a user for this instruction.
2293   void AddUser(HloInstruction* user);
2294 
2295   // Removes a user for this instruction.
2296   void RemoveUser(HloInstruction* user);
2297 
2298   // Helper for implementing backend_config().  Parses backend_config_ into the
2299   // given proto.
2300   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
2301 
2302   // Mark this instruction as dead. Accessed by friend class HloInstruction.
MarkAsDead()2303   void MarkAsDead() { marked_as_dead_ = true; }
2304 
2305   // Has this instruction been marked as dead? Accessed by friend class
2306   // HloInstruction.
IsMarkedAsDead()2307   bool IsMarkedAsDead() const { return marked_as_dead_; }
2308 
2309   int unique_id_;  // Unique to this HloInstruction within a HloModule
2310 
2311   // Opcode for this instruction.
2312   HloOpcode opcode_;
2313 
2314   // Instruction operands.
2315   InstructionVector operands_;
2316 
2317   // The set of control predecessors of this instruction.
2318   // Note that the order of the instructions in the vector influences the order
2319   // computed in HloComputation::ComputeInstructionPostOrder, which may
2320   // influence the result of the compilation by changing the scheduling. We are
2321   // not sure if it matters.
2322   std::vector<HloInstruction*> control_predecessors_;
2323 
2324   // The users of this instruction. Users are HLOs where this instruction is an
2325   // operand. The vector users_ and the map user_map_ contain identical members.
2326   // The map enables fast membership testing and the vector enables fast, stable
2327   // iteration. The value in the map contains the index of the instruction in
2328   // the vector what enables fast removal.
2329   std::vector<HloInstruction*> users_;
2330   absl::flat_hash_map<const HloInstruction*, int64_t> user_map_;
2331 
2332   // The set of control successors of this instruction.
2333   std::vector<HloInstruction*> control_successors_;
2334 
2335   // The computation in which this instruction is contained.
2336   HloComputation* parent_ = nullptr;
2337 
2338   // Result shape of this instruction.
2339   Shape shape_;
2340 
2341   // The sharding, if one exists.
2342   // Uses std::shared_ptr to allow reuse of the same sharding object between
2343   // HloInstructions and other components as HloSharding can be very large for
2344   // many element tuples.
2345   std::shared_ptr<const HloSharding> sharding_;
2346 
2347   // Computations called by this instruction.
2348   std::vector<HloComputation*> called_computations_;
2349 
2350   // A trace instruction that consumes this instruction.
2351   //
2352   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
2353   // an operand.
2354   HloInstruction* trace_instruction_ = nullptr;
2355 
2356   // The backend-specific configuration for how a backend should compile this
2357   // HLO. See the documentation on backend_config().
2358   mutable BackendConfigRep backend_config_;
2359 
2360   // Attributes passed from the frontend to give hints to the backend about
2361   // how to compile this HLO.
2362   // HLO -> HLO transforms are expected to preserve these attributes on a
2363   // "best effort" basis only.
2364   // For example:
2365   //    x = const(10, frontend_attributes={x}
2366   //    y = const(10, frontend_attributes={y}
2367   //    z = add(x,y), frontend_attributes={y}
2368   // Could be simplified to:
2369   //    z' = const(20), frontend_attributes={?}
2370   FrontendAttributes frontend_attributes_;
2371 
2372   // This field is assigned to true when backend_config_ is assigned to
2373   // a default configuration.
2374   bool is_default_config_ = false;
2375 
2376   // True if this instruction has already been detached from its user and
2377   // operands.
2378   bool cleaned_up_ = false;
2379 
2380   // String identifier for instruction.
2381   std::string name_;
2382 
2383   // Metadata for debugging.
2384   OpMetadata metadata_;
2385 
2386   // The number of partitions per outer dimension (listed in order from
2387   // outer-most dimension first).
2388   std::vector<int64_t> outer_dimension_partitions_;
2389 
2390   // Intrusive flag used by HloComputation, whether this instruction has
2391   // been marked as dead.
2392   bool marked_as_dead_;
2393 
2394   HloInstruction(const HloInstruction&) = delete;
2395   HloInstruction& operator=(const HloInstruction&) = delete;
2396 };
2397 
2398 // Explicit instantiations in hlo_instruction.cc.
2399 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2400 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2401 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2402 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2403 
2404 std::string ToString(HloInstruction::FusionKind kind);
2405 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2406     const std::string& kind_name);
2407 
2408 // Custom (de)stringification functions for protos that live inside
2409 // HloInstruction.
2410 std::string PaddingConfigToString(const PaddingConfig& padding);
2411 std::string FrontendAttributesToString(
2412     const FrontendAttributes& frontend_attributes);
2413 std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
2414 std::string RandomDistributionToString(const RandomDistribution& distribution);
2415 std::string PrecisionToString(const PrecisionConfig::Precision& precision);
2416 std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums);
2417 std::string ConvolutionDimensionNumbersToString(
2418     const ConvolutionDimensionNumbers& dnums);
2419 std::string ReplicaGroupsToString(
2420     absl::Span<const ReplicaGroup> replica_groups);
2421 
2422 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const std::string& name);
2423 StatusOr<RandomDistribution> StringToRandomDistribution(
2424     const std::string& name);
2425 StatusOr<PrecisionConfig::Precision> StringToPrecision(const std::string& name);
2426 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(absl::string_view name);
2427 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
2428     absl::string_view name);
2429 
2430 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
2431 
2432 // Map classes that guarantee a deterministic iteration order when the key is
2433 // an HloInstruction* or a const HloInstruction*.
2434 // To make the iteration order over the map deterministic, the comparator
2435 // should not be using the pointer values, but rather an intrinsic property of
2436 // the hlo. Exception: null pointer values compare less than non-null.
2437 struct HloPtrComparator {
2438   bool operator()(const HloInstruction* const& lhs,
2439                   const HloInstruction* const& rhs) const;
2440 };
2441 
2442 template <typename ValueT>
2443 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
2444 
2445 template <typename ValueT>
2446 using ConstHloInstructionMap =
2447     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
2448 
2449 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
2450 using ConstHloInstructionSet =
2451     std::set<const HloInstruction*, HloPtrComparator>;
2452 
2453 }  // namespace xla
2454 
2455 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
2456