1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // All HloInstruction subclasses are put in this file. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 24 namespace xla { 25 26 class HloBatchNormInstruction : public HloInstruction { 27 public: 28 // Returns feature_index field associated with the instruction. The index 29 // represents the index of the feature dimension. feature_index()30 int64 feature_index() const { return feature_index_; } 31 32 // Returns a epsilon value associated with the instruction. The is a small 33 // number added to the variance to avoid divide-by-zero error. epsilon()34 float epsilon() const { return epsilon_; } 35 36 // Returns a serialized representation of this instruction. 37 HloInstructionProto ToProto() const override; 38 39 protected: 40 explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, 41 HloInstruction* operand, 42 HloInstruction* scale, float epsilon, 43 int64 feature_index); 44 45 private: 46 std::vector<string> ExtraAttributesToStringImpl( 47 const HloPrintOptions& options) const override; 48 bool IdenticalSlowPath( 49 const HloInstruction& other, 50 const std::function<bool(const HloComputation*, const HloComputation*)>& 51 eq_computations) const override; 52 // A small float number added to the variance to avoid divide-by-zero error. 53 float epsilon_ = 0.0f; 54 55 // An integer value representing the index of the feature dimension. 56 int64 feature_index_ = -1; 57 }; 58 59 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { 60 public: 61 explicit HloBatchNormTrainingInstruction(const Shape& shape, 62 HloInstruction* operand, 63 HloInstruction* scale, 64 HloInstruction* offset, 65 float epsilon, int64 feature_index); 66 67 private: 68 // Implementation for non-common logic of CloneWithNewOperands. 69 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 70 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 71 HloCloneContext* context) const override; 72 }; 73 74 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { 75 public: 76 explicit HloBatchNormInferenceInstruction( 77 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 78 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 79 float epsilon, int64 feature_index); 80 81 private: 82 // Implementation for non-common logic of CloneWithNewOperands. 83 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 84 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 85 HloCloneContext* context) const override; 86 }; 87 88 class HloBatchNormGradInstruction : public HloBatchNormInstruction { 89 public: 90 explicit HloBatchNormGradInstruction( 91 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 92 HloInstruction* mean, HloInstruction* variance, 93 HloInstruction* grad_output, float epsilon, int64 feature_index); 94 95 private: 96 // Implementation for non-common logic of CloneWithNewOperands. 97 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 98 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 99 HloCloneContext* context) const override; 100 }; 101 102 class HloFftInstruction : public HloInstruction { 103 public: 104 explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, 105 FftType fft_type, 106 absl::Span<const int64> fft_length); fft_type()107 FftType fft_type() const { return fft_type_; } 108 fft_length()109 const std::vector<int64>& fft_length() const { return fft_length_; } 110 111 // Returns a serialized representation of this instruction. 112 HloInstructionProto ToProto() const override; 113 114 private: 115 std::vector<string> ExtraAttributesToStringImpl( 116 const HloPrintOptions& options) const override; 117 bool IdenticalSlowPath( 118 const HloInstruction& other, 119 const std::function<bool(const HloComputation*, const HloComputation*)>& 120 eq_computations) const override; 121 122 // Implementation for non-common logic of CloneWithNewOperands. 123 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 124 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 125 HloCloneContext* context) const override; 126 127 // Describes FFT type for an FFT instruction. 128 FftType fft_type_ = FftType::FFT; 129 130 // Indicates the FFT length for an FFT instruction. 131 std::vector<int64> fft_length_; 132 }; 133 134 class HloCompareInstruction : public HloInstruction { 135 public: 136 explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, 137 HloInstruction* rhs, 138 ComparisonDirection direction); direction()139 ComparisonDirection direction() const { return direction_; } 140 HloInstructionProto ToProto() const override; 141 142 private: 143 std::vector<string> ExtraAttributesToStringImpl( 144 const HloPrintOptions& options) const override; 145 bool IdenticalSlowPath( 146 const HloInstruction& other, 147 const std::function<bool(const HloComputation*, const HloComputation*)>& 148 eq_computations) const override; 149 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 150 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 151 HloCloneContext* context) const override; 152 153 ComparisonDirection direction_; 154 }; 155 156 class HloTriangularSolveInstruction : public HloInstruction { 157 public: 158 explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, 159 HloInstruction* b, 160 const TriangularSolveOptions& options); triangular_solve_options()161 const TriangularSolveOptions& triangular_solve_options() const { 162 return triangular_solve_options_; 163 } 164 165 // Returns a serialized representation of this instruction. 166 HloInstructionProto ToProto() const override; 167 168 private: 169 std::vector<string> ExtraAttributesToStringImpl( 170 const HloPrintOptions& options) const override; 171 bool IdenticalSlowPath( 172 const HloInstruction& other, 173 const std::function<bool(const HloComputation*, const HloComputation*)>& 174 eq_computations) const override; 175 176 // Implementation for non-common logic of CloneWithNewOperands. 177 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 178 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 179 HloCloneContext* context) const override; 180 181 TriangularSolveOptions triangular_solve_options_; 182 }; 183 184 class HloCholeskyInstruction : public HloInstruction { 185 public: 186 explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, 187 const CholeskyOptions& options); cholesky_options()188 const CholeskyOptions& cholesky_options() const { return cholesky_options_; } 189 190 // Returns a serialized representation of this instruction. 191 HloInstructionProto ToProto() const override; 192 193 private: 194 std::vector<string> ExtraAttributesToStringImpl( 195 const HloPrintOptions& options) const override; 196 bool IdenticalSlowPath( 197 const HloInstruction& other, 198 const std::function<bool(const HloComputation*, const HloComputation*)>& 199 eq_computations) const override; 200 201 // Implementation for non-common logic of CloneWithNewOperands. 202 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 203 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 204 HloCloneContext* context) const override; 205 206 CholeskyOptions cholesky_options_; 207 }; 208 209 class HloSendRecvInstruction : public HloInstruction { 210 public: 211 // Returns the channel id associated with the instruction. The id is 212 // shared between each Send/Recv pair and is globally unique to identify each 213 // channel. channel_id()214 int64 channel_id() const { return channel_id_; } 215 216 // Returns whether this send/recv instruction sends data to/from the host. is_host_transfer()217 bool is_host_transfer() const { return is_host_transfer_; } 218 219 // Returns a serialized representation of this instruction. 220 HloInstructionProto ToProto() const override; 221 222 protected: 223 explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, 224 int64 channel_id, bool is_host_transfer); 225 226 private: 227 std::vector<string> ExtraAttributesToStringImpl( 228 const HloPrintOptions& options) const override; 229 bool IdenticalSlowPath( 230 const HloInstruction& other, 231 const std::function<bool(const HloComputation*, const HloComputation*)>& 232 eq_computations) const override; 233 // Represents a unique identifier for each Send/Recv instruction pair. 234 int64 channel_id_; 235 236 // Whether this send/recv instruction sends data to/from the host. 237 bool is_host_transfer_; 238 }; 239 240 class HloSendInstruction : public HloSendRecvInstruction { 241 public: 242 explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, 243 int64 channel_id, bool is_host_transfer); 244 245 private: 246 // Implementation for non-common logic of CloneWithNewOperands. 247 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 248 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 249 HloCloneContext* context) const override; 250 }; 251 252 class HloSendDoneInstruction : public HloSendRecvInstruction { 253 public: 254 explicit HloSendDoneInstruction(HloSendInstruction* operand, 255 bool is_host_transfer); 256 257 private: 258 // Implementation for non-common logic of CloneWithNewOperands. 259 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 260 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 261 HloCloneContext* context) const override; 262 }; 263 264 class HloRecvInstruction : public HloSendRecvInstruction { 265 public: 266 explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, 267 int64 channel_id, bool is_host_transfer); 268 269 private: 270 // Implementation for non-common logic of CloneWithNewOperands. 271 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 272 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 273 HloCloneContext* context) const override; 274 }; 275 276 class HloRecvDoneInstruction : public HloSendRecvInstruction { 277 public: 278 explicit HloRecvDoneInstruction(HloRecvInstruction* operand, 279 bool is_host_transfer); 280 281 private: 282 // Implementation for non-common logic of CloneWithNewOperands. 283 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 284 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 285 HloCloneContext* context) const override; 286 }; 287 288 class HloCollectiveInstruction : public HloInstruction { 289 public: replica_groups()290 const std::vector<ReplicaGroup>& replica_groups() const { 291 return replica_groups_; 292 } 293 294 protected: 295 explicit HloCollectiveInstruction( 296 HloOpcode opcode, const Shape& shape, 297 absl::Span<HloInstruction* const> operands, 298 const std::vector<ReplicaGroup>& replica_groups); 299 300 HloInstructionProto ToProto() const override; 301 302 std::vector<string> ExtraAttributesToStringImpl( 303 const HloPrintOptions& options) const override; 304 bool IdenticalSlowPath( 305 const HloInstruction& other, 306 const std::function<bool(const HloComputation*, const HloComputation*)>& 307 eq_computations) const override; 308 309 std::vector<ReplicaGroup> replica_groups_; 310 }; 311 312 class HloAllReduceInstruction : public HloCollectiveInstruction { 313 public: 314 explicit HloAllReduceInstruction( 315 const Shape& shape, absl::Span<HloInstruction* const> operands, 316 HloComputation* reduce_computation, 317 const std::vector<ReplicaGroup>& replica_groups, 318 absl::string_view barrier, const absl::optional<int64>& all_reduce_id); 319 320 // Returns the barrier config used for the AllReduce implementation of 321 // each backend. all_reduce_barrier()322 string all_reduce_barrier() const { return all_reduce_barrier_; } set_all_reduce_barrier(string barrier)323 void set_all_reduce_barrier(string barrier) { all_reduce_barrier_ = barrier; } 324 all_reduce_id()325 absl::optional<int64> all_reduce_id() const { return all_reduce_id_; } 326 void set_all_reduce_id(const absl::optional<int64>& all_reduce_id); 327 328 // Returns a serialized representation of this instruction. 329 HloInstructionProto ToProto() const override; 330 331 // Returns true if the AllReduce does no communication, so it's equivalent 332 // to a mem copy. 333 bool IsNoop() const; 334 335 private: 336 std::vector<string> ExtraAttributesToStringImpl( 337 const HloPrintOptions& options) const override; 338 bool IdenticalSlowPath( 339 const HloInstruction& other, 340 const std::function<bool(const HloComputation*, const HloComputation*)>& 341 eq_computations) const override; 342 343 // Implementation for non-common logic of CloneWithNewOperands. 344 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 345 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 346 HloCloneContext* context) const override; 347 348 // The string representation of the barrier config used for AllReduce. 349 string all_reduce_barrier_; 350 351 // For Allreduce nodes from different modules, if they have the same 352 // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be 353 // applied cross modules. 354 absl::optional<int64> all_reduce_id_; 355 }; 356 357 class HloAllToAllInstruction : public HloCollectiveInstruction { 358 public: 359 explicit HloAllToAllInstruction( 360 const Shape& shape, absl::Span<HloInstruction* const> operands, 361 const std::vector<ReplicaGroup>& replica_groups); 362 363 private: 364 // Implementation for non-common logic of CloneWithNewOperands. 365 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 366 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 367 HloCloneContext* context) const override; 368 }; 369 370 class HloCollectivePermuteInstruction : public HloInstruction { 371 public: 372 explicit HloCollectivePermuteInstruction( 373 const Shape& shape, HloInstruction* operand, 374 const std::vector<std::pair<int64, int64>>& source_target_pairs); 375 source_target_pairs()376 const std::vector<std::pair<int64, int64>>& source_target_pairs() const { 377 return source_target_pairs_; 378 } 379 380 // Returns a serialized representation of this instruction. 381 HloInstructionProto ToProto() const override; 382 383 private: 384 std::vector<string> ExtraAttributesToStringImpl( 385 const HloPrintOptions& options) const override; 386 bool IdenticalSlowPath( 387 const HloInstruction& other, 388 const std::function<bool(const HloComputation*, const HloComputation*)>& 389 eq_computations) const override; 390 391 // Implementation for non-common logic of CloneWithNewOperands. 392 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 393 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 394 HloCloneContext* context) const override; 395 396 const std::vector<std::pair<int64, int64>> source_target_pairs_; 397 }; 398 399 class HloReverseInstruction : public HloInstruction { 400 public: 401 explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, 402 absl::Span<const int64> dimensions); 403 // Returns the dimension sizes or numbers associated with this instruction. dimensions()404 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)405 int64 dimensions(int64 index) const override { return dimensions()[index]; } 406 // Returns a serialized representation of this instruction. 407 HloInstructionProto ToProto() const override; 408 409 private: 410 std::vector<string> ExtraAttributesToStringImpl( 411 const HloPrintOptions& options) const override; 412 bool IdenticalSlowPath( 413 const HloInstruction& other, 414 const std::function<bool(const HloComputation*, const HloComputation*)>& 415 eq_computations) const override; 416 // Implementation for non-common logic of CloneWithNewOperands. 417 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 418 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 419 HloCloneContext* context) const override; 420 421 std::vector<int64> dimensions_; 422 }; 423 424 class HloConcatenateInstruction : public HloInstruction { 425 public: 426 explicit HloConcatenateInstruction(const Shape& shape, 427 absl::Span<HloInstruction* const> operands, 428 int64 dimension); 429 // Returns the dimension sizes or numbers associated with this instruction. dimensions()430 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)431 int64 dimensions(int64 index) const override { return dimensions()[index]; } 432 // Accessor for the dimension in which a concatenate HLO should occur. concatenate_dimension()433 int64 concatenate_dimension() const { return dimensions(0); } 434 // Returns a serialized representation of this instruction. 435 HloInstructionProto ToProto() const override; 436 437 private: 438 std::vector<string> ExtraAttributesToStringImpl( 439 const HloPrintOptions& options) const override; 440 bool IdenticalSlowPath( 441 const HloInstruction& other, 442 const std::function<bool(const HloComputation*, const HloComputation*)>& 443 eq_computations) const override; 444 // Implementation for non-common logic of CloneWithNewOperands. 445 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 446 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 447 HloCloneContext* context) const override; 448 449 std::vector<int64> dimensions_; 450 }; 451 452 class HloReduceInstruction : public HloInstruction { 453 public: 454 explicit HloReduceInstruction(const Shape& shape, 455 absl::Span<HloInstruction* const> args, 456 absl::Span<const int64> dimensions_to_reduce, 457 HloComputation* reduce_computation); 458 // Returns the dimension sizes or numbers associated with this instruction. dimensions()459 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)460 int64 dimensions(int64 index) const override { return dimensions()[index]; } 461 // Returns a serialized representation of this instruction. 462 HloInstructionProto ToProto() const override; 463 464 // Returns the number of input arrays (and, consequentially, the number of 465 // init values) this reduce has. input_count()466 int64 input_count() const { return operand_count() / 2; } 467 468 // Returns the input tensors to be reduced. inputs()469 absl::Span<HloInstruction* const> inputs() const { 470 return absl::MakeSpan(operands()).subspan(0, input_count()); 471 } 472 473 // Returns the init values of the reduction. init_values()474 absl::Span<HloInstruction* const> init_values() const { 475 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 476 } 477 478 private: 479 std::vector<string> ExtraAttributesToStringImpl( 480 const HloPrintOptions& options) const override; 481 bool IdenticalSlowPath( 482 const HloInstruction& other, 483 const std::function<bool(const HloComputation*, const HloComputation*)>& 484 eq_computations) const override; 485 // Implementation for non-common logic of CloneWithNewOperands. 486 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 487 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 488 HloCloneContext* context) const override; 489 490 std::vector<int64> dimensions_; 491 }; 492 493 class HloSortInstruction : public HloInstruction { 494 public: 495 explicit HloSortInstruction(const Shape& shape, int64 dimension, 496 absl::Span<HloInstruction* const> operands, 497 HloComputation* compare, bool is_stable); 498 // Returns the dimension sizes or numbers associated with this instruction. dimensions()499 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)500 int64 dimensions(int64 index) const override { return dimensions()[index]; } 501 // Returns the sort dimension for this instruction sort_dimension()502 int64 sort_dimension() const { return dimensions(0); } 503 // Returns a serialized representation of this instruction. 504 HloInstructionProto ToProto() const override; 505 // Returns the key operand to this instruction. keys()506 const HloInstruction* keys() const { return operand(0); } mutable_keys()507 HloInstruction* mutable_keys() { return mutable_operand(0); } 508 // Returns the number of value operands. values_count()509 int64 values_count() const { return operand_count() - 1; } is_stable()510 bool is_stable() const { return is_stable_; } 511 512 private: 513 std::vector<string> ExtraAttributesToStringImpl( 514 const HloPrintOptions& options) const override; 515 bool IdenticalSlowPath( 516 const HloInstruction& other, 517 const std::function<bool(const HloComputation*, const HloComputation*)>& 518 eq_computations) const override; 519 // Implementation for non-common logic of CloneWithNewOperands. 520 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 521 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 522 HloCloneContext* context) const override; 523 524 std::vector<int64> dimensions_; 525 bool is_stable_; 526 }; 527 528 class HloTransposeInstruction : public HloInstruction { 529 public: 530 explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, 531 absl::Span<const int64> dimensions); 532 // Returns the dimension sizes or numbers associated with this instruction. dimensions()533 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)534 int64 dimensions(int64 index) const override { return dimensions()[index]; } 535 // Returns whether this instruction does a rank-2 transposition. 536 bool IsRank2Transpose() const; 537 // Returns a serialized representation of this instruction. 538 HloInstructionProto ToProto() const override; 539 540 private: 541 std::vector<string> ExtraAttributesToStringImpl( 542 const HloPrintOptions& options) const override; 543 bool IdenticalSlowPath( 544 const HloInstruction& other, 545 const std::function<bool(const HloComputation*, const HloComputation*)>& 546 eq_computations) const override; 547 // Implementation for non-common logic of CloneWithNewOperands. 548 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 549 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 550 HloCloneContext* context) const override; 551 552 std::vector<int64> dimensions_; 553 }; 554 555 class HloBroadcastInstruction : public HloInstruction { 556 public: 557 explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, 558 absl::Span<const int64> broadcast_dimension); 559 // Returns the dimension sizes or numbers associated with this instruction. dimensions()560 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)561 int64 dimensions(int64 index) const override { return dimensions()[index]; } 562 // Returns a serialized representation of this instruction. 563 HloInstructionProto ToProto() const override; 564 565 private: 566 std::vector<string> ExtraAttributesToStringImpl( 567 const HloPrintOptions& options) const override; 568 bool IdenticalSlowPath( 569 const HloInstruction& other, 570 const std::function<bool(const HloComputation*, const HloComputation*)>& 571 eq_computations) const override; 572 // Implementation for non-common logic of CloneWithNewOperands. 573 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 574 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 575 HloCloneContext* context) const override; 576 577 std::vector<int64> dimensions_; 578 }; 579 580 class HloMapInstruction : public HloInstruction { 581 public: 582 explicit HloMapInstruction(const Shape& shape, 583 absl::Span<HloInstruction* const> operands, 584 HloComputation* map_computation); 585 // Returns the dimension sizes or numbers associated with this instruction. dimensions()586 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)587 int64 dimensions(int64 index) const override { return dimensions()[index]; } 588 // Returns a serialized representation of this instruction. 589 HloInstructionProto ToProto() const override; 590 591 private: 592 bool IsElementwiseImpl( 593 const absl::optional<int64>& operand_idx) const override; 594 std::vector<string> ExtraAttributesToStringImpl( 595 const HloPrintOptions& options) const override; 596 bool IdenticalSlowPath( 597 const HloInstruction& other, 598 const std::function<bool(const HloComputation*, const HloComputation*)>& 599 eq_computations) const override; 600 // Implementation for non-common logic of CloneWithNewOperands. 601 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 602 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 603 HloCloneContext* context) const override; 604 605 std::vector<int64> dimensions_; 606 }; 607 608 class HloSliceInstruction : public HloInstruction { 609 public: 610 explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, 611 absl::Span<const int64> start_indices, 612 absl::Span<const int64> limit_indices, 613 absl::Span<const int64> strides); 614 615 HloInstructionProto ToProto() const override; 616 617 // Returns the start index in the given dimension for a slice node. slice_starts(int64 dimension)618 int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } slice_starts()619 const std::vector<int64>& slice_starts() const { return slice_starts_; } 620 621 // Returns the (exclusive) limit index in the given dimension for a slice 622 // node. slice_limits(int64 dimension)623 int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } slice_limits()624 const std::vector<int64>& slice_limits() const { return slice_limits_; } 625 626 // Returns the stride in the given dimension for a slice node. slice_strides(int64 dimension)627 int64 slice_strides(int64 dimension) const { 628 return slice_strides_[dimension]; 629 } slice_strides()630 const std::vector<int64>& slice_strides() const { return slice_strides_; } 631 632 private: 633 std::vector<string> ExtraAttributesToStringImpl( 634 const HloPrintOptions& options) const override; 635 bool IdenticalSlowPath( 636 const HloInstruction& other, 637 const std::function<bool(const HloComputation*, const HloComputation*)>& 638 eq_computations) const override; 639 // Implementation for non-common logic of CloneWithNewOperands. 640 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 641 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 642 HloCloneContext* context) const override; 643 644 // Describes the [begin, end) index range for a slice. 645 std::vector<int64> slice_starts_; 646 std::vector<int64> slice_limits_; 647 std::vector<int64> slice_strides_; 648 }; 649 650 class HloConstantInstruction : public HloInstruction { 651 public: 652 explicit HloConstantInstruction(Literal literal); 653 // Used when the literal is too large and dropped. 654 explicit HloConstantInstruction(const Shape& shape); 655 // Returns the literal associated with this instruction. literal()656 const Literal& literal() const { return *literal_; } 657 // Returns whether there is literal associated with this instruction. HasLiteral()658 bool HasLiteral() const { return literal_.has_value(); } 659 // Returns a serialized representation of this instruction. 660 HloInstructionProto ToProto() const override; 661 662 // Change the layout for an Constant Hlo instruction to match new_layout. For 663 // tuple shaped constants shape_index is the path to the internal array 664 // subshape whose layout needs to be changed. 665 void RelayoutConstant(const Layout& new_layout, 666 const ShapeIndex& shape_index = {}); 667 668 private: 669 bool IsElementwiseImpl( 670 const absl::optional<int64>& operand_idx) const override; 671 bool IdenticalSlowPath( 672 const HloInstruction& other, 673 const std::function<bool(const HloComputation*, const HloComputation*)>& 674 eq_computations) const override; 675 string OperandsToStringWithCanonicalNameMap( 676 const HloPrintOptions& options, 677 CanonicalNameMap* canonical_name_map) const override; 678 // Implementation for non-common logic of CloneWithNewOperands. 679 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 680 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 681 HloCloneContext* context) const override; 682 absl::optional<Literal> literal_; 683 }; 684 685 class HloTraceInstruction : public HloInstruction { 686 public: 687 explicit HloTraceInstruction(const string& tag, HloInstruction* operand); 688 // Returns a tag to be used in tracing. TracingTag()689 string TracingTag() const { return literal_.GetR1U8AsString(); } 690 // Returns a serialized representation of this instruction. 691 HloInstructionProto ToProto() const override; 692 693 private: 694 bool IdenticalSlowPath( 695 const HloInstruction& other, 696 const std::function<bool(const HloComputation*, const HloComputation*)>& 697 eq_computations) const override; 698 // Implementation for non-common logic of CloneWithNewOperands. 699 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 700 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 701 HloCloneContext* context) const override; 702 Literal literal_; 703 }; 704 705 class HloFusionInstruction : public HloInstruction { 706 public: 707 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 708 HloInstruction* fused_root); 709 710 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 711 absl::Span<HloInstruction* const> operands, 712 HloComputation* fusion_computation); 713 714 string ToCategory() const override; 715 // Returns a serialized representation of this instruction. 716 HloInstructionProto ToProto() const override; 717 718 // Adds a new operand the fusion instruction. 719 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 720 721 // Merges the fused instructions from 'instruction_to_merge' into the 722 // fused instruction set of 'this', updating operands as necessary. 723 // 724 // Predondition: 'instruction_to_merge' must be an operand of 'this'. 725 void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); 726 727 // Merges the fused instructions from instruction_to_merge into the fused 728 // instruction set of 'this' and generates multioutput fusion instructions. 729 // All the users of instruction_to_merge will be redirected to 'this' 730 // instruction. instruction_to_merge will be removed from its parent 731 // computation. 732 void MergeFusionInstructionIntoMultiOutput( 733 HloFusionInstruction* instruction_to_merge); 734 735 // Fuses the given instruction in this fusion instruction. instruction_to_fuse 736 // is cloned and the clone is placed in the fusion 737 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather 738 // than moved to cleanly handle the case where the instruction has a use 739 // outside the fusion instruction. Moving such an instruction into a fusion 740 // instruction would violate the single-result invariant of HLO instructions 741 // and significantly complicate code generation. FuseInstruction(HloInstruction * instruction_to_fuse)742 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { 743 return FuseInstructionInternal(instruction_to_fuse); 744 } 745 746 // Fuses the given instruction in this fusion instruction and generate 747 // multioutput fusion instruction. A clone of the instruction_to_fuse will 748 // be part of the output of fusion instructions. The users of 749 // instruction_to_fuse will be redirected to this fusion instructions. 750 // instruction_to_fuse will be removed from its parent computation. FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)751 HloInstruction* FuseInstructionIntoMultiOutput( 752 HloInstruction* instruction_to_fuse) { 753 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); 754 } 755 756 // Returns the computation for this fused instruction. 757 HloComputation* fused_instructions_computation() const; 758 759 // Returns the root instruction of the fused expression contained within this 760 // fusion instruction. 761 HloInstruction* fused_expression_root() const; 762 763 // Returns the list of fused instructions inside this fusion instruction. The 764 // returned type is a range of HloInstruction*s. 765 const tensorflow::gtl::iterator_range<UnwrappingIterator< 766 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 767 fused_instructions() const; 768 769 const tensorflow::gtl::iterator_range< 770 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 771 fused_instructions(); 772 773 // Gets the number of instructions inside this fusion instruction. 774 int64 fused_instruction_count() const; 775 776 // Returns the fused parameter instruction in this fusion instruction 777 // corresponding to the given parameter number. 778 HloInstruction* fused_parameter(int64 parameter_number) const; 779 780 // Returns the vector of fused parameters inside this fusion instruction. 781 const std::vector<HloInstruction*>& fused_parameters() const; 782 783 // Returns true if this instruction is a fusion instruction that generates 784 // multiple outputs. IsMultiOutputFusion()785 const bool IsMultiOutputFusion() const { 786 return fused_expression_root()->opcode() == HloOpcode::kTuple; 787 } 788 fusion_kind()789 FusionKind fusion_kind() const { return fusion_kind_; } 790 set_fusion_kind(FusionKind kind)791 void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } 792 793 // If multiple operands are the same instruction, keeps only one of them. 794 Status DeduplicateFusionOperands(); 795 796 private: 797 // Fuses the given instruction into this fusion instruction. When add_output 798 // is false (which is the default), instruction_to_fuse is cloned and the 799 // clone is placed in the fusion instruction. instruction_to_fuse is 800 // unchanged. 801 // 802 // When add_output is true, a clone of the instruction_to_fuse will be part 803 // of the output of fusion instructions. The users of instruction_to_fuse 804 // will be redirected to this fusion instructions. instruction_to_fuse will 805 // be removed from its parent computation. 806 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, 807 bool add_output = false); 808 // Clones the given instruction_to_fuse and insert the clone into this fusion 809 // instruction. If add_output is true, a clone of instruction_to_fuse will 810 // be in the output of the this fusion instruction (part of the tuple of the 811 // fusion root). 812 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, 813 bool add_output = false); 814 815 bool IsElementwiseImpl( 816 const absl::optional<int64>& operand_idx) const override; 817 std::vector<string> ExtraAttributesToStringImpl( 818 const HloPrintOptions& options) const override; 819 bool IdenticalSlowPath( 820 const HloInstruction& other, 821 const std::function<bool(const HloComputation*, const HloComputation*)>& 822 eq_computations) const override; 823 uint64 InnerHash() const override; 824 825 // Implementation for non-common logic of CloneWithNewOperands. 826 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 827 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 828 HloCloneContext* context) const override; 829 830 // The type of the fusion. Used by kFusion only. 831 FusionKind fusion_kind_; 832 }; 833 834 class HloRngInstruction : public HloInstruction { 835 public: 836 explicit HloRngInstruction(const Shape& shape, 837 RandomDistribution distribution, 838 absl::Span<HloInstruction* const> parameters); 839 // Returns the random distribution for this rng node. random_distribution()840 RandomDistribution random_distribution() const { return distribution_; } 841 // Returns a serialized representation of this instruction. 842 HloInstructionProto ToProto() const override; 843 844 private: 845 bool IsElementwiseImpl( 846 const absl::optional<int64>& operand_idx) const override; 847 std::vector<string> ExtraAttributesToStringImpl( 848 const HloPrintOptions& options) const override; 849 bool IdenticalSlowPath( 850 const HloInstruction& other, 851 const std::function<bool(const HloComputation*, const HloComputation*)>& 852 eq_computations) const override; 853 // Implementation for non-common logic of CloneWithNewOperands. 854 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 855 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 856 HloCloneContext* context) const override; 857 858 // The distribution requested for random number generation. 859 RandomDistribution distribution_; 860 }; 861 862 class HloParameterInstruction : public HloInstruction { 863 public: 864 explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, 865 const string& name); parameter_number()866 int64 parameter_number() const { return parameter_number_; } 867 868 // Sets and gets the whether all replicas will receive the same parameter data 869 // for each leaf buffer in data parallelism. set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)870 void set_parameter_replicated_at_leaf_buffers( 871 absl::Span<const bool> parameter_replicated_at_leaf_buffers) { 872 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 873 parameter_replicated_at_leaf_buffers.size()); 874 parameter_replicated_at_leaf_buffers_.emplace( 875 parameter_replicated_at_leaf_buffers.begin(), 876 parameter_replicated_at_leaf_buffers.end()); 877 } 878 const absl::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()879 parameter_replicated_at_leaf_buffers() const { 880 return parameter_replicated_at_leaf_buffers_; 881 } 882 883 // Returns a serialized representation of this instruction. 884 HloInstructionProto ToProto() const override; 885 886 private: 887 std::vector<string> ExtraAttributesToStringImpl( 888 const HloPrintOptions& options) const override; 889 bool IdenticalSlowPath( 890 const HloInstruction& other, 891 const std::function<bool(const HloComputation*, const HloComputation*)>& 892 eq_computations) const override; 893 string OperandsToStringWithCanonicalNameMap( 894 const HloPrintOptions& options, 895 CanonicalNameMap* canonical_name_map) const override; 896 // Implementation for non-common logic of CloneWithNewOperands. 897 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 898 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 899 HloCloneContext* context) const override; 900 901 int64 parameter_number_ = 0; 902 903 // Specifies whether each buffer has the same parameter value on all replicas 904 // in data parallelism. 905 absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_; 906 }; 907 908 class HloGetTupleElementInstruction : public HloInstruction { 909 public: 910 explicit HloGetTupleElementInstruction(const Shape& shape, 911 HloInstruction* operand, int64 index); 912 // Returns the tuple index associated with this instruction. tuple_index()913 int64 tuple_index() const { return tuple_index_; } 914 // Returns a serialized representation of this instruction. 915 HloInstructionProto ToProto() const override; 916 917 private: 918 std::vector<string> ExtraAttributesToStringImpl( 919 const HloPrintOptions& options) const override; 920 bool IdenticalSlowPath( 921 const HloInstruction& other, 922 const std::function<bool(const HloComputation*, const HloComputation*)>& 923 eq_computations) const override; 924 // Implementation for non-common logic of CloneWithNewOperands. 925 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 926 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 927 HloCloneContext* context) const override; 928 929 int64 tuple_index_ = -1; 930 }; 931 932 class HloReducePrecisionInstruction : public HloInstruction { 933 public: 934 explicit HloReducePrecisionInstruction(const Shape& shape, 935 HloInstruction* operand, 936 const int exponent_bits, 937 const int mantissa_bits); 938 // Returns the number of exponent bits for a reduce-precision node. exponent_bits()939 int32 exponent_bits() const { return exponent_bits_; } 940 // Returns the number of mantissa bits for a reduce-precision node. mantissa_bits()941 int32 mantissa_bits() const { return mantissa_bits_; } 942 // Returns a serialized representation of this instruction. 943 HloInstructionProto ToProto() const override; 944 945 private: 946 std::vector<string> ExtraAttributesToStringImpl( 947 const HloPrintOptions& options) const override; 948 bool IdenticalSlowPath( 949 const HloInstruction& other, 950 const std::function<bool(const HloComputation*, const HloComputation*)>& 951 eq_computations) const override; 952 // Implementation for non-common logic of CloneWithNewOperands. 953 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 954 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 955 HloCloneContext* context) const override; 956 957 // The bit sizes for a reduce-precision operation. 958 int32 exponent_bits_ = 0; 959 int32 mantissa_bits_ = 0; 960 }; 961 962 class HloInfeedInstruction : public HloInstruction { 963 public: 964 explicit HloInfeedInstruction(const Shape& infeed_shape, 965 HloInstruction* token_operand, 966 const string& config); 967 // Returns the infeed configuration string. The infeed configuration includes 968 // any metadata needed for the backend compiler (e.g., infeed buffer address) 969 // and is target-dependent. infeed_config()970 string infeed_config() const { return infeed_config_; } set_infeed_config(const string & config)971 void set_infeed_config(const string& config) { infeed_config_ = config; } 972 // Returns the shape of the data received by the infeed. This is not the same 973 // as the shape of the infeed instruction which produces a tuple containing 974 // the infeed data shape and a TOKEN. infeed_shape()975 const Shape& infeed_shape() const { 976 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); 977 return ShapeUtil::GetSubshape(shape(), {0}); 978 } 979 // Returns a serialized representation of this instruction. 980 HloInstructionProto ToProto() const override; 981 982 private: 983 std::vector<string> ExtraAttributesToStringImpl( 984 const HloPrintOptions& options) const override; 985 bool IdenticalSlowPath( 986 const HloInstruction& other, 987 const std::function<bool(const HloComputation*, const HloComputation*)>& 988 eq_computations) const override; 989 // Implementation for non-common logic of CloneWithNewOperands. 990 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 991 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 992 HloCloneContext* context) const override; 993 994 // The string representation of the infeed configuration. 995 string infeed_config_; 996 }; 997 998 class HloOutfeedInstruction : public HloInstruction { 999 public: 1000 explicit HloOutfeedInstruction(const Shape& outfeed_shape, 1001 HloInstruction* operand, 1002 HloInstruction* token_operand, 1003 absl::string_view outfeed_config); 1004 // Returns the shape for the Outfeed instruction. outfeed_shape()1005 const Shape& outfeed_shape() const { return outfeed_shape_; } 1006 // Returns the config for the Outfeed instruction. outfeed_config()1007 const string& outfeed_config() const { return outfeed_config_; } 1008 // Returns a serialized representation of this instruction. 1009 HloInstructionProto ToProto() const override; 1010 1011 private: 1012 std::vector<string> ExtraAttributesToStringImpl( 1013 const HloPrintOptions& options) const override; 1014 bool IdenticalSlowPath( 1015 const HloInstruction& other, 1016 const std::function<bool(const HloComputation*, const HloComputation*)>& 1017 eq_computations) const override; 1018 // Implementation for non-common logic of CloneWithNewOperands. 1019 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1020 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1021 HloCloneContext* context) const override; 1022 1023 // Shape of outfeed request. 1024 Shape outfeed_shape_; 1025 // Outfeed configuration information, only present for kOutfeed. 1026 string outfeed_config_; 1027 }; 1028 1029 class HloConvolutionInstruction : public HloInstruction { 1030 public: 1031 explicit HloConvolutionInstruction( 1032 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 1033 int64 feature_group_count, int64 batch_group_count, const Window& window, 1034 const ConvolutionDimensionNumbers& dimension_numbers, 1035 const PrecisionConfig& precision_config); window()1036 const Window& window() const override { return window_; } set_window(const Window & window)1037 void set_window(const Window& window) override { window_ = window; } convolution_dimension_numbers()1038 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1039 return convolution_dimension_numbers_; 1040 } set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1041 void set_convolution_dimension_numbers( 1042 const ConvolutionDimensionNumbers& dnums) { 1043 convolution_dimension_numbers_ = dnums; 1044 } 1045 // The number of feature groups. Must be a divisor of the input feature 1046 // dimension and output feature dimension. feature_group_count()1047 int64 feature_group_count() const { return feature_group_count_; } 1048 1049 // The number of feature groups. Must be a divisor of the input batch 1050 // dimension. batch_group_count()1051 int64 batch_group_count() const { return batch_group_count_; } 1052 1053 // Returns the information used to tell the implementation information about 1054 // what sort of precision is requested. The meaning of the field is backend 1055 // specific. At the moment, it is only supported for kConvolution and kDot. 1056 // Transformations on one kDot or kConvolution to another will preserve this 1057 // information. Transformations to other HLOs will not preserve this 1058 // information but it is presumed that the alternate lowering is strictly 1059 // superior. precision_config()1060 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1061 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1062 1063 string ToCategory() const override; 1064 // Returns a serialized representation of this instruction. 1065 HloInstructionProto ToProto() const override; 1066 1067 private: 1068 std::vector<string> ExtraAttributesToStringImpl( 1069 const HloPrintOptions& options) const override; 1070 bool IdenticalSlowPath( 1071 const HloInstruction& other, 1072 const std::function<bool(const HloComputation*, const HloComputation*)>& 1073 eq_computations) const override; 1074 // Implementation for non-common logic of CloneWithNewOperands. 1075 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1076 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1077 HloCloneContext* context) const override; 1078 // The number of feature groups. Must be a divisor of the input feature 1079 // dimension and output feature dimension. 1080 int64 feature_group_count_; 1081 // The number of feature groups. Must be a divisor of the input batch 1082 // dimension. 1083 int64 batch_group_count_; 1084 // Describes the window used for a convolution. 1085 Window window_; 1086 // Describes the dimension numbers used for a convolution. 1087 ConvolutionDimensionNumbers convolution_dimension_numbers_; 1088 // Information used to communicate to the implementation about the algorithm 1089 // used to produce results. See the documentation on precision_config(). 1090 PrecisionConfig precision_config_; 1091 }; 1092 1093 class HloReduceWindowInstruction : public HloInstruction { 1094 public: 1095 explicit HloReduceWindowInstruction(const Shape& shape, 1096 HloInstruction* operand, 1097 HloInstruction* init_value, 1098 const Window& window, 1099 HloComputation* reduce_computation); window()1100 const Window& window() const override { return window_; } set_window(const Window & window)1101 void set_window(const Window& window) override { window_ = window; } 1102 // Returns a serialized representation of this instruction. 1103 HloInstructionProto ToProto() const override; 1104 1105 private: 1106 std::vector<string> ExtraAttributesToStringImpl( 1107 const HloPrintOptions& options) const override; 1108 bool IdenticalSlowPath( 1109 const HloInstruction& other, 1110 const std::function<bool(const HloComputation*, const HloComputation*)>& 1111 eq_computations) const override; 1112 // Implementation for non-common logic of CloneWithNewOperands. 1113 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1114 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1115 HloCloneContext* context) const override; 1116 Window window_; 1117 }; 1118 1119 class HloSelectAndScatterInstruction : public HloInstruction { 1120 public: 1121 explicit HloSelectAndScatterInstruction( 1122 const Shape& shape, HloInstruction* operand, HloComputation* select, 1123 const Window& window, HloInstruction* source, HloInstruction* init_value, 1124 HloComputation* scatter); window()1125 const Window& window() const override { return window_; } set_window(const Window & window)1126 void set_window(const Window& window) override { window_ = window; } 1127 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The 1128 // setters should only be called by HloModule or HloComputation methods. select()1129 HloComputation* select() const { 1130 return called_computations()[kSelectComputationIndex]; 1131 } 1132 scatter()1133 HloComputation* scatter() const { 1134 return called_computations()[kScatterComputationIndex]; 1135 } 1136 set_select(HloComputation * computation)1137 void set_select(HloComputation* computation) { 1138 // Don't allow changing the computation for fused instructions so we don't 1139 // have to recompute called_instructions for the entire fusion instruction. 1140 CHECK(!IsFused()); 1141 set_called_computation(kSelectComputationIndex, computation); 1142 } 1143 set_scatter(HloComputation * computation)1144 void set_scatter(HloComputation* computation) { 1145 // Don't allow changing the computation for fused instructions so we don't 1146 // have to recompute called_instructions for the entire fusion instruction. 1147 CHECK(!IsFused()); 1148 set_called_computation(kScatterComputationIndex, computation); 1149 } 1150 // Returns a serialized representation of this instruction. 1151 HloInstructionProto ToProto() const override; 1152 1153 private: 1154 std::vector<string> ExtraAttributesToStringImpl( 1155 const HloPrintOptions& options) const override; 1156 bool IdenticalSlowPath( 1157 const HloInstruction& other, 1158 const std::function<bool(const HloComputation*, const HloComputation*)>& 1159 eq_computations) const override; 1160 // Implementation for non-common logic of CloneWithNewOperands. 1161 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1162 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1163 HloCloneContext* context) const override; 1164 Window window_; 1165 }; 1166 1167 class HloCustomCallInstruction : public HloInstruction { 1168 public: 1169 HloCustomCallInstruction(const Shape& shape, 1170 absl::Span<HloInstruction* const> operands, 1171 absl::string_view custom_call_target, 1172 absl::string_view opaque); 1173 1174 // Constructor for a custom call with constrained layout. 'shape' and 1175 // 'operands_with_layout' must all have layouts. 1176 HloCustomCallInstruction(const Shape& shape, 1177 absl::Span<HloInstruction* const> operands, 1178 absl::string_view custom_call_target, 1179 absl::string_view opaque, 1180 absl::Span<const Shape> operand_shapes_with_layout); 1181 window()1182 const Window& window() const override { 1183 CHECK(window_ != nullptr); 1184 return *window_; 1185 } 1186 set_window(const Window & window)1187 void set_window(const Window& window) override { 1188 window_ = absl::make_unique<Window>(window); 1189 } 1190 convolution_dimension_numbers()1191 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1192 CHECK(convolution_dimension_numbers_ != nullptr); 1193 return *convolution_dimension_numbers_; 1194 } 1195 set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1196 void set_convolution_dimension_numbers( 1197 const ConvolutionDimensionNumbers& dnums) { 1198 convolution_dimension_numbers_ = 1199 absl::make_unique<ConvolutionDimensionNumbers>(dnums); 1200 } opaque()1201 const string& opaque() const { return opaque_; } custom_call_target()1202 const string& custom_call_target() const { return custom_call_target_; } set_feature_group_count(int64 feature_group_count)1203 void set_feature_group_count(int64 feature_group_count) { 1204 feature_group_count_ = feature_group_count; 1205 } set_batch_group_count(int64 batch_group_count)1206 void set_batch_group_count(int64 batch_group_count) { 1207 batch_group_count_ = batch_group_count; 1208 } feature_group_count()1209 int64 feature_group_count() const { return feature_group_count_; } batch_group_count()1210 int64 batch_group_count() const { return batch_group_count_; } 1211 // Returns a serialized representation of this instruction. 1212 HloInstructionProto ToProto() const override; 1213 1214 // Returns whether the result and operand layouts are constrained. layout_constrained()1215 bool layout_constrained() const { return layout_constrained_; } 1216 1217 // Returns the shapes (with layout) of the operands. CHECKs if this custom 1218 // call does not have constrained layouts. operand_shapes_with_layout()1219 const std::vector<Shape>& operand_shapes_with_layout() const { 1220 CHECK(layout_constrained()); 1221 return operand_shapes_with_layout_; 1222 } 1223 1224 private: 1225 std::vector<string> ExtraAttributesToStringImpl( 1226 const HloPrintOptions& options) const override; 1227 bool IdenticalSlowPath( 1228 const HloInstruction& other, 1229 const std::function<bool(const HloComputation*, const HloComputation*)>& 1230 eq_computations) const override; 1231 // Implementation for non-common logic of CloneWithNewOperands. 1232 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1233 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1234 HloCloneContext* context) const override; 1235 // Name of a global symbol to call. 1236 string custom_call_target_; 1237 // Opaque string interpreted by the backend. 1238 string opaque_; 1239 // Describes the window in a windowed operation such as convolution. 1240 std::unique_ptr<Window> window_; 1241 // Describes the dimension numbers used for a convolution. 1242 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; 1243 // The number of feature groups. This is used for grouped convolutions. 1244 int64 feature_group_count_; 1245 int64 batch_group_count_; 1246 // Whether the result and operand layouts are constrained. 1247 bool layout_constrained_; 1248 // For layout-constrained custom calls, this vector holds the shape with 1249 // layout for each operand. 1250 std::vector<Shape> operand_shapes_with_layout_; 1251 }; 1252 1253 class HloPadInstruction : public HloInstruction { 1254 public: 1255 explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, 1256 HloInstruction* padding_value, 1257 const PaddingConfig& padding_config); 1258 // Returns the padding configuration for a pad node. padding_config()1259 const PaddingConfig& padding_config() const { return padding_config_; } 1260 // Returns the padding value. padding_value()1261 const HloInstruction* padding_value() const { return operand(1); } mutable_padding_value()1262 HloInstruction* mutable_padding_value() { return mutable_operand(1); } 1263 // Returns a serialized representation of this instruction. 1264 HloInstructionProto ToProto() const override; 1265 1266 private: 1267 std::vector<string> ExtraAttributesToStringImpl( 1268 const HloPrintOptions& options) const override; 1269 bool IdenticalSlowPath( 1270 const HloInstruction& other, 1271 const std::function<bool(const HloComputation*, const HloComputation*)>& 1272 eq_computations) const override; 1273 // Implementation for non-common logic of CloneWithNewOperands. 1274 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1275 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1276 HloCloneContext* context) const override; 1277 1278 // The padding configuration that describes the edge padding and interior 1279 // padding of this pad instruction. 1280 PaddingConfig padding_config_; 1281 }; 1282 1283 class HloDynamicIndexInstruction : public HloInstruction { 1284 public: HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1285 explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) 1286 : HloInstruction(opcode, shape) {} 1287 virtual int64 first_index_operand_number() const = 0; 1288 1289 // Returns a subspan of operands which represent the start indices. index_operands()1290 absl::Span<HloInstruction* const> index_operands() const { 1291 return absl::MakeSpan(operands()).subspan(first_index_operand_number()); 1292 } 1293 1294 // Returns the shapes of the index operands. index_shapes()1295 std::vector<Shape> index_shapes() const { 1296 std::vector<Shape> shapes; 1297 auto indices = index_operands(); 1298 for (const HloInstruction* index : indices) { 1299 shapes.push_back(index->shape()); 1300 } 1301 return shapes; 1302 } 1303 }; 1304 1305 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { 1306 public: 1307 explicit HloDynamicSliceInstruction(const Shape& shape, 1308 HloInstruction* operand, 1309 HloInstruction* start_indices, 1310 absl::Span<const int64> slice_sizes); 1311 explicit HloDynamicSliceInstruction( 1312 const Shape& shape, HloInstruction* operand, 1313 absl::Span<HloInstruction* const> start_indices, 1314 absl::Span<const int64> slice_sizes); 1315 // Old methods kept for smooth subclassing transition END. 1316 // Returns the size of the slice in the given dimension for a dynamic 1317 // slice node. slice_sizes(int64 dimension)1318 int64 slice_sizes(int64 dimension) const { 1319 return dynamic_slice_sizes_[dimension]; 1320 } dynamic_slice_sizes()1321 const std::vector<int64>& dynamic_slice_sizes() const { 1322 return dynamic_slice_sizes_; 1323 } 1324 // Returns a serialized representation of this instruction. 1325 HloInstructionProto ToProto() const override; 1326 first_index_operand_number()1327 int64 first_index_operand_number() const override { return 1; } 1328 1329 private: 1330 std::vector<string> ExtraAttributesToStringImpl( 1331 const HloPrintOptions& options) const override; 1332 bool IdenticalSlowPath( 1333 const HloInstruction& other, 1334 const std::function<bool(const HloComputation*, const HloComputation*)>& 1335 eq_computations) const override; 1336 // Implementation for non-common logic of CloneWithNewOperands. 1337 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1338 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1339 HloCloneContext* context) const override; 1340 1341 // Describes the [start, start + size) range size for a dynamic slice 1342 // ('start' is specified dynamically in the second operand of the operation). 1343 std::vector<int64> dynamic_slice_sizes_; 1344 }; 1345 1346 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { 1347 public: 1348 explicit HloDynamicUpdateSliceInstruction(const Shape& shape, 1349 HloInstruction* operand, 1350 HloInstruction* update, 1351 HloInstruction* start_indices); 1352 explicit HloDynamicUpdateSliceInstruction( 1353 const Shape& shape, HloInstruction* operand, HloInstruction* update, 1354 absl::Span<HloInstruction* const> start_indices); 1355 first_index_operand_number()1356 int64 first_index_operand_number() const override { return 2; } 1357 }; 1358 1359 class HloGatherInstruction : public HloInstruction { 1360 public: 1361 explicit HloGatherInstruction( 1362 const Shape& shape, HloInstruction* operand, 1363 HloInstruction* start_indices, 1364 const GatherDimensionNumbers& gather_dim_numbers, 1365 absl::Span<const int64> slice_sizes); gather_dimension_numbers()1366 const GatherDimensionNumbers& gather_dimension_numbers() const { 1367 CHECK(gather_dimension_numbers_ != nullptr); 1368 return *gather_dimension_numbers_; 1369 } gather_slice_sizes()1370 absl::Span<const int64> gather_slice_sizes() const { 1371 return gather_slice_sizes_; 1372 } 1373 // Returns the dump string of the gather dimension numbers. 1374 string GatherDimensionNumbersToString() const; 1375 // Returns a serialized representation of this instruction. 1376 HloInstructionProto ToProto() const override; 1377 1378 // Creates an instance of GatherDimensionNumbers. 1379 static GatherDimensionNumbers MakeGatherDimNumbers( 1380 absl::Span<const int64> offset_dims, 1381 absl::Span<const int64> collapsed_slice_dims, 1382 absl::Span<const int64> start_index_map, int64 index_vector_dim); 1383 1384 private: 1385 std::vector<string> ExtraAttributesToStringImpl( 1386 const HloPrintOptions& options) const override; 1387 bool IdenticalSlowPath( 1388 const HloInstruction& other, 1389 const std::function<bool(const HloComputation*, const HloComputation*)>& 1390 eq_computations) const override; 1391 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1392 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1393 HloCloneContext* context) const override; 1394 1395 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; 1396 std::vector<int64> gather_slice_sizes_; 1397 }; 1398 1399 class HloScatterInstruction : public HloInstruction { 1400 public: 1401 explicit HloScatterInstruction( 1402 const Shape& shape, HloInstruction* operand, 1403 HloInstruction* scatter_indices, HloInstruction* updates, 1404 HloComputation* update_computation, 1405 const ScatterDimensionNumbers& scatter_dim_numbers); scatter_dimension_numbers()1406 const ScatterDimensionNumbers& scatter_dimension_numbers() const { 1407 CHECK(scatter_dimension_numbers_ != nullptr); 1408 return *scatter_dimension_numbers_; 1409 } 1410 // Returns the dump string of the scatter dimension numbers. 1411 string ScatterDimensionNumbersToString() const; 1412 // Returns a serialized representation of this instruction. 1413 HloInstructionProto ToProto() const override; 1414 1415 // Creates an instance of ScatterDimensionNumbers. 1416 static ScatterDimensionNumbers MakeScatterDimNumbers( 1417 absl::Span<const int64> update_window_dims, 1418 absl::Span<const int64> inserted_window_dims, 1419 absl::Span<const int64> scatter_dims_to_operand_dims, 1420 int64 index_vector_dim); 1421 1422 private: 1423 std::vector<string> ExtraAttributesToStringImpl( 1424 const HloPrintOptions& options) const override; 1425 bool IdenticalSlowPath( 1426 const HloInstruction& other, 1427 const std::function<bool(const HloComputation*, const HloComputation*)>& 1428 eq_computations) const override; 1429 // Implementation for non-common logic of CloneWithNewOperands. 1430 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1431 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1432 HloCloneContext* context) const override; 1433 1434 std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; 1435 }; 1436 1437 class HloIotaInstruction : public HloInstruction { 1438 public: 1439 explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); 1440 // Returns the dimension sizes or numbers associated with this instruction. iota_dimension()1441 int64 iota_dimension() const { return iota_dimension_; } 1442 // Returns a serialized representation of this instruction. 1443 HloInstructionProto ToProto() const override; 1444 1445 private: 1446 std::vector<string> ExtraAttributesToStringImpl( 1447 const HloPrintOptions& options) const override; 1448 bool IdenticalSlowPath( 1449 const HloInstruction& other, 1450 const std::function<bool(const HloComputation*, const HloComputation*)>& 1451 eq_computations) const override; 1452 // Implementation for non-common logic of CloneWithNewOperands. 1453 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1454 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1455 HloCloneContext* context) const override; 1456 1457 const int64 iota_dimension_; 1458 }; 1459 1460 class HloDotInstruction : public HloInstruction { 1461 public: 1462 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 1463 // dimensions specified in 'dimension_numbers'. 1464 explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, 1465 HloInstruction* rhs, 1466 const DotDimensionNumbers& dimension_numbers, 1467 const PrecisionConfig& precision_config); 1468 1469 // Returns data on the dimension numbers used for a dot operation. dot_dimension_numbers()1470 const DotDimensionNumbers& dot_dimension_numbers() const { 1471 return dot_dimension_numbers_; 1472 } 1473 1474 // Returns the information used to tell the implementation information about 1475 // what sort of precision is requested. The meaning of the field is backend 1476 // specific. At the moment, it is only supported for kConvolution and kDot. 1477 // Transformations on one kDot or kConvolution to another will preserve this 1478 // information. Transformations to other HLOs will not preserve this 1479 // information but it is presumed that the alternate lowering is strictly 1480 // superior. precision_config()1481 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1482 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1483 1484 // Returns a serialized representation of this instruction. 1485 HloInstructionProto ToProto() const override; 1486 1487 private: 1488 std::vector<string> ExtraAttributesToStringImpl( 1489 const HloPrintOptions& options) const override; 1490 bool IdenticalSlowPath( 1491 const HloInstruction& other, 1492 const std::function<bool(const HloComputation*, const HloComputation*)>& 1493 eq_computations) const override; 1494 // Implementation for non-common logic of CloneWithNewOperands. 1495 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1496 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1497 HloCloneContext* context) const override; 1498 // Returns the dump string of the dot dimension numbers. 1499 string DotDimensionNumbersToString() const; 1500 1501 // Describes the dimension numbers used for a dot. 1502 DotDimensionNumbers dot_dimension_numbers_; 1503 1504 // Information used to communicate to the implementation about the algorithm 1505 // used to produce results. See the documentation on precision_config(). 1506 PrecisionConfig precision_config_; 1507 }; 1508 1509 class HloDomainInstruction : public HloInstruction { 1510 public: 1511 explicit HloDomainInstruction( 1512 const Shape& shape, HloInstruction* operand, 1513 std::unique_ptr<DomainMetadata> operand_side_metadata, 1514 std::unique_ptr<DomainMetadata> user_side_metadata); 1515 1516 // Returns a serialized representation of this instruction. 1517 HloInstructionProto ToProto() const override; 1518 1519 // Retrieves the operand side metadata of a kDomain instruction. operand_side_metadata()1520 const DomainMetadata& operand_side_metadata() const { 1521 return *operand_side_metadata_; 1522 } 1523 // Retrieves the user side metadata of a kDomain instruction. user_side_metadata()1524 const DomainMetadata& user_side_metadata() const { 1525 return *user_side_metadata_; 1526 } 1527 1528 private: 1529 std::vector<string> ExtraAttributesToStringImpl( 1530 const HloPrintOptions& options) const override; 1531 bool IdenticalSlowPath( 1532 const HloInstruction& other, 1533 const std::function<bool(const HloComputation*, const HloComputation*)>& 1534 eq_computations) const override; 1535 // Implementation for non-common logic of CloneWithNewOperands. 1536 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1537 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1538 HloCloneContext* context) const override; 1539 1540 std::unique_ptr<DomainMetadata> operand_side_metadata_; 1541 std::unique_ptr<DomainMetadata> user_side_metadata_; 1542 }; 1543 1544 class HloGetDimensionSizeInstruction : public HloInstruction { 1545 public: 1546 explicit HloGetDimensionSizeInstruction(const Shape& shape, 1547 HloInstruction* operand, 1548 int64 dimension); 1549 1550 // Returns the dimension sizes or numbers associated with this instruction. dimension()1551 int64 dimension() const { return dimension_; } 1552 // Returns a serialized representation of this instruction. 1553 HloInstructionProto ToProto() const override; 1554 1555 private: 1556 std::vector<string> ExtraAttributesToStringImpl( 1557 const HloPrintOptions& options) const override; 1558 bool IdenticalSlowPath( 1559 const HloInstruction& other, 1560 const std::function<bool(const HloComputation*, const HloComputation*)>& 1561 eq_computations) const override; 1562 // Implementation for non-common logic of CloneWithNewOperands. 1563 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1564 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1565 HloCloneContext* context) const override; 1566 1567 int64 dimension_; 1568 }; 1569 1570 } // namespace xla 1571 1572 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 1573