1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <optional> 29 #include <ostream> 30 #include <set> 31 #include <string> 32 #include <tuple> 33 #include <utility> 34 #include <vector> 35 36 #include "absl/container/flat_hash_map.h" 37 #include "absl/container/flat_hash_set.h" 38 #include "absl/container/inlined_vector.h" 39 #include "absl/hash/hash.h" 40 #include "absl/strings/str_cat.h" 41 #include "absl/strings/string_view.h" 42 #include "absl/types/span.h" 43 #include "tensorflow/compiler/xla/comparison_util.h" 44 #include "tensorflow/compiler/xla/iterator_util.h" 45 #include "tensorflow/compiler/xla/literal.h" 46 #include "tensorflow/compiler/xla/map_util.h" 47 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 48 #include "tensorflow/compiler/xla/service/hlo.pb.h" 49 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 50 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" 51 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 52 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 53 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h" 54 #include "tensorflow/compiler/xla/service/name_uniquer.h" 55 #include "tensorflow/compiler/xla/shape_tree.h" 56 #include "tensorflow/compiler/xla/types.h" 57 #include "tensorflow/compiler/xla/xla_data.pb.h" 58 #include "tensorflow/core/lib/core/status.h" 59 #include "tensorflow/core/lib/gtl/iterator_range.h" 60 #include "tensorflow/core/platform/logging.h" 61 #include "tensorflow/core/platform/protobuf.h" 62 63 namespace xla { 64 65 class HloComputation; 66 class HloModule; 67 68 std::string PrintName(const std::string& name, bool print_ids); 69 70 // A bunch of switches that control how the hlo text should be printed. 71 class HloPrintOptions { 72 public: 73 enum class PrintSubcomputationMode { 74 kOff, // Do not print anything about subcomputations. 75 kNameOnly, // Only print the name of subcomputations. 76 kFullBodies, // Print the full bodies of subcomputations. 77 kNonSequentialBodies, // Print the full bodies of subcomputations that are 78 // not in a sequential context. 79 }; 80 81 // Constructs the default print options: don't print large constants, don't 82 // compact operands, no indentation. HloPrintOptions()83 HloPrintOptions() 84 : print_large_constants_(false), 85 print_only_essential_constants_(false), 86 print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), 87 print_metadata_(true), 88 print_backend_config_(true), 89 print_infeed_outfeed_config_(true), 90 compact_operands_(false), 91 include_layout_in_shapes_(true), 92 print_result_shape_(true), 93 print_operand_shape_(true), 94 print_operand_names_(true), 95 print_operand_index_annotation_interval_(5), 96 print_program_shape_(true), 97 print_percent_(true), 98 print_control_dependencies_(true), 99 canonicalize_instruction_names_(false), 100 indent_amount_(0), 101 is_in_nested_computation_(false), 102 print_ids_(true), 103 canonicalize_computations_(false), 104 print_extra_attributes_(true), 105 syntax_sugar_async_ops_(true) {} 106 ShortParsable()107 static HloPrintOptions ShortParsable() { 108 return HloPrintOptions() 109 .set_print_large_constants(true) 110 .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) 111 .set_print_metadata(false) 112 .set_print_backend_config(false) 113 .set_print_operand_shape(false) 114 .set_print_operand_index_annotation_interval(0) 115 .set_print_program_shape(false) 116 .set_print_percent(false) 117 .set_print_control_dependencies(false); 118 } 119 120 // Options to produce the canonical string representing an isomorphic 121 // computation graph. Canonical()122 static HloPrintOptions Canonical() { 123 return HloPrintOptions() 124 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 125 .set_print_metadata(false) 126 .set_print_backend_config(false) 127 // Compact option won't print name of operand_k, where k > a threshold. 128 // Canonical (and Fingerprint) must include all operands. 129 .set_compact_operands(false) 130 .set_print_operand_names(false) 131 .set_print_operand_shape(true) 132 // No index annotations as they are only for ease of human inspection. 133 .set_print_operand_index_annotation_interval(0) 134 .set_print_program_shape(false) 135 .set_print_percent(false) 136 .set_print_control_dependencies(false) 137 .set_canonicalize_instruction_names(true); 138 } 139 140 // Options to produce a fingerprint of an HLO instruction. 141 // Based on Canonical() with some important changes commented below. Fingerprint()142 static HloPrintOptions Fingerprint() { 143 return Canonical() 144 // Exclude because they do not affect HLO optimizations. 145 .set_print_infeed_outfeed_config(false) 146 // Exclude floating point constant literals that are not all zeros, all 147 // ones, or integers because they may be randomly initialized weights, 148 // which may be changed across different runs. 149 .set_print_only_essential_constants(true) 150 // Remove "id" in "name.id" (after period) because it can be 151 // non-deterministic. This mainly affects computations' names because 152 // canonicalized instructions' names are in "tmp_id" format. 153 .set_print_ids(false) 154 // Sort computations. 155 .set_canonicalize_computations(true); 156 } 157 158 // Options to produce a fingerprint of an HLO module and computation. 159 // Shorter (and therefore faster) than Fingerprint(). ModuleFingerprint()160 static HloPrintOptions ModuleFingerprint() { 161 return Fingerprint() 162 // Operand shapes can be inferred from output shapes and canonicalized 163 // names when we have an entire computation. 164 .set_print_operand_shape(false); 165 } 166 167 // If true, large constants will be printed out. set_print_large_constants(bool value)168 HloPrintOptions& set_print_large_constants(bool value) { 169 print_large_constants_ = value; 170 return *this; 171 } 172 173 // If true, only integer, all-zero, are all-one constants will be printed out. set_print_only_essential_constants(bool value)174 HloPrintOptions& set_print_only_essential_constants(bool value) { 175 print_only_essential_constants_ = value; 176 return *this; 177 } 178 set_print_subcomputation_mode(PrintSubcomputationMode value)179 HloPrintOptions& set_print_subcomputation_mode( 180 PrintSubcomputationMode value) { 181 print_subcomputation_mode_ = value; 182 return *this; 183 } 184 185 // If true, metadata will be printed. set_print_metadata(bool value)186 HloPrintOptions& set_print_metadata(bool value) { 187 print_metadata_ = value; 188 return *this; 189 } 190 191 // If true, backend_config will be printed. set_print_backend_config(bool value)192 HloPrintOptions& set_print_backend_config(bool value) { 193 print_backend_config_ = value; 194 return *this; 195 } 196 197 // If true, infeed_config and outfeed_config will be printed. set_print_infeed_outfeed_config(bool value)198 HloPrintOptions& set_print_infeed_outfeed_config(bool value) { 199 print_infeed_outfeed_config_ = value; 200 return *this; 201 } 202 203 // If true, result shapes will be printed. set_print_result_shape(bool value)204 HloPrintOptions& set_print_result_shape(bool value) { 205 print_result_shape_ = value; 206 return *this; 207 } 208 209 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)210 HloPrintOptions& set_print_operand_shape(bool value) { 211 print_operand_shape_ = value; 212 return *this; 213 } 214 215 // If true, operands' shapes will be printed. set_print_operand_index_annotation_interval(int64_t value)216 HloPrintOptions& set_print_operand_index_annotation_interval(int64_t value) { 217 print_operand_index_annotation_interval_ = value; 218 return *this; 219 } 220 221 // If true, the operand names will be printed. set_print_operand_names(bool value)222 HloPrintOptions& set_print_operand_names(bool value) { 223 print_operand_names_ = value; 224 return *this; 225 } 226 227 // If true, all printed names include unique identifiers. set_print_ids(bool value)228 HloPrintOptions& set_print_ids(bool value) { 229 print_ids_ = value; 230 return *this; 231 } 232 233 // If true, the HLO includes its attributes. set_print_extra_attributes(bool value)234 HloPrintOptions& set_print_extra_attributes(bool value) { 235 print_extra_attributes_ = value; 236 return *this; 237 } 238 239 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)240 HloPrintOptions& set_print_program_shape(bool value) { 241 print_program_shape_ = value; 242 return *this; 243 } 244 245 // If true, names will be printed with prefix '%'. set_print_percent(bool value)246 HloPrintOptions& set_print_percent(bool value) { 247 print_percent_ = value; 248 return *this; 249 } 250 251 // If true, control dependencies will be printed. set_print_control_dependencies(bool value)252 HloPrintOptions& set_print_control_dependencies(bool value) { 253 print_control_dependencies_ = value; 254 return *this; 255 } 256 257 // If true, uses the async operation syntax sugar to print async-start, 258 // async-update, and async-done HLOs. If the syntax sugar is enabled, the 259 // computations called by these instructions will not be printed and instead 260 // the root of the called computation will be printed instead of these 261 // instructions and -start, -update, and -done suffixes will be appended to 262 // the opcode of the async operation. For example, for an HLO module where the 263 // syntax sugar is off: 264 // 265 // HloModule Module 266 // 267 // %AsyncOp (p0.1: f32[10]) -> f32[20] { 268 // %p0.1 = f32[10]{0} parameter(0) 269 // ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1), 270 // custom_call_target="foo" 271 // } 272 // 273 // ENTRY %Entry (p0: f32[10]) -> f32[20] { 274 // %p0 = f32[10]{0} parameter(0) 275 // %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(%p0), 276 // calls=%AsyncOp 277 // %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update( 278 // %async-start), 279 // calls=%AsyncOp 280 // ROOT %async-done = f32[20]{0} async-done(%async-update), calls=%AsyncOp 281 // } 282 // 283 // will be printed as following if the syntax sugar is enabled: 284 // 285 // HloModule Module 286 // 287 // ENTRY %Entry (p0: f32[10]) -> f32[20] { 288 // %p0 = f32[10]{0} parameter(0) 289 // %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(%p0), 290 // custom_call_target="foo" 291 // %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update( 292 // %async-start), 293 // custom_call_target="foo" 294 // ROOT %async-done = f32[20]{0} custom-call-done(%async-update), 295 // custom_call_target="foo" 296 // } set_syntax_sugar_async_ops(bool value)297 HloPrintOptions& set_syntax_sugar_async_ops(bool value) { 298 syntax_sugar_async_ops_ = value; 299 return *this; 300 } 301 302 // If true, only a part of operands will be printed out (note that in this 303 // case the text will not be parsable). set_compact_operands(bool value)304 HloPrintOptions& set_compact_operands(bool value) { 305 compact_operands_ = value; 306 return *this; 307 } 308 309 // If true, include the layout in any shapes that are printed (instruction 310 // and operands). set_include_layout_in_shapes(bool value)311 HloPrintOptions& set_include_layout_in_shapes(bool value) { 312 include_layout_in_shapes_ = value; 313 return *this; 314 } 315 316 // If true, canonicalizes instructions' name. Instead of using "%foo.1" as 317 // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. set_canonicalize_instruction_names(bool value)318 HloPrintOptions& set_canonicalize_instruction_names(bool value) { 319 canonicalize_instruction_names_ = value; 320 return *this; 321 } 322 323 // If true, canonicalizes computations, sorting by computations' names. set_canonicalize_computations(bool value)324 HloPrintOptions& set_canonicalize_computations(bool value) { 325 canonicalize_computations_ = value; 326 return *this; 327 } 328 329 // The indent of the hlo text block. set_indent_amount(int value)330 HloPrintOptions& set_indent_amount(int value) { 331 indent_amount_ = value; 332 return *this; 333 } 334 335 // If true, indicates the instruction being printed is inside a nested 336 // computation. set_is_in_nested_computation(bool value)337 HloPrintOptions& set_is_in_nested_computation(bool value) { 338 is_in_nested_computation_ = value; 339 return *this; 340 } 341 print_large_constants()342 bool print_large_constants() const { return print_large_constants_; } print_only_essential_constants()343 bool print_only_essential_constants() const { 344 return print_only_essential_constants_; 345 } print_subcomputation_mode()346 PrintSubcomputationMode print_subcomputation_mode() const { 347 return print_subcomputation_mode_; 348 } print_metadata()349 bool print_metadata() const { return print_metadata_; } print_backend_config()350 bool print_backend_config() const { return print_backend_config_; } print_infeed_outfeed_config()351 bool print_infeed_outfeed_config() const { 352 return print_infeed_outfeed_config_; 353 } compact_operands()354 bool compact_operands() const { return compact_operands_; } include_layout_in_shapes()355 bool include_layout_in_shapes() const { return include_layout_in_shapes_; } print_result_shape()356 bool print_result_shape() const { return print_result_shape_; } print_operand_shape()357 bool print_operand_shape() const { return print_operand_shape_; } print_operand_names()358 bool print_operand_names() const { return print_operand_names_; } print_operand_index_annotation_interval()359 int64_t print_operand_index_annotation_interval() const { 360 return print_operand_index_annotation_interval_; 361 } print_ids()362 bool print_ids() const { return print_ids_; } print_program_shape()363 bool print_program_shape() const { return print_program_shape_; } print_percent()364 bool print_percent() const { return print_percent_; } print_control_dependencies()365 bool print_control_dependencies() const { 366 return print_control_dependencies_; 367 } print_extra_attributes()368 bool print_extra_attributes() const { return print_extra_attributes_; } syntax_sugar_async_ops()369 bool syntax_sugar_async_ops() const { return syntax_sugar_async_ops_; } canonicalize_instruction_names()370 bool canonicalize_instruction_names() const { 371 return canonicalize_instruction_names_; 372 } canonicalize_computations()373 bool canonicalize_computations() const { return canonicalize_computations_; } indent_amount()374 int indent_amount() const { return indent_amount_; } is_in_nested_computation()375 int is_in_nested_computation() const { return is_in_nested_computation_; } 376 377 private: 378 bool print_large_constants_; 379 bool print_only_essential_constants_; 380 PrintSubcomputationMode print_subcomputation_mode_; 381 bool print_metadata_; 382 bool print_backend_config_; 383 bool print_infeed_outfeed_config_; 384 bool compact_operands_; 385 bool include_layout_in_shapes_; 386 bool print_result_shape_; 387 bool print_operand_shape_; 388 bool print_operand_names_; 389 // The interval between the /*index=*/ annotated operands. 0 means never print 390 // the annotation, 1 means print annotation for every operand. 391 int64_t print_operand_index_annotation_interval_; 392 bool print_program_shape_; 393 bool print_percent_; 394 bool print_control_dependencies_; 395 bool canonicalize_instruction_names_; 396 int indent_amount_; 397 bool is_in_nested_computation_; 398 bool print_ids_; 399 bool canonicalize_computations_; 400 bool print_extra_attributes_; 401 bool syntax_sugar_async_ops_; 402 }; 403 404 // For canonical string output, we need to have a canonical way to rename 405 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>", 406 // where <xxx> is an index starting from 0. 407 class CanonicalNameMap { 408 public: LookupOrInsert(const std::string & name)409 const std::string& LookupOrInsert(const std::string& name) { 410 std::string& canonical_name = canonical_name_map_[name]; 411 if (canonical_name.empty()) { 412 absl::StrAppend(&canonical_name, "tmp_", canonical_name_map_.size() - 1); 413 } 414 return canonical_name; 415 } 416 417 private: 418 absl::flat_hash_map<std::string, std::string> canonical_name_map_; 419 }; 420 421 // HLO instructions are the atomic unit of the high-level compiler's IR. 422 // 423 // HloInstructions live inside of an HloComputation, which is analogous to a 424 // function in other programming languages. Nodes have no total order within 425 // their computation. Instead, they have a partial ordering determined by their 426 // data and control dependencies. 427 // 428 // HLO does not have basic blocks or explicit "branch" instructions. Instead, 429 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode 430 // control flow. For example, the kConditional HLO executes one of two possible 431 // computations, depending on the runtime value of a predicate. 432 // 433 // HLO is pure (mostly). It has no concept of mutable state. Instead, data 434 // values are produced by one HLO and flow into consumers across dependency 435 // edges. 436 class HloInstruction { 437 public: 438 // A fusion node computes the same value a call to its fusion computation 439 // would compute. However, the choice of fusion kind dictates codegen 440 // strategy for the backend. 441 // 442 // To generate code for a kFusion HloInstruction, most backends do something 443 // like the following: 444 // 445 // 1) Identify the "primary" HloInstruction of the fused computation. 446 // 2) Emit code that does the work of the primary node, creating its inputs 447 // and transforming its outputs as specified by the fused computation. 448 // 449 // In step (2), the code emitted is usually similar to the code that would be 450 // emitted for an *unfused* version of the primary node, except that 451 // 452 // - when the primary node reads an element of one of its operands, instead 453 // of loading the value from memory, it *computes* the value based on the 454 // contents of the fused computation. 455 // - when the primary node outputs a value, instead of storing it to memory, 456 // it forwards the value to its users, which then perform additional 457 // computations before the value is finally stored to memory at the root of 458 // the fusion node. 459 // 460 // An HloInstruction's FusionKind helps us find the kFusion instruction's 461 // primary node, and can also affect how we generate code in step (2). 462 // 463 // - kInput: The primary node is the root of the fused instruction. 464 // 465 // - kOutput: The primary node is not the root of the fused instruction. 466 // This fusion kind requires that one operand buffer of the fusion 467 // instruction be able to alias the output buffer. This constraint is 468 // usually enough to let backends find the primary node unambiguously. 469 // 470 // - kLoop: The primary node is the root of the fused computation, but, 471 // unlike in input fusion, we prescribe a specific implementation for 472 // codegen. Rather than generating code that looks like the code we'd emit 473 // for an unfused version of the primary/root node, we emit code that 474 // generates one element of the root at a time. 475 // 476 // - kCustom: Custom category for backend-specific fusions that don't fit 477 // into the above patterns. 478 // 479 // Not all backends support all fusion kinds, and given a particular fused 480 // computation, it's not in general safe to change its fusion kind. Creation 481 // of fusion nodes is always backend-specific. 482 // 483 // For elementwise ops (e.g. kAdd), most backends would emit a 484 // one-element-at-a-time implementation for the unfused version, so loop 485 // fusion and input fusion are probably equivalent if the root node is 486 // elementwise. They're not necessarily equivalent e.g. for kReduce, where an 487 // implementation might emit something more sophisticated for an unfused or 488 // input-fusion reduce, but will emit the naive code that reduces one element 489 // at a time for loop fusion with a reduce as the root. 490 // 491 // Another way to think of loop fusion is that it's equivalent to input 492 // fusion, but where the root node is an implicit identity node, whose 493 // unfused implementation is "read one element, write one element". 494 // 495 // TODO(b/79869434): This categorization scheme is not great. For one thing, 496 // input and loop fusion are basically the same thing: There is no reason for 497 // the HLO to encode backend-specific decisions about how e.g. a reduce that's 498 // the root of a fusion should be lowered. In addition, this scheme as 499 // written doesn't work for multi-output fusion, where the primary node is 500 // never actually the root (which is a kTuple instruction that gathers the 501 // multiple outputs of the fusion). 502 enum class FusionKind { 503 kLoop, 504 kInput, 505 kOutput, 506 kCustom, 507 }; 508 509 inline static constexpr char kMainExecutionThread[] = "main"; 510 ~HloInstruction()511 virtual ~HloInstruction() { DetachFromOperandsAndUsers(); } 512 513 // Detaches an instruction from its operands and users. That is, remove the 514 // instruction from each operand's user set and user's operand set. 515 void DetachFromOperandsAndUsers(); 516 517 // Adds a derived instruciton to the parent compuation of this instruction. 518 // Also update setup the new instruction as a derived instruction. 519 HloInstruction* AddInstruction( 520 std::unique_ptr<HloInstruction> derived_instruction); 521 522 // Creates an instruction from the given proto. Arguments: 523 // 524 // proto: the proto to convert from. 525 // instruction_map: a map from instruction id to HloInstruction*. This map 526 // must contain all operands of the newly constructed instruction. 527 // computation_map: a map from computation id to HloComputation*. This map 528 // must contain all computations which the newly constructed instruction 529 // calls. 530 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 531 const HloInstructionProto& proto, 532 const absl::flat_hash_map<int64_t, HloInstruction*>& instruction_map, 533 const absl::flat_hash_map<int64_t, HloComputation*>& computation_map = {}, 534 bool prohibit_empty_literal = true); 535 536 // Creates a parameter-retrieving instruction. 537 static std::unique_ptr<HloInstruction> CreateParameter( 538 int64_t parameter_number, const Shape& shape, const std::string& name); 539 540 // Creates a literal constant instruction. 541 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); 542 543 // Creates an Iota instruction. 544 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, 545 int64_t iota_dimension); 546 547 // Creates a get tuple element instruction. 548 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 549 const Shape& shape, HloInstruction* operand, int64_t index); 550 551 // Creates a get tuple element instruction. 552 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 553 HloInstruction* operand, int64_t index); 554 555 // Creates a random number generation instruction that fills a shape with 556 // random numbers from a given distribution. 557 // 558 // The parameters to the instruction are interpreted as follows: 559 // 560 // - If `distribution` is RNG_UNIFORM, generates a number in range 561 // [param0, param1). 562 // 563 // - If `distribution` is RNG_NORMAL, generates a normally-distributed value 564 // with mean `param0` and standard deviation `param1`. 565 static std::unique_ptr<HloInstruction> CreateRng( 566 const Shape& shape, RandomDistribution distribution, 567 absl::Span<HloInstruction* const> parameters); 568 569 // Creates a stateless random bit generator instruction that fills a shape 570 // with random bits. 571 static std::unique_ptr<HloInstruction> CreateRngBitGenerator( 572 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm); 573 574 // Creates an instruction to update the random number generator state to 575 // reflect the new state after `delta` units of 32 random bits are generated 576 // and returns the old state. 577 static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState( 578 const Shape& shape, int64_t delta); 579 580 // Creates a unary instruction (one operand). 581 // Precondition: opcode must be a legitimate unary operation. 582 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 583 HloOpcode opcode, 584 HloInstruction* operand); 585 586 // Creates a binary instruction (two operands). 587 // Precondition: opcode must be a legitimate binary operation. 588 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 589 HloOpcode opcode, 590 HloInstruction* lhs, 591 HloInstruction* rhs); 592 593 // Creates a ternary instruction (three operands). 594 // Precondition: opcode must be a legitimate ternary operation. 595 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 596 HloOpcode opcode, 597 HloInstruction* lhs, 598 HloInstruction* rhs, 599 HloInstruction* ehs); 600 601 // Creates a variadic instruction (variable number of operands). 602 // Precondition: opcode must be a legitimate variadic operation. 603 static std::unique_ptr<HloInstruction> CreateVariadic( 604 const Shape& shape, HloOpcode opcode, 605 absl::Span<HloInstruction* const> operands); 606 607 // Creates a map instruction, where the computation (given by the handle) is 608 // applied element-wise to every element in operands (across the operands, 609 // at a given index) 610 static std::unique_ptr<HloInstruction> CreateMap( 611 const Shape& shape, absl::Span<HloInstruction* const> operands, 612 HloComputation* map_computation); 613 614 // Creates a convolution op, where rhs is the convolutional filter 615 // and window describes how the filter is applied to lhs. 616 static std::unique_ptr<HloInstruction> CreateConvolve( 617 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 618 int64_t feature_group_count, int64_t batch_group_count, 619 const Window& window, 620 const ConvolutionDimensionNumbers& dimension_numbers, 621 const PrecisionConfig& precision_config); 622 623 // Creates an FFT op, of the type indicated by fft_type. 624 static std::unique_ptr<HloInstruction> CreateFft( 625 const Shape& shape, HloInstruction* operand, FftType fft_type, 626 absl::Span<const int64_t> fft_length); 627 628 // Creates an asynchronous start, update, and done op. 629 static std::unique_ptr<HloInstruction> CreateAsyncStart( 630 const Shape& shape, absl::Span<HloInstruction* const> operands, 631 HloComputation* async_computation, 632 std::optional<int64_t> async_group_id = std::nullopt, 633 absl::string_view async_execution_thread = kMainExecutionThread); 634 static std::unique_ptr<HloInstruction> CreateAsyncUpdate( 635 const Shape& shape, HloInstruction* operand, 636 HloComputation* async_computation, 637 std::optional<int64_t> async_group_id = std::nullopt, 638 absl::string_view async_execution_thread = kMainExecutionThread); 639 static std::unique_ptr<HloInstruction> CreateAsyncDone( 640 const Shape& shape, HloInstruction* operand, 641 HloComputation* async_computation, 642 std::optional<int64_t> async_group_id = std::nullopt, 643 absl::string_view async_execution_thread = kMainExecutionThread); 644 645 // Creates a copy-start op, indicating whether this is a cross-program 646 // prefetch or not. 647 static std::unique_ptr<HloInstruction> CreateCopyStart( 648 const Shape& shape, HloInstruction* operand, 649 bool is_cross_program_prefetch = false); 650 651 // Creates a compare op, performing the comparison specified in direction. 652 static std::unique_ptr<HloInstruction> CreateCompare( 653 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 654 Comparison::Direction direction, 655 std::optional<Comparison::Type> type = std::nullopt); 656 657 static std::unique_ptr<HloInstruction> CreateTriangularSolve( 658 const Shape& shape, HloInstruction* a, HloInstruction* b, 659 const TriangularSolveOptions& options); 660 661 static std::unique_ptr<HloInstruction> CreateCholesky( 662 const Shape& shape, HloInstruction* a, const CholeskyOptions& options); 663 664 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 665 // dimensions specified in 'dimension_numbers'. 666 static std::unique_ptr<HloInstruction> CreateDot( 667 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 668 const DotDimensionNumbers& dimension_numbers, 669 const PrecisionConfig& precision_config); 670 671 // Creates a reduce-precision op, where operand is the data to reduce in 672 // precision, and exponent_bits and mantissa_bits describe the precision to 673 // reduce it to. 674 static std::unique_ptr<HloInstruction> CreateReducePrecision( 675 const Shape& shape, HloInstruction* operand, const int exponent_bits, 676 const int mantissa_bits); 677 678 // Creates an all-gather op, which concats the operands of all participants 679 // along all_gather_dimension. The replica_groups, channel_id, and 680 // use_global_device_ids arguments are identical to those in all-reduce, 681 // except that the order of the group members determines the concatenation 682 // order of inputs from different participants. 683 static std::unique_ptr<HloInstruction> CreateAllGather( 684 const Shape& shape, absl::Span<HloInstruction* const> operands, 685 int64_t all_gather_dimension, 686 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 687 const std::optional<int64_t>& channel_id, bool use_global_device_ids); 688 689 // Creates an all-gather-start op, which concats the operands of all 690 // participants 691 // along all_gather_dimension. The replica_groups, channel_id, and 692 // use_global_device_ids arguments are identical to those in all-reduce, 693 // except that the order of the group members determines the concatenation 694 // order of inputs from different participants. Needs to be used in 695 // conjunction of a AllGatherDone op that synchronizes and returns the result. 696 static std::unique_ptr<HloInstruction> CreateAllGatherStart( 697 const Shape& shape, absl::Span<HloInstruction* const> operands, 698 int64_t all_gather_dimension, 699 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 700 const std::optional<int64_t>& channel_id, bool use_global_device_ids); 701 702 // Creates a cross replica reduction op. 703 // 704 // `reduction_computation`: the reduction function. 705 // 706 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 707 // empty, all replicas belong to one group in the order of 0 - (n-1). 708 // Allreduce will be applied within subgroups. 709 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 710 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 711 // 712 // `channel_id`: for Allreduce nodes from different modules, if 713 // they have the same channel_id, they will be 'Allreduce'd. If 714 // empty, Allreduce will not be applied cross modules. 715 static std::unique_ptr<HloInstruction> CreateAllReduce( 716 const Shape& shape, absl::Span<HloInstruction* const> operands, 717 HloComputation* reduce_computation, 718 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 719 const std::optional<int64_t>& channel_id, bool use_global_device_ids); 720 721 // Creates a reduce-scatter operation which reduces its inputs across the 722 // given replica groups and then scatters the reduced data across the N 723 // participants. 724 static std::unique_ptr<HloInstruction> CreateReduceScatter( 725 const Shape& shape, absl::Span<HloInstruction* const> operands, 726 HloComputation* reduce_computation, 727 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 728 const std::optional<int64_t>& channel_id, bool use_global_device_ids, 729 int64_t scatter_dimension); 730 731 // Creates an asynchronous cross replica reduction op. 732 // 733 // `reduction_computation`: the reduction function. 734 // 735 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 736 // empty, all replicas belong to one group in the order of 0 - (n-1). 737 // Allreduce will be applied within subgroups. 738 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 739 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 740 // 741 // `channel_id`: for Allreduce nodes from different modules, if 742 // they have the same channel_id, they will be 'Allreduce'd. If 743 // empty, Allreduce will not be applied cross modules. 744 static std::unique_ptr<HloInstruction> CreateAllReduceStart( 745 const Shape& shape, absl::Span<HloInstruction* const> operands, 746 HloComputation* reduce_computation, 747 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 748 const std::optional<int64_t>& channel_id, bool use_global_device_ids); 749 750 // An all-to-all op takes N array operands of the same shape and scatters them 751 // to N replicas. Each replica gathers the results into a tuple. 752 // 753 // For example, suppose we have 3 replicas, with replica i passing inputs 754 // [a_i, b_i, c_i] to its all-to-all op. Then the resulting tuples are 755 // 756 // replica 0: (a_0, a_1, a_2) 757 // replica 1: (b_0, b_1, b_2) 758 // replica 2: (c_0, c_1, c_2). 759 // 760 // If replica_groups is set, the op is sharded and the replicas are permuted. 761 // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}. 762 // Then each replica passes two operands, say [a_i, b_i], and the result is 763 // 764 // replica 0: (b_3, b_0) 765 // replica 1: (a_1, a_2) 766 // replica 2: (b_1, b_2) 767 // replica 3: (a_3, a_0). 768 // 769 // All replica groups must have the same number of elements, and the number of 770 // operands must be equal to the size of one replica group. Each replica must 771 // appear in exactly one group. 772 // 773 // Note that in addition to supporting this instruction, XlaBuilder also 774 // supports a higher-level instruction which takes one input and slices it, 775 // performs AllToAll and then concatenates the results into a single array. 776 static std::unique_ptr<HloInstruction> CreateAllToAll( 777 const Shape& shape, absl::Span<HloInstruction* const> operands, 778 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 779 const std::optional<int64_t>& channel_id, 780 const std::optional<int64_t>& split_dimension = std::nullopt); 781 782 // Creates a communication instruction that permutes data cross replicas. 783 // Data is sent/received according to the (source_replica_id, 784 // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a 785 // target_replica_id in any pair, the output on that replica is a tensor 786 // consists of 0(s) in `shape`. 787 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 788 const Shape& shape, HloInstruction* operand, 789 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 790 const std::optional<int64_t>& channel_id); 791 792 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 793 const Shape& shape, HloInstruction* input, HloInstruction* output, 794 HloInstruction* input_start_indices, HloInstruction* output_start_indices, 795 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, 796 absl::Span<const std::vector<int64_t>> slice_sizes, 797 const std::optional<int64_t>& channel_id); 798 799 // Creates a communication instruction that initiates the start of 800 // CollectivePermute. 801 static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( 802 const Shape& shape, HloInstruction* operand, 803 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 804 const std::optional<int64_t>& channel_id); 805 806 static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( 807 const Shape& shape, HloInstruction* input, HloInstruction* output, 808 HloInstruction* input_start_indices, HloInstruction* output_start_indices, 809 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, 810 absl::Span<const std::vector<int64_t>> slice_sizes, 811 const std::optional<int64_t>& channel_id); 812 813 // Creates an instruction that returns a U32 replica ID. 814 static std::unique_ptr<HloInstruction> CreateReplicaId( 815 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 816 817 // Creates an instruction that returns a U32 partition ID. 818 static std::unique_ptr<HloInstruction> CreatePartitionId( 819 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 820 821 // Creates a conversion instruction, where operand is the data to convert and 822 // shape is the target shape for the conversion. 823 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 824 HloInstruction* operand); 825 826 // Creates a bitcast instruction, where operand is the data to 827 // convert and shape is the target shape for the conversion. 828 static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape, 829 HloInstruction* operand); 830 831 // Creates a bitcast conversion instruction, where operand is the data to 832 // convert and shape is the target shape for the conversion. 833 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 834 const Shape& shape, HloInstruction* operand); 835 836 // Creates an infeed instruction, which reads data of the given shape from the 837 // Infeed interface of the device. infeed_shape is the shape of the data 838 // received from the infeed *not* the shape of the infeed instruction which 839 // is a tuple containing the infeed_shape and the TOKEN. 840 static std::unique_ptr<HloInstruction> CreateInfeed( 841 const Shape& infeed_shape, HloInstruction* token_operand, 842 const std::string& config); 843 844 // Creates an outfeed instruction, which outputs data. outfeed_shape is the 845 // shape of the data being outfed *not* the shape of the outfeed instruction 846 // which is a TOKEN. 847 static std::unique_ptr<HloInstruction> CreateOutfeed( 848 const Shape& outfeed_shape, HloInstruction* operand, 849 HloInstruction* token_operand, absl::string_view outfeed_config); 850 851 // Creates an asynchronous send instruction with the given channel id, which 852 // initiates sending the operand data to a unique receive instruction in 853 // another computation that has the same channel id. If is_host_transfer is 854 // true, then this Send operation transfers data to the host. 855 static std::unique_ptr<HloInstruction> CreateSend( 856 HloInstruction* operand, HloInstruction* token, int64_t channel_id, 857 bool is_host_transfer = false); 858 859 // Blocks until data transfer for the Send instruction (operand) is complete. 860 // The operand must be kSend. 861 static std::unique_ptr<HloInstruction> CreateSendDone( 862 HloInstruction* operand, bool is_host_transfer = false); 863 864 // Creates an asynchronous receive instruction with the given channel id, 865 // which allocates resources to receive data of the given shape from a unique 866 // send instruction in another computation that has the same channel id. If 867 // is_host_transfer is true, then this Recv operation transfers data from the 868 // host. 869 static std::unique_ptr<HloInstruction> CreateRecv( 870 const Shape& shape, HloInstruction* token, int64_t channel_id, 871 bool is_host_transfer = false); 872 873 // Blocks until data transfer for the Recv instruction (operand) is complete 874 // and returns the receive buffer. The operand must be kRecv. 875 static std::unique_ptr<HloInstruction> CreateRecvDone( 876 HloInstruction* operand, bool is_host_transfer = false); 877 878 // Creates a slice instruction, where the operand is sliced by the given 879 // start/limit indices. 880 static std::unique_ptr<HloInstruction> CreateSlice( 881 const Shape& shape, HloInstruction* operand, 882 absl::Span<const int64_t> start_indices, 883 absl::Span<const int64_t> limit_indices, 884 absl::Span<const int64_t> strides); 885 886 // Creates a slice instruction, where the first operand is sliced by 887 // start indices specified in the second operand, and by size specified in 888 // 'slice_sizes'. 889 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 890 const Shape& shape, HloInstruction* operand, 891 absl::Span<HloInstruction* const> start_indices, 892 absl::Span<const int64_t> slice_sizes); 893 894 // Creates a dynamic update slice instruction, which updates a slice 895 // of 'operand' with 'update' and 'start_indices'. 896 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 897 const Shape& shape, HloInstruction* operand, HloInstruction* update, 898 absl::Span<HloInstruction* const> start_indices); 899 900 // Creates a concatenate instruction, where the operands are concatenated on 901 // the provided dimension. 902 static std::unique_ptr<HloInstruction> CreateConcatenate( 903 const Shape& shape, absl::Span<HloInstruction* const> operands, 904 int64_t dimension); 905 906 // Creates a reduce instruction, where the computation (given by the handle) 907 // is applied successively to every element in operand. For example, let f be 908 // the function to apply, which takes 2 arguments, an accumulator and the 909 // current value. Let init be an initial value (which is normally chosen to be 910 // the identity element for f, e.g. 0 if f is addition). 911 // Then the reduce HLO will compute: 912 // f(f(init, value0), value1), ...) 913 static std::unique_ptr<HloInstruction> CreateReduce( 914 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 915 absl::Span<const int64_t> dimensions_to_reduce, 916 HloComputation* reduce_computation); 917 918 // A more general, multiple-argument version of the above. 919 // The function to apply, f, now takes N arguments: 920 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 921 // init_valueN], and returns an N-tuple. The performed computation is (for 922 // commutative and associative f operators) equivalent to: 923 // 924 // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) 925 // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, 926 // ..., inputN.value1) 927 // ... 928 static std::unique_ptr<HloInstruction> CreateReduce( 929 const Shape& shape, absl::Span<HloInstruction* const> operands, 930 absl::Span<HloInstruction* const> init_values, 931 absl::Span<const int64_t> dimensions_to_reduce, 932 HloComputation* reduce_computation); 933 934 // Helper version where the operands are given by a single instruction which 935 // either is a tuple of size `init_values`, or a single input, in which case 936 // size of `init_values` is one. 937 static std::unique_ptr<HloInstruction> CreateReduce( 938 const Shape& shape, HloInstruction* tuple_of_instructions, 939 absl::Span<HloInstruction* const> init_values, 940 absl::Span<const int64_t> dimensions_to_reduce, 941 HloComputation* reduce_computation); 942 943 // Creates a reduce-window instruction, where the computation (given 944 // by the handle) is applied window-wise at each valid window 945 // position in the operand. 946 static std::unique_ptr<HloInstruction> CreateReduceWindow( 947 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 948 const Window& window, HloComputation* reduce_computation); 949 950 // A more general, multiple-argument version of the above. 951 // The reduce_computation being applied,now takes N arguments: 952 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 953 // valueN], and returns an N-tuple. The operands and init_values now each 954 // contain a span of N input arrays and n initial values. 955 static std::unique_ptr<HloInstruction> CreateReduceWindow( 956 const Shape& shape, absl::Span<HloInstruction* const> operands, 957 absl::Span<HloInstruction* const> init_values, const Window& window, 958 HloComputation* reduce_computation); 959 960 // Creates a batch-norm-training instruction. 961 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 962 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 963 HloInstruction* offset, float epsilon, int64_t feature_index); 964 965 // Creates a batch-norm-inference instruction. 966 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 967 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 968 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 969 float epsilon, int64_t feature_index); 970 971 // Creates a batch-norm-grad instruction. 972 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 973 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 974 HloInstruction* mean, HloInstruction* variance, 975 HloInstruction* grad_output, float epsilon, int64_t feature_index); 976 977 // Creates a scatter computation that scatters the `source` array to the 978 // selected indices of each window. 979 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 980 const Shape& shape, HloInstruction* operand, HloComputation* select, 981 const Window& window, HloInstruction* source, HloInstruction* init_value, 982 HloComputation* scatter); 983 984 // Creates a broadcast instruction. 985 static std::unique_ptr<HloInstruction> CreateBroadcast( 986 const Shape& shape, HloInstruction* operand, 987 absl::Span<const int64_t> broadcast_dimensions); 988 989 // Creates a sequence of instructions that performs an explicit broadcast of 990 // the operand to the target shape. 991 // 992 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 993 // returned as a unique_ptr for API consistency with other factory methods in 994 // this interface. 995 // 996 // TODO(b/72173833) Ideally HloComputations would always be present, and so 997 // the adder being passed by the caller would not be necessary. 998 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 999 const Shape& output_shape, HloInstruction* operand, 1000 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 1001 adder); 1002 1003 // Creates a pad instruction, where the operand is padded on the edges and 1004 // between the elements with the given padding value. 1005 static std::unique_ptr<HloInstruction> CreatePad( 1006 const Shape& shape, HloInstruction* operand, 1007 HloInstruction* padding_value, const PaddingConfig& padding_config); 1008 1009 // Creates a reshape instruction, where the operand is flattened row-major 1010 // order and then reshaped to the given result shape. 1011 static std::unique_ptr<HloInstruction> CreateReshape( 1012 const Shape& shape, HloInstruction* operand, 1013 int64_t inferred_dimension = -1); 1014 1015 // Creates a dynamic reshape instruction. Similar to reshape but dynamic 1016 // dimensions sizes are provided as additional variadic arguments. 1017 // 1018 // Precondition: dim_sizes.size() == shape.rank() 1019 static std::unique_ptr<HloInstruction> CreateDynamicReshape( 1020 const Shape& shape, HloInstruction* data_operand, 1021 absl::Span<HloInstruction* const> dim_sizes); 1022 1023 // Creates a transpose instruction which permutes the operand dimensions. 1024 static std::unique_ptr<HloInstruction> CreateTranspose( 1025 const Shape& shape, HloInstruction* operand, 1026 absl::Span<const int64_t> dimensions); 1027 1028 // Creates a n-ary sort op with a 'compare' computation which is used for 1029 // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, 1030 // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at 1031 // specific index positions which should be compared, and should return a 1032 // PRED. 'is_stable' specifies whether stable sorting is required. 1033 static std::unique_ptr<HloInstruction> CreateSort( 1034 const Shape& shape, int64_t dimension, 1035 absl::Span<HloInstruction* const> operands, HloComputation* compare, 1036 bool is_stable); 1037 1038 // Creates a while instruction, given a condition computation, a body 1039 // computation, and the initial value for the input of the computations. For 1040 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 1041 // corresponds to the C code below. 1042 // int32_t i = 1; int32_t result = while(i < 1000) { i = i * 2 } 1043 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 1044 HloComputation* condition, 1045 HloComputation* body, 1046 HloInstruction* init); 1047 1048 static std::unique_ptr<HloInstruction> CreateConditional( 1049 const Shape& shape, HloInstruction* pred, 1050 HloInstruction* true_computation_arg, HloComputation* true_computation, 1051 HloInstruction* false_computation_arg, HloComputation* false_computation); 1052 1053 static std::unique_ptr<HloInstruction> CreateConditional( 1054 const Shape& shape, HloInstruction* branch_index, 1055 absl::Span<HloComputation* const> branch_computations, 1056 absl::Span<HloInstruction* const> branch_computation_args); 1057 1058 static std::unique_ptr<HloInstruction> CreateGather( 1059 const Shape& shape, HloInstruction* operand, 1060 HloInstruction* start_indices, 1061 const GatherDimensionNumbers& gather_dim_numbers, 1062 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted); 1063 1064 static std::unique_ptr<HloInstruction> CreateScatter( 1065 const Shape& shape, absl::Span<HloInstruction* const> operands, 1066 HloInstruction* scatter_indices, 1067 absl::Span<HloInstruction* const> updates, 1068 HloComputation* update_computation, 1069 const ScatterDimensionNumbers& scatter_dim_numbers, 1070 bool indices_are_sorted, bool unique_indices); 1071 1072 static std::unique_ptr<HloInstruction> CreateScatter( 1073 const Shape& shape, HloInstruction* operand, 1074 HloInstruction* scatter_indices, HloInstruction* updates, 1075 HloComputation* update_computation, 1076 const ScatterDimensionNumbers& scatter_dim_numbers, 1077 bool indices_are_sorted, bool unique_indices); 1078 1079 // Creates a kDomain instruction which delimits an HLO domain which have 1080 // the provided user and operand side metadata. 1081 static std::unique_ptr<HloInstruction> CreateDomain( 1082 const Shape& shape, HloInstruction* operand, 1083 std::unique_ptr<DomainMetadata> operand_side_metadata, 1084 std::unique_ptr<DomainMetadata> user_side_metadata); 1085 1086 // Creates a fusion instruction. A fusion instruction contains one or more 1087 // fused instructions forming an expression with a single root 1088 // "fused_root". Additional instructions can be added to the fusion 1089 // instruction with the method FuseInstruction. 1090 static std::unique_ptr<HloInstruction> CreateFusion( 1091 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 1092 1093 static std::unique_ptr<HloInstruction> CreateFusion( 1094 const Shape& shape, FusionKind fusion_kind, 1095 absl::Span<HloInstruction* const> operands, 1096 HloComputation* fusion_computation, absl::string_view prefix = ""); 1097 1098 // Creates a call instruction that applies the given computation on the given 1099 // operands. "shape" is the resultant shape. 1100 static std::unique_ptr<HloInstruction> CreateCall( 1101 const Shape& shape, HloInstruction* called_computation_root); 1102 1103 static std::unique_ptr<HloInstruction> CreateCall( 1104 const Shape& shape, absl::Span<HloInstruction* const> operands, 1105 HloComputation* computation); 1106 1107 // Creates a custom call instruction that applies the given custom call target 1108 // to the given operands. "opaque" can be an arbitrary string with a 1109 // backend-specific interpretation. "shape" is the resultant shape. 1110 static std::unique_ptr<HloInstruction> CreateCustomCall( 1111 const Shape& shape, absl::Span<HloInstruction* const> operands, 1112 absl::string_view custom_call_target, std::string opaque = "", 1113 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1114 1115 // Overload with a to_apply computation. 1116 static std::unique_ptr<HloInstruction> CreateCustomCall( 1117 const Shape& shape, absl::Span<HloInstruction* const> operands, 1118 HloComputation* to_apply, absl::string_view custom_call_target, 1119 std::string opaque = "", 1120 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1121 1122 // Overload with multiple computations. The called computations can have 1123 // different function signatures. 1124 static std::unique_ptr<HloInstruction> CreateCustomCall( 1125 const Shape& shape, absl::Span<HloInstruction* const> operands, 1126 absl::Span<HloComputation* const> called_computations, 1127 absl::string_view custom_call_target, std::string opaque = "", 1128 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1129 1130 // Overload which constrains the layouts of the operand and result. 'shape' 1131 // and 'operand_shapes_with_layout' must have layouts. 1132 // 'operand_shapes_with_layout' must have a compatible element for each 1133 // operand. 1134 static std::unique_ptr<HloInstruction> CreateCustomCall( 1135 const Shape& shape, absl::Span<HloInstruction* const> operands, 1136 absl::string_view custom_call_target, 1137 absl::Span<const Shape> operand_shapes_with_layout, 1138 std::string opaque = "", 1139 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1140 1141 // Creates a tuple instruction with the given elements. This is a convenience 1142 // wrapper around CreateVariadic. 1143 static std::unique_ptr<HloInstruction> CreateTuple( 1144 absl::Span<HloInstruction* const> elements); 1145 1146 // Creates a reverse instruction, which reverses the order of the elements 1147 // in the specified dimensions. 1148 static std::unique_ptr<HloInstruction> CreateReverse( 1149 const Shape& shape, HloInstruction* operand, 1150 absl::Span<const int64_t> dimensions); 1151 1152 // Creates a Afterall instruction used for joining or creating new values of 1153 // token type which thread through side-effecting operations. Operands must 1154 // all be tokens, and there must be at least one operand. 1155 static std::unique_ptr<HloInstruction> CreateAfterAll( 1156 absl::Span<HloInstruction* const> operands); 1157 1158 // Creates an AfterAll instruction which creates a token type out of thin air 1159 // (no operands). This is a separate method from CreateAfterAll to facility 1160 // the removal of operand-less AfterAll instructions. 1161 // TODO(b/110532604): Remove this capability of creating a token from nothing 1162 // when we plumb a primordial token from the entry computation. 1163 static std::unique_ptr<HloInstruction> CreateToken(); 1164 1165 static std::unique_ptr<HloInstruction> CreateGetDimensionSize( 1166 const Shape& shape, HloInstruction* operand, int64_t dimension); 1167 1168 static std::unique_ptr<HloInstruction> CreateSetDimensionSize( 1169 const Shape& shape, HloInstruction* operand, HloInstruction* val, 1170 int64_t dimension); 1171 1172 static std::unique_ptr<HloInstruction> CreateAddDependency( 1173 HloInstruction* data_operand, HloInstruction* token_operand); 1174 1175 // Returns the opcode for this instruction. opcode()1176 HloOpcode opcode() const { return opcode_; } mutable_opcode()1177 HloOpcode* mutable_opcode() { return &opcode_; } 1178 1179 // Returns whether this instruction is the root of its parent computation. 1180 bool IsRoot() const; 1181 1182 // Does this instruction have no users. IsDead()1183 bool IsDead() const { return users_.empty() && !IsRoot(); } 1184 1185 // Returns true if this instruction has a side effect, irrespective of whether 1186 // any called computations may contain an instruction with side effects. 1187 bool HasSideEffectNoRecurse() const; 1188 1189 // Returns true if this instruction has a side effect. An instruction has a 1190 // side effect if it uses certain opcodes or calls a computation with a side 1191 // effect. 1192 bool HasSideEffect() const; 1193 1194 // Returns the result shape of this instruction. 1195 const Shape& shape() const; 1196 1197 // Returns the (mutable) result shape of this instruction. mutable_shape()1198 Shape* mutable_shape() { return &shape_; } 1199 1200 // Returns the ith operand to this instruction. 1201 const HloInstruction* operand(int64_t i) const; 1202 1203 // Returns the ith operand to this instruction. 1204 HloInstruction* mutable_operand(int64_t i); 1205 1206 // Returns the number of operands to this instruction. operand_count()1207 int64_t operand_count() const { return operands_.size(); } 1208 1209 // Returns the vector of operands of this instruction. 1210 using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; operands()1211 const InstructionVector& operands() const { return operands_; } mutable_operands()1212 InstructionVector mutable_operands() { return operands_; } 1213 1214 // Returns the vector of unique operands, in the same order they are found 1215 // within the operand vector. 1216 InstructionVector unique_operands() const; 1217 1218 // Returns the index of 'target' in the operands sequence. 1219 // Precondition: target must be an operand (or a fatal error will occur). 1220 int64_t operand_index(const HloInstruction* target) const; 1221 1222 // Returns the number of users of this instruction. user_count()1223 int64_t user_count() const { return users_.size(); } 1224 1225 // Returns the users of this instruction. users()1226 const std::vector<HloInstruction*>& users() const { return users_; } 1227 1228 // Returns the index of the user in the users() vector. 1229 // 1230 // Precondition: `user` is a user of the instruction. 1231 int64_t UserId(HloInstruction* user); 1232 1233 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)1234 bool IsUserOf(const HloInstruction* instruction) const { 1235 return ContainsKey(instruction->user_map_, this); 1236 } 1237 1238 // Adds a control dependency from this instruction to the given 1239 // instruction. This instruction becomes a control predecessor of 1240 // 'instruction', and 'instruction' becomes a control successor of this 1241 // instruction. Returns an error status if either of the given instructions 1242 // does not belong to the same computation. 1243 // 1244 // This is used to enforce an additional ordering requirement that is not 1245 // captured by normal data dependencies, such as ordering among Send or Recv 1246 // operations to avoid deadlock. 1247 Status AddControlDependencyTo(HloInstruction* instruction); 1248 1249 // Removes a previously added control dependency from this instruction to 1250 // 'instruction'. 1251 Status RemoveControlDependencyTo(HloInstruction* instruction); 1252 1253 // Drops all control predecessors and successors from this HLO instruction. 1254 Status DropAllControlDeps(); 1255 1256 // Copies the control predecessors and successors on this HLO instruction to 1257 // `inst`. Does not do a deep copy so this makes sense only if `inst` and 1258 // this HLO are in the same module. 1259 // 1260 // Depending on the use cases we see in practice, in the future we may 1261 // consider folding the logic here into Clone, CloneWithNewOperands and 1262 // ReplaceAllUsesWith by treating control dependencies like data dependencies. 1263 Status CopyAllControlDepsFrom(const HloInstruction* inst); 1264 1265 // Returns the set of control predecessors (successors) of this 1266 // instruction. Control predecessors (successors) must execute before (after) 1267 // the current instruction. control_predecessors()1268 const std::vector<HloInstruction*>& control_predecessors() const { 1269 return control_predecessors_; 1270 } control_successors()1271 const std::vector<HloInstruction*>& control_successors() const { 1272 return control_successors_; 1273 } 1274 1275 // Returns true if "other" performs the same computation as this instruction. 1276 bool Identical( 1277 const HloInstruction& other, 1278 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1279 eq_operands = std::equal_to<const HloInstruction*>(), 1280 const std::function<bool(const HloComputation*, const HloComputation*)>& 1281 eq_computations = std::equal_to<const HloComputation*>(), 1282 bool layout_sensitive = true) const { 1283 return IdenticalInternal(other, eq_operands, eq_computations, 1284 layout_sensitive, 1285 /*ignore_channel_id_values=*/false, 1286 /*ignore_commutative_operand_order=*/false); 1287 } 1288 1289 // Same as Identical() but ignores the order of commutative operands (e.g. 1290 // considers add(a,b) equal to add(b,a)). 1291 bool IdenticalIgnoringCommutativeOperandOrder( 1292 const HloInstruction& other, 1293 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1294 eq_operands = std::equal_to<const HloInstruction*>(), 1295 const std::function<bool(const HloComputation*, const HloComputation*)>& 1296 eq_computations = std::equal_to<const HloComputation*>(), 1297 bool layout_sensitive = true) const { 1298 return IdenticalInternal(other, eq_operands, eq_computations, 1299 layout_sensitive, 1300 /*ignore_channel_id_values=*/false, 1301 /*ignore_commutative_operand_order=*/true); 1302 } 1303 1304 // Same as Identical() but ignores channel ID value mismatches, as long as 1305 // both have channel IDs or neither has a channel ID. 1306 bool IdenticalIgnoringChannelIdValues( 1307 const HloInstruction& other, 1308 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1309 eq_operands = std::equal_to<const HloInstruction*>(), 1310 const std::function<bool(const HloComputation*, const HloComputation*)>& 1311 eq_computations = std::equal_to<const HloComputation*>(), 1312 bool layout_sensitive = true) const { 1313 return IdenticalInternal(other, eq_operands, eq_computations, 1314 layout_sensitive, 1315 /*ignore_channel_id_values=*/true, 1316 /*ignore_commutative_operand_order=*/true); 1317 } 1318 1319 // Generates a hash value of an HLO instruction. Hash considers 1320 // information on opcode, shape, operands, and typically a root instruction. 1321 // This function returns the same hash value for equivalent HLO instructions, 1322 // with respect to HloInstruction::Identical() method. 1323 // TODO(majnemer): Make the comment here more crisp & accurate. 1324 template <typename H> AbslHashValue(H h,const HloInstruction & hlo)1325 friend H AbslHashValue(H h, const HloInstruction& hlo) { 1326 h = H::combine(std::move(h), hlo.opcode(), hlo.shape()); 1327 1328 if (!hlo.IsCrossModuleAllReduce()) { 1329 for (size_t i = 0; i < hlo.operands().size(); ++i) { 1330 h = H::combine(std::move(h), hlo.operand(i)->shape()); 1331 } 1332 h = H::combine(std::move(h), hlo.operand_count()); 1333 } 1334 1335 if (hlo.opcode() == HloOpcode::kFusion) { 1336 h = H::combine(std::move(h), *hlo.fused_expression_root(), 1337 hlo.fusion_kind(), hlo.fused_instruction_count(), 1338 hlo.fused_parameters().size()); 1339 } 1340 return h; 1341 } 1342 1343 // Returns whether the instruction has a constant operand. 1344 bool HasConstantOperand() const; 1345 1346 // Replaces the use of this instruction in "user" with "new_producer". Note 1347 // that there might be multiple uses of this instruction in "user"; all will 1348 // be replaced. 1349 // 1350 // If user is a fusion instruction, this function will remove any duplicated 1351 // operands of it which could be created due to this replacement. 1352 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 1353 1354 // Same as ReplaceUseWith(), but new_producer can have a different shape. 1355 Status ReplaceUseWithDifferentShape(HloInstruction* user, 1356 HloInstruction* new_producer); 1357 1358 // Same as ReplaceUseWith but only replaces the use at the given operand 1359 // number. 1360 Status ReplaceUseWith(HloInstruction* user, int operand_number, 1361 HloInstruction* new_producer); 1362 Status ReplaceUseWithDifferentShape(HloInstruction* user, int operand_number, 1363 HloInstruction* new_producer); 1364 1365 // Replaces the specified operand with new_operand. The old and new operands 1366 // must have compatible shapes ignoring floating-point precision. 1367 // 1368 // This function does NOT remove duplicated operands even if this instruction 1369 // is a fusion, so that the existing operand numbers do not change. 1370 Status ReplaceOperandWith(int64_t operand_num, HloInstruction* new_operand); 1371 1372 // Same as ReplaceOperandWith(), but new_operand can have a different shape. 1373 Status ReplaceOperandWithDifferentShape(int64_t operand_num, 1374 HloInstruction* new_operand); 1375 1376 // Replaces all uses of this instruction with the new producer. If 1377 // new_producer is a user of this instruction then new_producer remains a use 1378 // of this instruction to avoid introducing cycles into the graph. 1379 // 1380 // If this instruction is the root of its computation, sets the computation's 1381 // root to new_producer. 1382 // 1383 // The new producer must have a compatible shape ignoring floating-point 1384 // precision. 1385 // 1386 // If a user is a fusion instruction, this function will remove any duplicated 1387 // operands of it which could be created due to this replacement. 1388 Status ReplaceAllUsesWith(HloInstruction* new_producer); 1389 1390 // Same as ReplaceAllUsesWith, but new_producer can have a different shape. 1391 Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); 1392 1393 // Same as ReplaceAllUsesWith, but only replace given set of users. 1394 Status ReplaceUsesWith(absl::Span<HloInstruction* const> users, 1395 HloInstruction* new_producer); 1396 Status ReplaceAllUsesWithDifferentShape( 1397 absl::Span<HloInstruction* const> users, HloInstruction* new_producer); 1398 1399 // Performs a postorder DFS visit using this node as the root. If 1400 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 1401 // complete. If ignore_control_predecessors is true, instructions only 1402 // reachable via control dependencies will not be visited, and the postorder 1403 // will not take control dependencies into account. It is as if the control 1404 // dependencies didn't exist in the graph at all. 1405 template <typename HloInstructionPtr> 1406 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 1407 bool call_finish_visit = true, 1408 bool ignore_control_predecessors = false); 1409 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 1410 bool ignore_control_predecessors = false) const { 1411 return const_cast<HloInstruction*>(this)->Accept( 1412 visitor, call_finish_visit, ignore_control_predecessors); 1413 } 1414 1415 // Same as Accept() above, but the order of operand and control predecessor 1416 // visitation is determined by the given operand order; if compare(A, B) == 1417 // true, A is visited before B. 1418 using CompareFunction = 1419 std::function<bool(const HloInstruction*, const HloInstruction*)>; 1420 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 1421 const CompareFunction& operand_order, 1422 bool call_finish_visit = true); 1423 1424 // Visit this instruction and only this instruction with the given visitor. 1425 template <typename HloInstructionPtr> 1426 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 1427 1428 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 1429 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 1430 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 1431 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 1432 const; 1433 LatestNonGteAncestorAndIndex()1434 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 1435 auto rv = 1436 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 1437 return {const_cast<HloInstruction*>(rv.first), rv.second}; 1438 } 1439 1440 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 1441 const HloInstruction* LatestNonGteAncestor() const; 1442 LatestNonGteAncestor()1443 HloInstruction* LatestNonGteAncestor() { 1444 return const_cast<HloInstruction*>( 1445 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 1446 } 1447 1448 // Returns true whether this instruction is effectively a bitcast. Currently, 1449 // this means it either is a bitcast, or it is a transpose that is effectively 1450 // a bitcast. 1451 bool IsEffectiveBitcast() const; 1452 1453 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 1454 // The setter should only be called by HloModule or HloComputation methods. 1455 // 1456 // Precondition: The instruction has a valid to_apply_ field. 1457 HloComputation* to_apply() const; 1458 void set_to_apply(HloComputation* computation); 1459 1460 // Gets/sets the while_condition or while_body HloComputation for While. The 1461 // setters should only be called by HloModule or HloComputation methods. 1462 // 1463 // Precondition: The instruction is a While instruction. 1464 HloComputation* while_condition() const; 1465 HloComputation* while_body() const; 1466 void set_while_condition(HloComputation* computation); 1467 void set_while_body(HloComputation* computation); 1468 1469 HloInstruction* while_init() const; 1470 1471 // Gets/sets the true and false HloComputation for Conditional. 1472 // 1473 // Precondition: The instruction is a predicated Conditional instruction. 1474 HloComputation* true_computation() const; 1475 HloComputation* false_computation() const; 1476 1477 // Gets the branch HloComputations for Conditional. 1478 // 1479 // Precondition: The instruction is a Conditional instruction. 1480 const std::vector<HloComputation*>& branch_computations() const; 1481 int branch_count() const; 1482 HloComputation* branch_computation(int b) const; 1483 // Sets a branch HloComputation for Conditional. 1484 // The setter should only be called by HloModule or HloComputation methods. 1485 // 1486 // Precondition: The instruction is a Conditional instruction. 1487 void set_branch_computation(int b, HloComputation* computation); 1488 1489 // Returns a string for the signature of this instruction if considered as a 1490 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 1491 std::string SignatureString() const; 1492 1493 // Returns a debugging string that represents this instruction. 1494 // 1495 // (We express the default options using an overload rather than a default 1496 // param because gdb ignores default params, but does resolve overloads.) 1497 // 1498 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 1499 // default, backing off on providing full information for very large strings, 1500 // or provide a different name for a ToString-like function that does that. ToString()1501 std::string ToString() const { return ToString(HloPrintOptions()); } 1502 std::string ToString(const HloPrintOptions& options) const; 1503 1504 // Components of the ToString() representation: 1505 1506 // Returns a string representation of the operand list. 1507 std::string OperandsToString(const HloPrintOptions& options) const; 1508 1509 // Returns string representation of op-specific attributes. 1510 std::vector<std::string> ExtraAttributesToString( 1511 const HloPrintOptions& options) const; 1512 1513 // As ToString, but returns a shorter string. 1514 std::string ToShortString() const; 1515 1516 // Prints an instruction to a string. 1517 // 1518 // The canonical string representation needs to name operands and instruction 1519 // names in a consistent way. This is implemented through the 1520 // canonical_name_map. 1521 std::string ToStringWithCanonicalNameMap( 1522 const HloPrintOptions& options, 1523 CanonicalNameMap* canonical_name_map) const; 1524 1525 // Returns a serialized representation of this instruction. 1526 virtual HloInstructionProto ToProto() const; 1527 1528 // Returns a category for the HLO. This could be something like "convolution" 1529 // or "elementwise". 1530 virtual std::string ToCategory() const; 1531 1532 // Returns true if this instruction is fused, ie contained within a fusion 1533 // instruction. 1534 bool IsFused() const; 1535 1536 bool IsLoopFusion() const; 1537 bool IsInputFusion() const; 1538 bool IsOutputFusion() const; 1539 bool IsCustomFusion() const; 1540 1541 // Returns true if this instruction can be legally fused into a fusion 1542 // instruction. 1543 bool IsFusible() const; 1544 1545 bool IsCustomCall(absl::string_view target) const; 1546 1547 // Returns the sharding applied to this operator. 1548 // REQUIRES: has_sharding() is true. sharding()1549 const HloSharding& sharding() const { 1550 CHECK(has_sharding()); 1551 return *sharding_; 1552 } sharding_ptr()1553 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } 1554 1555 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)1556 const HloSharding& sharding_or_default(const HloSharding& default_) const { 1557 return sharding_ ? *sharding_ : default_; 1558 } 1559 // Returns the sharding unique device, if any. sharding_unique_device()1560 std::optional<int64_t> sharding_unique_device() const { 1561 if (sharding_ == nullptr) { 1562 return std::optional<int64_t>(); 1563 } 1564 return sharding_->UniqueDevice(); 1565 } 1566 // Sets the sharding of this operator. Should only be called by HloModule or 1567 // HloComputation methods. set_sharding(const HloSharding & sharding)1568 void set_sharding(const HloSharding& sharding) { 1569 sharding_ = std::make_shared<const HloSharding>(sharding); 1570 } set_sharding(std::shared_ptr<const HloSharding> sharding)1571 void set_sharding(std::shared_ptr<const HloSharding> sharding) { 1572 sharding_ = std::move(sharding); 1573 } 1574 void set_single_sharding(const HloSharding& sharding); 1575 // Sets a sharding that assigns the current instruction to device. set_device_sharding(int64_t device)1576 void set_device_sharding(int64_t device) { 1577 set_single_sharding(HloSharding::AssignDevice(device)); 1578 } 1579 // Remove any sharding from this operator. clear_sharding()1580 void clear_sharding() { sharding_ = nullptr; } 1581 // Return true if this operator has a sharding assigned. has_sharding()1582 bool has_sharding() const { return sharding_ != nullptr; } 1583 // Checks whether the instruction has compatible sharding with the other 1584 // instruction. has_compatible_sharding(const HloInstruction * other)1585 bool has_compatible_sharding(const HloInstruction* other) const { 1586 if (!has_sharding()) { 1587 return !other->has_sharding(); 1588 } 1589 return other->has_sharding() ? sharding() == other->sharding() : false; 1590 } 1591 1592 // When creating a new instruction which either replaces, or shifts up (kCopy 1593 // insertion case), another instruction, we need to make sure the certain 1594 // properties of the new instruction are copied into the derived one. As of 1595 // today, the metadata and sharding will be propagated to the derived 1596 // instruction. 1597 void SetupDerivedInstruction(HloInstruction* derived_instruction) const; 1598 1599 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1600 // operands. After creation the clone has no uses. "this" (the instruction 1601 // cloned from) is not changed. Suffix is the string to append to the name of 1602 // the instruction to form the name of the cloned instruction. 1603 // Ignores the control predecessors and successors of this HLO instruction. 1604 std::unique_ptr<HloInstruction> Clone( 1605 const std::string& suffix = "clone", 1606 HloCloneContext* context = nullptr) const; 1607 1608 // Clones the HLO instruction as above but with new shape. 1609 std::unique_ptr<HloInstruction> CloneWithNewShape( 1610 const Shape& shape, const std::string& suffix = "clone", 1611 HloCloneContext* context = nullptr) const; 1612 1613 // Clones the HLO instruction as above but with new shape and operands. 1614 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1615 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1616 HloCloneContext* context = nullptr) const; 1617 1618 // Returns the computations this instruction directly calls (if any). called_computations()1619 const std::vector<HloComputation*>& called_computations() const { 1620 return called_computations_; 1621 } 1622 1623 // Replaces all called computations based on a map function. This is needed 1624 // when we clone hlo_computations and want to let the instructions to point 1625 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1626 void ReplaceCalledComputations( 1627 std::function<HloComputation*(HloComputation*)> map_function) { 1628 for (int64_t i = 0; i < called_computations_.size(); ++i) { 1629 called_computations_[i] = map_function(called_computations_[i]); 1630 } 1631 } 1632 1633 // Clears out the called computations. 1634 // 1635 // This is, in particular, necessary when inlining function bodies into their 1636 // caller. If there were side-effecting operations in the called computations, 1637 // the call itself is considered side-effecting and thus cannot be removed. By 1638 // clearing out the computations, we reflect the fact that all side-effecting 1639 // properties have been reflected in the caller, and make the call HLO 1640 // removable. ClearCalledComputations()1641 virtual void ClearCalledComputations() { called_computations_.clear(); } 1642 1643 // Returns true if this instruction performs an elementwise operation on 1644 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1645 // to compute the output at index {i_0,i_1,...,i_n}, the only element required 1646 // from the operand (if any) is the element at {i_0,i_1,...,i_n}. 1647 // 1648 // Note on performance: when this instruction is kFusion, this method, in the 1649 // worst case, scans all fused instructions. We could speed this up by 1650 // caching. 1651 bool IsElementwiseOnOperand(int64_t operand_idx) const; 1652 1653 // Returns true if this instruction is elementwise on all its operands. 1654 bool IsElementwise() const; 1655 1656 static bool IsOpElementwise(HloOpcode opcode); 1657 1658 // Returns true if this is a cross module all-reduce instruction. 1659 bool IsCrossModuleAllReduce() const; 1660 1661 // Returns true if this is a cross-replica all-reduce instruction. 1662 bool IsCrossReplicaAllReduce() const; 1663 1664 // Returns true if this instruction is binary and elementwise. 1665 bool IsElementwiseBinary() const; 1666 1667 // Returns whether this instruction may reuse elements of its `i`th operand. 1668 bool ReusesOperandElements(int64_t i) const; 1669 1670 // Returns the indices that the given operand appear in the operand list of 1671 // this instruction. Note that an instruction can use the same operand 1672 // multiple times. 1673 absl::InlinedVector<int64_t, 4> OperandIndices( 1674 const HloInstruction* operand) const; 1675 1676 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1677 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1678 // indices of the deleted dimensions and the output indices of the inserted 1679 // dimensions. 1680 // 1681 // Precondition: this op must be a reshape. 1682 std::optional<ShapeUtil::ShapeEqualityDescriptor> 1683 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1684 1685 // Gets the string identifier for this instruction. name()1686 const std::string& name() const { return name_; } 1687 1688 // Sets the string identifier for this instruction. Name will be sanitized to 1689 // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". 1690 // 1691 // See also HloModule::SetAndUniquifyInstrName(), which does this plus 1692 // UniqufyName(). SetAndSanitizeName(absl::string_view name)1693 void SetAndSanitizeName(absl::string_view name) { 1694 name_ = NameUniquer::GetSanitizedName(name); 1695 } 1696 1697 // Use the given NameUniquer to select a unique name for the instruction based 1698 // on the instruction's existing name. 1699 // 1700 // See also HloModule::SetAndUniquifyInstrName(), which does this plus 1701 // SetAndSanitizeName(). 1702 void UniquifyName(NameUniquer* name_uniquer); 1703 1704 // Clear the unique ID of the instruction so that it can be re-assigned, such 1705 // as for the purpose of compacting the instruction unique IDs. ClearUniqueIdInternal()1706 void ClearUniqueIdInternal() { unique_id_ = -1; } 1707 1708 // Set the unique id for this instruction to "id" SetUniqueId(int id)1709 void SetUniqueId(int id) { 1710 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1711 CHECK_GE(id, 0); 1712 unique_id_ = id; 1713 } 1714 1715 // Return the unique ID assigned to this node via SetUniqueId (or -1 1716 // if no id has been assigned yet). unique_id()1717 int unique_id() const { return unique_id_; } 1718 1719 template <typename T> 1720 using EnableIfProto = typename std::enable_if_t< 1721 std::is_base_of<tensorflow::protobuf::Message, T>::value>; 1722 1723 // Returns the backend-specific configuration for how a backend should compile 1724 // this HLO. The meaning of the field is backend specific. Not for use before 1725 // or during general HLO optimization, since HLO optimizations do not preserve 1726 // this field and they cannot interpret it due to its meaning being backend 1727 // specific. Except for CustomCall, where this field is preserved and no 1728 // general HLO optimization needs to interpret it. 1729 // 1730 // ConfigProto should be a protobuf Message type. 1731 template <typename ConfigProto, EnableIfProto<ConfigProto>* = nullptr> backend_config()1732 StatusOr<ConfigProto> backend_config() const { 1733 ConfigProto proto; 1734 TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); 1735 return std::move(proto); 1736 } 1737 set_backend_config(const tensorflow::protobuf::Message & proto)1738 Status set_backend_config(const tensorflow::protobuf::Message& proto) { 1739 backend_config_ = proto; 1740 return OkStatus(); 1741 } 1742 has_backend_config()1743 bool has_backend_config() const { return !backend_config_.empty(); } 1744 clear_backend_config()1745 void clear_backend_config() { backend_config_.clear(); } 1746 CopyBackendConfigFrom(const HloInstruction * other)1747 void CopyBackendConfigFrom(const HloInstruction* other) { 1748 backend_config_ = other->backend_config_.Clone(); 1749 } 1750 set_frontend_attributes(FrontendAttributes frontend_attributes)1751 void set_frontend_attributes(FrontendAttributes frontend_attributes) { 1752 frontend_attributes_ = std::move(frontend_attributes); 1753 } 1754 add_frontend_attributes(FrontendAttributes frontend_attributes)1755 void add_frontend_attributes(FrontendAttributes frontend_attributes) { 1756 frontend_attributes_.mutable_map()->insert( 1757 frontend_attributes.map().begin(), frontend_attributes.map().end()); 1758 } 1759 frontend_attributes()1760 const FrontendAttributes& frontend_attributes() const { 1761 return frontend_attributes_; 1762 } 1763 1764 // Getter/setter for raw JSON-encoded backend config. Prefer the 1765 // functions above that deal in proto Messages where possible. raw_backend_config_string()1766 const std::string& raw_backend_config_string() const { 1767 return backend_config_.GetRawString(); 1768 } set_raw_backend_config_string(std::string config_str)1769 void set_raw_backend_config_string(std::string config_str) { 1770 backend_config_ = std::move(config_str); 1771 } 1772 is_default_config()1773 bool is_default_config() const { return is_default_config_; } set_default_config()1774 void set_default_config() { is_default_config_ = true; } 1775 1776 // Returns a string representation of a proto in the format used by 1777 // raw_backend_config_string. 1778 // 1779 // This is morally equivalent to: 1780 // 1781 // HloInstruction instr; 1782 // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); 1783 // return instr.raw_backend_config_string(); 1784 // 1785 static StatusOr<std::string> BackendConfigToRawString( 1786 const tensorflow::protobuf::Message& proto); 1787 1788 // Returns the information used to tell the implementation information about 1789 // what sort of precision is requested. The meaning of the field is backend 1790 // specific. At the moment, it is only supported for kConvolution and kDot. 1791 // Transformations on one kDot or kConvolution to another will preserve this 1792 // information. Transformations to other HLOs will not preserve this 1793 // information but it is presumed that the alternate lowering is strictly 1794 // superior. 1795 // Precondition: opcode must be kConvolution or kDot. 1796 const PrecisionConfig& precision_config() const; 1797 PrecisionConfig* mutable_precision_config(); 1798 1799 // Sets the debug metadata for this instruction, excluding creation_pass_id, 1800 // which should never be copied anywhere. set_metadata(const OpMetadata & metadata)1801 void set_metadata(const OpMetadata& metadata) { 1802 int64_t creation_pass_id = metadata_.creation_pass_id(); 1803 metadata_ = metadata; 1804 metadata_.set_creation_pass_id(creation_pass_id); 1805 } 1806 set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes)1807 void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) { 1808 metadata_.set_size_of_generated_code_in_bytes(code_size_in_bytes); 1809 } set_size_of_memory_working_set_in_bytes(int64_t working_set_size_in_bytes)1810 void set_size_of_memory_working_set_in_bytes( 1811 int64_t working_set_size_in_bytes) { 1812 metadata_.set_size_of_memory_working_set_in_bytes( 1813 working_set_size_in_bytes); 1814 } set_creation_pass_id(int64_t pass_id)1815 void set_creation_pass_id(int64_t pass_id) { 1816 metadata_.set_creation_pass_id(pass_id); 1817 } set_metadata_op_name(const std::string & name)1818 void set_metadata_op_name(const std::string& name) { 1819 metadata_.set_op_name(name); 1820 } set_logical_creation_pass_id(int64_t pass_id)1821 void set_logical_creation_pass_id(int64_t pass_id) { 1822 metadata_.set_logical_creation_pass_id(pass_id); 1823 } set_metadata_replaced_op(absl::string_view replaced_op)1824 void set_metadata_replaced_op(absl::string_view replaced_op) { 1825 metadata_.set_replaced_op(std::string(replaced_op)); 1826 } metadata()1827 const OpMetadata& metadata() const { return metadata_; } 1828 1829 // Set/get the computation containing this instruction. set_parent should only 1830 // be called by HloComputation methods which add/remove instructions to 1831 // computations. set_parent(HloComputation * computation)1832 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1833 const HloComputation* parent() const { return parent_; } parent()1834 HloComputation* parent() { return parent_; } 1835 1836 // Returns the module for this instruction. 1837 HloModule* GetModule() const; 1838 1839 // Get/Set the number of partitions per outer dimension (in order, starting 1840 // with outer-most dimension first). Currently used by the parallel cpu 1841 // backend to partition HLOs into parallel tasks. 1842 // 1843 // TODO(b/62783254) Replace these methods with a more general way to 1844 // annotate HLOs with backend-specific information. outer_dimension_partitions()1845 const std::vector<int64_t>& outer_dimension_partitions() const { 1846 return outer_dimension_partitions_; 1847 } 1848 void set_outer_dimension_partitions( 1849 const std::vector<int64_t>& outer_dimension_partitions); 1850 1851 // A method that sorts users_, control_predecessors_, and control_successors_ 1852 // according to the orders used in sorted_instruction. The sorting is used 1853 // during cloning, to make clone behavior match uncloned behavior. 1854 void SortInstructionUsersAndControlLists( 1855 const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn, 1856 const HloInstruction& sorted_instruction); 1857 1858 // Old methods kept for smooth subclassing transition BEGIN. 1859 // TODO(b/80131774): Remove this code. 1860 1861 // Delegates to HloBatchNormInstruction::feature_index. 1862 int64_t feature_index() const; 1863 1864 // Delegates to HloBatchNormInstruction::epsilon. 1865 float epsilon() const; 1866 1867 // Delegates to HloFftInstruction::fft_type. 1868 FftType fft_type() const; 1869 1870 // Delegates to HloFftInstruction::fft_length. 1871 const std::vector<int64_t>& fft_length() const; 1872 1873 // Delegates to HloChannelInstruction::channel_id. 1874 std::optional<int64_t> channel_id() const; 1875 void set_channel_id(const std::optional<int64_t>& channel_id); 1876 1877 // Returns the dimension sizes or numbers associated with this instruction. dimensions()1878 virtual absl::Span<const int64_t> dimensions() const { 1879 LOG(FATAL) << "Unimplemented method."; 1880 } 1881 dimensions(int64_t index)1882 int64_t dimensions(int64_t index) const { return dimensions()[index]; } 1883 mutable_dimensions()1884 virtual std::vector<int64_t>* mutable_dimensions() { 1885 LOG(FATAL) << "Unimplemented method."; 1886 } 1887 1888 // Delegates to HloConcatenateInstruction::concatenate_dimension. 1889 virtual int64_t concatenate_dimension() const; 1890 1891 // Delegates to HloGetDimensionSizeInstruction::dimension. 1892 int64_t dimension() const; 1893 1894 // Delegates to HloReshapeInstruction::inferred_dimension. 1895 int64_t inferred_dimension() const; 1896 1897 // Returns whether this instruction does a rank-2 transposition. 1898 bool IsRank2Transpose() const; 1899 1900 // Delegates to HloSliceInstruction::slice_start. 1901 int64_t slice_starts(int64_t dimension) const; 1902 const std::vector<int64_t>& slice_starts() const; 1903 std::vector<int64_t>* mutable_slice_starts(); 1904 1905 // Delegates to HloSliceInstruction::slice_limits. 1906 int64_t slice_limits(int64_t dimension) const; 1907 const std::vector<int64_t>& slice_limits() const; 1908 std::vector<int64_t>* mutable_slice_limits(); 1909 1910 // Delegates to HloSliceInstruction::slice_strides. 1911 int64_t slice_strides(int64_t dimension) const; 1912 const std::vector<int64_t>& slice_strides() const; 1913 std::vector<int64_t>* mutable_slice_strides(); 1914 1915 // Returns the literal associated with this instruction. 1916 const Literal& literal() const; 1917 1918 // Returns whether the instruction is a constant. 1919 bool IsConstant() const; 1920 1921 // Delegate to HloConstantInstruction::RelayoutConstant. 1922 void RelayoutConstant(const Layout& new_layout, 1923 const ShapeIndex& shape_index = {}); 1924 1925 // Delegates to 1926 // HloCallableInstruction::AppendInstructionIntoCalledComputation. 1927 HloInstruction* AppendInstructionIntoCalledComputation( 1928 HloInstruction* instruction_to_append, bool add_output = false); 1929 1930 // Delegates to HloFusionInstruction::AddFusionOperand. 1931 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 1932 1933 // Delegates to HloFusionInstruction::MergeFusionInstruction. 1934 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 1935 1936 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. 1937 void MergeFusionInstructionIntoMultiOutput( 1938 HloInstruction* instruction_to_merge); 1939 1940 // Delegates to HloFusionInstruction::FuseInstruction. 1941 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); 1942 1943 // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. 1944 HloInstruction* FuseInstructionIntoMultiOutput( 1945 HloInstruction* instruction_to_fuse); 1946 1947 // Delegates to HloFusionInstruction::fused_instruction. 1948 HloComputation* fused_instructions_computation() const; 1949 1950 // Delegates to HloFusionInstruction::fused_expression_root. 1951 HloInstruction* fused_expression_root() const; 1952 1953 // Delegates to HloFusionInstruction::fused_instructions. 1954 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1955 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1956 fused_instructions() const; 1957 1958 const tensorflow::gtl::iterator_range< 1959 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1960 fused_instructions(); 1961 1962 // Delegates to HloFusionInstruction::fused_instruction_count. 1963 int64_t fused_instruction_count() const; 1964 1965 // Delegates to HloFusionInstruction::fused_parameter. 1966 HloInstruction* fused_parameter(int64_t parameter_number) const; 1967 1968 // Delegates to HloFusionInstruction::fused_parameters. 1969 const std::vector<HloInstruction*>& fused_parameters() const; 1970 1971 // Returns true if this instruction is a fusion instruction that generates 1972 // multiple outputs. 1973 const bool IsMultiOutputFusion() const; 1974 1975 // Delegates to HloFusionInstruction::fusion_kind. 1976 FusionKind fusion_kind() const; 1977 1978 // Delegates to HloFusionInstruction::set_fusion_kind. 1979 void set_fusion_kind(FusionKind kind); 1980 1981 // Delegates to HloRngInstruction::random_distribution. 1982 RandomDistribution random_distribution() const; 1983 1984 // Delegates to HloParameterInstruction::parameter_number. 1985 int64_t parameter_number() const; 1986 1987 // Delegates to 1988 // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. 1989 void set_parameter_replicated_at_leaf_buffers( 1990 absl::Span<const bool> parameter_replicated_at_leaf_buffers); 1991 void set_parameter_replicated_at_leaf_buffers( 1992 const std::vector<bool>& parameter_replicated_at_leaf_buffers); 1993 1994 // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. 1995 const std::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers() 1996 const; 1997 1998 // Delegates to HloGetTupleElementInstruction::tuple_index. 1999 int64_t tuple_index() const; 2000 2001 // Delegates to HloGetTupleElementInstruction::set_tuple_index. 2002 void set_tuple_index(int64_t new_tuple_index); 2003 2004 // Delegates to HloReducePrecisionInstruction::exponent_bits. 2005 int32_t exponent_bits() const; 2006 2007 // Delegates to HloReducePrecisionInstruction::mantissa_bits. 2008 int32_t mantissa_bits() const; 2009 2010 // Delegates to HloInfeedInstruction::infeed_config. 2011 std::string infeed_config() const; 2012 2013 // Delegates to HloInfeedInstruction::set_infeed_config. 2014 void set_infeed_config(const std::string& config); 2015 2016 // Returns the config for the Outfeed instruction. 2017 const std::string& outfeed_config() const; 2018 2019 // Delegates to HloOutfeedInstruction::set_outfeed_config. 2020 void set_outfeed_config(const std::string& config); 2021 2022 // Returns the shape for the Outfeed instruction. 2023 const Shape& outfeed_shape() const; 2024 2025 // Returns the mutable shape for the Outfeed instruction. 2026 Shape* mutable_outfeed_shape(); 2027 2028 // Delegates to HloCollectiveInstruction::replica_groups. 2029 const std::vector<ReplicaGroup>& replica_groups() const; 2030 2031 // Delegates to HloCollectivePermuteInstruction::source_target_pairs. 2032 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const; 2033 2034 // Returns data on the window in a windowed operation such as 2035 // convolution. window()2036 virtual const Window& window() const { 2037 LOG(FATAL) << "Unimplemented method."; 2038 } 2039 2040 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)2041 virtual void set_window(const Window& window) { 2042 LOG(FATAL) << "Unimplemented method."; 2043 } 2044 2045 // Returns the unique_indices field. unique_indices()2046 virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; } 2047 2048 // Returns data on the dimension numbers used for a convolution operation, 2049 // which may be a kConvolution instruction or a kCustomCall that implements a 2050 // convolution. 2051 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; 2052 2053 // Sets the convolution dimension numbers on this instruction. In general you 2054 // shouldn't need to call this; instead, specify the convolution dimension 2055 // numbers when you create the instruction. 2056 void set_convolution_dimension_numbers( 2057 const ConvolutionDimensionNumbers& dnums); 2058 2059 // The number of feature groups. Must be a divisor of the input feature 2060 // dimension and output feature dimension. 2061 int64_t feature_group_count() const; 2062 2063 void set_feature_group_count(int64_t feature_group_count); 2064 2065 // The number of batch groups. Must be a divisor of the input batch dimension 2066 int64_t batch_group_count() const; 2067 2068 void set_batch_group_count(int64_t batch_group_count); 2069 2070 // Delegates to HloSelectAndScatterInstruction::select. 2071 HloComputation* select() const; 2072 2073 // Delegates to HloSelectAndScatterInstruction::scatter. 2074 HloComputation* scatter() const; 2075 2076 // Delegates to HloSelectAndScatterInstruction::set_select. 2077 void set_select(HloComputation* computation); 2078 2079 // Delegates to HloSelectAndScatterInstruction::set_scatter. 2080 void set_scatter(HloComputation* computation); 2081 2082 // Delegates to HloCustomCallInstruction::custom_call_target. 2083 const std::string& custom_call_target() const; 2084 void set_custom_call_target(absl::string_view target); 2085 2086 // Delegates to HloPadInstruction::padding_config. 2087 const PaddingConfig& padding_config() const; 2088 PaddingConfig* mutable_padding_config(); 2089 2090 // Delegates to HloConvolutionInstruction::padding_type. 2091 PaddingType padding_type() const; 2092 2093 // Delegates to HloDynamicSliceInstruction::slice_sizes. 2094 int64_t slice_sizes(int64_t dimension) const; 2095 2096 // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. 2097 const std::vector<int64_t>& dynamic_slice_sizes() const; 2098 2099 // Delegates to HloCollectivePermuteInstruction::dynamic_slice_sizes. 2100 const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const; 2101 2102 // Delegates to HloGatherInstruction::gather_dimension_numbers. 2103 const GatherDimensionNumbers& gather_dimension_numbers() const; 2104 // Delegates to HloGatherInstruction::gather_slice_sizes. 2105 absl::Span<const int64_t> gather_slice_sizes() const; 2106 2107 // Delegates to HloScatterInstruction::scatter_dimension_numbers(). 2108 const ScatterDimensionNumbers& scatter_dimension_numbers() const; 2109 2110 // Delegates to HloDotInstruction::dot_dimension_numbers(). 2111 const DotDimensionNumbers& dot_dimension_numbers() const; 2112 2113 // Delegates to HloDomainInstruction::operand_side_metadata(). 2114 const DomainMetadata& operand_side_metadata() const; 2115 2116 // Delegates to HloDomainInstruction::user_side_metadata(). 2117 const DomainMetadata& user_side_metadata() const; 2118 2119 // Returns true if the instruction is an async-start, async-update, or 2120 // async-done. 2121 bool IsAsynchronous() const; 2122 2123 // Returns the computation that will executed asynchronously. 2124 HloComputation* async_wrapped_computation() const; 2125 2126 // Delagates to HloAsyncInstruction::async_wrapped_instruction(). 2127 HloInstruction* async_wrapped_instruction() const; 2128 2129 // Delagates to HloAsyncInstruction::async_wrapped_opcode(). 2130 HloOpcode async_wrapped_opcode() const; 2131 2132 // Delegates to HloAsyncInstruction::async_group_id(). 2133 std::optional<int64_t> async_group_id() const; 2134 2135 // Delegates to HloAsyncInstruction::set_async_group_id(). 2136 void set_async_group_id(std::optional<int64_t> async_group_id); 2137 2138 // Delegates to HloAsyncInstruction::async_execution_thread(). 2139 absl::string_view async_execution_thread() const; 2140 2141 // Delegates to HloAsyncInstruction::set_async_execution_thread(). 2142 void set_async_execution_thread(absl::string_view async_execution_thread); 2143 2144 // Delegates to 2145 // HloCallableInstruction::RecursivelySetComputationsThreadName(). 2146 void set_called_computations_execution_thread( 2147 absl::string_view async_execution_thread, 2148 bool skip_async_execution_thread_overwrite); 2149 2150 // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). 2151 bool is_cross_program_prefetch() const; 2152 2153 // Delegates to HloCompareInstruction::direction(). 2154 ComparisonDirection comparison_direction() const; 2155 // Delegates to HloCompareInstruction::order(). 2156 ComparisonOrder comparison_order() const; 2157 2158 // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). 2159 const TriangularSolveOptions& triangular_solve_options() const; 2160 2161 // Delegates to HloCholeskyInstruction::cholesky_options(). 2162 const CholeskyOptions& cholesky_options() const; 2163 2164 // Appends operand to the list of operands and adds this instruction as a user 2165 // of the operand. 2166 void AppendOperand(HloInstruction* operand); 2167 2168 // Old methods kept for smooth subclassing transition END. 2169 2170 protected: 2171 // Internal constructor for a given opcode/shape, other fields must be filled 2172 // by factory methods. 2173 HloInstruction(HloOpcode opcode, const Shape& shape); 2174 RemoveAllOperands()2175 void RemoveAllOperands() { operands_.clear(); } 2176 RemoveOperandAt(int index)2177 void RemoveOperandAt(int index) { 2178 operands_.erase(operands_.begin() + index); 2179 } 2180 2181 // Removes a list of operands with the given indices in ascending order. 2182 void RemoveOperandsAtAscendingIndices( 2183 absl::Span<const int> ascending_indices); 2184 AppendComputation(HloComputation * computation)2185 void AppendComputation(HloComputation* computation) { 2186 called_computations_.push_back(computation); 2187 } 2188 DetachFrom(HloInstruction * usee)2189 void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } 2190 set_called_computation(int index,HloComputation * computation)2191 void set_called_computation(int index, HloComputation* computation) { 2192 called_computations_[index] = computation; 2193 } 2194 // Indices of computations in called_computations_ for instructions which call 2195 // multiple computations. 2196 enum { 2197 // kWhile computations. 2198 kBodyComputationIndex = 0, 2199 kConditionComputationIndex = 1, 2200 2201 // kSelectAndScatter computations. 2202 kSelectComputationIndex = 0, 2203 kScatterComputationIndex = 1, 2204 2205 // kConditional computations. 2206 kTrueComputationIndex = 0, 2207 kFalseComputationIndex = 1, 2208 }; 2209 2210 private: 2211 friend class HloComputation; 2212 // Wrapper class of string format and protobuf format of BackendConfig. 2213 class BackendConfigRep { 2214 public: GetProtoPtr()2215 const tensorflow::protobuf::Message* GetProtoPtr() const { 2216 return proto_.get(); 2217 } 2218 2219 const std::string& GetRawString() const; 2220 2221 BackendConfigRep Clone() const; 2222 2223 bool operator==(const BackendConfigRep& other) const; 2224 bool operator!=(const BackendConfigRep& other) const { 2225 return !(*this == other); 2226 } 2227 empty()2228 bool empty() const { return proto_ == nullptr && raw_string_.empty(); } 2229 clear()2230 void clear() { 2231 proto_.reset(); 2232 raw_string_.clear(); 2233 } 2234 2235 BackendConfigRep& operator=(std::string raw_string); 2236 BackendConfigRep& operator=(const tensorflow::protobuf::Message& proto); 2237 void SetProto(const tensorflow::protobuf::Message& proto); 2238 2239 private: 2240 std::unique_ptr<tensorflow::protobuf::Message> proto_; 2241 // If proto_ is not null, raw_string_ is a lazy cache of its string format. 2242 mutable std::string raw_string_; 2243 }; 2244 2245 bool IdenticalInternal( 2246 const HloInstruction& other, 2247 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 2248 eq_operands, 2249 const std::function<bool(const HloComputation*, const HloComputation*)>& 2250 eq_computations, 2251 bool layout_sensitive, bool ignore_channel_id_values, 2252 bool ignore_commutative_operand_order) const; 2253 2254 // Implementation for non-common logic of CloneWithNewOperands. CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2255 virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2256 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2257 HloCloneContext* context) const { 2258 // TODO(b/80131774): This should be pure virtual. 2259 LOG(FATAL) << "Unimplemented method."; 2260 } 2261 2262 // Implementation for non-common logic of ExtraAttributesToString. ExtraAttributesToStringImpl(const HloPrintOptions & options)2263 virtual std::vector<std::string> ExtraAttributesToStringImpl( 2264 const HloPrintOptions& options) const { 2265 return {}; 2266 } 2267 2268 // Implementation for IsElementwise if operand_idx is nullopt and for 2269 // IsElementwiseOnOperand if otherwise. 2270 // 2271 // NOTE: For all instructions other than kFusion, being elementwise on one of 2272 // the operands is equivalent to being elementwise on all the operands. 2273 virtual bool IsElementwiseImpl( 2274 const std::optional<int64_t>& operand_idx) const; 2275 2276 // Prints an operand to a string. Accessed by friend class HloInstruction. 2277 virtual std::string OperandsToStringWithCanonicalNameMap( 2278 const HloPrintOptions& options, 2279 CanonicalNameMap* canonical_name_map) const; 2280 2281 // See comments on Identical(). 2282 virtual bool IdenticalSlowPath( 2283 const HloInstruction& other, 2284 const std::function<bool(const HloComputation*, const HloComputation*)>& 2285 eq_computations) const; 2286 2287 // Creates an n-ary elementwise operation. 2288 static std::unique_ptr<HloInstruction> CreateNary( 2289 const Shape& shape, HloOpcode opcode, 2290 absl::Span<HloInstruction* const> operands); 2291 2292 // Adds a user for this instruction. 2293 void AddUser(HloInstruction* user); 2294 2295 // Removes a user for this instruction. 2296 void RemoveUser(HloInstruction* user); 2297 2298 // Helper for implementing backend_config(). Parses backend_config_ into the 2299 // given proto. 2300 Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; 2301 2302 // Mark this instruction as dead. Accessed by friend class HloInstruction. MarkAsDead()2303 void MarkAsDead() { marked_as_dead_ = true; } 2304 2305 // Has this instruction been marked as dead? Accessed by friend class 2306 // HloInstruction. IsMarkedAsDead()2307 bool IsMarkedAsDead() const { return marked_as_dead_; } 2308 2309 int unique_id_; // Unique to this HloInstruction within a HloModule 2310 2311 // Opcode for this instruction. 2312 HloOpcode opcode_; 2313 2314 // Instruction operands. 2315 InstructionVector operands_; 2316 2317 // The set of control predecessors of this instruction. 2318 // Note that the order of the instructions in the vector influences the order 2319 // computed in HloComputation::ComputeInstructionPostOrder, which may 2320 // influence the result of the compilation by changing the scheduling. We are 2321 // not sure if it matters. 2322 std::vector<HloInstruction*> control_predecessors_; 2323 2324 // The users of this instruction. Users are HLOs where this instruction is an 2325 // operand. The vector users_ and the map user_map_ contain identical members. 2326 // The map enables fast membership testing and the vector enables fast, stable 2327 // iteration. The value in the map contains the index of the instruction in 2328 // the vector what enables fast removal. 2329 std::vector<HloInstruction*> users_; 2330 absl::flat_hash_map<const HloInstruction*, int64_t> user_map_; 2331 2332 // The set of control successors of this instruction. 2333 std::vector<HloInstruction*> control_successors_; 2334 2335 // The computation in which this instruction is contained. 2336 HloComputation* parent_ = nullptr; 2337 2338 // Result shape of this instruction. 2339 Shape shape_; 2340 2341 // The sharding, if one exists. 2342 // Uses std::shared_ptr to allow reuse of the same sharding object between 2343 // HloInstructions and other components as HloSharding can be very large for 2344 // many element tuples. 2345 std::shared_ptr<const HloSharding> sharding_; 2346 2347 // Computations called by this instruction. 2348 std::vector<HloComputation*> called_computations_; 2349 2350 // A trace instruction that consumes this instruction. 2351 // 2352 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 2353 // an operand. 2354 HloInstruction* trace_instruction_ = nullptr; 2355 2356 // The backend-specific configuration for how a backend should compile this 2357 // HLO. See the documentation on backend_config(). 2358 mutable BackendConfigRep backend_config_; 2359 2360 // Attributes passed from the frontend to give hints to the backend about 2361 // how to compile this HLO. 2362 // HLO -> HLO transforms are expected to preserve these attributes on a 2363 // "best effort" basis only. 2364 // For example: 2365 // x = const(10, frontend_attributes={x} 2366 // y = const(10, frontend_attributes={y} 2367 // z = add(x,y), frontend_attributes={y} 2368 // Could be simplified to: 2369 // z' = const(20), frontend_attributes={?} 2370 FrontendAttributes frontend_attributes_; 2371 2372 // This field is assigned to true when backend_config_ is assigned to 2373 // a default configuration. 2374 bool is_default_config_ = false; 2375 2376 // True if this instruction has already been detached from its user and 2377 // operands. 2378 bool cleaned_up_ = false; 2379 2380 // String identifier for instruction. 2381 std::string name_; 2382 2383 // Metadata for debugging. 2384 OpMetadata metadata_; 2385 2386 // The number of partitions per outer dimension (listed in order from 2387 // outer-most dimension first). 2388 std::vector<int64_t> outer_dimension_partitions_; 2389 2390 // Intrusive flag used by HloComputation, whether this instruction has 2391 // been marked as dead. 2392 bool marked_as_dead_; 2393 2394 HloInstruction(const HloInstruction&) = delete; 2395 HloInstruction& operator=(const HloInstruction&) = delete; 2396 }; 2397 2398 // Explicit instantiations in hlo_instruction.cc. 2399 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); 2400 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); 2401 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor); 2402 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); 2403 2404 std::string ToString(HloInstruction::FusionKind kind); 2405 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 2406 const std::string& kind_name); 2407 2408 // Custom (de)stringification functions for protos that live inside 2409 // HloInstruction. 2410 std::string PaddingConfigToString(const PaddingConfig& padding); 2411 std::string FrontendAttributesToString( 2412 const FrontendAttributes& frontend_attributes); 2413 std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm); 2414 std::string RandomDistributionToString(const RandomDistribution& distribution); 2415 std::string PrecisionToString(const PrecisionConfig::Precision& precision); 2416 std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums); 2417 std::string ConvolutionDimensionNumbersToString( 2418 const ConvolutionDimensionNumbers& dnums); 2419 std::string ReplicaGroupsToString( 2420 absl::Span<const ReplicaGroup> replica_groups); 2421 2422 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const std::string& name); 2423 StatusOr<RandomDistribution> StringToRandomDistribution( 2424 const std::string& name); 2425 StatusOr<PrecisionConfig::Precision> StringToPrecision(const std::string& name); 2426 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(absl::string_view name); 2427 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion( 2428 absl::string_view name); 2429 2430 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 2431 2432 // Map classes that guarantee a deterministic iteration order when the key is 2433 // an HloInstruction* or a const HloInstruction*. 2434 // To make the iteration order over the map deterministic, the comparator 2435 // should not be using the pointer values, but rather an intrinsic property of 2436 // the hlo. Exception: null pointer values compare less than non-null. 2437 struct HloPtrComparator { 2438 bool operator()(const HloInstruction* const& lhs, 2439 const HloInstruction* const& rhs) const; 2440 }; 2441 2442 template <typename ValueT> 2443 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 2444 2445 template <typename ValueT> 2446 using ConstHloInstructionMap = 2447 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 2448 2449 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 2450 using ConstHloInstructionSet = 2451 std::set<const HloInstruction*, HloPtrComparator>; 2452 2453 } // namespace xla 2454 2455 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 2456