1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <set> 29 #include <string> 30 #include <tuple> 31 #include <vector> 32 33 #include "absl/container/flat_hash_map.h" 34 #include "absl/container/flat_hash_set.h" 35 #include "absl/container/inlined_vector.h" 36 #include "absl/memory/memory.h" 37 #include "absl/strings/str_cat.h" 38 #include "absl/strings/string_view.h" 39 #include "absl/types/span.h" 40 #include "tensorflow/compiler/xla/comparison_util.h" 41 #include "tensorflow/compiler/xla/iterator_util.h" 42 #include "tensorflow/compiler/xla/literal.h" 43 #include "tensorflow/compiler/xla/map_util.h" 44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 45 #include "tensorflow/compiler/xla/service/hlo.pb.h" 46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" 48 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 49 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 50 #include "tensorflow/compiler/xla/service/name_uniquer.h" 51 #include "tensorflow/compiler/xla/shape_tree.h" 52 #include "tensorflow/compiler/xla/types.h" 53 #include "tensorflow/compiler/xla/xla_data.pb.h" 54 #include "tensorflow/core/lib/core/status.h" 55 #include "tensorflow/core/lib/gtl/iterator_range.h" 56 #include "tensorflow/core/platform/logging.h" 57 #include "tensorflow/core/platform/macros.h" 58 #include "tensorflow/core/platform/protobuf.h" 59 #include "tensorflow/core/platform/types.h" 60 61 namespace xla { 62 63 class HloComputation; 64 class HloModule; 65 66 string PrintName(const string& name, bool print_ids); 67 68 // A bunch of switches that control how the hlo text should be printed. 69 class HloPrintOptions { 70 public: 71 enum class PrintSubcomputationMode { 72 kOff, // Do not print anything about subcomputations. 73 kNameOnly, // Only print the name of subcomputations. 74 kFullBodies, // Print the full bodies of subcomputations. 75 kNonSequentialBodies, // Print the full bodies of subcomputations that are 76 // not in a sequential context. 77 }; 78 79 // Constructs the default print options: don't print large constants, don't 80 // compact operands, no indentation. HloPrintOptions()81 HloPrintOptions() 82 : print_large_constants_(false), 83 print_only_essential_constants_(false), 84 print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), 85 print_metadata_(true), 86 print_backend_config_(true), 87 print_infeed_outfeed_config_(true), 88 compact_operands_(false), 89 include_layout_in_shapes_(true), 90 print_result_shape_(true), 91 print_operand_shape_(true), 92 print_operand_names_(true), 93 print_operand_index_annotation_interval_(5), 94 print_program_shape_(true), 95 print_percent_(true), 96 print_control_dependencies_(true), 97 canonicalize_instruction_names_(false), 98 indent_amount_(0), 99 is_in_nested_computation_(false), 100 print_ids_(true), 101 canonicalize_computations_(false), 102 print_extra_attributes_(true) {} 103 ShortParsable()104 static HloPrintOptions ShortParsable() { 105 return HloPrintOptions() 106 .set_print_large_constants(true) 107 .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) 108 .set_print_metadata(false) 109 .set_print_backend_config(false) 110 .set_print_operand_shape(false) 111 .set_print_operand_index_annotation_interval(0) 112 .set_print_program_shape(false) 113 .set_print_percent(false) 114 .set_print_control_dependencies(false); 115 } 116 117 // Options to produce the canonical string representing an isomorphic 118 // computation graph. Canonical()119 static HloPrintOptions Canonical() { 120 return HloPrintOptions() 121 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 122 .set_print_metadata(false) 123 .set_print_backend_config(false) 124 .set_compact_operands(false) 125 .set_print_operand_names(false) 126 .set_print_operand_shape(true) 127 .set_print_operand_index_annotation_interval(0) 128 .set_print_program_shape(false) 129 .set_print_percent(false) 130 .set_print_control_dependencies(false) 131 .set_canonicalize_instruction_names(true); 132 } 133 134 // Options to produce a fingerprint of an HLO. Fingerprint()135 static HloPrintOptions Fingerprint() { 136 return HloPrintOptions() 137 .set_print_subcomputation_mode( 138 PrintSubcomputationMode::kNonSequentialBodies) 139 .set_print_metadata(false) 140 .set_print_backend_config(false) 141 .set_print_infeed_outfeed_config(false) 142 .set_print_only_essential_constants(true) 143 .set_compact_operands(true) 144 .set_print_operand_names(false) 145 .set_print_operand_shape(true) 146 .set_print_operand_index_annotation_interval(0) 147 .set_print_program_shape(false) 148 .set_print_percent(false) 149 .set_print_control_dependencies(false) 150 .set_canonicalize_instruction_names(true) 151 .set_print_ids(false) 152 .set_canonicalize_computations(true); 153 } 154 155 // If true, large constants will be printed out. set_print_large_constants(bool value)156 HloPrintOptions& set_print_large_constants(bool value) { 157 print_large_constants_ = value; 158 return *this; 159 } 160 161 // If true, only integer, all-zero, are all-one constants will be printed out. set_print_only_essential_constants(bool value)162 HloPrintOptions& set_print_only_essential_constants(bool value) { 163 print_only_essential_constants_ = value; 164 return *this; 165 } 166 set_print_subcomputation_mode(PrintSubcomputationMode value)167 HloPrintOptions& set_print_subcomputation_mode( 168 PrintSubcomputationMode value) { 169 print_subcomputation_mode_ = value; 170 return *this; 171 } 172 173 // If true, metadata will be printed. set_print_metadata(bool value)174 HloPrintOptions& set_print_metadata(bool value) { 175 print_metadata_ = value; 176 return *this; 177 } 178 179 // If true, backend_config will be printed. set_print_backend_config(bool value)180 HloPrintOptions& set_print_backend_config(bool value) { 181 print_backend_config_ = value; 182 return *this; 183 } 184 185 // If true, infeed_config and outfeed_config will be printed. set_print_infeed_outfeed_config(bool value)186 HloPrintOptions& set_print_infeed_outfeed_config(bool value) { 187 print_infeed_outfeed_config_ = value; 188 return *this; 189 } 190 191 // If true, result shapes will be printed. set_print_result_shape(bool value)192 HloPrintOptions& set_print_result_shape(bool value) { 193 print_result_shape_ = value; 194 return *this; 195 } 196 197 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)198 HloPrintOptions& set_print_operand_shape(bool value) { 199 print_operand_shape_ = value; 200 return *this; 201 } 202 203 // If true, operands' shapes will be printed. set_print_operand_index_annotation_interval(int64_t value)204 HloPrintOptions& set_print_operand_index_annotation_interval(int64_t value) { 205 print_operand_index_annotation_interval_ = value; 206 return *this; 207 } 208 209 // If true, the operand names will be printed. set_print_operand_names(bool value)210 HloPrintOptions& set_print_operand_names(bool value) { 211 print_operand_names_ = value; 212 return *this; 213 } 214 215 // If true, all printed names include unique identifiers. set_print_ids(bool value)216 HloPrintOptions& set_print_ids(bool value) { 217 print_ids_ = value; 218 return *this; 219 } 220 221 // If true, the HLO includes its attributes. set_print_extra_attributes(bool value)222 HloPrintOptions& set_print_extra_attributes(bool value) { 223 print_extra_attributes_ = value; 224 return *this; 225 } 226 227 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)228 HloPrintOptions& set_print_program_shape(bool value) { 229 print_program_shape_ = value; 230 return *this; 231 } 232 233 // If true, names will be printed with prefix '%'. set_print_percent(bool value)234 HloPrintOptions& set_print_percent(bool value) { 235 print_percent_ = value; 236 return *this; 237 } 238 239 // If true, control dependencies will be printed. set_print_control_dependencies(bool value)240 HloPrintOptions& set_print_control_dependencies(bool value) { 241 print_control_dependencies_ = value; 242 return *this; 243 } 244 245 // If true, only a part of operands will be printed out (note that in this 246 // case the text will not be parsable). set_compact_operands(bool value)247 HloPrintOptions& set_compact_operands(bool value) { 248 compact_operands_ = value; 249 return *this; 250 } 251 252 // If true, include the layout in any shapes that are printed (instruction 253 // and operands). set_include_layout_in_shapes(bool value)254 HloPrintOptions& set_include_layout_in_shapes(bool value) { 255 include_layout_in_shapes_ = value; 256 return *this; 257 } 258 259 // If true, canonicalizes instructions' name. Instead of using "%foo.1" as 260 // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. set_canonicalize_instruction_names(bool value)261 HloPrintOptions& set_canonicalize_instruction_names(bool value) { 262 canonicalize_instruction_names_ = value; 263 return *this; 264 } 265 266 // If true, canonicalizes computations, sorting by computations' names. set_canonicalize_computations(bool value)267 HloPrintOptions& set_canonicalize_computations(bool value) { 268 canonicalize_computations_ = value; 269 return *this; 270 } 271 272 // The indent of the hlo text block. set_indent_amount(int value)273 HloPrintOptions& set_indent_amount(int value) { 274 indent_amount_ = value; 275 return *this; 276 } 277 278 // If true, indicates the instruction being printed is inside a nested 279 // computation. set_is_in_nested_computation(bool value)280 HloPrintOptions& set_is_in_nested_computation(bool value) { 281 is_in_nested_computation_ = value; 282 return *this; 283 } 284 285 // Instructions are selected for printing by a predicate function 286 // (`set_print_instructions`). We also print their surrounding instructions 287 // for ease of reading. We print `leading_and_trailing_instructions_number` 288 // instructions before and after the qualified ones inside a computation. set_leading_and_trailing_instructions_number(int value)289 HloPrintOptions& set_leading_and_trailing_instructions_number(int value) { 290 leading_and_trailing_instructions_number_ = value; 291 return *this; 292 } 293 294 // A callback which takes an HloInstruction*, its string representation, 295 // the indentation level of the resulting block, and a 296 // bool variable indicating whether the instruction is root or not. The return 297 // value is a string which is used for this instruction during printing. 298 using FormatInstructionFunc = 299 std::function<string(const HloInstruction*, const string&, int, bool)>; 300 set_format_instruction(FormatInstructionFunc callback)301 HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) { 302 format_instruction_ = callback; 303 return *this; 304 } 305 306 using HloInstructionPredicate = std::function<bool(const HloInstruction*)>; 307 308 // A callback which takes an HloInstruction* and returns whether it should be 309 // printed or not. set_print_instruction(HloInstructionPredicate callback)310 HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) { 311 print_instruction_ = callback; 312 return *this; 313 } 314 315 using HloComputationPredicate = std::function<bool(const HloComputation*)>; 316 317 // A callback which takes an HloComputation* and returns whether it should be 318 // printed or not. set_print_computation(HloComputationPredicate callback)319 HloPrintOptions& set_print_computation(HloComputationPredicate callback) { 320 print_computation_ = callback; 321 return *this; 322 } 323 print_large_constants()324 bool print_large_constants() const { return print_large_constants_; } print_only_essential_constants()325 bool print_only_essential_constants() const { 326 return print_only_essential_constants_; 327 } print_subcomputation_mode()328 PrintSubcomputationMode print_subcomputation_mode() const { 329 return print_subcomputation_mode_; 330 } print_metadata()331 bool print_metadata() const { return print_metadata_; } print_backend_config()332 bool print_backend_config() const { return print_backend_config_; } print_infeed_outfeed_config()333 bool print_infeed_outfeed_config() const { 334 return print_infeed_outfeed_config_; 335 } compact_operands()336 bool compact_operands() const { return compact_operands_; } include_layout_in_shapes()337 bool include_layout_in_shapes() const { return include_layout_in_shapes_; } print_result_shape()338 bool print_result_shape() const { return print_result_shape_; } print_operand_shape()339 bool print_operand_shape() const { return print_operand_shape_; } print_operand_names()340 bool print_operand_names() const { return print_operand_names_; } print_operand_index_annotation_interval()341 int64 print_operand_index_annotation_interval() const { 342 return print_operand_index_annotation_interval_; 343 } print_ids()344 bool print_ids() const { return print_ids_; } print_program_shape()345 bool print_program_shape() const { return print_program_shape_; } print_percent()346 bool print_percent() const { return print_percent_; } print_control_dependencies()347 bool print_control_dependencies() const { 348 return print_control_dependencies_; 349 } print_extra_attributes()350 bool print_extra_attributes() const { return print_extra_attributes_; } canonicalize_instruction_names()351 bool canonicalize_instruction_names() const { 352 return canonicalize_instruction_names_; 353 } canonicalize_computations()354 bool canonicalize_computations() const { return canonicalize_computations_; } indent_amount()355 int indent_amount() const { return indent_amount_; } is_in_nested_computation()356 int is_in_nested_computation() const { return is_in_nested_computation_; } leading_and_trailing_instructions_number()357 int leading_and_trailing_instructions_number() const { 358 return leading_and_trailing_instructions_number_; 359 } format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)360 string format_instruction(const HloInstruction* instr, 361 const string& instr_name, int indent, 362 bool is_root) const { 363 return format_instruction_(instr, instr_name, indent, is_root); 364 } print_instruction(const HloInstruction * instr)365 bool print_instruction(const HloInstruction* instr) const { 366 return print_instruction_(instr); 367 } print_computation(const HloComputation * comp)368 bool print_computation(const HloComputation* comp) const { 369 return print_computation_(comp); 370 } 371 372 private: 373 bool print_large_constants_; 374 bool print_only_essential_constants_; 375 PrintSubcomputationMode print_subcomputation_mode_; 376 bool print_metadata_; 377 bool print_backend_config_; 378 bool print_infeed_outfeed_config_; 379 bool compact_operands_; 380 bool include_layout_in_shapes_; 381 bool print_result_shape_; 382 bool print_operand_shape_; 383 bool print_operand_names_; 384 // The interval between the /*index=*/ annotated operands. 0 means never print 385 // the annotation, 1 means print annotation for every operand. 386 int64 print_operand_index_annotation_interval_; 387 bool print_program_shape_; 388 bool print_percent_; 389 bool print_control_dependencies_; 390 bool canonicalize_instruction_names_; 391 int indent_amount_; 392 bool is_in_nested_computation_; 393 bool print_ids_; 394 bool canonicalize_computations_; 395 bool print_extra_attributes_; 396 int leading_and_trailing_instructions_number_ = 3; 397 FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr, 398 const string& instr_name, 399 int indent, bool is_root) { 400 return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "", 401 instr_name); 402 }; 403 HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) { 404 return true; 405 }; 406 HloComputationPredicate print_computation_ = [](const HloComputation* comp) { 407 return true; 408 }; 409 }; 410 411 // For canonical string output, we need to have a canonical way to rename 412 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>", 413 // where <xxx> is an index starting from 0. 414 class CanonicalNameMap { 415 public: CanonicalNameMap()416 CanonicalNameMap() : index(0) {} 417 LookupOrInsert(const string & old_name)418 string LookupOrInsert(const string& old_name) { 419 auto iter = canonical_name_map.find(old_name); 420 if (iter != canonical_name_map.end()) { 421 return iter->second; 422 } 423 424 string new_name = absl::StrCat("tmp_", index++); 425 canonical_name_map[old_name] = new_name; 426 return new_name; 427 } Clear()428 void Clear() { 429 canonical_name_map.clear(); 430 index = 0; 431 } 432 433 private: 434 int64 index; 435 absl::flat_hash_map<string, string> canonical_name_map; 436 }; 437 438 // HLO instructions are the atomic unit of the high-level compiler's IR. 439 // 440 // HloInstructions live inside of an HloComputation, which is analogous to a 441 // function in other programming languages. Nodes have no total order within 442 // their computation. Instead, they have a partial ordering determined by their 443 // data and control dependencies. 444 // 445 // HLO does not have basic blocks or explicit "branch" instructions. Instead, 446 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode 447 // control flow. For example, the kConditional HLO executes one of two possible 448 // computations, depending on the runtime value of a predicate. 449 // 450 // HLO is pure (mostly). It has no concept of mutable state. Instead, data 451 // values are produced by one HLO and flow into consumers across dependency 452 // edges. 453 class HloInstruction { 454 public: 455 // A fusion node computes the same value a call to its fusion computation 456 // would compute. However, the choice of fusion kind dictates codegen 457 // strategy for the backend. 458 // 459 // To generate code for a kFusion HloInstruction, most backends do something 460 // like the following: 461 // 462 // 1) Identify the "primary" HloInstruction of the fused computation. 463 // 2) Emit code that does the work of the primary node, creating its inputs 464 // and transforming its outputs as specified by the fused computation. 465 // 466 // In step (2), the code emitted is usually similar to the code that would be 467 // emitted for an *unfused* version of the primary node, except that 468 // 469 // - when the primary node reads an element of one of its operands, instead 470 // of loading the value from memory, it *computes* the value based on the 471 // contents of the fused computation. 472 // - when the primary node outputs a value, instead of storing it to memory, 473 // it forwards the value to its users, which then perform additional 474 // computations before the value is finally stored to memory at the root of 475 // the fusion node. 476 // 477 // An HloInstruction's FusionKind helps us find the kFusion instruction's 478 // primary node, and can also affect how we generate code in step (2). 479 // 480 // - kInput: The primary node is the root of the fused instruction. 481 // 482 // - kOutput: The primary node is not the root of the fused instruction. 483 // This fusion kind requires that one operand buffer of the fusion 484 // instruction be able to alias the output buffer. This constraint is 485 // usually enough to let backends find the primary node unambiguously. 486 // 487 // - kLoop: The primary node is the root of the fused computation, but, 488 // unlike in input fusion, we prescribe a specific implementation for 489 // codegen. Rather than generating code that looks like the code we'd emit 490 // for an unfused version of the primary/root node, we emit code that 491 // generates one element of the root at a time. 492 // 493 // - kCustom: Custom category for backend-specific fusions that don't fit 494 // into the above patterns. 495 // 496 // Not all backends support all fusion kinds, and given a particular fused 497 // computation, it's not in general safe to change its fusion kind. Creation 498 // of fusion nodes is always backend-specific. 499 // 500 // For elementwise ops (e.g. kAdd), most backends would emit a 501 // one-element-at-a-time implementation for the unfused version, so loop 502 // fusion and input fusion are probably equivalent if the root node is 503 // elementwise. They're not necessarily equivalent e.g. for kReduce, where an 504 // implementation might emit something more sophisticated for an unfused or 505 // input-fusion reduce, but will emit the naive code that reduces one element 506 // at a time for loop fusion with a reduce as the root. 507 // 508 // Another way to think of loop fusion is that it's equivalent to input 509 // fusion, but where the root node is an implicit identity node, whose 510 // unfused implementation is "read one element, write one element". 511 // 512 // TODO(b/79869434): This categorization scheme is not great. For one thing, 513 // input and loop fusion are basically the same thing: There is no reason for 514 // the HLO to encode backend-specific decisions about how e.g. a reduce that's 515 // the root of a fusion should be lowered. In addition, this scheme as 516 // written doesn't work for multi-output fusion, where the primary node is 517 // never actually the root (which is a kTuple instruction that gathers the 518 // multiple outputs of the fusion). 519 enum class FusionKind { 520 kLoop, 521 kInput, 522 kOutput, 523 kCustom, 524 }; 525 ~HloInstruction()526 virtual ~HloInstruction() { DetachFromOperandsAndUsers(); } 527 528 // Detaches an instruction from its operands and users. That is, remove the 529 // instruction from each operand's user set and user's operand set. 530 void DetachFromOperandsAndUsers(); 531 532 // Creates an instruction from the given proto. Arguments: 533 // 534 // proto: the proto to convert from. 535 // instruction_map: a map from instruction id to HloInstruction*. This map 536 // must contain all operands of the newly constructed instruction. 537 // computation_map: a map from computation id to HloComputation*. This map 538 // must contain all computations which the newly constructed instruction 539 // calls. 540 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 541 const HloInstructionProto& proto, 542 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, 543 const absl::flat_hash_map<int64, HloComputation*>& computation_map = {}, 544 bool prohibit_empty_literal = true); 545 546 // Creates a parameter-retrieving instruction. 547 static std::unique_ptr<HloInstruction> CreateParameter( 548 int64_t parameter_number, const Shape& shape, const string& name); 549 550 // Creates a literal constant instruction. 551 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); 552 553 // Creates an Iota instruction. 554 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, 555 int64_t iota_dimension); 556 557 // Creates a get tuple element instruction. 558 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 559 const Shape& shape, HloInstruction* operand, int64_t index); 560 561 // Creates a trace instruction that logs the input operand in the computation. 562 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, 563 HloInstruction* operand); 564 565 // Creates a random number generation instruction that fills a shape with 566 // random numbers from a given distribution. 567 // 568 // The parameters to the instruction are interpreted as follows: 569 // 570 // - If `distribution` is RNG_UNIFORM, generates a number in range 571 // [param0, param1). 572 // 573 // - If `distribution` is RNG_NORMAL, generates a normally-distributed value 574 // with mean `param0` and standard deviation `param1`. 575 static std::unique_ptr<HloInstruction> CreateRng( 576 const Shape& shape, RandomDistribution distribution, 577 absl::Span<HloInstruction* const> parameters); 578 579 // Creates a stateless random bit generator instruction that fills a shape 580 // with random bits. 581 static std::unique_ptr<HloInstruction> CreateRngBitGenerator( 582 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm); 583 584 // Creates an instruction to update the random number generator state to 585 // reflect the new state after `delta` units of 32 random bits are generated 586 // and returns the old state. 587 static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState( 588 const Shape& shape, int64_t delta); 589 590 // Creates a unary instruction (one operand). 591 // Precondition: opcode must be a legitimate unary operation. 592 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 593 HloOpcode opcode, 594 HloInstruction* operand); 595 596 // Creates a binary instruction (two operands). 597 // Precondition: opcode must be a legitimate binary operation. 598 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 599 HloOpcode opcode, 600 HloInstruction* lhs, 601 HloInstruction* rhs); 602 603 // Creates a ternary instruction (three operands). 604 // Precondition: opcode must be a legitimate ternary operation. 605 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 606 HloOpcode opcode, 607 HloInstruction* lhs, 608 HloInstruction* rhs, 609 HloInstruction* ehs); 610 611 // Creates a variadic instruction (variable number of operands). 612 // Precondition: opcode must be a legitimate variadic operation. 613 static std::unique_ptr<HloInstruction> CreateVariadic( 614 const Shape& shape, HloOpcode opcode, 615 absl::Span<HloInstruction* const> operands); 616 617 // Creates a map instruction, where the computation (given by the handle) is 618 // applied element-wise to every element in operands (across the operands, 619 // at a given index) 620 static std::unique_ptr<HloInstruction> CreateMap( 621 const Shape& shape, absl::Span<HloInstruction* const> operands, 622 HloComputation* map_computation); 623 624 // Creates a convolution op, where rhs is the convolutional filter 625 // and window describes how the filter is applied to lhs. 626 static std::unique_ptr<HloInstruction> CreateConvolve( 627 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 628 int64_t feature_group_count, int64_t batch_group_count, 629 const Window& window, 630 const ConvolutionDimensionNumbers& dimension_numbers, 631 const PrecisionConfig& precision_config); 632 633 // Creates an FFT op, of the type indicated by fft_type. 634 static std::unique_ptr<HloInstruction> CreateFft( 635 const Shape& shape, HloInstruction* operand, FftType fft_type, 636 absl::Span<const int64> fft_length); 637 638 // Creates a copy-start op, indicating whether this is a cross-program 639 // prefetch or not. 640 static std::unique_ptr<HloInstruction> CreateCopyStart( 641 const Shape& shape, HloInstruction* operand, 642 bool is_cross_program_prefetch = false); 643 644 // Creates a compare op, performing the comparison specified in direction. 645 static std::unique_ptr<HloInstruction> CreateCompare( 646 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 647 Comparison::Direction direction, 648 absl::optional<Comparison::Type> type = absl::nullopt); 649 650 static std::unique_ptr<HloInstruction> CreateTriangularSolve( 651 const Shape& shape, HloInstruction* a, HloInstruction* b, 652 const TriangularSolveOptions& options); 653 654 static std::unique_ptr<HloInstruction> CreateCholesky( 655 const Shape& shape, HloInstruction* a, const CholeskyOptions& options); 656 657 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 658 // dimensions specified in 'dimension_numbers'. 659 static std::unique_ptr<HloInstruction> CreateDot( 660 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 661 const DotDimensionNumbers& dimension_numbers, 662 const PrecisionConfig& precision_config); 663 664 // Creates a reduce-precision op, where operand is the data to reduce in 665 // precision, and exponent_bits and mantissa_bits describe the precision to 666 // reduce it to. 667 static std::unique_ptr<HloInstruction> CreateReducePrecision( 668 const Shape& shape, HloInstruction* operand, const int exponent_bits, 669 const int mantissa_bits); 670 671 // Creates an all-gather op, which concats the operands of all participants 672 // along all_gather_dimension. The replica_groups, channel_id, and 673 // use_global_device_ids arguments are identical to those in all-reduce, 674 // except that the order of the group members determines the concatenation 675 // order of inputs from different participants. 676 static std::unique_ptr<HloInstruction> CreateAllGather( 677 const Shape& shape, absl::Span<HloInstruction* const> operands, 678 int64_t all_gather_dimension, 679 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 680 const absl::optional<int64>& channel_id, bool use_global_device_ids); 681 682 // Creates an all-gather-start op, which concats the operands of all 683 // participants 684 // along all_gather_dimension. The replica_groups, channel_id, and 685 // use_global_device_ids arguments are identical to those in all-reduce, 686 // except that the order of the group members determines the concatenation 687 // order of inputs from different participants. Needs to be used in 688 // conjunction of a AllGatherDone op that synchronizes and returns the result. 689 static std::unique_ptr<HloInstruction> CreateAllGatherStart( 690 const Shape& shape, absl::Span<HloInstruction* const> operands, 691 int64_t all_gather_dimension, 692 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 693 const absl::optional<int64>& channel_id, bool use_global_device_ids); 694 695 // Creates a cross replica reduction op. 696 // 697 // `reduction_computation`: the reduction function. 698 // 699 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 700 // empty, all replicas belong to one group in the order of 0 - (n-1). 701 // Allreduce will be applied within subgroups. 702 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 703 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 704 // 705 // `channel_id`: for Allreduce nodes from different modules, if 706 // they have the same channel_id, they will be 'Allreduce'd. If 707 // empty, Allreduce will not be applied cross modules. 708 static std::unique_ptr<HloInstruction> CreateAllReduce( 709 const Shape& shape, absl::Span<HloInstruction* const> operands, 710 HloComputation* reduce_computation, 711 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 712 const absl::optional<int64>& channel_id, bool use_global_device_ids); 713 714 // Creates a reduce-scatter operation which reduces its inputs across the 715 // given replica groups and then scatters the reduced data across the N 716 // participants. 717 static std::unique_ptr<HloInstruction> CreateReduceScatter( 718 const Shape& shape, absl::Span<HloInstruction* const> operands, 719 HloComputation* reduce_computation, 720 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 721 const absl::optional<int64>& channel_id, bool use_global_device_ids, 722 int64_t scatter_dimension); 723 724 // Creates an asynchronous cross replica reduction op. 725 // 726 // `reduction_computation`: the reduction function. 727 // 728 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 729 // empty, all replicas belong to one group in the order of 0 - (n-1). 730 // Allreduce will be applied within subgroups. 731 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 732 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 733 // 734 // `channel_id`: for Allreduce nodes from different modules, if 735 // they have the same channel_id, they will be 'Allreduce'd. If 736 // empty, Allreduce will not be applied cross modules. 737 static std::unique_ptr<HloInstruction> CreateAllReduceStart( 738 const Shape& shape, absl::Span<HloInstruction* const> operands, 739 HloComputation* reduce_computation, 740 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 741 const absl::optional<int64>& channel_id, bool use_global_device_ids); 742 743 // An all-to-all op takes N array operands of the same shape and scatters them 744 // to N replicas. Each replica gathers the results into a tuple. 745 // 746 // For example, suppose we have 3 replicas, with replica i passing inputs 747 // [a_i, b_i, c_i] to its all-to-all op. Then the resulting tuples are 748 // 749 // replica 0: (a_0, a_1, a_2) 750 // replica 1: (b_0, b_1, b_2) 751 // replica 2: (c_0, c_1, c_2). 752 // 753 // If replica_groups is set, the op is sharded and the replicas are permuted. 754 // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}. 755 // Then each replica passes two operands, say [a_i, b_i], and the result is 756 // 757 // replica 0: (b_3, b_0) 758 // replica 1: (a_1, a_2) 759 // replica 2: (b_1, b_2) 760 // replica 3: (a_3, a_0). 761 // 762 // All replica groups must have the same number of elements, and the number of 763 // operands must be equal to the size of one replica group. Each replica must 764 // appear in exactly one group. 765 // 766 // Note that this instruction is different than the all-to-all op in 767 // xla_builder.h. The version in XlaBuilder takes one input and slices it, 768 // and then concatenates the results into a single array. This instruction 769 // takes multiple inputs and returns a tuple; it doesn't slice or concatenate. 770 // It is used to implement the higher-level instruction in XlaBuilder. 771 static std::unique_ptr<HloInstruction> CreateAllToAll( 772 const Shape& shape, absl::Span<HloInstruction* const> operands, 773 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 774 const absl::optional<int64>& channel_id, 775 const absl::optional<int64>& split_dimension = absl::nullopt); 776 777 // Creates a communication instruction that permutes data cross replicas. 778 // Data is sent/received according to the (source_replica_id, 779 // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a 780 // target_replica_id in any pair, the output on that replica is a tensor 781 // consists of 0(s) in `shape`. 782 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 783 const Shape& shape, HloInstruction* operand, 784 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 785 const absl::optional<int64_t>& channel_id); 786 787 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 788 const Shape& shape, HloInstruction* input, HloInstruction* output, 789 HloInstruction* input_start_indices, HloInstruction* output_start_indices, 790 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, 791 absl::Span<const std::vector<int64_t>> slice_sizes, 792 const absl::optional<int64_t>& channel_id); 793 794 // Creates a communication instruction that initiates the start of 795 // CollectivePermute. 796 static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( 797 const Shape& shape, HloInstruction* operand, 798 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 799 const absl::optional<int64_t>& channel_id); 800 801 static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( 802 const Shape& shape, HloInstruction* input, HloInstruction* output, 803 HloInstruction* input_start_indices, HloInstruction* output_start_indices, 804 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, 805 absl::Span<const std::vector<int64_t>> slice_sizes, 806 const absl::optional<int64_t>& channel_id); 807 808 // Creates an instruction that returns a U32 replica ID. 809 static std::unique_ptr<HloInstruction> CreateReplicaId( 810 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 811 812 // Creates an instruction that returns a U32 partition ID. 813 static std::unique_ptr<HloInstruction> CreatePartitionId( 814 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 815 816 // Creates a conversion instruction, where operand is the data to convert and 817 // shape is the target shape for the conversion. 818 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 819 HloInstruction* operand); 820 821 // Creates a bitcast instruction, where operand is the data to 822 // convert and shape is the target shape for the conversion. 823 static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape, 824 HloInstruction* operand); 825 826 // Creates a bitcast conversion instruction, where operand is the data to 827 // convert and shape is the target shape for the conversion. 828 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 829 const Shape& shape, HloInstruction* operand); 830 831 // Creates an infeed instruction, which reads data of the given shape from the 832 // Infeed interface of the device. infeed_shape is the shape of the data 833 // received from the infeed *not* the shape of the infeed instruction which 834 // is a tuple containing the infeed_shape and the TOKEN. 835 static std::unique_ptr<HloInstruction> CreateInfeed( 836 const Shape& infeed_shape, HloInstruction* token_operand, 837 const string& config); 838 839 // Creates an outfeed instruction, which outputs data. outfeed_shape is the 840 // shape of the data being outfed *not* the shape of the outfeed instruction 841 // which is a TOKEN. 842 static std::unique_ptr<HloInstruction> CreateOutfeed( 843 const Shape& outfeed_shape, HloInstruction* operand, 844 HloInstruction* token_operand, absl::string_view outfeed_config); 845 846 // Creates an asynchronous send instruction with the given channel id, which 847 // initiates sending the operand data to a unique receive instruction in 848 // another computation that has the same channel id. If is_host_transfer is 849 // true, then this Send operation transfers data to the host. 850 static std::unique_ptr<HloInstruction> CreateSend( 851 HloInstruction* operand, HloInstruction* token, int64_t channel_id, 852 bool is_host_transfer = false); 853 854 // Blocks until data transfer for the Send instruction (operand) is complete. 855 // The operand must be kSend. 856 static std::unique_ptr<HloInstruction> CreateSendDone( 857 HloInstruction* operand, bool is_host_transfer = false); 858 859 // Creates an asynchronous receive instruction with the given channel id, 860 // which allocates resources to receive data of the given shape from a unique 861 // send instruction in another computation that has the same channel id. If 862 // is_host_transfer is true, then this Send operation transfers data from the 863 // host. 864 static std::unique_ptr<HloInstruction> CreateRecv( 865 const Shape& shape, HloInstruction* token, int64_t channel_id, 866 bool is_host_transfer = false); 867 868 // Blocks until data transfer for the Recv instruction (operand) is complete 869 // and returns the receive buffer. The operand must be kRecv. 870 static std::unique_ptr<HloInstruction> CreateRecvDone( 871 HloInstruction* operand, bool is_host_transfer = false); 872 873 // Creates a slice instruction, where the operand is sliced by the given 874 // start/limit indices. 875 static std::unique_ptr<HloInstruction> CreateSlice( 876 const Shape& shape, HloInstruction* operand, 877 absl::Span<const int64> start_indices, 878 absl::Span<const int64> limit_indices, absl::Span<const int64> strides); 879 880 // Creates a slice instruction, where the first operand is sliced by 881 // start indices specified in the second operand, and by size specified in 882 // 'slice_sizes'. 883 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 884 const Shape& shape, HloInstruction* operand, 885 absl::Span<HloInstruction* const> start_indices, 886 absl::Span<const int64> slice_sizes); 887 888 // Creates a dynamic update slice instruction, which updates a slice 889 // of 'operand' with 'update' and 'start_indices'. 890 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 891 const Shape& shape, HloInstruction* operand, HloInstruction* update, 892 absl::Span<HloInstruction* const> start_indices); 893 894 // Creates a concatenate instruction, where the operands are concatenated on 895 // the provided dimension. 896 static std::unique_ptr<HloInstruction> CreateConcatenate( 897 const Shape& shape, absl::Span<HloInstruction* const> operands, 898 int64_t dimension); 899 900 // Creates a reduce instruction, where the computation (given by the handle) 901 // is applied successively to every element in operand. For example, let f be 902 // the function to apply, which takes 2 arguments, an accumulator and the 903 // current value. Let init be an initial value (which is normally chosen to be 904 // the identity element for f, e.g. 0 if f is addition). 905 // Then the reduce HLO will compute: 906 // f(f(init, value0), value1), ...) 907 static std::unique_ptr<HloInstruction> CreateReduce( 908 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 909 absl::Span<const int64> dimensions_to_reduce, 910 HloComputation* reduce_computation); 911 912 // A more general, multiple-argument version of the above. 913 // The function to apply, f, now takes N arguments: 914 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 915 // init_valueN], and returns an N-tuple. The performed computation is (for 916 // commutative and associative f operators) equivalent to: 917 // 918 // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) 919 // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, 920 // ..., inputN.value1) 921 // ... 922 static std::unique_ptr<HloInstruction> CreateReduce( 923 const Shape& shape, absl::Span<HloInstruction* const> operands, 924 absl::Span<HloInstruction* const> init_values, 925 absl::Span<const int64> dimensions_to_reduce, 926 HloComputation* reduce_computation); 927 928 // Creates a reduce-window instruction, where the computation (given 929 // by the handle) is applied window-wise at each valid window 930 // position in the operand. 931 static std::unique_ptr<HloInstruction> CreateReduceWindow( 932 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 933 const Window& window, HloComputation* reduce_computation); 934 935 // A more general, multiple-argument version of the above. 936 // The reduce_computation being applied,now takes N arguments: 937 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 938 // valueN], and returns an N-tuple. The operands and init_values now each 939 // contain a span of N input arrays and n initial values. 940 static std::unique_ptr<HloInstruction> CreateReduceWindow( 941 const Shape& shape, absl::Span<HloInstruction* const> operands, 942 absl::Span<HloInstruction* const> init_values, const Window& window, 943 HloComputation* reduce_computation); 944 945 // Creates a batch-norm-training instruction. 946 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 947 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 948 HloInstruction* offset, float epsilon, int64_t feature_index); 949 950 // Creates a batch-norm-inference instruction. 951 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 952 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 953 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 954 float epsilon, int64_t feature_index); 955 956 // Creates a batch-norm-grad instruction. 957 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 958 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 959 HloInstruction* mean, HloInstruction* variance, 960 HloInstruction* grad_output, float epsilon, int64_t feature_index); 961 962 // Creates a scatter computation that scatters the `source` array to the 963 // selected indices of each window. 964 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 965 const Shape& shape, HloInstruction* operand, HloComputation* select, 966 const Window& window, HloInstruction* source, HloInstruction* init_value, 967 HloComputation* scatter); 968 969 // Creates a broadcast instruction. 970 static std::unique_ptr<HloInstruction> CreateBroadcast( 971 const Shape& shape, HloInstruction* operand, 972 absl::Span<const int64> broadcast_dimensions); 973 974 // Creates a sequence of instructions that performs an explicit broadcast of 975 // the operand to the target shape. 976 // 977 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 978 // returned as a unique_ptr for API consistency with other factory methods in 979 // this interface. 980 // 981 // TODO(b/72173833) Ideally HloComputations would always be present, and so 982 // the adder being passed by the caller would not be necessary. 983 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 984 const Shape& output_shape, HloInstruction* operand, 985 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 986 adder); 987 988 // Creates a pad instruction, where the operand is padded on the edges and 989 // between the elements with the given padding value. 990 static std::unique_ptr<HloInstruction> CreatePad( 991 const Shape& shape, HloInstruction* operand, 992 HloInstruction* padding_value, const PaddingConfig& padding_config); 993 994 // Creates a reshape instruction, where the operand is flattened row-major 995 // order and then reshaped to the given result shape. 996 static std::unique_ptr<HloInstruction> CreateReshape( 997 const Shape& shape, HloInstruction* operand, 998 int64_t inferred_dimension = -1); 999 1000 // Creates a dynamic reshape instruction. Similar to reshape but dynamic 1001 // dimensions sizes are provided as additional variadic arguments. 1002 // 1003 // Precondition: dim_sizes.size() == shape.rank() 1004 static std::unique_ptr<HloInstruction> CreateDynamicReshape( 1005 const Shape& shape, HloInstruction* data_operand, 1006 absl::Span<HloInstruction* const> dim_sizes); 1007 1008 // Creates a transpose instruction which permutes the operand dimensions. 1009 static std::unique_ptr<HloInstruction> CreateTranspose( 1010 const Shape& shape, HloInstruction* operand, 1011 absl::Span<const int64> dimensions); 1012 1013 // Creates a n-ary sort op with a 'compare' computation which is used for 1014 // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, 1015 // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at 1016 // specific index positions which should be compared, and should return a 1017 // PRED. 'is_stable' specifies whether stable sorting is required. 1018 static std::unique_ptr<HloInstruction> CreateSort( 1019 const Shape& shape, int64_t dimension, 1020 absl::Span<HloInstruction* const> operands, HloComputation* compare, 1021 bool is_stable); 1022 1023 // Creates a while instruction, given a condition computation, a body 1024 // computation, and the initial value for the input of the computations. For 1025 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 1026 // corresponds to the C code below. 1027 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } 1028 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 1029 HloComputation* condition, 1030 HloComputation* body, 1031 HloInstruction* init); 1032 1033 static std::unique_ptr<HloInstruction> CreateConditional( 1034 const Shape& shape, HloInstruction* pred, 1035 HloInstruction* true_computation_arg, HloComputation* true_computation, 1036 HloInstruction* false_computation_arg, HloComputation* false_computation); 1037 1038 static std::unique_ptr<HloInstruction> CreateConditional( 1039 const Shape& shape, HloInstruction* branch_index, 1040 absl::Span<HloComputation* const> branch_computations, 1041 absl::Span<HloInstruction* const> branch_computation_args); 1042 1043 static std::unique_ptr<HloInstruction> CreateGather( 1044 const Shape& shape, HloInstruction* operand, 1045 HloInstruction* start_indices, 1046 const GatherDimensionNumbers& gather_dim_numbers, 1047 absl::Span<const int64> slice_sizes, bool indices_are_sorted); 1048 1049 static std::unique_ptr<HloInstruction> CreateScatter( 1050 const Shape& shape, HloInstruction* operand, 1051 HloInstruction* scatter_indices, HloInstruction* updates, 1052 HloComputation* update_computation, 1053 const ScatterDimensionNumbers& scatter_dim_numbers, 1054 bool indices_are_sorted, bool unique_indices); 1055 1056 // Creates a kDomain instruction which delimits an HLO domain which have 1057 // the provided user and operand side metadata. 1058 static std::unique_ptr<HloInstruction> CreateDomain( 1059 const Shape& shape, HloInstruction* operand, 1060 std::unique_ptr<DomainMetadata> operand_side_metadata, 1061 std::unique_ptr<DomainMetadata> user_side_metadata); 1062 1063 // Creates a fusion instruction. A fusion instruction contains one or more 1064 // fused instructions forming an expression with a single root 1065 // "fused_root". Additional instructions can be added to the fusion 1066 // instruction with the method FuseInstruction. 1067 static std::unique_ptr<HloInstruction> CreateFusion( 1068 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 1069 1070 static std::unique_ptr<HloInstruction> CreateFusion( 1071 const Shape& shape, FusionKind fusion_kind, 1072 absl::Span<HloInstruction* const> operands, 1073 HloComputation* fusion_computation); 1074 1075 // Creates a call instruction that applies the given computation on the given 1076 // operands. "shape" is the resultant shape. 1077 static std::unique_ptr<HloInstruction> CreateCall( 1078 const Shape& shape, absl::Span<HloInstruction* const> operands, 1079 HloComputation* computation); 1080 1081 // Creates a custom call instruction that applies the given custom call target 1082 // to the given operands. "opaque" can be an arbitrary string with a 1083 // backend-specific interpretation. "shape" is the resultant shape. 1084 static std::unique_ptr<HloInstruction> CreateCustomCall( 1085 const Shape& shape, absl::Span<HloInstruction* const> operands, 1086 absl::string_view custom_call_target, string opaque = "", 1087 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1088 1089 // Overload with a to_apply computation. 1090 static std::unique_ptr<HloInstruction> CreateCustomCall( 1091 const Shape& shape, absl::Span<HloInstruction* const> operands, 1092 HloComputation* to_apply, absl::string_view custom_call_target, 1093 string opaque = "", 1094 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1095 1096 // Overload with multiple computations. The called computations can have 1097 // different function signatures. 1098 static std::unique_ptr<HloInstruction> CreateCustomCall( 1099 const Shape& shape, absl::Span<HloInstruction* const> operands, 1100 absl::Span<HloComputation* const> called_computations, 1101 absl::string_view custom_call_target, string opaque = "", 1102 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1103 1104 // Overload which constrains the layouts of the operand and result. 'shape' 1105 // and 'operand_shapes_with_layout' must have layouts. 1106 // 'operand_shapes_with_layout' must have a compatible element for each 1107 // operand. 1108 static std::unique_ptr<HloInstruction> CreateCustomCall( 1109 const Shape& shape, absl::Span<HloInstruction* const> operands, 1110 absl::string_view custom_call_target, 1111 absl::Span<const Shape> operand_shapes_with_layout, string opaque = "", 1112 CustomCallApiVersion api_version = API_VERSION_ORIGINAL); 1113 1114 // Creates a tuple instruction with the given elements. This is a convenience 1115 // wrapper around CreateVariadic. 1116 static std::unique_ptr<HloInstruction> CreateTuple( 1117 absl::Span<HloInstruction* const> elements); 1118 1119 // Creates a reverse instruction, which reverses the order of the elements 1120 // in the specified dimensions. 1121 static std::unique_ptr<HloInstruction> CreateReverse( 1122 const Shape& shape, HloInstruction* operand, 1123 absl::Span<const int64> dimensions); 1124 1125 // Creates a Afterall instruction used for joining or creating new values of 1126 // token type which thread through side-effecting operations. Operands must 1127 // all be tokens, and there must be at least one operand. 1128 static std::unique_ptr<HloInstruction> CreateAfterAll( 1129 absl::Span<HloInstruction* const> operands); 1130 1131 // Creates an AfterAll instruction which creates a token type out of thin air 1132 // (no operands). This is a separate method from CreateAfterAll to facility 1133 // the removal of operand-less AfterAll instructions. 1134 // TODO(b/110532604): Remove this capability of creating a token from nothing 1135 // when we plumb a primordial token from the entry computation. 1136 static std::unique_ptr<HloInstruction> CreateToken(); 1137 1138 static std::unique_ptr<HloInstruction> CreateGetDimensionSize( 1139 const Shape& shape, HloInstruction* operand, int64_t dimension); 1140 1141 static std::unique_ptr<HloInstruction> CreateSetDimensionSize( 1142 const Shape& shape, HloInstruction* operand, HloInstruction* val, 1143 int64_t dimension); 1144 1145 static std::unique_ptr<HloInstruction> CreateAddDependency( 1146 HloInstruction* data_operand, HloInstruction* token_operand); 1147 1148 // Returns the opcode for this instruction. opcode()1149 HloOpcode opcode() const { return opcode_; } mutable_opcode()1150 HloOpcode* mutable_opcode() { return &opcode_; } 1151 1152 // Returns true if this instruction has a side effect, irrespective of whether 1153 // any called computations may contain an instruction with side effects. 1154 bool HasSideEffectNoRecurse() const; 1155 1156 // Returns true if this instruction has a side effect. An instruction has a 1157 // side effect if it uses certain opcodes or calls a computation with a side 1158 // effect. 1159 bool HasSideEffect() const; 1160 1161 // Returns the result shape of this instruction. 1162 const Shape& shape() const; 1163 1164 // Returns the (mutable) result shape of this instruction. mutable_shape()1165 Shape* mutable_shape() { return &shape_; } 1166 1167 // Returns the ith operand to this instruction. 1168 const HloInstruction* operand(int64_t i) const; 1169 1170 // Returns the ith operand to this instruction. 1171 HloInstruction* mutable_operand(int64_t i); 1172 1173 // Returns the number of operands to this instruction. operand_count()1174 int64 operand_count() const { return operands_.size(); } 1175 1176 // Returns the vector of operands of this instruction. 1177 using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; operands()1178 const InstructionVector& operands() const { return operands_; } 1179 1180 // Returns the vector of unique operands, in the same order they are found 1181 // within the operand vector. 1182 InstructionVector unique_operands() const; 1183 1184 // Returns the index of 'target' in the operands sequence. 1185 // Precondition: target must be an operand (or a fatal error will occur). 1186 int64 operand_index(const HloInstruction* target) const; 1187 1188 // Returns the number of users of this instruction. user_count()1189 int64 user_count() const { return users_.size(); } 1190 1191 // Returns the users of this instruction. users()1192 const std::vector<HloInstruction*>& users() const { return users_; } 1193 1194 // Returns the index of the user in the users() vector. 1195 // 1196 // Precondition: `user` is a user of the instruction. 1197 int64 UserId(HloInstruction* user); 1198 1199 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)1200 bool IsUserOf(const HloInstruction* instruction) const { 1201 return ContainsKey(instruction->user_map_, this); 1202 } 1203 1204 // Adds a control dependency from this instruction to the given 1205 // instruction. This instruction becomes a control predecessor of 1206 // 'instruction', and 'instruction' becomes a control successor of this 1207 // instruction. Returns an error status if either of the given instructions 1208 // does not belong to the same computation. 1209 // 1210 // This is used to enforce an additional ordering requirement that is not 1211 // captured by normal data dependencies, such as ordering among Send or Recv 1212 // operations to avoid deadlock. 1213 Status AddControlDependencyTo(HloInstruction* instruction); 1214 1215 // Removes a previously added control dependency from this instruction to 1216 // 'instruction'. 1217 Status RemoveControlDependencyTo(HloInstruction* instruction); 1218 1219 // Drops all control predecessors and successors from this HLO instruction. 1220 Status DropAllControlDeps(); 1221 1222 // Copies the control predecessors and successors on this HLO instruction to 1223 // `inst`. Does not do a deep copy so this makes sense only if `inst` and 1224 // this HLO are in the same module. 1225 // 1226 // Depending on the use cases we see in practice, in the future we may 1227 // consider folding the logic here into Clone, CloneWithNewOperands and 1228 // ReplaceAllUsesWith by treating control dependencies like data dependencies. 1229 Status CopyAllControlDepsFrom(const HloInstruction* inst); 1230 1231 // Returns the set of control predecessors (successors) of this 1232 // instruction. Control predecessors (successors) must execute before (after) 1233 // the current instruction. control_predecessors()1234 const std::vector<HloInstruction*>& control_predecessors() const { 1235 return control_predecessors_; 1236 } control_successors()1237 const std::vector<HloInstruction*>& control_successors() const { 1238 return control_successors_; 1239 } 1240 1241 // Returns true if "other" performs the same computation as this instruction. 1242 bool Identical( 1243 const HloInstruction& other, 1244 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1245 eq_operands = std::equal_to<const HloInstruction*>(), 1246 const std::function<bool(const HloComputation*, const HloComputation*)>& 1247 eq_computations = std::equal_to<const HloComputation*>(), 1248 bool layout_sensitive = true) const { 1249 return IdenticalInternal(other, eq_operands, eq_computations, 1250 layout_sensitive, 1251 /*ignore_channel_id_values=*/false); 1252 } 1253 1254 // Same as Identical() but ignores channel ID value mismatches, as long as 1255 // both have channel IDs or neither has a channel ID. 1256 bool IdenticalIgnoringChannelIdValues( 1257 const HloInstruction& other, 1258 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1259 eq_operands = std::equal_to<const HloInstruction*>(), 1260 const std::function<bool(const HloComputation*, const HloComputation*)>& 1261 eq_computations = std::equal_to<const HloComputation*>(), 1262 bool layout_sensitive = true) const { 1263 return IdenticalInternal(other, eq_operands, eq_computations, 1264 layout_sensitive, 1265 /*ignore_channel_id_values=*/true); 1266 } 1267 1268 // Generates a hash value of an HLO instruction. Hash considers 1269 // information on opcode, shape, operands, and typically a root instruction. 1270 // This function returns the same hash value for equivalent HLO instructions, 1271 // with respect to HloInstruction::Identical() method. 1272 // 1273 // Uses hash_operand function to compute hash values of its operands. 1274 // At the very top level, hash_operand should be non-recursive to prevent 1275 // non-termination. 1276 uint64 Hash( 1277 const std::function<uint64(const HloInstruction*)>& hash_operand) const; 1278 1279 // Calls the above method with non-recursive hash_operand function. 1280 uint64 Hash() const; 1281 1282 // Returns whether the instruction has a constant operand. 1283 bool HasConstantOperand() const; 1284 1285 // Replaces the use of this instruction in "user" with "new_producer". Note 1286 // that there might be multiple uses of this instruction in "user"; all will 1287 // be replaced. 1288 // 1289 // If user is a fusion instruction, this function will remove any duplicated 1290 // operands of it which could be created due to this replacement. 1291 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 1292 1293 // Same as ReplaceUseWith(), but new_producer can have a different shape. 1294 Status ReplaceUseWithDifferentShape(HloInstruction* user, 1295 HloInstruction* new_producer); 1296 1297 // Replaces the specified operand with new_operand. The old and new operands 1298 // must have compatible shapes ignoring floating-point precision. 1299 // 1300 // This function does NOT remove duplicated operands even if this instruction 1301 // is a fusion, so that the existing operand numbers do not change. 1302 Status ReplaceOperandWith(int64_t operand_num, HloInstruction* new_operand); 1303 1304 // Same as ReplaceOperandWith(), but new_operand can have a different shape. 1305 Status ReplaceOperandWithDifferentShape(int64_t operand_num, 1306 HloInstruction* new_operand); 1307 1308 // Replaces all uses of this instruction with the new producer. If 1309 // new_producer is a user of this instruction then new_producer remains a use 1310 // of this instruction to avoid introducing cycles into the graph. 1311 // 1312 // If this instruction is the root of its computation, sets the computation's 1313 // root to new_producer. 1314 // 1315 // The new producer must have a compatible shape ignoring floating-point 1316 // precision. 1317 // 1318 // If a user is a fusion instruction, this function will remove any duplicated 1319 // operands of it which could be created due to this replacement. 1320 Status ReplaceAllUsesWith(HloInstruction* new_producer); 1321 1322 // Same as ReplaceAllUsesWith, but new_producer can have a different shape. 1323 Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); 1324 1325 // Same as ReplaceAllUsesWith, but only replace given set of users. 1326 Status ReplaceUsesWith(absl::Span<HloInstruction* const> users, 1327 HloInstruction* new_producer); 1328 Status ReplaceAllUsesWithDifferentShape( 1329 absl::Span<HloInstruction* const> users, HloInstruction* new_producer); 1330 1331 // Performs a postorder DFS visit using this node as the root. If 1332 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 1333 // complete. If ignore_control_predecessors is true, instructions only 1334 // reachable via control dependencies will not be visited, and the postorder 1335 // will not take control dependencies into account. It is as if the control 1336 // dependencies didn't exist in the graph at all. 1337 template <typename HloInstructionPtr> 1338 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 1339 bool call_finish_visit = true, 1340 bool ignore_control_predecessors = false); 1341 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 1342 bool ignore_control_predecessors = false) const { 1343 return const_cast<HloInstruction*>(this)->Accept( 1344 visitor, call_finish_visit, ignore_control_predecessors); 1345 } 1346 1347 // Same as Accept() above, but the order of operand and control predecessor 1348 // visitation is determined by the given operand order; if compare(A, B) == 1349 // true, A is visited before B. 1350 using CompareFunction = 1351 std::function<bool(const HloInstruction*, const HloInstruction*)>; 1352 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 1353 const CompareFunction& operand_order, 1354 bool call_finish_visit = true); 1355 1356 // Visit this instruction and only this instruction with the given visitor. 1357 template <typename HloInstructionPtr> 1358 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 1359 1360 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 1361 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 1362 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 1363 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 1364 const; 1365 LatestNonGteAncestorAndIndex()1366 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 1367 auto rv = 1368 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 1369 return {const_cast<HloInstruction*>(rv.first), rv.second}; 1370 } 1371 1372 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 1373 const HloInstruction* LatestNonGteAncestor() const; 1374 LatestNonGteAncestor()1375 HloInstruction* LatestNonGteAncestor() { 1376 return const_cast<HloInstruction*>( 1377 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 1378 } 1379 1380 // Returns true whether this instruction is effectively a bitcast. Currently, 1381 // this means it either is a bitcast, or it is a transpose that is effectively 1382 // a bitcast. 1383 bool IsEffectiveBitcast() const; 1384 1385 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 1386 // The setter should only be called by HloModule or HloComputation methods. 1387 // 1388 // Precondition: The instruction has a valid to_apply_ field. 1389 HloComputation* to_apply() const; 1390 void set_to_apply(HloComputation* to_apply); 1391 1392 // Gets/sets the while_condition or while_body HloComputation for While. The 1393 // setters should only be called by HloModule or HloComputation methods. 1394 // 1395 // Precondition: The instruction is a While instruction. 1396 HloComputation* while_condition() const; 1397 HloComputation* while_body() const; 1398 void set_while_condition(HloComputation* while_condition); 1399 void set_while_body(HloComputation* while_body); 1400 1401 HloInstruction* while_init() const; 1402 1403 // Gets/sets the true and false HloComputation for Conditional. 1404 // 1405 // Precondition: The instruction is a predicated Conditional instruction. 1406 HloComputation* true_computation() const; 1407 HloComputation* false_computation() const; 1408 1409 // Gets the branch HloComputations for Conditional. 1410 // 1411 // Precondition: The instruction is a Conditional instruction. 1412 const std::vector<HloComputation*>& branch_computations() const; 1413 int branch_count() const; 1414 HloComputation* branch_computation(int b) const; 1415 // Sets a branch HloComputation for Conditional. 1416 // The setter should only be called by HloModule or HloComputation methods. 1417 // 1418 // Precondition: The instruction is a Conditional instruction. 1419 void set_branch_computation(int b, HloComputation* computation); 1420 1421 // Returns a string for the signature of this instruction if considered as a 1422 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 1423 string SignatureString() const; 1424 1425 // Returns a debugging string that represents this instruction. 1426 // 1427 // (We express the default options using an overload rather than a default 1428 // param because gdb ignores default params, but does resolve overloads.) 1429 // 1430 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 1431 // default, backing off on providing full information for very large strings, 1432 // or provide a different name for a ToString-like function that does that. ToString()1433 string ToString() const { return ToString(HloPrintOptions()); } 1434 string ToString(const HloPrintOptions& options) const; 1435 1436 // Components of the ToString() representation: 1437 1438 // Returns a string representation of the operand list. 1439 string OperandsToString(const HloPrintOptions& options) const; 1440 1441 // Returns string representation of op-specific attributes. 1442 std::vector<string> ExtraAttributesToString( 1443 const HloPrintOptions& options) const; 1444 1445 // As ToString, but returns a shorter string. 1446 string ToShortString() const; 1447 1448 // Prints an instruction to a string. 1449 // 1450 // The canonical string representation needs to name operands and instruction 1451 // names in a consistent way. This is implemented through the 1452 // canonical_name_map. 1453 string ToStringWithCanonicalNameMap( 1454 const HloPrintOptions& options, 1455 CanonicalNameMap* canonical_name_map) const; 1456 1457 // Returns a serialized representation of this instruction. 1458 virtual HloInstructionProto ToProto() const; 1459 1460 // Returns a category for the HLO. This could be something like "convolution" 1461 // or "elementwise". 1462 virtual string ToCategory() const; 1463 1464 // Returns a logging instruction, if the output of this instruction is logged. 1465 // 1466 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace 1467 HloInstruction* tracing() const; 1468 void set_tracing(HloInstruction* trace_instruction); 1469 1470 // Returns true if this instruction is fused, ie contained within a fusion 1471 // instruction. 1472 bool IsFused() const; 1473 1474 bool IsLoopFusion() const; 1475 bool IsInputFusion() const; 1476 bool IsOutputFusion() const; 1477 bool IsCustomFusion() const; 1478 1479 // Returns true if this instruction can be legally fused into a fusion 1480 // instruction. 1481 bool IsFusible() const; 1482 1483 bool IsCustomCall(absl::string_view target) const; 1484 1485 // Returns the sharding applied to this operator. 1486 // REQUIRES: has_sharding() is true. sharding()1487 const HloSharding& sharding() const { 1488 CHECK(has_sharding()); 1489 return *sharding_; 1490 } sharding_ptr()1491 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } 1492 1493 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)1494 const HloSharding& sharding_or_default(const HloSharding& default_) const { 1495 return sharding_ ? *sharding_ : default_; 1496 } 1497 // Returns the sharding unique device, if any. sharding_unique_device()1498 absl::optional<int64> sharding_unique_device() const { 1499 if (sharding_ == nullptr) { 1500 return absl::optional<int64>(); 1501 } 1502 return sharding_->UniqueDevice(); 1503 } 1504 // Sets the sharding of this operator. Should only be called by HloModule or 1505 // HloComputation methods. set_sharding(const HloSharding & sharding)1506 void set_sharding(const HloSharding& sharding) { 1507 sharding_ = std::make_shared<const HloSharding>(sharding); 1508 } set_sharding(std::shared_ptr<const HloSharding> sharding)1509 void set_sharding(std::shared_ptr<const HloSharding> sharding) { 1510 sharding_ = std::move(sharding); 1511 } 1512 void set_single_sharding(const HloSharding& sharding); 1513 // Sets a sharding that assigns the current instruction to device. set_device_sharding(int64_t device)1514 void set_device_sharding(int64_t device) { 1515 set_single_sharding(HloSharding::AssignDevice(device)); 1516 } 1517 // Remove any sharding from this operator. clear_sharding()1518 void clear_sharding() { sharding_ = nullptr; } 1519 // Return true if this operator has a sharding assigned. has_sharding()1520 bool has_sharding() const { return sharding_ != nullptr; } 1521 // Checks whether the instruction has compatible sharding with the other 1522 // instruction. has_compatible_sharding(const HloInstruction * other)1523 bool has_compatible_sharding(const HloInstruction* other) const { 1524 if (!has_sharding()) { 1525 return !other->has_sharding(); 1526 } 1527 return other->has_sharding() ? sharding() == other->sharding() : false; 1528 } 1529 1530 // When creating a new instruction which either replaces, or shifts up (kCopy 1531 // insertion case), another instruction, we need to make sure the certain 1532 // properties of the new instruction are copied into the derived one. As of 1533 // today, the metadata and sharding will be propagated to the derived 1534 // instruction. 1535 void SetupDerivedInstruction(HloInstruction* derived_instruction) const; 1536 1537 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1538 // operands. After creation the clone has no uses. "this" (the instruction 1539 // cloned from) is not changed. Suffix is the string to append to the name of 1540 // the instruction to form the name of the cloned instruction. 1541 // Ignores the control predecessors and successors of this HLO instruction. 1542 std::unique_ptr<HloInstruction> Clone( 1543 const string& suffix = "clone", HloCloneContext* context = nullptr) const; 1544 1545 // Clones the HLO instruction as above but with new shape. 1546 std::unique_ptr<HloInstruction> CloneWithNewShape( 1547 const Shape& shape, const string& suffix = "clone", 1548 HloCloneContext* context = nullptr) const; 1549 1550 // Clones the HLO instruction as above but with new shape and operands. 1551 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1552 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1553 HloCloneContext* context = nullptr) const; 1554 1555 // Returns the computations this instruction directly calls (if any). called_computations()1556 const std::vector<HloComputation*>& called_computations() const { 1557 return called_computations_; 1558 } 1559 1560 // Replaces all called computations based on a map function. This is needed 1561 // when we clone hlo_computations and want to let the instructions to point 1562 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1563 void ReplaceCalledComputations( 1564 std::function<HloComputation*(HloComputation*)> map_function) { 1565 for (int64_t i = 0; i < called_computations_.size(); ++i) { 1566 called_computations_[i] = map_function(called_computations_[i]); 1567 } 1568 } 1569 1570 // Clears out the called computations. 1571 // 1572 // This is, in particular, necessary when inlining function bodies into their 1573 // caller. If there were side-effecting operations in the called computations, 1574 // the call itself is considered side-effecting and thus cannot be removed. By 1575 // clearing out the computations, we reflect the fact that all side-effecting 1576 // properties have been reflected in the caller, and make the call HLO 1577 // removable. ClearCalledComputations()1578 virtual void ClearCalledComputations() { called_computations_.clear(); } 1579 1580 // Returns true if this instruction performs an elementwise operation on 1581 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1582 // to compute the output at index {i_0,i_1,...,i_n}, the only element required 1583 // from the operand (if any) is the element at {i_0,i_1,...,i_n}. 1584 // 1585 // Note on performance: when this instruction is kFusion, this method, in the 1586 // worst case, scans all fused instructions. We could speed this up by 1587 // caching. 1588 bool IsElementwiseOnOperand(int64_t operand_idx) const; 1589 1590 // Returns true if this instruction is elementwise on all its operands. 1591 bool IsElementwise() const; 1592 1593 static bool IsOpElementwise(HloOpcode opcode); 1594 1595 // Returns true if this is a cross module all-reduce instruction. 1596 bool IsCrossModuleAllReduce() const; 1597 1598 // Returns true if this is a cross-replica all-reduce instruction. 1599 bool IsCrossReplicaAllReduce() const; 1600 1601 // Returns true if this instruction is binary and elementwise. 1602 bool IsElementwiseBinary() const; 1603 1604 // Returns whether this instruction may reuse elements of its `i`th operand. ReusesOperandElements(int64_t i)1605 bool ReusesOperandElements(int64_t i) const { 1606 return OperandElementUse(i) == UseKind::kReuse; 1607 } 1608 1609 // Returns the indices that the given operand appear in the operand list of 1610 // this instruction. Note that an instruction can use the same operand 1611 // multiple times. 1612 absl::InlinedVector<int64, 4> OperandIndices( 1613 const HloInstruction* operand) const; 1614 1615 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1616 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1617 // indices of the deleted dimensions and the output indices of the inserted 1618 // dimensions. 1619 // 1620 // Precondition: this op must be a reshape. 1621 std::tuple<bool, std::vector<int64>, std::vector<int64>> 1622 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1623 1624 // Gets the string identifier for this instruction. name()1625 const string& name() const { return name_; } 1626 1627 // Sets the string identifier for this instruction. Name will be sanitized to 1628 // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". SetAndSanitizeName(const string & name)1629 void SetAndSanitizeName(const string& name) { 1630 name_ = NameUniquer::GetSanitizedName(name); 1631 } 1632 1633 // Use the given NameUniquer to select a unique name for the instruction based 1634 // on the instruction's existing name. 1635 void UniquifyName(NameUniquer* name_uniquer); 1636 1637 // Clear the unique ID of the instruction so that it can be re-assigned, such 1638 // as for the purpose of compacting the instruction unique IDs. ClearUniqueIdInternal()1639 void ClearUniqueIdInternal() { unique_id_ = -1; } 1640 1641 // Set the unique id for this instruction to "id" SetUniqueId(int id)1642 void SetUniqueId(int id) { 1643 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1644 CHECK_GE(id, 0); 1645 unique_id_ = id; 1646 } 1647 1648 // Return the unique ID assigned to this node via SetUniqueId (or -1 1649 // if no id has been assigned yet). unique_id()1650 int unique_id() const { return unique_id_; } 1651 1652 // Returns the backend-specific configuration for how a backend should compile 1653 // this HLO. The meaning of the field is backend specific. Not for use before 1654 // or during general HLO optimization, since HLO optimizations do not preserve 1655 // this field and they cannot interpret it due to its meaning being backend 1656 // specific. Except for CustomCall, where this field is preserved and no 1657 // general HLO optimization needs to interpret it. 1658 // 1659 // ConfigProto should be a protobuf Message type. 1660 template <typename ConfigProto> backend_config()1661 StatusOr<ConfigProto> backend_config() const { 1662 ConfigProto proto; 1663 TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); 1664 return std::move(proto); 1665 } 1666 Status set_backend_config(const tensorflow::protobuf::Message& proto); 1667 set_frontend_attributes(FrontendAttributes frontend_attributes)1668 void set_frontend_attributes(FrontendAttributes frontend_attributes) { 1669 frontend_attributes_ = std::move(frontend_attributes); 1670 } 1671 add_frontend_attributes(FrontendAttributes frontend_attributes)1672 void add_frontend_attributes(FrontendAttributes frontend_attributes) { 1673 frontend_attributes_.mutable_map()->insert( 1674 frontend_attributes.map().begin(), frontend_attributes.map().end()); 1675 } 1676 frontend_attributes()1677 const FrontendAttributes& frontend_attributes() const { 1678 return frontend_attributes_; 1679 } 1680 1681 // Getter/setter for raw JSON-encoded backend config. Prefer the 1682 // functions above that deal in proto Messages where possible. raw_backend_config_string()1683 const string& raw_backend_config_string() const { return backend_config_; } set_raw_backend_config_string(string config_str)1684 void set_raw_backend_config_string(string config_str) { 1685 backend_config_ = std::move(config_str); 1686 } 1687 is_default_config()1688 bool is_default_config() const { return is_default_config_; } set_default_config()1689 void set_default_config() { is_default_config_ = true; } 1690 1691 // Returns a string representation of a proto in the format used by 1692 // raw_backend_config_string. 1693 // 1694 // This is morally equivalent to: 1695 // 1696 // HloInstruction instr; 1697 // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); 1698 // return instr.raw_backend_config_string(); 1699 // 1700 static StatusOr<string> BackendConfigToRawString( 1701 const tensorflow::protobuf::Message& proto); 1702 1703 // Returns the information used to tell the implementation information about 1704 // what sort of precision is requested. The meaning of the field is backend 1705 // specific. At the moment, it is only supported for kConvolution and kDot. 1706 // Transformations on one kDot or kConvolution to another will preserve this 1707 // information. Transformations to other HLOs will not preserve this 1708 // information but it is presumed that the alternate lowering is strictly 1709 // superior. 1710 // Precondition: opcode must be kConvolution or kDot. 1711 const PrecisionConfig& precision_config() const; 1712 PrecisionConfig* mutable_precision_config(); 1713 1714 // Sets the debug metadata for this instruction, excluding creation_pass_id, 1715 // which should never be copied anywhere. set_metadata(const OpMetadata & metadata)1716 void set_metadata(const OpMetadata& metadata) { 1717 int64_t creation_pass_id = metadata_.creation_pass_id(); 1718 metadata_ = metadata; 1719 metadata_.set_creation_pass_id(creation_pass_id); 1720 } 1721 set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes)1722 void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) { 1723 metadata_.set_size_of_generated_code_in_bytes(code_size_in_bytes); 1724 } set_size_of_memory_working_set_in_bytes(int64_t working_set_size_in_bytes)1725 void set_size_of_memory_working_set_in_bytes( 1726 int64_t working_set_size_in_bytes) { 1727 metadata_.set_size_of_memory_working_set_in_bytes( 1728 working_set_size_in_bytes); 1729 } set_creation_pass_id(int64_t pass_id)1730 void set_creation_pass_id(int64_t pass_id) { 1731 metadata_.set_creation_pass_id(pass_id); 1732 } set_metadata_op_name(const std::string & name)1733 void set_metadata_op_name(const std::string& name) { 1734 metadata_.set_op_name(name); 1735 } set_logical_creation_pass_id(int64_t pass_id)1736 void set_logical_creation_pass_id(int64_t pass_id) { 1737 metadata_.set_logical_creation_pass_id(pass_id); 1738 } metadata()1739 const OpMetadata& metadata() const { return metadata_; } 1740 1741 // Set/get the computation containing this instruction. set_parent should only 1742 // be called by HloComputation methods which add/remove instructions to 1743 // computations. set_parent(HloComputation * computation)1744 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1745 const HloComputation* parent() const { return parent_; } parent()1746 HloComputation* parent() { return parent_; } 1747 1748 // Returns the module for this instruction. 1749 HloModule* GetModule() const; 1750 1751 // Get/Set the number of partitions per outer dimension (in order, starting 1752 // with outer-most dimension first). Currently used by the parallel cpu 1753 // backend to partition HLOs into parallel tasks. 1754 // 1755 // TODO(b/62783254) Replace these methods with a more general way to 1756 // annotate HLOs with backend-specific information. outer_dimension_partitions()1757 const std::vector<int64>& outer_dimension_partitions() const { 1758 return outer_dimension_partitions_; 1759 } 1760 void set_outer_dimension_partitions( 1761 const std::vector<int64>& outer_dimension_partitions); 1762 1763 // Old methods kept for smooth subclassing transition BEGIN. 1764 // TODO(b/80131774): Remove this code. 1765 1766 // Delegates to HloBatchNormInstruction::feature_index. 1767 int64 feature_index() const; 1768 1769 // Delegates to HloBatchNormInstruction::epsilon. 1770 float epsilon() const; 1771 1772 // Delegates to HloFftInstruction::fft_type. 1773 FftType fft_type() const; 1774 1775 // Delegates to HloFftInstruction::fft_length. 1776 const std::vector<int64>& fft_length() const; 1777 1778 // Delegates to HloChannelInstruction::channel_id. 1779 absl::optional<int64> channel_id() const; 1780 void set_channel_id(const absl::optional<int64>& channel_id); 1781 1782 // Returns the dimension sizes or numbers associated with this instruction. dimensions()1783 virtual const std::vector<int64>& dimensions() const { 1784 LOG(FATAL) << "Unimplemented method."; 1785 } dimensions(int64_t index)1786 virtual int64 dimensions(int64_t index) const { 1787 LOG(FATAL) << "Unimplemented method."; 1788 } mutable_dimensions()1789 virtual std::vector<int64>* mutable_dimensions() { 1790 LOG(FATAL) << "Unimplemented method."; 1791 } 1792 1793 // Delegates to HloConcatenateInstruction::concatenate_dimension. 1794 int64 concatenate_dimension() const; 1795 1796 // Delegates to HloGetDimensionSizeInstruction::dimension. 1797 int64 dimension() const; 1798 1799 // Delegates to HloReshapeInstruction::inferred_dimension. 1800 int64 inferred_dimension() const; 1801 1802 // Returns whether this instruction does a rank-2 transposition. 1803 bool IsRank2Transpose() const; 1804 1805 // Delegates to HloSliceInstruction::slice_start. 1806 int64 slice_starts(int64_t dimension) const; 1807 const std::vector<int64>& slice_starts() const; 1808 std::vector<int64>* mutable_slice_starts(); 1809 1810 // Delegates to HloSliceInstruction::slice_limits. 1811 int64 slice_limits(int64_t dimension) const; 1812 const std::vector<int64>& slice_limits() const; 1813 std::vector<int64>* mutable_slice_limits(); 1814 1815 // Delegates to HloSliceInstruction::slice_strides. 1816 int64 slice_strides(int64_t dimension) const; 1817 const std::vector<int64>& slice_strides() const; 1818 std::vector<int64>* mutable_slice_strides(); 1819 1820 // Returns the literal associated with this instruction. 1821 const Literal& literal() const; 1822 1823 // Returns whether the instruction is a constant. 1824 bool IsConstant() const; 1825 1826 // Delegate to HloConstantInstruction::RelayoutConstant. 1827 void RelayoutConstant(const Layout& new_layout, 1828 const ShapeIndex& shape_index = {}); 1829 1830 // Delegates to HloTraceInstruction::TracingTag. 1831 string TracingTag() const; 1832 1833 // Delegates to HloFusionInstruction::AddFusionOperand. 1834 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 1835 1836 // Delegates to HloFusionInstruction::MergeFusionInstruction. 1837 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 1838 1839 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. 1840 void MergeFusionInstructionIntoMultiOutput( 1841 HloInstruction* instruction_to_merge); 1842 1843 // Delegates to HloFusionInstruction::FuseInstruction. 1844 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); 1845 1846 // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. 1847 HloInstruction* FuseInstructionIntoMultiOutput( 1848 HloInstruction* instruction_to_fuse); 1849 1850 // Delegates to HloFusionInstruction::fused_instruction. 1851 HloComputation* fused_instructions_computation() const; 1852 1853 // Delegates to HloFusionInstruction::fused_expression_root. 1854 HloInstruction* fused_expression_root() const; 1855 1856 // Delegates to HloFusionInstruction::fused_instructions. 1857 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1858 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1859 fused_instructions() const; 1860 1861 const tensorflow::gtl::iterator_range< 1862 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1863 fused_instructions(); 1864 1865 // Delegates to HloFusionInstruction::fused_instruction_count. 1866 int64 fused_instruction_count() const; 1867 1868 // Delegates to HloFusionInstruction::fused_parameter. 1869 HloInstruction* fused_parameter(int64_t parameter_number) const; 1870 1871 // Delegates to HloFusionInstruction::fused_parameters. 1872 const std::vector<HloInstruction*>& fused_parameters() const; 1873 1874 // Returns true if this instruction is a fusion instruction that generates 1875 // multiple outputs. 1876 const bool IsMultiOutputFusion() const; 1877 1878 // Delegates to HloFusionInstruction::fusion_kind. 1879 FusionKind fusion_kind() const; 1880 1881 // Delegates to HloFusionInstruction::set_fusion_kind. 1882 void set_fusion_kind(FusionKind kind); 1883 1884 // Delegates to HloRngInstruction::random_distribution. 1885 RandomDistribution random_distribution() const; 1886 1887 // Delegates to HloParameterInstruction::parameter_number. 1888 int64 parameter_number() const; 1889 1890 // Delegates to 1891 // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. 1892 void set_parameter_replicated_at_leaf_buffers( 1893 absl::Span<const bool> parameter_replicated_at_leaf_buffers); 1894 void set_parameter_replicated_at_leaf_buffers( 1895 const std::vector<bool>& parameter_replicated_at_leaf_buffers); 1896 1897 // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. 1898 const absl::optional<std::vector<bool>>& 1899 parameter_replicated_at_leaf_buffers() const; 1900 1901 // Delegates to HloGetTupleElementInstruction::tuple_index. 1902 int64 tuple_index() const; 1903 1904 // Delegates to HloGetTupleElementInstruction::set_tuple_index. 1905 void set_tuple_index(int64_t new_tuple_index); 1906 1907 // Delegates to HloReducePrecisionInstruction::exponent_bits. 1908 int32 exponent_bits() const; 1909 1910 // Delegates to HloReducePrecisionInstruction::mantissa_bits. 1911 int32 mantissa_bits() const; 1912 1913 // Delegates to HloInfeedInstruction::infeed_config. 1914 string infeed_config() const; 1915 1916 // Delegates to HloInfeedInstruction::set_infeed_config. 1917 void set_infeed_config(const string& config); 1918 1919 // Returns the config for the Outfeed instruction. 1920 const string& outfeed_config() const; 1921 1922 // Delegates to HloOutfeedInstruction::set_outfeed_config. 1923 void set_outfeed_config(const string& config); 1924 1925 // Returns the shape for the Outfeed instruction. 1926 const Shape& outfeed_shape() const; 1927 1928 // Returns the mutable shape for the Outfeed instruction. 1929 Shape* mutable_outfeed_shape(); 1930 1931 // Delegates to HloCollectiveInstruction::replica_groups. 1932 const std::vector<ReplicaGroup>& replica_groups() const; 1933 1934 // Delegates to HloCollectivePermuteInstruction::source_target_pairs. 1935 const std::vector<std::pair<int64, int64>>& source_target_pairs() const; 1936 1937 // Returns data on the window in a windowed operation such as 1938 // convolution. window()1939 virtual const Window& window() const { 1940 LOG(FATAL) << "Unimplemented method."; 1941 } 1942 1943 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)1944 virtual void set_window(const Window& window) { 1945 LOG(FATAL) << "Unimplemented method."; 1946 } 1947 1948 // Returns the unique_indices field. unique_indices()1949 virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; } 1950 1951 // Returns data on the dimension numbers used for a convolution operation, 1952 // which may be a kConvolution instruction or a kCustomCall that implements a 1953 // convolution. 1954 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; 1955 1956 // Sets the convolution dimension numbers on this instruction. In general you 1957 // shouldn't need to call this; instead, specify the convolution dimension 1958 // numbers when you create the instruction. 1959 void set_convolution_dimension_numbers( 1960 const ConvolutionDimensionNumbers& dnums); 1961 1962 // The number of feature groups. Must be a divisor of the input feature 1963 // dimension and output feature dimension. 1964 int64 feature_group_count() const; 1965 1966 void set_feature_group_count(int64_t feature_group_count); 1967 1968 // The number of batch groups. Must be a divisor of the input batch dimension 1969 int64 batch_group_count() const; 1970 1971 void set_batch_group_count(int64_t batch_group_count); 1972 1973 // Delegates to HloSelectAndScatterInstruction::select. 1974 HloComputation* select() const; 1975 1976 // Delegates to HloSelectAndScatterInstruction::scatter. 1977 HloComputation* scatter() const; 1978 1979 // Delegates to HloSelectAndScatterInstruction::set_select. 1980 void set_select(HloComputation* computation); 1981 1982 // Delegates to HloSelectAndScatterInstruction::set_scatter. 1983 void set_scatter(HloComputation* computation); 1984 1985 // Delegates to HloCustomCallInstruction::custom_call_target. 1986 const string& custom_call_target() const; 1987 1988 // Delegates to HloPadInstruction::padding_config. 1989 const PaddingConfig& padding_config() const; 1990 PaddingConfig* mutable_padding_config(); 1991 1992 // Delegates to HloConvolutionInstruction::padding_type. 1993 PaddingType padding_type() const; 1994 1995 // Delegates to HloDynamicSliceInstruction::slice_sizes. 1996 int64 slice_sizes(int64_t dimension) const; 1997 1998 // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. 1999 const std::vector<int64>& dynamic_slice_sizes() const; 2000 2001 // Delegates to HloCollectivePermuteInstruction::dynamic_slice_sizes. 2002 const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const; 2003 2004 // Delegates to HloGatherInstruction::gather_dimension_numbers. 2005 const GatherDimensionNumbers& gather_dimension_numbers() const; 2006 // Delegates to HloGatherInstruction::gather_slice_sizes. 2007 absl::Span<const int64> gather_slice_sizes() const; 2008 2009 // Delegates to HloScatterInstruction::scatter_dimension_numbers(). 2010 const ScatterDimensionNumbers& scatter_dimension_numbers() const; 2011 2012 // Delegates to HloDotInstruction::dot_dimension_numbers(). 2013 const DotDimensionNumbers& dot_dimension_numbers() const; 2014 2015 // Delegates to HloDomainInstruction::operand_side_metadata(). 2016 const DomainMetadata& operand_side_metadata() const; 2017 2018 // Delegates to HloDomainInstruction::user_side_metadata(). 2019 const DomainMetadata& user_side_metadata() const; 2020 2021 // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). 2022 bool is_cross_program_prefetch() const; 2023 2024 // Delegates to HloCompareInstruction::direction(). 2025 ComparisonDirection comparison_direction() const; 2026 2027 // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). 2028 const TriangularSolveOptions& triangular_solve_options() const; 2029 2030 // Delegates to HloCholeskyInstruction::cholesky_options(). 2031 const CholeskyOptions& cholesky_options() const; 2032 2033 // Appends operand to the list of operands and adds this instruction as a user 2034 // of the operand. 2035 void AppendOperand(HloInstruction* operand); 2036 2037 // Old methods kept for smooth subclassing transition END. 2038 2039 protected: 2040 // Indicates how an instruction uses a value (such as an operand). 2041 // 2042 // Does it (a) not use it, (b) use it, or (c) use it multiple times? 2043 // 2044 // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or 2045 // (ii) use it after having permuted it somehow, e.g. through a reshape. If 2046 // the use is a permuting use, we set permutation_instr to the instruction 2047 // that did the permuting. 2048 struct UseKind { 2049 enum Kind { kReuse, kUse, kNoUse }; 2050 2051 // Creates a UseKind that represents a use that permutes an instruction's 2052 // elements according to the given instruction. PermutingUseKind2053 static UseKind Permuting(const HloInstruction* permutation_instr) { 2054 UseKind k(kUse); 2055 k.permutation_instr = permutation_instr; 2056 return k; 2057 } 2058 UseKindUseKind2059 UseKind(Kind kind) // NOLINT intentionally nonexplicit 2060 : kind(kind), permutation_instr(nullptr) {} 2061 2062 bool friend operator==(UseKind a, Kind b) { return a.kind == b; } 2063 bool friend operator==(Kind a, UseKind b) { return b == a; } 2064 2065 Kind kind; 2066 const HloInstruction* permutation_instr; 2067 }; 2068 2069 // Helper class for computing OperandElementUse for kFusion. 2070 class FusionReusesParamElements; 2071 2072 // Internal constructor for a given opcode/shape, other fields must be filled 2073 // by factory methods. 2074 HloInstruction(HloOpcode opcode, const Shape& shape); 2075 RemoveAllOperands()2076 void RemoveAllOperands() { operands_.clear(); } 2077 RemoveOperandAt(int index)2078 void RemoveOperandAt(int index) { 2079 operands_.erase(operands_.begin() + index); 2080 } 2081 2082 // Removes a list of operands with the given indices in ascending order. 2083 void RemoveOperandsAtAscendingIndices( 2084 absl::Span<const int> ascending_indices); 2085 AppendComputation(HloComputation * computation)2086 void AppendComputation(HloComputation* computation) { 2087 called_computations_.push_back(computation); 2088 } 2089 DetachFrom(HloInstruction * usee)2090 void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } 2091 set_called_computation(int index,HloComputation * computation)2092 void set_called_computation(int index, HloComputation* computation) { 2093 called_computations_[index] = computation; 2094 } 2095 // Indices of computations in called_computations_ for instructions which call 2096 // multiple computations. 2097 enum { 2098 // kWhile computations. 2099 kBodyComputationIndex = 0, 2100 kConditionComputationIndex = 1, 2101 2102 // kSelectAndScatter computations. 2103 kSelectComputationIndex = 0, 2104 kScatterComputationIndex = 1, 2105 2106 // kConditional computations. 2107 kTrueComputationIndex = 0, 2108 kFalseComputationIndex = 1, 2109 }; 2110 2111 private: 2112 friend class HloComputation; 2113 2114 bool IdenticalInternal( 2115 const HloInstruction& other, 2116 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 2117 eq_operands, 2118 const std::function<bool(const HloComputation*, const HloComputation*)>& 2119 eq_computations, 2120 bool layout_sensitive, bool ignore_channel_id_values) const; 2121 2122 // Implementation for non-common logic of CloneWithNewOperands. CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2123 virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2124 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2125 HloCloneContext* context) const { 2126 // TODO(b/80131774): This should be pure virtual. 2127 LOG(FATAL) << "Unimplemented method."; 2128 } 2129 2130 // Implementation for non-common logic of ExtraAttributesToString. ExtraAttributesToStringImpl(const HloPrintOptions & options)2131 virtual std::vector<string> ExtraAttributesToStringImpl( 2132 const HloPrintOptions& options) const { 2133 return {}; 2134 } 2135 2136 // Implementation for IsElementwise if operand_idx is nullopt and for 2137 // IsElementwiseOnOperand if otherwise. 2138 // 2139 // NOTE: For all instructions other than kFusion, being elementwise on one of 2140 // the operands is equivalent to being elementwise on all the operands. 2141 virtual bool IsElementwiseImpl( 2142 const absl::optional<int64>& operand_idx) const; 2143 2144 // Prints an operand to a string. Accessed by friend class HloInstruction. 2145 virtual string OperandsToStringWithCanonicalNameMap( 2146 const HloPrintOptions& options, 2147 CanonicalNameMap* canonical_name_map) const; 2148 2149 // See comments on Identical(). 2150 virtual bool IdenticalSlowPath( 2151 const HloInstruction& other, 2152 const std::function<bool(const HloComputation*, const HloComputation*)>& 2153 eq_computations) const; 2154 2155 // Generates a hash value specific to a particular type of an instruction. 2156 // This function typically considers the inner root instruction. 2157 virtual uint64 InnerHash() const; 2158 2159 // Creates an n-ary elementwise operation. 2160 static std::unique_ptr<HloInstruction> CreateNary( 2161 const Shape& shape, HloOpcode opcode, 2162 absl::Span<HloInstruction* const> operands); 2163 2164 // Adds a user for this instruction. 2165 void AddUser(HloInstruction* user); 2166 2167 // Removes a user for this instruction. 2168 void RemoveUser(HloInstruction* user); 2169 2170 // Returns how this instruction uses elements of its operand at operand_num. 2171 UseKind OperandElementUse(int64_t operand_num) const; 2172 2173 // Helper for implementing backend_config(). Parses backend_config_ into the 2174 // given proto. 2175 Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; 2176 2177 // Mark this instruction as dead. Accessed by friend class HloInstruction. MarkAsDead()2178 void MarkAsDead() { marked_as_dead_ = true; } 2179 2180 // Has this instruction been marked as dead? Accessed by friend class 2181 // HloInstruction. IsMarkedAsDead()2182 bool IsMarkedAsDead() const { return marked_as_dead_; } 2183 2184 int unique_id_; // Unique to this HloInstruction within a HloModule 2185 2186 // Opcode for this instruction. 2187 HloOpcode opcode_; 2188 2189 // Instruction operands. 2190 InstructionVector operands_; 2191 2192 // The set of control predecessors of this instruction. 2193 // Note that the order of the instructions in the vector influences the order 2194 // computed in HloComputation::ComputeInstructionPostOrder, which may 2195 // influence the result of the compilation by changing the scheduling. We are 2196 // not sure if it matters. 2197 std::vector<HloInstruction*> control_predecessors_; 2198 2199 // The users of this instruction. Users are HLOs where this instruction is an 2200 // operand. The vector users_ and the map user_map_ contain identical members. 2201 // The map enables fast membership testing and the vector enables fast, stable 2202 // iteration. The value in the map contains the index of the instruction in 2203 // the vector what enables fast removal. 2204 std::vector<HloInstruction*> users_; 2205 absl::flat_hash_map<const HloInstruction*, int64> user_map_; 2206 2207 // The set of control successors of this instruction. 2208 std::vector<HloInstruction*> control_successors_; 2209 2210 // The computation in which this instruction is contained. 2211 HloComputation* parent_ = nullptr; 2212 2213 // Result shape of this instruction. 2214 Shape shape_; 2215 2216 // The sharding, if one exists. 2217 // Uses std::shared_ptr to allow reuse of the same sharding object between 2218 // HloInstructions and other components as HloSharding can be very large for 2219 // many element tuples. 2220 std::shared_ptr<const HloSharding> sharding_; 2221 2222 // Computations called by this instruction. 2223 std::vector<HloComputation*> called_computations_; 2224 2225 // A trace instruction that consumes this instruction. 2226 // 2227 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 2228 // an operand. 2229 HloInstruction* trace_instruction_ = nullptr; 2230 2231 // The backend-specific configuration for how a backend should compile this 2232 // HLO. See the documentation on backend_config(). 2233 string backend_config_; 2234 2235 // Attributes passed from the frontend to give hints to the backend about 2236 // how to compile this HLO. 2237 // HLO -> HLO transforms are expected to preserve these attributes on a 2238 // "best effort" basis only. 2239 // For example: 2240 // x = const(10, frontend_attributes={x} 2241 // y = const(10, frontend_attributes={y} 2242 // z = add(x,y), frontend_attributes={y} 2243 // Could be simplified to: 2244 // z' = const(20), frontend_attributes={?} 2245 FrontendAttributes frontend_attributes_; 2246 2247 // This field is assigned to true when backend_config_ is assigned to 2248 // a default configuration. 2249 bool is_default_config_ = false; 2250 2251 // True if this instruction has already been detached from its user and 2252 // operands. 2253 bool cleaned_up_ = false; 2254 2255 // String identifier for instruction. 2256 string name_; 2257 2258 // Metadata for debugging. 2259 OpMetadata metadata_; 2260 2261 // The number of partitions per outer dimension (listed in order from 2262 // outer-most dimension first). 2263 std::vector<int64> outer_dimension_partitions_; 2264 2265 // Intrusive flag used by HloComputation, whether this instruction has 2266 // been marked as dead. 2267 bool marked_as_dead_; 2268 2269 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); 2270 }; 2271 2272 // Explicit instantiations in hlo_instruction.cc. 2273 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); 2274 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); 2275 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor); 2276 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); 2277 2278 string ToString(HloInstruction::FusionKind kind); 2279 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 2280 const string& kind_name); 2281 2282 // Custom (de)stringification functions for protos that live inside 2283 // HloInstruction. 2284 string PaddingConfigToString(const PaddingConfig& padding); 2285 string FrontendAttributesToString( 2286 const FrontendAttributes& frontend_attributes); 2287 string RandomAlgorithmToString(const RandomAlgorithm& algorithm); 2288 string RandomDistributionToString(const RandomDistribution& distribution); 2289 string PrecisionToString(const PrecisionConfig::Precision& precision); 2290 string ConvolutionDimensionNumbersToString( 2291 const ConvolutionDimensionNumbers& dnums); 2292 string ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups); 2293 2294 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name); 2295 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); 2296 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); 2297 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(absl::string_view name); 2298 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion( 2299 absl::string_view name); 2300 2301 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 2302 2303 // Map classes that guarantee a deterministic iteration order when the key is 2304 // an HloInstruction* or a const HloInstruction*. 2305 // To make the iteration order over the map deterministic, the comparator 2306 // should not be using the pointer values, but rather an intrinsic property of 2307 // the hlo. Exception: null pointer values compare less than non-null. 2308 struct HloPtrComparator { 2309 bool operator()(const HloInstruction* const& lhs, 2310 const HloInstruction* const& rhs) const; 2311 }; 2312 2313 template <typename ValueT> 2314 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 2315 2316 template <typename ValueT> 2317 using ConstHloInstructionMap = 2318 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 2319 2320 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 2321 using ConstHloInstructionSet = 2322 std::set<const HloInstruction*, HloPtrComparator>; 2323 2324 } // namespace xla 2325 2326 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 2327