1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // All HloInstruction subclasses are put in this file. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/memory/memory.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/shape.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 27 namespace xla { 28 29 class HloBatchNormInstruction : public HloInstruction { 30 public: 31 // Returns feature_index field associated with the instruction. The index 32 // represents the index of the feature dimension. feature_index()33 int64 feature_index() const { return feature_index_; } 34 35 // Returns a epsilon value associated with the instruction. The is a small 36 // number added to the variance to avoid divide-by-zero error. epsilon()37 float epsilon() const { return epsilon_; } 38 39 // Returns a serialized representation of this instruction. 40 HloInstructionProto ToProto() const override; 41 42 protected: 43 explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, 44 HloInstruction* operand, 45 HloInstruction* scale, float epsilon, 46 int64_t feature_index); 47 48 private: 49 std::vector<string> ExtraAttributesToStringImpl( 50 const HloPrintOptions& options) const override; 51 bool IdenticalSlowPath( 52 const HloInstruction& other, 53 const std::function<bool(const HloComputation*, const HloComputation*)>& 54 eq_computations) const override; 55 // A small float number added to the variance to avoid divide-by-zero error. 56 float epsilon_ = 0.0f; 57 58 // An integer value representing the index of the feature dimension. 59 int64 feature_index_ = -1; 60 }; 61 62 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { 63 public: 64 explicit HloBatchNormTrainingInstruction( 65 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 66 HloInstruction* offset, float epsilon, int64_t feature_index); 67 68 private: 69 // Implementation for non-common logic of CloneWithNewOperands. 70 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 71 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 72 HloCloneContext* context) const override; 73 }; 74 75 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { 76 public: 77 explicit HloBatchNormInferenceInstruction( 78 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 79 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 80 float epsilon, int64_t feature_index); 81 82 private: 83 // Implementation for non-common logic of CloneWithNewOperands. 84 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 85 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 86 HloCloneContext* context) const override; 87 }; 88 89 class HloBatchNormGradInstruction : public HloBatchNormInstruction { 90 public: 91 explicit HloBatchNormGradInstruction( 92 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 93 HloInstruction* mean, HloInstruction* variance, 94 HloInstruction* grad_output, float epsilon, int64_t feature_index); 95 96 private: 97 // Implementation for non-common logic of CloneWithNewOperands. 98 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 99 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 100 HloCloneContext* context) const override; 101 }; 102 103 class HloFftInstruction : public HloInstruction { 104 public: 105 explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, 106 FftType fft_type, 107 absl::Span<const int64> fft_length); fft_type()108 FftType fft_type() const { return fft_type_; } 109 fft_length()110 const std::vector<int64>& fft_length() const { return fft_length_; } 111 112 // Returns a serialized representation of this instruction. 113 HloInstructionProto ToProto() const override; 114 115 private: 116 std::vector<string> ExtraAttributesToStringImpl( 117 const HloPrintOptions& options) const override; 118 bool IdenticalSlowPath( 119 const HloInstruction& other, 120 const std::function<bool(const HloComputation*, const HloComputation*)>& 121 eq_computations) const override; 122 123 // Implementation for non-common logic of CloneWithNewOperands. 124 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 125 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 126 HloCloneContext* context) const override; 127 128 // Describes FFT type for an FFT instruction. 129 FftType fft_type_ = FftType::FFT; 130 131 // Indicates the FFT length for an FFT instruction. 132 std::vector<int64> fft_length_; 133 }; 134 135 class HloCopyStartInstruction : public HloInstruction { 136 public: 137 explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, 138 bool is_cross_program_prefetch); 139 is_cross_program_prefetch()140 bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } 141 HloInstructionProto ToProto() const override; 142 143 private: 144 std::vector<string> ExtraAttributesToStringImpl( 145 const HloPrintOptions& options) const override; 146 bool IdenticalSlowPath( 147 const HloInstruction& other, 148 const std::function<bool(const HloComputation*, const HloComputation*)>& 149 eq_computations) const override; 150 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 151 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 152 HloCloneContext* context) const override; 153 154 bool is_cross_program_prefetch_; 155 }; 156 157 class HloCompareInstruction : public HloInstruction { 158 public: 159 explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, 160 HloInstruction* rhs, 161 ComparisonDirection direction, 162 absl::optional<Comparison::Type> type); direction()163 ComparisonDirection direction() const { return compare_.GetDirection(); } type()164 Comparison::Type type() const { return compare_.GetType(); } 165 HloInstructionProto ToProto() const override; 166 167 private: 168 std::vector<string> ExtraAttributesToStringImpl( 169 const HloPrintOptions& options) const override; 170 bool IdenticalSlowPath( 171 const HloInstruction& other, 172 const std::function<bool(const HloComputation*, const HloComputation*)>& 173 eq_computations) const override; 174 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 175 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 176 HloCloneContext* context) const override; 177 178 Comparison compare_; 179 }; 180 181 class HloTriangularSolveInstruction : public HloInstruction { 182 public: 183 explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, 184 HloInstruction* b, 185 const TriangularSolveOptions& options); triangular_solve_options()186 const TriangularSolveOptions& triangular_solve_options() const { 187 return triangular_solve_options_; 188 } 189 190 // Returns a serialized representation of this instruction. 191 HloInstructionProto ToProto() const override; 192 193 private: 194 std::vector<string> ExtraAttributesToStringImpl( 195 const HloPrintOptions& options) const override; 196 bool IdenticalSlowPath( 197 const HloInstruction& other, 198 const std::function<bool(const HloComputation*, const HloComputation*)>& 199 eq_computations) const override; 200 201 // Implementation for non-common logic of CloneWithNewOperands. 202 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 203 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 204 HloCloneContext* context) const override; 205 206 TriangularSolveOptions triangular_solve_options_; 207 }; 208 209 class HloCholeskyInstruction : public HloInstruction { 210 public: 211 explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, 212 const CholeskyOptions& options); cholesky_options()213 const CholeskyOptions& cholesky_options() const { return cholesky_options_; } 214 215 // Returns a serialized representation of this instruction. 216 HloInstructionProto ToProto() const override; 217 218 private: 219 std::vector<string> ExtraAttributesToStringImpl( 220 const HloPrintOptions& options) const override; 221 bool IdenticalSlowPath( 222 const HloInstruction& other, 223 const std::function<bool(const HloComputation*, const HloComputation*)>& 224 eq_computations) const override; 225 226 // Implementation for non-common logic of CloneWithNewOperands. 227 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 228 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 229 HloCloneContext* context) const override; 230 231 CholeskyOptions cholesky_options_; 232 }; 233 234 // Class that represents instructions that synchronize and transfer data between 235 // partitioned devices. Send/Recv and collective instructions (AllReduce, 236 // AllToAll, CollectivePermute) belong to this instruction type. A group of 237 // instructions (of the same opcode) with the same channel_id communicate during 238 // execution. 239 class HloChannelInstruction : public HloInstruction { 240 public: 241 // Returns the channel id associated with the instruction. The id is 242 // shared between each Send/Recv pair or a group of collective instructions 243 // and is globally unique to identify each channel. channel_id()244 absl::optional<int64> channel_id() const { return channel_id_; } 245 void set_channel_id(const absl::optional<int64>& channel_id); 246 247 // Whether this instruction is identical to `other` except for the values of 248 // channel IDs, as long as both have channel IDs or neither has a channel ID. IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)249 virtual bool IdenticalSlowPathIgnoringChannelIdValues( 250 const HloInstruction& other, 251 const std::function<bool(const HloComputation*, const HloComputation*)>& 252 eq_computations) const { 253 return channel_id_.has_value() == other.channel_id().has_value(); 254 } 255 256 protected: 257 explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, 258 const absl::optional<int64>& channel_id); 259 260 HloInstructionProto ToProto() const override; 261 262 std::vector<string> ExtraAttributesToStringImpl( 263 const HloPrintOptions& options) const override; 264 265 // Do not override IdenticalSlowPath(). Override 266 // IdenticalSlowPathIgnoringChannelIdValues() instead. 267 bool IdenticalSlowPath( 268 const HloInstruction& other, 269 const std::function<bool(const HloComputation*, const HloComputation*)>& 270 eq_computations) const final; 271 272 absl::optional<int64> channel_id_; 273 }; 274 275 class HloSendRecvInstruction : public HloChannelInstruction { 276 public: 277 // Returns whether this send/recv instruction sends data to/from the host. is_host_transfer()278 bool is_host_transfer() const { return is_host_transfer_; } 279 280 // Returns a serialized representation of this instruction. 281 HloInstructionProto ToProto() const override; 282 283 protected: 284 explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, 285 int64_t channel_id, bool is_host_transfer); 286 287 private: 288 std::vector<string> ExtraAttributesToStringImpl( 289 const HloPrintOptions& options) const override; 290 bool IdenticalSlowPathIgnoringChannelIdValues( 291 const HloInstruction& other, 292 const std::function<bool(const HloComputation*, const HloComputation*)>& 293 eq_computations) const override; 294 // Whether this send/recv instruction sends data to/from the host. 295 bool is_host_transfer_; 296 }; 297 298 class HloSendInstruction : public HloSendRecvInstruction { 299 public: 300 explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, 301 int64_t channel_id, bool is_host_transfer); 302 303 private: 304 // Implementation for non-common logic of CloneWithNewOperands. 305 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 306 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 307 HloCloneContext* context) const override; 308 }; 309 310 class HloSendDoneInstruction : public HloSendRecvInstruction { 311 public: 312 explicit HloSendDoneInstruction(HloSendInstruction* operand, 313 bool is_host_transfer); 314 315 private: 316 // Implementation for non-common logic of CloneWithNewOperands. 317 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 318 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 319 HloCloneContext* context) const override; 320 }; 321 322 class HloRecvInstruction : public HloSendRecvInstruction { 323 public: 324 explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, 325 int64_t channel_id, bool is_host_transfer); 326 327 private: 328 // Implementation for non-common logic of CloneWithNewOperands. 329 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 330 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 331 HloCloneContext* context) const override; 332 }; 333 334 class HloRecvDoneInstruction : public HloSendRecvInstruction { 335 public: 336 explicit HloRecvDoneInstruction(HloRecvInstruction* operand, 337 bool is_host_transfer); 338 339 private: 340 // Implementation for non-common logic of CloneWithNewOperands. 341 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 342 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 343 HloCloneContext* context) const override; 344 }; 345 346 class HloCollectiveInstruction : public HloChannelInstruction { 347 public: replica_groups()348 const std::vector<ReplicaGroup>& replica_groups() const { 349 return replica_groups_; 350 } 351 352 // Returns true if the layout of the AllReduce is enforced by XLA client (as 353 // the layout set in the shape). The only reason for the client to set the 354 // layout is to separately compile computations that communicate with 355 // AllReduce. Since this field is only set `true` by the client, the compiler 356 // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set 357 // `false` for all other cases. 358 // 359 // When this is `true`, there may be communication endpoints outside the 360 // current compilation unit, so the compiler considers this AllReduce as 361 // side-effecting to disable compiler transformations. The compiler is free to 362 // transform unconstrained AllReduces differently across compilation units. 363 // It is an error for an HloModule to have a mix of constrained and 364 // unconstrained AllReduce instructions (checked by HloVerifier). constrain_layout()365 bool constrain_layout() const { return constrain_layout_; } 366 367 protected: 368 explicit HloCollectiveInstruction( 369 HloOpcode opcode, const Shape& shape, 370 absl::Span<HloInstruction* const> operands, 371 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 372 const absl::optional<int64>& channel_id); 373 374 HloInstructionProto ToProto() const override; 375 376 std::vector<string> ExtraAttributesToStringImpl( 377 const HloPrintOptions& options) const override; 378 bool IdenticalSlowPathIgnoringChannelIdValues( 379 const HloInstruction& other, 380 const std::function<bool(const HloComputation*, const HloComputation*)>& 381 eq_computations) const override; 382 383 std::vector<ReplicaGroup> replica_groups_; 384 bool constrain_layout_; 385 }; 386 387 class HloAllGatherInstruction : public HloCollectiveInstruction { 388 public: 389 explicit HloAllGatherInstruction( 390 HloOpcode opcode, const Shape& shape, 391 absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension, 392 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 393 const absl::optional<int64>& channel_id, bool use_global_device_ids); 394 // Same as HloAllReduceInstruction::use_global_device_ids. use_global_device_ids()395 bool use_global_device_ids() const { return use_global_device_ids_; } 396 397 // The dimension on which data from different participants are concatenated. all_gather_dimension()398 int64 all_gather_dimension() const { return all_gather_dimension_; } 399 set_all_gather_dimension(int64_t dim)400 void set_all_gather_dimension(int64_t dim) { all_gather_dimension_ = dim; } 401 402 protected: 403 std::vector<string> ExtraAttributesToStringImpl( 404 const HloPrintOptions& options) const override; 405 HloInstructionProto ToProto() const override; 406 407 private: 408 bool IdenticalSlowPathIgnoringChannelIdValues( 409 const HloInstruction& other, 410 const std::function<bool(const HloComputation*, const HloComputation*)>& 411 eq_computations) const override; 412 413 // Implementation for non-common logic of CloneWithNewOperands. 414 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 415 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 416 HloCloneContext* context) const override; 417 418 int64 all_gather_dimension_; 419 bool use_global_device_ids_; 420 }; 421 422 // Base class for all-reduce and all-reduce scatter instructions. 423 class HloAllReduceInstructionBase : public HloCollectiveInstruction { 424 public: 425 explicit HloAllReduceInstructionBase( 426 HloOpcode opcode, const Shape& shape, 427 absl::Span<HloInstruction* const> operands, 428 HloComputation* reduce_computation, 429 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 430 const absl::optional<int64>& channel_id, bool use_global_device_ids); 431 432 // Returns true if the ids in the ReplicaGroup config represent a global id of 433 // (replica_id * partition_count + partition_id) instead of a replica id. 434 // This enables more flexible grouping of devices if this all-reduce is both 435 // cross-partition and cross-replica. 436 // 437 // For example with 2 replicas and 4 partitions, 438 // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that 439 // group[0] = (0,0), (0,1), (1,0), (1,1) 440 // group[1] = (0,2), (0,3), (1,2), (1,3) 441 // where each pair is (replica_id, partition_id). use_global_device_ids()442 bool use_global_device_ids() const { return use_global_device_ids_; } 443 444 protected: 445 std::vector<string> ExtraAttributesToStringImpl( 446 const HloPrintOptions& options) const override; 447 HloInstructionProto ToProto() const override; 448 449 bool IdenticalSlowPathIgnoringChannelIdValues( 450 const HloInstruction& other, 451 const std::function<bool(const HloComputation*, const HloComputation*)>& 452 eq_computations) const override; 453 454 private: 455 bool use_global_device_ids_; 456 }; 457 458 class HloAllReduceInstruction : public HloAllReduceInstructionBase { 459 public: 460 using HloAllReduceInstructionBase::HloAllReduceInstructionBase; 461 462 // Returns true if the AllReduce does no communication, so it's equivalent 463 // to a mem copy. 464 bool IsNoop() const; 465 466 private: 467 // Implementation for non-common logic of CloneWithNewOperands. 468 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 469 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 470 HloCloneContext* context) const override; 471 }; 472 473 class HloReduceScatterInstruction : public HloAllReduceInstructionBase { 474 public: 475 explicit HloReduceScatterInstruction( 476 const Shape& shape, absl::Span<HloInstruction* const> operands, 477 HloComputation* reduce_computation, 478 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, 479 const absl::optional<int64>& channel_id, bool use_global_device_ids, 480 int64_t scatter_dimension); 481 482 // The dimension on which reduced data is scattered to different participants. scatter_dimension()483 int64 scatter_dimension() const { return scatter_dimension_; } 484 485 protected: 486 std::vector<string> ExtraAttributesToStringImpl( 487 const HloPrintOptions& options) const override; 488 HloInstructionProto ToProto() const override; 489 490 private: 491 bool IdenticalSlowPathIgnoringChannelIdValues( 492 const HloInstruction& other, 493 const std::function<bool(const HloComputation*, const HloComputation*)>& 494 eq_computations) const override; 495 496 // Implementation for non-common logic of CloneWithNewOperands. 497 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 498 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 499 HloCloneContext* context) const override; 500 501 int64 scatter_dimension_; 502 }; 503 504 class HloAllToAllInstruction : public HloCollectiveInstruction { 505 public: 506 explicit HloAllToAllInstruction(const Shape& shape, 507 absl::Span<HloInstruction* const> operands, 508 absl::Span<const ReplicaGroup> replica_groups, 509 bool constrain_layout, 510 const absl::optional<int64>& channel_id, 511 const absl::optional<int64>& split_dimension); 512 513 // AllToAll can optionally take a split dimension, which means that this 514 // AllToAll takes a single (flattened) array operand and produces an array 515 // output (instead of taking a list of operands and producing a tuple). 516 // 517 // split_dimension specifies which dimension in the operand is split across 518 // devices in each replica_group, and also means the concatenated dimension 519 // on the output (i.e., input and the output shapes are the same). split_dimension()520 absl::optional<int64> split_dimension() const { return split_dimension_; } set_split_dimension(int64_t dim)521 void set_split_dimension(int64_t dim) { split_dimension_ = dim; } 522 523 protected: 524 std::vector<string> ExtraAttributesToStringImpl( 525 const HloPrintOptions& options) const override; 526 HloInstructionProto ToProto() const override; 527 528 private: 529 bool IdenticalSlowPathIgnoringChannelIdValues( 530 const HloInstruction& other, 531 const std::function<bool(const HloComputation*, const HloComputation*)>& 532 eq_computations) const override; 533 534 // Implementation for non-common logic of CloneWithNewOperands. 535 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 536 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 537 HloCloneContext* context) const override; 538 539 absl::optional<int64> split_dimension_; 540 }; 541 542 class HloCollectivePermuteInstruction : public HloChannelInstruction { 543 public: 544 explicit HloCollectivePermuteInstruction( 545 HloOpcode opcode, const Shape& shape, HloInstruction* operand, 546 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 547 const absl::optional<int64_t>& channel_id); 548 549 explicit HloCollectivePermuteInstruction( 550 HloOpcode opcode, const Shape& shape, HloInstruction* input, 551 HloInstruction* output, HloInstruction* input_start_indices, 552 HloInstruction* output_start_indices, 553 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, 554 absl::Span<const std::vector<int64_t>> slice_sizes, 555 const absl::optional<int64_t>& channel_id); 556 source_target_pairs()557 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const { 558 return source_target_pairs_; 559 } 560 dynamic_slice_sizes_list()561 const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const { 562 return slice_sizes_; 563 } 564 565 // Returns a serialized representation of this instruction. 566 HloInstructionProto ToProto() const override; 567 568 private: 569 std::vector<string> ExtraAttributesToStringImpl( 570 const HloPrintOptions& options) const override; 571 bool IdenticalSlowPathIgnoringChannelIdValues( 572 const HloInstruction& other, 573 const std::function<bool(const HloComputation*, const HloComputation*)>& 574 eq_computations) const override; 575 576 // Implementation for non-common logic of CloneWithNewOperands. 577 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 578 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 579 HloCloneContext* context) const override; 580 581 const std::vector<std::pair<int64_t, int64_t>> source_target_pairs_; 582 const std::vector<std::vector<int64_t>> slice_sizes_; 583 }; 584 585 class HloReverseInstruction : public HloInstruction { 586 public: 587 explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, 588 absl::Span<const int64> dimensions); 589 // Returns the dimension sizes or numbers associated with this instruction. dimensions()590 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)591 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()592 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 593 // Returns a serialized representation of this instruction. 594 HloInstructionProto ToProto() const override; 595 596 private: 597 std::vector<string> ExtraAttributesToStringImpl( 598 const HloPrintOptions& options) const override; 599 bool IdenticalSlowPath( 600 const HloInstruction& other, 601 const std::function<bool(const HloComputation*, const HloComputation*)>& 602 eq_computations) const override; 603 // Implementation for non-common logic of CloneWithNewOperands. 604 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 605 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 606 HloCloneContext* context) const override; 607 608 std::vector<int64> dimensions_; 609 }; 610 611 class HloConcatenateInstruction : public HloInstruction { 612 public: 613 explicit HloConcatenateInstruction(const Shape& shape, 614 absl::Span<HloInstruction* const> operands, 615 int64_t dimension); 616 // Returns the dimension sizes or numbers associated with this instruction. dimensions()617 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)618 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()619 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 620 // Accessor for the dimension in which a concatenate HLO should occur. concatenate_dimension()621 int64 concatenate_dimension() const { return dimensions(0); } 622 // Returns a serialized representation of this instruction. 623 HloInstructionProto ToProto() const override; 624 625 private: 626 std::vector<string> ExtraAttributesToStringImpl( 627 const HloPrintOptions& options) const override; 628 bool IdenticalSlowPath( 629 const HloInstruction& other, 630 const std::function<bool(const HloComputation*, const HloComputation*)>& 631 eq_computations) const override; 632 // Implementation for non-common logic of CloneWithNewOperands. 633 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 634 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 635 HloCloneContext* context) const override; 636 637 std::vector<int64> dimensions_; 638 }; 639 640 class HloReduceInstruction : public HloInstruction { 641 public: 642 explicit HloReduceInstruction(const Shape& shape, 643 absl::Span<HloInstruction* const> args, 644 absl::Span<const int64> dimensions_to_reduce, 645 HloComputation* reduce_computation); 646 // Returns the dimension sizes or numbers associated with this instruction. dimensions()647 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)648 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()649 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 650 // Returns a serialized representation of this instruction. 651 HloInstructionProto ToProto() const override; 652 653 // Returns the number of input arrays (and, consequentially, the number of 654 // init values) this reduce has. input_count()655 int64 input_count() const { return operand_count() / 2; } 656 657 // Returns the input tensors to be reduced. inputs()658 absl::Span<HloInstruction* const> inputs() const { 659 return absl::MakeSpan(operands()).subspan(0, input_count()); 660 } 661 662 // Returns the init values of the reduction. init_values()663 absl::Span<HloInstruction* const> init_values() const { 664 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 665 } 666 667 private: 668 std::vector<string> ExtraAttributesToStringImpl( 669 const HloPrintOptions& options) const override; 670 bool IdenticalSlowPath( 671 const HloInstruction& other, 672 const std::function<bool(const HloComputation*, const HloComputation*)>& 673 eq_computations) const override; 674 // Implementation for non-common logic of CloneWithNewOperands. 675 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 676 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 677 HloCloneContext* context) const override; 678 679 std::vector<int64> dimensions_; 680 }; 681 682 class HloSortInstruction : public HloInstruction { 683 public: 684 explicit HloSortInstruction(const Shape& shape, int64_t dimension, 685 absl::Span<HloInstruction* const> operands, 686 HloComputation* compare, bool is_stable); 687 // Returns the dimension sizes or numbers associated with this instruction. dimensions()688 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)689 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()690 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 691 // Returns the sort dimension for this instruction sort_dimension()692 int64 sort_dimension() const { return dimensions(0); } 693 // Returns a serialized representation of this instruction. 694 HloInstructionProto ToProto() const override; 695 // Returns the key operand to this instruction. keys()696 const HloInstruction* keys() const { return operand(0); } mutable_keys()697 HloInstruction* mutable_keys() { return mutable_operand(0); } 698 // Returns the number of value operands. values_count()699 int64 values_count() const { return operand_count() - 1; } is_stable()700 bool is_stable() const { return is_stable_; } 701 702 private: 703 std::vector<string> ExtraAttributesToStringImpl( 704 const HloPrintOptions& options) const override; 705 bool IdenticalSlowPath( 706 const HloInstruction& other, 707 const std::function<bool(const HloComputation*, const HloComputation*)>& 708 eq_computations) const override; 709 // Implementation for non-common logic of CloneWithNewOperands. 710 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 711 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 712 HloCloneContext* context) const override; 713 714 std::vector<int64> dimensions_; 715 bool is_stable_; 716 }; 717 718 class HloTransposeInstruction : public HloInstruction { 719 public: 720 explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, 721 absl::Span<const int64> dimensions); 722 // Returns the dimension sizes or numbers associated with this instruction. dimensions()723 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)724 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()725 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 726 // Returns whether this instruction does a rank-2 transposition. 727 bool IsRank2Transpose() const; 728 // Returns a serialized representation of this instruction. 729 HloInstructionProto ToProto() const override; 730 731 private: 732 std::vector<string> ExtraAttributesToStringImpl( 733 const HloPrintOptions& options) const override; 734 bool IdenticalSlowPath( 735 const HloInstruction& other, 736 const std::function<bool(const HloComputation*, const HloComputation*)>& 737 eq_computations) const override; 738 // Implementation for non-common logic of CloneWithNewOperands. 739 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 740 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 741 HloCloneContext* context) const override; 742 743 std::vector<int64> dimensions_; 744 }; 745 746 class HloBroadcastInstruction : public HloInstruction { 747 public: 748 explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, 749 absl::Span<const int64> broadcast_dimension); 750 // Returns the dimension sizes or numbers associated with this instruction. dimensions()751 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)752 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()753 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 754 // Returns a serialized representation of this instruction. 755 HloInstructionProto ToProto() const override; 756 757 private: 758 std::vector<string> ExtraAttributesToStringImpl( 759 const HloPrintOptions& options) const override; 760 bool IdenticalSlowPath( 761 const HloInstruction& other, 762 const std::function<bool(const HloComputation*, const HloComputation*)>& 763 eq_computations) const override; 764 // Implementation for non-common logic of CloneWithNewOperands. 765 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 766 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 767 HloCloneContext* context) const override; 768 769 std::vector<int64> dimensions_; 770 }; 771 772 class HloDynamicReshapeInstruction : public HloInstruction { 773 public: 774 explicit HloDynamicReshapeInstruction( 775 const Shape& shape, HloInstruction* data_operand, 776 absl::Span<HloInstruction* const> dim_sizes); 777 778 // Returns the input dim sizes dimensions, which is operands[1:] dim_sizes()779 absl::Span<HloInstruction* const> dim_sizes() const { 780 return absl::MakeSpan(operands()).subspan(1, operand_count()); 781 } 782 783 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 784 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 785 HloCloneContext* context) const override; 786 787 // Returns the input dim size dimension, which is operands[1+i] dim_sizes(int64_t i)788 HloInstruction* dim_sizes(int64_t i) const { return operands()[i + 1]; } 789 }; 790 791 class HloReshapeInstruction : public HloInstruction { 792 public: 793 explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, 794 int64_t inferred_dimension); inferred_dimension()795 int64 inferred_dimension() const { return inferred_dimension_; } 796 HloInstructionProto ToProto() const override; 797 798 private: 799 std::vector<string> ExtraAttributesToStringImpl( 800 const HloPrintOptions& options) const override; 801 bool IdenticalSlowPath( 802 const HloInstruction& other, 803 const std::function<bool(const HloComputation*, const HloComputation*)>& 804 eq_computations) const override; 805 // Implementation for non-common logic of CloneWithNewOperands. 806 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 807 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 808 HloCloneContext* context) const override; 809 int64 inferred_dimension_; 810 }; 811 812 class HloMapInstruction : public HloInstruction { 813 public: 814 explicit HloMapInstruction(const Shape& shape, 815 absl::Span<HloInstruction* const> operands, 816 HloComputation* map_computation); 817 // Returns the dimension sizes or numbers associated with this instruction. dimensions()818 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64_t index)819 int64 dimensions(int64_t index) const override { return dimensions()[index]; } mutable_dimensions()820 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 821 // Returns a serialized representation of this instruction. 822 HloInstructionProto ToProto() const override; 823 824 private: 825 bool IsElementwiseImpl( 826 const absl::optional<int64>& operand_idx) const override; 827 std::vector<string> ExtraAttributesToStringImpl( 828 const HloPrintOptions& options) const override; 829 bool IdenticalSlowPath( 830 const HloInstruction& other, 831 const std::function<bool(const HloComputation*, const HloComputation*)>& 832 eq_computations) const override; 833 // Implementation for non-common logic of CloneWithNewOperands. 834 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 835 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 836 HloCloneContext* context) const override; 837 838 std::vector<int64> dimensions_; 839 }; 840 841 class HloSliceInstruction : public HloInstruction { 842 public: 843 explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, 844 absl::Span<const int64> start_indices, 845 absl::Span<const int64> limit_indices, 846 absl::Span<const int64> strides); 847 848 HloInstructionProto ToProto() const override; 849 850 // Returns the start index in the given dimension for a slice node. slice_starts(int64_t dimension)851 int64 slice_starts(int64_t dimension) const { 852 return slice_starts_[dimension]; 853 } slice_starts()854 const std::vector<int64>& slice_starts() const { return slice_starts_; } mutable_slice_starts()855 std::vector<int64>* mutable_slice_starts() { return &slice_starts_; } 856 857 // Returns the (exclusive) limit index in the given dimension for a slice 858 // node. slice_limits(int64_t dimension)859 int64 slice_limits(int64_t dimension) const { 860 return slice_limits_[dimension]; 861 } slice_limits()862 const std::vector<int64>& slice_limits() const { return slice_limits_; } mutable_slice_limits()863 std::vector<int64>* mutable_slice_limits() { return &slice_limits_; } 864 865 // Returns the stride in the given dimension for a slice node. slice_strides(int64_t dimension)866 int64 slice_strides(int64_t dimension) const { 867 return slice_strides_[dimension]; 868 } slice_strides()869 const std::vector<int64>& slice_strides() const { return slice_strides_; } mutable_slice_strides()870 std::vector<int64>* mutable_slice_strides() { return &slice_strides_; } 871 872 private: 873 std::vector<string> ExtraAttributesToStringImpl( 874 const HloPrintOptions& options) const override; 875 bool IdenticalSlowPath( 876 const HloInstruction& other, 877 const std::function<bool(const HloComputation*, const HloComputation*)>& 878 eq_computations) const override; 879 // Implementation for non-common logic of CloneWithNewOperands. 880 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 881 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 882 HloCloneContext* context) const override; 883 884 // Describes the [begin, end) index range for a slice. 885 std::vector<int64> slice_starts_; 886 std::vector<int64> slice_limits_; 887 std::vector<int64> slice_strides_; 888 }; 889 890 class HloConstantInstruction : public HloInstruction { 891 public: 892 explicit HloConstantInstruction(Literal literal); 893 explicit HloConstantInstruction(Literal literal, const Shape& shape); 894 // Used when the literal is too large and dropped. 895 explicit HloConstantInstruction(const Shape& shape); 896 // Returns the literal associated with this instruction. literal()897 const Literal& literal() const { return *literal_; } 898 // Returns the (mutable) literal associated with this instruction. mutable_literal()899 Literal* mutable_literal() { return &literal_.value(); } 900 // Returns whether there is literal associated with this instruction. HasLiteral()901 bool HasLiteral() const { return literal_.has_value(); } 902 // Returns a serialized representation of this instruction. 903 HloInstructionProto ToProto() const override; 904 905 // Change the layout for an Constant Hlo instruction to match new_layout. For 906 // tuple shaped constants shape_index is the path to the internal array 907 // subshape whose layout needs to be changed. 908 void RelayoutConstant(const Layout& new_layout, 909 const ShapeIndex& shape_index = {}); 910 911 private: 912 bool IsElementwiseImpl( 913 const absl::optional<int64>& operand_idx) const override; 914 bool IdenticalSlowPath( 915 const HloInstruction& other, 916 const std::function<bool(const HloComputation*, const HloComputation*)>& 917 eq_computations) const override; 918 string OperandsToStringWithCanonicalNameMap( 919 const HloPrintOptions& options, 920 CanonicalNameMap* canonical_name_map) const override; 921 // Implementation for non-common logic of CloneWithNewOperands. 922 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 923 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 924 HloCloneContext* context) const override; 925 absl::optional<Literal> literal_; 926 }; 927 928 class HloTraceInstruction : public HloInstruction { 929 public: 930 explicit HloTraceInstruction(const string& tag, HloInstruction* operand); 931 // Returns a tag to be used in tracing. TracingTag()932 string TracingTag() const { return literal_.GetR1U8AsString(); } 933 // Returns a serialized representation of this instruction. 934 HloInstructionProto ToProto() const override; 935 936 private: 937 bool IdenticalSlowPath( 938 const HloInstruction& other, 939 const std::function<bool(const HloComputation*, const HloComputation*)>& 940 eq_computations) const override; 941 // Implementation for non-common logic of CloneWithNewOperands. 942 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 943 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 944 HloCloneContext* context) const override; 945 Literal literal_; 946 }; 947 948 class HloFusionInstruction : public HloInstruction { 949 public: 950 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 951 HloInstruction* fused_root); 952 953 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 954 absl::Span<HloInstruction* const> operands, 955 HloComputation* fusion_computation); 956 957 ~HloFusionInstruction() override; 958 959 void ClearCalledComputations() override; 960 961 // When a fusion instruction is being destructed, clear the back pointer of 962 // its fusion computation, to avoid referencing freed memory. 963 void ClearFusionComputationInstruction(); 964 965 string ToCategory() const override; 966 // Returns a serialized representation of this instruction. 967 HloInstructionProto ToProto() const override; 968 969 // Adds a new operand the fusion instruction. 970 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 971 972 // Merges the fused instructions from 'instruction_to_merge' into the 973 // fused instruction set of 'this', updating operands as necessary. 974 // 975 // Precondition: 'instruction_to_merge' must be an operand of 'this'. 976 void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); 977 978 // Merges the fused instructions from instruction_to_merge into the fused 979 // instruction set of 'this' and generates multioutput fusion instructions. 980 // All the users of instruction_to_merge will be redirected to 'this' 981 // instruction. instruction_to_merge will be removed from its parent 982 // computation. 983 void MergeFusionInstructionIntoMultiOutput( 984 HloFusionInstruction* instruction_to_merge); 985 986 // Fuses the given instruction in this fusion instruction. instruction_to_fuse 987 // is cloned and the clone is placed in the fusion 988 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather 989 // than moved to cleanly handle the case where the instruction has a use 990 // outside the fusion instruction. Moving such an instruction into a fusion 991 // instruction would violate the single-result invariant of HLO instructions 992 // and significantly complicate code generation. FuseInstruction(HloInstruction * instruction_to_fuse)993 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { 994 return FuseInstructionInternal(instruction_to_fuse); 995 } 996 997 // Fuses the given instruction in this fusion instruction and generates a 998 // multioutput fusion instruction. A clone of the instruction_to_fuse will 999 // be part of the output of fusion instructions. The users of 1000 // instruction_to_fuse will be redirected to this fusion instructions. 1001 // instruction_to_fuse is unchanged otherwise. FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)1002 HloInstruction* FuseInstructionIntoMultiOutput( 1003 HloInstruction* instruction_to_fuse) { 1004 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); 1005 } 1006 1007 // Returns the computation for this fused instruction. 1008 HloComputation* fused_instructions_computation() const; 1009 1010 // Returns the root instruction of the fused expression contained within this 1011 // fusion instruction. 1012 HloInstruction* fused_expression_root() const; 1013 1014 // Returns the list of fused instructions inside this fusion instruction. The 1015 // returned type is a range of HloInstruction*s. 1016 const tensorflow::gtl::iterator_range<UnwrappingIterator< 1017 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 1018 fused_instructions() const; 1019 1020 const tensorflow::gtl::iterator_range< 1021 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 1022 fused_instructions(); 1023 1024 // Gets the number of instructions inside this fusion instruction. 1025 int64 fused_instruction_count() const; 1026 1027 // Returns the fused parameter instruction in this fusion instruction 1028 // corresponding to the given parameter number. 1029 HloInstruction* fused_parameter(int64_t parameter_number) const; 1030 1031 // Returns the vector of fused parameters inside this fusion instruction. 1032 const std::vector<HloInstruction*>& fused_parameters() const; 1033 1034 // Returns true if this instruction is a fusion instruction that generates 1035 // multiple outputs. IsMultiOutputFusion()1036 const bool IsMultiOutputFusion() const { 1037 return fused_expression_root()->opcode() == HloOpcode::kTuple; 1038 } 1039 fusion_kind()1040 FusionKind fusion_kind() const { return fusion_kind_; } 1041 set_fusion_kind(FusionKind kind)1042 void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } 1043 1044 // If multiple operands are the same instruction, keeps only one of them. 1045 Status DeduplicateFusionOperands(); 1046 1047 private: 1048 // Fuses the given instruction into this fusion instruction. 1049 // instruction_to_fuse is cloned and the clone is placed in the fusion 1050 // instruction. The users of instruction_to_fuse will be redirected to this 1051 // fusion instruction. instruction_to_fuse is unchanged otherwise. When 1052 // add_output is true, a clone of the instruction_to_fuse will be added as 1053 // additional output resulting in a multi-output fusion. 1054 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, 1055 bool add_output = false); 1056 // Clones the given instruction_to_fuse and insert the clone into this fusion 1057 // instruction. If add_output is true, a clone of instruction_to_fuse will 1058 // be in the output of the this fusion instruction (part of the tuple of the 1059 // fusion root). 1060 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, 1061 bool add_output = false); 1062 1063 bool IsElementwiseImpl( 1064 const absl::optional<int64>& operand_idx) const override; 1065 std::vector<string> ExtraAttributesToStringImpl( 1066 const HloPrintOptions& options) const override; 1067 bool IdenticalSlowPath( 1068 const HloInstruction& other, 1069 const std::function<bool(const HloComputation*, const HloComputation*)>& 1070 eq_computations) const override; 1071 uint64 InnerHash() const override; 1072 1073 // Implementation for non-common logic of CloneWithNewOperands. 1074 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1075 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1076 HloCloneContext* context) const override; 1077 1078 // The type of the fusion. Used by kFusion only. 1079 FusionKind fusion_kind_; 1080 }; 1081 1082 class HloRngInstruction : public HloInstruction { 1083 public: 1084 explicit HloRngInstruction(const Shape& shape, 1085 RandomDistribution distribution, 1086 absl::Span<HloInstruction* const> parameters); 1087 // Returns the random distribution for this rng node. random_distribution()1088 RandomDistribution random_distribution() const { return distribution_; } 1089 // Returns a serialized representation of this instruction. 1090 HloInstructionProto ToProto() const override; 1091 1092 private: 1093 bool IsElementwiseImpl( 1094 const absl::optional<int64>& operand_idx) const override; 1095 std::vector<string> ExtraAttributesToStringImpl( 1096 const HloPrintOptions& options) const override; 1097 bool IdenticalSlowPath( 1098 const HloInstruction& other, 1099 const std::function<bool(const HloComputation*, const HloComputation*)>& 1100 eq_computations) const override; 1101 // Implementation for non-common logic of CloneWithNewOperands. 1102 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1103 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1104 HloCloneContext* context) const override; 1105 1106 // The distribution requested for random number generation. 1107 RandomDistribution distribution_; 1108 }; 1109 1110 class HloParameterInstruction : public HloInstruction { 1111 public: 1112 explicit HloParameterInstruction(int64_t parameter_number, const Shape& shape, 1113 const string& name); parameter_number()1114 int64 parameter_number() const { return parameter_number_; } 1115 1116 // Sets and gets the whether all replicas will receive the same parameter data 1117 // for each leaf buffer in data parallelism. set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1118 void set_parameter_replicated_at_leaf_buffers( 1119 absl::Span<const bool> parameter_replicated_at_leaf_buffers) { 1120 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 1121 parameter_replicated_at_leaf_buffers.size()); 1122 parameter_replicated_at_leaf_buffers_.emplace( 1123 parameter_replicated_at_leaf_buffers.begin(), 1124 parameter_replicated_at_leaf_buffers.end()); 1125 } set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1126 void set_parameter_replicated_at_leaf_buffers( 1127 const std::vector<bool>& parameter_replicated_at_leaf_buffers) { 1128 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 1129 parameter_replicated_at_leaf_buffers.size()); 1130 parameter_replicated_at_leaf_buffers_ = 1131 parameter_replicated_at_leaf_buffers; 1132 } 1133 const absl::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()1134 parameter_replicated_at_leaf_buffers() const { 1135 return parameter_replicated_at_leaf_buffers_; 1136 } 1137 1138 // Returns a serialized representation of this instruction. 1139 HloInstructionProto ToProto() const override; 1140 1141 private: 1142 std::vector<string> ExtraAttributesToStringImpl( 1143 const HloPrintOptions& options) const override; 1144 bool IdenticalSlowPath( 1145 const HloInstruction& other, 1146 const std::function<bool(const HloComputation*, const HloComputation*)>& 1147 eq_computations) const override; 1148 string OperandsToStringWithCanonicalNameMap( 1149 const HloPrintOptions& options, 1150 CanonicalNameMap* canonical_name_map) const override; 1151 // Implementation for non-common logic of CloneWithNewOperands. 1152 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1153 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1154 HloCloneContext* context) const override; 1155 1156 int64 parameter_number_ = 0; 1157 1158 // Specifies whether each buffer has the same parameter value on all replicas 1159 // in data parallelism. 1160 absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_; 1161 }; 1162 1163 class HloGetTupleElementInstruction : public HloInstruction { 1164 public: 1165 explicit HloGetTupleElementInstruction(const Shape& shape, 1166 HloInstruction* operand, 1167 int64_t index); 1168 // Returns the tuple index associated with this instruction. tuple_index()1169 int64 tuple_index() const { return tuple_index_; } 1170 // Sets the tuple index associated with this instruction. set_tuple_index(int64_t new_tuple_index)1171 void set_tuple_index(int64_t new_tuple_index) { 1172 tuple_index_ = new_tuple_index; 1173 } 1174 // Returns a serialized representation of this instruction. 1175 HloInstructionProto ToProto() const override; 1176 1177 private: 1178 std::vector<string> ExtraAttributesToStringImpl( 1179 const HloPrintOptions& options) const override; 1180 bool IdenticalSlowPath( 1181 const HloInstruction& other, 1182 const std::function<bool(const HloComputation*, const HloComputation*)>& 1183 eq_computations) const override; 1184 // Implementation for non-common logic of CloneWithNewOperands. 1185 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1186 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1187 HloCloneContext* context) const override; 1188 1189 int64 tuple_index_ = -1; 1190 }; 1191 1192 class HloReducePrecisionInstruction : public HloInstruction { 1193 public: 1194 explicit HloReducePrecisionInstruction(const Shape& shape, 1195 HloInstruction* operand, 1196 const int exponent_bits, 1197 const int mantissa_bits); 1198 // Returns the number of exponent bits for a reduce-precision node. exponent_bits()1199 int32 exponent_bits() const { return exponent_bits_; } 1200 // Returns the number of mantissa bits for a reduce-precision node. mantissa_bits()1201 int32 mantissa_bits() const { return mantissa_bits_; } 1202 // Returns a serialized representation of this instruction. 1203 HloInstructionProto ToProto() const override; 1204 1205 private: 1206 std::vector<string> ExtraAttributesToStringImpl( 1207 const HloPrintOptions& options) const override; 1208 bool IdenticalSlowPath( 1209 const HloInstruction& other, 1210 const std::function<bool(const HloComputation*, const HloComputation*)>& 1211 eq_computations) const override; 1212 // Implementation for non-common logic of CloneWithNewOperands. 1213 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1214 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1215 HloCloneContext* context) const override; 1216 1217 // The bit sizes for a reduce-precision operation. 1218 int32 exponent_bits_ = 0; 1219 int32 mantissa_bits_ = 0; 1220 }; 1221 1222 class HloInfeedInstruction : public HloInstruction { 1223 public: 1224 explicit HloInfeedInstruction(const Shape& infeed_shape, 1225 HloInstruction* token_operand, 1226 const string& config); 1227 // Returns the infeed configuration string. The infeed configuration includes 1228 // any metadata needed for the backend compiler (e.g., infeed buffer address) 1229 // and is target-dependent. infeed_config()1230 string infeed_config() const { return infeed_config_; } set_infeed_config(const string & config)1231 void set_infeed_config(const string& config) { infeed_config_ = config; } 1232 // Returns the shape of the data received by the infeed. This is not the same 1233 // as the shape of the infeed instruction which produces a tuple containing 1234 // the infeed data shape and a TOKEN. infeed_shape()1235 const Shape& infeed_shape() const { 1236 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); 1237 return ShapeUtil::GetSubshape(shape(), {0}); 1238 } 1239 // Returns a serialized representation of this instruction. 1240 HloInstructionProto ToProto() const override; 1241 1242 private: 1243 std::vector<string> ExtraAttributesToStringImpl( 1244 const HloPrintOptions& options) const override; 1245 bool IdenticalSlowPath( 1246 const HloInstruction& other, 1247 const std::function<bool(const HloComputation*, const HloComputation*)>& 1248 eq_computations) const override; 1249 // Implementation for non-common logic of CloneWithNewOperands. 1250 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1251 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1252 HloCloneContext* context) const override; 1253 1254 // The string representation of the infeed configuration. 1255 string infeed_config_; 1256 }; 1257 1258 class HloOutfeedInstruction : public HloInstruction { 1259 public: 1260 explicit HloOutfeedInstruction(const Shape& outfeed_shape, 1261 HloInstruction* operand, 1262 HloInstruction* token_operand, 1263 absl::string_view outfeed_config); 1264 // Returns the shape for the Outfeed instruction. outfeed_shape()1265 const Shape& outfeed_shape() const { return outfeed_shape_; } 1266 // Returns the mutable shape for the Outfeed instruction. mutable_outfeed_shape()1267 Shape* mutable_outfeed_shape() { return &outfeed_shape_; } 1268 // Returns the config for the Outfeed instruction. outfeed_config()1269 const string& outfeed_config() const { return outfeed_config_; } set_outfeed_config(const string & config)1270 void set_outfeed_config(const string& config) { outfeed_config_ = config; } 1271 // Returns a serialized representation of this instruction. 1272 HloInstructionProto ToProto() const override; 1273 1274 private: 1275 std::vector<string> ExtraAttributesToStringImpl( 1276 const HloPrintOptions& options) const override; 1277 bool IdenticalSlowPath( 1278 const HloInstruction& other, 1279 const std::function<bool(const HloComputation*, const HloComputation*)>& 1280 eq_computations) const override; 1281 // Implementation for non-common logic of CloneWithNewOperands. 1282 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1283 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1284 HloCloneContext* context) const override; 1285 1286 // Shape of outfeed request. 1287 Shape outfeed_shape_; 1288 // Outfeed configuration information, only present for kOutfeed. 1289 string outfeed_config_; 1290 }; 1291 1292 class HloConvolutionInstruction : public HloInstruction { 1293 public: 1294 explicit HloConvolutionInstruction( 1295 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 1296 int64_t feature_group_count, int64_t batch_group_count, 1297 const Window& window, 1298 const ConvolutionDimensionNumbers& dimension_numbers, 1299 const PrecisionConfig& precision_config); window()1300 const Window& window() const override { return window_; } set_window(const Window & window)1301 void set_window(const Window& window) override { window_ = window; } convolution_dimension_numbers()1302 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1303 return convolution_dimension_numbers_; 1304 } set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1305 void set_convolution_dimension_numbers( 1306 const ConvolutionDimensionNumbers& dnums) { 1307 convolution_dimension_numbers_ = dnums; 1308 } 1309 // The number of feature groups. Must be a divisor of the input feature 1310 // dimension and output feature dimension. feature_group_count()1311 int64 feature_group_count() const { return feature_group_count_; } set_feature_group_count(int64_t num_feature_groups)1312 void set_feature_group_count(int64_t num_feature_groups) { 1313 feature_group_count_ = num_feature_groups; 1314 } 1315 // The number of batch groups. Must be a divisor of the input batch dimension. batch_group_count()1316 int64 batch_group_count() const { return batch_group_count_; } set_batch_group_count(int64_t num_batch_groups)1317 void set_batch_group_count(int64_t num_batch_groups) { 1318 batch_group_count_ = num_batch_groups; 1319 } 1320 1321 // Returns the information used to tell the implementation information about 1322 // what sort of precision is requested. The meaning of the field is backend 1323 // specific. At the moment, it is only supported for kConvolution and kDot. 1324 // Transformations on one kDot or kConvolution to another will preserve this 1325 // information. Transformations to other HLOs will not preserve this 1326 // information but it is presumed that the alternate lowering is strictly 1327 // superior. precision_config()1328 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1329 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1330 1331 string ToCategory() const override; 1332 // Returns a serialized representation of this instruction. 1333 HloInstructionProto ToProto() const override; 1334 1335 private: 1336 std::vector<string> ExtraAttributesToStringImpl( 1337 const HloPrintOptions& options) const override; 1338 bool IdenticalSlowPath( 1339 const HloInstruction& other, 1340 const std::function<bool(const HloComputation*, const HloComputation*)>& 1341 eq_computations) const override; 1342 // Implementation for non-common logic of CloneWithNewOperands. 1343 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1344 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1345 HloCloneContext* context) const override; 1346 // The number of feature groups. Must be a divisor of the input feature 1347 // dimension and output feature dimension. 1348 int64 feature_group_count_; 1349 // The number of batch groups. Must be a divisor of the input batch dimension. 1350 int64 batch_group_count_; 1351 // Describes the window used for a convolution. 1352 Window window_; 1353 // Describes the dimension numbers used for a convolution. 1354 ConvolutionDimensionNumbers convolution_dimension_numbers_; 1355 // Information used to communicate to the implementation about the algorithm 1356 // used to produce results. See the documentation on precision_config(). 1357 PrecisionConfig precision_config_; 1358 }; 1359 1360 class HloReduceWindowInstruction : public HloInstruction { 1361 public: 1362 explicit HloReduceWindowInstruction(const Shape& shape, 1363 HloInstruction* operand, 1364 HloInstruction* init_value, 1365 const Window& window, 1366 HloComputation* reduce_computation); 1367 explicit HloReduceWindowInstruction( 1368 const Shape& shape, absl::Span<HloInstruction* const> operands, 1369 absl::Span<HloInstruction* const> init_values, const Window& window, 1370 HloComputation* reduce_computation); window()1371 const Window& window() const override { return window_; } set_window(const Window & window)1372 void set_window(const Window& window) override { window_ = window; } 1373 // Returns a serialized representation of this instruction. 1374 HloInstructionProto ToProto() const override; 1375 // Returns the number of input arrays (and, consequentially, the number of 1376 // init values) this reduce has. input_count()1377 int64 input_count() const { return operand_count() / 2; } 1378 // Returns the input tensors to be reduced. inputs()1379 absl::Span<HloInstruction* const> inputs() const { 1380 return absl::MakeSpan(operands()).subspan(0, input_count()); 1381 } 1382 // Returns the init values of the reduction. init_values()1383 absl::Span<HloInstruction* const> init_values() const { 1384 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 1385 } 1386 // Returns the shapes of input tensors to be reduced. input_shapes()1387 absl::InlinedVector<const Shape*, 2> input_shapes() const { 1388 absl::InlinedVector<const Shape*, 2> shapes; 1389 for (const auto* op : inputs()) { 1390 VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n"; 1391 shapes.push_back(&op->shape()); 1392 VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n"; 1393 } 1394 return shapes; 1395 } 1396 // Returns the init values of the reduction. init_value_shapes()1397 absl::InlinedVector<const Shape*, 2> init_value_shapes() const { 1398 absl::InlinedVector<const Shape*, 2> shapes; 1399 for (const auto* op : init_values()) { 1400 shapes.push_back(&op->shape()); 1401 } 1402 return shapes; 1403 } 1404 // Returns the shapes of the reduced output tensors. output_shapes()1405 absl::InlinedVector<const Shape*, 2> output_shapes() const { 1406 absl::InlinedVector<const Shape*, 2> shapes; 1407 if (shape().IsArray()) { 1408 shapes.push_back(&shape()); 1409 } else { 1410 for (const Shape& tuple_element_shape : shape().tuple_shapes()) { 1411 shapes.push_back(&tuple_element_shape); 1412 } 1413 } 1414 return shapes; 1415 } 1416 1417 private: 1418 std::vector<string> ExtraAttributesToStringImpl( 1419 const HloPrintOptions& options) const override; 1420 bool IdenticalSlowPath( 1421 const HloInstruction& other, 1422 const std::function<bool(const HloComputation*, const HloComputation*)>& 1423 eq_computations) const override; 1424 // Implementation for non-common logic of CloneWithNewOperands. 1425 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1426 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1427 HloCloneContext* context) const override; 1428 1429 Window window_; 1430 }; 1431 1432 class HloSelectAndScatterInstruction : public HloInstruction { 1433 public: 1434 explicit HloSelectAndScatterInstruction( 1435 const Shape& shape, HloInstruction* operand, HloComputation* select, 1436 const Window& window, HloInstruction* source, HloInstruction* init_value, 1437 HloComputation* scatter); window()1438 const Window& window() const override { return window_; } set_window(const Window & window)1439 void set_window(const Window& window) override { window_ = window; } 1440 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The 1441 // setters should only be called by HloModule or HloComputation methods. select()1442 HloComputation* select() const { 1443 return called_computations()[kSelectComputationIndex]; 1444 } 1445 scatter()1446 HloComputation* scatter() const { 1447 return called_computations()[kScatterComputationIndex]; 1448 } 1449 set_select(HloComputation * computation)1450 void set_select(HloComputation* computation) { 1451 // Don't allow changing the computation for fused instructions so we don't 1452 // have to recompute called_instructions for the entire fusion instruction. 1453 CHECK(!IsFused()); 1454 set_called_computation(kSelectComputationIndex, computation); 1455 } 1456 set_scatter(HloComputation * computation)1457 void set_scatter(HloComputation* computation) { 1458 // Don't allow changing the computation for fused instructions so we don't 1459 // have to recompute called_instructions for the entire fusion instruction. 1460 CHECK(!IsFused()); 1461 set_called_computation(kScatterComputationIndex, computation); 1462 } 1463 // Returns a serialized representation of this instruction. 1464 HloInstructionProto ToProto() const override; 1465 1466 private: 1467 std::vector<string> ExtraAttributesToStringImpl( 1468 const HloPrintOptions& options) const override; 1469 bool IdenticalSlowPath( 1470 const HloInstruction& other, 1471 const std::function<bool(const HloComputation*, const HloComputation*)>& 1472 eq_computations) const override; 1473 // Implementation for non-common logic of CloneWithNewOperands. 1474 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1475 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1476 HloCloneContext* context) const override; 1477 Window window_; 1478 }; 1479 1480 class HloCustomCallInstruction : public HloInstruction { 1481 public: 1482 HloCustomCallInstruction(const Shape& shape, 1483 absl::Span<HloInstruction* const> operands, 1484 absl::string_view custom_call_target, string opaque, 1485 CustomCallApiVersion api_version); 1486 1487 // Constructor for a custom call with constrained layout. 'shape' and 1488 // 'operands_with_layout' must all have layouts. 1489 HloCustomCallInstruction(const Shape& shape, 1490 absl::Span<HloInstruction* const> operands, 1491 absl::string_view custom_call_target, string opaque, 1492 absl::Span<const Shape> operand_shapes_with_layout, 1493 CustomCallApiVersion api_version); 1494 1495 // Constructor for a custom call with a to_apply computation. 1496 HloCustomCallInstruction(const Shape& shape, 1497 absl::Span<HloInstruction* const> operands, 1498 HloComputation* to_apply, 1499 absl::string_view custom_call_target, string opaque, 1500 CustomCallApiVersion api_version); 1501 1502 // Constructor for a custom call with multiple computations. 1503 HloCustomCallInstruction( 1504 const Shape& shape, absl::Span<HloInstruction* const> operands, 1505 absl::Span<HloComputation* const> called_computations, 1506 absl::string_view custom_call_target, string opaque, 1507 CustomCallApiVersion api_version); 1508 window()1509 const Window& window() const override { 1510 CHECK(window_ != nullptr); 1511 return *window_; 1512 } 1513 set_window(const Window & window)1514 void set_window(const Window& window) override { 1515 window_ = absl::make_unique<Window>(window); 1516 } 1517 convolution_dimension_numbers()1518 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1519 CHECK(convolution_dimension_numbers_ != nullptr); 1520 return *convolution_dimension_numbers_; 1521 } 1522 set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1523 void set_convolution_dimension_numbers( 1524 const ConvolutionDimensionNumbers& dnums) { 1525 convolution_dimension_numbers_ = 1526 absl::make_unique<ConvolutionDimensionNumbers>(dnums); 1527 } 1528 // TODO(jpienaar): Remove this accessor in the follow up. opaque()1529 const string& opaque() const { return raw_backend_config_string(); } custom_call_target()1530 const string& custom_call_target() const { return custom_call_target_; } set_feature_group_count(int64_t feature_group_count)1531 void set_feature_group_count(int64_t feature_group_count) { 1532 feature_group_count_ = feature_group_count; 1533 } set_batch_group_count(int64_t batch_group_count)1534 void set_batch_group_count(int64_t batch_group_count) { 1535 batch_group_count_ = batch_group_count; 1536 } 1537 // Sets whether this custom call has a side-effect - by default a custom call 1538 // has no side-effects. set_custom_call_has_side_effect(bool custom_call_has_side_effect)1539 void set_custom_call_has_side_effect(bool custom_call_has_side_effect) { 1540 custom_call_has_side_effect_ = custom_call_has_side_effect; 1541 } feature_group_count()1542 int64 feature_group_count() const { return feature_group_count_; } batch_group_count()1543 int64 batch_group_count() const { return batch_group_count_; } custom_call_has_side_effect()1544 bool custom_call_has_side_effect() const { 1545 return custom_call_has_side_effect_; 1546 } 1547 // Returns padding type used for ops like convolution. padding_type()1548 PaddingType padding_type() const { return padding_type_; } 1549 set_padding_type(PaddingType padding_type)1550 void set_padding_type(PaddingType padding_type) { 1551 padding_type_ = padding_type; 1552 } 1553 1554 // Returns the literal associated with this instruction. literal()1555 const Literal& literal() const { return *literal_; } 1556 // Set the value of literal to a new one. set_literal(Literal && literal)1557 void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); } 1558 // Returns whether there is literal associated with this instruction. HasLiteral()1559 bool HasLiteral() const { return literal_.has_value(); } 1560 precision_config()1561 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1562 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1563 1564 // Returns a serialized representation of this instruction. 1565 HloInstructionProto ToProto() const override; 1566 1567 // Returns whether the result and operand layouts are constrained. layout_constrained()1568 bool layout_constrained() const { return layout_constrained_; } 1569 1570 // Returns the shapes (with layout) of the operands. CHECKs if this custom 1571 // call does not have constrained layouts. operand_shapes_with_layout()1572 const std::vector<Shape>& operand_shapes_with_layout() const { 1573 CHECK(layout_constrained()); 1574 return operand_shapes_with_layout_; 1575 } 1576 // Gets a list of output/operand buffer pairs that alias each other, where the 1577 // output buffer is represented as a ShapeIndex, and the operand buffer is 1578 // represented as the operand index and the ShapeIndex. By default this list 1579 // is empty. 1580 const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>& output_to_operand_aliasing()1581 output_to_operand_aliasing() const { 1582 return output_to_operand_aliasing_; 1583 } 1584 // Sets the list of output/operand buffer pairs that alias each other. set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> aliasing)1585 void set_output_to_operand_aliasing( 1586 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 1587 aliasing) { 1588 output_to_operand_aliasing_ = std::move(aliasing); 1589 } set_custom_call_schedule(CustomCallSchedule custom_call_schedule)1590 void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) { 1591 custom_call_schedule_ = custom_call_schedule; 1592 } custom_call_schedule()1593 CustomCallSchedule custom_call_schedule() const { 1594 return custom_call_schedule_; 1595 } set_api_version(CustomCallApiVersion api_version)1596 void set_api_version(CustomCallApiVersion api_version) { 1597 api_version_ = api_version; 1598 } api_version()1599 CustomCallApiVersion api_version() const { return api_version_; } 1600 1601 private: 1602 std::vector<string> ExtraAttributesToStringImpl( 1603 const HloPrintOptions& options) const override; 1604 bool IdenticalSlowPath( 1605 const HloInstruction& other, 1606 const std::function<bool(const HloComputation*, const HloComputation*)>& 1607 eq_computations) const override; 1608 // Implementation for non-common logic of CloneWithNewOperands. 1609 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1610 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1611 HloCloneContext* context) const override; 1612 // Name of a global symbol to call. 1613 string custom_call_target_; 1614 // Describes the window in a windowed operation such as convolution. 1615 std::unique_ptr<Window> window_; 1616 // Describes the dimension numbers used for a convolution. 1617 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; 1618 // The number of feature groups. This is used for grouped convolutions. 1619 int64 feature_group_count_; 1620 int64 batch_group_count_; 1621 // Whether the result and operand layouts are constrained. 1622 bool layout_constrained_; 1623 // Information used to communicate to the implementation about the algorithm 1624 // used to produce results for convolution instructions. 1625 PrecisionConfig precision_config_; 1626 // Describes the padding type for convolution instructions. 1627 PaddingType padding_type_; 1628 // For layout-constrained custom calls, this vector holds the shape with 1629 // layout for each operand. 1630 std::vector<Shape> operand_shapes_with_layout_; 1631 // Whether this custom call has a side-effect. 1632 bool custom_call_has_side_effect_; 1633 // A list of output/operand buffer pairs that alias each other. See comment of 1634 // output_to_operand_aliasing(). 1635 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 1636 output_to_operand_aliasing_; 1637 absl::optional<Literal> literal_; 1638 // A custom-call schedule hint. 1639 CustomCallSchedule custom_call_schedule_; 1640 // The version of the API used by the custom call function. 1641 // TODO(b/189822916): Remove this field when all clients are migrated to the 1642 // status-returning API. 1643 CustomCallApiVersion api_version_; 1644 }; 1645 1646 class HloPadInstruction : public HloInstruction { 1647 public: 1648 explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, 1649 HloInstruction* padding_value, 1650 const PaddingConfig& padding_config); 1651 // Returns the padding configuration for a pad node. padding_config()1652 const PaddingConfig& padding_config() const { return padding_config_; } mutable_padding_config()1653 PaddingConfig* mutable_padding_config() { return &padding_config_; } 1654 // Returns the padding value. padding_value()1655 const HloInstruction* padding_value() const { return operand(1); } mutable_padding_value()1656 HloInstruction* mutable_padding_value() { return mutable_operand(1); } 1657 // Returns a serialized representation of this instruction. 1658 HloInstructionProto ToProto() const override; 1659 1660 private: 1661 std::vector<string> ExtraAttributesToStringImpl( 1662 const HloPrintOptions& options) const override; 1663 bool IdenticalSlowPath( 1664 const HloInstruction& other, 1665 const std::function<bool(const HloComputation*, const HloComputation*)>& 1666 eq_computations) const override; 1667 // Implementation for non-common logic of CloneWithNewOperands. 1668 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1669 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1670 HloCloneContext* context) const override; 1671 1672 // The padding configuration that describes the edge padding and interior 1673 // padding of this pad instruction. 1674 PaddingConfig padding_config_; 1675 }; 1676 1677 class HloDynamicIndexInstruction : public HloInstruction { 1678 public: HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1679 explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) 1680 : HloInstruction(opcode, shape) {} 1681 virtual int64 first_index_operand_number() const = 0; 1682 1683 // Returns a subspan of operands which represent the start indices. index_operands()1684 absl::Span<HloInstruction* const> index_operands() const { 1685 return absl::MakeSpan(operands()).subspan(first_index_operand_number()); 1686 } 1687 1688 // Returns the shapes of the index operands. index_shapes()1689 std::vector<Shape> index_shapes() const { 1690 std::vector<Shape> shapes; 1691 auto indices = index_operands(); 1692 for (const HloInstruction* index : indices) { 1693 shapes.push_back(index->shape()); 1694 } 1695 return shapes; 1696 } 1697 }; 1698 1699 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { 1700 public: 1701 explicit HloDynamicSliceInstruction(const Shape& shape, 1702 HloInstruction* operand, 1703 HloInstruction* start_indices, 1704 absl::Span<const int64> slice_sizes); 1705 explicit HloDynamicSliceInstruction( 1706 const Shape& shape, HloInstruction* operand, 1707 absl::Span<HloInstruction* const> start_indices, 1708 absl::Span<const int64> slice_sizes); 1709 // Old methods kept for smooth subclassing transition END. 1710 // Returns the size of the slice in the given dimension for a dynamic 1711 // slice node. slice_sizes(int64_t dimension)1712 int64 slice_sizes(int64_t dimension) const { 1713 return dynamic_slice_sizes_[dimension]; 1714 } dynamic_slice_sizes()1715 const std::vector<int64>& dynamic_slice_sizes() const { 1716 return dynamic_slice_sizes_; 1717 } 1718 // Returns a serialized representation of this instruction. 1719 HloInstructionProto ToProto() const override; 1720 first_index_operand_number()1721 int64 first_index_operand_number() const override { return 1; } 1722 1723 private: 1724 std::vector<string> ExtraAttributesToStringImpl( 1725 const HloPrintOptions& options) const override; 1726 bool IdenticalSlowPath( 1727 const HloInstruction& other, 1728 const std::function<bool(const HloComputation*, const HloComputation*)>& 1729 eq_computations) const override; 1730 // Implementation for non-common logic of CloneWithNewOperands. 1731 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1732 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1733 HloCloneContext* context) const override; 1734 1735 // Describes the [start, start + size) range size for a dynamic slice 1736 // ('start' is specified dynamically in the second operand of the operation). 1737 std::vector<int64> dynamic_slice_sizes_; 1738 }; 1739 1740 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { 1741 public: 1742 explicit HloDynamicUpdateSliceInstruction(const Shape& shape, 1743 HloInstruction* operand, 1744 HloInstruction* update, 1745 HloInstruction* start_indices); 1746 explicit HloDynamicUpdateSliceInstruction( 1747 const Shape& shape, HloInstruction* operand, HloInstruction* update, 1748 absl::Span<HloInstruction* const> start_indices); 1749 first_index_operand_number()1750 int64 first_index_operand_number() const override { return 2; } 1751 }; 1752 1753 class HloGatherInstruction : public HloInstruction { 1754 public: 1755 explicit HloGatherInstruction( 1756 const Shape& shape, HloInstruction* operand, 1757 HloInstruction* start_indices, 1758 const GatherDimensionNumbers& gather_dim_numbers, 1759 absl::Span<const int64> slice_sizes, bool indices_are_sorted); gather_dimension_numbers()1760 const GatherDimensionNumbers& gather_dimension_numbers() const { 1761 CHECK(gather_dimension_numbers_ != nullptr); 1762 return *gather_dimension_numbers_; 1763 } gather_slice_sizes()1764 absl::Span<const int64> gather_slice_sizes() const { 1765 return gather_slice_sizes_; 1766 } indices_are_sorted()1767 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1768 void set_indices_are_sorted(bool indices_are_sorted) { 1769 indices_are_sorted_ = indices_are_sorted; 1770 } 1771 // Returns a serialized representation of this instruction. 1772 HloInstructionProto ToProto() const override; 1773 1774 // Creates an instance of GatherDimensionNumbers. 1775 static GatherDimensionNumbers MakeGatherDimNumbers( 1776 absl::Span<const int64> offset_dims, 1777 absl::Span<const int64> collapsed_slice_dims, 1778 absl::Span<const int64> start_index_map, int64_t index_vector_dim); 1779 // Returns the dump string of the given gather dimension numbers. 1780 static string GatherDimensionNumbersToString( 1781 const GatherDimensionNumbers& gather_dimension_numbers); 1782 1783 private: 1784 std::vector<string> ExtraAttributesToStringImpl( 1785 const HloPrintOptions& options) const override; 1786 bool IdenticalSlowPath( 1787 const HloInstruction& other, 1788 const std::function<bool(const HloComputation*, const HloComputation*)>& 1789 eq_computations) const override; 1790 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1791 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1792 HloCloneContext* context) const override; 1793 1794 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; 1795 std::vector<int64> gather_slice_sizes_; 1796 bool indices_are_sorted_; 1797 }; 1798 1799 class HloScatterInstruction : public HloInstruction { 1800 public: 1801 explicit HloScatterInstruction( 1802 const Shape& shape, HloInstruction* operand, 1803 HloInstruction* scatter_indices, HloInstruction* updates, 1804 HloComputation* update_computation, 1805 const ScatterDimensionNumbers& scatter_dim_numbers, 1806 bool indices_are_sorted, bool unique_indices); scatter_dimension_numbers()1807 const ScatterDimensionNumbers& scatter_dimension_numbers() const { 1808 CHECK(scatter_dimension_numbers_ != nullptr); 1809 return *scatter_dimension_numbers_; 1810 } indices_are_sorted()1811 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1812 void set_indices_are_sorted(bool indices_are_sorted) { 1813 indices_are_sorted_ = indices_are_sorted; 1814 } unique_indices()1815 bool unique_indices() const override { return unique_indices_; } 1816 // Returns a serialized representation of this instruction. 1817 HloInstructionProto ToProto() const override; 1818 1819 // Creates an instance of ScatterDimensionNumbers. 1820 static ScatterDimensionNumbers MakeScatterDimNumbers( 1821 absl::Span<const int64> update_window_dims, 1822 absl::Span<const int64> inserted_window_dims, 1823 absl::Span<const int64> scatter_dims_to_operand_dims, 1824 int64_t index_vector_dim); 1825 // Returns the dump string of the given scatter dimension numbers. 1826 static string ScatterDimensionNumbersToString( 1827 const ScatterDimensionNumbers& scatter_dimension_numbers); 1828 1829 private: 1830 std::vector<string> ExtraAttributesToStringImpl( 1831 const HloPrintOptions& options) const override; 1832 bool IdenticalSlowPath( 1833 const HloInstruction& other, 1834 const std::function<bool(const HloComputation*, const HloComputation*)>& 1835 eq_computations) const override; 1836 // Implementation for non-common logic of CloneWithNewOperands. 1837 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1838 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1839 HloCloneContext* context) const override; 1840 1841 std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; 1842 bool indices_are_sorted_; 1843 bool unique_indices_; 1844 }; 1845 1846 class HloIotaInstruction : public HloInstruction { 1847 public: 1848 explicit HloIotaInstruction(const Shape& shape, int64_t iota_dimension); 1849 // Returns the dimension sizes or numbers associated with this instruction. iota_dimension()1850 int64 iota_dimension() const { return iota_dimension_; } 1851 // Returns a serialized representation of this instruction. 1852 HloInstructionProto ToProto() const override; 1853 1854 private: 1855 std::vector<string> ExtraAttributesToStringImpl( 1856 const HloPrintOptions& options) const override; 1857 bool IdenticalSlowPath( 1858 const HloInstruction& other, 1859 const std::function<bool(const HloComputation*, const HloComputation*)>& 1860 eq_computations) const override; 1861 // Implementation for non-common logic of CloneWithNewOperands. 1862 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1863 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1864 HloCloneContext* context) const override; 1865 1866 const int64 iota_dimension_; 1867 }; 1868 1869 class HloDotInstruction : public HloInstruction { 1870 public: 1871 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 1872 // dimensions specified in 'dimension_numbers'. 1873 explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, 1874 HloInstruction* rhs, 1875 const DotDimensionNumbers& dimension_numbers, 1876 const PrecisionConfig& precision_config); 1877 1878 // Returns data on the dimension numbers used for a dot operation. dot_dimension_numbers()1879 const DotDimensionNumbers& dot_dimension_numbers() const { 1880 return dot_dimension_numbers_; 1881 } 1882 1883 // Returns the information used to tell the implementation information about 1884 // what sort of precision is requested. The meaning of the field is backend 1885 // specific. At the moment, it is only supported for kConvolution and kDot. 1886 // Transformations on one kDot or kConvolution to another will preserve this 1887 // information. Transformations to other HLOs will not preserve this 1888 // information but it is presumed that the alternate lowering is strictly 1889 // superior. precision_config()1890 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1891 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1892 1893 // Returns a serialized representation of this instruction. 1894 HloInstructionProto ToProto() const override; 1895 1896 private: 1897 std::vector<string> ExtraAttributesToStringImpl( 1898 const HloPrintOptions& options) const override; 1899 bool IdenticalSlowPath( 1900 const HloInstruction& other, 1901 const std::function<bool(const HloComputation*, const HloComputation*)>& 1902 eq_computations) const override; 1903 // Implementation for non-common logic of CloneWithNewOperands. 1904 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1905 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1906 HloCloneContext* context) const override; 1907 // Returns the dump string of the dot dimension numbers. 1908 string DotDimensionNumbersToString() const; 1909 1910 // Describes the dimension numbers used for a dot. 1911 DotDimensionNumbers dot_dimension_numbers_; 1912 1913 // Information used to communicate to the implementation about the algorithm 1914 // used to produce results. See the documentation on precision_config(). 1915 PrecisionConfig precision_config_; 1916 }; 1917 1918 class HloDomainInstruction : public HloInstruction { 1919 public: 1920 explicit HloDomainInstruction( 1921 const Shape& shape, HloInstruction* operand, 1922 std::unique_ptr<DomainMetadata> operand_side_metadata, 1923 std::unique_ptr<DomainMetadata> user_side_metadata); 1924 1925 // Returns a serialized representation of this instruction. 1926 HloInstructionProto ToProto() const override; 1927 1928 // Retrieves the operand side metadata of a kDomain instruction. operand_side_metadata()1929 const DomainMetadata& operand_side_metadata() const { 1930 return *operand_side_metadata_; 1931 } 1932 // Retrieves the user side metadata of a kDomain instruction. user_side_metadata()1933 const DomainMetadata& user_side_metadata() const { 1934 return *user_side_metadata_; 1935 } 1936 1937 private: 1938 std::vector<string> ExtraAttributesToStringImpl( 1939 const HloPrintOptions& options) const override; 1940 bool IdenticalSlowPath( 1941 const HloInstruction& other, 1942 const std::function<bool(const HloComputation*, const HloComputation*)>& 1943 eq_computations) const override; 1944 // Implementation for non-common logic of CloneWithNewOperands. 1945 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1946 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1947 HloCloneContext* context) const override; 1948 1949 std::unique_ptr<DomainMetadata> operand_side_metadata_; 1950 std::unique_ptr<DomainMetadata> user_side_metadata_; 1951 }; 1952 1953 class HloGetDimensionSizeInstruction : public HloInstruction { 1954 public: 1955 explicit HloGetDimensionSizeInstruction(const Shape& shape, 1956 HloInstruction* operand, 1957 int64_t dimension); 1958 1959 // Returns the dimension sizes or numbers associated with this instruction. dimension()1960 int64 dimension() const { return dimension_; } 1961 // Returns a serialized representation of this instruction. 1962 HloInstructionProto ToProto() const override; 1963 1964 private: 1965 std::vector<string> ExtraAttributesToStringImpl( 1966 const HloPrintOptions& options) const override; 1967 bool IdenticalSlowPath( 1968 const HloInstruction& other, 1969 const std::function<bool(const HloComputation*, const HloComputation*)>& 1970 eq_computations) const override; 1971 // Implementation for non-common logic of CloneWithNewOperands. 1972 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1973 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1974 HloCloneContext* context) const override; 1975 1976 int64 dimension_; 1977 }; 1978 1979 class HloSetDimensionSizeInstruction : public HloInstruction { 1980 public: 1981 explicit HloSetDimensionSizeInstruction(const Shape& shape, 1982 HloInstruction* operand, 1983 HloInstruction* val, 1984 int64_t dimension); 1985 1986 // Returns the dimension sizes or numbers associated with this instruction. dimension()1987 int64 dimension() const { return dimension_; } 1988 // Returns a serialized representation of this instruction. 1989 HloInstructionProto ToProto() const override; 1990 1991 private: 1992 std::vector<string> ExtraAttributesToStringImpl( 1993 const HloPrintOptions& options) const override; 1994 bool IdenticalSlowPath( 1995 const HloInstruction& other, 1996 const std::function<bool(const HloComputation*, const HloComputation*)>& 1997 eq_computations) const override; 1998 // Implementation for non-common logic of CloneWithNewOperands. 1999 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2000 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2001 HloCloneContext* context) const override; 2002 2003 int64 dimension_; 2004 }; 2005 2006 class HloRngGetAndUpdateStateInstruction : public HloInstruction { 2007 public: 2008 explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, 2009 int64_t delta); 2010 2011 // Returns the delta value. delta()2012 int64 delta() const { return delta_; } set_delta(int64_t delta)2013 void set_delta(int64_t delta) { delta_ = delta; } 2014 // Returns a serialized representation of this instruction. 2015 HloInstructionProto ToProto() const override; 2016 2017 private: 2018 std::vector<string> ExtraAttributesToStringImpl( 2019 const HloPrintOptions& options) const override; 2020 bool IdenticalSlowPath( 2021 const HloInstruction& other, 2022 const std::function<bool(const HloComputation*, const HloComputation*)>& 2023 eq_computations) const override; 2024 // Implementation for non-common logic of CloneWithNewOperands. 2025 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2026 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2027 HloCloneContext* context) const override; 2028 2029 int64 delta_; 2030 }; 2031 2032 class HloRngBitGeneratorInstruction : public HloInstruction { 2033 public: 2034 HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state, 2035 RandomAlgorithm algorithm); 2036 algorithm()2037 RandomAlgorithm algorithm() const { return algorithm_; } 2038 HloInstructionProto ToProto() const override; 2039 2040 private: 2041 std::vector<string> ExtraAttributesToStringImpl( 2042 const HloPrintOptions& options) const override; 2043 bool IdenticalSlowPath( 2044 const HloInstruction& other, 2045 const std::function<bool(const HloComputation*, const HloComputation*)>& 2046 eq_computations) const override; 2047 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 2048 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 2049 HloCloneContext* context) const override; 2050 2051 RandomAlgorithm algorithm_; 2052 }; 2053 2054 } // namespace xla 2055 2056 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 2057