1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // All HloInstruction subclasses are put in this file. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/shape.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 26 namespace xla { 27 28 class HloBatchNormInstruction : public HloInstruction { 29 public: 30 // Returns feature_index field associated with the instruction. The index 31 // represents the index of the feature dimension. feature_index()32 int64 feature_index() const { return feature_index_; } 33 34 // Returns a epsilon value associated with the instruction. The is a small 35 // number added to the variance to avoid divide-by-zero error. epsilon()36 float epsilon() const { return epsilon_; } 37 38 // Returns a serialized representation of this instruction. 39 HloInstructionProto ToProto() const override; 40 41 protected: 42 explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, 43 HloInstruction* operand, 44 HloInstruction* scale, float epsilon, 45 int64 feature_index); 46 47 private: 48 std::vector<string> ExtraAttributesToStringImpl( 49 const HloPrintOptions& options) const override; 50 bool IdenticalSlowPath( 51 const HloInstruction& other, 52 const std::function<bool(const HloComputation*, const HloComputation*)>& 53 eq_computations) const override; 54 // A small float number added to the variance to avoid divide-by-zero error. 55 float epsilon_ = 0.0f; 56 57 // An integer value representing the index of the feature dimension. 58 int64 feature_index_ = -1; 59 }; 60 61 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { 62 public: 63 explicit HloBatchNormTrainingInstruction(const Shape& shape, 64 HloInstruction* operand, 65 HloInstruction* scale, 66 HloInstruction* offset, 67 float epsilon, int64 feature_index); 68 69 private: 70 // Implementation for non-common logic of CloneWithNewOperands. 71 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 72 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 73 HloCloneContext* context) const override; 74 }; 75 76 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { 77 public: 78 explicit HloBatchNormInferenceInstruction( 79 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 80 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 81 float epsilon, int64 feature_index); 82 83 private: 84 // Implementation for non-common logic of CloneWithNewOperands. 85 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 86 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 87 HloCloneContext* context) const override; 88 }; 89 90 class HloBatchNormGradInstruction : public HloBatchNormInstruction { 91 public: 92 explicit HloBatchNormGradInstruction( 93 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 94 HloInstruction* mean, HloInstruction* variance, 95 HloInstruction* grad_output, float epsilon, int64 feature_index); 96 97 private: 98 // Implementation for non-common logic of CloneWithNewOperands. 99 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 100 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 101 HloCloneContext* context) const override; 102 }; 103 104 class HloFftInstruction : public HloInstruction { 105 public: 106 explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, 107 FftType fft_type, 108 absl::Span<const int64> fft_length); fft_type()109 FftType fft_type() const { return fft_type_; } 110 fft_length()111 const std::vector<int64>& fft_length() const { return fft_length_; } 112 113 // Returns a serialized representation of this instruction. 114 HloInstructionProto ToProto() const override; 115 116 private: 117 std::vector<string> ExtraAttributesToStringImpl( 118 const HloPrintOptions& options) const override; 119 bool IdenticalSlowPath( 120 const HloInstruction& other, 121 const std::function<bool(const HloComputation*, const HloComputation*)>& 122 eq_computations) const override; 123 124 // Implementation for non-common logic of CloneWithNewOperands. 125 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 126 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 127 HloCloneContext* context) const override; 128 129 // Describes FFT type for an FFT instruction. 130 FftType fft_type_ = FftType::FFT; 131 132 // Indicates the FFT length for an FFT instruction. 133 std::vector<int64> fft_length_; 134 }; 135 136 class HloCopyStartInstruction : public HloInstruction { 137 public: 138 explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, 139 bool is_cross_program_prefetch); 140 is_cross_program_prefetch()141 bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } 142 HloInstructionProto ToProto() const override; 143 144 private: 145 std::vector<string> ExtraAttributesToStringImpl( 146 const HloPrintOptions& options) const override; 147 bool IdenticalSlowPath( 148 const HloInstruction& other, 149 const std::function<bool(const HloComputation*, const HloComputation*)>& 150 eq_computations) const override; 151 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 152 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 153 HloCloneContext* context) const override; 154 155 bool is_cross_program_prefetch_; 156 }; 157 158 class HloCompareInstruction : public HloInstruction { 159 public: 160 explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, 161 HloInstruction* rhs, 162 ComparisonDirection direction, 163 absl::optional<Comparison::Type> type); direction()164 ComparisonDirection direction() const { return compare_.GetDirection(); } type()165 Comparison::Type type() const { return compare_.GetType(); } 166 HloInstructionProto ToProto() const override; 167 168 private: 169 std::vector<string> ExtraAttributesToStringImpl( 170 const HloPrintOptions& options) const override; 171 bool IdenticalSlowPath( 172 const HloInstruction& other, 173 const std::function<bool(const HloComputation*, const HloComputation*)>& 174 eq_computations) const override; 175 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 176 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 177 HloCloneContext* context) const override; 178 179 Comparison compare_; 180 }; 181 182 class HloTriangularSolveInstruction : public HloInstruction { 183 public: 184 explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, 185 HloInstruction* b, 186 const TriangularSolveOptions& options); triangular_solve_options()187 const TriangularSolveOptions& triangular_solve_options() const { 188 return triangular_solve_options_; 189 } 190 191 // Returns a serialized representation of this instruction. 192 HloInstructionProto ToProto() const override; 193 194 private: 195 std::vector<string> ExtraAttributesToStringImpl( 196 const HloPrintOptions& options) const override; 197 bool IdenticalSlowPath( 198 const HloInstruction& other, 199 const std::function<bool(const HloComputation*, const HloComputation*)>& 200 eq_computations) const override; 201 202 // Implementation for non-common logic of CloneWithNewOperands. 203 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 204 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 205 HloCloneContext* context) const override; 206 207 TriangularSolveOptions triangular_solve_options_; 208 }; 209 210 class HloCholeskyInstruction : public HloInstruction { 211 public: 212 explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, 213 const CholeskyOptions& options); cholesky_options()214 const CholeskyOptions& cholesky_options() const { return cholesky_options_; } 215 216 // Returns a serialized representation of this instruction. 217 HloInstructionProto ToProto() const override; 218 219 private: 220 std::vector<string> ExtraAttributesToStringImpl( 221 const HloPrintOptions& options) const override; 222 bool IdenticalSlowPath( 223 const HloInstruction& other, 224 const std::function<bool(const HloComputation*, const HloComputation*)>& 225 eq_computations) const override; 226 227 // Implementation for non-common logic of CloneWithNewOperands. 228 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 229 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 230 HloCloneContext* context) const override; 231 232 CholeskyOptions cholesky_options_; 233 }; 234 235 // Class that represents instructions that synchronize and transfer data between 236 // partitioned devices. Send/Recv and collective instructions (AllReduce, 237 // AllToAll, CollectivePermute) belong to this instruction type. A group of 238 // instructions (of the same opcode) with the same channel_id communicate during 239 // execution. 240 class HloChannelInstruction : public HloInstruction { 241 public: 242 // Returns the channel id associated with the instruction. The id is 243 // shared between each Send/Recv pair or a group of collective instructions 244 // and is globally unique to identify each channel. channel_id()245 absl::optional<int64> channel_id() const { return channel_id_; } 246 void set_channel_id(const absl::optional<int64>& channel_id); 247 248 // Whether this instruction is identical to `other` except for the values of 249 // channel IDs, as long as both have channel IDs or neither has a channel ID. IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)250 virtual bool IdenticalSlowPathIgnoringChannelIdValues( 251 const HloInstruction& other, 252 const std::function<bool(const HloComputation*, const HloComputation*)>& 253 eq_computations) const { 254 return channel_id_.has_value() == other.channel_id().has_value(); 255 } 256 257 protected: 258 explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, 259 const absl::optional<int64>& channel_id); 260 261 HloInstructionProto ToProto() const override; 262 263 std::vector<string> ExtraAttributesToStringImpl( 264 const HloPrintOptions& options) const override; 265 266 // Do not override IdenticalSlowPath(). Override 267 // IdenticalSlowPathIgnoringChannelIdValues() instead. 268 bool IdenticalSlowPath( 269 const HloInstruction& other, 270 const std::function<bool(const HloComputation*, const HloComputation*)>& 271 eq_computations) const final; 272 273 absl::optional<int64> channel_id_; 274 }; 275 276 class HloSendRecvInstruction : public HloChannelInstruction { 277 public: 278 // Returns whether this send/recv instruction sends data to/from the host. is_host_transfer()279 bool is_host_transfer() const { return is_host_transfer_; } 280 281 // Returns a serialized representation of this instruction. 282 HloInstructionProto ToProto() const override; 283 284 protected: 285 explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, 286 int64 channel_id, bool is_host_transfer); 287 288 private: 289 std::vector<string> ExtraAttributesToStringImpl( 290 const HloPrintOptions& options) const override; 291 bool IdenticalSlowPathIgnoringChannelIdValues( 292 const HloInstruction& other, 293 const std::function<bool(const HloComputation*, const HloComputation*)>& 294 eq_computations) const override; 295 // Whether this send/recv instruction sends data to/from the host. 296 bool is_host_transfer_; 297 }; 298 299 class HloSendInstruction : public HloSendRecvInstruction { 300 public: 301 explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, 302 int64 channel_id, bool is_host_transfer); 303 304 private: 305 // Implementation for non-common logic of CloneWithNewOperands. 306 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 307 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 308 HloCloneContext* context) const override; 309 }; 310 311 class HloSendDoneInstruction : public HloSendRecvInstruction { 312 public: 313 explicit HloSendDoneInstruction(HloSendInstruction* operand, 314 bool is_host_transfer); 315 316 private: 317 // Implementation for non-common logic of CloneWithNewOperands. 318 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 319 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 320 HloCloneContext* context) const override; 321 }; 322 323 class HloRecvInstruction : public HloSendRecvInstruction { 324 public: 325 explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, 326 int64 channel_id, bool is_host_transfer); 327 328 private: 329 // Implementation for non-common logic of CloneWithNewOperands. 330 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 331 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 332 HloCloneContext* context) const override; 333 }; 334 335 class HloRecvDoneInstruction : public HloSendRecvInstruction { 336 public: 337 explicit HloRecvDoneInstruction(HloRecvInstruction* operand, 338 bool is_host_transfer); 339 340 private: 341 // Implementation for non-common logic of CloneWithNewOperands. 342 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 343 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 344 HloCloneContext* context) const override; 345 }; 346 347 class HloCollectiveInstruction : public HloChannelInstruction { 348 public: replica_groups()349 const std::vector<ReplicaGroup>& replica_groups() const { 350 return replica_groups_; 351 } 352 353 // Returns true if the layout of the AllReduce is enforced by XLA client (as 354 // the layout set in the shape). The only reason for the client to set the 355 // layout is to separately compile computations that communicate with 356 // AllReduce. Since this field is only set `true` by the client, the compiler 357 // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set 358 // `false` for all other cases. 359 // 360 // When this is `true`, there may be communication endpoints outside the 361 // current compilation unit, so the compiler considers this AllReduce as 362 // side-effecting to disable compiler transformations. The compiler is free to 363 // transform unconstrained AllReduces differently across compilation units. 364 // It is an error for an HloModule to have a mix of constrained and 365 // unconstrained AllReduce instructions (checked by HloVerifier). constrain_layout()366 bool constrain_layout() const { return constrain_layout_; } 367 368 protected: 369 explicit HloCollectiveInstruction( 370 HloOpcode opcode, const Shape& shape, 371 absl::Span<HloInstruction* const> operands, 372 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 373 const absl::optional<int64>& channel_id); 374 375 HloInstructionProto ToProto() const override; 376 377 std::vector<string> ExtraAttributesToStringImpl( 378 const HloPrintOptions& options) const override; 379 bool IdenticalSlowPathIgnoringChannelIdValues( 380 const HloInstruction& other, 381 const std::function<bool(const HloComputation*, const HloComputation*)>& 382 eq_computations) const override; 383 384 std::vector<ReplicaGroup> replica_groups_; 385 bool constrain_layout_; 386 }; 387 388 class HloAllGatherInstruction : public HloCollectiveInstruction { 389 public: 390 explicit HloAllGatherInstruction( 391 const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, 392 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 393 const absl::optional<int64>& channel_id, bool use_global_device_ids); 394 // Same as HloAllReduceInstruction::use_global_device_ids. use_global_device_ids()395 bool use_global_device_ids() const { return use_global_device_ids_; } 396 397 // The dimension on which data from different participants are concatenated. all_gather_dimension()398 int64 all_gather_dimension() const { return all_gather_dimension_; } 399 400 protected: 401 std::vector<string> ExtraAttributesToStringImpl( 402 const HloPrintOptions& options) const override; 403 HloInstructionProto ToProto() const override; 404 405 private: 406 bool IdenticalSlowPathIgnoringChannelIdValues( 407 const HloInstruction& other, 408 const std::function<bool(const HloComputation*, const HloComputation*)>& 409 eq_computations) const override; 410 411 // Implementation for non-common logic of CloneWithNewOperands. 412 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 413 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 414 HloCloneContext* context) const override; 415 416 int64 all_gather_dimension_; 417 bool use_global_device_ids_; 418 }; 419 420 class HloAllReduceInstruction : public HloCollectiveInstruction { 421 public: 422 explicit HloAllReduceInstruction( 423 const Shape& shape, absl::Span<HloInstruction* const> operands, 424 HloComputation* reduce_computation, 425 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 426 const absl::optional<int64>& channel_id, bool use_global_device_ids); 427 428 // Returns true if the AllReduce does no communication, so it's equivalent 429 // to a mem copy. 430 bool IsNoop() const; 431 432 // Returns true if the ids in the ReplicaGroup config represent a global id of 433 // (replica_id * partition_count + partition_id) instead of a replica id. 434 // This enables more flexible grouping of devices if this all-reduce is both 435 // cross-partition and cross-replica. 436 // 437 // For example with 2 replicas and 4 partitions, 438 // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that 439 // group[0] = (0,0), (0,1), (1,0), (1,1) 440 // group[1] = (0,2), (0,3), (1,2), (1,3) 441 // where each pair is (replica_id, partition_id). use_global_device_ids()442 bool use_global_device_ids() const { return use_global_device_ids_; } 443 444 protected: 445 std::vector<string> ExtraAttributesToStringImpl( 446 const HloPrintOptions& options) const override; 447 HloInstructionProto ToProto() const override; 448 449 private: 450 bool IdenticalSlowPathIgnoringChannelIdValues( 451 const HloInstruction& other, 452 const std::function<bool(const HloComputation*, const HloComputation*)>& 453 eq_computations) const override; 454 455 // Implementation for non-common logic of CloneWithNewOperands. 456 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 457 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 458 HloCloneContext* context) const override; 459 460 bool use_global_device_ids_; 461 }; 462 463 class HloAllToAllInstruction : public HloCollectiveInstruction { 464 public: 465 explicit HloAllToAllInstruction( 466 const Shape& shape, absl::Span<HloInstruction* const> operands, 467 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 468 const absl::optional<int64>& channel_id, 469 const absl::optional<int64>& split_dimension); 470 471 // AllToAll can optionally take a split dimension, which means that this 472 // AllToAll takes a single (flattened) array operand and produces an array 473 // output (instead of taking a list of operands and producing a tuple). 474 // 475 // split_dimension specifies which dimension in the operand is split across 476 // devices in each replica_group, and also means the concatenated dimension 477 // on the output (i.e., input and the output shapes are the same). split_dimension()478 absl::optional<int64> split_dimension() const { return split_dimension_; } set_split_dimension(int64 dim)479 void set_split_dimension(int64 dim) { split_dimension_ = dim; } 480 481 protected: 482 std::vector<string> ExtraAttributesToStringImpl( 483 const HloPrintOptions& options) const override; 484 HloInstructionProto ToProto() const override; 485 486 private: 487 bool IdenticalSlowPathIgnoringChannelIdValues( 488 const HloInstruction& other, 489 const std::function<bool(const HloComputation*, const HloComputation*)>& 490 eq_computations) const override; 491 492 // Implementation for non-common logic of CloneWithNewOperands. 493 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 494 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 495 HloCloneContext* context) const override; 496 497 absl::optional<int64> split_dimension_; 498 }; 499 500 class HloCollectivePermuteInstruction : public HloChannelInstruction { 501 public: 502 explicit HloCollectivePermuteInstruction( 503 HloOpcode opcode, const Shape& shape, HloInstruction* operand, 504 const std::vector<std::pair<int64, int64>>& source_target_pairs, 505 const absl::optional<int64>& channel_id); 506 source_target_pairs()507 const std::vector<std::pair<int64, int64>>& source_target_pairs() const { 508 return source_target_pairs_; 509 } 510 511 // Returns a serialized representation of this instruction. 512 HloInstructionProto ToProto() const override; 513 514 private: 515 std::vector<string> ExtraAttributesToStringImpl( 516 const HloPrintOptions& options) const override; 517 bool IdenticalSlowPathIgnoringChannelIdValues( 518 const HloInstruction& other, 519 const std::function<bool(const HloComputation*, const HloComputation*)>& 520 eq_computations) const override; 521 522 // Implementation for non-common logic of CloneWithNewOperands. 523 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 524 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 525 HloCloneContext* context) const override; 526 527 const std::vector<std::pair<int64, int64>> source_target_pairs_; 528 }; 529 530 class HloReverseInstruction : public HloInstruction { 531 public: 532 explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, 533 absl::Span<const int64> dimensions); 534 // Returns the dimension sizes or numbers associated with this instruction. dimensions()535 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)536 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()537 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 538 // Returns a serialized representation of this instruction. 539 HloInstructionProto ToProto() const override; 540 541 private: 542 std::vector<string> ExtraAttributesToStringImpl( 543 const HloPrintOptions& options) const override; 544 bool IdenticalSlowPath( 545 const HloInstruction& other, 546 const std::function<bool(const HloComputation*, const HloComputation*)>& 547 eq_computations) const override; 548 // Implementation for non-common logic of CloneWithNewOperands. 549 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 550 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 551 HloCloneContext* context) const override; 552 553 std::vector<int64> dimensions_; 554 }; 555 556 class HloConcatenateInstruction : public HloInstruction { 557 public: 558 explicit HloConcatenateInstruction(const Shape& shape, 559 absl::Span<HloInstruction* const> operands, 560 int64 dimension); 561 // Returns the dimension sizes or numbers associated with this instruction. dimensions()562 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)563 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()564 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 565 // Accessor for the dimension in which a concatenate HLO should occur. concatenate_dimension()566 int64 concatenate_dimension() const { return dimensions(0); } 567 // Returns a serialized representation of this instruction. 568 HloInstructionProto ToProto() const override; 569 570 private: 571 std::vector<string> ExtraAttributesToStringImpl( 572 const HloPrintOptions& options) const override; 573 bool IdenticalSlowPath( 574 const HloInstruction& other, 575 const std::function<bool(const HloComputation*, const HloComputation*)>& 576 eq_computations) const override; 577 // Implementation for non-common logic of CloneWithNewOperands. 578 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 579 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 580 HloCloneContext* context) const override; 581 582 std::vector<int64> dimensions_; 583 }; 584 585 class HloReduceInstruction : public HloInstruction { 586 public: 587 explicit HloReduceInstruction(const Shape& shape, 588 absl::Span<HloInstruction* const> args, 589 absl::Span<const int64> dimensions_to_reduce, 590 HloComputation* reduce_computation); 591 // Returns the dimension sizes or numbers associated with this instruction. dimensions()592 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)593 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()594 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 595 // Returns a serialized representation of this instruction. 596 HloInstructionProto ToProto() const override; 597 598 // Returns the number of input arrays (and, consequentially, the number of 599 // init values) this reduce has. input_count()600 int64 input_count() const { return operand_count() / 2; } 601 602 // Returns the input tensors to be reduced. inputs()603 absl::Span<HloInstruction* const> inputs() const { 604 return absl::MakeSpan(operands()).subspan(0, input_count()); 605 } 606 607 // Returns the init values of the reduction. init_values()608 absl::Span<HloInstruction* const> init_values() const { 609 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 610 } 611 612 private: 613 std::vector<string> ExtraAttributesToStringImpl( 614 const HloPrintOptions& options) const override; 615 bool IdenticalSlowPath( 616 const HloInstruction& other, 617 const std::function<bool(const HloComputation*, const HloComputation*)>& 618 eq_computations) const override; 619 // Implementation for non-common logic of CloneWithNewOperands. 620 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 621 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 622 HloCloneContext* context) const override; 623 624 std::vector<int64> dimensions_; 625 }; 626 627 class HloSortInstruction : public HloInstruction { 628 public: 629 explicit HloSortInstruction(const Shape& shape, int64 dimension, 630 absl::Span<HloInstruction* const> operands, 631 HloComputation* compare, bool is_stable); 632 // Returns the dimension sizes or numbers associated with this instruction. dimensions()633 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)634 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()635 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 636 // Returns the sort dimension for this instruction sort_dimension()637 int64 sort_dimension() const { return dimensions(0); } 638 // Returns a serialized representation of this instruction. 639 HloInstructionProto ToProto() const override; 640 // Returns the key operand to this instruction. keys()641 const HloInstruction* keys() const { return operand(0); } mutable_keys()642 HloInstruction* mutable_keys() { return mutable_operand(0); } 643 // Returns the number of value operands. values_count()644 int64 values_count() const { return operand_count() - 1; } is_stable()645 bool is_stable() const { return is_stable_; } 646 647 private: 648 std::vector<string> ExtraAttributesToStringImpl( 649 const HloPrintOptions& options) const override; 650 bool IdenticalSlowPath( 651 const HloInstruction& other, 652 const std::function<bool(const HloComputation*, const HloComputation*)>& 653 eq_computations) const override; 654 // Implementation for non-common logic of CloneWithNewOperands. 655 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 656 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 657 HloCloneContext* context) const override; 658 659 std::vector<int64> dimensions_; 660 bool is_stable_; 661 }; 662 663 class HloTransposeInstruction : public HloInstruction { 664 public: 665 explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, 666 absl::Span<const int64> dimensions); 667 // Returns the dimension sizes or numbers associated with this instruction. dimensions()668 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)669 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()670 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 671 // Returns whether this instruction does a rank-2 transposition. 672 bool IsRank2Transpose() const; 673 // Returns a serialized representation of this instruction. 674 HloInstructionProto ToProto() const override; 675 676 private: 677 std::vector<string> ExtraAttributesToStringImpl( 678 const HloPrintOptions& options) const override; 679 bool IdenticalSlowPath( 680 const HloInstruction& other, 681 const std::function<bool(const HloComputation*, const HloComputation*)>& 682 eq_computations) const override; 683 // Implementation for non-common logic of CloneWithNewOperands. 684 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 685 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 686 HloCloneContext* context) const override; 687 688 std::vector<int64> dimensions_; 689 }; 690 691 class HloBroadcastInstruction : public HloInstruction { 692 public: 693 explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, 694 absl::Span<const int64> broadcast_dimension); 695 // Returns the dimension sizes or numbers associated with this instruction. dimensions()696 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)697 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()698 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 699 // Returns a serialized representation of this instruction. 700 HloInstructionProto ToProto() const override; 701 702 private: 703 std::vector<string> ExtraAttributesToStringImpl( 704 const HloPrintOptions& options) const override; 705 bool IdenticalSlowPath( 706 const HloInstruction& other, 707 const std::function<bool(const HloComputation*, const HloComputation*)>& 708 eq_computations) const override; 709 // Implementation for non-common logic of CloneWithNewOperands. 710 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 711 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 712 HloCloneContext* context) const override; 713 714 std::vector<int64> dimensions_; 715 }; 716 717 class HloDynamicReshapeInstruction : public HloInstruction { 718 public: 719 explicit HloDynamicReshapeInstruction( 720 const Shape& shape, HloInstruction* data_operand, 721 absl::Span<HloInstruction* const> dim_sizes); 722 723 // Returns the input dim sizes dimensions, which is operands[1:] dim_sizes()724 absl::Span<HloInstruction* const> dim_sizes() const { 725 return absl::MakeSpan(operands()).subspan(1, operand_count()); 726 } 727 728 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 729 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 730 HloCloneContext* context) const override; 731 732 // Returns the input dim size dimension, which is operands[1+i] dim_sizes(int64 i)733 HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; } 734 }; 735 736 class HloReshapeInstruction : public HloInstruction { 737 public: 738 explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, 739 int64 inferred_dimension); inferred_dimension()740 int64 inferred_dimension() const { return inferred_dimension_; } 741 HloInstructionProto ToProto() const override; 742 743 private: 744 std::vector<string> ExtraAttributesToStringImpl( 745 const HloPrintOptions& options) const override; 746 bool IdenticalSlowPath( 747 const HloInstruction& other, 748 const std::function<bool(const HloComputation*, const HloComputation*)>& 749 eq_computations) const override; 750 // Implementation for non-common logic of CloneWithNewOperands. 751 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 752 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 753 HloCloneContext* context) const override; 754 int64 inferred_dimension_; 755 }; 756 757 class HloMapInstruction : public HloInstruction { 758 public: 759 explicit HloMapInstruction(const Shape& shape, 760 absl::Span<HloInstruction* const> operands, 761 HloComputation* map_computation); 762 // Returns the dimension sizes or numbers associated with this instruction. dimensions()763 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)764 int64 dimensions(int64 index) const override { return dimensions()[index]; } mutable_dimensions()765 std::vector<int64>* mutable_dimensions() override { return &dimensions_; } 766 // Returns a serialized representation of this instruction. 767 HloInstructionProto ToProto() const override; 768 769 private: 770 bool IsElementwiseImpl( 771 const absl::optional<int64>& operand_idx) const override; 772 std::vector<string> ExtraAttributesToStringImpl( 773 const HloPrintOptions& options) const override; 774 bool IdenticalSlowPath( 775 const HloInstruction& other, 776 const std::function<bool(const HloComputation*, const HloComputation*)>& 777 eq_computations) const override; 778 // Implementation for non-common logic of CloneWithNewOperands. 779 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 780 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 781 HloCloneContext* context) const override; 782 783 std::vector<int64> dimensions_; 784 }; 785 786 class HloSliceInstruction : public HloInstruction { 787 public: 788 explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, 789 absl::Span<const int64> start_indices, 790 absl::Span<const int64> limit_indices, 791 absl::Span<const int64> strides); 792 793 HloInstructionProto ToProto() const override; 794 795 // Returns the start index in the given dimension for a slice node. slice_starts(int64 dimension)796 int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } slice_starts()797 const std::vector<int64>& slice_starts() const { return slice_starts_; } mutable_slice_starts()798 std::vector<int64>* mutable_slice_starts() { return &slice_starts_; } 799 800 // Returns the (exclusive) limit index in the given dimension for a slice 801 // node. slice_limits(int64 dimension)802 int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } slice_limits()803 const std::vector<int64>& slice_limits() const { return slice_limits_; } mutable_slice_limits()804 std::vector<int64>* mutable_slice_limits() { return &slice_limits_; } 805 806 // Returns the stride in the given dimension for a slice node. slice_strides(int64 dimension)807 int64 slice_strides(int64 dimension) const { 808 return slice_strides_[dimension]; 809 } slice_strides()810 const std::vector<int64>& slice_strides() const { return slice_strides_; } mutable_slice_strides()811 std::vector<int64>* mutable_slice_strides() { return &slice_strides_; } 812 813 private: 814 std::vector<string> ExtraAttributesToStringImpl( 815 const HloPrintOptions& options) const override; 816 bool IdenticalSlowPath( 817 const HloInstruction& other, 818 const std::function<bool(const HloComputation*, const HloComputation*)>& 819 eq_computations) const override; 820 // Implementation for non-common logic of CloneWithNewOperands. 821 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 822 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 823 HloCloneContext* context) const override; 824 825 // Describes the [begin, end) index range for a slice. 826 std::vector<int64> slice_starts_; 827 std::vector<int64> slice_limits_; 828 std::vector<int64> slice_strides_; 829 }; 830 831 class HloConstantInstruction : public HloInstruction { 832 public: 833 explicit HloConstantInstruction(Literal literal); 834 explicit HloConstantInstruction(Literal literal, const Shape& shape); 835 // Used when the literal is too large and dropped. 836 explicit HloConstantInstruction(const Shape& shape); 837 // Returns the literal associated with this instruction. literal()838 const Literal& literal() const { return *literal_; } 839 // Returns the (mutable) literal associated with this instruction. mutable_literal()840 Literal* mutable_literal() { return &literal_.value(); } 841 // Returns whether there is literal associated with this instruction. HasLiteral()842 bool HasLiteral() const { return literal_.has_value(); } 843 // Returns a serialized representation of this instruction. 844 HloInstructionProto ToProto() const override; 845 846 // Change the layout for an Constant Hlo instruction to match new_layout. For 847 // tuple shaped constants shape_index is the path to the internal array 848 // subshape whose layout needs to be changed. 849 void RelayoutConstant(const Layout& new_layout, 850 const ShapeIndex& shape_index = {}); 851 852 private: 853 bool IsElementwiseImpl( 854 const absl::optional<int64>& operand_idx) const override; 855 bool IdenticalSlowPath( 856 const HloInstruction& other, 857 const std::function<bool(const HloComputation*, const HloComputation*)>& 858 eq_computations) const override; 859 string OperandsToStringWithCanonicalNameMap( 860 const HloPrintOptions& options, 861 CanonicalNameMap* canonical_name_map) const override; 862 // Implementation for non-common logic of CloneWithNewOperands. 863 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 864 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 865 HloCloneContext* context) const override; 866 absl::optional<Literal> literal_; 867 }; 868 869 class HloTraceInstruction : public HloInstruction { 870 public: 871 explicit HloTraceInstruction(const string& tag, HloInstruction* operand); 872 // Returns a tag to be used in tracing. TracingTag()873 string TracingTag() const { return literal_.GetR1U8AsString(); } 874 // Returns a serialized representation of this instruction. 875 HloInstructionProto ToProto() const override; 876 877 private: 878 bool IdenticalSlowPath( 879 const HloInstruction& other, 880 const std::function<bool(const HloComputation*, const HloComputation*)>& 881 eq_computations) const override; 882 // Implementation for non-common logic of CloneWithNewOperands. 883 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 884 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 885 HloCloneContext* context) const override; 886 Literal literal_; 887 }; 888 889 class HloFusionInstruction : public HloInstruction { 890 public: 891 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 892 HloInstruction* fused_root); 893 894 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 895 absl::Span<HloInstruction* const> operands, 896 HloComputation* fusion_computation); 897 898 string ToCategory() const override; 899 // Returns a serialized representation of this instruction. 900 HloInstructionProto ToProto() const override; 901 902 // Adds a new operand the fusion instruction. 903 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 904 905 // Merges the fused instructions from 'instruction_to_merge' into the 906 // fused instruction set of 'this', updating operands as necessary. 907 // 908 // Precondition: 'instruction_to_merge' must be an operand of 'this'. 909 void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); 910 911 // Merges the fused instructions from instruction_to_merge into the fused 912 // instruction set of 'this' and generates multioutput fusion instructions. 913 // All the users of instruction_to_merge will be redirected to 'this' 914 // instruction. instruction_to_merge will be removed from its parent 915 // computation. 916 void MergeFusionInstructionIntoMultiOutput( 917 HloFusionInstruction* instruction_to_merge); 918 919 // Fuses the given instruction in this fusion instruction. instruction_to_fuse 920 // is cloned and the clone is placed in the fusion 921 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather 922 // than moved to cleanly handle the case where the instruction has a use 923 // outside the fusion instruction. Moving such an instruction into a fusion 924 // instruction would violate the single-result invariant of HLO instructions 925 // and significantly complicate code generation. FuseInstruction(HloInstruction * instruction_to_fuse)926 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { 927 return FuseInstructionInternal(instruction_to_fuse); 928 } 929 930 // Fuses the given instruction in this fusion instruction and generates a 931 // multioutput fusion instruction. A clone of the instruction_to_fuse will 932 // be part of the output of fusion instructions. The users of 933 // instruction_to_fuse will be redirected to this fusion instructions. 934 // instruction_to_fuse is unchanged otherwise. FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)935 HloInstruction* FuseInstructionIntoMultiOutput( 936 HloInstruction* instruction_to_fuse) { 937 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); 938 } 939 940 // Returns the computation for this fused instruction. 941 HloComputation* fused_instructions_computation() const; 942 943 // Returns the root instruction of the fused expression contained within this 944 // fusion instruction. 945 HloInstruction* fused_expression_root() const; 946 947 // Returns the list of fused instructions inside this fusion instruction. The 948 // returned type is a range of HloInstruction*s. 949 const tensorflow::gtl::iterator_range<UnwrappingIterator< 950 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 951 fused_instructions() const; 952 953 const tensorflow::gtl::iterator_range< 954 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 955 fused_instructions(); 956 957 // Gets the number of instructions inside this fusion instruction. 958 int64 fused_instruction_count() const; 959 960 // Returns the fused parameter instruction in this fusion instruction 961 // corresponding to the given parameter number. 962 HloInstruction* fused_parameter(int64 parameter_number) const; 963 964 // Returns the vector of fused parameters inside this fusion instruction. 965 const std::vector<HloInstruction*>& fused_parameters() const; 966 967 // Returns true if this instruction is a fusion instruction that generates 968 // multiple outputs. IsMultiOutputFusion()969 const bool IsMultiOutputFusion() const { 970 return fused_expression_root()->opcode() == HloOpcode::kTuple; 971 } 972 fusion_kind()973 FusionKind fusion_kind() const { return fusion_kind_; } 974 set_fusion_kind(FusionKind kind)975 void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } 976 977 // If multiple operands are the same instruction, keeps only one of them. 978 Status DeduplicateFusionOperands(); 979 980 private: 981 // Fuses the given instruction into this fusion instruction. 982 // instruction_to_fuse is cloned and the clone is placed in the fusion 983 // instruction. The users of instruction_to_fuse will be redirected to this 984 // fusion instruction. instruction_to_fuse is unchanged otherwise. When 985 // add_output is true, a clone of the instruction_to_fuse will be added as 986 // additional output resulting in a multi-output fusion. 987 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, 988 bool add_output = false); 989 // Clones the given instruction_to_fuse and insert the clone into this fusion 990 // instruction. If add_output is true, a clone of instruction_to_fuse will 991 // be in the output of the this fusion instruction (part of the tuple of the 992 // fusion root). 993 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, 994 bool add_output = false); 995 996 bool IsElementwiseImpl( 997 const absl::optional<int64>& operand_idx) const override; 998 std::vector<string> ExtraAttributesToStringImpl( 999 const HloPrintOptions& options) const override; 1000 bool IdenticalSlowPath( 1001 const HloInstruction& other, 1002 const std::function<bool(const HloComputation*, const HloComputation*)>& 1003 eq_computations) const override; 1004 uint64 InnerHash() const override; 1005 1006 // Implementation for non-common logic of CloneWithNewOperands. 1007 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1008 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1009 HloCloneContext* context) const override; 1010 1011 // The type of the fusion. Used by kFusion only. 1012 FusionKind fusion_kind_; 1013 }; 1014 1015 class HloRngInstruction : public HloInstruction { 1016 public: 1017 explicit HloRngInstruction(const Shape& shape, 1018 RandomDistribution distribution, 1019 absl::Span<HloInstruction* const> parameters); 1020 // Returns the random distribution for this rng node. random_distribution()1021 RandomDistribution random_distribution() const { return distribution_; } 1022 // Returns a serialized representation of this instruction. 1023 HloInstructionProto ToProto() const override; 1024 1025 private: 1026 bool IsElementwiseImpl( 1027 const absl::optional<int64>& operand_idx) const override; 1028 std::vector<string> ExtraAttributesToStringImpl( 1029 const HloPrintOptions& options) const override; 1030 bool IdenticalSlowPath( 1031 const HloInstruction& other, 1032 const std::function<bool(const HloComputation*, const HloComputation*)>& 1033 eq_computations) const override; 1034 // Implementation for non-common logic of CloneWithNewOperands. 1035 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1036 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1037 HloCloneContext* context) const override; 1038 1039 // The distribution requested for random number generation. 1040 RandomDistribution distribution_; 1041 }; 1042 1043 class HloParameterInstruction : public HloInstruction { 1044 public: 1045 explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, 1046 const string& name); parameter_number()1047 int64 parameter_number() const { return parameter_number_; } 1048 1049 // Sets and gets the whether all replicas will receive the same parameter data 1050 // for each leaf buffer in data parallelism. set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1051 void set_parameter_replicated_at_leaf_buffers( 1052 absl::Span<const bool> parameter_replicated_at_leaf_buffers) { 1053 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 1054 parameter_replicated_at_leaf_buffers.size()); 1055 parameter_replicated_at_leaf_buffers_.emplace( 1056 parameter_replicated_at_leaf_buffers.begin(), 1057 parameter_replicated_at_leaf_buffers.end()); 1058 } set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1059 void set_parameter_replicated_at_leaf_buffers( 1060 const std::vector<bool>& parameter_replicated_at_leaf_buffers) { 1061 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 1062 parameter_replicated_at_leaf_buffers.size()); 1063 parameter_replicated_at_leaf_buffers_ = 1064 parameter_replicated_at_leaf_buffers; 1065 } 1066 const absl::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()1067 parameter_replicated_at_leaf_buffers() const { 1068 return parameter_replicated_at_leaf_buffers_; 1069 } 1070 1071 // Returns a serialized representation of this instruction. 1072 HloInstructionProto ToProto() const override; 1073 1074 private: 1075 std::vector<string> ExtraAttributesToStringImpl( 1076 const HloPrintOptions& options) const override; 1077 bool IdenticalSlowPath( 1078 const HloInstruction& other, 1079 const std::function<bool(const HloComputation*, const HloComputation*)>& 1080 eq_computations) const override; 1081 string OperandsToStringWithCanonicalNameMap( 1082 const HloPrintOptions& options, 1083 CanonicalNameMap* canonical_name_map) const override; 1084 // Implementation for non-common logic of CloneWithNewOperands. 1085 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1086 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1087 HloCloneContext* context) const override; 1088 1089 int64 parameter_number_ = 0; 1090 1091 // Specifies whether each buffer has the same parameter value on all replicas 1092 // in data parallelism. 1093 absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_; 1094 }; 1095 1096 class HloGetTupleElementInstruction : public HloInstruction { 1097 public: 1098 explicit HloGetTupleElementInstruction(const Shape& shape, 1099 HloInstruction* operand, int64 index); 1100 // Returns the tuple index associated with this instruction. tuple_index()1101 int64 tuple_index() const { return tuple_index_; } 1102 // Sets the tuple index associated with this instruction. set_tuple_index(int64 new_tuple_index)1103 void set_tuple_index(int64 new_tuple_index) { 1104 tuple_index_ = new_tuple_index; 1105 } 1106 // Returns a serialized representation of this instruction. 1107 HloInstructionProto ToProto() const override; 1108 1109 private: 1110 std::vector<string> ExtraAttributesToStringImpl( 1111 const HloPrintOptions& options) const override; 1112 bool IdenticalSlowPath( 1113 const HloInstruction& other, 1114 const std::function<bool(const HloComputation*, const HloComputation*)>& 1115 eq_computations) const override; 1116 // Implementation for non-common logic of CloneWithNewOperands. 1117 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1118 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1119 HloCloneContext* context) const override; 1120 1121 int64 tuple_index_ = -1; 1122 }; 1123 1124 class HloReducePrecisionInstruction : public HloInstruction { 1125 public: 1126 explicit HloReducePrecisionInstruction(const Shape& shape, 1127 HloInstruction* operand, 1128 const int exponent_bits, 1129 const int mantissa_bits); 1130 // Returns the number of exponent bits for a reduce-precision node. exponent_bits()1131 int32 exponent_bits() const { return exponent_bits_; } 1132 // Returns the number of mantissa bits for a reduce-precision node. mantissa_bits()1133 int32 mantissa_bits() const { return mantissa_bits_; } 1134 // Returns a serialized representation of this instruction. 1135 HloInstructionProto ToProto() const override; 1136 1137 private: 1138 std::vector<string> ExtraAttributesToStringImpl( 1139 const HloPrintOptions& options) const override; 1140 bool IdenticalSlowPath( 1141 const HloInstruction& other, 1142 const std::function<bool(const HloComputation*, const HloComputation*)>& 1143 eq_computations) const override; 1144 // Implementation for non-common logic of CloneWithNewOperands. 1145 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1146 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1147 HloCloneContext* context) const override; 1148 1149 // The bit sizes for a reduce-precision operation. 1150 int32 exponent_bits_ = 0; 1151 int32 mantissa_bits_ = 0; 1152 }; 1153 1154 class HloInfeedInstruction : public HloInstruction { 1155 public: 1156 explicit HloInfeedInstruction(const Shape& infeed_shape, 1157 HloInstruction* token_operand, 1158 const string& config); 1159 // Returns the infeed configuration string. The infeed configuration includes 1160 // any metadata needed for the backend compiler (e.g., infeed buffer address) 1161 // and is target-dependent. infeed_config()1162 string infeed_config() const { return infeed_config_; } set_infeed_config(const string & config)1163 void set_infeed_config(const string& config) { infeed_config_ = config; } 1164 // Returns the shape of the data received by the infeed. This is not the same 1165 // as the shape of the infeed instruction which produces a tuple containing 1166 // the infeed data shape and a TOKEN. infeed_shape()1167 const Shape& infeed_shape() const { 1168 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); 1169 return ShapeUtil::GetSubshape(shape(), {0}); 1170 } 1171 // Returns a serialized representation of this instruction. 1172 HloInstructionProto ToProto() const override; 1173 1174 private: 1175 std::vector<string> ExtraAttributesToStringImpl( 1176 const HloPrintOptions& options) const override; 1177 bool IdenticalSlowPath( 1178 const HloInstruction& other, 1179 const std::function<bool(const HloComputation*, const HloComputation*)>& 1180 eq_computations) const override; 1181 // Implementation for non-common logic of CloneWithNewOperands. 1182 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1183 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1184 HloCloneContext* context) const override; 1185 1186 // The string representation of the infeed configuration. 1187 string infeed_config_; 1188 }; 1189 1190 class HloOutfeedInstruction : public HloInstruction { 1191 public: 1192 explicit HloOutfeedInstruction(const Shape& outfeed_shape, 1193 HloInstruction* operand, 1194 HloInstruction* token_operand, 1195 absl::string_view outfeed_config); 1196 // Returns the shape for the Outfeed instruction. outfeed_shape()1197 const Shape& outfeed_shape() const { return outfeed_shape_; } 1198 // Returns the mutable shape for the Outfeed instruction. mutable_outfeed_shape()1199 Shape* mutable_outfeed_shape() { return &outfeed_shape_; } 1200 // Returns the config for the Outfeed instruction. outfeed_config()1201 const string& outfeed_config() const { return outfeed_config_; } set_outfeed_config(const string & config)1202 void set_outfeed_config(const string& config) { outfeed_config_ = config; } 1203 // Returns a serialized representation of this instruction. 1204 HloInstructionProto ToProto() const override; 1205 1206 private: 1207 std::vector<string> ExtraAttributesToStringImpl( 1208 const HloPrintOptions& options) const override; 1209 bool IdenticalSlowPath( 1210 const HloInstruction& other, 1211 const std::function<bool(const HloComputation*, const HloComputation*)>& 1212 eq_computations) const override; 1213 // Implementation for non-common logic of CloneWithNewOperands. 1214 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1215 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1216 HloCloneContext* context) const override; 1217 1218 // Shape of outfeed request. 1219 Shape outfeed_shape_; 1220 // Outfeed configuration information, only present for kOutfeed. 1221 string outfeed_config_; 1222 }; 1223 1224 class HloConvolutionInstruction : public HloInstruction { 1225 public: 1226 explicit HloConvolutionInstruction( 1227 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 1228 int64 feature_group_count, int64 batch_group_count, const Window& window, 1229 const ConvolutionDimensionNumbers& dimension_numbers, 1230 const PrecisionConfig& precision_config); window()1231 const Window& window() const override { return window_; } set_window(const Window & window)1232 void set_window(const Window& window) override { window_ = window; } convolution_dimension_numbers()1233 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1234 return convolution_dimension_numbers_; 1235 } set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1236 void set_convolution_dimension_numbers( 1237 const ConvolutionDimensionNumbers& dnums) { 1238 convolution_dimension_numbers_ = dnums; 1239 } 1240 // The number of feature groups. Must be a divisor of the input feature 1241 // dimension and output feature dimension. feature_group_count()1242 int64 feature_group_count() const { return feature_group_count_; } set_feature_group_count(int64 num_feature_groups)1243 void set_feature_group_count(int64 num_feature_groups) { 1244 feature_group_count_ = num_feature_groups; 1245 } 1246 // The number of batch groups. Must be a divisor of the input batch dimension. batch_group_count()1247 int64 batch_group_count() const { return batch_group_count_; } set_batch_group_count(int64 num_batch_groups)1248 void set_batch_group_count(int64 num_batch_groups) { 1249 batch_group_count_ = num_batch_groups; 1250 } 1251 1252 // Returns the information used to tell the implementation information about 1253 // what sort of precision is requested. The meaning of the field is backend 1254 // specific. At the moment, it is only supported for kConvolution and kDot. 1255 // Transformations on one kDot or kConvolution to another will preserve this 1256 // information. Transformations to other HLOs will not preserve this 1257 // information but it is presumed that the alternate lowering is strictly 1258 // superior. precision_config()1259 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1260 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1261 1262 string ToCategory() const override; 1263 // Returns a serialized representation of this instruction. 1264 HloInstructionProto ToProto() const override; 1265 1266 private: 1267 std::vector<string> ExtraAttributesToStringImpl( 1268 const HloPrintOptions& options) const override; 1269 bool IdenticalSlowPath( 1270 const HloInstruction& other, 1271 const std::function<bool(const HloComputation*, const HloComputation*)>& 1272 eq_computations) const override; 1273 // Implementation for non-common logic of CloneWithNewOperands. 1274 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1275 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1276 HloCloneContext* context) const override; 1277 // The number of feature groups. Must be a divisor of the input feature 1278 // dimension and output feature dimension. 1279 int64 feature_group_count_; 1280 // The number of batch groups. Must be a divisor of the input batch dimension. 1281 int64 batch_group_count_; 1282 // Describes the window used for a convolution. 1283 Window window_; 1284 // Describes the dimension numbers used for a convolution. 1285 ConvolutionDimensionNumbers convolution_dimension_numbers_; 1286 // Information used to communicate to the implementation about the algorithm 1287 // used to produce results. See the documentation on precision_config(). 1288 PrecisionConfig precision_config_; 1289 }; 1290 1291 class HloReduceWindowInstruction : public HloInstruction { 1292 public: 1293 explicit HloReduceWindowInstruction(const Shape& shape, 1294 HloInstruction* operand, 1295 HloInstruction* init_value, 1296 const Window& window, 1297 HloComputation* reduce_computation); 1298 explicit HloReduceWindowInstruction( 1299 const Shape& shape, absl::Span<HloInstruction* const> operands, 1300 absl::Span<HloInstruction* const> init_values, const Window& window, 1301 HloComputation* reduce_computation); window()1302 const Window& window() const override { return window_; } set_window(const Window & window)1303 void set_window(const Window& window) override { window_ = window; } 1304 // Returns a serialized representation of this instruction. 1305 HloInstructionProto ToProto() const override; 1306 // Returns the number of input arrays (and, consequentially, the number of 1307 // init values) this reduce has. input_count()1308 int64 input_count() const { return operand_count() / 2; } 1309 // Returns the input tensors to be reduced. input_arrays()1310 absl::Span<HloInstruction* const> input_arrays() const { 1311 return absl::MakeSpan(operands()).subspan(0, input_count()); 1312 } 1313 // Returns the init values of the reduction. init_values()1314 absl::Span<HloInstruction* const> init_values() const { 1315 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 1316 } 1317 // Returns the shapes of input tensors to be reduced. input_array_shapes()1318 absl::InlinedVector<const Shape*, 2> input_array_shapes() const { 1319 absl::InlinedVector<const Shape*, 2> shapes; 1320 for (const auto* op : input_arrays()) { 1321 VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n"; 1322 shapes.push_back(&op->shape()); 1323 VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n"; 1324 } 1325 return shapes; 1326 } 1327 // Returns the init values of the reduction. init_value_shapes()1328 absl::InlinedVector<const Shape*, 2> init_value_shapes() const { 1329 absl::InlinedVector<const Shape*, 2> shapes; 1330 for (const auto* op : init_values()) { 1331 shapes.push_back(&op->shape()); 1332 } 1333 return shapes; 1334 } 1335 1336 private: 1337 std::vector<string> ExtraAttributesToStringImpl( 1338 const HloPrintOptions& options) const override; 1339 bool IdenticalSlowPath( 1340 const HloInstruction& other, 1341 const std::function<bool(const HloComputation*, const HloComputation*)>& 1342 eq_computations) const override; 1343 // Implementation for non-common logic of CloneWithNewOperands. 1344 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1345 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1346 HloCloneContext* context) const override; 1347 1348 Window window_; 1349 }; 1350 1351 class HloSelectAndScatterInstruction : public HloInstruction { 1352 public: 1353 explicit HloSelectAndScatterInstruction( 1354 const Shape& shape, HloInstruction* operand, HloComputation* select, 1355 const Window& window, HloInstruction* source, HloInstruction* init_value, 1356 HloComputation* scatter); window()1357 const Window& window() const override { return window_; } set_window(const Window & window)1358 void set_window(const Window& window) override { window_ = window; } 1359 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The 1360 // setters should only be called by HloModule or HloComputation methods. select()1361 HloComputation* select() const { 1362 return called_computations()[kSelectComputationIndex]; 1363 } 1364 scatter()1365 HloComputation* scatter() const { 1366 return called_computations()[kScatterComputationIndex]; 1367 } 1368 set_select(HloComputation * computation)1369 void set_select(HloComputation* computation) { 1370 // Don't allow changing the computation for fused instructions so we don't 1371 // have to recompute called_instructions for the entire fusion instruction. 1372 CHECK(!IsFused()); 1373 set_called_computation(kSelectComputationIndex, computation); 1374 } 1375 set_scatter(HloComputation * computation)1376 void set_scatter(HloComputation* computation) { 1377 // Don't allow changing the computation for fused instructions so we don't 1378 // have to recompute called_instructions for the entire fusion instruction. 1379 CHECK(!IsFused()); 1380 set_called_computation(kScatterComputationIndex, computation); 1381 } 1382 // Returns a serialized representation of this instruction. 1383 HloInstructionProto ToProto() const override; 1384 1385 private: 1386 std::vector<string> ExtraAttributesToStringImpl( 1387 const HloPrintOptions& options) const override; 1388 bool IdenticalSlowPath( 1389 const HloInstruction& other, 1390 const std::function<bool(const HloComputation*, const HloComputation*)>& 1391 eq_computations) const override; 1392 // Implementation for non-common logic of CloneWithNewOperands. 1393 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1394 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1395 HloCloneContext* context) const override; 1396 Window window_; 1397 }; 1398 1399 class HloCustomCallInstruction : public HloInstruction { 1400 public: 1401 HloCustomCallInstruction(const Shape& shape, 1402 absl::Span<HloInstruction* const> operands, 1403 absl::string_view custom_call_target, string opaque); 1404 1405 // Constructor for a custom call with constrained layout. 'shape' and 1406 // 'operands_with_layout' must all have layouts. 1407 HloCustomCallInstruction(const Shape& shape, 1408 absl::Span<HloInstruction* const> operands, 1409 absl::string_view custom_call_target, string opaque, 1410 absl::Span<const Shape> operand_shapes_with_layout); 1411 1412 // Constructor for a custom call with a to_apply computation. 1413 HloCustomCallInstruction(const Shape& shape, 1414 absl::Span<HloInstruction* const> operands, 1415 HloComputation* to_apply, 1416 absl::string_view custom_call_target, string opaque); 1417 1418 // Constructor for a custom call with multiple computations. 1419 HloCustomCallInstruction( 1420 const Shape& shape, absl::Span<HloInstruction* const> operands, 1421 absl::Span<HloComputation* const> called_computations, 1422 absl::string_view custom_call_target, string opaque); 1423 window()1424 const Window& window() const override { 1425 CHECK(window_ != nullptr); 1426 return *window_; 1427 } 1428 set_window(const Window & window)1429 void set_window(const Window& window) override { 1430 window_ = absl::make_unique<Window>(window); 1431 } 1432 convolution_dimension_numbers()1433 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1434 CHECK(convolution_dimension_numbers_ != nullptr); 1435 return *convolution_dimension_numbers_; 1436 } 1437 set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1438 void set_convolution_dimension_numbers( 1439 const ConvolutionDimensionNumbers& dnums) { 1440 convolution_dimension_numbers_ = 1441 absl::make_unique<ConvolutionDimensionNumbers>(dnums); 1442 } 1443 // TODO(jpienaar): Remove this accessor in the follow up. opaque()1444 const string& opaque() const { return raw_backend_config_string(); } custom_call_target()1445 const string& custom_call_target() const { return custom_call_target_; } set_feature_group_count(int64 feature_group_count)1446 void set_feature_group_count(int64 feature_group_count) { 1447 feature_group_count_ = feature_group_count; 1448 } set_batch_group_count(int64 batch_group_count)1449 void set_batch_group_count(int64 batch_group_count) { 1450 batch_group_count_ = batch_group_count; 1451 } 1452 // Sets whether this custom call has a side-effect - by default a custom call 1453 // has no side-effects. set_custom_call_has_side_effect(bool custom_call_has_side_effect)1454 void set_custom_call_has_side_effect(bool custom_call_has_side_effect) { 1455 custom_call_has_side_effect_ = custom_call_has_side_effect; 1456 } feature_group_count()1457 int64 feature_group_count() const { return feature_group_count_; } batch_group_count()1458 int64 batch_group_count() const { return batch_group_count_; } custom_call_has_side_effect()1459 bool custom_call_has_side_effect() const { 1460 return custom_call_has_side_effect_; 1461 } 1462 // Returns padding type used for ops like convolution. padding_type()1463 PaddingType padding_type() const { return padding_type_; } 1464 set_padding_type(PaddingType padding_type)1465 void set_padding_type(PaddingType padding_type) { 1466 padding_type_ = padding_type; 1467 } 1468 1469 // Returns the literal associated with this instruction. literal()1470 const Literal& literal() const { return *literal_; } 1471 // Set the value of literal to a new one. set_literal(Literal && literal)1472 void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); } 1473 // Returns whether there is literal associated with this instruction. HasLiteral()1474 bool HasLiteral() const { return literal_.has_value(); } 1475 precision_config()1476 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1477 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1478 1479 // Returns a serialized representation of this instruction. 1480 HloInstructionProto ToProto() const override; 1481 1482 // Returns whether the result and operand layouts are constrained. layout_constrained()1483 bool layout_constrained() const { return layout_constrained_; } 1484 1485 // Returns the shapes (with layout) of the operands. CHECKs if this custom 1486 // call does not have constrained layouts. operand_shapes_with_layout()1487 const std::vector<Shape>& operand_shapes_with_layout() const { 1488 CHECK(layout_constrained()); 1489 return operand_shapes_with_layout_; 1490 } 1491 // Gets a list of output/operand buffer pairs that alias each other, where the 1492 // output buffer is represented as a ShapeIndex, and the operand buffer is 1493 // represented as the operand index and the ShapeIndex. By default this list 1494 // is empty. 1495 const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>& output_to_operand_aliasing()1496 output_to_operand_aliasing() const { 1497 return output_to_operand_aliasing_; 1498 } 1499 // Sets the list of output/operand buffer pairs that alias each other. set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> aliasing)1500 void set_output_to_operand_aliasing( 1501 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 1502 aliasing) { 1503 output_to_operand_aliasing_ = std::move(aliasing); 1504 } 1505 1506 private: 1507 std::vector<string> ExtraAttributesToStringImpl( 1508 const HloPrintOptions& options) const override; 1509 bool IdenticalSlowPath( 1510 const HloInstruction& other, 1511 const std::function<bool(const HloComputation*, const HloComputation*)>& 1512 eq_computations) const override; 1513 // Implementation for non-common logic of CloneWithNewOperands. 1514 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1515 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1516 HloCloneContext* context) const override; 1517 // Name of a global symbol to call. 1518 string custom_call_target_; 1519 // Describes the window in a windowed operation such as convolution. 1520 std::unique_ptr<Window> window_; 1521 // Describes the dimension numbers used for a convolution. 1522 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; 1523 // The number of feature groups. This is used for grouped convolutions. 1524 int64 feature_group_count_; 1525 int64 batch_group_count_; 1526 // Whether the result and operand layouts are constrained. 1527 bool layout_constrained_; 1528 // Information used to communicate to the implementation about the algorithm 1529 // used to produce results for convolution instructions. 1530 PrecisionConfig precision_config_; 1531 // Describes the padding type for convolution instructions. 1532 PaddingType padding_type_; 1533 // For layout-constrained custom calls, this vector holds the shape with 1534 // layout for each operand. 1535 std::vector<Shape> operand_shapes_with_layout_; 1536 // Whether this custom call has a side-effect. 1537 bool custom_call_has_side_effect_; 1538 // A list of output/operand buffer pairs that alias each other. See comment of 1539 // output_to_operand_aliasing(). 1540 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 1541 output_to_operand_aliasing_; 1542 absl::optional<Literal> literal_; 1543 }; 1544 1545 class HloPadInstruction : public HloInstruction { 1546 public: 1547 explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, 1548 HloInstruction* padding_value, 1549 const PaddingConfig& padding_config); 1550 // Returns the padding configuration for a pad node. padding_config()1551 const PaddingConfig& padding_config() const { return padding_config_; } mutable_padding_config()1552 PaddingConfig* mutable_padding_config() { return &padding_config_; } 1553 // Returns the padding value. padding_value()1554 const HloInstruction* padding_value() const { return operand(1); } mutable_padding_value()1555 HloInstruction* mutable_padding_value() { return mutable_operand(1); } 1556 // Returns a serialized representation of this instruction. 1557 HloInstructionProto ToProto() const override; 1558 1559 private: 1560 std::vector<string> ExtraAttributesToStringImpl( 1561 const HloPrintOptions& options) const override; 1562 bool IdenticalSlowPath( 1563 const HloInstruction& other, 1564 const std::function<bool(const HloComputation*, const HloComputation*)>& 1565 eq_computations) const override; 1566 // Implementation for non-common logic of CloneWithNewOperands. 1567 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1568 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1569 HloCloneContext* context) const override; 1570 1571 // The padding configuration that describes the edge padding and interior 1572 // padding of this pad instruction. 1573 PaddingConfig padding_config_; 1574 }; 1575 1576 class HloDynamicIndexInstruction : public HloInstruction { 1577 public: HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1578 explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) 1579 : HloInstruction(opcode, shape) {} 1580 virtual int64 first_index_operand_number() const = 0; 1581 1582 // Returns a subspan of operands which represent the start indices. index_operands()1583 absl::Span<HloInstruction* const> index_operands() const { 1584 return absl::MakeSpan(operands()).subspan(first_index_operand_number()); 1585 } 1586 1587 // Returns the shapes of the index operands. index_shapes()1588 std::vector<Shape> index_shapes() const { 1589 std::vector<Shape> shapes; 1590 auto indices = index_operands(); 1591 for (const HloInstruction* index : indices) { 1592 shapes.push_back(index->shape()); 1593 } 1594 return shapes; 1595 } 1596 }; 1597 1598 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { 1599 public: 1600 explicit HloDynamicSliceInstruction(const Shape& shape, 1601 HloInstruction* operand, 1602 HloInstruction* start_indices, 1603 absl::Span<const int64> slice_sizes); 1604 explicit HloDynamicSliceInstruction( 1605 const Shape& shape, HloInstruction* operand, 1606 absl::Span<HloInstruction* const> start_indices, 1607 absl::Span<const int64> slice_sizes); 1608 // Old methods kept for smooth subclassing transition END. 1609 // Returns the size of the slice in the given dimension for a dynamic 1610 // slice node. slice_sizes(int64 dimension)1611 int64 slice_sizes(int64 dimension) const { 1612 return dynamic_slice_sizes_[dimension]; 1613 } dynamic_slice_sizes()1614 const std::vector<int64>& dynamic_slice_sizes() const { 1615 return dynamic_slice_sizes_; 1616 } 1617 // Returns a serialized representation of this instruction. 1618 HloInstructionProto ToProto() const override; 1619 first_index_operand_number()1620 int64 first_index_operand_number() const override { return 1; } 1621 1622 private: 1623 std::vector<string> ExtraAttributesToStringImpl( 1624 const HloPrintOptions& options) const override; 1625 bool IdenticalSlowPath( 1626 const HloInstruction& other, 1627 const std::function<bool(const HloComputation*, const HloComputation*)>& 1628 eq_computations) const override; 1629 // Implementation for non-common logic of CloneWithNewOperands. 1630 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1631 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1632 HloCloneContext* context) const override; 1633 1634 // Describes the [start, start + size) range size for a dynamic slice 1635 // ('start' is specified dynamically in the second operand of the operation). 1636 std::vector<int64> dynamic_slice_sizes_; 1637 }; 1638 1639 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { 1640 public: 1641 explicit HloDynamicUpdateSliceInstruction(const Shape& shape, 1642 HloInstruction* operand, 1643 HloInstruction* update, 1644 HloInstruction* start_indices); 1645 explicit HloDynamicUpdateSliceInstruction( 1646 const Shape& shape, HloInstruction* operand, HloInstruction* update, 1647 absl::Span<HloInstruction* const> start_indices); 1648 first_index_operand_number()1649 int64 first_index_operand_number() const override { return 2; } 1650 }; 1651 1652 class HloGatherInstruction : public HloInstruction { 1653 public: 1654 explicit HloGatherInstruction( 1655 const Shape& shape, HloInstruction* operand, 1656 HloInstruction* start_indices, 1657 const GatherDimensionNumbers& gather_dim_numbers, 1658 absl::Span<const int64> slice_sizes, bool indices_are_sorted); gather_dimension_numbers()1659 const GatherDimensionNumbers& gather_dimension_numbers() const { 1660 CHECK(gather_dimension_numbers_ != nullptr); 1661 return *gather_dimension_numbers_; 1662 } gather_slice_sizes()1663 absl::Span<const int64> gather_slice_sizes() const { 1664 return gather_slice_sizes_; 1665 } indices_are_sorted()1666 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1667 void set_indices_are_sorted(bool indices_are_sorted) { 1668 indices_are_sorted_ = indices_are_sorted; 1669 } 1670 // Returns a serialized representation of this instruction. 1671 HloInstructionProto ToProto() const override; 1672 1673 // Creates an instance of GatherDimensionNumbers. 1674 static GatherDimensionNumbers MakeGatherDimNumbers( 1675 absl::Span<const int64> offset_dims, 1676 absl::Span<const int64> collapsed_slice_dims, 1677 absl::Span<const int64> start_index_map, int64 index_vector_dim); 1678 // Returns the dump string of the given gather dimension numbers. 1679 static string GatherDimensionNumbersToString( 1680 const GatherDimensionNumbers& gather_dimension_numbers); 1681 1682 private: 1683 std::vector<string> ExtraAttributesToStringImpl( 1684 const HloPrintOptions& options) const override; 1685 bool IdenticalSlowPath( 1686 const HloInstruction& other, 1687 const std::function<bool(const HloComputation*, const HloComputation*)>& 1688 eq_computations) const override; 1689 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1690 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1691 HloCloneContext* context) const override; 1692 1693 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; 1694 std::vector<int64> gather_slice_sizes_; 1695 bool indices_are_sorted_; 1696 }; 1697 1698 class HloScatterInstruction : public HloInstruction { 1699 public: 1700 explicit HloScatterInstruction( 1701 const Shape& shape, HloInstruction* operand, 1702 HloInstruction* scatter_indices, HloInstruction* updates, 1703 HloComputation* update_computation, 1704 const ScatterDimensionNumbers& scatter_dim_numbers, 1705 bool indices_are_sorted, bool unique_indices); scatter_dimension_numbers()1706 const ScatterDimensionNumbers& scatter_dimension_numbers() const { 1707 CHECK(scatter_dimension_numbers_ != nullptr); 1708 return *scatter_dimension_numbers_; 1709 } indices_are_sorted()1710 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1711 void set_indices_are_sorted(bool indices_are_sorted) { 1712 indices_are_sorted_ = indices_are_sorted; 1713 } unique_indices()1714 bool unique_indices() const override { return unique_indices_; } 1715 // Returns a serialized representation of this instruction. 1716 HloInstructionProto ToProto() const override; 1717 1718 // Creates an instance of ScatterDimensionNumbers. 1719 static ScatterDimensionNumbers MakeScatterDimNumbers( 1720 absl::Span<const int64> update_window_dims, 1721 absl::Span<const int64> inserted_window_dims, 1722 absl::Span<const int64> scatter_dims_to_operand_dims, 1723 int64 index_vector_dim); 1724 // Returns the dump string of the given scatter dimension numbers. 1725 static string ScatterDimensionNumbersToString( 1726 const ScatterDimensionNumbers& scatter_dimension_numbers); 1727 1728 private: 1729 std::vector<string> ExtraAttributesToStringImpl( 1730 const HloPrintOptions& options) const override; 1731 bool IdenticalSlowPath( 1732 const HloInstruction& other, 1733 const std::function<bool(const HloComputation*, const HloComputation*)>& 1734 eq_computations) const override; 1735 // Implementation for non-common logic of CloneWithNewOperands. 1736 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1737 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1738 HloCloneContext* context) const override; 1739 1740 std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; 1741 bool indices_are_sorted_; 1742 bool unique_indices_; 1743 }; 1744 1745 class HloIotaInstruction : public HloInstruction { 1746 public: 1747 explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); 1748 // Returns the dimension sizes or numbers associated with this instruction. iota_dimension()1749 int64 iota_dimension() const { return iota_dimension_; } 1750 // Returns a serialized representation of this instruction. 1751 HloInstructionProto ToProto() const override; 1752 1753 private: 1754 std::vector<string> ExtraAttributesToStringImpl( 1755 const HloPrintOptions& options) const override; 1756 bool IdenticalSlowPath( 1757 const HloInstruction& other, 1758 const std::function<bool(const HloComputation*, const HloComputation*)>& 1759 eq_computations) const override; 1760 // Implementation for non-common logic of CloneWithNewOperands. 1761 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1762 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1763 HloCloneContext* context) const override; 1764 1765 const int64 iota_dimension_; 1766 }; 1767 1768 class HloDotInstruction : public HloInstruction { 1769 public: 1770 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 1771 // dimensions specified in 'dimension_numbers'. 1772 explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, 1773 HloInstruction* rhs, 1774 const DotDimensionNumbers& dimension_numbers, 1775 const PrecisionConfig& precision_config); 1776 1777 // Returns data on the dimension numbers used for a dot operation. dot_dimension_numbers()1778 const DotDimensionNumbers& dot_dimension_numbers() const { 1779 return dot_dimension_numbers_; 1780 } 1781 1782 // Returns the information used to tell the implementation information about 1783 // what sort of precision is requested. The meaning of the field is backend 1784 // specific. At the moment, it is only supported for kConvolution and kDot. 1785 // Transformations on one kDot or kConvolution to another will preserve this 1786 // information. Transformations to other HLOs will not preserve this 1787 // information but it is presumed that the alternate lowering is strictly 1788 // superior. precision_config()1789 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1790 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1791 1792 // Returns a serialized representation of this instruction. 1793 HloInstructionProto ToProto() const override; 1794 1795 private: 1796 std::vector<string> ExtraAttributesToStringImpl( 1797 const HloPrintOptions& options) const override; 1798 bool IdenticalSlowPath( 1799 const HloInstruction& other, 1800 const std::function<bool(const HloComputation*, const HloComputation*)>& 1801 eq_computations) const override; 1802 // Implementation for non-common logic of CloneWithNewOperands. 1803 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1804 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1805 HloCloneContext* context) const override; 1806 // Returns the dump string of the dot dimension numbers. 1807 string DotDimensionNumbersToString() const; 1808 1809 // Describes the dimension numbers used for a dot. 1810 DotDimensionNumbers dot_dimension_numbers_; 1811 1812 // Information used to communicate to the implementation about the algorithm 1813 // used to produce results. See the documentation on precision_config(). 1814 PrecisionConfig precision_config_; 1815 }; 1816 1817 class HloDomainInstruction : public HloInstruction { 1818 public: 1819 explicit HloDomainInstruction( 1820 const Shape& shape, HloInstruction* operand, 1821 std::unique_ptr<DomainMetadata> operand_side_metadata, 1822 std::unique_ptr<DomainMetadata> user_side_metadata); 1823 1824 // Returns a serialized representation of this instruction. 1825 HloInstructionProto ToProto() const override; 1826 1827 // Retrieves the operand side metadata of a kDomain instruction. operand_side_metadata()1828 const DomainMetadata& operand_side_metadata() const { 1829 return *operand_side_metadata_; 1830 } 1831 // Retrieves the user side metadata of a kDomain instruction. user_side_metadata()1832 const DomainMetadata& user_side_metadata() const { 1833 return *user_side_metadata_; 1834 } 1835 1836 private: 1837 std::vector<string> ExtraAttributesToStringImpl( 1838 const HloPrintOptions& options) const override; 1839 bool IdenticalSlowPath( 1840 const HloInstruction& other, 1841 const std::function<bool(const HloComputation*, const HloComputation*)>& 1842 eq_computations) const override; 1843 // Implementation for non-common logic of CloneWithNewOperands. 1844 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1845 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1846 HloCloneContext* context) const override; 1847 1848 std::unique_ptr<DomainMetadata> operand_side_metadata_; 1849 std::unique_ptr<DomainMetadata> user_side_metadata_; 1850 }; 1851 1852 class HloGetDimensionSizeInstruction : public HloInstruction { 1853 public: 1854 explicit HloGetDimensionSizeInstruction(const Shape& shape, 1855 HloInstruction* operand, 1856 int64 dimension); 1857 1858 // Returns the dimension sizes or numbers associated with this instruction. dimension()1859 int64 dimension() const { return dimension_; } 1860 // Returns a serialized representation of this instruction. 1861 HloInstructionProto ToProto() const override; 1862 1863 private: 1864 std::vector<string> ExtraAttributesToStringImpl( 1865 const HloPrintOptions& options) const override; 1866 bool IdenticalSlowPath( 1867 const HloInstruction& other, 1868 const std::function<bool(const HloComputation*, const HloComputation*)>& 1869 eq_computations) const override; 1870 // Implementation for non-common logic of CloneWithNewOperands. 1871 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1872 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1873 HloCloneContext* context) const override; 1874 1875 int64 dimension_; 1876 }; 1877 1878 class HloSetDimensionSizeInstruction : public HloInstruction { 1879 public: 1880 explicit HloSetDimensionSizeInstruction(const Shape& shape, 1881 HloInstruction* operand, 1882 HloInstruction* val, int64 dimension); 1883 1884 // Returns the dimension sizes or numbers associated with this instruction. dimension()1885 int64 dimension() const { return dimension_; } 1886 // Returns a serialized representation of this instruction. 1887 HloInstructionProto ToProto() const override; 1888 1889 private: 1890 std::vector<string> ExtraAttributesToStringImpl( 1891 const HloPrintOptions& options) const override; 1892 bool IdenticalSlowPath( 1893 const HloInstruction& other, 1894 const std::function<bool(const HloComputation*, const HloComputation*)>& 1895 eq_computations) const override; 1896 // Implementation for non-common logic of CloneWithNewOperands. 1897 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1898 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1899 HloCloneContext* context) const override; 1900 1901 int64 dimension_; 1902 }; 1903 1904 class HloRngGetAndUpdateStateInstruction : public HloInstruction { 1905 public: 1906 explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta); 1907 1908 // Returns the delta value. delta()1909 int64 delta() const { return delta_; } set_delta(int64 delta)1910 void set_delta(int64 delta) { delta_ = delta; } 1911 // Returns a serialized representation of this instruction. 1912 HloInstructionProto ToProto() const override; 1913 1914 private: 1915 std::vector<string> ExtraAttributesToStringImpl( 1916 const HloPrintOptions& options) const override; 1917 bool IdenticalSlowPath( 1918 const HloInstruction& other, 1919 const std::function<bool(const HloComputation*, const HloComputation*)>& 1920 eq_computations) const override; 1921 // Implementation for non-common logic of CloneWithNewOperands. 1922 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1923 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1924 HloCloneContext* context) const override; 1925 1926 int64 delta_; 1927 }; 1928 1929 class HloRngBitGeneratorInstruction : public HloInstruction { 1930 public: 1931 HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state, 1932 RandomAlgorithm algorithm); 1933 algorithm()1934 RandomAlgorithm algorithm() const { return algorithm_; } 1935 HloInstructionProto ToProto() const override; 1936 1937 private: 1938 std::vector<string> ExtraAttributesToStringImpl( 1939 const HloPrintOptions& options) const override; 1940 bool IdenticalSlowPath( 1941 const HloInstruction& other, 1942 const std::function<bool(const HloComputation*, const HloComputation*)>& 1943 eq_computations) const override; 1944 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1945 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1946 HloCloneContext* context) const override; 1947 1948 RandomAlgorithm algorithm_; 1949 }; 1950 1951 } // namespace xla 1952 1953 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 1954