1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <set> 29 #include <string> 30 #include <tuple> 31 #include <vector> 32 33 #include "absl/container/flat_hash_map.h" 34 #include "absl/container/flat_hash_set.h" 35 #include "absl/container/inlined_vector.h" 36 #include "absl/memory/memory.h" 37 #include "absl/strings/str_cat.h" 38 #include "absl/strings/string_view.h" 39 #include "absl/types/span.h" 40 #include "tensorflow/compiler/xla/comparison_util.h" 41 #include "tensorflow/compiler/xla/iterator_util.h" 42 #include "tensorflow/compiler/xla/literal.h" 43 #include "tensorflow/compiler/xla/map_util.h" 44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 45 #include "tensorflow/compiler/xla/service/hlo.pb.h" 46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" 48 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 49 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 50 #include "tensorflow/compiler/xla/service/name_uniquer.h" 51 #include "tensorflow/compiler/xla/shape_tree.h" 52 #include "tensorflow/compiler/xla/types.h" 53 #include "tensorflow/compiler/xla/xla_data.pb.h" 54 #include "tensorflow/core/lib/core/status.h" 55 #include "tensorflow/core/lib/gtl/iterator_range.h" 56 #include "tensorflow/core/platform/logging.h" 57 #include "tensorflow/core/platform/macros.h" 58 #include "tensorflow/core/platform/protobuf.h" 59 #include "tensorflow/core/platform/types.h" 60 61 namespace xla { 62 63 class HloComputation; 64 class HloModule; 65 66 string PrintName(const string& name, bool print_ids); 67 68 // A bunch of switches that control how the hlo text should be printed. 69 class HloPrintOptions { 70 public: 71 enum class PrintSubcomputationMode { 72 kOff, // Do not print anything about subcomputations. 73 kNameOnly, // Only print the name of subcomputations. 74 kFullBodies, // Print the full bodies of subcomputations. 75 }; 76 77 // Constructs the default print options: don't print large constants, don't 78 // compact operands, no indentation. HloPrintOptions()79 HloPrintOptions() 80 : print_large_constants_(false), 81 print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), 82 print_metadata_(true), 83 print_backend_config_(true), 84 compact_operands_(false), 85 include_layout_in_shapes_(true), 86 print_operand_shape_(true), 87 print_operand_names_(true), 88 print_program_shape_(true), 89 print_percent_(true), 90 print_control_dependencies_(true), 91 canonicalize_instruction_names_(false), 92 indent_amount_(0), 93 is_in_nested_computation_(false), 94 print_ids_(true), 95 canonicalize_computations_(false) {} 96 ShortParsable()97 static HloPrintOptions ShortParsable() { 98 return HloPrintOptions() 99 .set_print_large_constants(true) 100 .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) 101 .set_print_metadata(false) 102 .set_print_backend_config(false) 103 .set_print_operand_shape(false) 104 .set_print_program_shape(false) 105 .set_print_percent(false) 106 .set_print_control_dependencies(false); 107 } 108 109 // Options to produce the canonical string representing an isomorphic 110 // computation graph. Canonical()111 static HloPrintOptions Canonical() { 112 return HloPrintOptions() 113 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 114 .set_print_metadata(false) 115 .set_print_backend_config(false) 116 .set_compact_operands(false) 117 .set_print_operand_names(false) 118 .set_print_operand_shape(true) 119 .set_print_program_shape(false) 120 .set_print_percent(false) 121 .set_print_control_dependencies(false) 122 .set_canonicalize_instruction_names(true); 123 } 124 125 // Options to produce a fingerprint of an HLO. Fingerprint()126 static HloPrintOptions Fingerprint() { 127 return HloPrintOptions() 128 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 129 .set_print_metadata(false) 130 .set_print_backend_config(false) 131 .set_compact_operands(true) 132 .set_print_operand_names(false) 133 .set_print_operand_shape(true) 134 .set_print_program_shape(false) 135 .set_print_percent(false) 136 .set_print_control_dependencies(false) 137 .set_canonicalize_instruction_names(true) 138 .set_print_ids(false) 139 .set_canonicalize_computations(true); 140 } 141 142 // If true, large constants will be printed out. set_print_large_constants(bool value)143 HloPrintOptions& set_print_large_constants(bool value) { 144 print_large_constants_ = value; 145 return *this; 146 } 147 set_print_subcomputation_mode(PrintSubcomputationMode value)148 HloPrintOptions& set_print_subcomputation_mode( 149 PrintSubcomputationMode value) { 150 print_subcomputation_mode_ = value; 151 return *this; 152 } 153 154 // If true, metadata will be printed. set_print_metadata(bool value)155 HloPrintOptions& set_print_metadata(bool value) { 156 print_metadata_ = value; 157 return *this; 158 } 159 160 // If true, backend_config will be printed. set_print_backend_config(bool value)161 HloPrintOptions& set_print_backend_config(bool value) { 162 print_backend_config_ = value; 163 return *this; 164 } 165 166 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)167 HloPrintOptions& set_print_operand_shape(bool value) { 168 print_operand_shape_ = value; 169 return *this; 170 } 171 172 // If true, the operand names will be printed. set_print_operand_names(bool value)173 HloPrintOptions& set_print_operand_names(bool value) { 174 print_operand_names_ = value; 175 return *this; 176 } 177 178 // If true, all printed names include unique identifiers. set_print_ids(bool value)179 HloPrintOptions& set_print_ids(bool value) { 180 print_ids_ = value; 181 return *this; 182 } 183 184 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)185 HloPrintOptions& set_print_program_shape(bool value) { 186 print_program_shape_ = value; 187 return *this; 188 } 189 190 // If true, names will be printed with prefix '%'. set_print_percent(bool value)191 HloPrintOptions& set_print_percent(bool value) { 192 print_percent_ = value; 193 return *this; 194 } 195 196 // If true, control dependencies will be printed. set_print_control_dependencies(bool value)197 HloPrintOptions& set_print_control_dependencies(bool value) { 198 print_control_dependencies_ = value; 199 return *this; 200 } 201 202 // If true, only a part of operands will be printed out (note that in this 203 // case the text will not be parsable). set_compact_operands(bool value)204 HloPrintOptions& set_compact_operands(bool value) { 205 compact_operands_ = value; 206 return *this; 207 } 208 209 // If true, include the layout in any shapes that are printed (instruction 210 // and operands). set_include_layout_in_shapes(bool value)211 HloPrintOptions& set_include_layout_in_shapes(bool value) { 212 include_layout_in_shapes_ = value; 213 return *this; 214 } 215 216 // If true, canonicalizes instructions' name. Instead of using "%foo.1" as 217 // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. set_canonicalize_instruction_names(bool value)218 HloPrintOptions& set_canonicalize_instruction_names(bool value) { 219 canonicalize_instruction_names_ = value; 220 return *this; 221 } 222 223 // If true, canonicalizes computations, sorting by computations' names. set_canonicalize_computations(bool value)224 HloPrintOptions& set_canonicalize_computations(bool value) { 225 canonicalize_computations_ = value; 226 return *this; 227 } 228 229 // The indent of the hlo text block. set_indent_amount(int value)230 HloPrintOptions& set_indent_amount(int value) { 231 indent_amount_ = value; 232 return *this; 233 } 234 235 // If true, indicates the instruction being printed is inside a nested 236 // computation. set_is_in_nested_computation(bool value)237 HloPrintOptions& set_is_in_nested_computation(bool value) { 238 is_in_nested_computation_ = value; 239 return *this; 240 } 241 242 // Instructions are selected for printing by a predicate function 243 // (`set_print_instructions`). We also print their surrounding instructions 244 // for ease of reading. We print `leading_and_trailing_instructions_number` 245 // instructions before and after the qualified ones inside a computation. set_leading_and_trailing_instructions_number(int value)246 HloPrintOptions& set_leading_and_trailing_instructions_number(int value) { 247 leading_and_trailing_instructions_number_ = value; 248 return *this; 249 } 250 251 // A callback which takes an HloInstruction*, its string representation, 252 // the indentation level of the resulting block, and a 253 // bool variable indicating whether the instruction is root or not. The return 254 // value is a string which is used for this instruction during printing. 255 using FormatInstructionFunc = 256 std::function<string(const HloInstruction*, const string&, int, bool)>; 257 set_format_instruction(FormatInstructionFunc callback)258 HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) { 259 format_instruction_ = callback; 260 return *this; 261 } 262 263 using HloInstructionPredicate = std::function<bool(const HloInstruction*)>; 264 265 // A callback which takes an HloInstruction* and returns whether it should be 266 // printed or not. set_print_instruction(HloInstructionPredicate callback)267 HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) { 268 print_instruction_ = callback; 269 return *this; 270 } 271 272 using HloComputationPredicate = std::function<bool(const HloComputation*)>; 273 274 // A callback which takes an HloComputation* and returns whether it should be 275 // printed or not. set_print_computation(HloComputationPredicate callback)276 HloPrintOptions& set_print_computation(HloComputationPredicate callback) { 277 print_computation_ = callback; 278 return *this; 279 } 280 print_large_constants()281 bool print_large_constants() const { return print_large_constants_; } print_subcomputation_mode()282 PrintSubcomputationMode print_subcomputation_mode() const { 283 return print_subcomputation_mode_; 284 } print_metadata()285 bool print_metadata() const { return print_metadata_; } print_backend_config()286 bool print_backend_config() const { return print_backend_config_; } compact_operands()287 bool compact_operands() const { return compact_operands_; } include_layout_in_shapes()288 bool include_layout_in_shapes() const { return include_layout_in_shapes_; } print_operand_shape()289 bool print_operand_shape() const { return print_operand_shape_; } print_operand_names()290 bool print_operand_names() const { return print_operand_names_; } print_ids()291 bool print_ids() const { return print_ids_; } print_program_shape()292 bool print_program_shape() const { return print_program_shape_; } print_percent()293 bool print_percent() const { return print_percent_; } print_control_dependencies()294 bool print_control_dependencies() const { 295 return print_control_dependencies_; 296 } canonicalize_instruction_names()297 bool canonicalize_instruction_names() const { 298 return canonicalize_instruction_names_; 299 } canonicalize_computations()300 bool canonicalize_computations() const { return canonicalize_computations_; } indent_amount()301 int indent_amount() const { return indent_amount_; } is_in_nested_computation()302 int is_in_nested_computation() const { return is_in_nested_computation_; } leading_and_trailing_instructions_number()303 int leading_and_trailing_instructions_number() const { 304 return leading_and_trailing_instructions_number_; 305 } format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)306 string format_instruction(const HloInstruction* instr, 307 const string& instr_name, int indent, 308 bool is_root) const { 309 return format_instruction_(instr, instr_name, indent, is_root); 310 } print_instruction(const HloInstruction * instr)311 bool print_instruction(const HloInstruction* instr) const { 312 return print_instruction_(instr); 313 } print_computation(const HloComputation * comp)314 bool print_computation(const HloComputation* comp) const { 315 return print_computation_(comp); 316 } 317 318 private: 319 bool print_large_constants_; 320 PrintSubcomputationMode print_subcomputation_mode_; 321 bool print_metadata_; 322 bool print_backend_config_; 323 bool compact_operands_; 324 bool include_layout_in_shapes_; 325 bool print_operand_shape_; 326 bool print_operand_names_; 327 bool print_program_shape_; 328 bool print_percent_; 329 bool print_control_dependencies_; 330 bool canonicalize_instruction_names_; 331 int indent_amount_; 332 bool is_in_nested_computation_; 333 bool print_ids_; 334 bool canonicalize_computations_; 335 int leading_and_trailing_instructions_number_ = 3; 336 FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr, 337 const string& instr_name, 338 int indent, bool is_root) { 339 return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "", 340 instr_name); 341 }; 342 HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) { 343 return true; 344 }; 345 HloComputationPredicate print_computation_ = [](const HloComputation* comp) { 346 return true; 347 }; 348 }; 349 350 // For canonical string output, we need to have a canonical way to rename 351 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>", 352 // where <xxx> is an index starting from 0. 353 class CanonicalNameMap { 354 public: CanonicalNameMap()355 CanonicalNameMap() : index(0) {} 356 LookupOrInsert(const string & old_name)357 string LookupOrInsert(const string& old_name) { 358 auto iter = canonical_name_map.find(old_name); 359 if (iter != canonical_name_map.end()) { 360 return iter->second; 361 } 362 363 string new_name = absl::StrCat("tmp_", index++); 364 canonical_name_map[old_name] = new_name; 365 return new_name; 366 } Clear()367 void Clear() { 368 canonical_name_map.clear(); 369 index = 0; 370 } 371 372 private: 373 int64 index; 374 absl::flat_hash_map<string, string> canonical_name_map; 375 }; 376 377 // HLO instructions are the atomic unit of the high-level compiler's IR. 378 // 379 // HloInstructions live inside of an HloComputation, which is analogous to a 380 // function in other programming languages. Nodes have no total order within 381 // their computation. Instead, they have a partial ordering determined by their 382 // data and control dependencies. 383 // 384 // HLO does not have basic blocks or explicit "branch" instructions. Instead, 385 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode 386 // control flow. For example, the kConditional HLO executes one of two possible 387 // computations, depending on the runtime value of a predicate. 388 // 389 // HLO is pure (mostly). It has no concept of mutable state. Instead, data 390 // values are produced by one HLO and flow into consumers across dependency 391 // edges. 392 class HloInstruction { 393 public: 394 // A fusion node computes the same value a call to its fusion computation 395 // would compute. However, the choice of fusion kind dictates codegen 396 // strategy for the backend. 397 // 398 // To generate code for a kFusion HloInstruction, most backends do something 399 // like the following: 400 // 401 // 1) Identify the "primary" HloInstruction of the fused computation. 402 // 2) Emit code that does the work of the primary node, creating its inputs 403 // and transforming its outputs as specified by the fused computation. 404 // 405 // In step (2), the code emitted is usually similar to the code that would be 406 // emitted for an *unfused* version of the primary node, except that 407 // 408 // - when the primary node reads an element of one of its operands, instead 409 // of loading the value from memory, it *computes* the value based on the 410 // contents of the fused computation. 411 // - when the primary node outputs a value, instead of storing it to memory, 412 // it forwards the value to its users, which then perform additional 413 // computations before the value is finally stored to memory at the root of 414 // the fusion node. 415 // 416 // An HloInstruction's FusionKind helps us find the kFusion instruction's 417 // primary node, and can also affect how we generate code in step (2). 418 // 419 // - kInput: The primary node is the root of the fused instruction. 420 // 421 // - kOutput: The primary node is not the root of the fused instruction. 422 // This fusion kind requires that one operand buffer of the fusion 423 // instruction be able to alias the output buffer. This constraint is 424 // usually enough to let backends find the primary node unambiguously. 425 // 426 // - kLoop: The primary node is the root of the fused computation, but, 427 // unlike in input fusion, we prescribe a specific implementation for 428 // codegen. Rather than generating code that looks like the code we'd emit 429 // for an unfused version of the primary/root node, we emit code that 430 // generates one element of the root at a time. 431 // 432 // - kCustom: Custom category for backend-specific fusions that don't fit 433 // into the above patterns. 434 // 435 // Not all backends support all fusion kinds, and given a particular fused 436 // computation, it's not in general safe to change its fusion kind. Creation 437 // of fusion nodes is always backend-specific. 438 // 439 // For elementwise ops (e.g. kAdd), most backends would emit a 440 // one-element-at-a-time implementation for the unfused version, so loop 441 // fusion and input fusion are probably equivalent if the root node is 442 // elementwise. They're not necessarily equivalent e.g. for kReduce, where an 443 // implementation might emit something more sophisticated for an unfused or 444 // input-fusion reduce, but will emit the naive code that reduces one element 445 // at a time for loop fusion with a reduce as the root. 446 // 447 // Another way to think of loop fusion is that it's equivalent to input 448 // fusion, but where the root node is an implicit identity node, whose 449 // unfused implementation is "read one element, write one element". 450 // 451 // TODO(b/79869434): This categorization scheme is not great. For one thing, 452 // input and loop fusion are basically the same thing: There is no reason for 453 // the HLO to encode backend-specific decisions about how e.g. a reduce that's 454 // the root of a fusion should be lowered. In addition, this scheme as 455 // written doesn't work for multi-output fusion, where the primary node is 456 // never actually the root (which is a kTuple instruction that gathers the 457 // multiple outputs of the fusion). 458 enum class FusionKind { 459 kLoop, 460 kInput, 461 kOutput, 462 kCustom, 463 }; 464 465 virtual ~HloInstruction(); 466 467 // Creates an instruction from the given proto. Arguments: 468 // 469 // proto: the proto to convert from. 470 // instruction_map: a map from instruction id to HloInstruction*. This map 471 // must contain all operands of the newly constructed instruction. 472 // computation_map: a map from computation id to HloComputation*. This map 473 // must contain all computations which the newly constructed instruction 474 // calls. 475 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 476 const HloInstructionProto& proto, 477 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, 478 const absl::flat_hash_map<int64, HloComputation*>& computation_map, 479 bool prohibit_empty_literal = true); 480 481 // Creates a parameter-retrieving instruction. 482 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, 483 const Shape& shape, 484 const string& name); 485 486 // Creates a literal constant instruction. 487 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); 488 489 // Creates an Iota instruction. 490 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, 491 int64 iota_dimension); 492 493 // Creates a get tuple element instruction. 494 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 495 const Shape& shape, HloInstruction* operand, int64 index); 496 497 // Creates a trace instruction that logs the input operand in the computation. 498 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, 499 HloInstruction* operand); 500 501 // Creates a random number generation instruction that fills a shape with 502 // random numbers from a given distribution. 503 // 504 // The parameters to the instruction are interpreted as follows: 505 // 506 // - If `distribution` is RNG_UNIFORM, generates a number in range 507 // [param0, param1). 508 // 509 // - If `distribution` is RNG_NORMAL, generates a normally-distributed value 510 // with mean `param0` and standard deviation `param1`. 511 static std::unique_ptr<HloInstruction> CreateRng( 512 const Shape& shape, RandomDistribution distribution, 513 absl::Span<HloInstruction* const> parameters); 514 515 // Creates an instruction to update the random number generator state to 516 // reflect the new state after `delta` units of 32 random bits are generated 517 // and returns the old state. 518 static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState( 519 const Shape& shape, int64 delta); 520 521 // Creates a unary instruction (one operand). 522 // Precondition: opcode must be a legitimate unary operation. 523 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 524 HloOpcode opcode, 525 HloInstruction* operand); 526 527 // Creates a binary instruction (two operands). 528 // Precondition: opcode must be a legitimate binary operation. 529 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 530 HloOpcode opcode, 531 HloInstruction* lhs, 532 HloInstruction* rhs); 533 534 // Creates a ternary instruction (three operands). 535 // Precondition: opcode must be a legitimate ternary operation. 536 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 537 HloOpcode opcode, 538 HloInstruction* lhs, 539 HloInstruction* rhs, 540 HloInstruction* ehs); 541 542 // Creates a variadic instruction (variable number of operands). 543 // Precondition: opcode must be a legitimate variadic operation. 544 static std::unique_ptr<HloInstruction> CreateVariadic( 545 const Shape& shape, HloOpcode opcode, 546 absl::Span<HloInstruction* const> operands); 547 548 // Creates a map instruction, where the computation (given by the handle) is 549 // applied element-wise to every element in operands (across the operands, 550 // at a given index) 551 static std::unique_ptr<HloInstruction> CreateMap( 552 const Shape& shape, absl::Span<HloInstruction* const> operands, 553 HloComputation* map_computation); 554 555 // Creates a convolution op, where rhs is the convolutional filter 556 // and window describes how the filter is applied to lhs. 557 static std::unique_ptr<HloInstruction> CreateConvolve( 558 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 559 int64 feature_group_count, int64 batch_group_count, const Window& window, 560 const ConvolutionDimensionNumbers& dimension_numbers, 561 const PrecisionConfig& precision_config); 562 563 // Creates an FFT op, of the type indicated by fft_type. 564 static std::unique_ptr<HloInstruction> CreateFft( 565 const Shape& shape, HloInstruction* operand, FftType fft_type, 566 absl::Span<const int64> fft_length); 567 568 // Creates a compare op, performing the comparison specified in direction. 569 static std::unique_ptr<HloInstruction> CreateCompare( 570 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 571 ComparisonDirection direction); 572 573 static std::unique_ptr<HloInstruction> CreateTriangularSolve( 574 const Shape& shape, HloInstruction* a, HloInstruction* b, 575 const TriangularSolveOptions& options); 576 577 static std::unique_ptr<HloInstruction> CreateCholesky( 578 const Shape& shape, HloInstruction* a, const CholeskyOptions& options); 579 580 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 581 // dimensions specified in 'dimension_numbers'. 582 static std::unique_ptr<HloInstruction> CreateDot( 583 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 584 const DotDimensionNumbers& dimension_numbers, 585 const PrecisionConfig& precision_config); 586 587 // Creates a reduce-precision op, where operand is the data to reduce in 588 // precision, and exponent_bits and mantissa_bits describe the precision to 589 // reduce it to. 590 static std::unique_ptr<HloInstruction> CreateReducePrecision( 591 const Shape& shape, HloInstruction* operand, const int exponent_bits, 592 const int mantissa_bits); 593 594 // Creates a cross replica reduction op. 595 // 596 // `reduction_computation`: the reduction function. 597 // 598 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 599 // empty, all replicas belong to one group in the order of 0 - (n-1). 600 // Allreduce will be applied within subgroups. 601 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 602 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 603 // 604 // `channel_id`: for Allreduce nodes from different modules, if 605 // they have the same channel_id, they will be 'Allreduce'd. If 606 // empty, Allreduce will not be applied cross modules. 607 static std::unique_ptr<HloInstruction> CreateAllReduce( 608 const Shape& shape, absl::Span<HloInstruction* const> operands, 609 HloComputation* reduce_computation, 610 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 611 const absl::optional<int64>& channel_id); 612 613 // An all-to-all op takes N array operands of the same shape and scatters them 614 // to N replicas. Each replica gathers the results into a tuple. 615 // 616 // For example, suppose we have 3 replicas, with replica i passing inputs 617 // [a_i, b_i, c_i] to its all-to-all op. Then the resulting tuples are 618 // 619 // replica 0: (a_0, a_1, a_2) 620 // replica 1: (b_0, b_1, b_2) 621 // replica 2: (c_0, c_1, c_2). 622 // 623 // If replica_groups is set, the op is sharded and the replicas are permuted. 624 // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}. 625 // Then each replica passes two operands, say [a_i, b_i], and the result is 626 // 627 // replica 0: (b_3, b_0) 628 // replica 1: (a_1, a_2) 629 // replica 2: (b_1, b_2) 630 // replica 3: (a_3, a_0). 631 // 632 // All replica groups must have the same number of elements, and the number of 633 // operands must be equal to the size of one replica group. Each replica must 634 // appear in exactly one group. 635 // 636 // Note that this instruction is different than the all-to-all op in 637 // xla_builder.h. The version in XlaBuilder takes one input and slices it, 638 // and then concatenates the results into a single array. This instruction 639 // takes multiple inputs and returns a tuple; it doesn't slice or concatenate. 640 // It is used to implement the higher-level instruction in XlaBuilder. 641 static std::unique_ptr<HloInstruction> CreateAllToAll( 642 const Shape& shape, absl::Span<HloInstruction* const> operands, 643 const std::vector<ReplicaGroup>& replica_groups, 644 const absl::optional<int64>& channel_id, 645 const absl::optional<int64>& split_dimension = absl::nullopt); 646 647 // Creates a communication instructions that permutes data cross replicas. 648 // Data is sent/received according to the (source_replica_id, 649 // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a 650 // target_replica_id in any pair, the output on that replica is a tensor 651 // consists of 0(s) in `shape`. 652 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 653 const Shape& shape, HloInstruction* operand, 654 const std::vector<std::pair<int64, int64>>& source_target_pairs, 655 const absl::optional<int64>& channel_id); 656 657 // Creates an instruction that returns a U32 replica ID. 658 static std::unique_ptr<HloInstruction> CreateReplicaId(); 659 660 // Creates an instruction that returns a U32 partition ID. 661 static std::unique_ptr<HloInstruction> CreatePartitionId(); 662 663 // Creates a conversion instruction, where operand is the data to convert and 664 // shape is the target shape for the conversion. 665 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 666 HloInstruction* operand); 667 668 // Creates a bitcast instruction, where operand is the data to 669 // convert and shape is the target shape for the conversion. 670 static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape, 671 HloInstruction* operand); 672 673 // Creates a bitcast conversion instruction, where operand is the data to 674 // convert and shape is the target shape for the conversion. 675 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 676 const Shape& shape, HloInstruction* operand); 677 678 // Creates an infeed instruction, which reads data of the given shape from the 679 // Infeed interface of the device. infeed_shape is the shape of the data 680 // received from the infeed *not* the shape of the infeed instruction which 681 // is a tuple containing the infeed_shape and the TOKEN. 682 static std::unique_ptr<HloInstruction> CreateInfeed( 683 const Shape& infeed_shape, HloInstruction* token_operand, 684 const string& config); 685 686 // Creates an outfeed instruction, which outputs data. outfeed_shape is the 687 // shape of the data being outfed *not* the shape of the outfeed instruction 688 // which is a TOKEN. 689 static std::unique_ptr<HloInstruction> CreateOutfeed( 690 const Shape& outfeed_shape, HloInstruction* operand, 691 HloInstruction* token_operand, absl::string_view outfeed_config); 692 693 // Creates an asynchronous send instruction with the given channel id, which 694 // initiates sending the operand data to a unique receive instruction in 695 // another computation that has the same channel id. If is_host_transfer is 696 // true, then this Send operation transfers data to the host. 697 static std::unique_ptr<HloInstruction> CreateSend( 698 HloInstruction* operand, HloInstruction* token, int64 channel_id, 699 bool is_host_transfer = false); 700 701 // Blocks until data transfer for the Send instruction (operand) is complete. 702 // The operand must be kSend. 703 static std::unique_ptr<HloInstruction> CreateSendDone( 704 HloInstruction* operand, bool is_host_transfer = false); 705 706 // Creates an asynchronous receive instruction with the given channel id, 707 // which allocates resources to receive data of the given shape from a unique 708 // send instruction in another computation that has the same channel id. If 709 // is_host_transfer is true, then this Send operation transfers data from the 710 // host. 711 static std::unique_ptr<HloInstruction> CreateRecv( 712 const Shape& shape, HloInstruction* token, int64 channel_id, 713 bool is_host_transfer = false); 714 715 // Blocks until data transfer for the Recv instruction (operand) is complete 716 // and returns the receive buffer. The operand must be kRecv. 717 static std::unique_ptr<HloInstruction> CreateRecvDone( 718 HloInstruction* operand, bool is_host_transfer = false); 719 720 // Creates a slice instruction, where the operand is sliced by the given 721 // start/limit indices. 722 static std::unique_ptr<HloInstruction> CreateSlice( 723 const Shape& shape, HloInstruction* operand, 724 absl::Span<const int64> start_indices, 725 absl::Span<const int64> limit_indices, absl::Span<const int64> strides); 726 727 // Creates a slice instruction, where the first operand is sliced by 728 // start indices specified in the second operand, and by size specified in 729 // 'slice_sizes'. 730 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 731 const Shape& shape, HloInstruction* operand, 732 absl::Span<HloInstruction* const> start_indices, 733 absl::Span<const int64> slice_sizes); 734 735 // Creates a dynamic update slice instruction, which updates a slice 736 // of 'operand' with 'update' and 'start_indices'. 737 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 738 const Shape& shape, HloInstruction* operand, HloInstruction* update, 739 absl::Span<HloInstruction* const> start_indices); 740 741 // Creates a concatenate instruction, where the operands are concatenated on 742 // the provided dimension. 743 static std::unique_ptr<HloInstruction> CreateConcatenate( 744 const Shape& shape, absl::Span<HloInstruction* const> operands, 745 int64 dimension); 746 747 // Creates a reduce instruction, where the computation (given by the handle) 748 // is applied successively to every element in operand. For example, let f be 749 // the function to apply, which takes 2 arguments, an accumulator and the 750 // current value. Let init be an initial value (which is normally chosen to be 751 // the identity element for f, e.g. 0 if f is addition). 752 // Then the reduce HLO will compute: 753 // f(f(init, value0), value1), ...) 754 static std::unique_ptr<HloInstruction> CreateReduce( 755 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 756 absl::Span<const int64> dimensions_to_reduce, 757 HloComputation* reduce_computation); 758 759 // A more general, multiple-argument version of the above. 760 // The function to apply, f, now takes N arguments: 761 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 762 // init_valueN], and returns an N-tuple. The performed computation is (for 763 // commutative and associative f operators) equivalent to: 764 // 765 // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) 766 // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, 767 // ..., inputN.value1) 768 // ... 769 static std::unique_ptr<HloInstruction> CreateReduce( 770 const Shape& shape, absl::Span<HloInstruction* const> operands, 771 absl::Span<HloInstruction* const> init_values, 772 absl::Span<const int64> dimensions_to_reduce, 773 HloComputation* reduce_computation); 774 775 // Creates a reduce-window instruction, where the computation (given 776 // by the handle) is applied window-wise at each valid window 777 // position in the operand. 778 static std::unique_ptr<HloInstruction> CreateReduceWindow( 779 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 780 const Window& window, HloComputation* reduce_computation); 781 782 // Creates a batch-norm-training instruction. 783 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 784 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 785 HloInstruction* offset, float epsilon, int64 feature_index); 786 787 // Creates a batch-norm-inference instruction. 788 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 789 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 790 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 791 float epsilon, int64 feature_index); 792 793 // Creates a batch-norm-grad instruction. 794 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 795 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 796 HloInstruction* mean, HloInstruction* variance, 797 HloInstruction* grad_output, float epsilon, int64 feature_index); 798 799 // Creates a scatter computation that scatters the `source` array to the 800 // selected indices of each window. 801 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 802 const Shape& shape, HloInstruction* operand, HloComputation* select, 803 const Window& window, HloInstruction* source, HloInstruction* init_value, 804 HloComputation* scatter); 805 806 // Creates a broadcast instruction. 807 static std::unique_ptr<HloInstruction> CreateBroadcast( 808 const Shape& shape, HloInstruction* operand, 809 absl::Span<const int64> broadcast_dimensions); 810 811 // Creates a sequence of instructions that performs an explicit broadcast of 812 // the operand to the target shape. 813 // 814 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 815 // returned as a unique_ptr for API consistency with other factory methods in 816 // this interface. 817 // 818 // TODO(b/72173833) Ideally HloComputations would always be present, and so 819 // the adder being passed by the caller would not be necessary. 820 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 821 const Shape& output_shape, HloInstruction* operand, 822 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 823 adder); 824 825 // Creates a pad instruction, where the operand is padded on the edges and 826 // between the elements with the given padding value. 827 static std::unique_ptr<HloInstruction> CreatePad( 828 const Shape& shape, HloInstruction* operand, 829 HloInstruction* padding_value, const PaddingConfig& padding_config); 830 831 // Creates a reshape instruction, where the operand is flattened row-major 832 // order and then reshaped to the given result shape. 833 static std::unique_ptr<HloInstruction> CreateReshape( 834 const Shape& shape, HloInstruction* operand, 835 int64 inferred_dimension = -1); 836 837 // Creates a transpose instruction which permutes the operand dimensions. 838 static std::unique_ptr<HloInstruction> CreateTranspose( 839 const Shape& shape, HloInstruction* operand, 840 absl::Span<const int64> dimensions); 841 842 // Creates a n-ary sort op with a 'compare' computation which is used for 843 // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, 844 // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at 845 // specific index positions which should be compared, and should return a 846 // PRED. 'is_stable' specifies whether stable sorting is required. 847 static std::unique_ptr<HloInstruction> CreateSort( 848 const Shape& shape, int64 dimension, 849 absl::Span<HloInstruction* const> operands, HloComputation* compare, 850 bool is_stable); 851 852 // Creates a while instruction, given a condition computation, a body 853 // computation, and the initial value for the input of the computations. For 854 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 855 // corresponds to the C code below. 856 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } 857 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 858 HloComputation* condition, 859 HloComputation* body, 860 HloInstruction* init); 861 862 static std::unique_ptr<HloInstruction> CreateConditional( 863 const Shape& shape, HloInstruction* pred, 864 HloInstruction* true_computation_arg, HloComputation* true_computation, 865 HloInstruction* false_computation_arg, HloComputation* false_computation); 866 867 static std::unique_ptr<HloInstruction> CreateConditional( 868 const Shape& shape, HloInstruction* branch_index, 869 absl::Span<HloComputation* const> branch_computations, 870 absl::Span<HloInstruction* const> branch_computation_args); 871 872 static std::unique_ptr<HloInstruction> CreateGather( 873 const Shape& shape, HloInstruction* operand, 874 HloInstruction* start_indices, 875 const GatherDimensionNumbers& gather_dim_numbers, 876 absl::Span<const int64> slice_sizes, bool indices_are_sorted); 877 878 static std::unique_ptr<HloInstruction> CreateScatter( 879 const Shape& shape, HloInstruction* operand, 880 HloInstruction* scatter_indices, HloInstruction* updates, 881 HloComputation* update_computation, 882 const ScatterDimensionNumbers& scatter_dim_numbers, 883 bool indices_are_sorted, bool unique_indices); 884 885 // Creates a kDomain instruction which delimits an HLO domain which have 886 // the provided user and operand side metadata. 887 static std::unique_ptr<HloInstruction> CreateDomain( 888 const Shape& shape, HloInstruction* operand, 889 std::unique_ptr<DomainMetadata> operand_side_metadata, 890 std::unique_ptr<DomainMetadata> user_side_metadata); 891 892 // Creates a fusion instruction. A fusion instruction contains one or more 893 // fused instructions forming an expression with a single root 894 // "fused_root". Additional instructions can be added to the fusion 895 // instruction with the method FuseInstruction. 896 static std::unique_ptr<HloInstruction> CreateFusion( 897 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 898 899 static std::unique_ptr<HloInstruction> CreateFusion( 900 const Shape& shape, FusionKind fusion_kind, 901 absl::Span<HloInstruction* const> operands, 902 HloComputation* fusion_computation); 903 904 // Creates a call instruction that applies the given computation on the given 905 // operands. "shape" is the resultant shape. 906 static std::unique_ptr<HloInstruction> CreateCall( 907 const Shape& shape, absl::Span<HloInstruction* const> operands, 908 HloComputation* computation); 909 910 // Creates a custom call instruction that applies the given custom call target 911 // to the given operands. "opaque" can be an arbitrary string with a 912 // backend-specific interpretation. "shape" is the resultant shape. 913 static std::unique_ptr<HloInstruction> CreateCustomCall( 914 const Shape& shape, absl::Span<HloInstruction* const> operands, 915 absl::string_view custom_call_target, string opaque = ""); 916 917 // Overload which constrains the layouts of the operand and result. 'shape' 918 // and 'operand_shapes_with_layout' must have layouts. 919 // 'operand_shapes_with_layout' must have a compatible element for each 920 // operand. 921 static std::unique_ptr<HloInstruction> CreateCustomCall( 922 const Shape& shape, absl::Span<HloInstruction* const> operands, 923 absl::string_view custom_call_target, 924 absl::Span<const Shape> operand_shapes_with_layout, string opaque = ""); 925 926 // Creates a tuple instruction with the given elements. This is a convenience 927 // wrapper around CreateVariadic. 928 static std::unique_ptr<HloInstruction> CreateTuple( 929 absl::Span<HloInstruction* const> elements); 930 931 // Creates a reverse instruction, which reverses the order of the elements 932 // in the specified dimensions. 933 static std::unique_ptr<HloInstruction> CreateReverse( 934 const Shape& shape, HloInstruction* operand, 935 absl::Span<const int64> dimensions); 936 937 // Creates a Afterall instruction used for joining or creating new values of 938 // token type which thread through side-effecting operations. Operands must 939 // all be tokens, and there must be at least one operand. 940 static std::unique_ptr<HloInstruction> CreateAfterAll( 941 absl::Span<HloInstruction* const> operands); 942 943 // Creates an AfterAll instruction which creates a token type out of thin air 944 // (no operands). This is a separate method from CreateAfterAll to facility 945 // the removal of operand-less AfterAll instructions. 946 // TODO(b/110532604): Remove this capability of creating a token from nothing 947 // when we plumb a primordial token from the entry computation. 948 static std::unique_ptr<HloInstruction> CreateToken(); 949 950 static std::unique_ptr<HloInstruction> CreateGetDimensionSize( 951 const Shape& shape, HloInstruction* operand, int64 dimension); 952 953 static std::unique_ptr<HloInstruction> CreateSetDimensionSize( 954 const Shape& shape, HloInstruction* operand, HloInstruction* val, 955 int64 dimension); 956 957 static std::unique_ptr<HloInstruction> CreateAddDependency( 958 HloInstruction* data_operand, HloInstruction* token_operand); 959 960 // Returns the opcode for this instruction. opcode()961 HloOpcode opcode() const { return opcode_; } 962 963 // Returns true if this instruction has a side effect, irrespective of whether 964 // any called computations may contain an instruction with side effects. 965 bool HasSideEffectNoRecurse() const; 966 967 // Returns true if this instruction has a side effect. An instruction has a 968 // side effect if it uses certain opcodes or calls a computation with a side 969 // effect. 970 bool HasSideEffect() const; 971 972 // Returns the result shape of this instruction. 973 const Shape& shape() const; 974 975 // Returns the (mutable) result shape of this instruction. mutable_shape()976 Shape* mutable_shape() { return &shape_; } 977 978 // Returns the ith operand to this instruction. 979 const HloInstruction* operand(int64 i) const; 980 981 // Returns the ith operand to this instruction. 982 HloInstruction* mutable_operand(int64 i); 983 984 // Returns the number of operands to this instruction. operand_count()985 int64 operand_count() const { return operands_.size(); } 986 987 // Returns the vector of operands of this instruction. 988 using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; operands()989 const InstructionVector& operands() const { return operands_; } 990 991 // Returns the vector of unique operands, in the same order they are found 992 // within the operand vector. 993 InstructionVector unique_operands() const; 994 995 // Returns the index of 'target' in the operands sequence. 996 // Precondition: target must be an operand (or a fatal error will occur). 997 int64 operand_index(const HloInstruction* target) const; 998 999 // Returns the number of users of this instruction. user_count()1000 int64 user_count() const { return users_.size(); } 1001 1002 // Returns the users of this instruction. users()1003 const std::vector<HloInstruction*>& users() const { return users_; } 1004 1005 // Returns the index of the user in the users() vector. 1006 // 1007 // Precondition: `user` is a user of the instruction. 1008 int64 UserId(HloInstruction* user); 1009 1010 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)1011 bool IsUserOf(const HloInstruction* instruction) const { 1012 return ContainsKey(instruction->user_map_, this); 1013 } 1014 1015 // Adds a control dependency from this instruction to the given 1016 // instruction. This instruction becomes a control predecessor of 1017 // 'instruction', and 'instruction' becomes a control successor of this 1018 // instruction. Returns an error status if either of the given instructions 1019 // does not belong to the same computation. 1020 // 1021 // This is used to enforce an additional ordering requirement that is not 1022 // captured by normal data dependencies, such as ordering among Send or Recv 1023 // operations to avoid deadlock. 1024 Status AddControlDependencyTo(HloInstruction* instruction); 1025 1026 // Removes a previously added control dependency from this instruction to 1027 // 'instruction'. 1028 Status RemoveControlDependencyTo(HloInstruction* instruction); 1029 1030 // Drops all control predecessors and successors from this HLO instruction. 1031 Status DropAllControlDeps(); 1032 1033 // Copies the control predecessors and successors on this HLO instruction to 1034 // `inst`. Does not do a deep copy so this makes sense only if `inst` and 1035 // this HLO are in the same module. 1036 // 1037 // Depending on the use cases we see in practice, in the future we may 1038 // consider folding the logic here into Clone, CloneWithNewOperands and 1039 // ReplaceAllUsesWith by treating control dependencies like data dependencies. 1040 Status CopyAllControlDepsFrom(const HloInstruction* inst); 1041 1042 // Returns the set of control predecessors (successors) of this 1043 // instruction. Control predecessors (successors) must execute before (after) 1044 // the current instruction. control_predecessors()1045 const std::vector<HloInstruction*>& control_predecessors() const { 1046 return control_predecessors_; 1047 } control_successors()1048 const std::vector<HloInstruction*>& control_successors() const { 1049 return control_successors_; 1050 } 1051 1052 // Returns true if "other" performs the same computation as this instruction. 1053 bool Identical( 1054 const HloInstruction& other, 1055 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 1056 eq_operands = std::equal_to<const HloInstruction*>(), 1057 const std::function<bool(const HloComputation*, const HloComputation*)>& 1058 eq_computations = std::equal_to<const HloComputation*>(), 1059 bool layout_sensitive = true) const { 1060 // An instruction is always identical to itself. 1061 if (this == &other) { 1062 return true; 1063 } 1064 1065 // Identical instruction must have the same opcode, shape, and identical 1066 // operands. 1067 if (opcode() != other.opcode()) { 1068 return false; 1069 } 1070 if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) 1071 : ShapeUtil::Compatible(shape(), other.shape()))) { 1072 return false; 1073 } 1074 if (operands().size() != other.operands().size()) { 1075 return false; 1076 } 1077 1078 // Two AllReduces are Identical if they have the same channel_id. 1079 // Their operands don't have to be Identical. 1080 if (!IsCrossModuleAllReduce()) { 1081 // Use an explicit loop rather than ContainerEquals, because copying 1082 // around std::functions may be too expensive in some cases. 1083 for (size_t i = 0; i < operands().size(); ++i) { 1084 if (!eq_operands(operand(i), other.operand(i))) { 1085 return false; 1086 } 1087 } 1088 } 1089 1090 if (backend_config_ != other.backend_config_) { 1091 return false; 1092 } 1093 1094 return IdenticalSlowPath(other, eq_computations); 1095 } 1096 1097 // Generates a hash value of an HLO instruction. Hash considers 1098 // information on opcode, shape, operands, and typically a root instruction. 1099 // This function returns the same hash value for equivalent HLO instructions, 1100 // with respect to HloInstruction::Identical() method. 1101 // 1102 // Uses hash_operand function to compute hash values of its operands. 1103 // At the very top level, hash_operand should be non-recursive to prevent 1104 // non-termination. 1105 uint64 Hash( 1106 const std::function<uint64(const HloInstruction*)>& hash_operand) const; 1107 1108 // Calls the above method with non-recursive hash_operand function. 1109 uint64 Hash() const; 1110 1111 // Returns whether the instruction has a constant operand. 1112 bool HasConstantOperand() const; 1113 1114 // Replaces the use of this instruction in "user" with "new_producer". Note 1115 // that there might be multiple uses of this instruction in "user"; all will 1116 // be replaced. 1117 // 1118 // If user is a fusion instruction, this function will remove any duplicated 1119 // operands of it which could be created due to this replacement. 1120 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 1121 1122 // Same as ReplaceUseWith(), but new_producer can have a different shape. 1123 Status ReplaceUseWithDifferentShape(HloInstruction* user, 1124 HloInstruction* new_producer); 1125 1126 // Replaces the specified operand with new_operand. The old and new operands 1127 // must have compatible shapes ignoring floating-point precision. 1128 // 1129 // This function does NOT remove duplicated operands even if this instruction 1130 // is a fusion, so that the existing operand numbers do not change. 1131 Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); 1132 1133 // Same as ReplaceOperandWith(), but new_operand can have a different shape. 1134 Status ReplaceOperandWithDifferentShape(int64 operand_num, 1135 HloInstruction* new_operand); 1136 1137 // Replaces all uses of this instruction with the new producer. If 1138 // new_producer is a user of this instruction then new_producer remains a use 1139 // of this instruction to avoid introducing cycles into the graph. 1140 // 1141 // If this instruction is the root of its computation, sets the computation's 1142 // root to new_producer. 1143 // 1144 // The new producer must have a compatible shape ignoring floating-point 1145 // precision. 1146 // 1147 // If a user is a fusion instruction, this function will remove any duplicated 1148 // operands of it which could be created due to this replacement. 1149 Status ReplaceAllUsesWith(HloInstruction* new_producer); 1150 1151 // Same as ReplaceAllUsesWith, but new_producer can have a different shape. 1152 Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); 1153 1154 // Performs a postorder DFS visit using this node as the root. If 1155 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 1156 // complete. If ignore_control_predecessors is true, instructions only 1157 // reachable via control dependencies will not be visited, and the postorder 1158 // will not take control dependencies into account. It is as if the control 1159 // dependencies didn't exist in the graph at all. 1160 template <typename HloInstructionPtr> 1161 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 1162 bool call_finish_visit = true, 1163 bool ignore_control_predecessors = false); 1164 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 1165 bool ignore_control_predecessors = false) const { 1166 return const_cast<HloInstruction*>(this)->Accept( 1167 visitor, call_finish_visit, ignore_control_predecessors); 1168 } 1169 1170 // Same as Accept() above, but the order of operand and control predecessor 1171 // visitation is determined by the given operand order; if compare(A, B) == 1172 // true, A is visited before B. 1173 using CompareFunction = 1174 std::function<bool(const HloInstruction*, const HloInstruction*)>; 1175 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 1176 const CompareFunction& operand_order, 1177 bool call_finish_visit = true); 1178 1179 // Visit this instruction and only this instruction with the given visitor. 1180 template <typename HloInstructionPtr> 1181 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 1182 1183 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 1184 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 1185 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 1186 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 1187 const; 1188 LatestNonGteAncestorAndIndex()1189 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 1190 auto rv = 1191 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 1192 return {const_cast<HloInstruction*>(rv.first), rv.second}; 1193 } 1194 1195 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 1196 const HloInstruction* LatestNonGteAncestor() const; 1197 LatestNonGteAncestor()1198 HloInstruction* LatestNonGteAncestor() { 1199 return const_cast<HloInstruction*>( 1200 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 1201 } 1202 1203 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 1204 // The setter should only be called by HloModule or HloComputation methods. 1205 // 1206 // Precondition: The instruction has a valid to_apply_ field. 1207 HloComputation* to_apply() const; 1208 void set_to_apply(HloComputation* to_apply); 1209 1210 // Gets/sets the while_condition or while_body HloComputation for While. The 1211 // setters should only be called by HloModule or HloComputation methods. 1212 // 1213 // Precondition: The instruction is a While instruction. 1214 HloComputation* while_condition() const; 1215 HloComputation* while_body() const; 1216 void set_while_condition(HloComputation* while_condition); 1217 void set_while_body(HloComputation* while_body); 1218 1219 HloInstruction* while_init() const; 1220 1221 // Gets/sets the true and false HloComputation for Conditional. 1222 // 1223 // Precondition: The instruction is a predicated Conditional instruction. 1224 HloComputation* true_computation() const; 1225 HloComputation* false_computation() const; 1226 1227 // Gets the branch HloComputations for Conditional. 1228 // 1229 // Precondition: The instruction is a Conditional instruction. 1230 const std::vector<HloComputation*>& branch_computations() const; 1231 int branch_count() const; 1232 HloComputation* branch_computation(int b) const; 1233 // Sets a branch HloComputation for Conditional. 1234 // The setter should only be called by HloModule or HloComputation methods. 1235 // 1236 // Precondition: The instruction is a Conditional instruction. 1237 void set_branch_computation(int b, HloComputation* computation); 1238 1239 // Returns a string for the signature of this instruction if considered as a 1240 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 1241 string SignatureString() const; 1242 1243 // Returns a debugging string that represents this instruction. 1244 // 1245 // (We express the default options using an overload rather than a default 1246 // param because gdb ignores default params, but does resolve overloads.) 1247 // 1248 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 1249 // default, backing off on providing full information for very large strings, 1250 // or provide a different name for a ToString-like function that does that. ToString()1251 string ToString() const { return ToString(HloPrintOptions()); } 1252 string ToString(const HloPrintOptions& options) const; 1253 1254 // Components of the ToString() representation: 1255 1256 // Returns a string representation of the operand list. 1257 string OperandsToString(const HloPrintOptions& options) const; 1258 1259 // Returns string representation of op-specific attributes. 1260 std::vector<string> ExtraAttributesToString( 1261 const HloPrintOptions& options) const; 1262 1263 // As ToString, but returns a shorter string. 1264 string ToShortString() const; 1265 1266 // Prints an instruction to a string. 1267 // 1268 // The canonical string representation needs to name operands and instruction 1269 // names in a consistent way. This is implemented through the 1270 // canonical_name_map. 1271 string ToStringWithCanonicalNameMap( 1272 const HloPrintOptions& options, 1273 CanonicalNameMap* canonical_name_map) const; 1274 1275 // Returns a serialized representation of this instruction. 1276 virtual HloInstructionProto ToProto() const; 1277 1278 // Returns a category for the HLO. This could be something like "convolution" 1279 // or "elementwise". 1280 virtual string ToCategory() const; 1281 1282 // Returns a logging instruction, if the output of this instruction is logged. 1283 // 1284 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace 1285 HloInstruction* tracing() const; 1286 void set_tracing(HloInstruction* trace_instruction); 1287 1288 // Returns true if this instruction is fused, ie contained within a fusion 1289 // instruction. 1290 bool IsFused() const; 1291 1292 bool IsLoopFusion() const; 1293 bool IsInputFusion() const; 1294 bool IsOutputFusion() const; 1295 bool IsCustomFusion() const; 1296 1297 // Returns true if this instruction can be legally fused into a fusion 1298 // instruction. 1299 bool IsFusible() const; 1300 1301 // Returns the sharding applied to this operator. 1302 // REQUIRES: has_sharding() is true. sharding()1303 const HloSharding& sharding() const { 1304 CHECK(has_sharding()); 1305 return *sharding_; 1306 } sharding_ptr()1307 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } 1308 1309 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)1310 const HloSharding& sharding_or_default(const HloSharding& default_) const { 1311 return sharding_ ? *sharding_ : default_; 1312 } 1313 // Returns the sharding unique device, if any. sharding_unique_device()1314 absl::optional<int64> sharding_unique_device() const { 1315 if (sharding_ == nullptr) { 1316 return absl::optional<int64>(); 1317 } 1318 return sharding_->UniqueDevice(); 1319 } 1320 // Sets the sharding of this operator. Should only be called by HloModule or 1321 // HloComputation methods. set_sharding(const HloSharding & sharding)1322 void set_sharding(const HloSharding& sharding) { 1323 sharding_ = std::make_shared<const HloSharding>(sharding); 1324 } set_sharding(std::shared_ptr<const HloSharding> sharding)1325 void set_sharding(std::shared_ptr<const HloSharding> sharding) { 1326 sharding_ = std::move(sharding); 1327 } 1328 void set_single_sharding(const HloSharding& sharding); 1329 // Sets a sharding that assigns the current instruction to device. set_device_sharding(int64 device)1330 void set_device_sharding(int64 device) { 1331 set_single_sharding(HloSharding::AssignDevice(device)); 1332 } 1333 // Remove any sharding from this operator. clear_sharding()1334 void clear_sharding() { sharding_ = nullptr; } 1335 // Return true if this operator has a sharding assigned. has_sharding()1336 bool has_sharding() const { return sharding_ != nullptr; } 1337 // Checks whether the instruction has compatible sharding with the other 1338 // instruction. has_compatible_sharding(const HloInstruction * other)1339 bool has_compatible_sharding(const HloInstruction* other) const { 1340 if (!has_sharding()) { 1341 return !other->has_sharding(); 1342 } 1343 return other->has_sharding() ? sharding() == other->sharding() : false; 1344 } 1345 1346 // When creating a new instruction which either replaces, or shifts up (kCopy 1347 // insertion case), another instruction, we need to make sure the certain 1348 // properties of the new instruction are copied into the derived one. As of 1349 // today, the metadata and sharding will be propagated to the derived 1350 // instruction. 1351 void SetupDerivedInstruction(HloInstruction* derived_instruction) const; 1352 1353 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1354 // operands. After creation the clone has no uses. "this" (the instruction 1355 // cloned from) is not changed. Suffix is the string to append to the name of 1356 // the instruction to form the name of the cloned instruction. 1357 // Ignores the control predecessors and successors of this HLO instruction. 1358 std::unique_ptr<HloInstruction> Clone( 1359 const string& suffix = "clone", HloCloneContext* context = nullptr) const; 1360 1361 // Clones the HLO instruction as above but with new shape and operands. 1362 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1363 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1364 HloCloneContext* context = nullptr) const; 1365 1366 // Returns the computations this instruction directly calls (if any). called_computations()1367 const std::vector<HloComputation*>& called_computations() const { 1368 return called_computations_; 1369 } 1370 1371 // Replaces all called computations based on a map function. This is needed 1372 // when we clone hlo_computations and want to let the instructions to point 1373 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1374 void ReplaceCalledComputations( 1375 std::function<HloComputation*(HloComputation*)> map_function) { 1376 for (int64 i = 0; i < called_computations_.size(); ++i) { 1377 called_computations_[i] = map_function(called_computations_[i]); 1378 } 1379 } 1380 1381 // Clears out the called computations. 1382 // 1383 // This is, in particular, necessary when inlining function bodies into their 1384 // caller. If there were side-effecting operations in the called computations, 1385 // the call itself is considered side-effecting and thus cannot be removed. By 1386 // clearing out the computations, we reflect the fact that all side-effecting 1387 // properties have been reflected in the caller, and make the call HLO 1388 // removable. ClearCalledComputations()1389 void ClearCalledComputations() { called_computations_.clear(); } 1390 1391 // Returns true if this instruction performs an elementwise operation on 1392 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1393 // to compute the output at index {i_0,i_1,...,i_n}, the only element required 1394 // from the operand (if any) is the element at {i_0,i_1,...,i_n}. 1395 // 1396 // Note on performance: when this instruction is kFusion, this method, in the 1397 // worst case, scans all fused instructions. We could speed this up by 1398 // caching. 1399 bool IsElementwiseOnOperand(int64 operand_idx) const; 1400 1401 // Returns true if this instruction is elementwise on all its operands. 1402 bool IsElementwise() const; 1403 1404 // Returns true if this is a cross module all-reduce instruction. 1405 bool IsCrossModuleAllReduce() const; 1406 1407 // Returns true if this is a cross-replica all-reduce instruction. 1408 bool IsCrossReplicaAllReduce() const; 1409 1410 // Returns true if this instruction is binary and elementwise. 1411 bool IsElementwiseBinary() const; 1412 1413 // Returns whether this instruction may reuse elements of its `i`th operand. ReusesOperandElements(int64 i)1414 bool ReusesOperandElements(int64 i) const { 1415 return OperandElementUse(i) == UseKind::kReuse; 1416 } 1417 1418 // Returns the indices that the given operand appear in the operand list of 1419 // this instruction. Note that an instruction can use the same operand 1420 // multiple times. 1421 absl::InlinedVector<int64, 4> OperandIndices( 1422 const HloInstruction* operand) const; 1423 1424 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1425 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1426 // indices of the deleted dimensions and the output indices of the inserted 1427 // dimensions. 1428 // 1429 // Precondition: this op must be a reshape. 1430 std::tuple<bool, std::vector<int64>, std::vector<int64>> 1431 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1432 1433 // Gets the string identifier for this instruction. name()1434 const string& name() const { return name_; } 1435 1436 // Sets the string identifier for this instruction. Name will be sanitized to 1437 // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". SetAndSanitizeName(const string & name)1438 void SetAndSanitizeName(const string& name) { 1439 name_ = NameUniquer::GetSanitizedName(name); 1440 } 1441 1442 // Use the given NameUniquer to select a unique name for the instruction based 1443 // on the instruction's existing name. 1444 void UniquifyName(NameUniquer* name_uniquer); 1445 1446 // Clear the unique ID of the instruction so that it can be re-assigned, such 1447 // as for the purpose of compacting the instruction unique IDs. ClearUniqueIdInternal()1448 void ClearUniqueIdInternal() { unique_id_ = -1; } 1449 1450 // Set the unique id for this instruction to "id" SetUniqueId(int id)1451 void SetUniqueId(int id) { 1452 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1453 CHECK_GE(id, 0); 1454 unique_id_ = id; 1455 } 1456 1457 // Return the unique ID assigned to this node via SetUniqueId (or -1 1458 // if no id has been assigned yet). unique_id()1459 int unique_id() const { return unique_id_; } 1460 1461 // Returns the backend-specific configuration for how a backend should compile 1462 // this HLO. The meaning of the field is backend specific. Not for use before 1463 // or during general HLO optimization, since HLO optimizations do not preserve 1464 // this field and they cannot interpret it due to its meaning being backend 1465 // specific. Except for CustomCall, where this field is preserved and no 1466 // general HLO optimization needs to interpret it. 1467 // 1468 // ConfigProto should be a protobuf Message type. 1469 template <typename ConfigProto> backend_config()1470 StatusOr<ConfigProto> backend_config() const { 1471 ConfigProto proto; 1472 TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); 1473 return std::move(proto); 1474 } 1475 Status set_backend_config(const tensorflow::protobuf::Message& proto); 1476 set_frontend_attributes(FrontendAttributes frontend_attributes)1477 void set_frontend_attributes(FrontendAttributes frontend_attributes) { 1478 frontend_attributes_ = std::move(frontend_attributes); 1479 } 1480 frontend_attributes()1481 const FrontendAttributes& frontend_attributes() const { 1482 return frontend_attributes_; 1483 } 1484 1485 // Getter/setter for raw JSON-encoded backend config. Prefer the 1486 // functions above that deal in proto Messages where possible. raw_backend_config_string()1487 const string& raw_backend_config_string() const { return backend_config_; } set_raw_backend_config_string(string config_str)1488 void set_raw_backend_config_string(string config_str) { 1489 backend_config_ = std::move(config_str); 1490 } 1491 is_default_config()1492 bool is_default_config() const { return is_default_config_; } set_default_config()1493 void set_default_config() { is_default_config_ = true; } 1494 1495 // Returns a string representation of a proto in the format used by 1496 // raw_backend_config_string. 1497 // 1498 // This is morally equivalent to: 1499 // 1500 // HloInstruction instr; 1501 // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); 1502 // return instr.raw_backend_config_string(); 1503 // 1504 static StatusOr<string> BackendConfigToRawString( 1505 const tensorflow::protobuf::Message& proto); 1506 1507 // Returns the information used to tell the implementation information about 1508 // what sort of precision is requested. The meaning of the field is backend 1509 // specific. At the moment, it is only supported for kConvolution and kDot. 1510 // Transformations on one kDot or kConvolution to another will preserve this 1511 // information. Transformations to other HLOs will not preserve this 1512 // information but it is presumed that the alternate lowering is strictly 1513 // superior. 1514 // Precondition: opcode must be kConvolution or kDot. 1515 const PrecisionConfig& precision_config() const; 1516 PrecisionConfig* mutable_precision_config(); 1517 1518 // Sets the debug metadata for this instruction. set_metadata(const OpMetadata & metadata)1519 void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } metadata()1520 const OpMetadata& metadata() const { return metadata_; } 1521 1522 // Set/get the computation containing this instruction. set_parent should only 1523 // be called by HloComputation methods which add/remove instructions to 1524 // computations. set_parent(HloComputation * computation)1525 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1526 const HloComputation* parent() const { return parent_; } parent()1527 HloComputation* parent() { return parent_; } 1528 1529 // Returns the module for this instruction. 1530 HloModule* GetModule() const; 1531 1532 // Returns whether we could assign input and output layouts to this 1533 // instruction to make it a bitcast. 1534 bool CouldBeBitcast() const; 1535 1536 // Get/Set the number of partitions per outer dimension (in order, starting 1537 // with outer-most dimension first). Currently used by the parallel cpu 1538 // backend to partition HLOs into parallel tasks. 1539 // 1540 // TODO(b/62783254) Replace these methods with a more general way to 1541 // annotate HLOs with backend-specific information. outer_dimension_partitions()1542 const std::vector<int64>& outer_dimension_partitions() const { 1543 return outer_dimension_partitions_; 1544 } 1545 void set_outer_dimension_partitions( 1546 const std::vector<int64>& outer_dimension_partitions); 1547 1548 // Old methods kept for smooth subclassing transition BEGIN. 1549 // TODO(b/80131774): Remove this code. 1550 1551 // Delegates to HloBatchNormInstruction::feature_index. 1552 int64 feature_index() const; 1553 1554 // Delegates to HloBatchNormInstruction::epsilon. 1555 float epsilon() const; 1556 1557 // Delegates to HloFftInstruction::fft_type. 1558 FftType fft_type() const; 1559 1560 // Delegates to HloFftInstruction::fft_length. 1561 const std::vector<int64>& fft_length() const; 1562 1563 // Delegates to HloChannelInstruction::channel_id. 1564 absl::optional<int64> channel_id() const; 1565 void set_channel_id(const absl::optional<int64>& channel_id); 1566 1567 // Returns the dimension sizes or numbers associated with this instruction. dimensions()1568 virtual const std::vector<int64>& dimensions() const { 1569 LOG(FATAL) << "Unimplemented method."; 1570 } dimensions(int64 index)1571 virtual int64 dimensions(int64 index) const { 1572 LOG(FATAL) << "Unimplemented method."; 1573 } 1574 1575 // Delegates to HloConcatenateInstruction::concatenate_dimension. 1576 int64 concatenate_dimension() const; 1577 1578 // Delegates to HloGetDimensionSizeInstruction::dimension. 1579 int64 dimension() const; 1580 1581 // Delegates to HloReshapeInstruction::inferred_dimension. 1582 int64 inferred_dimension() const; 1583 1584 // Returns whether this instruction does a rank-2 transposition. 1585 bool IsRank2Transpose() const; 1586 1587 // Delegates to HloSliceInstruction::slice_start. 1588 int64 slice_starts(int64 dimension) const; 1589 const std::vector<int64>& slice_starts() const; 1590 1591 // Delegates to HloSliceInstruction::slice_limits. 1592 int64 slice_limits(int64 dimension) const; 1593 const std::vector<int64>& slice_limits() const; 1594 1595 // Delegates to HloSliceInstruction::slice_strides. 1596 int64 slice_strides(int64 dimension) const; 1597 const std::vector<int64>& slice_strides() const; 1598 1599 // Returns the literal associated with this instruction. 1600 const Literal& literal() const; 1601 1602 // Returns whether the instruction is a constant. 1603 bool IsConstant() const; 1604 1605 // Delegate to HloConstantInstruction::RelayoutConstant. 1606 void RelayoutConstant(const Layout& new_layout, 1607 const ShapeIndex& shape_index = {}); 1608 1609 // Delegates to HloTraceInstruction::TracingTag. 1610 string TracingTag() const; 1611 1612 // Delegates to HloFusionInstruction::AddFusionOperand. 1613 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 1614 1615 // Delegates to HloFusionInstruction::MergeFusionInstruction. 1616 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 1617 1618 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. 1619 void MergeFusionInstructionIntoMultiOutput( 1620 HloInstruction* instruction_to_merge); 1621 1622 // Delegates to HloFusionInstruction::FuseInstruction. 1623 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); 1624 1625 // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. 1626 HloInstruction* FuseInstructionIntoMultiOutput( 1627 HloInstruction* instruction_to_fuse); 1628 1629 // Delegates to HloFusionInstruction::fused_instruction. 1630 HloComputation* fused_instructions_computation() const; 1631 1632 // Delegates to HloFusionInstruction::fused_expression_root. 1633 HloInstruction* fused_expression_root() const; 1634 1635 // Delegates to HloFusionInstruction::fused_instructions. 1636 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1637 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1638 fused_instructions() const; 1639 1640 const tensorflow::gtl::iterator_range< 1641 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1642 fused_instructions(); 1643 1644 // Delegates to HloFusionInstruction::fused_instruction_count. 1645 int64 fused_instruction_count() const; 1646 1647 // Delegates to HloFusionInstruction::fused_parameter. 1648 HloInstruction* fused_parameter(int64 parameter_number) const; 1649 1650 // Delegates to HloFusionInstruction::fused_parameters. 1651 const std::vector<HloInstruction*>& fused_parameters() const; 1652 1653 // Returns true if this instruction is a fusion instruction that generates 1654 // multiple outputs. 1655 const bool IsMultiOutputFusion() const; 1656 1657 // Delegates to HloFusionInstruction::fusion_kind. 1658 FusionKind fusion_kind() const; 1659 1660 // Delegates to HloFusionInstruction::set_fusion_kind. 1661 void set_fusion_kind(FusionKind kind); 1662 1663 // Delegates to HloRngInstruction::random_distribution. 1664 RandomDistribution random_distribution() const; 1665 1666 // Delegates to HloParameterInstruction::parameter_number. 1667 int64 parameter_number() const; 1668 1669 // Delegates to 1670 // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. 1671 void set_parameter_replicated_at_leaf_buffers( 1672 absl::Span<const bool> parameter_replicated_at_leaf_buffers); 1673 void set_parameter_replicated_at_leaf_buffers( 1674 const std::vector<bool>& parameter_replicated_at_leaf_buffers); 1675 1676 // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. 1677 const absl::optional<std::vector<bool>>& 1678 parameter_replicated_at_leaf_buffers() const; 1679 1680 // Delegates to HloGetTupleElementInstruction::tuple_index. 1681 int64 tuple_index() const; 1682 1683 // Delegates to HloGetTupleElementInstruction::set_tuple_index. 1684 void set_tuple_index(int64 new_tuple_index); 1685 1686 // Delegates to HloReducePrecisionInstruction::exponent_bits. 1687 int32 exponent_bits() const; 1688 1689 // Delegates to HloReducePrecisionInstruction::mantissa_bits. 1690 int32 mantissa_bits() const; 1691 1692 // Delegates to HloInfeedInstruction::infeed_config. 1693 string infeed_config() const; 1694 1695 // Delegates to HloInfeedInstruction::set_infeed_config. 1696 void set_infeed_config(const string& config); 1697 1698 // Returns the config for the Outfeed instruction. 1699 const string& outfeed_config() const; 1700 1701 // Returns the shape for the Outfeed instruction. 1702 const Shape& outfeed_shape() const; 1703 1704 // Delegates to HloCollectiveInstruction::replica_groups. 1705 const std::vector<ReplicaGroup>& replica_groups() const; 1706 1707 // Delegates to HloCollectivePermuteInstruction::source_target_pairs. 1708 const std::vector<std::pair<int64, int64>>& source_target_pairs() const; 1709 1710 // Returns data on the window in a windowed operation such as 1711 // convolution. window()1712 virtual const Window& window() const { 1713 LOG(FATAL) << "Unimplemented method."; 1714 } 1715 1716 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)1717 virtual void set_window(const Window& window) { 1718 LOG(FATAL) << "Unimplemented method."; 1719 } 1720 1721 // Returns the unique_indices field. unique_indices()1722 virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; } 1723 1724 // Returns data on the dimension numbers used for a convolution operation, 1725 // which may be a kConvolution instruction or a kCustomCall that implements a 1726 // convolution. 1727 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; 1728 1729 // Sets the convolution dimension numbers on this instruction. In general you 1730 // shouldn't need to call this; instead, specify the convolution dimension 1731 // numbers when you create the instruction. 1732 void set_convolution_dimension_numbers( 1733 const ConvolutionDimensionNumbers& dnums); 1734 1735 // The number of feature groups. Must be a divisor of the input feature 1736 // dimension and output feature dimension. 1737 int64 feature_group_count() const; 1738 1739 void set_feature_group_count(int64 feature_group_count); 1740 1741 // The number of batch groups. Must be a divisor of the input batch dimension 1742 int64 batch_group_count() const; 1743 1744 void set_batch_group_count(int64 batch_group_count); 1745 1746 // Delegates to HloSelectAndScatterInstruction::select. 1747 HloComputation* select() const; 1748 1749 // Delegates to HloSelectAndScatterInstruction::scatter. 1750 HloComputation* scatter() const; 1751 1752 // Delegates to HloSelectAndScatterInstruction::set_select. 1753 void set_select(HloComputation* computation); 1754 1755 // Delegates to HloSelectAndScatterInstruction::set_scatter. 1756 void set_scatter(HloComputation* computation); 1757 1758 // Delegates to HloCustomCallInstruction::custom_call_target. 1759 const string& custom_call_target() const; 1760 1761 // Delegates to HloPadInstruction::padding_config. 1762 const PaddingConfig& padding_config() const; 1763 1764 // Delegates to HloDynamicSliceInstruction::slice_sizes. 1765 int64 slice_sizes(int64 dimension) const; 1766 1767 // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. 1768 const std::vector<int64>& dynamic_slice_sizes() const; 1769 1770 // Delegates to HloGatherInstruction::gather_dimension_numbers. 1771 const GatherDimensionNumbers& gather_dimension_numbers() const; 1772 // Delegates to HloGatherInstruction::gather_slice_sizes. 1773 absl::Span<const int64> gather_slice_sizes() const; 1774 1775 // Delegates to HloScatterInstruction::scatter_dimension_numbers(). 1776 const ScatterDimensionNumbers& scatter_dimension_numbers() const; 1777 1778 // Delegates to HloDotInstruction::dot_dimension_numbers(). 1779 const DotDimensionNumbers& dot_dimension_numbers() const; 1780 1781 // Delegates to HloDomainInstruction::operand_side_metadata(). 1782 const DomainMetadata& operand_side_metadata() const; 1783 1784 // Delegates to HloDomainInstruction::user_side_metadata(). 1785 const DomainMetadata& user_side_metadata() const; 1786 1787 // Delegates to HloCompareInstruction::direction(). 1788 ComparisonDirection comparison_direction() const; 1789 1790 // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). 1791 const TriangularSolveOptions& triangular_solve_options() const; 1792 1793 // Delegates to HloCholeskyInstruction::cholesky_options(). 1794 const CholeskyOptions& cholesky_options() const; 1795 1796 // Appends operand to the list of operands and adds this instruction as a user 1797 // of the operand. 1798 void AppendOperand(HloInstruction* operand); 1799 1800 // Old methods kept for smooth subclassing transition END. 1801 1802 protected: 1803 // Indicates how an instruction uses a value (such as an operand). 1804 // 1805 // Does it (a) not use it, (b) use it, or (c) use it multiple times? 1806 // 1807 // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or 1808 // (ii) use it after having permuted it somehow, e.g. through a reshape. If 1809 // the use is a permuting use, we set permutation_instr to the instruction 1810 // that did the permuting. 1811 struct UseKind { 1812 enum Kind { kReuse, kUse, kNoUse }; 1813 1814 // Creates a UseKind that represents a use that permutes an instruction's 1815 // elements according to the given instruction. PermutingUseKind1816 static UseKind Permuting(const HloInstruction* permutation_instr) { 1817 UseKind k(kUse); 1818 k.permutation_instr = permutation_instr; 1819 return k; 1820 } 1821 UseKindUseKind1822 UseKind(Kind kind) // NOLINT intentionally nonexplicit 1823 : kind(kind), permutation_instr(nullptr) {} 1824 1825 bool friend operator==(UseKind a, Kind b) { return a.kind == b; } 1826 bool friend operator==(Kind a, UseKind b) { return b == a; } 1827 1828 Kind kind; 1829 const HloInstruction* permutation_instr; 1830 }; 1831 1832 // Helper class for computing OperandElementUse for kFusion. 1833 class FusionReusesParamElements; 1834 1835 // Internal constructor for a given opcode/shape, other fields must be filled 1836 // by factory methods. 1837 HloInstruction(HloOpcode opcode, const Shape& shape); 1838 RemoveOperandAt(int index)1839 void RemoveOperandAt(int index) { 1840 operands_.erase(operands_.begin() + index); 1841 } 1842 1843 // Removes a list of operands with the given indices in ascending order. 1844 void RemoveOperandsAtAscendingIndices( 1845 absl::Span<const int> ascending_indices); 1846 AppendComputation(HloComputation * computation)1847 void AppendComputation(HloComputation* computation) { 1848 called_computations_.push_back(computation); 1849 } 1850 DetachFrom(HloInstruction * usee)1851 void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } 1852 set_called_computation(int index,HloComputation * computation)1853 void set_called_computation(int index, HloComputation* computation) { 1854 called_computations_[index] = computation; 1855 } 1856 // Indices of computations in called_computations_ for instructions which call 1857 // multiple computations. 1858 enum { 1859 // kWhile computations. 1860 kBodyComputationIndex = 0, 1861 kConditionComputationIndex = 1, 1862 1863 // kSelectAndScatter computations. 1864 kSelectComputationIndex = 0, 1865 kScatterComputationIndex = 1, 1866 1867 // kConditional computations. 1868 kTrueComputationIndex = 0, 1869 kFalseComputationIndex = 1, 1870 }; 1871 1872 private: 1873 // Implementation for non-common logic of CloneWithNewOperands. CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)1874 virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1875 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1876 HloCloneContext* context) const { 1877 // TODO(b/80131774): This should be pure virtual. 1878 LOG(FATAL) << "Unimplemented method."; 1879 } 1880 1881 // Implementation for non-common logic of ExtraAttributesToString. ExtraAttributesToStringImpl(const HloPrintOptions & options)1882 virtual std::vector<string> ExtraAttributesToStringImpl( 1883 const HloPrintOptions& options) const { 1884 return {}; 1885 } 1886 1887 // Implementation for IsElementwise if operand_idx is nullopt and for 1888 // IsElementwiseOnOperand if otherwise. 1889 // 1890 // NOTE: For all instructions other than kFusion, being elementwise on one of 1891 // the operands is equivalent to being elementwise on all the operands. 1892 virtual bool IsElementwiseImpl( 1893 const absl::optional<int64>& operand_idx) const; 1894 1895 // Prints an operand to a string. 1896 virtual string OperandsToStringWithCanonicalNameMap( 1897 const HloPrintOptions& options, 1898 CanonicalNameMap* canonical_name_map) const; 1899 1900 // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and 1901 // OperandsToStringWithCanonicalNameMap() functions. 1902 friend class HloComputation; 1903 1904 // See comments on Identical(). 1905 virtual bool IdenticalSlowPath( 1906 const HloInstruction& other, 1907 const std::function<bool(const HloComputation*, const HloComputation*)>& 1908 eq_computations) const; 1909 1910 // Generates a hash value specific to a particular type of an instruction. 1911 // This function typically considers the inner root instruction. 1912 virtual uint64 InnerHash() const; 1913 1914 // Creates an n-ary elementwise operation. 1915 static std::unique_ptr<HloInstruction> CreateNary( 1916 const Shape& shape, HloOpcode opcode, 1917 absl::Span<HloInstruction* const> operands); 1918 1919 // Adds a user for this instruction. 1920 void AddUser(HloInstruction* user); 1921 1922 // Removes a user for this instruction. 1923 void RemoveUser(HloInstruction* user); 1924 1925 // Returns how this instruction uses elements of its operand at operand_num. 1926 UseKind OperandElementUse(int64 operand_num) const; 1927 1928 // Helper for implementing backend_config(). Parses backend_config_ into the 1929 // given proto. 1930 Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; 1931 1932 int unique_id_; // Unique to this HloInstruction within a HloModule 1933 1934 // Opcode for this instruction. 1935 HloOpcode opcode_; 1936 1937 // Instruction operands. 1938 InstructionVector operands_; 1939 1940 // The set of control predecessors of this instruction. 1941 // Note that the order of the instructions in the vector influences the order 1942 // computed in HloComputation::ComputeInstructionPostOrder, which may 1943 // influence the result of the compilation by changing the scheduling. We are 1944 // not sure if it matters. 1945 std::vector<HloInstruction*> control_predecessors_; 1946 1947 // The users of this instruction. Users are HLOs where this instruction is an 1948 // operand. The vector users_ and the map user_map_ contain identical members. 1949 // The map enables fast membership testing and the vector enables fast, stable 1950 // iteration. The value in the map contains the index of the instruction in 1951 // the vector what enables fast removal. 1952 std::vector<HloInstruction*> users_; 1953 absl::flat_hash_map<const HloInstruction*, int64> user_map_; 1954 1955 // The set of control successors of this instruction. 1956 std::vector<HloInstruction*> control_successors_; 1957 1958 // The computation in which this instruction is contained. 1959 HloComputation* parent_ = nullptr; 1960 1961 // Result shape of this instruction. 1962 Shape shape_; 1963 1964 // The sharding, if one exists. 1965 // Uses std::shared_ptr to allow reuse of the same sharding object between 1966 // HloInstructions and other components as HloSharding can be very large for 1967 // many element tuples. 1968 std::shared_ptr<const HloSharding> sharding_; 1969 1970 // Computations called by this instruction. 1971 std::vector<HloComputation*> called_computations_; 1972 1973 // A trace instruction that consumes this instruction. 1974 // 1975 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 1976 // an operand. 1977 HloInstruction* trace_instruction_ = nullptr; 1978 1979 // The backend-specific configuration for how a backend should compile this 1980 // HLO. See the documentation on backend_config(). 1981 string backend_config_; 1982 1983 // Attributes passed from the frontend to give hints to the backend about 1984 // how to compile this HLO. 1985 // HLO -> HLO transforms are expected to preserve these attributes on a 1986 // "best effort" basis only. 1987 // For example: 1988 // x = const(10, frontend_attributes={x} 1989 // y = const(10, frontend_attributes={y} 1990 // z = add(x,y), frontend_attributes={y} 1991 // Could be simplified to: 1992 // z' = const(20), frontend_attributes={?} 1993 FrontendAttributes frontend_attributes_; 1994 1995 // This field is assigned to true when backend_config_ is assigned to 1996 // a default configuration. 1997 bool is_default_config_ = false; 1998 1999 // String identifier for instruction. 2000 string name_; 2001 2002 // Metadata for debugging. 2003 OpMetadata metadata_; 2004 2005 // The number of partitions per outer dimension (listed in order from 2006 // outer-most dimension first). 2007 std::vector<int64> outer_dimension_partitions_; 2008 2009 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); 2010 }; 2011 2012 // Explicit instantiations in hlo_instruction.cc. 2013 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); 2014 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); 2015 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor); 2016 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); 2017 2018 string ToString(HloInstruction::FusionKind kind); 2019 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 2020 const string& kind_name); 2021 2022 // Custom (de)stringification functions for protos that live inside 2023 // HloInstruction. 2024 string PaddingConfigToString(const PaddingConfig& padding); 2025 string FrontendAttributesToString( 2026 const FrontendAttributes& frontend_attributes); 2027 string OpMetadataToString(const OpMetadata& metadata); 2028 string RandomDistributionToString(const RandomDistribution& distribution); 2029 string PrecisionToString(const PrecisionConfig::Precision& precision); 2030 string ConvolutionDimensionNumbersToString( 2031 const ConvolutionDimensionNumbers& dnums); 2032 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups); 2033 2034 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); 2035 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); 2036 2037 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 2038 2039 // Map classes that guarantee a deterministic iteration order when the key is 2040 // an HloInstruction* or a const HloInstruction*. 2041 // To make the iteration order over the map deterministic, the comparator 2042 // should not be using the pointer values, but rather an intrinsic property of 2043 // the hlo. Exception: null pointer values compare less than non-null. 2044 struct HloPtrComparator { 2045 bool operator()(const HloInstruction* const& lhs, 2046 const HloInstruction* const& rhs) const; 2047 }; 2048 2049 template <typename ValueT> 2050 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 2051 2052 template <typename ValueT> 2053 using ConstHloInstructionMap = 2054 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 2055 2056 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 2057 using ConstHloInstructionSet = 2058 std::set<const HloInstruction*, HloPtrComparator>; 2059 2060 } // namespace xla 2061 2062 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 2063