1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <set> 29 #include <string> 30 #include <tuple> 31 #include <vector> 32 33 #include "absl/container/flat_hash_map.h" 34 #include "absl/container/flat_hash_set.h" 35 #include "absl/container/inlined_vector.h" 36 #include "absl/memory/memory.h" 37 #include "absl/strings/str_cat.h" 38 #include "absl/strings/string_view.h" 39 #include "absl/types/span.h" 40 #include "tensorflow/compiler/xla/comparison_util.h" 41 #include "tensorflow/compiler/xla/iterator_util.h" 42 #include "tensorflow/compiler/xla/literal.h" 43 #include "tensorflow/compiler/xla/map_util.h" 44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 45 #include "tensorflow/compiler/xla/service/hlo.pb.h" 46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" 48 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 49 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 50 #include "tensorflow/compiler/xla/service/name_uniquer.h" 51 #include "tensorflow/compiler/xla/shape_tree.h" 52 #include "tensorflow/compiler/xla/types.h" 53 #include "tensorflow/compiler/xla/xla_data.pb.h" 54 #include "tensorflow/core/lib/core/status.h" 55 #include "tensorflow/core/lib/gtl/iterator_range.h" 56 #include "tensorflow/core/platform/logging.h" 57 #include "tensorflow/core/platform/macros.h" 58 #include "tensorflow/core/platform/protobuf.h" 59 #include "tensorflow/core/platform/types.h" 60 61 namespace xla { 62 63 class HloComputation; 64 class HloModule; 65 66 string PrintName(const string& name, bool print_ids); 67 68 // A bunch of switches that control how the hlo text should be printed. 69 class HloPrintOptions { 70 public: 71 enum class PrintSubcomputationMode { 72 kOff, // Do not print anything about subcomputations. 73 kNameOnly, // Only print the name of subcomputations. 74 kFullBodies, // Print the full bodies of subcomputations. 75 }; 76 77 // Constructs the default print options: don't print large constants, don't 78 // compact operands, no indentation. HloPrintOptions()79 HloPrintOptions() 80 : print_large_constants_(false), 81 print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), 82 print_metadata_(true), 83 print_backend_config_(true), 84 compact_operands_(false), 85 include_layout_in_shapes_(true), 86 print_result_shape_(true), 87 print_operand_shape_(true), 88 print_operand_names_(true), 89 print_program_shape_(true), 90 print_percent_(true), 91 print_control_dependencies_(true), 92 canonicalize_instruction_names_(false), 93 indent_amount_(0), 94 is_in_nested_computation_(false), 95 print_ids_(true), 96 canonicalize_computations_(false), 97 print_extra_attributes_(true) {} 98 ShortParsable()99 static HloPrintOptions ShortParsable() { 100 return HloPrintOptions() 101 .set_print_large_constants(true) 102 .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) 103 .set_print_metadata(false) 104 .set_print_backend_config(false) 105 .set_print_operand_shape(false) 106 .set_print_program_shape(false) 107 .set_print_percent(false) 108 .set_print_control_dependencies(false); 109 } 110 111 // Options to produce the canonical string representing an isomorphic 112 // computation graph. Canonical()113 static HloPrintOptions Canonical() { 114 return HloPrintOptions() 115 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 116 .set_print_metadata(false) 117 .set_print_backend_config(false) 118 .set_compact_operands(false) 119 .set_print_operand_names(false) 120 .set_print_operand_shape(true) 121 .set_print_program_shape(false) 122 .set_print_percent(false) 123 .set_print_control_dependencies(false) 124 .set_canonicalize_instruction_names(true); 125 } 126 127 // Options to produce a fingerprint of an HLO. Fingerprint()128 static HloPrintOptions Fingerprint() { 129 return HloPrintOptions() 130 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 131 .set_print_metadata(false) 132 .set_print_backend_config(false) 133 .set_compact_operands(true) 134 .set_print_operand_names(false) 135 .set_print_operand_shape(true) 136 .set_print_program_shape(false) 137 .set_print_percent(false) 138 .set_print_control_dependencies(false) 139 .set_canonicalize_instruction_names(true) 140 .set_print_ids(false) 141 .set_canonicalize_computations(true); 142 } 143 144 // If true, large constants will be printed out. set_print_large_constants(bool value)145 HloPrintOptions& set_print_large_constants(bool value) { 146 print_large_constants_ = value; 147 return *this; 148 } 149 set_print_subcomputation_mode(PrintSubcomputationMode value)150 HloPrintOptions& set_print_subcomputation_mode( 151 PrintSubcomputationMode value) { 152 print_subcomputation_mode_ = value; 153 return *this; 154 } 155 156 // If true, metadata will be printed. set_print_metadata(bool value)157 HloPrintOptions& set_print_metadata(bool value) { 158 print_metadata_ = value; 159 return *this; 160 } 161 162 // If true, backend_config will be printed. set_print_backend_config(bool value)163 HloPrintOptions& set_print_backend_config(bool value) { 164 print_backend_config_ = value; 165 return *this; 166 } 167 168 // If true, result shapes will be printed. set_print_result_shape(bool value)169 HloPrintOptions& set_print_result_shape(bool value) { 170 print_result_shape_ = value; 171 return *this; 172 } 173 174 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)175 HloPrintOptions& set_print_operand_shape(bool value) { 176 print_operand_shape_ = value; 177 return *this; 178 } 179 180 // If true, the operand names will be printed. set_print_operand_names(bool value)181 HloPrintOptions& set_print_operand_names(bool value) { 182 print_operand_names_ = value; 183 return *this; 184 } 185 186 // If true, all printed names include unique identifiers. set_print_ids(bool value)187 HloPrintOptions& set_print_ids(bool value) { 188 print_ids_ = value; 189 return *this; 190 } 191 192 // If true, the HLO includes its attributes. set_print_extra_attributes(bool value)193 HloPrintOptions& set_print_extra_attributes(bool value) { 194 print_extra_attributes_ = value; 195 return *this; 196 } 197 198 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)199 HloPrintOptions& set_print_program_shape(bool value) { 200 print_program_shape_ = value; 201 return *this; 202 } 203 204 // If true, names will be printed with prefix '%'. set_print_percent(bool value)205 HloPrintOptions& set_print_percent(bool value) { 206 print_percent_ = value; 207 return *this; 208 } 209 210 // If true, control dependencies will be printed. set_print_control_dependencies(bool value)211 HloPrintOptions& set_print_control_dependencies(bool value) { 212 print_control_dependencies_ = value; 213 return *this; 214 } 215 216 // If true, only a part of operands will be printed out (note that in this 217 // case the text will not be parsable). set_compact_operands(bool value)218 HloPrintOptions& set_compact_operands(bool value) { 219 compact_operands_ = value; 220 return *this; 221 } 222 223 // If true, include the layout in any shapes that are printed (instruction 224 // and operands). set_include_layout_in_shapes(bool value)225 HloPrintOptions& set_include_layout_in_shapes(bool value) { 226 include_layout_in_shapes_ = value; 227 return *this; 228 } 229 230 // If true, canonicalizes instructions' name. Instead of using "%foo.1" as 231 // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. set_canonicalize_instruction_names(bool value)232 HloPrintOptions& set_canonicalize_instruction_names(bool value) { 233 canonicalize_instruction_names_ = value; 234 return *this; 235 } 236 237 // If true, canonicalizes computations, sorting by computations' names. set_canonicalize_computations(bool value)238 HloPrintOptions& set_canonicalize_computations(bool value) { 239 canonicalize_computations_ = value; 240 return *this; 241 } 242 243 // The indent of the hlo text block. set_indent_amount(int value)244 HloPrintOptions& set_indent_amount(int value) { 245 indent_amount_ = value; 246 return *this; 247 } 248 249 // If true, indicates the instruction being printed is inside a nested 250 // computation. set_is_in_nested_computation(bool value)251 HloPrintOptions& set_is_in_nested_computation(bool value) { 252 is_in_nested_computation_ = value; 253 return *this; 254 } 255 256 // Instructions are selected for printing by a predicate function 257 // (`set_print_instructions`). We also print their surrounding instructions 258 // for ease of reading. We print `leading_and_trailing_instructions_number` 259 // instructions before and after the qualified ones inside a computation. set_leading_and_trailing_instructions_number(int value)260 HloPrintOptions& set_leading_and_trailing_instructions_number(int value) { 261 leading_and_trailing_instructions_number_ = value; 262 return *this; 263 } 264 265 // A callback which takes an HloInstruction*, its string representation, 266 // the indentation level of the resulting block, and a 267 // bool variable indicating whether the instruction is root or not. The return 268 // value is a string which is used for this instruction during printing. 269 using FormatInstructionFunc = 270 std::function<string(const HloInstruction*, const string&, int, bool)>; 271 set_format_instruction(FormatInstructionFunc callback)272 HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) { 273 format_instruction_ = callback; 274 return *this; 275 } 276 277 using HloInstructionPredicate = std::function<bool(const HloInstruction*)>; 278 279 // A callback which takes an HloInstruction* and returns whether it should be 280 // printed or not. set_print_instruction(HloInstructionPredicate callback)281 HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) { 282 print_instruction_ = callback; 283 return *this; 284 } 285 286 using HloComputationPredicate = std::function<bool(const HloComputation*)>; 287 288 // A callback which takes an HloComputation* and returns whether it should be 289 // printed or not. set_print_computation(HloComputationPredicate callback)290 HloPrintOptions& set_print_computation(HloComputationPredicate callback) { 291 print_computation_ = callback; 292 return *this; 293 } 294 print_large_constants()295 bool print_large_constants() const { return print_large_constants_; } print_subcomputation_mode()296 PrintSubcomputationMode print_subcomputation_mode() const { 297 return print_subcomputation_mode_; 298 } print_metadata()299 bool print_metadata() const { return print_metadata_; } print_backend_config()300 bool print_backend_config() const { return print_backend_config_; } compact_operands()301 bool compact_operands() const { return compact_operands_; } include_layout_in_shapes()302 bool include_layout_in_shapes() const { return include_layout_in_shapes_; } print_result_shape()303 bool print_result_shape() const { return print_result_shape_; } print_operand_shape()304 bool print_operand_shape() const { return print_operand_shape_; } print_operand_names()305 bool print_operand_names() const { return print_operand_names_; } print_ids()306 bool print_ids() const { return print_ids_; } print_program_shape()307 bool print_program_shape() const { return print_program_shape_; } print_percent()308 bool print_percent() const { return print_percent_; } print_control_dependencies()309 bool print_control_dependencies() const { 310 return print_control_dependencies_; 311 } print_extra_attributes()312 bool print_extra_attributes() const { return print_extra_attributes_; } canonicalize_instruction_names()313 bool canonicalize_instruction_names() const { 314 return canonicalize_instruction_names_; 315 } canonicalize_computations()316 bool canonicalize_computations() const { return canonicalize_computations_; } indent_amount()317 int indent_amount() const { return indent_amount_; } is_in_nested_computation()318 int is_in_nested_computation() const { return is_in_nested_computation_; } leading_and_trailing_instructions_number()319 int leading_and_trailing_instructions_number() const { 320 return leading_and_trailing_instructions_number_; 321 } format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)322 string format_instruction(const HloInstruction* instr, 323 const string& instr_name, int indent, 324 bool is_root) const { 325 return format_instruction_(instr, instr_name, indent, is_root); 326 } print_instruction(const HloInstruction * instr)327 bool print_instruction(const HloInstruction* instr) const { 328 return print_instruction_(instr); 329 } print_computation(const HloComputation * comp)330 bool print_computation(const HloComputation* comp) const { 331 return print_computation_(comp); 332 } 333 334 private: 335 bool print_large_constants_; 336 PrintSubcomputationMode print_subcomputation_mode_; 337 bool print_metadata_; 338 bool print_backend_config_; 339 bool compact_operands_; 340 bool include_layout_in_shapes_; 341 bool print_result_shape_; 342 bool print_operand_shape_; 343 bool print_operand_names_; 344 bool print_program_shape_; 345 bool print_percent_; 346 bool print_control_dependencies_; 347 bool canonicalize_instruction_names_; 348 int indent_amount_; 349 bool is_in_nested_computation_; 350 bool print_ids_; 351 bool canonicalize_computations_; 352 bool print_extra_attributes_; 353 int leading_and_trailing_instructions_number_ = 3; 354 FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr, 355 const string& instr_name, 356 int indent, bool is_root) { 357 return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "", 358 instr_name); 359 }; 360 HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) { 361 return true; 362 }; 363 HloComputationPredicate print_computation_ = [](const HloComputation* comp) { 364 return true; 365 }; 366 }; 367 368 // For canonical string output, we need to have a canonical way to rename 369 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>", 370 // where <xxx> is an index starting from 0. 371 class CanonicalNameMap { 372 public: CanonicalNameMap()373 CanonicalNameMap() : index(0) {} 374 LookupOrInsert(const string & old_name)375 string LookupOrInsert(const string& old_name) { 376 auto iter = canonical_name_map.find(old_name); 377 if (iter != canonical_name_map.end()) { 378 return iter->second; 379 } 380 381 string new_name = absl::StrCat("tmp_", index++); 382 canonical_name_map[old_name] = new_name; 383 return new_name; 384 } Clear()385 void Clear() { 386 canonical_name_map.clear(); 387 index = 0; 388 } 389 390 private: 391 int64 index; 392 absl::flat_hash_map<string, string> canonical_name_map; 393 }; 394 395 // HLO instructions are the atomic unit of the high-level compiler's IR. 396 // 397 // HloInstructions live inside of an HloComputation, which is analogous to a 398 // function in other programming languages. Nodes have no total order within 399 // their computation. Instead, they have a partial ordering determined by their 400 // data and control dependencies. 401 // 402 // HLO does not have basic blocks or explicit "branch" instructions. Instead, 403 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode 404 // control flow. For example, the kConditional HLO executes one of two possible 405 // computations, depending on the runtime value of a predicate. 406 // 407 // HLO is pure (mostly). It has no concept of mutable state. Instead, data 408 // values are produced by one HLO and flow into consumers across dependency 409 // edges. 410 class HloInstruction { 411 public: 412 // A fusion node computes the same value a call to its fusion computation 413 // would compute. However, the choice of fusion kind dictates codegen 414 // strategy for the backend. 415 // 416 // To generate code for a kFusion HloInstruction, most backends do something 417 // like the following: 418 // 419 // 1) Identify the "primary" HloInstruction of the fused computation. 420 // 2) Emit code that does the work of the primary node, creating its inputs 421 // and transforming its outputs as specified by the fused computation. 422 // 423 // In step (2), the code emitted is usually similar to the code that would be 424 // emitted for an *unfused* version of the primary node, except that 425 // 426 // - when the primary node reads an element of one of its operands, instead 427 // of loading the value from memory, it *computes* the value based on the 428 // contents of the fused computation. 429 // - when the primary node outputs a value, instead of storing it to memory, 430 // it forwards the value to its users, which then perform additional 431 // computations before the value is finally stored to memory at the root of 432 // the fusion node. 433 // 434 // An HloInstruction's FusionKind helps us find the kFusion instruction's 435 // primary node, and can also affect how we generate code in step (2). 436 // 437 // - kInput: The primary node is the root of the fused instruction. 438 // 439 // - kOutput: The primary node is not the root of the fused instruction. 440 // This fusion kind requires that one operand buffer of the fusion 441 // instruction be able to alias the output buffer. This constraint is 442 // usually enough to let backends find the primary node unambiguously. 443 // 444 // - kLoop: The primary node is the root of the fused computation, but, 445 // unlike in input fusion, we prescribe a specific implementation for 446 // codegen. Rather than generating code that looks like the code we'd emit 447 // for an unfused version of the primary/root node, we emit code that 448 // generates one element of the root at a time. 449 // 450 // - kCustom: Custom category for backend-specific fusions that don't fit 451 // into the above patterns. 452 // 453 // Not all backends support all fusion kinds, and given a particular fused 454 // computation, it's not in general safe to change its fusion kind. Creation 455 // of fusion nodes is always backend-specific. 456 // 457 // For elementwise ops (e.g. kAdd), most backends would emit a 458 // one-element-at-a-time implementation for the unfused version, so loop 459 // fusion and input fusion are probably equivalent if the root node is 460 // elementwise. They're not necessarily equivalent e.g. for kReduce, where an 461 // implementation might emit something more sophisticated for an unfused or 462 // input-fusion reduce, but will emit the naive code that reduces one element 463 // at a time for loop fusion with a reduce as the root. 464 // 465 // Another way to think of loop fusion is that it's equivalent to input 466 // fusion, but where the root node is an implicit identity node, whose 467 // unfused implementation is "read one element, write one element". 468 // 469 // TODO(b/79869434): This categorization scheme is not great. For one thing, 470 // input and loop fusion are basically the same thing: There is no reason for 471 // the HLO to encode backend-specific decisions about how e.g. a reduce that's 472 // the root of a fusion should be lowered. In addition, this scheme as 473 // written doesn't work for multi-output fusion, where the primary node is 474 // never actually the root (which is a kTuple instruction that gathers the 475 // multiple outputs of the fusion). 476 enum class FusionKind { 477 kLoop, 478 kInput, 479 kOutput, 480 kCustom, 481 }; 482 ~HloInstruction()483 virtual ~HloInstruction() { DetachFromOperandsAndUsers(); } 484 485 // Detaches an instruction from its operands and users. That is, remove the 486 // instruction from each operand's user set and user's operand set. 487 void DetachFromOperandsAndUsers(); 488 489 // Creates an instruction from the given proto. Arguments: 490 // 491 // proto: the proto to convert from. 492 // instruction_map: a map from instruction id to HloInstruction*. This map 493 // must contain all operands of the newly constructed instruction. 494 // computation_map: a map from computation id to HloComputation*. This map 495 // must contain all computations which the newly constructed instruction 496 // calls. 497 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 498 const HloInstructionProto& proto, 499 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, 500 const absl::flat_hash_map<int64, HloComputation*>& computation_map, 501 bool prohibit_empty_literal = true); 502 503 // Creates a parameter-retrieving instruction. 504 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, 505 const Shape& shape, 506 const string& name); 507 508 // Creates a literal constant instruction. 509 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); 510 511 // Creates an Iota instruction. 512 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, 513 int64 iota_dimension); 514 515 // Creates a get tuple element instruction. 516 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 517 const Shape& shape, HloInstruction* operand, int64 index); 518 519 // Creates a trace instruction that logs the input operand in the computation. 520 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, 521 HloInstruction* operand); 522 523 // Creates a random number generation instruction that fills a shape with 524 // random numbers from a given distribution. 525 // 526 // The parameters to the instruction are interpreted as follows: 527 // 528 // - If `distribution` is RNG_UNIFORM, generates a number in range 529 // [param0, param1). 530 // 531 // - If `distribution` is RNG_NORMAL, generates a normally-distributed value 532 // with mean `param0` and standard deviation `param1`. 533 static std::unique_ptr<HloInstruction> CreateRng( 534 const Shape& shape, RandomDistribution distribution, 535 absl::Span<HloInstruction* const> parameters); 536 537 // Creates a stateless random bit generator instruction that fills a shape 538 // with random bits. 539 static std::unique_ptr<HloInstruction> CreateRngBitGenerator( 540 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm); 541 542 // Creates an instruction to update the random number generator state to 543 // reflect the new state after `delta` units of 32 random bits are generated 544 // and returns the old state. 545 static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState( 546 const Shape& shape, int64 delta); 547 548 // Creates a unary instruction (one operand). 549 // Precondition: opcode must be a legitimate unary operation. 550 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 551 HloOpcode opcode, 552 HloInstruction* operand); 553 554 // Creates a binary instruction (two operands). 555 // Precondition: opcode must be a legitimate binary operation. 556 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 557 HloOpcode opcode, 558 HloInstruction* lhs, 559 HloInstruction* rhs); 560 561 // Creates a ternary instruction (three operands). 562 // Precondition: opcode must be a legitimate ternary operation. 563 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 564 HloOpcode opcode, 565 HloInstruction* lhs, 566 HloInstruction* rhs, 567 HloInstruction* ehs); 568 569 // Creates a variadic instruction (variable number of operands). 570 // Precondition: opcode must be a legitimate variadic operation. 571 static std::unique_ptr<HloInstruction> CreateVariadic( 572 const Shape& shape, HloOpcode opcode, 573 absl::Span<HloInstruction* const> operands); 574 575 // Creates a map instruction, where the computation (given by the handle) is 576 // applied element-wise to every element in operands (across the operands, 577 // at a given index) 578 static std::unique_ptr<HloInstruction> CreateMap( 579 const Shape& shape, absl::Span<HloInstruction* const> operands, 580 HloComputation* map_computation); 581 582 // Creates a convolution op, where rhs is the convolutional filter 583 // and window describes how the filter is applied to lhs. 584 static std::unique_ptr<HloInstruction> CreateConvolve( 585 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 586 int64 feature_group_count, int64 batch_group_count, const Window& window, 587 const ConvolutionDimensionNumbers& dimension_numbers, 588 const PrecisionConfig& precision_config); 589 590 // Creates an FFT op, of the type indicated by fft_type. 591 static std::unique_ptr<HloInstruction> CreateFft( 592 const Shape& shape, HloInstruction* operand, FftType fft_type, 593 absl::Span<const int64> fft_length); 594 595 // Creates a copy-start op, indicating whether this is a cross-program 596 // prefetch or not. 597 static std::unique_ptr<HloInstruction> CreateCopyStart( 598 const Shape& shape, HloInstruction* operand, 599 bool is_cross_program_prefetch = false); 600 601 // Creates a compare op, performing the comparison specified in direction. 602 static std::unique_ptr<HloInstruction> CreateCompare( 603 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 604 Comparison::Direction direction, 605 absl::optional<Comparison::Type> type = absl::nullopt); 606 607 static std::unique_ptr<HloInstruction> CreateTriangularSolve( 608 const Shape& shape, HloInstruction* a, HloInstruction* b, 609 const TriangularSolveOptions& options); 610 611 static std::unique_ptr<HloInstruction> CreateCholesky( 612 const Shape& shape, HloInstruction* a, const CholeskyOptions& options); 613 614 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 615 // dimensions specified in 'dimension_numbers'. 616 static std::unique_ptr<HloInstruction> CreateDot( 617 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 618 const DotDimensionNumbers& dimension_numbers, 619 const PrecisionConfig& precision_config); 620 621 // Creates a reduce-precision op, where operand is the data to reduce in 622 // precision, and exponent_bits and mantissa_bits describe the precision to 623 // reduce it to. 624 static std::unique_ptr<HloInstruction> CreateReducePrecision( 625 const Shape& shape, HloInstruction* operand, const int exponent_bits, 626 const int mantissa_bits); 627 628 // Creates an all-gather op, which concats the operands of all participants 629 // along all_gather_dimension. The replica_groups, channel_id, and 630 // use_global_device_ids arguments are identical to those in all-reduce, 631 // except that the order of the group members determines the concatenation 632 // order of inputs from different participants. 633 static std::unique_ptr<HloInstruction> CreateAllGather( 634 const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, 635 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 636 const absl::optional<int64>& channel_id, bool use_global_device_ids); 637 638 // Creates a cross replica reduction op. 639 // 640 // `reduction_computation`: the reduction function. 641 // 642 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 643 // empty, all replicas belong to one group in the order of 0 - (n-1). 644 // Allreduce will be applied within subgroups. 645 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 646 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 647 // 648 // `channel_id`: for Allreduce nodes from different modules, if 649 // they have the same channel_id, they will be 'Allreduce'd. If 650 // empty, Allreduce will not be applied cross modules. 651 static std::unique_ptr<HloInstruction> CreateAllReduce( 652 const Shape& shape, absl::Span<HloInstruction* const> operands, 653 HloComputation* reduce_computation, 654 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 655 const absl::optional<int64>& channel_id, bool use_global_device_ids); 656 657 // An all-to-all op takes N array operands of the same shape and scatters them 658 // to N replicas. Each replica gathers the results into a tuple. 659 // 660 // For example, suppose we have 3 replicas, with replica i passing inputs 661 // [a_i, b_i, c_i] to its all-to-all op. Then the resulting tuples are 662 // 663 // replica 0: (a_0, a_1, a_2) 664 // replica 1: (b_0, b_1, b_2) 665 // replica 2: (c_0, c_1, c_2). 666 // 667 // If replica_groups is set, the op is sharded and the replicas are permuted. 668 // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}. 669 // Then each replica passes two operands, say [a_i, b_i], and the result is 670 // 671 // replica 0: (b_3, b_0) 672 // replica 1: (a_1, a_2) 673 // replica 2: (b_1, b_2) 674 // replica 3: (a_3, a_0). 675 // 676 // All replica groups must have the same number of elements, and the number of 677 // operands must be equal to the size of one replica group. Each replica must 678 // appear in exactly one group. 679 // 680 // Note that this instruction is different than the all-to-all op in 681 // xla_builder.h. The version in XlaBuilder takes one input and slices it, 682 // and then concatenates the results into a single array. This instruction 683 // takes multiple inputs and returns a tuple; it doesn't slice or concatenate. 684 // It is used to implement the higher-level instruction in XlaBuilder. 685 static std::unique_ptr<HloInstruction> CreateAllToAll( 686 const Shape& shape, absl::Span<HloInstruction* const> operands, 687 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 688 const absl::optional<int64>& channel_id, 689 const absl::optional<int64>& split_dimension = absl::nullopt); 690 691 // Creates a communication instruction that permutes data cross replicas. 692 // Data is sent/received according to the (source_replica_id, 693 // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a 694 // target_replica_id in any pair, the output on that replica is a tensor 695 // consists of 0(s) in `shape`. 696 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 697 const Shape& shape, HloInstruction* operand, 698 const std::vector<std::pair<int64, int64>>& source_target_pairs, 699 const absl::optional<int64>& channel_id); 700 701 // Creates a communication instruction that initiates the start of 702 // CollectivePermute. 703 static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( 704 const Shape& shape, HloInstruction* operand, 705 const std::vector<std::pair<int64, int64>>& source_target_pairs, 706 const absl::optional<int64>& channel_id); 707 708 // Creates an instruction that returns a U32 replica ID. 709 static std::unique_ptr<HloInstruction> CreateReplicaId( 710 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 711 712 // Creates an instruction that returns a U32 partition ID. 713 static std::unique_ptr<HloInstruction> CreatePartitionId( 714 const Shape& shape = ShapeUtil::MakeShape(U32, {})); 715 716 // Creates a conversion instruction, where operand is the data to convert and 717 // shape is the target shape for the conversion. 718 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 719 HloInstruction* operand); 720 721 // Creates a bitcast instruction, where operand is the data to 722 // convert and shape is the target shape for the conversion. 723 static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape, 724 HloInstruction* operand); 725 726 // Creates a bitcast conversion instruction, where operand is the data to 727 // convert and shape is the target shape for the conversion. 728 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 729 const Shape& shape, HloInstruction* operand); 730 731 // Creates an infeed instruction, which reads data of the given shape from the 732 // Infeed interface of the device. infeed_shape is the shape of the data 733 // received from the infeed *not* the shape of the infeed instruction which 734 // is a tuple containing the infeed_shape and the TOKEN. 735 static std::unique_ptr<HloInstruction> CreateInfeed( 736 const Shape& infeed_shape, HloInstruction* token_operand, 737 const string& config); 738 739 // Creates an outfeed instruction, which outputs data. outfeed_shape is the 740 // shape of the data being outfed *not* the shape of the outfeed instruction 741 // which is a TOKEN. 742 static std::unique_ptr<HloInstruction> CreateOutfeed( 743 const Shape& outfeed_shape, HloInstruction* operand, 744 HloInstruction* token_operand, absl::string_view outfeed_config); 745 746 // Creates an asynchronous send instruction with the given channel id, which 747 // initiates sending the operand data to a unique receive instruction in 748 // another computation that has the same channel id. If is_host_transfer is 749 // true, then this Send operation transfers data to the host. 750 static std::unique_ptr<HloInstruction> CreateSend( 751 HloInstruction* operand, HloInstruction* token, int64 channel_id, 752 bool is_host_transfer = false); 753 754 // Blocks until data transfer for the Send instruction (operand) is complete. 755 // The operand must be kSend. 756 static std::unique_ptr<HloInstruction> CreateSendDone( 757 HloInstruction* operand, bool is_host_transfer = false); 758 759 // Creates an asynchronous receive instruction with the given channel id, 760 // which allocates resources to receive data of the given shape from a unique 761 // send instruction in another computation that has the same channel id. If 762 // is_host_transfer is true, then this Send operation transfers data from the 763 // host. 764 static std::unique_ptr<HloInstruction> CreateRecv( 765 const Shape& shape, HloInstruction* token, int64 channel_id, 766 bool is_host_transfer = false); 767 768 // Blocks until data transfer for the Recv instruction (operand) is complete 769 // and returns the receive buffer. The operand must be kRecv. 770 static std::unique_ptr<HloInstruction> CreateRecvDone( 771 HloInstruction* operand, bool is_host_transfer = false); 772 773 // Creates a slice instruction, where the operand is sliced by the given 774 // start/limit indices. 775 static std::unique_ptr<HloInstruction> CreateSlice( 776 const Shape& shape, HloInstruction* operand, 777 absl::Span<const int64> start_indices, 778 absl::Span<const int64> limit_indices, absl::Span<const int64> strides); 779 780 // Creates a slice instruction, where the first operand is sliced by 781 // start indices specified in the second operand, and by size specified in 782 // 'slice_sizes'. 783 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 784 const Shape& shape, HloInstruction* operand, 785 absl::Span<HloInstruction* const> start_indices, 786 absl::Span<const int64> slice_sizes); 787 788 // Creates a dynamic update slice instruction, which updates a slice 789 // of 'operand' with 'update' and 'start_indices'. 790 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 791 const Shape& shape, HloInstruction* operand, HloInstruction* update, 792 absl::Span<HloInstruction* const> start_indices); 793 794 // Creates a concatenate instruction, where the operands are concatenated on 795 // the provided dimension. 796 static std::unique_ptr<HloInstruction> CreateConcatenate( 797 const Shape& shape, absl::Span<HloInstruction* const> operands, 798 int64 dimension); 799 800 // Creates a reduce instruction, where the computation (given by the handle) 801 // is applied successively to every element in operand. For example, let f be 802 // the function to apply, which takes 2 arguments, an accumulator and the 803 // current value. Let init be an initial value (which is normally chosen to be 804 // the identity element for f, e.g. 0 if f is addition). 805 // Then the reduce HLO will compute: 806 // f(f(init, value0), value1), ...) 807 static std::unique_ptr<HloInstruction> CreateReduce( 808 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 809 absl::Span<const int64> dimensions_to_reduce, 810 HloComputation* reduce_computation); 811 812 // A more general, multiple-argument version of the above. 813 // The function to apply, f, now takes N arguments: 814 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 815 // init_valueN], and returns an N-tuple. The performed computation is (for 816 // commutative and associative f operators) equivalent to: 817 // 818 // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) 819 // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, 820 // ..., inputN.value1) 821 // ... 822 static std::unique_ptr<HloInstruction> CreateReduce( 823 const Shape& shape, absl::Span<HloInstruction* const> operands, 824 absl::Span<HloInstruction* const> init_values, 825 absl::Span<const int64> dimensions_to_reduce, 826 HloComputation* reduce_computation); 827 828 // Creates a reduce-window instruction, where the computation (given 829 // by the handle) is applied window-wise at each valid window 830 // position in the operand. 831 static std::unique_ptr<HloInstruction> CreateReduceWindow( 832 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 833 const Window& window, HloComputation* reduce_computation); 834 835 // A more general, multiple-argument version of the above. 836 // The reduce_computation being applied,now takes N arguments: 837 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 838 // valueN], and returns an N-tuple. The operands and init_values now each 839 // contain a span of N input arrays and n initial values. 840 static std::unique_ptr<HloInstruction> CreateReduceWindow( 841 const Shape& shape, absl::Span<HloInstruction* const> operands, 842 absl::Span<HloInstruction* const> init_values, const Window& window, 843 HloComputation* reduce_computation); 844 845 // Creates a batch-norm-training instruction. 846 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 847 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 848 HloInstruction* offset, float epsilon, int64 feature_index); 849 850 // Creates a batch-norm-inference instruction. 851 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 852 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 853 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 854 float epsilon, int64 feature_index); 855 856 // Creates a batch-norm-grad instruction. 857 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 858 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 859 HloInstruction* mean, HloInstruction* variance, 860 HloInstruction* grad_output, float epsilon, int64 feature_index); 861 862 // Creates a scatter computation that scatters the `source` array to the 863 // selected indices of each window. 864 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 865 const Shape& shape, HloInstruction* operand, HloComputation* select, 866 const Window& window, HloInstruction* source, HloInstruction* init_value, 867 HloComputation* scatter); 868 869 // Creates a broadcast instruction. 870 static std::unique_ptr<HloInstruction> CreateBroadcast( 871 const Shape& shape, HloInstruction* operand, 872 absl::Span<const int64> broadcast_dimensions); 873 874 // Creates a sequence of instructions that performs an explicit broadcast of 875 // the operand to the target shape. 876 // 877 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 878 // returned as a unique_ptr for API consistency with other factory methods in 879 // this interface. 880 // 881 // TODO(b/72173833) Ideally HloComputations would always be present, and so 882 // the adder being passed by the caller would not be necessary. 883 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 884 const Shape& output_shape, HloInstruction* operand, 885 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 886 adder); 887 888 // Creates a pad instruction, where the operand is padded on the edges and 889 // between the elements with the given padding value. 890 static std::unique_ptr<HloInstruction> CreatePad( 891 const Shape& shape, HloInstruction* operand, 892 HloInstruction* padding_value, const PaddingConfig& padding_config); 893 894 // Creates a reshape instruction, where the operand is flattened row-major 895 // order and then reshaped to the given result shape. 896 static std::unique_ptr<HloInstruction> CreateReshape( 897 const Shape& shape, HloInstruction* operand, 898 int64 inferred_dimension = -1); 899 900 // Creates a dynamic reshape instruction. Similar to reshape but dynamic 901 // dimensions sizes are provided as additional variadic arguments. 902 // 903 // Precondition: dim_sizes.size() == shape.rank() 904 static std::unique_ptr<HloInstruction> CreateDynamicReshape( 905 const Shape& shape, HloInstruction* data_operand, 906 absl::Span<HloInstruction* const> dim_sizes); 907 908 // Creates a transpose instruction which permutes the operand dimensions. 909 static std::unique_ptr<HloInstruction> CreateTranspose( 910 const Shape& shape, HloInstruction* operand, 911 absl::Span<const int64> dimensions); 912 913 // Creates a n-ary sort op with a 'compare' computation which is used for 914 // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, 915 // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at 916 // specific index positions which should be compared, and should return a 917 // PRED. 'is_stable' specifies whether stable sorting is required. 918 static std::unique_ptr<HloInstruction> CreateSort( 919 const Shape& shape, int64 dimension, 920 absl::Span<HloInstruction* const> operands, HloComputation* compare, 921 bool is_stable); 922 923 // Creates a while instruction, given a condition computation, a body 924 // computation, and the initial value for the input of the computations. For 925 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 926 // corresponds to the C code below. 927 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } 928 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 929 HloComputation* condition, 930 HloComputation* body, 931 HloInstruction* init); 932 933 static std::unique_ptr<HloInstruction> CreateConditional( 934 const Shape& shape, HloInstruction* pred, 935 HloInstruction* true_computation_arg, HloComputation* true_computation, 936 HloInstruction* false_computation_arg, HloComputation* false_computation); 937 938 static std::unique_ptr<HloInstruction> CreateConditional( 939 const Shape& shape, HloInstruction* branch_index, 940 absl::Span<HloComputation* const> branch_computations, 941 absl::Span<HloInstruction* const> branch_computation_args); 942 943 static std::unique_ptr<HloInstruction> CreateGather( 944 const Shape& shape, HloInstruction* operand, 945 HloInstruction* start_indices, 946 const GatherDimensionNumbers& gather_dim_numbers, 947 absl::Span<const int64> slice_sizes, bool indices_are_sorted); 948 949 static std::unique_ptr<HloInstruction> CreateScatter( 950 const Shape& shape, HloInstruction* operand, 951 HloInstruction* scatter_indices, HloInstruction* updates, 952 HloComputation* update_computation, 953 const ScatterDimensionNumbers& scatter_dim_numbers, 954 bool indices_are_sorted, bool unique_indices); 955 956 // Creates a kDomain instruction which delimits an HLO domain which have 957 // the provided user and operand side metadata. 958 static std::unique_ptr<HloInstruction> CreateDomain( 959 const Shape& shape, HloInstruction* operand, 960 std::unique_ptr<DomainMetadata> operand_side_metadata, 961 std::unique_ptr<DomainMetadata> user_side_metadata); 962 963 // Creates a fusion instruction. A fusion instruction contains one or more 964 // fused instructions forming an expression with a single root 965 // "fused_root". Additional instructions can be added to the fusion 966 // instruction with the method FuseInstruction. 967 static std::unique_ptr<HloInstruction> CreateFusion( 968 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 969 970 static std::unique_ptr<HloInstruction> CreateFusion( 971 const Shape& shape, FusionKind fusion_kind, 972 absl::Span<HloInstruction* const> operands, 973 HloComputation* fusion_computation); 974 975 // Creates a call instruction that applies the given computation on the given 976 // operands. "shape" is the resultant shape. 977 static std::unique_ptr<HloInstruction> CreateCall( 978 const Shape& shape, absl::Span<HloInstruction* const> operands, 979 HloComputation* computation); 980 981 // Creates a custom call instruction that applies the given custom call target 982 // to the given operands. "opaque" can be an arbitrary string with a 983 // backend-specific interpretation. "shape" is the resultant shape. 984 static std::unique_ptr<HloInstruction> CreateCustomCall( 985 const Shape& shape, absl::Span<HloInstruction* const> operands, 986 absl::string_view custom_call_target, string opaque = ""); 987 988 // Overload with a to_apply computation. 989 static std::unique_ptr<HloInstruction> CreateCustomCall( 990 const Shape& shape, absl::Span<HloInstruction* const> operands, 991 HloComputation* to_apply, absl::string_view custom_call_target, 992 string opaque = ""); 993 994 // Overload with multiple computations. The called computations can have 995 // different function signatures. 996 static std::unique_ptr<HloInstruction> CreateCustomCall( 997 const Shape& shape, absl::Span<HloInstruction* const> operands, 998 absl::Span<HloComputation* const> called_computations, 999 absl::string_view custom_call_target, string opaque = ""); 1000 1001 // Overload which constrains the layouts of the operand and result. 'shape' 1002 // and 'operand_shapes_with_layout' must have layouts. 1003 // 'operand_shapes_with_layout' must have a compatible element for each 1004 // operand. 1005 static std::unique_ptr<HloInstruction> CreateCustomCall( 1006 const Shape& shape, absl::Span<HloInstruction* const> operands, 1007 absl::string_view custom_call_target, 1008 absl::Span<const Shape> operand_shapes_with_layout, string opaque = ""); 1009 1010 // Creates a tuple instruction with the given elements. This is a convenience 1011 // wrapper around CreateVariadic. 1012 static std::unique_ptr<HloInstruction> CreateTuple( 1013 absl::Span<HloInstruction* const> elements); 1014 1015 // Creates a reverse instruction, which reverses the order of the elements 1016 // in the specified dimensions. 1017 static std::unique_ptr<HloInstruction> CreateReverse( 1018 const Shape& shape, HloInstruction* operand, 1019 absl::Span<const int64> dimensions); 1020 1021 // Creates a Afterall instruction used for joining or creating new values of 1022 // token type which thread through side-effecting operations. Operands must 1023 // all be tokens, and there must be at least one operand. 1024 static std::unique_ptr<HloInstruction> CreateAfterAll( 1025 absl::Span<HloInstruction* const> operands); 1026 1027 // Creates an AfterAll instruction which creates a token type out of thin air 1028 // (no operands). This is a separate method from CreateAfterAll to facility 1029 // the removal of operand-less AfterAll instructions. 1030 // TODO(b/110532604): Remove this capability of creating a token from nothing 1031 // when we plumb a primordial token from the entry computation. 1032 static std::unique_ptr<HloInstruction> CreateToken(); 1033 1034 static std::unique_ptr<HloInstruction> CreateGetDimensionSize( 1035 const Shape& shape, HloInstruction* operand, int64 dimension); 1036 1037 static std::unique_ptr<HloInstruction> CreateSetDimensionSize( 1038 const Shape& shape, HloInstruction* operand, HloInstruction* val, 1039 int64 dimension); 1040 1041 static std::unique_ptr<HloInstruction> CreateAddDependency( 1042 HloInstruction* data_operand, HloInstruction* token_operand); 1043 1044 // Returns the opcode for this instruction. opcode()1045 HloOpcode opcode() const { return opcode_; } 1046 1047 // Returns true if this instruction has a side effect, irrespective of whether 1048 // any called computations may contain an instruction with side effects. 1049 bool HasSideEffectNoRecurse() const; 1050 1051 // Returns true if this instruction has a side effect. An instruction has a 1052 // side effect if it uses certain opcodes or calls a computation with a side 1053 // effect. 1054 bool HasSideEffect() const; 1055 1056 // Returns the result shape of this instruction. 1057 const Shape& shape() const; 1058 1059 // Returns the (mutable) result shape of this instruction. mutable_shape()1060 Shape* mutable_shape() { return &shape_; } 1061 1062 // Returns the ith operand to this instruction. 1063 const HloInstruction* operand(int64 i) const; 1064 1065 // Returns the ith operand to this instruction. 1066 HloInstruction* mutable_operand(int64 i); 1067 1068 // Returns the number of operands to this instruction. operand_count()1069 int64 operand_count() const { return operands_.size(); } 1070 1071 // Returns the vector of operands of this instruction. 1072 using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; operands()1073 const InstructionVector& operands() const { return operands_; } 1074 1075 // Returns the vector of unique operands, in the same order they are found 1076 // within the operand vector. 1077 InstructionVector unique_operands() const; 1078 1079 // Returns the index of 'target' in the operands sequence. 1080 // Precondition: target must be an operand (or a fatal error will occur). 1081 int64 operand_index(const HloInstruction* target) const; 1082 1083 // Returns the number of users of this instruction. user_count()1084 int64 user_count() const { return users_.size(); } 1085 1086 // Returns the users of this instruction. users()1087 const std::vector<HloInstruction*>& users() const { return users_; } 1088 1089 // Returns the index of the user in the users() vector. 1090 // 1091 // Precondition: `user` is a user of the instruction. 1092 int64 UserId(HloInstruction* user); 1093 1094 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)1095 bool IsUserOf(const HloInstruction* instruction) const { 1096 return ContainsKey(instruction->user_map_, this); 1097 } 1098 1099 // Adds a control dependency from this instruction to the given 1100 // instruction. This instruction becomes a control predecessor of 1101 // 'instruction', and 'instruction' becomes a control successor of this 1102 // instruction. Returns an error status if either of the given instructions 1103 // does not belong to the same computation. 1104 // 1105 // This is used to enforce an additional ordering requirement that is not 1106 // captured by normal data dependencies, such as ordering among Send or Recv 1107 // operations to avoid deadlock. 1108 Status AddControlDependencyTo(HloInstruction* instruction); 1109 1110 // Removes a previously added control dependency from this instruction to 1111 // 'instruction'. 1112 Status RemoveControlDependencyTo(HloInstruction* instruction); 1113 1114 // Drops all control predecessors and successors from this HLO instruction. 1115 Status DropAllControlDeps(); 1116 1117 // Copies the control predecessors and successors on this HLO instruction to 1118 // `inst`. Does not do a deep copy so this makes sense only if `inst` and 1119 // this HLO are in the same module. 1120 // 1121 // Depending on the use cases we see in practice, in the future we may 1122 // consider folding the logic here into Clone, CloneWithNewOperands and 1123 // ReplaceAllUsesWith by treating control dependencies like data dependencies. 1124 Status CopyAllControlDepsFrom(const HloInstruction* inst); 1125 1126 // Returns the set of control predecessors (successors) of this 1127 // instruction. Control predecessors (successors) must execute before (after) 1128 // the current instruction. control_predecessors()1129 const std::vector<HloInstruction*>& control_predecessors() const { 1130 return control_predecessors_; 1131 } control_successors()1132 const std::vector<HloInstruction*>& control_successors() const { 1133 return control_successors_; 1134 } 1135 1136 // Returns true if "other" performs the same computation as this instruction. 1137 bool Identical( 1138 const HloInstruction& other, 1139 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1140 eq_operands = std::equal_to<const HloInstruction*>(), 1141 const std::function<bool(const HloComputation*, const HloComputation*)>& 1142 eq_computations = std::equal_to<const HloComputation*>(), 1143 bool layout_sensitive = true) const { 1144 return IdenticalInternal(other, eq_operands, eq_computations, 1145 layout_sensitive, 1146 /*ignore_channel_id_values=*/false); 1147 } 1148 1149 // Same as Identical() but ignores channel ID value mismatches, as long as 1150 // both have channel IDs or neither has a channel ID. 1151 bool IdenticalIgnoringChannelIdValues( 1152 const HloInstruction& other, 1153 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1154 eq_operands = std::equal_to<const HloInstruction*>(), 1155 const std::function<bool(const HloComputation*, const HloComputation*)>& 1156 eq_computations = std::equal_to<const HloComputation*>(), 1157 bool layout_sensitive = true) const { 1158 return IdenticalInternal(other, eq_operands, eq_computations, 1159 layout_sensitive, 1160 /*ignore_channel_id_values=*/true); 1161 } 1162 1163 // Generates a hash value of an HLO instruction. Hash considers 1164 // information on opcode, shape, operands, and typically a root instruction. 1165 // This function returns the same hash value for equivalent HLO instructions, 1166 // with respect to HloInstruction::Identical() method. 1167 // 1168 // Uses hash_operand function to compute hash values of its operands. 1169 // At the very top level, hash_operand should be non-recursive to prevent 1170 // non-termination. 1171 uint64 Hash( 1172 const std::function<uint64(const HloInstruction*)>& hash_operand) const; 1173 1174 // Calls the above method with non-recursive hash_operand function. 1175 uint64 Hash() const; 1176 1177 // Returns whether the instruction has a constant operand. 1178 bool HasConstantOperand() const; 1179 1180 // Replaces the use of this instruction in "user" with "new_producer". Note 1181 // that there might be multiple uses of this instruction in "user"; all will 1182 // be replaced. 1183 // 1184 // If user is a fusion instruction, this function will remove any duplicated 1185 // operands of it which could be created due to this replacement. 1186 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 1187 1188 // Same as ReplaceUseWith(), but new_producer can have a different shape. 1189 Status ReplaceUseWithDifferentShape(HloInstruction* user, 1190 HloInstruction* new_producer); 1191 1192 // Replaces the specified operand with new_operand. The old and new operands 1193 // must have compatible shapes ignoring floating-point precision. 1194 // 1195 // This function does NOT remove duplicated operands even if this instruction 1196 // is a fusion, so that the existing operand numbers do not change. 1197 Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); 1198 1199 // Same as ReplaceOperandWith(), but new_operand can have a different shape. 1200 Status ReplaceOperandWithDifferentShape(int64 operand_num, 1201 HloInstruction* new_operand); 1202 1203 // Replaces all uses of this instruction with the new producer. If 1204 // new_producer is a user of this instruction then new_producer remains a use 1205 // of this instruction to avoid introducing cycles into the graph. 1206 // 1207 // If this instruction is the root of its computation, sets the computation's 1208 // root to new_producer. 1209 // 1210 // The new producer must have a compatible shape ignoring floating-point 1211 // precision. 1212 // 1213 // If a user is a fusion instruction, this function will remove any duplicated 1214 // operands of it which could be created due to this replacement. 1215 Status ReplaceAllUsesWith(HloInstruction* new_producer); 1216 1217 // Same as ReplaceAllUsesWith, but new_producer can have a different shape. 1218 Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); 1219 1220 // Same as ReplaceAllUsesWith, but only replace given set of users. 1221 Status ReplaceUsesWith(absl::Span<HloInstruction* const> users, 1222 HloInstruction* new_producer); 1223 Status ReplaceAllUsesWithDifferentShape( 1224 absl::Span<HloInstruction* const> users, HloInstruction* new_producer); 1225 1226 // Performs a postorder DFS visit using this node as the root. If 1227 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 1228 // complete. If ignore_control_predecessors is true, instructions only 1229 // reachable via control dependencies will not be visited, and the postorder 1230 // will not take control dependencies into account. It is as if the control 1231 // dependencies didn't exist in the graph at all. 1232 template <typename HloInstructionPtr> 1233 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 1234 bool call_finish_visit = true, 1235 bool ignore_control_predecessors = false); 1236 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 1237 bool ignore_control_predecessors = false) const { 1238 return const_cast<HloInstruction*>(this)->Accept( 1239 visitor, call_finish_visit, ignore_control_predecessors); 1240 } 1241 1242 // Same as Accept() above, but the order of operand and control predecessor 1243 // visitation is determined by the given operand order; if compare(A, B) == 1244 // true, A is visited before B. 1245 using CompareFunction = 1246 std::function<bool(const HloInstruction*, const HloInstruction*)>; 1247 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 1248 const CompareFunction& operand_order, 1249 bool call_finish_visit = true); 1250 1251 // Visit this instruction and only this instruction with the given visitor. 1252 template <typename HloInstructionPtr> 1253 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 1254 1255 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 1256 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 1257 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 1258 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 1259 const; 1260 LatestNonGteAncestorAndIndex()1261 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 1262 auto rv = 1263 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 1264 return {const_cast<HloInstruction*>(rv.first), rv.second}; 1265 } 1266 1267 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 1268 const HloInstruction* LatestNonGteAncestor() const; 1269 LatestNonGteAncestor()1270 HloInstruction* LatestNonGteAncestor() { 1271 return const_cast<HloInstruction*>( 1272 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 1273 } 1274 1275 // Returns true whether this instruction is effectively a bitcast. Currently, 1276 // this means it either is a bitcast, or it is a transpose that is effectively 1277 // a bitcast. 1278 bool IsEffectiveBitcast() const; 1279 1280 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 1281 // The setter should only be called by HloModule or HloComputation methods. 1282 // 1283 // Precondition: The instruction has a valid to_apply_ field. 1284 HloComputation* to_apply() const; 1285 void set_to_apply(HloComputation* to_apply); 1286 1287 // Gets/sets the while_condition or while_body HloComputation for While. The 1288 // setters should only be called by HloModule or HloComputation methods. 1289 // 1290 // Precondition: The instruction is a While instruction. 1291 HloComputation* while_condition() const; 1292 HloComputation* while_body() const; 1293 void set_while_condition(HloComputation* while_condition); 1294 void set_while_body(HloComputation* while_body); 1295 1296 HloInstruction* while_init() const; 1297 1298 // Gets/sets the true and false HloComputation for Conditional. 1299 // 1300 // Precondition: The instruction is a predicated Conditional instruction. 1301 HloComputation* true_computation() const; 1302 HloComputation* false_computation() const; 1303 1304 // Gets the branch HloComputations for Conditional. 1305 // 1306 // Precondition: The instruction is a Conditional instruction. 1307 const std::vector<HloComputation*>& branch_computations() const; 1308 int branch_count() const; 1309 HloComputation* branch_computation(int b) const; 1310 // Sets a branch HloComputation for Conditional. 1311 // The setter should only be called by HloModule or HloComputation methods. 1312 // 1313 // Precondition: The instruction is a Conditional instruction. 1314 void set_branch_computation(int b, HloComputation* computation); 1315 1316 // Returns a string for the signature of this instruction if considered as a 1317 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 1318 string SignatureString() const; 1319 1320 // Returns a debugging string that represents this instruction. 1321 // 1322 // (We express the default options using an overload rather than a default 1323 // param because gdb ignores default params, but does resolve overloads.) 1324 // 1325 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 1326 // default, backing off on providing full information for very large strings, 1327 // or provide a different name for a ToString-like function that does that. ToString()1328 string ToString() const { return ToString(HloPrintOptions()); } 1329 string ToString(const HloPrintOptions& options) const; 1330 1331 // Components of the ToString() representation: 1332 1333 // Returns a string representation of the operand list. 1334 string OperandsToString(const HloPrintOptions& options) const; 1335 1336 // Returns string representation of op-specific attributes. 1337 std::vector<string> ExtraAttributesToString( 1338 const HloPrintOptions& options) const; 1339 1340 // As ToString, but returns a shorter string. 1341 string ToShortString() const; 1342 1343 // Prints an instruction to a string. 1344 // 1345 // The canonical string representation needs to name operands and instruction 1346 // names in a consistent way. This is implemented through the 1347 // canonical_name_map. 1348 string ToStringWithCanonicalNameMap( 1349 const HloPrintOptions& options, 1350 CanonicalNameMap* canonical_name_map) const; 1351 1352 // Returns a serialized representation of this instruction. 1353 virtual HloInstructionProto ToProto() const; 1354 1355 // Returns a category for the HLO. This could be something like "convolution" 1356 // or "elementwise". 1357 virtual string ToCategory() const; 1358 1359 // Returns a logging instruction, if the output of this instruction is logged. 1360 // 1361 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace 1362 HloInstruction* tracing() const; 1363 void set_tracing(HloInstruction* trace_instruction); 1364 1365 // Returns true if this instruction is fused, ie contained within a fusion 1366 // instruction. 1367 bool IsFused() const; 1368 1369 bool IsLoopFusion() const; 1370 bool IsInputFusion() const; 1371 bool IsOutputFusion() const; 1372 bool IsCustomFusion() const; 1373 1374 // Returns true if this instruction can be legally fused into a fusion 1375 // instruction. 1376 bool IsFusible() const; 1377 1378 bool IsCustomCall(absl::string_view target) const; 1379 1380 // Returns the sharding applied to this operator. 1381 // REQUIRES: has_sharding() is true. sharding()1382 const HloSharding& sharding() const { 1383 CHECK(has_sharding()); 1384 return *sharding_; 1385 } sharding_ptr()1386 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } 1387 1388 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)1389 const HloSharding& sharding_or_default(const HloSharding& default_) const { 1390 return sharding_ ? *sharding_ : default_; 1391 } 1392 // Returns the sharding unique device, if any. sharding_unique_device()1393 absl::optional<int64> sharding_unique_device() const { 1394 if (sharding_ == nullptr) { 1395 return absl::optional<int64>(); 1396 } 1397 return sharding_->UniqueDevice(); 1398 } 1399 // Sets the sharding of this operator. Should only be called by HloModule or 1400 // HloComputation methods. set_sharding(const HloSharding & sharding)1401 void set_sharding(const HloSharding& sharding) { 1402 sharding_ = std::make_shared<const HloSharding>(sharding); 1403 } set_sharding(std::shared_ptr<const HloSharding> sharding)1404 void set_sharding(std::shared_ptr<const HloSharding> sharding) { 1405 sharding_ = std::move(sharding); 1406 } 1407 void set_single_sharding(const HloSharding& sharding); 1408 // Sets a sharding that assigns the current instruction to device. set_device_sharding(int64 device)1409 void set_device_sharding(int64 device) { 1410 set_single_sharding(HloSharding::AssignDevice(device)); 1411 } 1412 // Remove any sharding from this operator. clear_sharding()1413 void clear_sharding() { sharding_ = nullptr; } 1414 // Return true if this operator has a sharding assigned. has_sharding()1415 bool has_sharding() const { return sharding_ != nullptr; } 1416 // Checks whether the instruction has compatible sharding with the other 1417 // instruction. has_compatible_sharding(const HloInstruction * other)1418 bool has_compatible_sharding(const HloInstruction* other) const { 1419 if (!has_sharding()) { 1420 return !other->has_sharding(); 1421 } 1422 return other->has_sharding() ? sharding() == other->sharding() : false; 1423 } 1424 1425 // When creating a new instruction which either replaces, or shifts up (kCopy 1426 // insertion case), another instruction, we need to make sure the certain 1427 // properties of the new instruction are copied into the derived one. As of 1428 // today, the metadata and sharding will be propagated to the derived 1429 // instruction. 1430 void SetupDerivedInstruction(HloInstruction* derived_instruction) const; 1431 1432 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1433 // operands. After creation the clone has no uses. "this" (the instruction 1434 // cloned from) is not changed. Suffix is the string to append to the name of 1435 // the instruction to form the name of the cloned instruction. 1436 // Ignores the control predecessors and successors of this HLO instruction. 1437 std::unique_ptr<HloInstruction> Clone( 1438 const string& suffix = "clone", HloCloneContext* context = nullptr) const; 1439 1440 // Clones the HLO instruction as above but with new shape. 1441 std::unique_ptr<HloInstruction> CloneWithNewShape( 1442 const Shape& shape, const string& suffix = "clone", 1443 HloCloneContext* context = nullptr) const; 1444 1445 // Clones the HLO instruction as above but with new shape and operands. 1446 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1447 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1448 HloCloneContext* context = nullptr) const; 1449 1450 // Returns the computations this instruction directly calls (if any). called_computations()1451 const std::vector<HloComputation*>& called_computations() const { 1452 return called_computations_; 1453 } 1454 1455 // Replaces all called computations based on a map function. This is needed 1456 // when we clone hlo_computations and want to let the instructions to point 1457 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1458 void ReplaceCalledComputations( 1459 std::function<HloComputation*(HloComputation*)> map_function) { 1460 for (int64 i = 0; i < called_computations_.size(); ++i) { 1461 called_computations_[i] = map_function(called_computations_[i]); 1462 } 1463 } 1464 1465 // Clears out the called computations. 1466 // 1467 // This is, in particular, necessary when inlining function bodies into their 1468 // caller. If there were side-effecting operations in the called computations, 1469 // the call itself is considered side-effecting and thus cannot be removed. By 1470 // clearing out the computations, we reflect the fact that all side-effecting 1471 // properties have been reflected in the caller, and make the call HLO 1472 // removable. ClearCalledComputations()1473 void ClearCalledComputations() { called_computations_.clear(); } 1474 1475 // Returns true if this instruction performs an elementwise operation on 1476 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1477 // to compute the output at index {i_0,i_1,...,i_n}, the only element required 1478 // from the operand (if any) is the element at {i_0,i_1,...,i_n}. 1479 // 1480 // Note on performance: when this instruction is kFusion, this method, in the 1481 // worst case, scans all fused instructions. We could speed this up by 1482 // caching. 1483 bool IsElementwiseOnOperand(int64 operand_idx) const; 1484 1485 // Returns true if this instruction is elementwise on all its operands. 1486 bool IsElementwise() const; 1487 1488 static bool IsOpElementwise(HloOpcode opcode); 1489 1490 // Returns true if this is a cross module all-reduce instruction. 1491 bool IsCrossModuleAllReduce() const; 1492 1493 // Returns true if this is a cross-replica all-reduce instruction. 1494 bool IsCrossReplicaAllReduce() const; 1495 1496 // Returns true if this instruction is binary and elementwise. 1497 bool IsElementwiseBinary() const; 1498 1499 // Returns whether this instruction may reuse elements of its `i`th operand. ReusesOperandElements(int64 i)1500 bool ReusesOperandElements(int64 i) const { 1501 return OperandElementUse(i) == UseKind::kReuse; 1502 } 1503 1504 // Returns the indices that the given operand appear in the operand list of 1505 // this instruction. Note that an instruction can use the same operand 1506 // multiple times. 1507 absl::InlinedVector<int64, 4> OperandIndices( 1508 const HloInstruction* operand) const; 1509 1510 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1511 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1512 // indices of the deleted dimensions and the output indices of the inserted 1513 // dimensions. 1514 // 1515 // Precondition: this op must be a reshape. 1516 std::tuple<bool, std::vector<int64>, std::vector<int64>> 1517 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1518 1519 // Gets the string identifier for this instruction. name()1520 const string& name() const { return name_; } 1521 1522 // Sets the string identifier for this instruction. Name will be sanitized to 1523 // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". SetAndSanitizeName(const string & name)1524 void SetAndSanitizeName(const string& name) { 1525 name_ = NameUniquer::GetSanitizedName(name); 1526 } 1527 1528 // Use the given NameUniquer to select a unique name for the instruction based 1529 // on the instruction's existing name. 1530 void UniquifyName(NameUniquer* name_uniquer); 1531 1532 // Clear the unique ID of the instruction so that it can be re-assigned, such 1533 // as for the purpose of compacting the instruction unique IDs. ClearUniqueIdInternal()1534 void ClearUniqueIdInternal() { unique_id_ = -1; } 1535 1536 // Set the unique id for this instruction to "id" SetUniqueId(int id)1537 void SetUniqueId(int id) { 1538 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1539 CHECK_GE(id, 0); 1540 unique_id_ = id; 1541 } 1542 1543 // Return the unique ID assigned to this node via SetUniqueId (or -1 1544 // if no id has been assigned yet). unique_id()1545 int unique_id() const { return unique_id_; } 1546 1547 // Returns the backend-specific configuration for how a backend should compile 1548 // this HLO. The meaning of the field is backend specific. Not for use before 1549 // or during general HLO optimization, since HLO optimizations do not preserve 1550 // this field and they cannot interpret it due to its meaning being backend 1551 // specific. Except for CustomCall, where this field is preserved and no 1552 // general HLO optimization needs to interpret it. 1553 // 1554 // ConfigProto should be a protobuf Message type. 1555 template <typename ConfigProto> backend_config()1556 StatusOr<ConfigProto> backend_config() const { 1557 ConfigProto proto; 1558 TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); 1559 return std::move(proto); 1560 } 1561 Status set_backend_config(const tensorflow::protobuf::Message& proto); 1562 set_frontend_attributes(FrontendAttributes frontend_attributes)1563 void set_frontend_attributes(FrontendAttributes frontend_attributes) { 1564 frontend_attributes_ = std::move(frontend_attributes); 1565 } 1566 frontend_attributes()1567 const FrontendAttributes& frontend_attributes() const { 1568 return frontend_attributes_; 1569 } 1570 1571 // Getter/setter for raw JSON-encoded backend config. Prefer the 1572 // functions above that deal in proto Messages where possible. raw_backend_config_string()1573 const string& raw_backend_config_string() const { return backend_config_; } set_raw_backend_config_string(string config_str)1574 void set_raw_backend_config_string(string config_str) { 1575 backend_config_ = std::move(config_str); 1576 } 1577 is_default_config()1578 bool is_default_config() const { return is_default_config_; } set_default_config()1579 void set_default_config() { is_default_config_ = true; } 1580 1581 // Returns a string representation of a proto in the format used by 1582 // raw_backend_config_string. 1583 // 1584 // This is morally equivalent to: 1585 // 1586 // HloInstruction instr; 1587 // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); 1588 // return instr.raw_backend_config_string(); 1589 // 1590 static StatusOr<string> BackendConfigToRawString( 1591 const tensorflow::protobuf::Message& proto); 1592 1593 // Returns the information used to tell the implementation information about 1594 // what sort of precision is requested. The meaning of the field is backend 1595 // specific. At the moment, it is only supported for kConvolution and kDot. 1596 // Transformations on one kDot or kConvolution to another will preserve this 1597 // information. Transformations to other HLOs will not preserve this 1598 // information but it is presumed that the alternate lowering is strictly 1599 // superior. 1600 // Precondition: opcode must be kConvolution or kDot. 1601 const PrecisionConfig& precision_config() const; 1602 PrecisionConfig* mutable_precision_config(); 1603 1604 // Sets the debug metadata for this instruction, excluding creation_pass_id, 1605 // which should never be copied anywhere. set_metadata(const OpMetadata & metadata)1606 void set_metadata(const OpMetadata& metadata) { 1607 int64 creation_pass_id = metadata_.creation_pass_id(); 1608 metadata_ = metadata; 1609 metadata_.set_creation_pass_id(creation_pass_id); 1610 } set_creation_pass_id(int64 pass_id)1611 void set_creation_pass_id(int64 pass_id) { 1612 metadata_.set_creation_pass_id(pass_id); 1613 } set_metadata_op_name(const std::string & name)1614 void set_metadata_op_name(const std::string& name) { 1615 metadata_.set_op_name(name); 1616 } set_logical_creation_pass_id(int64 pass_id)1617 void set_logical_creation_pass_id(int64 pass_id) { 1618 metadata_.set_logical_creation_pass_id(pass_id); 1619 } metadata()1620 const OpMetadata& metadata() const { return metadata_; } 1621 1622 // Set/get the computation containing this instruction. set_parent should only 1623 // be called by HloComputation methods which add/remove instructions to 1624 // computations. set_parent(HloComputation * computation)1625 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1626 const HloComputation* parent() const { return parent_; } parent()1627 HloComputation* parent() { return parent_; } 1628 1629 // Returns the module for this instruction. 1630 HloModule* GetModule() const; 1631 1632 // Get/Set the number of partitions per outer dimension (in order, starting 1633 // with outer-most dimension first). Currently used by the parallel cpu 1634 // backend to partition HLOs into parallel tasks. 1635 // 1636 // TODO(b/62783254) Replace these methods with a more general way to 1637 // annotate HLOs with backend-specific information. outer_dimension_partitions()1638 const std::vector<int64>& outer_dimension_partitions() const { 1639 return outer_dimension_partitions_; 1640 } 1641 void set_outer_dimension_partitions( 1642 const std::vector<int64>& outer_dimension_partitions); 1643 1644 // Old methods kept for smooth subclassing transition BEGIN. 1645 // TODO(b/80131774): Remove this code. 1646 1647 // Delegates to HloBatchNormInstruction::feature_index. 1648 int64 feature_index() const; 1649 1650 // Delegates to HloBatchNormInstruction::epsilon. 1651 float epsilon() const; 1652 1653 // Delegates to HloFftInstruction::fft_type. 1654 FftType fft_type() const; 1655 1656 // Delegates to HloFftInstruction::fft_length. 1657 const std::vector<int64>& fft_length() const; 1658 1659 // Delegates to HloChannelInstruction::channel_id. 1660 absl::optional<int64> channel_id() const; 1661 void set_channel_id(const absl::optional<int64>& channel_id); 1662 1663 // Returns the dimension sizes or numbers associated with this instruction. dimensions()1664 virtual const std::vector<int64>& dimensions() const { 1665 LOG(FATAL) << "Unimplemented method."; 1666 } dimensions(int64 index)1667 virtual int64 dimensions(int64 index) const { 1668 LOG(FATAL) << "Unimplemented method."; 1669 } mutable_dimensions()1670 virtual std::vector<int64>* mutable_dimensions() { 1671 LOG(FATAL) << "Unimplemented method."; 1672 } 1673 1674 // Delegates to HloConcatenateInstruction::concatenate_dimension. 1675 int64 concatenate_dimension() const; 1676 1677 // Delegates to HloGetDimensionSizeInstruction::dimension. 1678 int64 dimension() const; 1679 1680 // Delegates to HloReshapeInstruction::inferred_dimension. 1681 int64 inferred_dimension() const; 1682 1683 // Returns whether this instruction does a rank-2 transposition. 1684 bool IsRank2Transpose() const; 1685 1686 // Delegates to HloSliceInstruction::slice_start. 1687 int64 slice_starts(int64 dimension) const; 1688 const std::vector<int64>& slice_starts() const; 1689 std::vector<int64>* mutable_slice_starts(); 1690 1691 // Delegates to HloSliceInstruction::slice_limits. 1692 int64 slice_limits(int64 dimension) const; 1693 const std::vector<int64>& slice_limits() const; 1694 std::vector<int64>* mutable_slice_limits(); 1695 1696 // Delegates to HloSliceInstruction::slice_strides. 1697 int64 slice_strides(int64 dimension) const; 1698 const std::vector<int64>& slice_strides() const; 1699 std::vector<int64>* mutable_slice_strides(); 1700 1701 // Returns the literal associated with this instruction. 1702 const Literal& literal() const; 1703 1704 // Returns whether the instruction is a constant. 1705 bool IsConstant() const; 1706 1707 // Delegate to HloConstantInstruction::RelayoutConstant. 1708 void RelayoutConstant(const Layout& new_layout, 1709 const ShapeIndex& shape_index = {}); 1710 1711 // Delegates to HloTraceInstruction::TracingTag. 1712 string TracingTag() const; 1713 1714 // Delegates to HloFusionInstruction::AddFusionOperand. 1715 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 1716 1717 // Delegates to HloFusionInstruction::MergeFusionInstruction. 1718 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 1719 1720 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. 1721 void MergeFusionInstructionIntoMultiOutput( 1722 HloInstruction* instruction_to_merge); 1723 1724 // Delegates to HloFusionInstruction::FuseInstruction. 1725 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); 1726 1727 // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. 1728 HloInstruction* FuseInstructionIntoMultiOutput( 1729 HloInstruction* instruction_to_fuse); 1730 1731 // Delegates to HloFusionInstruction::fused_instruction. 1732 HloComputation* fused_instructions_computation() const; 1733 1734 // Delegates to HloFusionInstruction::fused_expression_root. 1735 HloInstruction* fused_expression_root() const; 1736 1737 // Delegates to HloFusionInstruction::fused_instructions. 1738 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1739 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1740 fused_instructions() const; 1741 1742 const tensorflow::gtl::iterator_range< 1743 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1744 fused_instructions(); 1745 1746 // Delegates to HloFusionInstruction::fused_instruction_count. 1747 int64 fused_instruction_count() const; 1748 1749 // Delegates to HloFusionInstruction::fused_parameter. 1750 HloInstruction* fused_parameter(int64 parameter_number) const; 1751 1752 // Delegates to HloFusionInstruction::fused_parameters. 1753 const std::vector<HloInstruction*>& fused_parameters() const; 1754 1755 // Returns true if this instruction is a fusion instruction that generates 1756 // multiple outputs. 1757 const bool IsMultiOutputFusion() const; 1758 1759 // Delegates to HloFusionInstruction::fusion_kind. 1760 FusionKind fusion_kind() const; 1761 1762 // Delegates to HloFusionInstruction::set_fusion_kind. 1763 void set_fusion_kind(FusionKind kind); 1764 1765 // Delegates to HloRngInstruction::random_distribution. 1766 RandomDistribution random_distribution() const; 1767 1768 // Delegates to HloParameterInstruction::parameter_number. 1769 int64 parameter_number() const; 1770 1771 // Delegates to 1772 // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. 1773 void set_parameter_replicated_at_leaf_buffers( 1774 absl::Span<const bool> parameter_replicated_at_leaf_buffers); 1775 void set_parameter_replicated_at_leaf_buffers( 1776 const std::vector<bool>& parameter_replicated_at_leaf_buffers); 1777 1778 // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. 1779 const absl::optional<std::vector<bool>>& 1780 parameter_replicated_at_leaf_buffers() const; 1781 1782 // Delegates to HloGetTupleElementInstruction::tuple_index. 1783 int64 tuple_index() const; 1784 1785 // Delegates to HloGetTupleElementInstruction::set_tuple_index. 1786 void set_tuple_index(int64 new_tuple_index); 1787 1788 // Delegates to HloReducePrecisionInstruction::exponent_bits. 1789 int32 exponent_bits() const; 1790 1791 // Delegates to HloReducePrecisionInstruction::mantissa_bits. 1792 int32 mantissa_bits() const; 1793 1794 // Delegates to HloInfeedInstruction::infeed_config. 1795 string infeed_config() const; 1796 1797 // Delegates to HloInfeedInstruction::set_infeed_config. 1798 void set_infeed_config(const string& config); 1799 1800 // Returns the config for the Outfeed instruction. 1801 const string& outfeed_config() const; 1802 1803 // Delegates to HloOutfeedInstruction::set_outfeed_config. 1804 void set_outfeed_config(const string& config); 1805 1806 // Returns the shape for the Outfeed instruction. 1807 const Shape& outfeed_shape() const; 1808 1809 // Returns the mutable shape for the Outfeed instruction. 1810 Shape* mutable_outfeed_shape(); 1811 1812 // Delegates to HloCollectiveInstruction::replica_groups. 1813 const std::vector<ReplicaGroup>& replica_groups() const; 1814 1815 // Delegates to HloCollectivePermuteInstruction::source_target_pairs. 1816 const std::vector<std::pair<int64, int64>>& source_target_pairs() const; 1817 1818 // Returns data on the window in a windowed operation such as 1819 // convolution. window()1820 virtual const Window& window() const { 1821 LOG(FATAL) << "Unimplemented method."; 1822 } 1823 1824 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)1825 virtual void set_window(const Window& window) { 1826 LOG(FATAL) << "Unimplemented method."; 1827 } 1828 1829 // Returns the unique_indices field. unique_indices()1830 virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; } 1831 1832 // Returns data on the dimension numbers used for a convolution operation, 1833 // which may be a kConvolution instruction or a kCustomCall that implements a 1834 // convolution. 1835 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; 1836 1837 // Sets the convolution dimension numbers on this instruction. In general you 1838 // shouldn't need to call this; instead, specify the convolution dimension 1839 // numbers when you create the instruction. 1840 void set_convolution_dimension_numbers( 1841 const ConvolutionDimensionNumbers& dnums); 1842 1843 // The number of feature groups. Must be a divisor of the input feature 1844 // dimension and output feature dimension. 1845 int64 feature_group_count() const; 1846 1847 void set_feature_group_count(int64 feature_group_count); 1848 1849 // The number of batch groups. Must be a divisor of the input batch dimension 1850 int64 batch_group_count() const; 1851 1852 void set_batch_group_count(int64 batch_group_count); 1853 1854 // Delegates to HloSelectAndScatterInstruction::select. 1855 HloComputation* select() const; 1856 1857 // Delegates to HloSelectAndScatterInstruction::scatter. 1858 HloComputation* scatter() const; 1859 1860 // Delegates to HloSelectAndScatterInstruction::set_select. 1861 void set_select(HloComputation* computation); 1862 1863 // Delegates to HloSelectAndScatterInstruction::set_scatter. 1864 void set_scatter(HloComputation* computation); 1865 1866 // Delegates to HloCustomCallInstruction::custom_call_target. 1867 const string& custom_call_target() const; 1868 1869 // Delegates to HloPadInstruction::padding_config. 1870 const PaddingConfig& padding_config() const; 1871 PaddingConfig* mutable_padding_config(); 1872 1873 // Delegates to HloConvolutionInstruction::padding_type. 1874 PaddingType padding_type() const; 1875 1876 // Delegates to HloDynamicSliceInstruction::slice_sizes. 1877 int64 slice_sizes(int64 dimension) const; 1878 1879 // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. 1880 const std::vector<int64>& dynamic_slice_sizes() const; 1881 1882 // Delegates to HloGatherInstruction::gather_dimension_numbers. 1883 const GatherDimensionNumbers& gather_dimension_numbers() const; 1884 // Delegates to HloGatherInstruction::gather_slice_sizes. 1885 absl::Span<const int64> gather_slice_sizes() const; 1886 1887 // Delegates to HloScatterInstruction::scatter_dimension_numbers(). 1888 const ScatterDimensionNumbers& scatter_dimension_numbers() const; 1889 1890 // Delegates to HloDotInstruction::dot_dimension_numbers(). 1891 const DotDimensionNumbers& dot_dimension_numbers() const; 1892 1893 // Delegates to HloDomainInstruction::operand_side_metadata(). 1894 const DomainMetadata& operand_side_metadata() const; 1895 1896 // Delegates to HloDomainInstruction::user_side_metadata(). 1897 const DomainMetadata& user_side_metadata() const; 1898 1899 // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). 1900 bool is_cross_program_prefetch() const; 1901 1902 // Delegates to HloCompareInstruction::direction(). 1903 ComparisonDirection comparison_direction() const; 1904 1905 // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). 1906 const TriangularSolveOptions& triangular_solve_options() const; 1907 1908 // Delegates to HloCholeskyInstruction::cholesky_options(). 1909 const CholeskyOptions& cholesky_options() const; 1910 1911 // Appends operand to the list of operands and adds this instruction as a user 1912 // of the operand. 1913 void AppendOperand(HloInstruction* operand); 1914 1915 // Old methods kept for smooth subclassing transition END. 1916 1917 protected: 1918 // Indicates how an instruction uses a value (such as an operand). 1919 // 1920 // Does it (a) not use it, (b) use it, or (c) use it multiple times? 1921 // 1922 // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or 1923 // (ii) use it after having permuted it somehow, e.g. through a reshape. If 1924 // the use is a permuting use, we set permutation_instr to the instruction 1925 // that did the permuting. 1926 struct UseKind { 1927 enum Kind { kReuse, kUse, kNoUse }; 1928 1929 // Creates a UseKind that represents a use that permutes an instruction's 1930 // elements according to the given instruction. PermutingUseKind1931 static UseKind Permuting(const HloInstruction* permutation_instr) { 1932 UseKind k(kUse); 1933 k.permutation_instr = permutation_instr; 1934 return k; 1935 } 1936 UseKindUseKind1937 UseKind(Kind kind) // NOLINT intentionally nonexplicit 1938 : kind(kind), permutation_instr(nullptr) {} 1939 1940 bool friend operator==(UseKind a, Kind b) { return a.kind == b; } 1941 bool friend operator==(Kind a, UseKind b) { return b == a; } 1942 1943 Kind kind; 1944 const HloInstruction* permutation_instr; 1945 }; 1946 1947 // Helper class for computing OperandElementUse for kFusion. 1948 class FusionReusesParamElements; 1949 1950 // Internal constructor for a given opcode/shape, other fields must be filled 1951 // by factory methods. 1952 HloInstruction(HloOpcode opcode, const Shape& shape); 1953 RemoveAllOperands()1954 void RemoveAllOperands() { operands_.clear(); } 1955 RemoveOperandAt(int index)1956 void RemoveOperandAt(int index) { 1957 operands_.erase(operands_.begin() + index); 1958 } 1959 1960 // Removes a list of operands with the given indices in ascending order. 1961 void RemoveOperandsAtAscendingIndices( 1962 absl::Span<const int> ascending_indices); 1963 AppendComputation(HloComputation * computation)1964 void AppendComputation(HloComputation* computation) { 1965 called_computations_.push_back(computation); 1966 } 1967 DetachFrom(HloInstruction * usee)1968 void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } 1969 set_called_computation(int index,HloComputation * computation)1970 void set_called_computation(int index, HloComputation* computation) { 1971 called_computations_[index] = computation; 1972 } 1973 // Indices of computations in called_computations_ for instructions which call 1974 // multiple computations. 1975 enum { 1976 // kWhile computations. 1977 kBodyComputationIndex = 0, 1978 kConditionComputationIndex = 1, 1979 1980 // kSelectAndScatter computations. 1981 kSelectComputationIndex = 0, 1982 kScatterComputationIndex = 1, 1983 1984 // kConditional computations. 1985 kTrueComputationIndex = 0, 1986 kFalseComputationIndex = 1, 1987 }; 1988 1989 private: 1990 friend class HloComputation; 1991 1992 bool IdenticalInternal( 1993 const HloInstruction& other, 1994 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1995 eq_operands, 1996 const std::function<bool(const HloComputation*, const HloComputation*)>& 1997 eq_computations, 1998 bool layout_sensitive, bool ignore_channel_id_values) const; 1999 2000 // Implementation for non-common logic of CloneWithNewOperands. CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2001 virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2002 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2003 HloCloneContext* context) const { 2004 // TODO(b/80131774): This should be pure virtual. 2005 LOG(FATAL) << "Unimplemented method."; 2006 } 2007 2008 // Implementation for non-common logic of ExtraAttributesToString. ExtraAttributesToStringImpl(const HloPrintOptions & options)2009 virtual std::vector<string> ExtraAttributesToStringImpl( 2010 const HloPrintOptions& options) const { 2011 return {}; 2012 } 2013 2014 // Implementation for IsElementwise if operand_idx is nullopt and for 2015 // IsElementwiseOnOperand if otherwise. 2016 // 2017 // NOTE: For all instructions other than kFusion, being elementwise on one of 2018 // the operands is equivalent to being elementwise on all the operands. 2019 virtual bool IsElementwiseImpl( 2020 const absl::optional<int64>& operand_idx) const; 2021 2022 // Prints an operand to a string. Accessed by friend class HloInstruction. 2023 virtual string OperandsToStringWithCanonicalNameMap( 2024 const HloPrintOptions& options, 2025 CanonicalNameMap* canonical_name_map) const; 2026 2027 // See comments on Identical(). 2028 virtual bool IdenticalSlowPath( 2029 const HloInstruction& other, 2030 const std::function<bool(const HloComputation*, const HloComputation*)>& 2031 eq_computations) const; 2032 2033 // Generates a hash value specific to a particular type of an instruction. 2034 // This function typically considers the inner root instruction. 2035 virtual uint64 InnerHash() const; 2036 2037 // Creates an n-ary elementwise operation. 2038 static std::unique_ptr<HloInstruction> CreateNary( 2039 const Shape& shape, HloOpcode opcode, 2040 absl::Span<HloInstruction* const> operands); 2041 2042 // Adds a user for this instruction. 2043 void AddUser(HloInstruction* user); 2044 2045 // Removes a user for this instruction. 2046 void RemoveUser(HloInstruction* user); 2047 2048 // Returns how this instruction uses elements of its operand at operand_num. 2049 UseKind OperandElementUse(int64 operand_num) const; 2050 2051 // Helper for implementing backend_config(). Parses backend_config_ into the 2052 // given proto. 2053 Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; 2054 2055 // Mark this instruction as dead. Accessed by friend class HloInstruction. MarkAsDead()2056 void MarkAsDead() { marked_as_dead_ = true; } 2057 2058 // Has this instruction been marked as dead? Accessed by friend class 2059 // HloInstruction. IsMarkedAsDead()2060 bool IsMarkedAsDead() const { return marked_as_dead_; } 2061 2062 int unique_id_; // Unique to this HloInstruction within a HloModule 2063 2064 // Opcode for this instruction. 2065 HloOpcode opcode_; 2066 2067 // Instruction operands. 2068 InstructionVector operands_; 2069 2070 // The set of control predecessors of this instruction. 2071 // Note that the order of the instructions in the vector influences the order 2072 // computed in HloComputation::ComputeInstructionPostOrder, which may 2073 // influence the result of the compilation by changing the scheduling. We are 2074 // not sure if it matters. 2075 std::vector<HloInstruction*> control_predecessors_; 2076 2077 // The users of this instruction. Users are HLOs where this instruction is an 2078 // operand. The vector users_ and the map user_map_ contain identical members. 2079 // The map enables fast membership testing and the vector enables fast, stable 2080 // iteration. The value in the map contains the index of the instruction in 2081 // the vector what enables fast removal. 2082 std::vector<HloInstruction*> users_; 2083 absl::flat_hash_map<const HloInstruction*, int64> user_map_; 2084 2085 // The set of control successors of this instruction. 2086 std::vector<HloInstruction*> control_successors_; 2087 2088 // The computation in which this instruction is contained. 2089 HloComputation* parent_ = nullptr; 2090 2091 // Result shape of this instruction. 2092 Shape shape_; 2093 2094 // The sharding, if one exists. 2095 // Uses std::shared_ptr to allow reuse of the same sharding object between 2096 // HloInstructions and other components as HloSharding can be very large for 2097 // many element tuples. 2098 std::shared_ptr<const HloSharding> sharding_; 2099 2100 // Computations called by this instruction. 2101 std::vector<HloComputation*> called_computations_; 2102 2103 // A trace instruction that consumes this instruction. 2104 // 2105 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 2106 // an operand. 2107 HloInstruction* trace_instruction_ = nullptr; 2108 2109 // The backend-specific configuration for how a backend should compile this 2110 // HLO. See the documentation on backend_config(). 2111 string backend_config_; 2112 2113 // Attributes passed from the frontend to give hints to the backend about 2114 // how to compile this HLO. 2115 // HLO -> HLO transforms are expected to preserve these attributes on a 2116 // "best effort" basis only. 2117 // For example: 2118 // x = const(10, frontend_attributes={x} 2119 // y = const(10, frontend_attributes={y} 2120 // z = add(x,y), frontend_attributes={y} 2121 // Could be simplified to: 2122 // z' = const(20), frontend_attributes={?} 2123 FrontendAttributes frontend_attributes_; 2124 2125 // This field is assigned to true when backend_config_ is assigned to 2126 // a default configuration. 2127 bool is_default_config_ = false; 2128 2129 // True if this instruction has already been detached from its user and 2130 // operands. 2131 bool cleaned_up_ = false; 2132 2133 // String identifier for instruction. 2134 string name_; 2135 2136 // Metadata for debugging. 2137 OpMetadata metadata_; 2138 2139 // The number of partitions per outer dimension (listed in order from 2140 // outer-most dimension first). 2141 std::vector<int64> outer_dimension_partitions_; 2142 2143 // Intrusive flag used by HloComputation, whether this instruction has 2144 // been marked as dead. 2145 bool marked_as_dead_; 2146 2147 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); 2148 }; 2149 2150 // Explicit instantiations in hlo_instruction.cc. 2151 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); 2152 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); 2153 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor); 2154 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); 2155 2156 string ToString(HloInstruction::FusionKind kind); 2157 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 2158 const string& kind_name); 2159 2160 // Custom (de)stringification functions for protos that live inside 2161 // HloInstruction. 2162 string PaddingConfigToString(const PaddingConfig& padding); 2163 string FrontendAttributesToString( 2164 const FrontendAttributes& frontend_attributes); 2165 string RandomAlgorithmToString(const RandomAlgorithm& algorithm); 2166 string RandomDistributionToString(const RandomDistribution& distribution); 2167 string PrecisionToString(const PrecisionConfig::Precision& precision); 2168 string ConvolutionDimensionNumbersToString( 2169 const ConvolutionDimensionNumbers& dnums); 2170 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups); 2171 2172 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name); 2173 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); 2174 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); 2175 2176 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 2177 2178 // Map classes that guarantee a deterministic iteration order when the key is 2179 // an HloInstruction* or a const HloInstruction*. 2180 // To make the iteration order over the map deterministic, the comparator 2181 // should not be using the pointer values, but rather an intrinsic property of 2182 // the hlo. Exception: null pointer values compare less than non-null. 2183 struct HloPtrComparator { 2184 bool operator()(const HloInstruction* const& lhs, 2185 const HloInstruction* const& rhs) const; 2186 }; 2187 2188 template <typename ValueT> 2189 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 2190 2191 template <typename ValueT> 2192 using ConstHloInstructionMap = 2193 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 2194 2195 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 2196 using ConstHloInstructionSet = 2197 std::set<const HloInstruction*, HloPtrComparator>; 2198 2199 } // namespace xla 2200 2201 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 2202