1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // All HloInstruction subclasses are put in this file. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 24 namespace xla { 25 26 class HloBatchNormInstruction : public HloInstruction { 27 public: 28 // Returns feature_index field associated with the instruction. The index 29 // represents the index of the feature dimension. feature_index()30 int64 feature_index() const { return feature_index_; } 31 32 // Returns a epsilon value associated with the instruction. The is a small 33 // number added to the variance to avoid divide-by-zero error. epsilon()34 float epsilon() const { return epsilon_; } 35 36 // Returns a serialized representation of this instruction. 37 HloInstructionProto ToProto() const override; 38 39 protected: 40 explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, 41 HloInstruction* operand, 42 HloInstruction* scale, float epsilon, 43 int64 feature_index); 44 45 private: 46 std::vector<string> ExtraAttributesToStringImpl( 47 const HloPrintOptions& options) const override; 48 bool IdenticalSlowPath( 49 const HloInstruction& other, 50 const std::function<bool(const HloComputation*, const HloComputation*)>& 51 eq_computations) const override; 52 // A small float number added to the variance to avoid divide-by-zero error. 53 float epsilon_ = 0.0f; 54 55 // An integer value representing the index of the feature dimension. 56 int64 feature_index_ = -1; 57 }; 58 59 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { 60 public: 61 explicit HloBatchNormTrainingInstruction(const Shape& shape, 62 HloInstruction* operand, 63 HloInstruction* scale, 64 HloInstruction* offset, 65 float epsilon, int64 feature_index); 66 67 private: 68 // Implementation for non-common logic of CloneWithNewOperands. 69 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 70 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 71 HloCloneContext* context) const override; 72 }; 73 74 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { 75 public: 76 explicit HloBatchNormInferenceInstruction( 77 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 78 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 79 float epsilon, int64 feature_index); 80 81 private: 82 // Implementation for non-common logic of CloneWithNewOperands. 83 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 84 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 85 HloCloneContext* context) const override; 86 }; 87 88 class HloBatchNormGradInstruction : public HloBatchNormInstruction { 89 public: 90 explicit HloBatchNormGradInstruction( 91 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 92 HloInstruction* mean, HloInstruction* variance, 93 HloInstruction* grad_output, float epsilon, int64 feature_index); 94 95 private: 96 // Implementation for non-common logic of CloneWithNewOperands. 97 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 98 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 99 HloCloneContext* context) const override; 100 }; 101 102 class HloFftInstruction : public HloInstruction { 103 public: 104 explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, 105 FftType fft_type, 106 absl::Span<const int64> fft_length); fft_type()107 FftType fft_type() const { return fft_type_; } 108 fft_length()109 const std::vector<int64>& fft_length() const { return fft_length_; } 110 111 // Returns a serialized representation of this instruction. 112 HloInstructionProto ToProto() const override; 113 114 private: 115 std::vector<string> ExtraAttributesToStringImpl( 116 const HloPrintOptions& options) const override; 117 bool IdenticalSlowPath( 118 const HloInstruction& other, 119 const std::function<bool(const HloComputation*, const HloComputation*)>& 120 eq_computations) const override; 121 122 // Implementation for non-common logic of CloneWithNewOperands. 123 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 124 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 125 HloCloneContext* context) const override; 126 127 // Describes FFT type for an FFT instruction. 128 FftType fft_type_ = FftType::FFT; 129 130 // Indicates the FFT length for an FFT instruction. 131 std::vector<int64> fft_length_; 132 }; 133 134 class HloCompareInstruction : public HloInstruction { 135 public: 136 explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, 137 HloInstruction* rhs, 138 ComparisonDirection direction); direction()139 ComparisonDirection direction() const { return direction_; } 140 HloInstructionProto ToProto() const override; 141 142 private: 143 std::vector<string> ExtraAttributesToStringImpl( 144 const HloPrintOptions& options) const override; 145 bool IdenticalSlowPath( 146 const HloInstruction& other, 147 const std::function<bool(const HloComputation*, const HloComputation*)>& 148 eq_computations) const override; 149 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 150 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 151 HloCloneContext* context) const override; 152 153 ComparisonDirection direction_; 154 }; 155 156 class HloTriangularSolveInstruction : public HloInstruction { 157 public: 158 explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, 159 HloInstruction* b, 160 const TriangularSolveOptions& options); triangular_solve_options()161 const TriangularSolveOptions& triangular_solve_options() const { 162 return triangular_solve_options_; 163 } 164 165 // Returns a serialized representation of this instruction. 166 HloInstructionProto ToProto() const override; 167 168 private: 169 std::vector<string> ExtraAttributesToStringImpl( 170 const HloPrintOptions& options) const override; 171 bool IdenticalSlowPath( 172 const HloInstruction& other, 173 const std::function<bool(const HloComputation*, const HloComputation*)>& 174 eq_computations) const override; 175 176 // Implementation for non-common logic of CloneWithNewOperands. 177 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 178 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 179 HloCloneContext* context) const override; 180 181 TriangularSolveOptions triangular_solve_options_; 182 }; 183 184 class HloCholeskyInstruction : public HloInstruction { 185 public: 186 explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, 187 const CholeskyOptions& options); cholesky_options()188 const CholeskyOptions& cholesky_options() const { return cholesky_options_; } 189 190 // Returns a serialized representation of this instruction. 191 HloInstructionProto ToProto() const override; 192 193 private: 194 std::vector<string> ExtraAttributesToStringImpl( 195 const HloPrintOptions& options) const override; 196 bool IdenticalSlowPath( 197 const HloInstruction& other, 198 const std::function<bool(const HloComputation*, const HloComputation*)>& 199 eq_computations) const override; 200 201 // Implementation for non-common logic of CloneWithNewOperands. 202 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 203 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 204 HloCloneContext* context) const override; 205 206 CholeskyOptions cholesky_options_; 207 }; 208 209 // Class that represents instructions that synchronize and transfer data between 210 // partitioned devices. Send/Recv and collective instructions (AllReduce, 211 // AllToAll, CollectivePermute) belong to this instruction type. A group of 212 // instructions (of the same opcode) with the same channel_id communicate during 213 // execution. 214 class HloChannelInstruction : public HloInstruction { 215 public: 216 // Returns the channel id associated with the instruction. The id is 217 // shared between each Send/Recv pair or a group of collective instructions 218 // and is globally unique to identify each channel. channel_id()219 absl::optional<int64> channel_id() const { return channel_id_; } 220 void set_channel_id(const absl::optional<int64>& channel_id); 221 222 protected: 223 explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, 224 const absl::optional<int64>& channel_id); 225 226 HloInstructionProto ToProto() const override; 227 228 std::vector<string> ExtraAttributesToStringImpl( 229 const HloPrintOptions& options) const override; 230 bool IdenticalSlowPath( 231 const HloInstruction& other, 232 const std::function<bool(const HloComputation*, const HloComputation*)>& 233 eq_computations) const override; 234 235 absl::optional<int64> channel_id_; 236 }; 237 238 class HloSendRecvInstruction : public HloChannelInstruction { 239 public: 240 // Returns whether this send/recv instruction sends data to/from the host. is_host_transfer()241 bool is_host_transfer() const { return is_host_transfer_; } 242 243 // Returns a serialized representation of this instruction. 244 HloInstructionProto ToProto() const override; 245 246 protected: 247 explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, 248 int64 channel_id, bool is_host_transfer); 249 250 private: 251 std::vector<string> ExtraAttributesToStringImpl( 252 const HloPrintOptions& options) const override; 253 bool IdenticalSlowPath( 254 const HloInstruction& other, 255 const std::function<bool(const HloComputation*, const HloComputation*)>& 256 eq_computations) const override; 257 // Whether this send/recv instruction sends data to/from the host. 258 bool is_host_transfer_; 259 }; 260 261 class HloSendInstruction : public HloSendRecvInstruction { 262 public: 263 explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, 264 int64 channel_id, bool is_host_transfer); 265 266 private: 267 // Implementation for non-common logic of CloneWithNewOperands. 268 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 269 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 270 HloCloneContext* context) const override; 271 }; 272 273 class HloSendDoneInstruction : public HloSendRecvInstruction { 274 public: 275 explicit HloSendDoneInstruction(HloSendInstruction* operand, 276 bool is_host_transfer); 277 278 private: 279 // Implementation for non-common logic of CloneWithNewOperands. 280 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 281 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 282 HloCloneContext* context) const override; 283 }; 284 285 class HloRecvInstruction : public HloSendRecvInstruction { 286 public: 287 explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, 288 int64 channel_id, bool is_host_transfer); 289 290 private: 291 // Implementation for non-common logic of CloneWithNewOperands. 292 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 293 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 294 HloCloneContext* context) const override; 295 }; 296 297 class HloRecvDoneInstruction : public HloSendRecvInstruction { 298 public: 299 explicit HloRecvDoneInstruction(HloRecvInstruction* operand, 300 bool is_host_transfer); 301 302 private: 303 // Implementation for non-common logic of CloneWithNewOperands. 304 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 305 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 306 HloCloneContext* context) const override; 307 }; 308 309 class HloCollectiveInstruction : public HloChannelInstruction { 310 public: replica_groups()311 const std::vector<ReplicaGroup>& replica_groups() const { 312 return replica_groups_; 313 } 314 315 protected: 316 explicit HloCollectiveInstruction( 317 HloOpcode opcode, const Shape& shape, 318 absl::Span<HloInstruction* const> operands, 319 const std::vector<ReplicaGroup>& replica_groups, 320 const absl::optional<int64>& channel_id); 321 322 HloInstructionProto ToProto() const override; 323 324 std::vector<string> ExtraAttributesToStringImpl( 325 const HloPrintOptions& options) const override; 326 bool IdenticalSlowPath( 327 const HloInstruction& other, 328 const std::function<bool(const HloComputation*, const HloComputation*)>& 329 eq_computations) const override; 330 331 std::vector<ReplicaGroup> replica_groups_; 332 }; 333 334 class HloAllReduceInstruction : public HloCollectiveInstruction { 335 public: 336 explicit HloAllReduceInstruction( 337 const Shape& shape, absl::Span<HloInstruction* const> operands, 338 HloComputation* reduce_computation, 339 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, 340 const absl::optional<int64>& channel_id); 341 342 // Returns true if the AllReduce does no communication, so it's equivalent 343 // to a mem copy. 344 bool IsNoop() const; 345 346 // Returns true if the layout of the AllReduce is enforced by XLA client (as 347 // the layout set in the shape). The only reason for the client to set the 348 // layout is to separately compile computations that communicate with 349 // AllReduce. Since this field is only set `true` by the client, the compiler 350 // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set 351 // `false` for all other cases. 352 // 353 // When this is `true`, there may be communication endpoints outside the 354 // current compilation unit, so the compiler considers this AllReduce as 355 // side-effecting to disable compiler transformations. The compiler is free to 356 // transform unconstrained AllReduces differently across compilation units. 357 // It is an error for an HloModule to have a mix of constrained and 358 // unconstrained AllReduce instructions (checked by HloVerifier). constrain_layout()359 bool constrain_layout() const { return constrain_layout_; } 360 361 protected: 362 std::vector<string> ExtraAttributesToStringImpl( 363 const HloPrintOptions& options) const override; 364 HloInstructionProto ToProto() const override; 365 366 private: 367 bool IdenticalSlowPath( 368 const HloInstruction& other, 369 const std::function<bool(const HloComputation*, const HloComputation*)>& 370 eq_computations) const override; 371 372 // Implementation for non-common logic of CloneWithNewOperands. 373 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 374 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 375 HloCloneContext* context) const override; 376 377 bool constrain_layout_; 378 }; 379 380 class HloAllToAllInstruction : public HloCollectiveInstruction { 381 public: 382 explicit HloAllToAllInstruction( 383 const Shape& shape, absl::Span<HloInstruction* const> operands, 384 const std::vector<ReplicaGroup>& replica_groups, 385 const absl::optional<int64>& channel_id, 386 const absl::optional<int64>& split_dimension); 387 388 // AllToAll can optionally take a split dimension, which means that this 389 // AllToAll takes a single (flattened) array operand and produces an array 390 // output (instead of taking a list of operands and producing a tuple). 391 // 392 // split_dimension specifies which dimension in the operand is split across 393 // devices in each replica_group, and also means the concatenated dimension 394 // on the output (i.e., input and the output shapes are the same). split_dimension()395 absl::optional<int64> split_dimension() const { return split_dimension_; } 396 397 protected: 398 std::vector<string> ExtraAttributesToStringImpl( 399 const HloPrintOptions& options) const override; 400 HloInstructionProto ToProto() const override; 401 402 private: 403 bool IdenticalSlowPath( 404 const HloInstruction& other, 405 const std::function<bool(const HloComputation*, const HloComputation*)>& 406 eq_computations) const override; 407 408 // Implementation for non-common logic of CloneWithNewOperands. 409 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 410 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 411 HloCloneContext* context) const override; 412 413 absl::optional<int64> split_dimension_; 414 }; 415 416 class HloCollectivePermuteInstruction : public HloChannelInstruction { 417 public: 418 explicit HloCollectivePermuteInstruction( 419 const Shape& shape, HloInstruction* operand, 420 const std::vector<std::pair<int64, int64>>& source_target_pairs, 421 const absl::optional<int64>& channel_id); 422 source_target_pairs()423 const std::vector<std::pair<int64, int64>>& source_target_pairs() const { 424 return source_target_pairs_; 425 } 426 427 // Returns a serialized representation of this instruction. 428 HloInstructionProto ToProto() const override; 429 430 private: 431 std::vector<string> ExtraAttributesToStringImpl( 432 const HloPrintOptions& options) const override; 433 bool IdenticalSlowPath( 434 const HloInstruction& other, 435 const std::function<bool(const HloComputation*, const HloComputation*)>& 436 eq_computations) const override; 437 438 // Implementation for non-common logic of CloneWithNewOperands. 439 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 440 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 441 HloCloneContext* context) const override; 442 443 const std::vector<std::pair<int64, int64>> source_target_pairs_; 444 }; 445 446 class HloReverseInstruction : public HloInstruction { 447 public: 448 explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, 449 absl::Span<const int64> dimensions); 450 // Returns the dimension sizes or numbers associated with this instruction. dimensions()451 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)452 int64 dimensions(int64 index) const override { return dimensions()[index]; } 453 // Returns a serialized representation of this instruction. 454 HloInstructionProto ToProto() const override; 455 456 private: 457 std::vector<string> ExtraAttributesToStringImpl( 458 const HloPrintOptions& options) const override; 459 bool IdenticalSlowPath( 460 const HloInstruction& other, 461 const std::function<bool(const HloComputation*, const HloComputation*)>& 462 eq_computations) const override; 463 // Implementation for non-common logic of CloneWithNewOperands. 464 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 465 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 466 HloCloneContext* context) const override; 467 468 std::vector<int64> dimensions_; 469 }; 470 471 class HloConcatenateInstruction : public HloInstruction { 472 public: 473 explicit HloConcatenateInstruction(const Shape& shape, 474 absl::Span<HloInstruction* const> operands, 475 int64 dimension); 476 // Returns the dimension sizes or numbers associated with this instruction. dimensions()477 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)478 int64 dimensions(int64 index) const override { return dimensions()[index]; } 479 // Accessor for the dimension in which a concatenate HLO should occur. concatenate_dimension()480 int64 concatenate_dimension() const { return dimensions(0); } 481 // Returns a serialized representation of this instruction. 482 HloInstructionProto ToProto() const override; 483 484 private: 485 std::vector<string> ExtraAttributesToStringImpl( 486 const HloPrintOptions& options) const override; 487 bool IdenticalSlowPath( 488 const HloInstruction& other, 489 const std::function<bool(const HloComputation*, const HloComputation*)>& 490 eq_computations) const override; 491 // Implementation for non-common logic of CloneWithNewOperands. 492 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 493 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 494 HloCloneContext* context) const override; 495 496 std::vector<int64> dimensions_; 497 }; 498 499 class HloReduceInstruction : public HloInstruction { 500 public: 501 explicit HloReduceInstruction(const Shape& shape, 502 absl::Span<HloInstruction* const> args, 503 absl::Span<const int64> dimensions_to_reduce, 504 HloComputation* reduce_computation); 505 // Returns the dimension sizes or numbers associated with this instruction. dimensions()506 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)507 int64 dimensions(int64 index) const override { return dimensions()[index]; } 508 // Returns a serialized representation of this instruction. 509 HloInstructionProto ToProto() const override; 510 511 // Returns the number of input arrays (and, consequentially, the number of 512 // init values) this reduce has. input_count()513 int64 input_count() const { return operand_count() / 2; } 514 515 // Returns the input tensors to be reduced. inputs()516 absl::Span<HloInstruction* const> inputs() const { 517 return absl::MakeSpan(operands()).subspan(0, input_count()); 518 } 519 520 // Returns the init values of the reduction. init_values()521 absl::Span<HloInstruction* const> init_values() const { 522 return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); 523 } 524 525 private: 526 std::vector<string> ExtraAttributesToStringImpl( 527 const HloPrintOptions& options) const override; 528 bool IdenticalSlowPath( 529 const HloInstruction& other, 530 const std::function<bool(const HloComputation*, const HloComputation*)>& 531 eq_computations) const override; 532 // Implementation for non-common logic of CloneWithNewOperands. 533 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 534 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 535 HloCloneContext* context) const override; 536 537 std::vector<int64> dimensions_; 538 }; 539 540 class HloSortInstruction : public HloInstruction { 541 public: 542 explicit HloSortInstruction(const Shape& shape, int64 dimension, 543 absl::Span<HloInstruction* const> operands, 544 HloComputation* compare, bool is_stable); 545 // Returns the dimension sizes or numbers associated with this instruction. dimensions()546 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)547 int64 dimensions(int64 index) const override { return dimensions()[index]; } 548 // Returns the sort dimension for this instruction sort_dimension()549 int64 sort_dimension() const { return dimensions(0); } 550 // Returns a serialized representation of this instruction. 551 HloInstructionProto ToProto() const override; 552 // Returns the key operand to this instruction. keys()553 const HloInstruction* keys() const { return operand(0); } mutable_keys()554 HloInstruction* mutable_keys() { return mutable_operand(0); } 555 // Returns the number of value operands. values_count()556 int64 values_count() const { return operand_count() - 1; } is_stable()557 bool is_stable() const { return is_stable_; } 558 559 private: 560 std::vector<string> ExtraAttributesToStringImpl( 561 const HloPrintOptions& options) const override; 562 bool IdenticalSlowPath( 563 const HloInstruction& other, 564 const std::function<bool(const HloComputation*, const HloComputation*)>& 565 eq_computations) const override; 566 // Implementation for non-common logic of CloneWithNewOperands. 567 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 568 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 569 HloCloneContext* context) const override; 570 571 std::vector<int64> dimensions_; 572 bool is_stable_; 573 }; 574 575 class HloTransposeInstruction : public HloInstruction { 576 public: 577 explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, 578 absl::Span<const int64> dimensions); 579 // Returns the dimension sizes or numbers associated with this instruction. dimensions()580 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)581 int64 dimensions(int64 index) const override { return dimensions()[index]; } 582 // Returns whether this instruction does a rank-2 transposition. 583 bool IsRank2Transpose() const; 584 // Returns a serialized representation of this instruction. 585 HloInstructionProto ToProto() const override; 586 587 private: 588 std::vector<string> ExtraAttributesToStringImpl( 589 const HloPrintOptions& options) const override; 590 bool IdenticalSlowPath( 591 const HloInstruction& other, 592 const std::function<bool(const HloComputation*, const HloComputation*)>& 593 eq_computations) const override; 594 // Implementation for non-common logic of CloneWithNewOperands. 595 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 596 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 597 HloCloneContext* context) const override; 598 599 std::vector<int64> dimensions_; 600 }; 601 602 class HloBroadcastInstruction : public HloInstruction { 603 public: 604 explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, 605 absl::Span<const int64> broadcast_dimension); 606 // Returns the dimension sizes or numbers associated with this instruction. dimensions()607 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)608 int64 dimensions(int64 index) const override { return dimensions()[index]; } 609 // Returns a serialized representation of this instruction. 610 HloInstructionProto ToProto() const override; 611 612 private: 613 std::vector<string> ExtraAttributesToStringImpl( 614 const HloPrintOptions& options) const override; 615 bool IdenticalSlowPath( 616 const HloInstruction& other, 617 const std::function<bool(const HloComputation*, const HloComputation*)>& 618 eq_computations) const override; 619 // Implementation for non-common logic of CloneWithNewOperands. 620 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 621 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 622 HloCloneContext* context) const override; 623 624 std::vector<int64> dimensions_; 625 }; 626 627 class HloReshapeInstruction : public HloInstruction { 628 public: 629 explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, 630 int64 inferred_dimension); inferred_dimension()631 int64 inferred_dimension() const { return inferred_dimension_; } 632 HloInstructionProto ToProto() const override; 633 634 private: 635 std::vector<string> ExtraAttributesToStringImpl( 636 const HloPrintOptions& options) const override; 637 bool IdenticalSlowPath( 638 const HloInstruction& other, 639 const std::function<bool(const HloComputation*, const HloComputation*)>& 640 eq_computations) const override; 641 // Implementation for non-common logic of CloneWithNewOperands. 642 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 643 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 644 HloCloneContext* context) const override; 645 int64 inferred_dimension_; 646 }; 647 648 class HloMapInstruction : public HloInstruction { 649 public: 650 explicit HloMapInstruction(const Shape& shape, 651 absl::Span<HloInstruction* const> operands, 652 HloComputation* map_computation); 653 // Returns the dimension sizes or numbers associated with this instruction. dimensions()654 const std::vector<int64>& dimensions() const override { return dimensions_; } dimensions(int64 index)655 int64 dimensions(int64 index) const override { return dimensions()[index]; } 656 // Returns a serialized representation of this instruction. 657 HloInstructionProto ToProto() const override; 658 659 private: 660 bool IsElementwiseImpl( 661 const absl::optional<int64>& operand_idx) const override; 662 std::vector<string> ExtraAttributesToStringImpl( 663 const HloPrintOptions& options) const override; 664 bool IdenticalSlowPath( 665 const HloInstruction& other, 666 const std::function<bool(const HloComputation*, const HloComputation*)>& 667 eq_computations) const override; 668 // Implementation for non-common logic of CloneWithNewOperands. 669 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 670 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 671 HloCloneContext* context) const override; 672 673 std::vector<int64> dimensions_; 674 }; 675 676 class HloSliceInstruction : public HloInstruction { 677 public: 678 explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, 679 absl::Span<const int64> start_indices, 680 absl::Span<const int64> limit_indices, 681 absl::Span<const int64> strides); 682 683 HloInstructionProto ToProto() const override; 684 685 // Returns the start index in the given dimension for a slice node. slice_starts(int64 dimension)686 int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } slice_starts()687 const std::vector<int64>& slice_starts() const { return slice_starts_; } 688 689 // Returns the (exclusive) limit index in the given dimension for a slice 690 // node. slice_limits(int64 dimension)691 int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } slice_limits()692 const std::vector<int64>& slice_limits() const { return slice_limits_; } 693 694 // Returns the stride in the given dimension for a slice node. slice_strides(int64 dimension)695 int64 slice_strides(int64 dimension) const { 696 return slice_strides_[dimension]; 697 } slice_strides()698 const std::vector<int64>& slice_strides() const { return slice_strides_; } 699 700 private: 701 std::vector<string> ExtraAttributesToStringImpl( 702 const HloPrintOptions& options) const override; 703 bool IdenticalSlowPath( 704 const HloInstruction& other, 705 const std::function<bool(const HloComputation*, const HloComputation*)>& 706 eq_computations) const override; 707 // Implementation for non-common logic of CloneWithNewOperands. 708 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 709 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 710 HloCloneContext* context) const override; 711 712 // Describes the [begin, end) index range for a slice. 713 std::vector<int64> slice_starts_; 714 std::vector<int64> slice_limits_; 715 std::vector<int64> slice_strides_; 716 }; 717 718 class HloConstantInstruction : public HloInstruction { 719 public: 720 explicit HloConstantInstruction(Literal literal); 721 explicit HloConstantInstruction(Literal literal, const Shape& shape); 722 // Used when the literal is too large and dropped. 723 explicit HloConstantInstruction(const Shape& shape); 724 // Returns the literal associated with this instruction. literal()725 const Literal& literal() const { return *literal_; } 726 // Returns whether there is literal associated with this instruction. HasLiteral()727 bool HasLiteral() const { return literal_.has_value(); } 728 // Returns a serialized representation of this instruction. 729 HloInstructionProto ToProto() const override; 730 731 // Change the layout for an Constant Hlo instruction to match new_layout. For 732 // tuple shaped constants shape_index is the path to the internal array 733 // subshape whose layout needs to be changed. 734 void RelayoutConstant(const Layout& new_layout, 735 const ShapeIndex& shape_index = {}); 736 737 private: 738 bool IsElementwiseImpl( 739 const absl::optional<int64>& operand_idx) const override; 740 bool IdenticalSlowPath( 741 const HloInstruction& other, 742 const std::function<bool(const HloComputation*, const HloComputation*)>& 743 eq_computations) const override; 744 string OperandsToStringWithCanonicalNameMap( 745 const HloPrintOptions& options, 746 CanonicalNameMap* canonical_name_map) const override; 747 // Implementation for non-common logic of CloneWithNewOperands. 748 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 749 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 750 HloCloneContext* context) const override; 751 absl::optional<Literal> literal_; 752 }; 753 754 class HloTraceInstruction : public HloInstruction { 755 public: 756 explicit HloTraceInstruction(const string& tag, HloInstruction* operand); 757 // Returns a tag to be used in tracing. TracingTag()758 string TracingTag() const { return literal_.GetR1U8AsString(); } 759 // Returns a serialized representation of this instruction. 760 HloInstructionProto ToProto() const override; 761 762 private: 763 bool IdenticalSlowPath( 764 const HloInstruction& other, 765 const std::function<bool(const HloComputation*, const HloComputation*)>& 766 eq_computations) const override; 767 // Implementation for non-common logic of CloneWithNewOperands. 768 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 769 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 770 HloCloneContext* context) const override; 771 Literal literal_; 772 }; 773 774 class HloFusionInstruction : public HloInstruction { 775 public: 776 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 777 HloInstruction* fused_root); 778 779 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, 780 absl::Span<HloInstruction* const> operands, 781 HloComputation* fusion_computation); 782 783 string ToCategory() const override; 784 // Returns a serialized representation of this instruction. 785 HloInstructionProto ToProto() const override; 786 787 // Adds a new operand the fusion instruction. 788 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 789 790 // Merges the fused instructions from 'instruction_to_merge' into the 791 // fused instruction set of 'this', updating operands as necessary. 792 // 793 // Precondition: 'instruction_to_merge' must be an operand of 'this'. 794 void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); 795 796 // Merges the fused instructions from instruction_to_merge into the fused 797 // instruction set of 'this' and generates multioutput fusion instructions. 798 // All the users of instruction_to_merge will be redirected to 'this' 799 // instruction. instruction_to_merge will be removed from its parent 800 // computation. 801 void MergeFusionInstructionIntoMultiOutput( 802 HloFusionInstruction* instruction_to_merge); 803 804 // Fuses the given instruction in this fusion instruction. instruction_to_fuse 805 // is cloned and the clone is placed in the fusion 806 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather 807 // than moved to cleanly handle the case where the instruction has a use 808 // outside the fusion instruction. Moving such an instruction into a fusion 809 // instruction would violate the single-result invariant of HLO instructions 810 // and significantly complicate code generation. FuseInstruction(HloInstruction * instruction_to_fuse)811 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { 812 return FuseInstructionInternal(instruction_to_fuse); 813 } 814 815 // Fuses the given instruction in this fusion instruction and generates a 816 // multioutput fusion instruction. A clone of the instruction_to_fuse will 817 // be part of the output of fusion instructions. The users of 818 // instruction_to_fuse will be redirected to this fusion instructions. 819 // instruction_to_fuse is unchanged otherwise. FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)820 HloInstruction* FuseInstructionIntoMultiOutput( 821 HloInstruction* instruction_to_fuse) { 822 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); 823 } 824 825 // Returns the computation for this fused instruction. 826 HloComputation* fused_instructions_computation() const; 827 828 // Returns the root instruction of the fused expression contained within this 829 // fusion instruction. 830 HloInstruction* fused_expression_root() const; 831 832 // Returns the list of fused instructions inside this fusion instruction. The 833 // returned type is a range of HloInstruction*s. 834 const tensorflow::gtl::iterator_range<UnwrappingIterator< 835 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 836 fused_instructions() const; 837 838 const tensorflow::gtl::iterator_range< 839 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 840 fused_instructions(); 841 842 // Gets the number of instructions inside this fusion instruction. 843 int64 fused_instruction_count() const; 844 845 // Returns the fused parameter instruction in this fusion instruction 846 // corresponding to the given parameter number. 847 HloInstruction* fused_parameter(int64 parameter_number) const; 848 849 // Returns the vector of fused parameters inside this fusion instruction. 850 const std::vector<HloInstruction*>& fused_parameters() const; 851 852 // Returns true if this instruction is a fusion instruction that generates 853 // multiple outputs. IsMultiOutputFusion()854 const bool IsMultiOutputFusion() const { 855 return fused_expression_root()->opcode() == HloOpcode::kTuple; 856 } 857 fusion_kind()858 FusionKind fusion_kind() const { return fusion_kind_; } 859 set_fusion_kind(FusionKind kind)860 void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } 861 862 // If multiple operands are the same instruction, keeps only one of them. 863 Status DeduplicateFusionOperands(); 864 865 private: 866 // Fuses the given instruction into this fusion instruction. 867 // instruction_to_fuse is cloned and the clone is placed in the fusion 868 // instruction. The users of instruction_to_fuse will be redirected to this 869 // fusion instruction. instruction_to_fuse is unchanged otherwise. When 870 // add_output is true, a clone of the instruction_to_fuse will be added as 871 // additional output resulting in a multi-output fusion. 872 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, 873 bool add_output = false); 874 // Clones the given instruction_to_fuse and insert the clone into this fusion 875 // instruction. If add_output is true, a clone of instruction_to_fuse will 876 // be in the output of the this fusion instruction (part of the tuple of the 877 // fusion root). 878 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, 879 bool add_output = false); 880 881 bool IsElementwiseImpl( 882 const absl::optional<int64>& operand_idx) const override; 883 std::vector<string> ExtraAttributesToStringImpl( 884 const HloPrintOptions& options) const override; 885 bool IdenticalSlowPath( 886 const HloInstruction& other, 887 const std::function<bool(const HloComputation*, const HloComputation*)>& 888 eq_computations) const override; 889 uint64 InnerHash() const override; 890 891 // Implementation for non-common logic of CloneWithNewOperands. 892 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 893 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 894 HloCloneContext* context) const override; 895 896 // The type of the fusion. Used by kFusion only. 897 FusionKind fusion_kind_; 898 }; 899 900 class HloRngInstruction : public HloInstruction { 901 public: 902 explicit HloRngInstruction(const Shape& shape, 903 RandomDistribution distribution, 904 absl::Span<HloInstruction* const> parameters); 905 // Returns the random distribution for this rng node. random_distribution()906 RandomDistribution random_distribution() const { return distribution_; } 907 // Returns a serialized representation of this instruction. 908 HloInstructionProto ToProto() const override; 909 910 private: 911 bool IsElementwiseImpl( 912 const absl::optional<int64>& operand_idx) const override; 913 std::vector<string> ExtraAttributesToStringImpl( 914 const HloPrintOptions& options) const override; 915 bool IdenticalSlowPath( 916 const HloInstruction& other, 917 const std::function<bool(const HloComputation*, const HloComputation*)>& 918 eq_computations) const override; 919 // Implementation for non-common logic of CloneWithNewOperands. 920 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 921 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 922 HloCloneContext* context) const override; 923 924 // The distribution requested for random number generation. 925 RandomDistribution distribution_; 926 }; 927 928 class HloParameterInstruction : public HloInstruction { 929 public: 930 explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, 931 const string& name); parameter_number()932 int64 parameter_number() const { return parameter_number_; } 933 934 // Sets and gets the whether all replicas will receive the same parameter data 935 // for each leaf buffer in data parallelism. set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)936 void set_parameter_replicated_at_leaf_buffers( 937 absl::Span<const bool> parameter_replicated_at_leaf_buffers) { 938 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 939 parameter_replicated_at_leaf_buffers.size()); 940 parameter_replicated_at_leaf_buffers_.emplace( 941 parameter_replicated_at_leaf_buffers.begin(), 942 parameter_replicated_at_leaf_buffers.end()); 943 } set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)944 void set_parameter_replicated_at_leaf_buffers( 945 const std::vector<bool>& parameter_replicated_at_leaf_buffers) { 946 CHECK_EQ(ShapeUtil::GetLeafCount(shape()), 947 parameter_replicated_at_leaf_buffers.size()); 948 parameter_replicated_at_leaf_buffers_ = 949 parameter_replicated_at_leaf_buffers; 950 } 951 const absl::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()952 parameter_replicated_at_leaf_buffers() const { 953 return parameter_replicated_at_leaf_buffers_; 954 } 955 956 // Returns a serialized representation of this instruction. 957 HloInstructionProto ToProto() const override; 958 959 private: 960 std::vector<string> ExtraAttributesToStringImpl( 961 const HloPrintOptions& options) const override; 962 bool IdenticalSlowPath( 963 const HloInstruction& other, 964 const std::function<bool(const HloComputation*, const HloComputation*)>& 965 eq_computations) const override; 966 string OperandsToStringWithCanonicalNameMap( 967 const HloPrintOptions& options, 968 CanonicalNameMap* canonical_name_map) const override; 969 // Implementation for non-common logic of CloneWithNewOperands. 970 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 971 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 972 HloCloneContext* context) const override; 973 974 int64 parameter_number_ = 0; 975 976 // Specifies whether each buffer has the same parameter value on all replicas 977 // in data parallelism. 978 absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_; 979 }; 980 981 class HloGetTupleElementInstruction : public HloInstruction { 982 public: 983 explicit HloGetTupleElementInstruction(const Shape& shape, 984 HloInstruction* operand, int64 index); 985 // Returns the tuple index associated with this instruction. tuple_index()986 int64 tuple_index() const { return tuple_index_; } 987 // Sets the tuple index associated with this instruction. set_tuple_index(int64 new_tuple_index)988 void set_tuple_index(int64 new_tuple_index) { 989 tuple_index_ = new_tuple_index; 990 } 991 // Returns a serialized representation of this instruction. 992 HloInstructionProto ToProto() const override; 993 994 private: 995 std::vector<string> ExtraAttributesToStringImpl( 996 const HloPrintOptions& options) const override; 997 bool IdenticalSlowPath( 998 const HloInstruction& other, 999 const std::function<bool(const HloComputation*, const HloComputation*)>& 1000 eq_computations) const override; 1001 // Implementation for non-common logic of CloneWithNewOperands. 1002 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1003 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1004 HloCloneContext* context) const override; 1005 1006 int64 tuple_index_ = -1; 1007 }; 1008 1009 class HloReducePrecisionInstruction : public HloInstruction { 1010 public: 1011 explicit HloReducePrecisionInstruction(const Shape& shape, 1012 HloInstruction* operand, 1013 const int exponent_bits, 1014 const int mantissa_bits); 1015 // Returns the number of exponent bits for a reduce-precision node. exponent_bits()1016 int32 exponent_bits() const { return exponent_bits_; } 1017 // Returns the number of mantissa bits for a reduce-precision node. mantissa_bits()1018 int32 mantissa_bits() const { return mantissa_bits_; } 1019 // Returns a serialized representation of this instruction. 1020 HloInstructionProto ToProto() const override; 1021 1022 private: 1023 std::vector<string> ExtraAttributesToStringImpl( 1024 const HloPrintOptions& options) const override; 1025 bool IdenticalSlowPath( 1026 const HloInstruction& other, 1027 const std::function<bool(const HloComputation*, const HloComputation*)>& 1028 eq_computations) const override; 1029 // Implementation for non-common logic of CloneWithNewOperands. 1030 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1031 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1032 HloCloneContext* context) const override; 1033 1034 // The bit sizes for a reduce-precision operation. 1035 int32 exponent_bits_ = 0; 1036 int32 mantissa_bits_ = 0; 1037 }; 1038 1039 class HloInfeedInstruction : public HloInstruction { 1040 public: 1041 explicit HloInfeedInstruction(const Shape& infeed_shape, 1042 HloInstruction* token_operand, 1043 const string& config); 1044 // Returns the infeed configuration string. The infeed configuration includes 1045 // any metadata needed for the backend compiler (e.g., infeed buffer address) 1046 // and is target-dependent. infeed_config()1047 string infeed_config() const { return infeed_config_; } set_infeed_config(const string & config)1048 void set_infeed_config(const string& config) { infeed_config_ = config; } 1049 // Returns the shape of the data received by the infeed. This is not the same 1050 // as the shape of the infeed instruction which produces a tuple containing 1051 // the infeed data shape and a TOKEN. infeed_shape()1052 const Shape& infeed_shape() const { 1053 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); 1054 return ShapeUtil::GetSubshape(shape(), {0}); 1055 } 1056 // Returns a serialized representation of this instruction. 1057 HloInstructionProto ToProto() const override; 1058 1059 private: 1060 std::vector<string> ExtraAttributesToStringImpl( 1061 const HloPrintOptions& options) const override; 1062 bool IdenticalSlowPath( 1063 const HloInstruction& other, 1064 const std::function<bool(const HloComputation*, const HloComputation*)>& 1065 eq_computations) const override; 1066 // Implementation for non-common logic of CloneWithNewOperands. 1067 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1068 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1069 HloCloneContext* context) const override; 1070 1071 // The string representation of the infeed configuration. 1072 string infeed_config_; 1073 }; 1074 1075 class HloOutfeedInstruction : public HloInstruction { 1076 public: 1077 explicit HloOutfeedInstruction(const Shape& outfeed_shape, 1078 HloInstruction* operand, 1079 HloInstruction* token_operand, 1080 absl::string_view outfeed_config); 1081 // Returns the shape for the Outfeed instruction. outfeed_shape()1082 const Shape& outfeed_shape() const { return outfeed_shape_; } 1083 // Returns the config for the Outfeed instruction. outfeed_config()1084 const string& outfeed_config() const { return outfeed_config_; } 1085 // Returns a serialized representation of this instruction. 1086 HloInstructionProto ToProto() const override; 1087 1088 private: 1089 std::vector<string> ExtraAttributesToStringImpl( 1090 const HloPrintOptions& options) const override; 1091 bool IdenticalSlowPath( 1092 const HloInstruction& other, 1093 const std::function<bool(const HloComputation*, const HloComputation*)>& 1094 eq_computations) const override; 1095 // Implementation for non-common logic of CloneWithNewOperands. 1096 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1097 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1098 HloCloneContext* context) const override; 1099 1100 // Shape of outfeed request. 1101 Shape outfeed_shape_; 1102 // Outfeed configuration information, only present for kOutfeed. 1103 string outfeed_config_; 1104 }; 1105 1106 class HloConvolutionInstruction : public HloInstruction { 1107 public: 1108 explicit HloConvolutionInstruction( 1109 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 1110 int64 feature_group_count, int64 batch_group_count, const Window& window, 1111 const ConvolutionDimensionNumbers& dimension_numbers, 1112 const PrecisionConfig& precision_config); window()1113 const Window& window() const override { return window_; } set_window(const Window & window)1114 void set_window(const Window& window) override { window_ = window; } convolution_dimension_numbers()1115 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1116 return convolution_dimension_numbers_; 1117 } set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1118 void set_convolution_dimension_numbers( 1119 const ConvolutionDimensionNumbers& dnums) { 1120 convolution_dimension_numbers_ = dnums; 1121 } 1122 // The number of feature groups. Must be a divisor of the input feature 1123 // dimension and output feature dimension. feature_group_count()1124 int64 feature_group_count() const { return feature_group_count_; } set_feature_group_count(int64 num_feature_groups)1125 void set_feature_group_count(int64 num_feature_groups) { 1126 feature_group_count_ = num_feature_groups; 1127 } 1128 // The number of batch groups. Must be a divisor of the input batch dimension. batch_group_count()1129 int64 batch_group_count() const { return batch_group_count_; } set_batch_group_count(int64 num_batch_groups)1130 void set_batch_group_count(int64 num_batch_groups) { 1131 batch_group_count_ = num_batch_groups; 1132 } 1133 1134 // Returns the information used to tell the implementation information about 1135 // what sort of precision is requested. The meaning of the field is backend 1136 // specific. At the moment, it is only supported for kConvolution and kDot. 1137 // Transformations on one kDot or kConvolution to another will preserve this 1138 // information. Transformations to other HLOs will not preserve this 1139 // information but it is presumed that the alternate lowering is strictly 1140 // superior. precision_config()1141 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1142 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1143 1144 string ToCategory() const override; 1145 // Returns a serialized representation of this instruction. 1146 HloInstructionProto ToProto() const override; 1147 1148 private: 1149 std::vector<string> ExtraAttributesToStringImpl( 1150 const HloPrintOptions& options) const override; 1151 bool IdenticalSlowPath( 1152 const HloInstruction& other, 1153 const std::function<bool(const HloComputation*, const HloComputation*)>& 1154 eq_computations) const override; 1155 // Implementation for non-common logic of CloneWithNewOperands. 1156 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1157 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1158 HloCloneContext* context) const override; 1159 // The number of feature groups. Must be a divisor of the input feature 1160 // dimension and output feature dimension. 1161 int64 feature_group_count_; 1162 // The number of batch groups. Must be a divisor of the input batch dimension. 1163 int64 batch_group_count_; 1164 // Describes the window used for a convolution. 1165 Window window_; 1166 // Describes the dimension numbers used for a convolution. 1167 ConvolutionDimensionNumbers convolution_dimension_numbers_; 1168 // Information used to communicate to the implementation about the algorithm 1169 // used to produce results. See the documentation on precision_config(). 1170 PrecisionConfig precision_config_; 1171 }; 1172 1173 class HloReduceWindowInstruction : public HloInstruction { 1174 public: 1175 explicit HloReduceWindowInstruction(const Shape& shape, 1176 HloInstruction* operand, 1177 HloInstruction* init_value, 1178 const Window& window, 1179 HloComputation* reduce_computation); window()1180 const Window& window() const override { return window_; } set_window(const Window & window)1181 void set_window(const Window& window) override { window_ = window; } 1182 // Returns a serialized representation of this instruction. 1183 HloInstructionProto ToProto() const override; 1184 1185 private: 1186 std::vector<string> ExtraAttributesToStringImpl( 1187 const HloPrintOptions& options) const override; 1188 bool IdenticalSlowPath( 1189 const HloInstruction& other, 1190 const std::function<bool(const HloComputation*, const HloComputation*)>& 1191 eq_computations) const override; 1192 // Implementation for non-common logic of CloneWithNewOperands. 1193 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1194 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1195 HloCloneContext* context) const override; 1196 Window window_; 1197 }; 1198 1199 class HloSelectAndScatterInstruction : public HloInstruction { 1200 public: 1201 explicit HloSelectAndScatterInstruction( 1202 const Shape& shape, HloInstruction* operand, HloComputation* select, 1203 const Window& window, HloInstruction* source, HloInstruction* init_value, 1204 HloComputation* scatter); window()1205 const Window& window() const override { return window_; } set_window(const Window & window)1206 void set_window(const Window& window) override { window_ = window; } 1207 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The 1208 // setters should only be called by HloModule or HloComputation methods. select()1209 HloComputation* select() const { 1210 return called_computations()[kSelectComputationIndex]; 1211 } 1212 scatter()1213 HloComputation* scatter() const { 1214 return called_computations()[kScatterComputationIndex]; 1215 } 1216 set_select(HloComputation * computation)1217 void set_select(HloComputation* computation) { 1218 // Don't allow changing the computation for fused instructions so we don't 1219 // have to recompute called_instructions for the entire fusion instruction. 1220 CHECK(!IsFused()); 1221 set_called_computation(kSelectComputationIndex, computation); 1222 } 1223 set_scatter(HloComputation * computation)1224 void set_scatter(HloComputation* computation) { 1225 // Don't allow changing the computation for fused instructions so we don't 1226 // have to recompute called_instructions for the entire fusion instruction. 1227 CHECK(!IsFused()); 1228 set_called_computation(kScatterComputationIndex, computation); 1229 } 1230 // Returns a serialized representation of this instruction. 1231 HloInstructionProto ToProto() const override; 1232 1233 private: 1234 std::vector<string> ExtraAttributesToStringImpl( 1235 const HloPrintOptions& options) const override; 1236 bool IdenticalSlowPath( 1237 const HloInstruction& other, 1238 const std::function<bool(const HloComputation*, const HloComputation*)>& 1239 eq_computations) const override; 1240 // Implementation for non-common logic of CloneWithNewOperands. 1241 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1242 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1243 HloCloneContext* context) const override; 1244 Window window_; 1245 }; 1246 1247 class HloCustomCallInstruction : public HloInstruction { 1248 public: 1249 HloCustomCallInstruction(const Shape& shape, 1250 absl::Span<HloInstruction* const> operands, 1251 absl::string_view custom_call_target, string opaque); 1252 1253 // Constructor for a custom call with constrained layout. 'shape' and 1254 // 'operands_with_layout' must all have layouts. 1255 HloCustomCallInstruction(const Shape& shape, 1256 absl::Span<HloInstruction* const> operands, 1257 absl::string_view custom_call_target, string opaque, 1258 absl::Span<const Shape> operand_shapes_with_layout); 1259 window()1260 const Window& window() const override { 1261 CHECK(window_ != nullptr); 1262 return *window_; 1263 } 1264 set_window(const Window & window)1265 void set_window(const Window& window) override { 1266 window_ = absl::make_unique<Window>(window); 1267 } 1268 convolution_dimension_numbers()1269 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1270 CHECK(convolution_dimension_numbers_ != nullptr); 1271 return *convolution_dimension_numbers_; 1272 } 1273 set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1274 void set_convolution_dimension_numbers( 1275 const ConvolutionDimensionNumbers& dnums) { 1276 convolution_dimension_numbers_ = 1277 absl::make_unique<ConvolutionDimensionNumbers>(dnums); 1278 } 1279 // TODO(jpienaar): Remove this accessor in the follow up. opaque()1280 const string& opaque() const { return raw_backend_config_string(); } custom_call_target()1281 const string& custom_call_target() const { return custom_call_target_; } set_feature_group_count(int64 feature_group_count)1282 void set_feature_group_count(int64 feature_group_count) { 1283 feature_group_count_ = feature_group_count; 1284 } set_batch_group_count(int64 batch_group_count)1285 void set_batch_group_count(int64 batch_group_count) { 1286 batch_group_count_ = batch_group_count; 1287 } 1288 // Sets whether this custom call has a side-effect - by default a custom call 1289 // has no side-effects. set_custom_call_has_side_effect(bool custom_call_has_side_effect)1290 void set_custom_call_has_side_effect(bool custom_call_has_side_effect) { 1291 custom_call_has_side_effect_ = custom_call_has_side_effect; 1292 } feature_group_count()1293 int64 feature_group_count() const { return feature_group_count_; } batch_group_count()1294 int64 batch_group_count() const { return batch_group_count_; } custom_call_has_side_effect()1295 bool custom_call_has_side_effect() const { 1296 return custom_call_has_side_effect_; 1297 } 1298 // Returns a serialized representation of this instruction. 1299 HloInstructionProto ToProto() const override; 1300 1301 // Returns whether the result and operand layouts are constrained. layout_constrained()1302 bool layout_constrained() const { return layout_constrained_; } 1303 1304 // Returns the shapes (with layout) of the operands. CHECKs if this custom 1305 // call does not have constrained layouts. operand_shapes_with_layout()1306 const std::vector<Shape>& operand_shapes_with_layout() const { 1307 CHECK(layout_constrained()); 1308 return operand_shapes_with_layout_; 1309 } 1310 1311 private: 1312 std::vector<string> ExtraAttributesToStringImpl( 1313 const HloPrintOptions& options) const override; 1314 bool IdenticalSlowPath( 1315 const HloInstruction& other, 1316 const std::function<bool(const HloComputation*, const HloComputation*)>& 1317 eq_computations) const override; 1318 // Implementation for non-common logic of CloneWithNewOperands. 1319 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1320 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1321 HloCloneContext* context) const override; 1322 // Name of a global symbol to call. 1323 string custom_call_target_; 1324 // Describes the window in a windowed operation such as convolution. 1325 std::unique_ptr<Window> window_; 1326 // Describes the dimension numbers used for a convolution. 1327 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; 1328 // The number of feature groups. This is used for grouped convolutions. 1329 int64 feature_group_count_; 1330 int64 batch_group_count_; 1331 // Whether the result and operand layouts are constrained. 1332 bool layout_constrained_; 1333 // For layout-constrained custom calls, this vector holds the shape with 1334 // layout for each operand. 1335 std::vector<Shape> operand_shapes_with_layout_; 1336 // Whether this custom call has a side-effect. 1337 bool custom_call_has_side_effect_; 1338 }; 1339 1340 class HloPadInstruction : public HloInstruction { 1341 public: 1342 explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, 1343 HloInstruction* padding_value, 1344 const PaddingConfig& padding_config); 1345 // Returns the padding configuration for a pad node. padding_config()1346 const PaddingConfig& padding_config() const { return padding_config_; } 1347 // Returns the padding value. padding_value()1348 const HloInstruction* padding_value() const { return operand(1); } mutable_padding_value()1349 HloInstruction* mutable_padding_value() { return mutable_operand(1); } 1350 // Returns a serialized representation of this instruction. 1351 HloInstructionProto ToProto() const override; 1352 1353 private: 1354 std::vector<string> ExtraAttributesToStringImpl( 1355 const HloPrintOptions& options) const override; 1356 bool IdenticalSlowPath( 1357 const HloInstruction& other, 1358 const std::function<bool(const HloComputation*, const HloComputation*)>& 1359 eq_computations) const override; 1360 // Implementation for non-common logic of CloneWithNewOperands. 1361 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1362 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1363 HloCloneContext* context) const override; 1364 1365 // The padding configuration that describes the edge padding and interior 1366 // padding of this pad instruction. 1367 PaddingConfig padding_config_; 1368 }; 1369 1370 class HloDynamicIndexInstruction : public HloInstruction { 1371 public: HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1372 explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) 1373 : HloInstruction(opcode, shape) {} 1374 virtual int64 first_index_operand_number() const = 0; 1375 1376 // Returns a subspan of operands which represent the start indices. index_operands()1377 absl::Span<HloInstruction* const> index_operands() const { 1378 return absl::MakeSpan(operands()).subspan(first_index_operand_number()); 1379 } 1380 1381 // Returns the shapes of the index operands. index_shapes()1382 std::vector<Shape> index_shapes() const { 1383 std::vector<Shape> shapes; 1384 auto indices = index_operands(); 1385 for (const HloInstruction* index : indices) { 1386 shapes.push_back(index->shape()); 1387 } 1388 return shapes; 1389 } 1390 }; 1391 1392 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { 1393 public: 1394 explicit HloDynamicSliceInstruction(const Shape& shape, 1395 HloInstruction* operand, 1396 HloInstruction* start_indices, 1397 absl::Span<const int64> slice_sizes); 1398 explicit HloDynamicSliceInstruction( 1399 const Shape& shape, HloInstruction* operand, 1400 absl::Span<HloInstruction* const> start_indices, 1401 absl::Span<const int64> slice_sizes); 1402 // Old methods kept for smooth subclassing transition END. 1403 // Returns the size of the slice in the given dimension for a dynamic 1404 // slice node. slice_sizes(int64 dimension)1405 int64 slice_sizes(int64 dimension) const { 1406 return dynamic_slice_sizes_[dimension]; 1407 } dynamic_slice_sizes()1408 const std::vector<int64>& dynamic_slice_sizes() const { 1409 return dynamic_slice_sizes_; 1410 } 1411 // Returns a serialized representation of this instruction. 1412 HloInstructionProto ToProto() const override; 1413 first_index_operand_number()1414 int64 first_index_operand_number() const override { return 1; } 1415 1416 private: 1417 std::vector<string> ExtraAttributesToStringImpl( 1418 const HloPrintOptions& options) const override; 1419 bool IdenticalSlowPath( 1420 const HloInstruction& other, 1421 const std::function<bool(const HloComputation*, const HloComputation*)>& 1422 eq_computations) const override; 1423 // Implementation for non-common logic of CloneWithNewOperands. 1424 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1425 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1426 HloCloneContext* context) const override; 1427 1428 // Describes the [start, start + size) range size for a dynamic slice 1429 // ('start' is specified dynamically in the second operand of the operation). 1430 std::vector<int64> dynamic_slice_sizes_; 1431 }; 1432 1433 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { 1434 public: 1435 explicit HloDynamicUpdateSliceInstruction(const Shape& shape, 1436 HloInstruction* operand, 1437 HloInstruction* update, 1438 HloInstruction* start_indices); 1439 explicit HloDynamicUpdateSliceInstruction( 1440 const Shape& shape, HloInstruction* operand, HloInstruction* update, 1441 absl::Span<HloInstruction* const> start_indices); 1442 first_index_operand_number()1443 int64 first_index_operand_number() const override { return 2; } 1444 }; 1445 1446 class HloGatherInstruction : public HloInstruction { 1447 public: 1448 explicit HloGatherInstruction( 1449 const Shape& shape, HloInstruction* operand, 1450 HloInstruction* start_indices, 1451 const GatherDimensionNumbers& gather_dim_numbers, 1452 absl::Span<const int64> slice_sizes, bool indices_are_sorted); gather_dimension_numbers()1453 const GatherDimensionNumbers& gather_dimension_numbers() const { 1454 CHECK(gather_dimension_numbers_ != nullptr); 1455 return *gather_dimension_numbers_; 1456 } gather_slice_sizes()1457 absl::Span<const int64> gather_slice_sizes() const { 1458 return gather_slice_sizes_; 1459 } indices_are_sorted()1460 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1461 void set_indices_are_sorted(bool indices_are_sorted) { 1462 indices_are_sorted_ = indices_are_sorted; 1463 } 1464 // Returns a serialized representation of this instruction. 1465 HloInstructionProto ToProto() const override; 1466 1467 // Creates an instance of GatherDimensionNumbers. 1468 static GatherDimensionNumbers MakeGatherDimNumbers( 1469 absl::Span<const int64> offset_dims, 1470 absl::Span<const int64> collapsed_slice_dims, 1471 absl::Span<const int64> start_index_map, int64 index_vector_dim); 1472 // Returns the dump string of the given gather dimension numbers. 1473 static string GatherDimensionNumbersToString( 1474 const GatherDimensionNumbers& gather_dimension_numbers); 1475 1476 private: 1477 std::vector<string> ExtraAttributesToStringImpl( 1478 const HloPrintOptions& options) const override; 1479 bool IdenticalSlowPath( 1480 const HloInstruction& other, 1481 const std::function<bool(const HloComputation*, const HloComputation*)>& 1482 eq_computations) const override; 1483 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1484 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1485 HloCloneContext* context) const override; 1486 1487 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; 1488 std::vector<int64> gather_slice_sizes_; 1489 bool indices_are_sorted_; 1490 }; 1491 1492 class HloScatterInstruction : public HloInstruction { 1493 public: 1494 explicit HloScatterInstruction( 1495 const Shape& shape, HloInstruction* operand, 1496 HloInstruction* scatter_indices, HloInstruction* updates, 1497 HloComputation* update_computation, 1498 const ScatterDimensionNumbers& scatter_dim_numbers, 1499 bool indices_are_sorted, bool unique_indices); scatter_dimension_numbers()1500 const ScatterDimensionNumbers& scatter_dimension_numbers() const { 1501 CHECK(scatter_dimension_numbers_ != nullptr); 1502 return *scatter_dimension_numbers_; 1503 } indices_are_sorted()1504 bool indices_are_sorted() const { return indices_are_sorted_; } set_indices_are_sorted(bool indices_are_sorted)1505 void set_indices_are_sorted(bool indices_are_sorted) { 1506 indices_are_sorted_ = indices_are_sorted; 1507 } unique_indices()1508 bool unique_indices() const override { return unique_indices_; } 1509 // Returns a serialized representation of this instruction. 1510 HloInstructionProto ToProto() const override; 1511 1512 // Creates an instance of ScatterDimensionNumbers. 1513 static ScatterDimensionNumbers MakeScatterDimNumbers( 1514 absl::Span<const int64> update_window_dims, 1515 absl::Span<const int64> inserted_window_dims, 1516 absl::Span<const int64> scatter_dims_to_operand_dims, 1517 int64 index_vector_dim); 1518 // Returns the dump string of the given scatter dimension numbers. 1519 static string ScatterDimensionNumbersToString( 1520 const ScatterDimensionNumbers& scatter_dimension_numbers); 1521 1522 private: 1523 std::vector<string> ExtraAttributesToStringImpl( 1524 const HloPrintOptions& options) const override; 1525 bool IdenticalSlowPath( 1526 const HloInstruction& other, 1527 const std::function<bool(const HloComputation*, const HloComputation*)>& 1528 eq_computations) const override; 1529 // Implementation for non-common logic of CloneWithNewOperands. 1530 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1531 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1532 HloCloneContext* context) const override; 1533 1534 std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; 1535 bool indices_are_sorted_; 1536 bool unique_indices_; 1537 }; 1538 1539 class HloIotaInstruction : public HloInstruction { 1540 public: 1541 explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); 1542 // Returns the dimension sizes or numbers associated with this instruction. iota_dimension()1543 int64 iota_dimension() const { return iota_dimension_; } 1544 // Returns a serialized representation of this instruction. 1545 HloInstructionProto ToProto() const override; 1546 1547 private: 1548 std::vector<string> ExtraAttributesToStringImpl( 1549 const HloPrintOptions& options) const override; 1550 bool IdenticalSlowPath( 1551 const HloInstruction& other, 1552 const std::function<bool(const HloComputation*, const HloComputation*)>& 1553 eq_computations) const override; 1554 // Implementation for non-common logic of CloneWithNewOperands. 1555 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1556 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1557 HloCloneContext* context) const override; 1558 1559 const int64 iota_dimension_; 1560 }; 1561 1562 class HloDotInstruction : public HloInstruction { 1563 public: 1564 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 1565 // dimensions specified in 'dimension_numbers'. 1566 explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, 1567 HloInstruction* rhs, 1568 const DotDimensionNumbers& dimension_numbers, 1569 const PrecisionConfig& precision_config); 1570 1571 // Returns data on the dimension numbers used for a dot operation. dot_dimension_numbers()1572 const DotDimensionNumbers& dot_dimension_numbers() const { 1573 return dot_dimension_numbers_; 1574 } 1575 1576 // Returns the information used to tell the implementation information about 1577 // what sort of precision is requested. The meaning of the field is backend 1578 // specific. At the moment, it is only supported for kConvolution and kDot. 1579 // Transformations on one kDot or kConvolution to another will preserve this 1580 // information. Transformations to other HLOs will not preserve this 1581 // information but it is presumed that the alternate lowering is strictly 1582 // superior. precision_config()1583 const PrecisionConfig& precision_config() const { return precision_config_; } mutable_precision_config()1584 PrecisionConfig* mutable_precision_config() { return &precision_config_; } 1585 1586 // Returns a serialized representation of this instruction. 1587 HloInstructionProto ToProto() const override; 1588 1589 private: 1590 std::vector<string> ExtraAttributesToStringImpl( 1591 const HloPrintOptions& options) const override; 1592 bool IdenticalSlowPath( 1593 const HloInstruction& other, 1594 const std::function<bool(const HloComputation*, const HloComputation*)>& 1595 eq_computations) const override; 1596 // Implementation for non-common logic of CloneWithNewOperands. 1597 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1598 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1599 HloCloneContext* context) const override; 1600 // Returns the dump string of the dot dimension numbers. 1601 string DotDimensionNumbersToString() const; 1602 1603 // Describes the dimension numbers used for a dot. 1604 DotDimensionNumbers dot_dimension_numbers_; 1605 1606 // Information used to communicate to the implementation about the algorithm 1607 // used to produce results. See the documentation on precision_config(). 1608 PrecisionConfig precision_config_; 1609 }; 1610 1611 class HloDomainInstruction : public HloInstruction { 1612 public: 1613 explicit HloDomainInstruction( 1614 const Shape& shape, HloInstruction* operand, 1615 std::unique_ptr<DomainMetadata> operand_side_metadata, 1616 std::unique_ptr<DomainMetadata> user_side_metadata); 1617 1618 // Returns a serialized representation of this instruction. 1619 HloInstructionProto ToProto() const override; 1620 1621 // Retrieves the operand side metadata of a kDomain instruction. operand_side_metadata()1622 const DomainMetadata& operand_side_metadata() const { 1623 return *operand_side_metadata_; 1624 } 1625 // Retrieves the user side metadata of a kDomain instruction. user_side_metadata()1626 const DomainMetadata& user_side_metadata() const { 1627 return *user_side_metadata_; 1628 } 1629 1630 private: 1631 std::vector<string> ExtraAttributesToStringImpl( 1632 const HloPrintOptions& options) const override; 1633 bool IdenticalSlowPath( 1634 const HloInstruction& other, 1635 const std::function<bool(const HloComputation*, const HloComputation*)>& 1636 eq_computations) const override; 1637 // Implementation for non-common logic of CloneWithNewOperands. 1638 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1639 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1640 HloCloneContext* context) const override; 1641 1642 std::unique_ptr<DomainMetadata> operand_side_metadata_; 1643 std::unique_ptr<DomainMetadata> user_side_metadata_; 1644 }; 1645 1646 class HloGetDimensionSizeInstruction : public HloInstruction { 1647 public: 1648 explicit HloGetDimensionSizeInstruction(const Shape& shape, 1649 HloInstruction* operand, 1650 int64 dimension); 1651 1652 // Returns the dimension sizes or numbers associated with this instruction. dimension()1653 int64 dimension() const { return dimension_; } 1654 // Returns a serialized representation of this instruction. 1655 HloInstructionProto ToProto() const override; 1656 1657 private: 1658 std::vector<string> ExtraAttributesToStringImpl( 1659 const HloPrintOptions& options) const override; 1660 bool IdenticalSlowPath( 1661 const HloInstruction& other, 1662 const std::function<bool(const HloComputation*, const HloComputation*)>& 1663 eq_computations) const override; 1664 // Implementation for non-common logic of CloneWithNewOperands. 1665 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1666 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1667 HloCloneContext* context) const override; 1668 1669 int64 dimension_; 1670 }; 1671 1672 class HloSetDimensionSizeInstruction : public HloInstruction { 1673 public: 1674 explicit HloSetDimensionSizeInstruction(const Shape& shape, 1675 HloInstruction* operand, 1676 HloInstruction* val, int64 dimension); 1677 1678 // Returns the dimension sizes or numbers associated with this instruction. dimension()1679 int64 dimension() const { return dimension_; } 1680 // Returns a serialized representation of this instruction. 1681 HloInstructionProto ToProto() const override; 1682 1683 private: 1684 std::vector<string> ExtraAttributesToStringImpl( 1685 const HloPrintOptions& options) const override; 1686 bool IdenticalSlowPath( 1687 const HloInstruction& other, 1688 const std::function<bool(const HloComputation*, const HloComputation*)>& 1689 eq_computations) const override; 1690 // Implementation for non-common logic of CloneWithNewOperands. 1691 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1692 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1693 HloCloneContext* context) const override; 1694 1695 int64 dimension_; 1696 }; 1697 1698 class HloRngGetAndUpdateStateInstruction : public HloInstruction { 1699 public: 1700 explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta); 1701 1702 // Returns the delta value. delta()1703 int64 delta() const { return delta_; } set_delta(int64 delta)1704 void set_delta(int64 delta) { delta_ = delta; } 1705 // Returns a serialized representation of this instruction. 1706 HloInstructionProto ToProto() const override; 1707 1708 private: 1709 std::vector<string> ExtraAttributesToStringImpl( 1710 const HloPrintOptions& options) const override; 1711 bool IdenticalSlowPath( 1712 const HloInstruction& other, 1713 const std::function<bool(const HloComputation*, const HloComputation*)>& 1714 eq_computations) const override; 1715 // Implementation for non-common logic of CloneWithNewOperands. 1716 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( 1717 const Shape& shape, absl::Span<HloInstruction* const> new_operands, 1718 HloCloneContext* context) const override; 1719 1720 int64 delta_; 1721 }; 1722 1723 } // namespace xla 1724 1725 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ 1726