1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <set> 29 #include <string> 30 #include <tuple> 31 #include <vector> 32 33 #include "absl/container/flat_hash_map.h" 34 #include "absl/container/flat_hash_set.h" 35 #include "absl/container/inlined_vector.h" 36 #include "absl/memory/memory.h" 37 #include "absl/strings/str_cat.h" 38 #include "absl/strings/string_view.h" 39 #include "absl/types/span.h" 40 #include "tensorflow/compiler/xla/comparison_util.h" 41 #include "tensorflow/compiler/xla/iterator_util.h" 42 #include "tensorflow/compiler/xla/literal.h" 43 #include "tensorflow/compiler/xla/map_util.h" 44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 45 #include "tensorflow/compiler/xla/service/hlo.pb.h" 46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" 48 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 49 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 50 #include "tensorflow/compiler/xla/service/name_uniquer.h" 51 #include "tensorflow/compiler/xla/shape_tree.h" 52 #include "tensorflow/compiler/xla/types.h" 53 #include "tensorflow/compiler/xla/xla_data.pb.h" 54 #include "tensorflow/core/lib/core/status.h" 55 #include "tensorflow/core/lib/gtl/iterator_range.h" 56 #include "tensorflow/core/platform/logging.h" 57 #include "tensorflow/core/platform/macros.h" 58 #include "tensorflow/core/platform/protobuf.h" 59 #include "tensorflow/core/platform/types.h" 60 61 namespace xla { 62 63 class HloComputation; 64 class HloModule; 65 66 // A bunch of switches that control how the hlo text should be printed. 67 class HloPrintOptions { 68 public: 69 enum class PrintSubcomputationMode { 70 kOff, // Do not print anything about subcomputations. 71 kNameOnly, // Only print the name of subcomputations. 72 kFullBodies, // Print the full bodies of subcomputations. 73 }; 74 75 // Constructs the default print options: don't print large constants, don't 76 // compact operands, no indentation. HloPrintOptions()77 HloPrintOptions() 78 : print_large_constants_(false), 79 print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), 80 print_metadata_(true), 81 print_backend_config_(true), 82 compact_operands_(false), 83 print_operand_shape_(true), 84 print_operand_names_(true), 85 print_program_shape_(true), 86 print_percent_(true), 87 print_control_dependencies_(true), 88 canonicalize_instruction_names_(false), 89 indent_amount_(0), 90 is_in_nested_computation_(false) {} 91 ShortParsable()92 static HloPrintOptions ShortParsable() { 93 return HloPrintOptions() 94 .set_print_large_constants(true) 95 .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) 96 .set_print_metadata(false) 97 .set_print_backend_config(false) 98 .set_print_operand_shape(false) 99 .set_print_program_shape(false) 100 .set_print_percent(false) 101 .set_print_control_dependencies(false); 102 } 103 104 // Options to produce the canonical string representing an isomorphic 105 // computation graph. Canonical()106 static HloPrintOptions Canonical() { 107 return HloPrintOptions() 108 .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) 109 .set_print_metadata(false) 110 .set_print_backend_config(false) 111 .set_compact_operands(true) 112 .set_print_operand_names(false) 113 .set_print_operand_shape(true) 114 .set_print_program_shape(false) 115 .set_print_percent(false) 116 .set_print_control_dependencies(false) 117 .set_canonicalize_instruction_names(true); 118 } 119 120 // If true, large constants will be printed out. set_print_large_constants(bool value)121 HloPrintOptions& set_print_large_constants(bool value) { 122 print_large_constants_ = value; 123 return *this; 124 } 125 set_print_subcomputation_mode(PrintSubcomputationMode value)126 HloPrintOptions& set_print_subcomputation_mode( 127 PrintSubcomputationMode value) { 128 print_subcomputation_mode_ = value; 129 return *this; 130 } 131 132 // If true, metadata will be printed. set_print_metadata(bool value)133 HloPrintOptions& set_print_metadata(bool value) { 134 print_metadata_ = value; 135 return *this; 136 } 137 138 // If true, backend_config will be printed. set_print_backend_config(bool value)139 HloPrintOptions& set_print_backend_config(bool value) { 140 print_backend_config_ = value; 141 return *this; 142 } 143 144 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)145 HloPrintOptions& set_print_operand_shape(bool value) { 146 print_operand_shape_ = value; 147 return *this; 148 } 149 150 // If true, the operand names will be printed. set_print_operand_names(bool value)151 HloPrintOptions& set_print_operand_names(bool value) { 152 print_operand_names_ = value; 153 return *this; 154 } 155 156 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)157 HloPrintOptions& set_print_program_shape(bool value) { 158 print_program_shape_ = value; 159 return *this; 160 } 161 162 // If true, names will be printed with prefix '%'. set_print_percent(bool value)163 HloPrintOptions& set_print_percent(bool value) { 164 print_percent_ = value; 165 return *this; 166 } 167 168 // If true, control dependencies will be printed. set_print_control_dependencies(bool value)169 HloPrintOptions& set_print_control_dependencies(bool value) { 170 print_control_dependencies_ = value; 171 return *this; 172 } 173 174 // If true, only a part of operands will be printed out (note that in this 175 // case the text will not be parsable). set_compact_operands(bool value)176 HloPrintOptions& set_compact_operands(bool value) { 177 compact_operands_ = value; 178 return *this; 179 } 180 181 // If true, canonicalizes instructions' name. Instead of using "%foo.1" as 182 // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. set_canonicalize_instruction_names(bool value)183 HloPrintOptions& set_canonicalize_instruction_names(bool value) { 184 canonicalize_instruction_names_ = value; 185 return *this; 186 } 187 188 // The indent of the hlo text block. set_indent_amount(int value)189 HloPrintOptions& set_indent_amount(int value) { 190 indent_amount_ = value; 191 return *this; 192 } 193 194 // If true, indicates the instruction being printed is inside a nested 195 // computation. set_is_in_nested_computation(bool value)196 HloPrintOptions& set_is_in_nested_computation(bool value) { 197 is_in_nested_computation_ = value; 198 return *this; 199 } 200 print_large_constants()201 bool print_large_constants() const { return print_large_constants_; } print_subcomputation_mode()202 PrintSubcomputationMode print_subcomputation_mode() const { 203 return print_subcomputation_mode_; 204 } print_metadata()205 bool print_metadata() const { return print_metadata_; } print_backend_config()206 bool print_backend_config() const { return print_backend_config_; } compact_operands()207 bool compact_operands() const { return compact_operands_; } print_operand_shape()208 bool print_operand_shape() const { return print_operand_shape_; } print_operand_names()209 bool print_operand_names() const { return print_operand_names_; } print_program_shape()210 bool print_program_shape() const { return print_program_shape_; } print_percent()211 bool print_percent() const { return print_percent_; } print_control_dependencies()212 bool print_control_dependencies() const { 213 return print_control_dependencies_; 214 } canonicalize_instruction_names()215 bool canonicalize_instruction_names() const { 216 return canonicalize_instruction_names_; 217 } indent_amount()218 int indent_amount() const { return indent_amount_; } is_in_nested_computation()219 int is_in_nested_computation() const { return is_in_nested_computation_; } 220 221 private: 222 bool print_large_constants_; 223 PrintSubcomputationMode print_subcomputation_mode_; 224 bool print_metadata_; 225 bool print_backend_config_; 226 bool compact_operands_; 227 bool print_operand_shape_; 228 bool print_operand_names_; 229 bool print_program_shape_; 230 bool print_percent_; 231 bool print_control_dependencies_; 232 bool canonicalize_instruction_names_; 233 int indent_amount_; 234 bool is_in_nested_computation_; 235 }; 236 237 // For canonical string output, we need to have a canonical way to rename 238 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>", 239 // where <xxx> is an index starting from 0. 240 class CanonicalNameMap { 241 public: CanonicalNameMap()242 CanonicalNameMap() : index(0) {} 243 LookupOrInsert(const string & old_name)244 string LookupOrInsert(const string& old_name) { 245 auto iter = canonical_name_map.find(old_name); 246 if (iter != canonical_name_map.end()) { 247 return iter->second; 248 } 249 250 string new_name = absl::StrCat("tmp_", index++); 251 canonical_name_map[old_name] = new_name; 252 return new_name; 253 } Clear()254 void Clear() { 255 canonical_name_map.clear(); 256 index = 0; 257 } 258 259 private: 260 int64 index; 261 absl::flat_hash_map<string, string> canonical_name_map; 262 }; 263 264 // HLO instructions are the atomic unit of the high-level compiler's IR. 265 // 266 // HloInstructions live inside of an HloComputation, which is analogous to a 267 // function in other programming languages. Nodes have no total order within 268 // their computation. Instead, they have a partial ordering determined by their 269 // data and control dependencies. 270 // 271 // HLO does not have basic blocks or explicit "branch" instructions. Instead, 272 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode 273 // control flow. For example, the kConditional HLO executes one of two possible 274 // computations, depending on the runtime value of a predicate. 275 // 276 // HLO is pure (mostly). It has no concept of mutable state. Instead, data 277 // values are produced by one HLO and flow into consumers across dependency 278 // edges. 279 class HloInstruction { 280 public: 281 // A fusion node computes the same value a call to its fusion computation 282 // would compute. However, the choice of fusion kind dictates codegen 283 // strategy for the backend. 284 // 285 // To generate code for a kFusion HloInstruction, most backends do something 286 // like the following: 287 // 288 // 1) Identify the "primary" HloInstruction of the fused computation. 289 // 2) Emit code that does the work of the primary node, creating its inputs 290 // and transforming its outputs as specified by the fused computation. 291 // 292 // In step (2), the code emitted is usually similar to the code that would be 293 // emitted for an *unfused* version of the primary node, except that 294 // 295 // - when the primary node reads an element of one of its operands, instead 296 // of loading the value from memory, it *computes* the value based on the 297 // contents of the fused computation. 298 // - when the primary node outputs a value, instead of storing it to memory, 299 // it forwards the value to its users, which then perform additional 300 // computations before the value is finally stored to memory at the root of 301 // the fusion node. 302 // 303 // An HloInstruction's FusionKind helps us find the kFusion instruction's 304 // primary node, and can also affect how we generate code in step (2). 305 // 306 // - kInput: The primary node is the root of the fused instruction. 307 // 308 // - kOutput: The primary node is not the root of the fused instruction. 309 // This fusion kind requires that one operand buffer of the fusion 310 // instruction be able to alias the output buffer. This constraint is 311 // usually enough to let backends find the primary node unambiguously. 312 // 313 // - kLoop: The primary node is the root of the fused computation, but, 314 // unlike in input fusion, we prescribe a specific implementation for 315 // codegen. Rather than generating code that looks like the code we'd emit 316 // for an unfused version of the primary/root node, we emit code that 317 // generates one element of the root at a time. 318 // 319 // - kCustom: Custom category for backend-specific fusions that don't fit 320 // into the above patterns. 321 // 322 // Not all backends support all fusion kinds, and given a particular fused 323 // computation, it's not in general safe to change its fusion kind. Creation 324 // of fusion nodes is always backend-specific. 325 // 326 // For elementwise ops (e.g. kAdd), most backends would emit a 327 // one-element-at-a-time implementation for the unfused version, so loop 328 // fusion and input fusion are probably equivalent if the root node is 329 // elementwise. They're not necessarily equivalent e.g. for kReduce, where an 330 // implementation might emit something more sophisticated for an unfused or 331 // input-fusion reduce, but will emit the naive code that reduces one element 332 // at a time for loop fusion with a reduce as the root. 333 // 334 // Another way to think of loop fusion is that it's equivalent to input 335 // fusion, but where the root node is an implicit identity node, whose 336 // unfused implementation is "read one element, write one element". 337 // 338 // TODO(b/79869434): This categorization scheme is not great. For one thing, 339 // input and loop fusion are basically the same thing: There is no reason for 340 // the HLO to encode backend-specific decisions about how e.g. a reduce that's 341 // the root of a fusion should be lowered. In addition, this scheme as 342 // written doesn't work for multi-output fusion, where the primary node is 343 // never actually the root (which is a kTuple instruction that gathers the 344 // multiple outputs of the fusion). 345 enum class FusionKind { 346 kLoop, 347 kInput, 348 kOutput, 349 kCustom, 350 }; 351 352 virtual ~HloInstruction(); 353 354 // Creates an instruction from the given proto. Arguments: 355 // 356 // proto: the proto to convert from. 357 // instruction_map: a map from instruction id to HloInstruction*. This map 358 // must contain all operands of the newly constructed instruction. 359 // computation_map: a map from computation id to HloComputation*. This map 360 // must contain all computations which the newly constructed instruction 361 // calls. 362 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 363 const HloInstructionProto& proto, 364 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, 365 const absl::flat_hash_map<int64, HloComputation*>& computation_map); 366 367 // Creates a parameter-retrieving instruction. 368 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, 369 const Shape& shape, 370 const string& name); 371 372 // Creates a literal constant instruction. 373 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); 374 375 // Creates an Iota instruction. 376 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, 377 int64 iota_dimension); 378 379 // Creates a get tuple element instruction. 380 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 381 const Shape& shape, HloInstruction* operand, int64 index); 382 383 // Creates a trace instruction that logs the input operand in the computation. 384 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, 385 HloInstruction* operand); 386 387 // Creates a random number generation instruction that fills a shape with 388 // random numbers from a given distribution. 389 // 390 // The parameters to the instruction are interpreted as follows: 391 // 392 // - If `distribution` is RNG_UNIFORM, generates a number in range 393 // [param0, param1). 394 // 395 // - If `distribution` is RNG_NORMAL, generates a normally-distributed value 396 // with mean `param0` and standard deviation `param1`. 397 static std::unique_ptr<HloInstruction> CreateRng( 398 const Shape& shape, RandomDistribution distribution, 399 absl::Span<HloInstruction* const> parameters); 400 401 // Creates a unary instruction (one operand). 402 // Precondition: opcode must be a legitimate unary operation. 403 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 404 HloOpcode opcode, 405 HloInstruction* operand); 406 407 // Creates a binary instruction (two operands). 408 // Precondition: opcode must be a legitimate binary operation. 409 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 410 HloOpcode opcode, 411 HloInstruction* lhs, 412 HloInstruction* rhs); 413 414 // Creates a ternary instruction (three operands). 415 // Precondition: opcode must be a legitimate ternary operation. 416 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 417 HloOpcode opcode, 418 HloInstruction* lhs, 419 HloInstruction* rhs, 420 HloInstruction* ehs); 421 422 // Creates a variadic instruction (variable number of operands). 423 // Precondition: opcode must be a legitimate variadic operation. 424 static std::unique_ptr<HloInstruction> CreateVariadic( 425 const Shape& shape, HloOpcode opcode, 426 absl::Span<HloInstruction* const> operands); 427 428 // Creates a map instruction, where the computation (given by the handle) is 429 // applied element-wise to every element in operands (across the operands, 430 // at a given index) 431 static std::unique_ptr<HloInstruction> CreateMap( 432 const Shape& shape, absl::Span<HloInstruction* const> operands, 433 HloComputation* map_computation); 434 435 // Creates a convolution op, where rhs is the convolutional filter 436 // and window describes how the filter is applied to lhs. 437 static std::unique_ptr<HloInstruction> CreateConvolve( 438 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 439 int64 feature_group_count, int64 batch_group_count, const Window& window, 440 const ConvolutionDimensionNumbers& dimension_numbers, 441 const PrecisionConfig& precision_config); 442 443 // Creates an FFT op, of the type indicated by fft_type. 444 static std::unique_ptr<HloInstruction> CreateFft( 445 const Shape& shape, HloInstruction* operand, FftType fft_type, 446 absl::Span<const int64> fft_length); 447 448 // Creates a compare op, performing the comparison specified in direction. 449 static std::unique_ptr<HloInstruction> CreateCompare( 450 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 451 ComparisonDirection direction); 452 453 static std::unique_ptr<HloInstruction> CreateTriangularSolve( 454 const Shape& shape, HloInstruction* a, HloInstruction* b, 455 const TriangularSolveOptions& options); 456 457 static std::unique_ptr<HloInstruction> CreateCholesky( 458 const Shape& shape, HloInstruction* a, const CholeskyOptions& options); 459 460 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 461 // dimensions specified in 'dimension_numbers'. 462 static std::unique_ptr<HloInstruction> CreateDot( 463 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 464 const DotDimensionNumbers& dimension_numbers, 465 const PrecisionConfig& precision_config); 466 467 // Creates a reduce-precision op, where operand is the data to reduce in 468 // precision, and exponent_bits and mantissa_bits describe the precision to 469 // reduce it to. 470 static std::unique_ptr<HloInstruction> CreateReducePrecision( 471 const Shape& shape, HloInstruction* operand, const int exponent_bits, 472 const int mantissa_bits); 473 474 // Creates a cross replica reduction op. 475 // 476 // `reduction_computation`: the reduction function. 477 // 478 // `replica_groups`: each ReplicaGroup contains a list of replica id. If 479 // empty, all replicas belong to one group in the order of 0 - (n-1). 480 // Allreduce will be applied within subgroups. 481 // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, 482 // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 483 // 484 // `all_reduce_id`: for Allreduce nodes from different modules, if they have 485 // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will 486 // not be applied cross modules. 487 static std::unique_ptr<HloInstruction> CreateAllReduce( 488 const Shape& shape, absl::Span<HloInstruction* const> operands, 489 HloComputation* reduce_computation, 490 const std::vector<ReplicaGroup>& replica_groups, 491 absl::string_view barrier, const absl::optional<int64>& all_reduce_id); 492 493 // This op handles the communication of an Alltoall operation. On each core, 494 // the operands are N ops in the same shape, where N is the number of cores 495 // participating the Alltoall. Then the N operands are scattered to N cores, 496 // e.g., the ith operand is sent to the ith core. Then each core gathers the 497 // received data into a tuple. 498 // 499 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If 500 // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall 501 // will be applied within subgroups in the specified order. For example, 502 // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied 503 // within replica 1, 2, 3, and in the gather phase, the received blocks will 504 // be concatenated in the order of 1, 2, 3; another Alltoall will be applied 505 // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. 506 static std::unique_ptr<HloInstruction> CreateAllToAll( 507 const Shape& shape, absl::Span<HloInstruction* const> operands, 508 const std::vector<ReplicaGroup>& replica_groups); 509 510 // Creates a communitation instructions that permutes data cross replicas. 511 // Data is sent/received according to the (source_replica_id, 512 // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a 513 // target_replica_id in any pair, the output on that replica is a tensor 514 // consists of 0(s) in `shape`. 515 static std::unique_ptr<HloInstruction> CreateCollectivePermute( 516 const Shape& shape, HloInstruction* operand, 517 const std::vector<std::pair<int64, int64>>& source_target_pairs); 518 519 // Creates an instruction that returns a U32 replica ID. 520 static std::unique_ptr<HloInstruction> CreateReplicaId(); 521 522 // Creates a conversion instruction, where operand is the data to convert and 523 // shape is the target shape for the conversion. 524 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 525 HloInstruction* operand); 526 527 // Creates a bitcast conversion instruction, where operand is the data to 528 // convert and shape is the target shape for the conversion. 529 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 530 const Shape& shape, HloInstruction* operand); 531 532 // Creates an infeed instruction, which reads data of the given shape from the 533 // Infeed interface of the device. infeed_shape is the shape of the data 534 // received from the infeed *not* the shape of the infeed instruction which 535 // is a tuple containing the infeed_shape and the TOKEN. 536 static std::unique_ptr<HloInstruction> CreateInfeed( 537 const Shape& infeed_shape, HloInstruction* token_operand, 538 const string& config); 539 540 // Creates an outfeed instruction, which outputs data. outfeed_shape is the 541 // shape of the data being outfed *not* the shape of the outfeed instruction 542 // which is a TOKEN. 543 static std::unique_ptr<HloInstruction> CreateOutfeed( 544 const Shape& outfeed_shape, HloInstruction* operand, 545 HloInstruction* token_operand, absl::string_view outfeed_config); 546 547 // Creates an asynchronous send instruction with the given channel id, which 548 // initiates sending the operand data to a unique receive instruction in 549 // another computation that has the same channel id. If is_host_transfer is 550 // true, then this Send operation transfers data to the host. 551 static std::unique_ptr<HloInstruction> CreateSend( 552 HloInstruction* operand, HloInstruction* token, int64 channel_id, 553 bool is_host_transfer = false); 554 555 // Blocks until data transfer for the Send instruction (operand) is complete. 556 // The operand must be kSend. 557 static std::unique_ptr<HloInstruction> CreateSendDone( 558 HloInstruction* operand, bool is_host_transfer = false); 559 560 // Creates an asynchronous receive instruction with the given channel id, 561 // which allocates resources to receive data of the given shape from a unique 562 // send instruction in another computation that has the same channel id. If 563 // is_host_transfer is true, then this Send operation transfers data from the 564 // host. 565 static std::unique_ptr<HloInstruction> CreateRecv( 566 const Shape& shape, HloInstruction* token, int64 channel_id, 567 bool is_host_transfer = false); 568 569 // Blocks until data transfer for the Recv instruction (operand) is complete 570 // and returns the receive buffer. The operand must be kRecv. 571 static std::unique_ptr<HloInstruction> CreateRecvDone( 572 HloInstruction* operand, bool is_host_transfer = false); 573 574 // Creates a slice instruction, where the operand is sliced by the given 575 // start/limit indices. 576 static std::unique_ptr<HloInstruction> CreateSlice( 577 const Shape& shape, HloInstruction* operand, 578 absl::Span<const int64> start_indices, 579 absl::Span<const int64> limit_indices, absl::Span<const int64> strides); 580 581 // Creates a slice instruction, where the first operand is sliced by 582 // start indices specified in the second operand, and by size specified in 583 // 'slice_sizes'. 584 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 585 const Shape& shape, HloInstruction* operand, 586 absl::Span<HloInstruction* const> start_indices, 587 absl::Span<const int64> slice_sizes); 588 589 // Creates a dynamic update slice instruction, which updates a slice 590 // of 'operand' with 'update' and 'start_indices'. 591 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 592 const Shape& shape, HloInstruction* operand, HloInstruction* update, 593 absl::Span<HloInstruction* const> start_indices); 594 595 // Creates a concatenate instruction, where the operands are concatenated on 596 // the provided dimension. 597 static std::unique_ptr<HloInstruction> CreateConcatenate( 598 const Shape& shape, absl::Span<HloInstruction* const> operands, 599 int64 dimension); 600 601 // Creates a reduce instruction, where the computation (given by the handle) 602 // is applied successively to every element in operand. For example, let f be 603 // the function to apply, which takes 2 arguments, an accumulator and the 604 // current value. Let init be an initial value (which is normally chosen to be 605 // the identity element for f, e.g. 0 if f is addition). 606 // Then the reduce HLO will compute: 607 // f(f(init, value0), value1), ...) 608 static std::unique_ptr<HloInstruction> CreateReduce( 609 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 610 absl::Span<const int64> dimensions_to_reduce, 611 HloComputation* reduce_computation); 612 613 // A more general, multiple-argument version of the above. 614 // The function to apply, f, now takes N arguments: 615 // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., 616 // init_valueN], and returns an N-tuple. The performed computation is (for 617 // commutative and associative f operators) equivalent to: 618 // 619 // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) 620 // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, 621 // ..., inputN.value1) 622 // ... 623 static std::unique_ptr<HloInstruction> CreateReduce( 624 const Shape& shape, absl::Span<HloInstruction* const> operands, 625 absl::Span<HloInstruction* const> init_values, 626 absl::Span<const int64> dimensions_to_reduce, 627 HloComputation* reduce_computation); 628 629 // Creates a reduce-window instruction, where the computation (given 630 // by the handle) is applied window-wise at each valid window 631 // position in the operand. 632 static std::unique_ptr<HloInstruction> CreateReduceWindow( 633 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 634 const Window& window, HloComputation* reduce_computation); 635 636 // Creates a batch-norm-training instruction. 637 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 638 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 639 HloInstruction* offset, float epsilon, int64 feature_index); 640 641 // Creates a batch-norm-inference instruction. 642 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 643 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 644 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 645 float epsilon, int64 feature_index); 646 647 // Creates a batch-norm-grad instruction. 648 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 649 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 650 HloInstruction* mean, HloInstruction* variance, 651 HloInstruction* grad_output, float epsilon, int64 feature_index); 652 653 // Creates a scatter computation that scatters the `source` array to the 654 // selected indices of each window. 655 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 656 const Shape& shape, HloInstruction* operand, HloComputation* select, 657 const Window& window, HloInstruction* source, HloInstruction* init_value, 658 HloComputation* scatter); 659 660 // Creates a broadcast instruction. 661 static std::unique_ptr<HloInstruction> CreateBroadcast( 662 const Shape& shape, HloInstruction* operand, 663 absl::Span<const int64> broadcast_dimensions); 664 665 // Creates a sequence of instructions that performs an explicit broadcast of 666 // the operand to the target shape. 667 // 668 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 669 // returned as a unique_ptr for API consistency with other factory methods in 670 // this interface. 671 // 672 // TODO(b/72173833) Ideally HloComputations would always be present, and so 673 // the adder being passed by the caller would not be necessary. 674 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 675 const Shape& output_shape, HloInstruction* operand, 676 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 677 adder); 678 679 // Creates a pad instruction, where the operand is padded on the edges and 680 // between the elements with the given padding value. 681 static std::unique_ptr<HloInstruction> CreatePad( 682 const Shape& shape, HloInstruction* operand, 683 HloInstruction* padding_value, const PaddingConfig& padding_config); 684 685 // Creates a reshape instruction, where the operand is flattened row-major 686 // order and then reshaped to the given result shape. 687 static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape, 688 HloInstruction* operand); 689 690 // Creates a transpose instruction which permutes the operand dimensions. 691 static std::unique_ptr<HloInstruction> CreateTranspose( 692 const Shape& shape, HloInstruction* operand, 693 absl::Span<const int64> dimensions); 694 695 // Creates a n-ary sort op with a 'compare' computation which is used for 696 // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, 697 // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at 698 // specific index positions which should be compared, and should return a 699 // PRED. 'is_stable' specifies whether stable sorting is required. 700 static std::unique_ptr<HloInstruction> CreateSort( 701 const Shape& shape, int64 dimension, 702 absl::Span<HloInstruction* const> operands, HloComputation* compare, 703 bool is_stable); 704 705 // Creates a while instruction, given a condition computation, a body 706 // computation, and the initial value for the input of the computations. For 707 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 708 // corresponds to the C code below. 709 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } 710 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 711 HloComputation* condition, 712 HloComputation* body, 713 HloInstruction* init); 714 715 static std::unique_ptr<HloInstruction> CreateConditional( 716 const Shape& shape, HloInstruction* pred, 717 HloInstruction* true_computation_arg, HloComputation* true_computation, 718 HloInstruction* false_computation_arg, HloComputation* false_computation); 719 720 static std::unique_ptr<HloInstruction> CreateConditional( 721 const Shape& shape, HloInstruction* branch_index, 722 absl::Span<HloComputation* const> branch_computations, 723 absl::Span<HloInstruction* const> branch_computation_args); 724 725 static std::unique_ptr<HloInstruction> CreateGather( 726 const Shape& shape, HloInstruction* operand, 727 HloInstruction* start_indices, 728 const GatherDimensionNumbers& gather_dim_numbers, 729 absl::Span<const int64> slice_sizes); 730 731 static std::unique_ptr<HloInstruction> CreateScatter( 732 const Shape& shape, HloInstruction* operand, 733 HloInstruction* scatter_indices, HloInstruction* updates, 734 HloComputation* update_computation, 735 const ScatterDimensionNumbers& scatter_dim_numbers); 736 737 // Creates a kDomain instruction which delimits an HLO domain which have 738 // the provided user and operand side metadata. 739 static std::unique_ptr<HloInstruction> CreateDomain( 740 const Shape& shape, HloInstruction* operand, 741 std::unique_ptr<DomainMetadata> operand_side_metadata, 742 std::unique_ptr<DomainMetadata> user_side_metadata); 743 744 // Creates a fusion instruction. A fusion instruction contains one or more 745 // fused instructions forming an expression with a single root 746 // "fused_root". Additional instructions can be added to the fusion 747 // instruction with the method FuseInstruction. 748 static std::unique_ptr<HloInstruction> CreateFusion( 749 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 750 751 static std::unique_ptr<HloInstruction> CreateFusion( 752 const Shape& shape, FusionKind fusion_kind, 753 absl::Span<HloInstruction* const> operands, 754 HloComputation* fusion_computation); 755 756 // Creates a call instruction that applies the given computation on the given 757 // operands. "shape" is the resultant shape. 758 static std::unique_ptr<HloInstruction> CreateCall( 759 const Shape& shape, absl::Span<HloInstruction* const> operands, 760 HloComputation* computation); 761 762 // Creates a custom call instruction that applies the given custom call target 763 // to the given operands. "opaque" can be an arbitrary string with a 764 // backend-specific interpretation. "shape" is the resultant shape. 765 static std::unique_ptr<HloInstruction> CreateCustomCall( 766 const Shape& shape, absl::Span<HloInstruction* const> operands, 767 absl::string_view custom_call_target, absl::string_view opaque = ""); 768 769 // Overload which constrains the layouts of the operand and result. 'shape' 770 // and 'operand_shapes_with_layout' must have layouts. 771 // 'operand_shapes_with_layout' must have a compatible element for each 772 // operand. 773 static std::unique_ptr<HloInstruction> CreateCustomCall( 774 const Shape& shape, absl::Span<HloInstruction* const> operands, 775 absl::string_view custom_call_target, 776 absl::Span<const Shape> operand_shapes_with_layout, 777 absl::string_view opaque = ""); 778 779 // Creates a tuple instruction with the given elements. This is a convenience 780 // wrapper around CreateVariadic. 781 static std::unique_ptr<HloInstruction> CreateTuple( 782 absl::Span<HloInstruction* const> elements); 783 784 // Creates a reverse instruction, which reverses the order of the elements 785 // in the specified dimensions. 786 static std::unique_ptr<HloInstruction> CreateReverse( 787 const Shape& shape, HloInstruction* operand, 788 absl::Span<const int64> dimensions); 789 790 // Creates a Afterall instruction used for joining or creating new values of 791 // token type which thread through side-effecting operations. Operands must 792 // all be tokens, and there must be at least one operand. 793 static std::unique_ptr<HloInstruction> CreateAfterAll( 794 absl::Span<HloInstruction* const> operands); 795 796 // Creates an AfterAll instruction which creates a token type out of thin air 797 // (no operands). This is a separate method from CreateAfterAll to facility 798 // the removal of operand-less AfterAll instructions. 799 // TODO(b/110532604): Remove this capability of creating a token from nothing 800 // when we plumb a primordial token from the entry computation. 801 static std::unique_ptr<HloInstruction> CreateToken(); 802 803 static std::unique_ptr<HloInstruction> CreateGetDimensionSize( 804 const Shape& shape, HloInstruction* operand, int64 dimension); 805 806 static std::unique_ptr<HloInstruction> CreateAddDependency( 807 HloInstruction* data_operand, HloInstruction* token_operand); 808 809 // Returns the opcode for this instruction. opcode()810 HloOpcode opcode() const { return opcode_; } 811 812 // Returns true if this instruction has a side effect, irrespective of whether 813 // any called computations may contain an instruction with side effects. 814 bool HasSideEffectNoRecurse() const; 815 816 // Returns true if this instruction has a side effect. An instruction has a 817 // side effect if it uses certain opcodes or calls a computation with a side 818 // effect. 819 bool HasSideEffect() const; 820 821 // Returns the result shape of this instruction. 822 const Shape& shape() const; 823 824 // Returns the (mutable) result shape of this instruction. mutable_shape()825 Shape* mutable_shape() { return &shape_; } 826 827 // Returns the ith operand to this instruction. 828 const HloInstruction* operand(int64 i) const; 829 830 // Returns the ith operand to this instruction. 831 HloInstruction* mutable_operand(int64 i); 832 833 // Returns the number of operands to this instruction. operand_count()834 int64 operand_count() const { return operands_.size(); } 835 836 // Returns the vector of operands of this instruction. 837 using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; operands()838 const InstructionVector& operands() const { return operands_; } 839 840 // Returns the vector of unique operands, in the same order they are found 841 // within the operand vector. 842 InstructionVector unique_operands() const; 843 844 // Returns the index of 'target' in the operands sequence. 845 // Precondition: target must be an operand (or a fatal error will occur). 846 int64 operand_index(const HloInstruction* target) const; 847 848 // Returns the number of users of this instruction. user_count()849 int64 user_count() const { return users_.size(); } 850 851 // Returns the users of this instruction. users()852 const std::vector<HloInstruction*>& users() const { return users_; } 853 854 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)855 bool IsUserOf(const HloInstruction* instruction) const { 856 return ContainsKey(instruction->user_set_, this); 857 } 858 859 // Adds a control dependency from this instruction to the given 860 // instruction. This instruction becomes a control predecessor of 861 // 'instruction', and 'instruction' becomes a control successor of this 862 // instruction. Returns an error status if either of the given instructions 863 // does not belong to the same computation. 864 // 865 // This is used to enforce an additional ordering requirement that is not 866 // captured by normal data dependencies, such as ordering among Send or Recv 867 // operations to avoid deadlock. 868 Status AddControlDependencyTo(HloInstruction* instruction); 869 870 // Removes a previously added control dependency from this instruction to 871 // 'instruction'. 872 Status RemoveControlDependencyTo(HloInstruction* instruction); 873 874 // Drops all control predecessors and successors from this HLO instruction. 875 Status DropAllControlDeps(); 876 877 // Copies the control predecessors and successors on this HLO instruction to 878 // `inst`. Does not do a deep copy so this makes sense only if `inst` and 879 // this HLO are in the same module. 880 // 881 // Depending on the use cases we see in practice, in the future we may 882 // consider folding the logic here into Clone, CloneWithNewOperands and 883 // ReplaceAllUsesWith by treating control dependencies like data dependencies. 884 Status CopyAllControlDepsFrom(const HloInstruction* inst); 885 886 // Returns the set of control predecessors (successors) of this 887 // instruction. Control predecessors (successors) must execute before (after) 888 // the current instruction. control_predecessors()889 const std::vector<HloInstruction*>& control_predecessors() const { 890 return control_predecessors_; 891 } control_successors()892 const std::vector<HloInstruction*>& control_successors() const { 893 return control_successors_; 894 } 895 896 // Returns true if "other" performs the same computation as this instruction. 897 bool Identical( 898 const HloInstruction& other, 899 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 900 eq_operands = std::equal_to<const HloInstruction*>(), 901 const std::function<bool(const HloComputation*, const HloComputation*)>& 902 eq_computations = std::equal_to<const HloComputation*>(), 903 bool layout_sensitive = true) const { 904 // An instruction is always identical to itself. 905 if (this == &other) { 906 return true; 907 } 908 909 // Identical instruction must have the same opcode, shape, and identical 910 // operands. 911 if (opcode() != other.opcode()) { 912 return false; 913 } 914 if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) 915 : ShapeUtil::Compatible(shape(), other.shape()))) { 916 return false; 917 } 918 if (operands().size() != other.operands().size()) { 919 return false; 920 } 921 922 // Two AllReduces are Identical if they have the same all_reduce_id. 923 // Their operands don't have to be Identical. 924 if (!IsCrossModuleAllReduce()) { 925 // Use an explicit loop rather than ContainerEquals, because copying 926 // around std::functions may be too expensive in some cases. 927 for (size_t i = 0; i < operands().size(); ++i) { 928 if (!eq_operands(operand(i), other.operand(i))) { 929 return false; 930 } 931 } 932 } 933 934 if (backend_config_ != other.backend_config_) { 935 return false; 936 } 937 938 return IdenticalSlowPath(other, eq_computations); 939 } 940 941 // Generates a hash value of an HLO instruction. Hash considers 942 // information on opcode, shape, operands, and typically a root instruction. 943 // This function returns the same hash value for equivalent HLO instructions, 944 // with respect to HloInstruction::Identical() method. 945 // 946 // Uses hash_operand function to compute hash values of its operands. 947 // At the very top level, hash_operand should be non-recursive to prevent 948 // non-termination. 949 uint64 Hash( 950 const std::function<uint64(const HloInstruction*)>& hash_operand) const; 951 952 // Calls the above method with non-recursive hash_operand function. 953 uint64 Hash() const; 954 955 // Returns whether the instruction has a constant operand. 956 bool HasConstantOperand() const; 957 958 // Replaces the use of this instruction in "user" with "new_producer". Note 959 // that there might be multiple uses of this instruction in "user"; all will 960 // be replaced. 961 // 962 // If user is a fusion instruction, this function will remove any duplicated 963 // operands of it which could be created due to this replacement. 964 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 965 966 // Same as ReplaceUseWith(), but new_producer can have a different shape. 967 Status ReplaceUseWithDifferentShape(HloInstruction* user, 968 HloInstruction* new_producer); 969 970 // Replaces the specified operand with new_operand. The old and new operands 971 // must have compatible shapes ignoring floating-point precision. 972 // 973 // This function does NOT remove duplicated operands even if this instruction 974 // is a fusion, so that the existing operand numbers do not change. 975 Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); 976 977 // Same as ReplaceOperandWith(), but new_operand can have a different shape. 978 Status ReplaceOperandWithDifferentShape(int64 operand_num, 979 HloInstruction* new_operand); 980 981 // Replaces all uses of this instruction with the new producer. If 982 // new_producer is a user of this instruction then new_producer remains a use 983 // of this instruction to avoid introducing cycles into the graph. 984 // 985 // If this instruction is the root of its computation, sets the computation's 986 // root to new_producer. 987 // 988 // The new producer must have a compatible shape ignoring floating-point 989 // precision. 990 // 991 // If a user is a fusion instruction, this function will remove any duplicated 992 // operands of it which could be created due to this replacement. 993 Status ReplaceAllUsesWith(HloInstruction* new_producer); 994 995 // Same as ReplaceAllUsesWith, but new_producer can have a different shape. 996 Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); 997 998 // Performs a postorder DFS visit using this node as the root. If 999 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 1000 // complete. If ignore_control_predecessors is true, instructions only 1001 // reachable via control dependencies will not be visited, and the postorder 1002 // will not take control dependencies into account. It is as if the control 1003 // dependencies didn't exist in the graph at all. 1004 template <typename HloInstructionPtr> 1005 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 1006 bool call_finish_visit = true, 1007 bool ignore_control_predecessors = false); 1008 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 1009 bool ignore_control_predecessors = false) const { 1010 return const_cast<HloInstruction*>(this)->Accept( 1011 visitor, call_finish_visit, ignore_control_predecessors); 1012 } 1013 1014 // Same as Accept() above, but the order of operand and control predecessor 1015 // visitation is determined by the given operand order; if compare(A, B) == 1016 // true, A is visited before B. 1017 using CompareFunction = 1018 std::function<bool(const HloInstruction*, const HloInstruction*)>; 1019 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 1020 const CompareFunction& operand_order, 1021 bool call_finish_visit = true); 1022 1023 // Performs a postorder DFS visit using this node as the root. Calls the given 1024 // visitor function at each instruction. 1025 Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); 1026 Status Accept( 1027 const std::function<Status(const HloInstruction*)>& visitor_func) const; 1028 1029 // Visit this instruction and only this instruction with the given visitor. 1030 template <typename HloInstructionPtr> 1031 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 1032 1033 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 1034 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 1035 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 1036 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 1037 const; 1038 LatestNonGteAncestorAndIndex()1039 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 1040 auto rv = 1041 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 1042 return {const_cast<HloInstruction*>(rv.first), rv.second}; 1043 } 1044 1045 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 1046 const HloInstruction* LatestNonGteAncestor() const; 1047 LatestNonGteAncestor()1048 HloInstruction* LatestNonGteAncestor() { 1049 return const_cast<HloInstruction*>( 1050 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 1051 } 1052 1053 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 1054 // The setter should only be called by HloModule or HloComputation methods. 1055 // 1056 // Precondition: The instruction has a valid to_apply_ field. 1057 HloComputation* to_apply() const; 1058 void set_to_apply(HloComputation* to_apply); 1059 1060 // Gets/sets the while_condition or while_body HloComputation for While. The 1061 // setters should only be called by HloModule or HloComputation methods. 1062 // 1063 // Precondition: The instruction is a While instruction. 1064 HloComputation* while_condition() const; 1065 HloComputation* while_body() const; 1066 void set_while_condition(HloComputation* while_condition); 1067 void set_while_body(HloComputation* while_body); 1068 1069 HloInstruction* while_init() const; 1070 1071 // Gets/sets the true and false HloComputation for Conditional. 1072 // 1073 // Precondition: The instruction is a predicated Conditional instruction. 1074 HloComputation* true_computation() const; 1075 HloComputation* false_computation() const; 1076 1077 // Gets the branch HloComputations for Conditional. 1078 // 1079 // Precondition: The instruction is a Conditional instruction. 1080 const std::vector<HloComputation*>& branch_computations() const; 1081 int branch_count() const; 1082 HloComputation* branch_computation(int b) const; 1083 // Sets a branch HloComputation for Conditional. 1084 // The setter should only be called by HloModule or HloComputation methods. 1085 // 1086 // Precondition: The instruction is a Conditional instruction. 1087 void set_branch_computation(int b, HloComputation* computation); 1088 1089 // Returns a string for the signature of this instruction if considered as a 1090 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 1091 string SignatureString() const; 1092 1093 // Returns a debugging string that represents this instruction. 1094 // 1095 // (We express the default options using an overload rather than a default 1096 // param because gdb ignores default params, but does resolve overloads.) 1097 // 1098 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 1099 // default, backing off on providing full information for very large strings, 1100 // or provide a different name for a ToString-like function that does that. ToString()1101 string ToString() const { return ToString(HloPrintOptions()); } 1102 string ToString(const HloPrintOptions& options) const; 1103 1104 // Components of the ToString() representation: 1105 1106 // Returns a string representation of the operand list. 1107 string OperandsToString(const HloPrintOptions& options) const; 1108 1109 // Returns string representation of op-specific attributes. 1110 std::vector<string> ExtraAttributesToString( 1111 const HloPrintOptions& options) const; 1112 1113 // As ToString, but returns a shorter string. 1114 string ToShortString() const; 1115 1116 // Returns a serialized representation of this instruction. 1117 virtual HloInstructionProto ToProto() const; 1118 1119 // Returns a category for the HLO. This could be something like "convolution" 1120 // or "elementwise". 1121 virtual string ToCategory() const; 1122 1123 // Returns a logging instruction, if the output of this instruction is logged. 1124 // 1125 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace 1126 HloInstruction* tracing() const; 1127 void set_tracing(HloInstruction* trace_instruction); 1128 1129 // Returns true if this instruction is fused, ie contained within a fusion 1130 // instruction. 1131 bool IsFused() const; 1132 1133 // Returns true if this instruction can be legally fused into a fusion 1134 // instruction. 1135 bool IsFusible() const; 1136 1137 // Returns the sharding applied to this operator. 1138 // REQUIRES: has_sharding() is true. sharding()1139 const HloSharding& sharding() const { 1140 CHECK(has_sharding()); 1141 return *sharding_; 1142 } sharding_ptr()1143 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } 1144 1145 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)1146 const HloSharding& sharding_or_default(const HloSharding& default_) const { 1147 return sharding_ ? *sharding_ : default_; 1148 } 1149 // Returns the sharding unique device, if any. sharding_unique_device()1150 absl::optional<int64> sharding_unique_device() const { 1151 if (sharding_ == nullptr) { 1152 return absl::optional<int64>(); 1153 } 1154 return sharding_->UniqueDevice(); 1155 } 1156 // Sets the sharding of this operator. Should only be called by HloModule or 1157 // HloComputation methods. set_sharding(const HloSharding & sharding)1158 void set_sharding(const HloSharding& sharding) { 1159 sharding_ = std::make_shared<const HloSharding>(sharding); 1160 } set_sharding(std::shared_ptr<const HloSharding> sharding)1161 void set_sharding(std::shared_ptr<const HloSharding> sharding) { 1162 sharding_ = std::move(sharding); 1163 } 1164 void set_single_sharding(const HloSharding& sharding); 1165 // Sets a sharding that assigns the current instruction to device. set_device_sharding(int64 device)1166 void set_device_sharding(int64 device) { 1167 set_single_sharding(HloSharding::AssignDevice(device)); 1168 } 1169 // Remove any sharding from this operator. clear_sharding()1170 void clear_sharding() { sharding_ = nullptr; } 1171 // Return true if this operator has a sharding assigned. has_sharding()1172 bool has_sharding() const { return sharding_ != nullptr; } 1173 // Checks whether the instruction has compatible sharding with the other 1174 // instruction. has_compatible_sharding(const HloInstruction * other)1175 bool has_compatible_sharding(const HloInstruction* other) const { 1176 if (!has_sharding()) { 1177 return !other->has_sharding(); 1178 } 1179 return other->has_sharding() ? sharding() == other->sharding() : false; 1180 } 1181 1182 // When creating a new instruction which either replaces, or shifts up (kCopy 1183 // insertion case), another instruction, we need to make sure the certain 1184 // properties of the new instruction are copied into the derived one. As of 1185 // today, the metadata and sharding will be propagated to the derived 1186 // instruction. 1187 void SetupDerivedInstruction(HloInstruction* derived_instruction) const; 1188 1189 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1190 // operands. After creation the clone has no uses. "this" (the instruction 1191 // cloned from) is not changed. Suffix is the string to append to the name of 1192 // the instruction to form the name of the cloned instruction. 1193 // Ignores the control predecessors and successors of this HLO instruction. 1194 std::unique_ptr<HloInstruction> Clone( 1195 const string& suffix = "clone", HloCloneContext* context = nullptr) const; 1196 1197 // Clones the HLO instruction as above but with new shape and operands. 1198 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1199 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1200 HloCloneContext* context = nullptr) const; 1201 1202 // Returns the computations this instruction directly calls (if any). called_computations()1203 const std::vector<HloComputation*>& called_computations() const { 1204 return called_computations_; 1205 } 1206 1207 // Replaces all called computations based on a map function. This is needed 1208 // when we clone hlo_computations and want to let the instructions to point 1209 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1210 void ReplaceCalledComputations( 1211 std::function<HloComputation*(HloComputation*)> map_function) { 1212 for (int64 i = 0; i < called_computations_.size(); ++i) { 1213 called_computations_[i] = map_function(called_computations_[i]); 1214 } 1215 } 1216 1217 // Clears out the called computations. 1218 // 1219 // This is, in particular, necessary when inlining function bodies into their 1220 // caller. If there were side-effecting operations in the called computations, 1221 // the call itself is considered side-effecting and thus cannot be removed. By 1222 // clearing out the computations, we reflect the fact that all side-effecting 1223 // properties have been reflected in the caller, and make the call HLO 1224 // removable. ClearCalledComputations()1225 void ClearCalledComputations() { called_computations_.clear(); } 1226 1227 // Returns true if this instruction performs an elementwise operation on 1228 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1229 // to compute the output at index {i_0,i_1,...,i_n}, the only element required 1230 // from the operand (if any) is the element at {i_0,i_1,...,i_n}. 1231 // 1232 // Note on performance: when this instruction is kFusion, this method, in the 1233 // worst case, scans all fused instructions. We could speed this up by 1234 // caching. 1235 bool IsElementwiseOnOperand(int64 operand_idx) const; 1236 1237 // Returns true if this instruction is elementwise on all its operands. 1238 bool IsElementwise() const; 1239 1240 // Returns true if this is a cross module all-reduce instruction. 1241 bool IsCrossModuleAllReduce() const; 1242 1243 // Returns true if this is a cross-replica all-reduce instruction. 1244 bool IsCrossReplicaAllReduce() const; 1245 1246 // Returns true if this instruction is binary and elementwise. 1247 bool IsElementwiseBinary() const; 1248 1249 // Returns whether this instruction may reuse elements of its `i`th operand. ReusesOperandElements(int64 i)1250 bool ReusesOperandElements(int64 i) const { 1251 return OperandElementUse(i) == UseKind::kReuse; 1252 } 1253 1254 // Returns the indices that the given operand appear in the operand list of 1255 // this instruction. Note that an instruction can use the same operand 1256 // multiple times. 1257 std::vector<int64> OperandIndices(const HloInstruction* operand) const; 1258 1259 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1260 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1261 // indices of the deleted dimensions and the output indices of the inserted 1262 // dimensions. 1263 // 1264 // Precondition: this op must be a reshape. 1265 std::tuple<bool, std::vector<int64>, std::vector<int64>> 1266 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1267 1268 // Gets the string identifier for this instruction. name()1269 const string& name() const { return name_; } 1270 1271 // Sets the string identifier for this instruction. Name will be sanitized to 1272 // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". SetAndSanitizeName(const string & name)1273 void SetAndSanitizeName(const string& name) { 1274 name_ = NameUniquer::GetSanitizedName(name); 1275 } 1276 1277 // Use the given NameUniquer to select a unique name for the instruction based 1278 // on the instruction's existing name. 1279 void UniquifyName(NameUniquer* name_uniquer); 1280 1281 // Clear the unique ID of the instruction so that it can be re-assigned, such 1282 // as for the purpose of compacting the instruction unique IDs. ClearUniqueIdInternal()1283 void ClearUniqueIdInternal() { unique_id_ = -1; } 1284 1285 // Set the unique id for this instruction to "id" SetUniqueId(int id)1286 void SetUniqueId(int id) { 1287 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1288 CHECK_GE(id, 0); 1289 unique_id_ = id; 1290 } 1291 1292 // Return the unique ID assigned to this node via SetUniqueId (or -1 1293 // if no id has been assigned yet). unique_id()1294 int unique_id() const { return unique_id_; } 1295 1296 // Returns the backend-specific configuration for how a backend should compile 1297 // this HLO. The meaning of the field is backend specific. Not for use before 1298 // or during general HLO optimization, since HLO optimizations do not preserve 1299 // this field and they cannot interpret it due to its meaning being backend 1300 // specific. 1301 // 1302 // ConfigProto should be a protobuf Message type. 1303 template <typename ConfigProto> backend_config()1304 StatusOr<ConfigProto> backend_config() const { 1305 ConfigProto proto; 1306 TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); 1307 return std::move(proto); 1308 } 1309 Status set_backend_config(const tensorflow::protobuf::Message& proto); 1310 1311 // Getter/setter for raw JSON-encoded backend config. Prefer the 1312 // functions above that deal in proto Messages where possible. raw_backend_config_string()1313 const string& raw_backend_config_string() const { return backend_config_; } set_raw_backend_config_string(string config_str)1314 void set_raw_backend_config_string(string config_str) { 1315 backend_config_ = std::move(config_str); 1316 } 1317 is_default_config()1318 bool is_default_config() const { return is_default_config_; } set_default_config()1319 void set_default_config() { is_default_config_ = true; } 1320 1321 // Returns a string representation of a proto in the format used by 1322 // raw_backend_config_string. 1323 // 1324 // This is morally equivalent to: 1325 // 1326 // HloInstruction instr; 1327 // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); 1328 // return instr.raw_backend_config_string(); 1329 // 1330 static StatusOr<string> BackendConfigToRawString( 1331 const tensorflow::protobuf::Message& proto); 1332 1333 // Returns the information used to tell the implementation information about 1334 // what sort of precision is requested. The meaning of the field is backend 1335 // specific. At the moment, it is only supported for kConvolution and kDot. 1336 // Transformations on one kDot or kConvolution to another will preserve this 1337 // information. Transformations to other HLOs will not preserve this 1338 // information but it is presumed that the alternate lowering is strictly 1339 // superior. 1340 // Precondition: opcode must be kConvolution or kDot. 1341 const PrecisionConfig& precision_config() const; 1342 PrecisionConfig* mutable_precision_config(); 1343 1344 // Sets the debug metadata for this instruction. set_metadata(const OpMetadata & metadata)1345 void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } metadata()1346 const OpMetadata& metadata() const { return metadata_; } 1347 1348 // Set/get the computation containing this instruction. set_parent should only 1349 // be called by HloComputation methods which add/remove instructions to 1350 // computations. set_parent(HloComputation * computation)1351 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1352 const HloComputation* parent() const { return parent_; } parent()1353 HloComputation* parent() { return parent_; } 1354 1355 // Returns the module for this instruction. 1356 HloModule* GetModule() const; 1357 1358 // Returns whether we could assign input and output layouts to this 1359 // instruction to make it a bitcast. 1360 bool CouldBeBitcast() const; 1361 1362 // Get/Set the number of partitions per outer dimension (in order, starting 1363 // with outer-most dimension first). Currently used by the parallel cpu 1364 // backend to partition HLOs into parallel tasks. 1365 // 1366 // TODO(b/62783254) Replace these methods with a more general way to 1367 // annotate HLOs with backend-specific information. outer_dimension_partitions()1368 const std::vector<int64>& outer_dimension_partitions() const { 1369 return outer_dimension_partitions_; 1370 } 1371 void set_outer_dimension_partitions( 1372 const std::vector<int64>& outer_dimension_partitions); 1373 1374 // Old methods kept for smooth subclassing transition BEGIN. 1375 // TODO(b/80131774): Remove this code. 1376 1377 // Delegates to HloBatchNormInstruction::feature_index. 1378 int64 feature_index() const; 1379 1380 // Delegates to HloBatchNormInstruction::epsilon. 1381 float epsilon() const; 1382 1383 // Delegates to HloFftInstruction::fft_type. 1384 FftType fft_type() const; 1385 1386 // Delegates to HloFftInstruction::fft_length. 1387 const std::vector<int64>& fft_length() const; 1388 1389 // Delegates to HloSendRecvInstruction::channel_id. 1390 int64 channel_id() const; 1391 1392 // Returns the dimension sizes or numbers associated with this instruction. dimensions()1393 virtual const std::vector<int64>& dimensions() const { 1394 LOG(FATAL) << "Unimplemented method."; 1395 } dimensions(int64 index)1396 virtual int64 dimensions(int64 index) const { 1397 LOG(FATAL) << "Unimplemented method."; 1398 } 1399 1400 // Delegates to HloConcatenateInstruction::concatenate_dimension. 1401 int64 concatenate_dimension() const; 1402 1403 // Delegates to HloGetDimensionSizeInstruction::dimension. 1404 int64 dimension() const; 1405 1406 // Returns whether this instruction does a rank-2 transposition. 1407 bool IsRank2Transpose() const; 1408 1409 // Delegates to HloSliceInstruction::slice_start. 1410 int64 slice_starts(int64 dimension) const; 1411 const std::vector<int64>& slice_starts() const; 1412 1413 // Delegates to HloSliceInstruction::slice_limits. 1414 int64 slice_limits(int64 dimension) const; 1415 const std::vector<int64>& slice_limits() const; 1416 1417 // Delegates to HloSliceInstruction::slice_strides. 1418 int64 slice_strides(int64 dimension) const; 1419 const std::vector<int64>& slice_strides() const; 1420 1421 // Returns the literal associated with this instruction. 1422 const Literal& literal() const; 1423 1424 // Returns whether the instruction is a constant. 1425 bool IsConstant() const; 1426 1427 // Delegate to HloConstantInstruction::RelayoutConstant. 1428 void RelayoutConstant(const Layout& new_layout, 1429 const ShapeIndex& shape_index = {}); 1430 1431 // Delegates to HloTraceInstruction::TracingTag. 1432 string TracingTag() const; 1433 1434 // Delegates to HloFusionInstruction::AddFusionOperand. 1435 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 1436 1437 // Delegates to HloFusionInstruction::MergeFusionInstruction. 1438 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 1439 1440 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. 1441 void MergeFusionInstructionIntoMultiOutput( 1442 HloInstruction* instruction_to_merge); 1443 1444 // Delegates to HloFusionInstruction::FuseInstruction. 1445 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); 1446 1447 // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. 1448 HloInstruction* FuseInstructionIntoMultiOutput( 1449 HloInstruction* instruction_to_fuse); 1450 1451 // Delegates to HloFusionInstruction::fused_instruction. 1452 HloComputation* fused_instructions_computation() const; 1453 1454 // Delegates to HloFusionInstruction::fused_expression_root. 1455 HloInstruction* fused_expression_root() const; 1456 1457 // Delegates to HloFusionInstruction::fused_instructions. 1458 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1459 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1460 fused_instructions() const; 1461 1462 const tensorflow::gtl::iterator_range< 1463 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1464 fused_instructions(); 1465 1466 // Delegates to HloFusionInstruction::fused_instruction_count. 1467 int64 fused_instruction_count() const; 1468 1469 // Delegates to HloFusionInstruction::fused_parameter. 1470 HloInstruction* fused_parameter(int64 parameter_number) const; 1471 1472 // Delegates to HloFusionInstruction::fused_parameters. 1473 const std::vector<HloInstruction*>& fused_parameters() const; 1474 1475 // Returns true if this instruction is a fusion instruction that generates 1476 // multiple outputs. 1477 const bool IsMultiOutputFusion() const; 1478 1479 // Delegates to HloFusionInstruction::fusion_kind. 1480 FusionKind fusion_kind() const; 1481 1482 // Delegates to HloFusionInstruction::set_fusion_kind. 1483 void set_fusion_kind(FusionKind kind); 1484 1485 // Delegates to HloRngInstruction::random_distribution. 1486 RandomDistribution random_distribution() const; 1487 1488 // Delegates to HloParameterInstruction::parameter_number. 1489 int64 parameter_number() const; 1490 1491 // Delegates to 1492 // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. 1493 void set_parameter_replicated_at_leaf_buffers( 1494 absl::Span<const bool> parameter_replicated_at_leaf_buffers); 1495 1496 // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. 1497 const absl::optional<std::vector<bool>>& 1498 parameter_replicated_at_leaf_buffers() const; 1499 1500 // Delegates to HloGetTupleElementInstruction::tuple_index. 1501 int64 tuple_index() const; 1502 1503 // Delegates to HloReducePrecisionInstruction::exponent_bits. 1504 int32 exponent_bits() const; 1505 1506 // Delegates to HloReducePrecisionInstruction::mantissa_bits. 1507 int32 mantissa_bits() const; 1508 1509 // Delegates to HloInfeedInstruction::infeed_config. 1510 string infeed_config() const; 1511 1512 // Delegates to HloInfeedInstruction::set_infeed_config. 1513 void set_infeed_config(const string& config); 1514 1515 // Returns the config for the Outfeed instruction. 1516 const string& outfeed_config() const; 1517 1518 // Returns the shape for the Outfeed instruction. 1519 const Shape& outfeed_shape() const; 1520 1521 // Delegates to HloCollectiveInstruction::replica_groups. 1522 const std::vector<ReplicaGroup>& replica_groups() const; 1523 1524 // Delegates to HloCollectivePermuteInstruction::source_target_pairs. 1525 const std::vector<std::pair<int64, int64>>& source_target_pairs() const; 1526 1527 // Delegates to HloAllReduceInstruction::all_reduce_barrier. 1528 string all_reduce_barrier() const; 1529 void set_all_reduce_barrier(const string& barrier); 1530 1531 // Delegates to HloAllReduceInstruction::all_reduce_id. 1532 absl::optional<int64> all_reduce_id() const; 1533 void set_all_reduce_id(const absl::optional<int64>& all_reduce_id); 1534 1535 // Returns data on the window in a windowed operation such as 1536 // convolution. window()1537 virtual const Window& window() const { 1538 LOG(FATAL) << "Unimplemented method."; 1539 } 1540 1541 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)1542 virtual void set_window(const Window& window) { 1543 LOG(FATAL) << "Unimplemented method."; 1544 } 1545 1546 // Returns data on the dimension numbers used for a convolution operation, 1547 // which may be a kConvolution instruction or a kCustomCall that implements a 1548 // convolution. 1549 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; 1550 1551 // Sets the convolution dimension numbers on this instruction. In general you 1552 // shouldn't need to call this; instead, specify the convolution dimension 1553 // numbers when you create the instruction. 1554 void set_convolution_dimension_numbers( 1555 const ConvolutionDimensionNumbers& dnums); 1556 1557 // The number of feature groups. Must be a divisor of the input feature 1558 // dimension and output feature dimension. 1559 int64 feature_group_count() const; 1560 1561 void set_feature_group_count(int64 feature_group_count); 1562 1563 // The number of batch groups. Must be a divisor of the input batch dimension 1564 int64 batch_group_count() const; 1565 1566 void set_batch_group_count(int64 batch_group_count); 1567 1568 // Delegates to HloSelectAndScatterInstruction::select. 1569 HloComputation* select() const; 1570 1571 // Delegates to HloSelectAndScatterInstruction::scatter. 1572 HloComputation* scatter() const; 1573 1574 // Delegates to HloSelectAndScatterInstruction::set_select. 1575 void set_select(HloComputation* computation); 1576 1577 // Delegates to HloSelectAndScatterInstruction::set_scatter. 1578 void set_scatter(HloComputation* computation); 1579 1580 // Delegates to HloCustomCallInstruction::custom_call_target. 1581 const string& custom_call_target() const; 1582 1583 // Delegates to HloPadInstruction::padding_config. 1584 const PaddingConfig& padding_config() const; 1585 1586 // Delegates to HloDynamicSliceInstruction::slice_sizes. 1587 int64 slice_sizes(int64 dimension) const; 1588 1589 // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. 1590 const std::vector<int64>& dynamic_slice_sizes() const; 1591 1592 // Delegates to HloGatherInstruction::gather_dimension_numbers. 1593 const GatherDimensionNumbers& gather_dimension_numbers() const; 1594 // Delegates to HloGatherInstruction::gather_slice_sizes. 1595 absl::Span<const int64> gather_slice_sizes() const; 1596 1597 // Delegates to HloScatterInstruction::scatter_dimension_numbers(). 1598 const ScatterDimensionNumbers& scatter_dimension_numbers() const; 1599 1600 // Delegates to HloDotInstruction::dot_dimension_numbers(). 1601 const DotDimensionNumbers& dot_dimension_numbers() const; 1602 1603 // Delegates to HloDomainInstruction::operand_side_metadata(). 1604 const DomainMetadata& operand_side_metadata() const; 1605 1606 // Delegates to HloDomainInstruction::user_side_metadata(). 1607 const DomainMetadata& user_side_metadata() const; 1608 1609 // Delegates to HloCompareInstruction::direction(). 1610 ComparisonDirection comparison_direction() const; 1611 1612 // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). 1613 const TriangularSolveOptions& triangular_solve_options() const; 1614 1615 // Delegates to HloCholeskyInstruction::cholesky_options(). 1616 const CholeskyOptions& cholesky_options() const; 1617 1618 // Old methods kept for smooth subclassing transition END. 1619 1620 protected: 1621 enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; 1622 // Helper class for computing OperandElementUse for kFusion. 1623 class FusionReusesParamElements; 1624 1625 // Internal constructor for a given opcode/shape, other fields must be filled 1626 // by factory methods. 1627 HloInstruction(HloOpcode opcode, const Shape& shape); 1628 1629 // Appends operand to the list of operands and adds this instruction as a user 1630 // of the operand. 1631 void AppendOperand(HloInstruction* operand); 1632 RemoveOperandAt(int index)1633 void RemoveOperandAt(int index) { 1634 operands_.erase(operands_.begin() + index); 1635 } 1636 1637 // Removes a list of operands with the given indices in ascending order. 1638 void RemoveOperandsAtAscendingIndices( 1639 absl::Span<const int> ascending_indices); 1640 AppendComputation(HloComputation * computation)1641 void AppendComputation(HloComputation* computation) { 1642 called_computations_.push_back(computation); 1643 } 1644 DetachFrom(HloInstruction * usee)1645 void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } 1646 set_called_computation(int index,HloComputation * computation)1647 void set_called_computation(int index, HloComputation* computation) { 1648 called_computations_[index] = computation; 1649 } 1650 // Indices of computations in called_computations_ for instructions which call 1651 // multiple computations. 1652 enum { 1653 // kWhile computations. 1654 kBodyComputationIndex = 0, 1655 kConditionComputationIndex = 1, 1656 1657 // kSelectAndScatter computations. 1658 kSelectComputationIndex = 0, 1659 kScatterComputationIndex = 1, 1660 1661 // kConditional computations. 1662 kTrueComputationIndex = 0, 1663 kFalseComputationIndex = 1, 1664 }; 1665 1666 private: 1667 // Implementation for non-common logic of CloneWithNewOperands. CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)1668 virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1669 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1670 HloCloneContext* context) const { 1671 // TODO(b/80131774): This should be pure virtual. 1672 LOG(FATAL) << "Unimplemented method."; 1673 } 1674 1675 // Implementation for non-common logic of ExtraAttributesToString. ExtraAttributesToStringImpl(const HloPrintOptions & options)1676 virtual std::vector<string> ExtraAttributesToStringImpl( 1677 const HloPrintOptions& options) const { 1678 return {}; 1679 } 1680 1681 // Implementation for IsElementwise if operand_idx is nullopt and for 1682 // IsElementwiseOnOperand if otherwise. 1683 // 1684 // NOTE: For all instructions other than kFusion, being elementwise on one of 1685 // the operands is equivalent to being elementwise on all the operands. 1686 virtual bool IsElementwiseImpl( 1687 const absl::optional<int64>& operand_idx) const; 1688 // Prints an instruction to a string. 1689 // 1690 // The canonical string representation needs to name operands and instruction 1691 // names in a consistent way. This is implemented through the 1692 // canonical_name_map. 1693 string ToStringWithCanonicalNameMap( 1694 const HloPrintOptions& options, 1695 CanonicalNameMap* canonical_name_map) const; 1696 1697 // Prints an operand to a string. 1698 virtual string OperandsToStringWithCanonicalNameMap( 1699 const HloPrintOptions& options, 1700 CanonicalNameMap* canonical_name_map) const; 1701 1702 // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and 1703 // OperandsToStringWithCanonicalNameMap() functions. 1704 friend class HloComputation; 1705 1706 // See comments on Identical(). 1707 virtual bool IdenticalSlowPath( 1708 const HloInstruction& other, 1709 const std::function<bool(const HloComputation*, const HloComputation*)>& 1710 eq_computations) const; 1711 1712 // Generates a hash value specific to a particular type of an instruction. 1713 // This function typically considers the inner root instruction. 1714 virtual uint64 InnerHash() const; 1715 1716 // Creates an n-ary elementwise operation. 1717 static std::unique_ptr<HloInstruction> CreateNary( 1718 const Shape& shape, HloOpcode opcode, 1719 absl::Span<HloInstruction* const> operands); 1720 1721 // Adds a user for this instruction. 1722 void AddUser(HloInstruction* user); 1723 1724 // Removes a user for this instruction. 1725 void RemoveUser(HloInstruction* user); 1726 1727 // Returns how this instruction uses elements of its `i`th operand. 1728 UseKind OperandElementUse(int64 i) const; 1729 1730 // Helper for implementing backend_config(). Parses backend_config_ into the 1731 // given proto. 1732 Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; 1733 1734 int unique_id_; // Unique to this HloInstruction within a HloModule 1735 1736 // Opcode for this instruction. 1737 HloOpcode opcode_; 1738 1739 // Instruction operands. 1740 InstructionVector operands_; 1741 1742 // The set of control predecessors of this instruction. 1743 // Note that the order of the instructions in the vector influences the order 1744 // computed in HloComputation::ComputeInstructionPostOrder, which may 1745 // influence the result of the compilation by changing the scheduling. We are 1746 // not sure if it matters. 1747 std::vector<HloInstruction*> control_predecessors_; 1748 1749 // The users of this instruction. Users are HLOs where this instruction is an 1750 // operand. The vector users_ and the set user_set_ contain identical 1751 // members. The set enables fast membership testing and the vector enables 1752 // fast, stable iteration. 1753 std::vector<HloInstruction*> users_; 1754 absl::flat_hash_set<const HloInstruction*> user_set_; 1755 1756 // The set of control successors of this instruction. 1757 std::vector<HloInstruction*> control_successors_; 1758 1759 // The computation in which this instruction is contained. 1760 HloComputation* parent_ = nullptr; 1761 1762 // Result shape of this instruction. 1763 Shape shape_; 1764 1765 // The sharding, if one exists. 1766 // Uses std::shared_ptr to allow reuse of the same sharding object between 1767 // HloInstructions and other components as HloSharding can be very large for 1768 // many element tuples. 1769 std::shared_ptr<const HloSharding> sharding_; 1770 1771 // Computations called by this instruction. 1772 std::vector<HloComputation*> called_computations_; 1773 1774 // A trace instruction that consumes this instruction. 1775 // 1776 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 1777 // an operand. 1778 HloInstruction* trace_instruction_ = nullptr; 1779 1780 // The backend-specific configuration for how a backend should compile this 1781 // HLO. See the documentation on backend_config(). 1782 string backend_config_; 1783 1784 // This field is assigned to true when backend_config_ is assigned to 1785 // a default configuration. 1786 bool is_default_config_ = false; 1787 1788 // String identifier for instruction. 1789 string name_; 1790 1791 // Metadata for debugging. 1792 OpMetadata metadata_; 1793 1794 // The number of partitions per outer dimension (listed in order from 1795 // outer-most dimension first). 1796 std::vector<int64> outer_dimension_partitions_; 1797 1798 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); 1799 }; 1800 1801 // Explicit instantiations in hlo_instruction.cc. 1802 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); 1803 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); 1804 1805 string ToString(HloInstruction::FusionKind kind); 1806 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 1807 const string& kind_name); 1808 1809 // Custom (de)stringification functions for protos that live inside 1810 // HloInstruction. 1811 string PaddingConfigToString(const PaddingConfig& padding); 1812 string OpMetadataToString(const OpMetadata& metadata); 1813 string RandomDistributionToString(const RandomDistribution& distribution); 1814 string PrecisionToString(const PrecisionConfig::Precision& precision); 1815 string ConvolutionDimensionNumbersToString( 1816 const ConvolutionDimensionNumbers& dnums); 1817 1818 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); 1819 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); 1820 1821 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 1822 1823 // Map classes that guarantee a deterministic iteration order when the key is 1824 // an HloInstruction* or a const HloInstruction*. 1825 // To make the iteration order over the map deterministic, the comparator 1826 // should not be using the pointer values, but rather an intrinsic property of 1827 // the hlo. Exception: null pointer values compare less than non-null. 1828 struct HloPtrComparator { 1829 bool operator()(const HloInstruction* const& lhs, 1830 const HloInstruction* const& rhs) const; 1831 }; 1832 1833 template <typename ValueT> 1834 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 1835 1836 template <typename ValueT> 1837 using ConstHloInstructionMap = 1838 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 1839 1840 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 1841 using ConstHloInstructionSet = 1842 std::set<const HloInstruction*, HloPtrComparator>; 1843 1844 } // namespace xla 1845 1846 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 1847