1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // All HloInstruction subclasses are put in this file.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20 
21 #include <functional>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 
37 namespace xla {
38 
39 // Base class for instructions with a dimensions vector.
40 class HloDimensionsInstruction : public HloInstruction {
41  public:
dimensions()42   absl::Span<const int64_t> dimensions() const override { return dimensions_; }
43 
mutable_dimensions()44   std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; }
45 
46   HloInstructionProto ToProto() const override;
47 
ClassOf(const HloInstruction * hlo)48   static bool ClassOf(const HloInstruction* hlo) {
49     switch (hlo->opcode()) {
50       case HloOpcode::kBroadcast:
51       case HloOpcode::kConcatenate:
52       case HloOpcode::kReduce:
53       case HloOpcode::kReverse:
54       case HloOpcode::kSort:
55       case HloOpcode::kTranspose:
56         return true;
57       default:
58         return false;
59     }
60   }
61 
62  protected:
HloDimensionsInstruction(HloOpcode opcode,const Shape & shape,absl::Span<const int64_t> dimensions)63   HloDimensionsInstruction(HloOpcode opcode, const Shape& shape,
64                            absl::Span<const int64_t> dimensions)
65       : HloInstruction(opcode, shape),
66         dimensions_(dimensions.begin(), dimensions.end()) {}
67   std::vector<std::string> ExtraAttributesToStringImpl(
68       const HloPrintOptions& options) const override;
69 
70   bool IdenticalSlowPath(
71       const HloInstruction& other,
72       const std::function<bool(const HloComputation*, const HloComputation*)>&
73           eq_computations) const override;
74 
75   std::vector<int64_t> dimensions_;
76 };
77 
78 class HloBatchNormInstruction : public HloInstruction {
79  public:
80   // Returns feature_index field associated with the instruction. The index
81   // represents the index of the feature dimension.
feature_index()82   int64_t feature_index() const { return feature_index_; }
83 
84   // Returns a epsilon value associated with the instruction. The is a small
85   // number added to the variance to avoid divide-by-zero error.
epsilon()86   float epsilon() const { return epsilon_; }
87 
88   // Returns a serialized representation of this instruction.
89   HloInstructionProto ToProto() const override;
90 
ClassOf(const HloInstruction * hlo)91   static bool ClassOf(const HloInstruction* hlo) {
92     switch (hlo->opcode()) {
93       case HloOpcode::kBatchNormGrad:
94       case HloOpcode::kBatchNormInference:
95       case HloOpcode::kBatchNormTraining:
96         return true;
97       default:
98         return false;
99     }
100   }
101 
102  protected:
103   explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
104                                    HloInstruction* operand,
105                                    HloInstruction* scale, float epsilon,
106                                    int64_t feature_index);
107 
108  private:
109   std::vector<std::string> ExtraAttributesToStringImpl(
110       const HloPrintOptions& options) const override;
111   bool IdenticalSlowPath(
112       const HloInstruction& other,
113       const std::function<bool(const HloComputation*, const HloComputation*)>&
114           eq_computations) const override;
115   // A small float number added to the variance to avoid divide-by-zero error.
116   float epsilon_ = 0.0f;
117 
118   // An integer value representing the index of the feature dimension.
119   int64_t feature_index_ = -1;
120 };
121 
122 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
123  public:
124   explicit HloBatchNormTrainingInstruction(
125       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
126       HloInstruction* offset, float epsilon, int64_t feature_index);
127 
ClassOf(const HloInstruction * hlo)128   static bool ClassOf(const HloInstruction* hlo) {
129     return hlo->opcode() == HloOpcode::kBatchNormTraining;
130   }
131 
132  private:
133   // Implementation for non-common logic of CloneWithNewOperands.
134   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
135       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
136       HloCloneContext* context) const override;
137 };
138 
139 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
140  public:
141   explicit HloBatchNormInferenceInstruction(
142       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
143       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
144       float epsilon, int64_t feature_index);
145 
ClassOf(const HloInstruction * hlo)146   static bool ClassOf(const HloInstruction* hlo) {
147     return hlo->opcode() == HloOpcode::kBatchNormInference;
148   }
149 
150  private:
151   // Implementation for non-common logic of CloneWithNewOperands.
152   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
153       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
154       HloCloneContext* context) const override;
155 };
156 
157 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
158  public:
159   explicit HloBatchNormGradInstruction(
160       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
161       HloInstruction* mean, HloInstruction* variance,
162       HloInstruction* grad_output, float epsilon, int64_t feature_index);
163 
ClassOf(const HloInstruction * hlo)164   static bool ClassOf(const HloInstruction* hlo) {
165     return hlo->opcode() == HloOpcode::kBatchNormGrad;
166   }
167 
168  private:
169   // Implementation for non-common logic of CloneWithNewOperands.
170   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
171       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
172       HloCloneContext* context) const override;
173 };
174 
175 class HloFftInstruction : public HloInstruction {
176  public:
177   explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
178                              FftType fft_type,
179                              absl::Span<const int64_t> fft_length);
fft_type()180   FftType fft_type() const { return fft_type_; }
181 
fft_length()182   const std::vector<int64_t>& fft_length() const { return fft_length_; }
183 
184   // Returns a serialized representation of this instruction.
185   HloInstructionProto ToProto() const override;
186 
ClassOf(const HloInstruction * hlo)187   static bool ClassOf(const HloInstruction* hlo) {
188     return hlo->opcode() == HloOpcode::kFft;
189   }
190 
191  private:
192   std::vector<std::string> ExtraAttributesToStringImpl(
193       const HloPrintOptions& options) const override;
194   bool IdenticalSlowPath(
195       const HloInstruction& other,
196       const std::function<bool(const HloComputation*, const HloComputation*)>&
197           eq_computations) const override;
198 
199   // Implementation for non-common logic of CloneWithNewOperands.
200   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
201       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
202       HloCloneContext* context) const override;
203 
204   // Describes FFT type for an FFT instruction.
205   FftType fft_type_ = FftType::FFT;
206 
207   // Indicates the FFT length for an FFT instruction.
208   std::vector<int64_t> fft_length_;
209 };
210 
211 class HloAsyncInstruction : public HloInstruction {
212  public:
213   HloAsyncInstruction(
214       HloOpcode opcode, const Shape& shape,
215       absl::Span<HloInstruction* const> operands,
216       HloComputation* async_computation,
217       std::optional<int64_t> async_group_id = std::nullopt,
218       absl::string_view async_execution_thread = kMainExecutionThread);
219   HloAsyncInstruction(
220       HloOpcode opcode, const Shape& shape, HloInstruction* operand,
221       HloComputation* async_computation,
222       std::optional<int64_t> async_group_id = std::nullopt,
223       absl::string_view async_execution_thread = kMainExecutionThread);
224 
225   ~HloAsyncInstruction() override;
226   // When an async instruction is being destructed, remove it from the vector of
227   // pointers of its called computation, to avoid referencing freed memory.
228   void ClearAsyncComputationInstruction();
229 
230   HloInstruction* async_wrapped_instruction() const;
231   HloOpcode async_wrapped_opcode() const;
232 
233   // Async group id is a unique id given to a group of async operations that
234   // consist of one async start, one async done, and zero or more async update
235   // operations. The async group participates in a single async operation. The
236   // async operation canonicalizer pass assigns async group ids.
async_group_id()237   std::optional<int64_t> async_group_id() const { return async_group_id_; }
238 
239   // Async thread name is a unique thread name for one or more async groups.
240   // Typically one HLO module contains a main thread as well as one or more
241   // parallel threads.
async_execution_thread()242   absl::string_view async_execution_thread() const {
243     return async_execution_thread_;
244   }
245   void set_async_group_id(std::optional<int64_t> async_group_id);
246   void set_async_execution_thread(absl::string_view async_execution_thread);
247   HloInstructionProto ToProto() const override;
248 
ClassOf(const HloInstruction * hlo)249   static bool ClassOf(const HloInstruction* hlo) {
250     switch (hlo->opcode()) {
251       case HloOpcode::kAsyncStart:
252       case HloOpcode::kAsyncUpdate:
253       case HloOpcode::kAsyncDone:
254         return true;
255       default:
256         return false;
257     }
258   }
259 
260  private:
261   std::vector<std::string> ExtraAttributesToStringImpl(
262       const HloPrintOptions& options) const override;
263   bool IdenticalSlowPath(
264       const HloInstruction& other,
265       const std::function<bool(const HloComputation*, const HloComputation*)>&
266           eq_computations) const override;
267   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
268       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
269       HloCloneContext* context) const override;
270   std::optional<int64_t> async_group_id_;
271   std::string async_execution_thread_ = kMainExecutionThread;
272 };
273 
274 class HloCopyStartInstruction : public HloInstruction {
275  public:
276   explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
277                                    bool is_cross_program_prefetch);
278 
is_cross_program_prefetch()279   bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
280   HloInstructionProto ToProto() const override;
281 
ClassOf(const HloInstruction * hlo)282   static bool ClassOf(const HloInstruction* hlo) {
283     return hlo->opcode() == HloOpcode::kCopyStart;
284   }
285 
286  private:
287   std::vector<std::string> ExtraAttributesToStringImpl(
288       const HloPrintOptions& options) const override;
289   bool IdenticalSlowPath(
290       const HloInstruction& other,
291       const std::function<bool(const HloComputation*, const HloComputation*)>&
292           eq_computations) const override;
293   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
294       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
295       HloCloneContext* context) const override;
296 
297   bool is_cross_program_prefetch_;
298 };
299 
300 class HloCompareInstruction : public HloInstruction {
301  public:
302   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
303                                  HloInstruction* rhs,
304                                  ComparisonDirection direction,
305                                  std::optional<Comparison::Type> type);
direction()306   ComparisonDirection direction() const { return compare_.GetDirection(); }
order()307   ComparisonOrder order() const { return compare_.GetOrder(); }
type()308   Comparison::Type type() const { return compare_.GetType(); }
309   HloInstructionProto ToProto() const override;
310 
ClassOf(const HloInstruction * hlo)311   static bool ClassOf(const HloInstruction* hlo) {
312     return hlo->opcode() == HloOpcode::kCompare;
313   }
314 
315  private:
316   std::vector<std::string> ExtraAttributesToStringImpl(
317       const HloPrintOptions& options) const override;
318   bool IdenticalSlowPath(
319       const HloInstruction& other,
320       const std::function<bool(const HloComputation*, const HloComputation*)>&
321           eq_computations) const override;
322   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
323       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
324       HloCloneContext* context) const override;
325 
326   Comparison compare_;
327 };
328 
329 class HloTriangularSolveInstruction : public HloInstruction {
330  public:
331   explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
332                                          HloInstruction* b,
333                                          const TriangularSolveOptions& options);
triangular_solve_options()334   const TriangularSolveOptions& triangular_solve_options() const {
335     return triangular_solve_options_;
336   }
337 
338   // Returns a serialized representation of this instruction.
339   HloInstructionProto ToProto() const override;
340 
ClassOf(const HloInstruction * hlo)341   static bool ClassOf(const HloInstruction* hlo) {
342     return hlo->opcode() == HloOpcode::kTriangularSolve;
343   }
344 
345  private:
346   std::vector<std::string> ExtraAttributesToStringImpl(
347       const HloPrintOptions& options) const override;
348   bool IdenticalSlowPath(
349       const HloInstruction& other,
350       const std::function<bool(const HloComputation*, const HloComputation*)>&
351           eq_computations) const override;
352 
353   // Implementation for non-common logic of CloneWithNewOperands.
354   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
355       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
356       HloCloneContext* context) const override;
357 
358   TriangularSolveOptions triangular_solve_options_;
359 };
360 
361 class HloCholeskyInstruction : public HloInstruction {
362  public:
363   explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
364                                   const CholeskyOptions& options);
cholesky_options()365   const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
366 
367   // Returns a serialized representation of this instruction.
368   HloInstructionProto ToProto() const override;
369 
ClassOf(const HloInstruction * hlo)370   static bool ClassOf(const HloInstruction* hlo) {
371     return hlo->opcode() == HloOpcode::kCholesky;
372   }
373 
374  private:
375   std::vector<std::string> ExtraAttributesToStringImpl(
376       const HloPrintOptions& options) const override;
377   bool IdenticalSlowPath(
378       const HloInstruction& other,
379       const std::function<bool(const HloComputation*, const HloComputation*)>&
380           eq_computations) const override;
381 
382   // Implementation for non-common logic of CloneWithNewOperands.
383   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
384       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
385       HloCloneContext* context) const override;
386 
387   CholeskyOptions cholesky_options_;
388 };
389 
390 // Class that represents instructions that synchronize and transfer data between
391 // partitioned devices. Send/Recv and collective instructions (AllReduce,
392 // AllToAll, CollectivePermute) belong to this instruction type. A group of
393 // instructions (of the same opcode) with the same channel_id communicate during
394 // execution.
395 class HloChannelInstruction : public HloInstruction {
396  public:
397   // Returns the channel id associated with the instruction. The id is
398   // shared between each Send/Recv pair or a group of collective instructions
399   // and is globally unique to identify each channel.
channel_id()400   std::optional<int64_t> channel_id() const { return channel_id_; }
401   void set_channel_id(const std::optional<int64_t>& channel_id);
402 
403   // Whether this instruction is identical to `other` except for the values of
404   // channel IDs, as long as both have channel IDs or neither has a channel ID.
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)405   virtual bool IdenticalSlowPathIgnoringChannelIdValues(
406       const HloInstruction& other,
407       const std::function<bool(const HloComputation*, const HloComputation*)>&
408           eq_computations) const {
409     return channel_id_.has_value() == other.channel_id().has_value();
410   }
411 
412   static bool ClassOf(const HloInstruction* hlo);
413 
414  protected:
415   explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
416                                  const std::optional<int64_t>& channel_id);
417 
418   HloInstructionProto ToProto() const override;
419 
420   std::vector<std::string> ExtraAttributesToStringImpl(
421       const HloPrintOptions& options) const override;
422 
423   // Do not override IdenticalSlowPath(). Override
424   // IdenticalSlowPathIgnoringChannelIdValues() instead.
425   bool IdenticalSlowPath(
426       const HloInstruction& other,
427       const std::function<bool(const HloComputation*, const HloComputation*)>&
428           eq_computations) const final;
429 
430   std::optional<int64_t> channel_id_;
431 };
432 
433 class HloSendRecvInstruction : public HloChannelInstruction {
434  public:
435   // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()436   bool is_host_transfer() const { return is_host_transfer_; }
437 
438   // Returns a serialized representation of this instruction.
439   HloInstructionProto ToProto() const override;
440 
ClassOf(const HloInstruction * hlo)441   static bool ClassOf(const HloInstruction* hlo) {
442     switch (hlo->opcode()) {
443       case HloOpcode::kSend:
444       case HloOpcode::kSendDone:
445       case HloOpcode::kRecv:
446       case HloOpcode::kRecvDone:
447         return true;
448       default:
449         return false;
450     }
451   }
452 
453  protected:
454   explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
455                                   int64_t channel_id, bool is_host_transfer);
456 
457  private:
458   std::vector<std::string> ExtraAttributesToStringImpl(
459       const HloPrintOptions& options) const override;
460   bool IdenticalSlowPathIgnoringChannelIdValues(
461       const HloInstruction& other,
462       const std::function<bool(const HloComputation*, const HloComputation*)>&
463           eq_computations) const override;
464   // Whether this send/recv instruction sends data to/from the host.
465   bool is_host_transfer_;
466 };
467 
468 class HloSendInstruction : public HloSendRecvInstruction {
469  public:
470   explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
471                               int64_t channel_id, bool is_host_transfer);
472 
ClassOf(const HloInstruction * hlo)473   static bool ClassOf(const HloInstruction* hlo) {
474     return hlo->opcode() == HloOpcode::kSend;
475   }
476 
477  private:
478   // Implementation for non-common logic of CloneWithNewOperands.
479   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
480       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
481       HloCloneContext* context) const override;
482 };
483 
484 class HloSendDoneInstruction : public HloSendRecvInstruction {
485  public:
486   explicit HloSendDoneInstruction(HloSendInstruction* operand,
487                                   bool is_host_transfer);
488 
ClassOf(const HloInstruction * hlo)489   static bool ClassOf(const HloInstruction* hlo) {
490     return hlo->opcode() == HloOpcode::kSendDone;
491   }
492 
493  private:
494   // Implementation for non-common logic of CloneWithNewOperands.
495   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
496       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
497       HloCloneContext* context) const override;
498 };
499 
500 class HloRecvInstruction : public HloSendRecvInstruction {
501  public:
502   explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
503                               int64_t channel_id, bool is_host_transfer);
504 
ClassOf(const HloInstruction * hlo)505   static bool ClassOf(const HloInstruction* hlo) {
506     return hlo->opcode() == HloOpcode::kRecv;
507   }
508 
509  private:
510   // Implementation for non-common logic of CloneWithNewOperands.
511   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
512       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
513       HloCloneContext* context) const override;
514 };
515 
516 class HloRecvDoneInstruction : public HloSendRecvInstruction {
517  public:
518   explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
519                                   bool is_host_transfer);
520 
ClassOf(const HloInstruction * hlo)521   static bool ClassOf(const HloInstruction* hlo) {
522     return hlo->opcode() == HloOpcode::kRecvDone;
523   }
524 
525  private:
526   // Implementation for non-common logic of CloneWithNewOperands.
527   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
528       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
529       HloCloneContext* context) const override;
530 };
531 
532 class HloCollectiveInstruction : public HloChannelInstruction {
533  public:
replica_groups()534   const std::vector<ReplicaGroup>& replica_groups() const {
535     return replica_groups_;
536   }
537 
538   // Returns true if the layout of the AllReduce is enforced by XLA client (as
539   // the layout set in the shape). The only reason for the client to set the
540   // layout is to separately compile computations that communicate with
541   // AllReduce. Since this field is only set `true` by the client, the compiler
542   // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
543   // `false` for all other cases.
544   //
545   // When this is `true`, there may be communication endpoints outside the
546   // current compilation unit, so the compiler considers this AllReduce as
547   // side-effecting to disable compiler transformations. The compiler is free to
548   // transform unconstrained AllReduces differently across compilation units.
549   // It is an error for an HloModule to have a mix of constrained and
550   // unconstrained AllReduce instructions (checked by HloVerifier).
constrain_layout()551   bool constrain_layout() const { return constrain_layout_; }
552 
553   static bool ClassOf(const HloInstruction* hlo);
554 
555  protected:
556   explicit HloCollectiveInstruction(
557       HloOpcode opcode, const Shape& shape,
558       absl::Span<HloInstruction* const> operands,
559       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
560       const std::optional<int64_t>& channel_id);
561 
562   HloInstructionProto ToProto() const override;
563 
564   std::vector<std::string> ExtraAttributesToStringImpl(
565       const HloPrintOptions& options) const override;
566   bool IdenticalSlowPathIgnoringChannelIdValues(
567       const HloInstruction& other,
568       const std::function<bool(const HloComputation*, const HloComputation*)>&
569           eq_computations) const override;
570 
571   std::vector<ReplicaGroup> replica_groups_;
572   bool constrain_layout_;
573 };
574 
575 class HloAllGatherInstruction : public HloCollectiveInstruction {
576  public:
577   explicit HloAllGatherInstruction(
578       HloOpcode opcode, const Shape& shape,
579       absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
580       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
581       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
582   // Same as HloAllReduceInstruction::use_global_device_ids.
use_global_device_ids()583   bool use_global_device_ids() const { return use_global_device_ids_; }
584 
585   // The dimension on which data from different participants are concatenated.
all_gather_dimension()586   int64_t all_gather_dimension() const { return all_gather_dimension_; }
dimensions()587   absl::Span<const int64_t> dimensions() const override {
588     return absl::MakeConstSpan(&all_gather_dimension_, 1);
589   }
590 
set_all_gather_dimension(int64_t dim)591   void set_all_gather_dimension(int64_t dim) { all_gather_dimension_ = dim; }
592 
ClassOf(const HloInstruction * hlo)593   static bool ClassOf(const HloInstruction* hlo) {
594     return hlo->opcode() == HloOpcode::kAllGather ||
595            hlo->opcode() == HloOpcode::kAllGatherStart;
596   }
597 
598  protected:
599   std::vector<std::string> ExtraAttributesToStringImpl(
600       const HloPrintOptions& options) const override;
601   HloInstructionProto ToProto() const override;
602 
603  private:
604   bool IdenticalSlowPathIgnoringChannelIdValues(
605       const HloInstruction& other,
606       const std::function<bool(const HloComputation*, const HloComputation*)>&
607           eq_computations) const override;
608 
609   // Implementation for non-common logic of CloneWithNewOperands.
610   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
611       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
612       HloCloneContext* context) const override;
613 
614   int64_t all_gather_dimension_;
615   bool use_global_device_ids_;
616 };
617 
618 // Base class for all-reduce and all-reduce scatter instructions.
619 class HloAllReduceInstructionBase : public HloCollectiveInstruction {
620  public:
621   explicit HloAllReduceInstructionBase(
622       HloOpcode opcode, const Shape& shape,
623       absl::Span<HloInstruction* const> operands,
624       HloComputation* reduce_computation,
625       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
626       const std::optional<int64_t>& channel_id, bool use_global_device_ids);
627 
628   // Returns true if the ids in the ReplicaGroup config represent a global id of
629   // (replica_id * partition_count + partition_id) instead of a replica id.
630   // This enables more flexible grouping of devices if this all-reduce is both
631   // cross-partition and cross-replica.
632   //
633   // For example with 2 replicas and 4 partitions,
634   // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
635   // group[0] = (0,0), (0,1), (1,0), (1,1)
636   // group[1] = (0,2), (0,3), (1,2), (1,3)
637   // where each pair is (replica_id, partition_id).
use_global_device_ids()638   bool use_global_device_ids() const { return use_global_device_ids_; }
set_use_global_device_ids(bool value)639   void set_use_global_device_ids(bool value) { use_global_device_ids_ = value; }
640 
641   static bool ClassOf(const HloInstruction* hlo);
642 
643  protected:
644   std::vector<std::string> ExtraAttributesToStringImpl(
645       const HloPrintOptions& options) const override;
646   HloInstructionProto ToProto() const override;
647 
648   bool IdenticalSlowPathIgnoringChannelIdValues(
649       const HloInstruction& other,
650       const std::function<bool(const HloComputation*, const HloComputation*)>&
651           eq_computations) const override;
652 
653  private:
654   bool use_global_device_ids_;
655 };
656 
657 class HloAllReduceInstruction : public HloAllReduceInstructionBase {
658  public:
659   using HloAllReduceInstructionBase::HloAllReduceInstructionBase;
660 
ClassOf(const HloInstruction * hlo)661   static bool ClassOf(const HloInstruction* hlo) {
662     return hlo->opcode() == HloOpcode::kAllReduce ||
663            hlo->opcode() == HloOpcode::kAllReduceStart;
664   }
665 
666   // Returns true if the AllReduce does no communication, so it's equivalent
667   // to a mem copy.
668   bool IsNoop() const;
669 
670  private:
671   // Implementation for non-common logic of CloneWithNewOperands.
672   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
673       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
674       HloCloneContext* context) const override;
675 };
676 
677 class HloReduceScatterInstruction : public HloAllReduceInstructionBase {
678  public:
679   explicit HloReduceScatterInstruction(
680       const Shape& shape, absl::Span<HloInstruction* const> operands,
681       HloComputation* reduce_computation,
682       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
683       const std::optional<int64_t>& channel_id, bool use_global_device_ids,
684       int64_t scatter_dimension);
685 
686   // The dimension on which reduced data is scattered to different participants.
scatter_dimension()687   int64_t scatter_dimension() const { return scatter_dimension_; }
dimensions()688   absl::Span<const int64_t> dimensions() const override {
689     return absl::MakeConstSpan(&scatter_dimension_, 1);
690   }
691 
ClassOf(const HloInstruction * hlo)692   static bool ClassOf(const HloInstruction* hlo) {
693     return hlo->opcode() == HloOpcode::kReduceScatter;
694   }
695 
696  protected:
697   std::vector<std::string> ExtraAttributesToStringImpl(
698       const HloPrintOptions& options) const override;
699   HloInstructionProto ToProto() const override;
700 
701  private:
702   bool IdenticalSlowPathIgnoringChannelIdValues(
703       const HloInstruction& other,
704       const std::function<bool(const HloComputation*, const HloComputation*)>&
705           eq_computations) const override;
706 
707   // Implementation for non-common logic of CloneWithNewOperands.
708   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
709       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
710       HloCloneContext* context) const override;
711 
712   int64_t scatter_dimension_;
713 };
714 
715 class HloAllToAllInstruction : public HloCollectiveInstruction {
716  public:
717   explicit HloAllToAllInstruction(
718       const Shape& shape, absl::Span<HloInstruction* const> operands,
719       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
720       const std::optional<int64_t>& channel_id,
721       const std::optional<int64_t>& split_dimension);
722 
723   // AllToAll can optionally take a split dimension, which means that this
724   // AllToAll takes a single (flattened) array operand and produces an array
725   // output (instead of taking a list of operands and producing a tuple).
726   //
727   // split_dimension specifies which dimension in the operand is split across
728   // devices in each replica_group, and also means the concatenated dimension
729   // on the output (i.e., input and the output shapes are the same).
split_dimension()730   std::optional<int64_t> split_dimension() const { return split_dimension_; }
set_split_dimension(int64_t dim)731   void set_split_dimension(int64_t dim) { split_dimension_ = dim; }
732 
ClassOf(const HloInstruction * hlo)733   static bool ClassOf(const HloInstruction* hlo) {
734     return hlo->opcode() == HloOpcode::kAllToAll;
735   }
736 
737  protected:
738   std::vector<std::string> ExtraAttributesToStringImpl(
739       const HloPrintOptions& options) const override;
740   HloInstructionProto ToProto() const override;
741 
742  private:
743   bool IdenticalSlowPathIgnoringChannelIdValues(
744       const HloInstruction& other,
745       const std::function<bool(const HloComputation*, const HloComputation*)>&
746           eq_computations) const override;
747 
748   // Implementation for non-common logic of CloneWithNewOperands.
749   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
750       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
751       HloCloneContext* context) const override;
752 
753   std::optional<int64_t> split_dimension_;
754 };
755 
756 class HloCollectivePermuteInstruction : public HloChannelInstruction {
757  public:
758   explicit HloCollectivePermuteInstruction(
759       HloOpcode opcode, const Shape& shape, HloInstruction* operand,
760       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
761       const std::optional<int64_t>& channel_id);
762 
763   explicit HloCollectivePermuteInstruction(
764       HloOpcode opcode, const Shape& shape, HloInstruction* input,
765       HloInstruction* output, HloInstruction* input_start_indices,
766       HloInstruction* output_start_indices,
767       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
768       absl::Span<const std::vector<int64_t>> slice_sizes,
769       const std::optional<int64_t>& channel_id);
770 
source_target_pairs()771   const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const {
772     return source_target_pairs_;
773   }
774 
dynamic_slice_sizes_list()775   const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const {
776     return slice_sizes_;
777   }
778 
779   // Returns a serialized representation of this instruction.
780   HloInstructionProto ToProto() const override;
781 
ClassOf(const HloInstruction * hlo)782   static bool ClassOf(const HloInstruction* hlo) {
783     return hlo->opcode() == HloOpcode::kCollectivePermute ||
784            hlo->opcode() == HloOpcode::kCollectivePermuteStart;
785   }
786 
787  private:
788   std::vector<std::string> ExtraAttributesToStringImpl(
789       const HloPrintOptions& options) const override;
790   bool IdenticalSlowPathIgnoringChannelIdValues(
791       const HloInstruction& other,
792       const std::function<bool(const HloComputation*, const HloComputation*)>&
793           eq_computations) const override;
794 
795   // Implementation for non-common logic of CloneWithNewOperands.
796   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
797       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
798       HloCloneContext* context) const override;
799 
800   const std::vector<std::pair<int64_t, int64_t>> source_target_pairs_;
801   const std::vector<std::vector<int64_t>> slice_sizes_;
802 };
803 
ClassOf(const HloInstruction * hlo)804 inline bool HloAllReduceInstructionBase::ClassOf(const HloInstruction* hlo) {
805   return HloAllReduceInstruction::ClassOf(hlo) ||
806          hlo->opcode() == HloOpcode::kReduceScatter;
807 }
808 
ClassOf(const HloInstruction * hlo)809 inline bool HloCollectiveInstruction::ClassOf(const HloInstruction* hlo) {
810   return HloAllReduceInstructionBase::ClassOf(hlo) ||
811          HloAllGatherInstruction::ClassOf(hlo) ||
812          HloAllToAllInstruction::ClassOf(hlo);
813 }
814 
ClassOf(const HloInstruction * hlo)815 inline bool HloChannelInstruction::ClassOf(const HloInstruction* hlo) {
816   return HloCollectiveInstruction::ClassOf(hlo) ||
817          HloCollectivePermuteInstruction::ClassOf(hlo) ||
818          HloSendRecvInstruction::ClassOf(hlo);
819 }
820 
821 class HloReverseInstruction : public HloDimensionsInstruction {
822  public:
823   explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
824                                  absl::Span<const int64_t> dimensions);
825 
ClassOf(const HloInstruction * hlo)826   static bool ClassOf(const HloInstruction* hlo) {
827     return hlo->opcode() == HloOpcode::kReverse;
828   }
829 
830  private:
831   // Implementation for non-common logic of CloneWithNewOperands.
832   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
833       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
834       HloCloneContext* context) const override;
835 };
836 
837 class HloConcatenateInstruction : public HloDimensionsInstruction {
838  public:
839   explicit HloConcatenateInstruction(const Shape& shape,
840                                      absl::Span<HloInstruction* const> operands,
841                                      int64_t dimension);
842   // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()843   int64_t concatenate_dimension() const override {
844     return HloInstruction::dimensions(0);
845   }
846 
ClassOf(const HloInstruction * hlo)847   static bool ClassOf(const HloInstruction* hlo) {
848     return hlo->opcode() == HloOpcode::kConcatenate;
849   }
850 
851  private:
852   // Implementation for non-common logic of CloneWithNewOperands.
853   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
854       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
855       HloCloneContext* context) const override;
856 };
857 
858 class HloReduceInstruction : public HloDimensionsInstruction {
859  public:
860   explicit HloReduceInstruction(const Shape& shape,
861                                 absl::Span<HloInstruction* const> args,
862                                 absl::Span<const int64_t> dimensions_to_reduce,
863                                 HloComputation* reduce_computation);
864 
865   // Returns the number of input arrays (and, consequentially, the number of
866   // init values) this reduce has.
input_count()867   int64_t input_count() const { return operand_count() / 2; }
868 
869   // Returns the input tensors to be reduced.
inputs()870   absl::Span<HloInstruction* const> inputs() const {
871     return absl::MakeSpan(operands()).subspan(0, input_count());
872   }
873 
874   // Returns the init values of the reduction.
init_values()875   absl::Span<HloInstruction* const> init_values() const {
876     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
877   }
878 
ClassOf(const HloInstruction * hlo)879   static bool ClassOf(const HloInstruction* hlo) {
880     return hlo->opcode() == HloOpcode::kReduce;
881   }
882 
883  private:
884   bool IdenticalSlowPath(
885       const HloInstruction& other,
886       const std::function<bool(const HloComputation*, const HloComputation*)>&
887           eq_computations) const override;
888   // Implementation for non-common logic of CloneWithNewOperands.
889   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
890       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
891       HloCloneContext* context) const override;
892 };
893 
894 class HloSortInstruction : public HloDimensionsInstruction {
895  public:
896   explicit HloSortInstruction(const Shape& shape, int64_t dimension,
897                               absl::Span<HloInstruction* const> operands,
898                               HloComputation* compare, bool is_stable);
899   // Returns the sort dimension for this instruction
sort_dimension()900   int64_t sort_dimension() const { return HloInstruction::dimensions(0); }
901   // Returns a serialized representation of this instruction.
902   HloInstructionProto ToProto() const override;
903   // Returns the key operand to this instruction.
keys()904   const HloInstruction* keys() const { return operand(0); }
mutable_keys()905   HloInstruction* mutable_keys() { return mutable_operand(0); }
906   // Returns the number of value operands.
values_count()907   int64_t values_count() const { return operand_count() - 1; }
is_stable()908   bool is_stable() const { return is_stable_; }
909 
ClassOf(const HloInstruction * hlo)910   static bool ClassOf(const HloInstruction* hlo) {
911     return hlo->opcode() == HloOpcode::kSort;
912   }
913 
914  private:
915   std::vector<std::string> ExtraAttributesToStringImpl(
916       const HloPrintOptions& options) const override;
917   bool IdenticalSlowPath(
918       const HloInstruction& other,
919       const std::function<bool(const HloComputation*, const HloComputation*)>&
920           eq_computations) const override;
921   // Implementation for non-common logic of CloneWithNewOperands.
922   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
923       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
924       HloCloneContext* context) const override;
925 
926   bool is_stable_;
927 };
928 
929 class HloTransposeInstruction : public HloDimensionsInstruction {
930  public:
931   explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
932                                    absl::Span<const int64_t> dimensions);
933   // Returns whether this instruction does a rank-2 transposition.
934   bool IsRank2Transpose() const;
935 
ClassOf(const HloInstruction * hlo)936   static bool ClassOf(const HloInstruction* hlo) {
937     return hlo->opcode() == HloOpcode::kTranspose;
938   }
939 
940  private:
941   // Implementation for non-common logic of CloneWithNewOperands.
942   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
943       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
944       HloCloneContext* context) const override;
945 };
946 
947 class HloBroadcastInstruction : public HloDimensionsInstruction {
948  public:
949   explicit HloBroadcastInstruction(
950       const Shape& shape, HloInstruction* operand,
951       absl::Span<const int64_t> broadcast_dimension);
952 
ClassOf(const HloInstruction * hlo)953   static bool ClassOf(const HloInstruction* hlo) {
954     return hlo->opcode() == HloOpcode::kBroadcast;
955   }
956 
957  private:
958   // Implementation for non-common logic of CloneWithNewOperands.
959   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
960       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
961       HloCloneContext* context) const override;
962 };
963 
964 class HloDynamicReshapeInstruction : public HloInstruction {
965  public:
966   explicit HloDynamicReshapeInstruction(
967       const Shape& shape, HloInstruction* data_operand,
968       absl::Span<HloInstruction* const> dim_sizes);
969 
970   // Returns the input dim sizes dimensions, which is operands[1:]
dim_sizes()971   absl::Span<HloInstruction* const> dim_sizes() const {
972     return absl::MakeSpan(operands()).subspan(1, operand_count());
973   }
974 
975   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
976       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
977       HloCloneContext* context) const override;
978 
979   // Returns the input dim size dimension, which is operands[1+i]
dim_sizes(int64_t i)980   HloInstruction* dim_sizes(int64_t i) const { return operands()[i + 1]; }
981 
ClassOf(const HloInstruction * hlo)982   static bool ClassOf(const HloInstruction* hlo) {
983     return hlo->opcode() == HloOpcode::kDynamicReshape;
984   }
985 };
986 
987 class HloReshapeInstruction : public HloInstruction {
988  public:
989   explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
990                                  int64_t inferred_dimension);
inferred_dimension()991   int64_t inferred_dimension() const { return inferred_dimension_; }
992   HloInstructionProto ToProto() const override;
993 
ClassOf(const HloInstruction * hlo)994   static bool ClassOf(const HloInstruction* hlo) {
995     return hlo->opcode() == HloOpcode::kReshape;
996   }
997 
998  private:
999   std::vector<std::string> ExtraAttributesToStringImpl(
1000       const HloPrintOptions& options) const override;
1001   bool IdenticalSlowPath(
1002       const HloInstruction& other,
1003       const std::function<bool(const HloComputation*, const HloComputation*)>&
1004           eq_computations) const override;
1005   // Implementation for non-common logic of CloneWithNewOperands.
1006   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1007       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1008       HloCloneContext* context) const override;
1009   int64_t inferred_dimension_;
1010 };
1011 
1012 class HloMapInstruction : public HloInstruction {
1013  public:
1014   explicit HloMapInstruction(const Shape& shape,
1015                              absl::Span<HloInstruction* const> operands,
1016                              HloComputation* map_computation);
1017   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1018   absl::Span<const int64_t> dimensions() const override { return dimensions_; }
1019 
mutable_dimensions()1020   std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; }
1021   // Returns a serialized representation of this instruction.
1022   HloInstructionProto ToProto() const override;
1023 
ClassOf(const HloInstruction * hlo)1024   static bool ClassOf(const HloInstruction* hlo) {
1025     return hlo->opcode() == HloOpcode::kMap;
1026   }
1027 
1028  private:
1029   bool IsElementwiseImpl(
1030       const std::optional<int64_t>& operand_idx) const override;
1031   std::vector<std::string> ExtraAttributesToStringImpl(
1032       const HloPrintOptions& options) const override;
1033   bool IdenticalSlowPath(
1034       const HloInstruction& other,
1035       const std::function<bool(const HloComputation*, const HloComputation*)>&
1036           eq_computations) const override;
1037   // Implementation for non-common logic of CloneWithNewOperands.
1038   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1039       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1040       HloCloneContext* context) const override;
1041 
1042   std::vector<int64_t> dimensions_;
1043 };
1044 
1045 class HloSliceInstruction : public HloInstruction {
1046  public:
1047   explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
1048                                absl::Span<const int64_t> start_indices,
1049                                absl::Span<const int64_t> limit_indices,
1050                                absl::Span<const int64_t> strides);
1051 
1052   HloInstructionProto ToProto() const override;
1053 
1054   // Returns the start index in the given dimension for a slice node.
slice_starts(int64_t dimension)1055   int64_t slice_starts(int64_t dimension) const {
1056     return slice_starts_[dimension];
1057   }
slice_starts()1058   const std::vector<int64_t>& slice_starts() const { return slice_starts_; }
mutable_slice_starts()1059   std::vector<int64_t>* mutable_slice_starts() { return &slice_starts_; }
1060 
1061   // Returns the (exclusive) limit index in the given dimension for a slice
1062   // node.
slice_limits(int64_t dimension)1063   int64_t slice_limits(int64_t dimension) const {
1064     return slice_limits_[dimension];
1065   }
slice_limits()1066   const std::vector<int64_t>& slice_limits() const { return slice_limits_; }
mutable_slice_limits()1067   std::vector<int64_t>* mutable_slice_limits() { return &slice_limits_; }
1068 
1069   // Returns the stride in the given dimension for a slice node.
slice_strides(int64_t dimension)1070   int64_t slice_strides(int64_t dimension) const {
1071     return slice_strides_[dimension];
1072   }
slice_strides()1073   const std::vector<int64_t>& slice_strides() const { return slice_strides_; }
mutable_slice_strides()1074   std::vector<int64_t>* mutable_slice_strides() { return &slice_strides_; }
1075 
ClassOf(const HloInstruction * hlo)1076   static bool ClassOf(const HloInstruction* hlo) {
1077     return hlo->opcode() == HloOpcode::kSlice;
1078   }
1079 
1080  private:
1081   std::vector<std::string> ExtraAttributesToStringImpl(
1082       const HloPrintOptions& options) const override;
1083   bool IdenticalSlowPath(
1084       const HloInstruction& other,
1085       const std::function<bool(const HloComputation*, const HloComputation*)>&
1086           eq_computations) const override;
1087   // Implementation for non-common logic of CloneWithNewOperands.
1088   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1089       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1090       HloCloneContext* context) const override;
1091 
1092   // Describes the [begin, end) index range for a slice.
1093   std::vector<int64_t> slice_starts_;
1094   std::vector<int64_t> slice_limits_;
1095   std::vector<int64_t> slice_strides_;
1096 };
1097 
1098 class HloConstantInstruction : public HloInstruction {
1099  public:
1100   explicit HloConstantInstruction(Literal literal);
1101   explicit HloConstantInstruction(Literal literal, const Shape& shape);
1102   // Used when the literal is too large and dropped.
1103   explicit HloConstantInstruction(const Shape& shape);
1104   // Returns the literal associated with this instruction.
literal()1105   const Literal& literal() const { return *literal_; }
1106   // Returns the (mutable) literal associated with this instruction.
mutable_literal()1107   Literal* mutable_literal() { return &literal_.value(); }
1108   // Returns whether there is literal associated with this instruction.
HasLiteral()1109   bool HasLiteral() const { return literal_.has_value(); }
1110   // Returns a serialized representation of this instruction.
1111   HloInstructionProto ToProto() const override;
1112 
1113   // Change the layout for an Constant Hlo instruction to match new_layout.  For
1114   // tuple shaped constants shape_index is the path to the internal array
1115   // subshape whose layout needs to be changed.
1116   void RelayoutConstant(const Layout& new_layout,
1117                         const ShapeIndex& shape_index = {});
1118 
ClassOf(const HloInstruction * hlo)1119   static bool ClassOf(const HloInstruction* hlo) {
1120     return hlo->opcode() == HloOpcode::kConstant;
1121   }
1122 
1123  private:
1124   bool IsElementwiseImpl(
1125       const std::optional<int64_t>& operand_idx) const override;
1126   bool IdenticalSlowPath(
1127       const HloInstruction& other,
1128       const std::function<bool(const HloComputation*, const HloComputation*)>&
1129           eq_computations) const override;
1130   std::string OperandsToStringWithCanonicalNameMap(
1131       const HloPrintOptions& options,
1132       CanonicalNameMap* canonical_name_map) const override;
1133   // Implementation for non-common logic of CloneWithNewOperands.
1134   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1135       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1136       HloCloneContext* context) const override;
1137   std::optional<Literal> literal_;
1138 };
1139 
1140 // Abstract class that represents an HLO instruction that "calls" a computation.
1141 // Fusion and Call HLOs inherit from this class.
1142 class HloCallableInstruction : public HloInstruction {
1143  public:
1144   HloCallableInstruction(HloOpcode opcode, const Shape& shape);
1145 
1146   HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1147                          absl::Span<HloInstruction* const> operands);
1148 
1149   HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1150                          absl::Span<HloInstruction* const> operands,
1151                          HloComputation* called_computation,
1152                          absl::string_view prefix = "");
1153 
1154   HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1155                          absl::Span<HloInstruction* const> operands,
1156                          absl::Span<HloComputation* const> called_computations);
1157 
1158   ~HloCallableInstruction() override;
1159 
1160   // Adds a new operand to the callable instruction.
1161   HloInstruction* AddCallOperand(HloInstruction* new_operand);
1162 
1163   // Appends (fuses) the given instruction into this callable instruction.
1164   // instruction_to_append is cloned and the clone is placed in the callable
1165   // instruction.  The users of instruction_to_append will be redirected to this
1166   // callable instruction. instruction_to_append is unchanged otherwise. When
1167   // add_output is true, a clone of the instruction_to_append will be added as
1168   // additional output resulting in a multi-output callable instruction.
1169   HloInstruction* AppendInstructionIntoCalledComputation(
1170       HloInstruction* instruction_to_append, bool add_output = false);
1171   // Clones the given instruction_to_append and inserts the clone into this
1172   // callable instruction. If add_output is true, a clone of
1173   // instruction_to_append will be in the output of the this callable
1174   // instruction (part of the tuple of the callable root).
1175   HloInstruction* CloneAndAppendInstructionIntoCalledComputation(
1176       HloInstruction* instruction_to_append, bool add_output = false);
1177 
1178   // Retrieves the called computations of an HloCallableInstruction that is
1179   // being cloned. If the called computations have not yet been cloned, then
1180   // they are first cloned and added to the context.
1181   absl::InlinedVector<HloComputation*, 1> GetOrCloneCalledComputations(
1182       HloCloneContext* context) const;
1183 
1184   HloComputation* called_computation() const;
1185 
1186   HloInstruction* called_computation_root() const;
1187 
1188   // Recursively sets all nested called computation to have thread name as
1189   // `execution_thread`. if `skip_async_execution_thread_overwrite` is true,
1190   // skip overwrite async instruction and its comptuations thread name
1191   // overwriting.
1192   void RecursivelySetComputationsThreadName(
1193       absl::string_view execution_thread,
1194       bool skip_async_execution_thread_overwrite);
1195 
ClassOf(const HloInstruction * hlo)1196   static bool ClassOf(const HloInstruction* hlo) {
1197     return hlo->opcode() == HloOpcode::kFusion ||
1198            hlo->opcode() == HloOpcode::kCall ||
1199            hlo->opcode() == HloOpcode::kCustomCall;
1200   }
1201 
1202  protected:
1203   // Returns the default called computation name.
1204   virtual std::string default_called_computation_name() const = 0;
1205 };
1206 
1207 class HloFusionInstruction : public HloCallableInstruction {
1208  public:
1209   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
1210                                 HloInstruction* fused_root);
1211 
1212   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
1213                                 absl::Span<HloInstruction* const> operands,
1214                                 HloComputation* fusion_computation,
1215                                 absl::string_view prefix = "");
1216 
1217   ~HloFusionInstruction() override;
1218 
1219   void ClearCalledComputations() override;
1220 
1221   // When a fusion instruction is being destructed, clear the back pointer of
1222   // its fusion computation, to avoid referencing freed memory.
1223   void ClearFusionComputationInstruction();
1224 
1225   std::string ToCategory() const override;
1226   // Returns a serialized representation of this instruction.
1227   HloInstructionProto ToProto() const override;
1228 
1229   // Adds a new operand the fusion instruction.
1230   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1231 
1232   // Merges the fused instructions from 'instruction_to_merge' into the
1233   // fused instruction set of 'this', updating operands as necessary.
1234   //
1235   // Precondition: 'instruction_to_merge' must be an operand of 'this'.
1236   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
1237 
1238   // Merges the fused instructions from instruction_to_merge into the fused
1239   // instruction set of 'this' and generates multioutput fusion instructions.
1240   // All the users of instruction_to_merge will be redirected to 'this'
1241   // instruction. instruction_to_merge will be removed from its parent
1242   // computation.
1243   void MergeFusionInstructionIntoMultiOutput(
1244       HloFusionInstruction* instruction_to_merge);
1245 
1246   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
1247   // is cloned and the clone is placed in the fusion
1248   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
1249   // than moved to cleanly handle the case where the instruction has a use
1250   // outside the fusion instruction. Moving such an instruction into a fusion
1251   // instruction would violate the single-result invariant of HLO instructions
1252   // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)1253   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
1254     CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1255     return AppendInstructionIntoCalledComputation(instruction_to_fuse);
1256   }
1257 
1258   // Fuses the given instruction in this fusion instruction and generates a
1259   // multioutput fusion instruction. A clone of the instruction_to_fuse will
1260   // be part of the output of fusion instructions. The users of
1261   // instruction_to_fuse will be redirected to this fusion instructions.
1262   // instruction_to_fuse is unchanged otherwise.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)1263   HloInstruction* FuseInstructionIntoMultiOutput(
1264       HloInstruction* instruction_to_fuse) {
1265     return AppendInstructionIntoCalledComputation(instruction_to_fuse,
1266                                                   /*add_output=*/true);
1267   }
1268 
1269   // Returns the computation for this fused instruction.
1270   HloComputation* fused_instructions_computation() const;
1271 
1272   // Returns the root instruction of the fused expression contained within this
1273   // fusion instruction.
1274   HloInstruction* fused_expression_root() const;
1275 
1276   // Returns the list of fused instructions inside this fusion instruction.  The
1277   // returned type is a range of HloInstruction*s.
1278   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1279       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1280   fused_instructions() const;
1281 
1282   const tensorflow::gtl::iterator_range<
1283       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1284   fused_instructions();
1285 
1286   // Gets the number of instructions inside this fusion instruction.
1287   int64_t fused_instruction_count() const;
1288 
1289   // Returns the fused parameter instruction in this fusion instruction
1290   // corresponding to the given parameter number.
1291   HloInstruction* fused_parameter(int64_t parameter_number) const;
1292 
1293   // Returns the vector of fused parameters inside this fusion instruction.
1294   const std::vector<HloInstruction*>& fused_parameters() const;
1295 
1296   // Returns true if this instruction is a fusion instruction that generates
1297   // multiple outputs.
IsMultiOutputFusion()1298   const bool IsMultiOutputFusion() const {
1299     return fused_expression_root()->opcode() == HloOpcode::kTuple;
1300   }
1301 
fusion_kind()1302   FusionKind fusion_kind() const { return fusion_kind_; }
1303 
set_fusion_kind(FusionKind kind)1304   void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
1305 
1306   // If multiple operands are the same instruction, keeps only one of them.
1307   Status DeduplicateFusionOperands();
1308 
ClassOf(const HloInstruction * hlo)1309   static bool ClassOf(const HloInstruction* hlo) {
1310     return hlo->opcode() == HloOpcode::kFusion;
1311   }
1312 
1313  protected:
default_called_computation_name()1314   std::string default_called_computation_name() const override {
1315     return "fused_computation";
1316   }
1317 
1318  private:
1319   bool IsElementwiseImpl(
1320       const std::optional<int64_t>& operand_idx) const override;
1321   std::vector<std::string> ExtraAttributesToStringImpl(
1322       const HloPrintOptions& options) const override;
1323   bool IdenticalSlowPath(
1324       const HloInstruction& other,
1325       const std::function<bool(const HloComputation*, const HloComputation*)>&
1326           eq_computations) const override;
1327 
1328   // Implementation for non-common logic of CloneWithNewOperands.
1329   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1330       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1331       HloCloneContext* context) const override;
1332 
1333   // The type of the fusion.
1334   FusionKind fusion_kind_;
1335 };
1336 
1337 class HloCallInstruction : public HloCallableInstruction {
1338  public:
1339   HloCallInstruction(const Shape& shape,
1340                      HloInstruction* called_computation_root);
1341 
1342   HloCallInstruction(const Shape& shape,
1343                      absl::Span<HloInstruction* const> operands,
1344                      HloComputation* called_computation);
1345 
ClassOf(const HloInstruction * hlo)1346   static bool ClassOf(const HloInstruction* hlo) {
1347     return hlo->opcode() == HloOpcode::kCall;
1348   }
1349 
1350  protected:
default_called_computation_name()1351   std::string default_called_computation_name() const override {
1352     return "called_computation";
1353   }
1354 };
1355 
1356 class HloRngInstruction : public HloInstruction {
1357  public:
1358   explicit HloRngInstruction(const Shape& shape,
1359                              RandomDistribution distribution,
1360                              absl::Span<HloInstruction* const> parameters);
1361   // Returns the random distribution for this rng node.
random_distribution()1362   RandomDistribution random_distribution() const { return distribution_; }
1363   // Returns a serialized representation of this instruction.
1364   HloInstructionProto ToProto() const override;
1365 
ClassOf(const HloInstruction * hlo)1366   static bool ClassOf(const HloInstruction* hlo) {
1367     return hlo->opcode() == HloOpcode::kRng;
1368   }
1369 
1370  private:
1371   bool IsElementwiseImpl(
1372       const std::optional<int64_t>& operand_idx) const override;
1373   std::vector<std::string> ExtraAttributesToStringImpl(
1374       const HloPrintOptions& options) const override;
1375   bool IdenticalSlowPath(
1376       const HloInstruction& other,
1377       const std::function<bool(const HloComputation*, const HloComputation*)>&
1378           eq_computations) const override;
1379   // Implementation for non-common logic of CloneWithNewOperands.
1380   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1381       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1382       HloCloneContext* context) const override;
1383 
1384   // The distribution requested for random number generation.
1385   RandomDistribution distribution_;
1386 };
1387 
1388 class HloParameterInstruction : public HloInstruction {
1389  public:
1390   explicit HloParameterInstruction(int64_t parameter_number, const Shape& shape,
1391                                    const std::string& name);
parameter_number()1392   int64_t parameter_number() const { return parameter_number_; }
1393 
1394   // Sets and gets the whether all replicas will receive the same parameter data
1395   // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1396   void set_parameter_replicated_at_leaf_buffers(
1397       absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
1398     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1399              parameter_replicated_at_leaf_buffers.size());
1400     parameter_replicated_at_leaf_buffers_.emplace(
1401         parameter_replicated_at_leaf_buffers.begin(),
1402         parameter_replicated_at_leaf_buffers.end());
1403   }
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1404   void set_parameter_replicated_at_leaf_buffers(
1405       const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
1406     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1407              parameter_replicated_at_leaf_buffers.size());
1408     parameter_replicated_at_leaf_buffers_ =
1409         parameter_replicated_at_leaf_buffers;
1410   }
parameter_replicated_at_leaf_buffers()1411   const std::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()
1412       const {
1413     return parameter_replicated_at_leaf_buffers_;
1414   }
1415 
1416   // Returns a serialized representation of this instruction.
1417   HloInstructionProto ToProto() const override;
1418 
ClassOf(const HloInstruction * hlo)1419   static bool ClassOf(const HloInstruction* hlo) {
1420     return hlo->opcode() == HloOpcode::kParameter;
1421   }
1422 
1423  private:
1424   std::vector<std::string> ExtraAttributesToStringImpl(
1425       const HloPrintOptions& options) const override;
1426   bool IdenticalSlowPath(
1427       const HloInstruction& other,
1428       const std::function<bool(const HloComputation*, const HloComputation*)>&
1429           eq_computations) const override;
1430   std::string OperandsToStringWithCanonicalNameMap(
1431       const HloPrintOptions& options,
1432       CanonicalNameMap* canonical_name_map) const override;
1433   // Implementation for non-common logic of CloneWithNewOperands.
1434   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1435       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1436       HloCloneContext* context) const override;
1437 
1438   int64_t parameter_number_ = 0;
1439 
1440   // Specifies whether each buffer has the same parameter value on all replicas
1441   // in data parallelism.
1442   std::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
1443 };
1444 
1445 class HloGetTupleElementInstruction : public HloInstruction {
1446  public:
1447   explicit HloGetTupleElementInstruction(const Shape& shape,
1448                                          HloInstruction* operand,
1449                                          int64_t index);
1450   // Returns the tuple index associated with this instruction.
tuple_index()1451   int64_t tuple_index() const { return tuple_index_; }
1452   // Sets the tuple index associated with this instruction.
set_tuple_index(int64_t new_tuple_index)1453   void set_tuple_index(int64_t new_tuple_index) {
1454     tuple_index_ = new_tuple_index;
1455   }
1456   // Returns a serialized representation of this instruction.
1457   HloInstructionProto ToProto() const override;
1458 
ClassOf(const HloInstruction * hlo)1459   static bool ClassOf(const HloInstruction* hlo) {
1460     return hlo->opcode() == HloOpcode::kGetTupleElement;
1461   }
1462 
1463  private:
1464   std::vector<std::string> ExtraAttributesToStringImpl(
1465       const HloPrintOptions& options) const override;
1466   bool IdenticalSlowPath(
1467       const HloInstruction& other,
1468       const std::function<bool(const HloComputation*, const HloComputation*)>&
1469           eq_computations) const override;
1470   // Implementation for non-common logic of CloneWithNewOperands.
1471   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1472       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1473       HloCloneContext* context) const override;
1474 
1475   int64_t tuple_index_ = -1;
1476 };
1477 
1478 class HloReducePrecisionInstruction : public HloInstruction {
1479  public:
1480   explicit HloReducePrecisionInstruction(const Shape& shape,
1481                                          HloInstruction* operand,
1482                                          const int exponent_bits,
1483                                          const int mantissa_bits);
1484   // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()1485   int32_t exponent_bits() const { return exponent_bits_; }
1486   // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()1487   int32_t mantissa_bits() const { return mantissa_bits_; }
1488   // Returns a serialized representation of this instruction.
1489   HloInstructionProto ToProto() const override;
1490 
ClassOf(const HloInstruction * hlo)1491   static bool ClassOf(const HloInstruction* hlo) {
1492     return hlo->opcode() == HloOpcode::kReducePrecision;
1493   }
1494 
1495  private:
1496   std::vector<std::string> ExtraAttributesToStringImpl(
1497       const HloPrintOptions& options) const override;
1498   bool IdenticalSlowPath(
1499       const HloInstruction& other,
1500       const std::function<bool(const HloComputation*, const HloComputation*)>&
1501           eq_computations) const override;
1502   // Implementation for non-common logic of CloneWithNewOperands.
1503   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1504       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1505       HloCloneContext* context) const override;
1506 
1507   // The bit sizes for a reduce-precision operation.
1508   int32_t exponent_bits_ = 0;
1509   int32_t mantissa_bits_ = 0;
1510 };
1511 
1512 class HloInfeedInstruction : public HloInstruction {
1513  public:
1514   explicit HloInfeedInstruction(const Shape& infeed_shape,
1515                                 HloInstruction* token_operand,
1516                                 const std::string& config);
1517   // Returns the infeed configuration string. The infeed configuration includes
1518   // any metadata needed for the backend compiler (e.g., infeed buffer address)
1519   // and is target-dependent.
infeed_config()1520   std::string infeed_config() const { return infeed_config_; }
set_infeed_config(const std::string & config)1521   void set_infeed_config(const std::string& config) { infeed_config_ = config; }
1522   // Returns the shape of the data received by the infeed. This is not the same
1523   // as the shape of the infeed instruction which produces a tuple containing
1524   // the infeed data shape and a TOKEN.
infeed_shape()1525   const Shape& infeed_shape() const {
1526     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
1527     return ShapeUtil::GetSubshape(shape(), {0});
1528   }
1529   // Returns a serialized representation of this instruction.
1530   HloInstructionProto ToProto() const override;
1531 
ClassOf(const HloInstruction * hlo)1532   static bool ClassOf(const HloInstruction* hlo) {
1533     return hlo->opcode() == HloOpcode::kInfeed;
1534   }
1535 
1536  private:
1537   std::vector<std::string> ExtraAttributesToStringImpl(
1538       const HloPrintOptions& options) const override;
1539   bool IdenticalSlowPath(
1540       const HloInstruction& other,
1541       const std::function<bool(const HloComputation*, const HloComputation*)>&
1542           eq_computations) const override;
1543   // Implementation for non-common logic of CloneWithNewOperands.
1544   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1545       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1546       HloCloneContext* context) const override;
1547 
1548   // The string representation of the infeed configuration.
1549   std::string infeed_config_;
1550 };
1551 
1552 class HloOutfeedInstruction : public HloInstruction {
1553  public:
1554   explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1555                                  HloInstruction* operand,
1556                                  HloInstruction* token_operand,
1557                                  absl::string_view outfeed_config);
1558   // Returns the shape for the Outfeed instruction.
outfeed_shape()1559   const Shape& outfeed_shape() const { return outfeed_shape_; }
1560   // Returns the mutable shape for the Outfeed instruction.
mutable_outfeed_shape()1561   Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
1562   // Returns the config for the Outfeed instruction.
outfeed_config()1563   const std::string& outfeed_config() const { return outfeed_config_; }
set_outfeed_config(const std::string & config)1564   void set_outfeed_config(const std::string& config) {
1565     outfeed_config_ = config;
1566   }
1567   // Returns a serialized representation of this instruction.
1568   HloInstructionProto ToProto() const override;
1569 
ClassOf(const HloInstruction * hlo)1570   static bool ClassOf(const HloInstruction* hlo) {
1571     return hlo->opcode() == HloOpcode::kOutfeed;
1572   }
1573 
1574  private:
1575   std::vector<std::string> ExtraAttributesToStringImpl(
1576       const HloPrintOptions& options) const override;
1577   bool IdenticalSlowPath(
1578       const HloInstruction& other,
1579       const std::function<bool(const HloComputation*, const HloComputation*)>&
1580           eq_computations) const override;
1581   // Implementation for non-common logic of CloneWithNewOperands.
1582   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1583       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1584       HloCloneContext* context) const override;
1585 
1586   // Shape of outfeed request.
1587   Shape outfeed_shape_;
1588   // Outfeed configuration information, only present for kOutfeed.
1589   std::string outfeed_config_;
1590 };
1591 
1592 class HloConvolutionInstruction : public HloInstruction {
1593  public:
1594   explicit HloConvolutionInstruction(
1595       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1596       int64_t feature_group_count, int64_t batch_group_count,
1597       const Window& window,
1598       const ConvolutionDimensionNumbers& dimension_numbers,
1599       const PrecisionConfig& precision_config);
window()1600   const Window& window() const override { return window_; }
set_window(const Window & window)1601   void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1602   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1603     return convolution_dimension_numbers_;
1604   }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1605   void set_convolution_dimension_numbers(
1606       const ConvolutionDimensionNumbers& dnums) {
1607     convolution_dimension_numbers_ = dnums;
1608   }
1609   // The number of feature groups. Must be a divisor of the input feature
1610   // dimension and output feature dimension.
feature_group_count()1611   int64_t feature_group_count() const { return feature_group_count_; }
set_feature_group_count(int64_t num_feature_groups)1612   void set_feature_group_count(int64_t num_feature_groups) {
1613     feature_group_count_ = num_feature_groups;
1614   }
1615   // The number of batch groups. Must be a divisor of the input batch dimension.
batch_group_count()1616   int64_t batch_group_count() const { return batch_group_count_; }
set_batch_group_count(int64_t num_batch_groups)1617   void set_batch_group_count(int64_t num_batch_groups) {
1618     batch_group_count_ = num_batch_groups;
1619   }
1620 
1621   // Returns the information used to tell the implementation information about
1622   // what sort of precision is requested. The meaning of the field is backend
1623   // specific. At the moment, it is only supported for kConvolution and kDot.
1624   // Transformations on one kDot or kConvolution to another will preserve this
1625   // information. Transformations to other HLOs will not preserve this
1626   // information but it is presumed that the alternate lowering is strictly
1627   // superior.
precision_config()1628   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1629   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1630 
1631   std::string ToCategory() const override;
1632   // Returns a serialized representation of this instruction.
1633   HloInstructionProto ToProto() const override;
1634 
ClassOf(const HloInstruction * hlo)1635   static bool ClassOf(const HloInstruction* hlo) {
1636     return hlo->opcode() == HloOpcode::kConvolution;
1637   }
1638 
1639  private:
1640   std::vector<std::string> ExtraAttributesToStringImpl(
1641       const HloPrintOptions& options) const override;
1642   bool IdenticalSlowPath(
1643       const HloInstruction& other,
1644       const std::function<bool(const HloComputation*, const HloComputation*)>&
1645           eq_computations) const override;
1646   // Implementation for non-common logic of CloneWithNewOperands.
1647   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1648       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1649       HloCloneContext* context) const override;
1650   // The number of feature groups. Must be a divisor of the input feature
1651   // dimension and output feature dimension.
1652   int64_t feature_group_count_;
1653   // The number of batch groups. Must be a divisor of the input batch dimension.
1654   int64_t batch_group_count_;
1655   // Describes the window used for a convolution.
1656   Window window_;
1657   // Describes the dimension numbers used for a convolution.
1658   ConvolutionDimensionNumbers convolution_dimension_numbers_;
1659   // Information used to communicate to the implementation about the algorithm
1660   // used to produce results. See the documentation on precision_config().
1661   PrecisionConfig precision_config_;
1662 };
1663 
1664 class HloReduceWindowInstruction : public HloInstruction {
1665  public:
1666   explicit HloReduceWindowInstruction(const Shape& shape,
1667                                       HloInstruction* operand,
1668                                       HloInstruction* init_value,
1669                                       const Window& window,
1670                                       HloComputation* reduce_computation);
1671   explicit HloReduceWindowInstruction(
1672       const Shape& shape, absl::Span<HloInstruction* const> operands,
1673       absl::Span<HloInstruction* const> init_values, const Window& window,
1674       HloComputation* reduce_computation);
window()1675   const Window& window() const override { return window_; }
set_window(const Window & window)1676   void set_window(const Window& window) override { window_ = window; }
1677   // Returns a serialized representation of this instruction.
1678   HloInstructionProto ToProto() const override;
1679   // Returns the number of input arrays (and, consequentially, the number of
1680   // init values) this reduce has.
input_count()1681   int64_t input_count() const { return operand_count() / 2; }
1682   // Returns the input tensors to be reduced.
inputs()1683   absl::Span<HloInstruction* const> inputs() const {
1684     return absl::MakeSpan(operands()).subspan(0, input_count());
1685   }
1686   // Returns the init values of the reduction.
init_values()1687   absl::Span<HloInstruction* const> init_values() const {
1688     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
1689   }
1690   // Returns the shapes of input tensors to be reduced.
input_shapes()1691   absl::InlinedVector<const Shape*, 2> input_shapes() const {
1692     absl::InlinedVector<const Shape*, 2> shapes;
1693     for (const auto* op : inputs()) {
1694       VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
1695       shapes.push_back(&op->shape());
1696       VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
1697     }
1698     return shapes;
1699   }
1700   // Returns the init values of the reduction.
init_value_shapes()1701   absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
1702     absl::InlinedVector<const Shape*, 2> shapes;
1703     for (const auto* op : init_values()) {
1704       shapes.push_back(&op->shape());
1705     }
1706     return shapes;
1707   }
1708   // Returns the shapes of the reduced output tensors.
output_shapes()1709   absl::InlinedVector<const Shape*, 2> output_shapes() const {
1710     absl::InlinedVector<const Shape*, 2> shapes;
1711     if (shape().IsArray()) {
1712       shapes.push_back(&shape());
1713     } else {
1714       for (const Shape& tuple_element_shape : shape().tuple_shapes()) {
1715         shapes.push_back(&tuple_element_shape);
1716       }
1717     }
1718     return shapes;
1719   }
1720 
ClassOf(const HloInstruction * hlo)1721   static bool ClassOf(const HloInstruction* hlo) {
1722     return hlo->opcode() == HloOpcode::kReduceWindow;
1723   }
1724 
1725  private:
1726   std::vector<std::string> ExtraAttributesToStringImpl(
1727       const HloPrintOptions& options) const override;
1728   bool IdenticalSlowPath(
1729       const HloInstruction& other,
1730       const std::function<bool(const HloComputation*, const HloComputation*)>&
1731           eq_computations) const override;
1732   // Implementation for non-common logic of CloneWithNewOperands.
1733   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1734       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1735       HloCloneContext* context) const override;
1736 
1737   Window window_;
1738 };
1739 
1740 class HloSelectAndScatterInstruction : public HloInstruction {
1741  public:
1742   explicit HloSelectAndScatterInstruction(
1743       const Shape& shape, HloInstruction* operand, HloComputation* select,
1744       const Window& window, HloInstruction* source, HloInstruction* init_value,
1745       HloComputation* scatter);
window()1746   const Window& window() const override { return window_; }
set_window(const Window & window)1747   void set_window(const Window& window) override { window_ = window; }
1748   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1749   // setters should only be called by HloModule or HloComputation methods.
select()1750   HloComputation* select() const {
1751     return called_computations()[kSelectComputationIndex];
1752   }
1753 
scatter()1754   HloComputation* scatter() const {
1755     return called_computations()[kScatterComputationIndex];
1756   }
1757 
set_select(HloComputation * computation)1758   void set_select(HloComputation* computation) {
1759     // Don't allow changing the computation for fused instructions so we don't
1760     // have to recompute called_instructions for the entire fusion instruction.
1761     CHECK(!IsFused());
1762     set_called_computation(kSelectComputationIndex, computation);
1763   }
1764 
set_scatter(HloComputation * computation)1765   void set_scatter(HloComputation* computation) {
1766     // Don't allow changing the computation for fused instructions so we don't
1767     // have to recompute called_instructions for the entire fusion instruction.
1768     CHECK(!IsFused());
1769     set_called_computation(kScatterComputationIndex, computation);
1770   }
1771   // Returns a serialized representation of this instruction.
1772   HloInstructionProto ToProto() const override;
1773 
ClassOf(const HloInstruction * hlo)1774   static bool ClassOf(const HloInstruction* hlo) {
1775     return hlo->opcode() == HloOpcode::kSelectAndScatter;
1776   }
1777 
1778  private:
1779   std::vector<std::string> ExtraAttributesToStringImpl(
1780       const HloPrintOptions& options) const override;
1781   bool IdenticalSlowPath(
1782       const HloInstruction& other,
1783       const std::function<bool(const HloComputation*, const HloComputation*)>&
1784           eq_computations) const override;
1785   // Implementation for non-common logic of CloneWithNewOperands.
1786   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1787       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1788       HloCloneContext* context) const override;
1789   Window window_;
1790 };
1791 
1792 class HloCustomCallInstruction : public HloCallableInstruction {
1793  public:
1794   HloCustomCallInstruction(const Shape& shape,
1795                            absl::Span<HloInstruction* const> operands,
1796                            absl::string_view custom_call_target,
1797                            std::string opaque,
1798                            CustomCallApiVersion api_version);
1799 
1800   // Constructor for a custom call with constrained layout. 'shape' and
1801   // 'operands_with_layout' must all have layouts.
1802   HloCustomCallInstruction(const Shape& shape,
1803                            absl::Span<HloInstruction* const> operands,
1804                            absl::string_view custom_call_target,
1805                            std::string opaque,
1806                            absl::Span<const Shape> operand_shapes_with_layout,
1807                            CustomCallApiVersion api_version);
1808 
1809   // Constructor for a custom call with a to_apply computation.
1810   HloCustomCallInstruction(const Shape& shape,
1811                            absl::Span<HloInstruction* const> operands,
1812                            HloComputation* to_apply,
1813                            absl::string_view custom_call_target,
1814                            std::string opaque,
1815                            CustomCallApiVersion api_version);
1816 
1817   // Constructor for a custom call with multiple computations.
1818   HloCustomCallInstruction(
1819       const Shape& shape, absl::Span<HloInstruction* const> operands,
1820       absl::Span<HloComputation* const> called_computations,
1821       absl::string_view custom_call_target, std::string opaque,
1822       CustomCallApiVersion api_version);
1823 
window()1824   const Window& window() const override {
1825     CHECK(window_ != nullptr);
1826     return *window_;
1827   }
1828 
set_window(const Window & window)1829   void set_window(const Window& window) override {
1830     window_ = std::make_unique<Window>(window);
1831   }
1832 
convolution_dimension_numbers()1833   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1834     CHECK(convolution_dimension_numbers_ != nullptr);
1835     return *convolution_dimension_numbers_;
1836   }
1837 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1838   void set_convolution_dimension_numbers(
1839       const ConvolutionDimensionNumbers& dnums) {
1840     convolution_dimension_numbers_ =
1841         std::make_unique<ConvolutionDimensionNumbers>(dnums);
1842   }
1843   // TODO(jpienaar): Remove this accessor in the follow up.
opaque()1844   const std::string& opaque() const { return raw_backend_config_string(); }
custom_call_target()1845   const std::string& custom_call_target() const { return custom_call_target_; }
set_custom_call_target(absl::string_view target)1846   void set_custom_call_target(absl::string_view target) {
1847     custom_call_target_ = std::string(target);
1848   }
set_feature_group_count(int64_t feature_group_count)1849   void set_feature_group_count(int64_t feature_group_count) {
1850     feature_group_count_ = feature_group_count;
1851   }
set_batch_group_count(int64_t batch_group_count)1852   void set_batch_group_count(int64_t batch_group_count) {
1853     batch_group_count_ = batch_group_count;
1854   }
1855   // Sets whether this custom call has a side-effect - by default a custom call
1856   // has no side-effects.
set_custom_call_has_side_effect(bool custom_call_has_side_effect)1857   void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
1858     custom_call_has_side_effect_ = custom_call_has_side_effect;
1859   }
feature_group_count()1860   int64_t feature_group_count() const { return feature_group_count_; }
batch_group_count()1861   int64_t batch_group_count() const { return batch_group_count_; }
custom_call_has_side_effect()1862   bool custom_call_has_side_effect() const {
1863     return custom_call_has_side_effect_;
1864   }
1865   // Returns padding type used for ops like convolution.
padding_type()1866   PaddingType padding_type() const { return padding_type_; }
1867 
set_padding_type(PaddingType padding_type)1868   void set_padding_type(PaddingType padding_type) {
1869     padding_type_ = padding_type;
1870   }
1871 
1872   // Returns the literal associated with this instruction.
literal()1873   const Literal& literal() const { return *literal_; }
1874   // Set the value of literal to a new one.
set_literal(Literal && literal)1875   void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
1876   // Returns whether there is literal associated with this instruction.
HasLiteral()1877   bool HasLiteral() const { return literal_.has_value(); }
1878 
precision_config()1879   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1880   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1881 
1882   // Returns a serialized representation of this instruction.
1883   HloInstructionProto ToProto() const override;
1884 
1885   // Returns whether the result and operand layouts are constrained.
layout_constrained()1886   bool layout_constrained() const { return layout_constrained_; }
1887 
1888   // Returns the shapes (with layout) of the operands. CHECKs if this custom
1889   // call does not have constrained layouts.
operand_shapes_with_layout()1890   const std::vector<Shape>& operand_shapes_with_layout() const {
1891     CHECK(layout_constrained());
1892     return operand_shapes_with_layout_;
1893   }
1894   // Gets a list of output/operand buffer pairs that alias each other, where the
1895   // output buffer is represented as a ShapeIndex, and the operand buffer is
1896   // represented as the operand index and the ShapeIndex. By default this list
1897   // is empty.
1898   const std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>&
output_to_operand_aliasing()1899   output_to_operand_aliasing() const {
1900     return output_to_operand_aliasing_;
1901   }
1902   // Sets the list of output/operand buffer pairs that alias each other.
set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> aliasing)1903   void set_output_to_operand_aliasing(
1904       std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1905           aliasing) {
1906     output_to_operand_aliasing_ = std::move(aliasing);
1907   }
set_custom_call_schedule(CustomCallSchedule custom_call_schedule)1908   void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) {
1909     custom_call_schedule_ = custom_call_schedule;
1910   }
custom_call_schedule()1911   CustomCallSchedule custom_call_schedule() const {
1912     return custom_call_schedule_;
1913   }
set_api_version(CustomCallApiVersion api_version)1914   void set_api_version(CustomCallApiVersion api_version) {
1915     api_version_ = api_version;
1916   }
api_version()1917   CustomCallApiVersion api_version() const { return api_version_; }
1918 
ClassOf(const HloInstruction * hlo)1919   static bool ClassOf(const HloInstruction* hlo) {
1920     return hlo->opcode() == HloOpcode::kCustomCall;
1921   }
1922 
1923  protected:
default_called_computation_name()1924   std::string default_called_computation_name() const override {
1925     return "custom_call_computation";
1926   }
1927 
1928  private:
1929   std::vector<std::string> ExtraAttributesToStringImpl(
1930       const HloPrintOptions& options) const override;
1931   bool IdenticalSlowPath(
1932       const HloInstruction& other,
1933       const std::function<bool(const HloComputation*, const HloComputation*)>&
1934           eq_computations) const override;
1935   // Implementation for non-common logic of CloneWithNewOperands.
1936   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1937       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1938       HloCloneContext* context) const override;
1939   // Name of a global symbol to call.
1940   std::string custom_call_target_;
1941   // Describes the window in a windowed operation such as convolution.
1942   std::unique_ptr<Window> window_;
1943   // Describes the dimension numbers used for a convolution.
1944   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1945   // The number of feature groups. This is used for grouped convolutions.
1946   int64_t feature_group_count_;
1947   int64_t batch_group_count_;
1948   // Whether the result and operand layouts are constrained.
1949   bool layout_constrained_;
1950   // Information used to communicate to the implementation about the algorithm
1951   // used to produce results for convolution instructions.
1952   PrecisionConfig precision_config_;
1953   // Describes the padding type for convolution instructions.
1954   PaddingType padding_type_;
1955   // For layout-constrained custom calls, this vector holds the shape with
1956   // layout for each operand.
1957   std::vector<Shape> operand_shapes_with_layout_;
1958   // Whether this custom call has a side-effect.
1959   bool custom_call_has_side_effect_;
1960   // A list of output/operand buffer pairs that alias each other. See comment of
1961   // output_to_operand_aliasing().
1962   std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1963       output_to_operand_aliasing_;
1964   std::optional<Literal> literal_;
1965   // A custom-call schedule hint.
1966   CustomCallSchedule custom_call_schedule_;
1967   // The version of the API used by the custom call function.
1968   // TODO(b/189822916): Remove this field when all clients are migrated to the
1969   // status-returning API.
1970   CustomCallApiVersion api_version_;
1971 };
1972 
1973 class HloPadInstruction : public HloInstruction {
1974  public:
1975   explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1976                              HloInstruction* padding_value,
1977                              const PaddingConfig& padding_config);
1978   // Returns the padding configuration for a pad node.
padding_config()1979   const PaddingConfig& padding_config() const { return padding_config_; }
mutable_padding_config()1980   PaddingConfig* mutable_padding_config() { return &padding_config_; }
1981   // Returns the operand being padded.
padded_operand()1982   const HloInstruction* padded_operand() const { return operand(0); }
mutable_padded_operand()1983   HloInstruction* mutable_padded_operand() { return mutable_operand(0); }
1984   // Returns the padding value.
padding_value()1985   const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1986   HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1987   // Returns a serialized representation of this instruction.
1988   HloInstructionProto ToProto() const override;
1989 
ClassOf(const HloInstruction * hlo)1990   static bool ClassOf(const HloInstruction* hlo) {
1991     return hlo->opcode() == HloOpcode::kPad;
1992   }
1993 
1994  private:
1995   std::vector<std::string> ExtraAttributesToStringImpl(
1996       const HloPrintOptions& options) const override;
1997   bool IdenticalSlowPath(
1998       const HloInstruction& other,
1999       const std::function<bool(const HloComputation*, const HloComputation*)>&
2000           eq_computations) const override;
2001   // Implementation for non-common logic of CloneWithNewOperands.
2002   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2003       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2004       HloCloneContext* context) const override;
2005 
2006   // The padding configuration that describes the edge padding and interior
2007   // padding of this pad instruction.
2008   PaddingConfig padding_config_;
2009 };
2010 
2011 class HloDynamicIndexInstruction : public HloInstruction {
2012  public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)2013   explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
2014       : HloInstruction(opcode, shape) {}
2015   virtual int64_t first_index_operand_number() const = 0;
2016 
2017   // Returns a subspan of operands which represent the start indices.
index_operands()2018   absl::Span<HloInstruction* const> index_operands() const {
2019     return absl::MakeSpan(operands()).subspan(first_index_operand_number());
2020   }
2021 
2022   // Returns the shapes of the index operands.
index_shapes()2023   std::vector<Shape> index_shapes() const {
2024     std::vector<Shape> shapes;
2025     auto indices = index_operands();
2026     for (const HloInstruction* index : indices) {
2027       shapes.push_back(index->shape());
2028     }
2029     return shapes;
2030   }
2031 
ClassOf(const HloInstruction * hlo)2032   static bool ClassOf(const HloInstruction* hlo) {
2033     return hlo->opcode() == HloOpcode::kDynamicSlice ||
2034            hlo->opcode() == HloOpcode::kDynamicUpdateSlice;
2035   }
2036 };
2037 
2038 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
2039  public:
2040   explicit HloDynamicSliceInstruction(const Shape& shape,
2041                                       HloInstruction* operand,
2042                                       HloInstruction* start_indices,
2043                                       absl::Span<const int64_t> slice_sizes);
2044   explicit HloDynamicSliceInstruction(
2045       const Shape& shape, HloInstruction* operand,
2046       absl::Span<HloInstruction* const> start_indices,
2047       absl::Span<const int64_t> slice_sizes);
2048   // Old methods kept for smooth subclassing transition END.
2049   // Returns the size of the slice in the given dimension for a dynamic
2050   // slice node.
slice_sizes(int64_t dimension)2051   int64_t slice_sizes(int64_t dimension) const {
2052     return dynamic_slice_sizes_[dimension];
2053   }
dynamic_slice_sizes()2054   const std::vector<int64_t>& dynamic_slice_sizes() const {
2055     return dynamic_slice_sizes_;
2056   }
2057   // Returns a serialized representation of this instruction.
2058   HloInstructionProto ToProto() const override;
2059 
first_index_operand_number()2060   int64_t first_index_operand_number() const override { return 1; }
ClassOf(const HloInstruction * hlo)2061   static bool ClassOf(const HloInstruction* hlo) {
2062     return hlo->opcode() == HloOpcode::kDynamicSlice;
2063   }
2064 
2065  private:
2066   std::vector<std::string> ExtraAttributesToStringImpl(
2067       const HloPrintOptions& options) const override;
2068   bool IdenticalSlowPath(
2069       const HloInstruction& other,
2070       const std::function<bool(const HloComputation*, const HloComputation*)>&
2071           eq_computations) const override;
2072   // Implementation for non-common logic of CloneWithNewOperands.
2073   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2074       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2075       HloCloneContext* context) const override;
2076 
2077   // Describes the [start, start + size) range size for a dynamic slice
2078   // ('start' is specified dynamically in the second operand of the operation).
2079   std::vector<int64_t> dynamic_slice_sizes_;
2080 };
2081 
2082 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
2083  public:
2084   explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
2085                                             HloInstruction* operand,
2086                                             HloInstruction* update,
2087                                             HloInstruction* start_indices);
2088   explicit HloDynamicUpdateSliceInstruction(
2089       const Shape& shape, HloInstruction* operand, HloInstruction* update,
2090       absl::Span<HloInstruction* const> start_indices);
2091 
first_index_operand_number()2092   int64_t first_index_operand_number() const override { return 2; }
2093 
ClassOf(const HloInstruction * hlo)2094   static bool ClassOf(const HloInstruction* hlo) {
2095     return hlo->opcode() == HloOpcode::kDynamicUpdateSlice;
2096   }
2097 };
2098 
2099 class HloGatherInstruction : public HloInstruction {
2100  public:
2101   explicit HloGatherInstruction(
2102       const Shape& shape, HloInstruction* operand,
2103       HloInstruction* start_indices,
2104       const GatherDimensionNumbers& gather_dim_numbers,
2105       absl::Span<const int64_t> slice_sizes, bool indices_are_sorted);
gather_dimension_numbers()2106   const GatherDimensionNumbers& gather_dimension_numbers() const {
2107     CHECK(gather_dimension_numbers_ != nullptr);
2108     return *gather_dimension_numbers_;
2109   }
gather_slice_sizes()2110   absl::Span<const int64_t> gather_slice_sizes() const {
2111     return gather_slice_sizes_;
2112   }
indices_are_sorted()2113   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)2114   void set_indices_are_sorted(bool indices_are_sorted) {
2115     indices_are_sorted_ = indices_are_sorted;
2116   }
2117   // Returns a serialized representation of this instruction.
2118   HloInstructionProto ToProto() const override;
2119 
2120   // Creates an instance of GatherDimensionNumbers.
2121   static GatherDimensionNumbers MakeGatherDimNumbers(
2122       absl::Span<const int64_t> offset_dims,
2123       absl::Span<const int64_t> collapsed_slice_dims,
2124       absl::Span<const int64_t> start_index_map, int64_t index_vector_dim);
2125   // Returns the dump string of the given gather dimension numbers.
2126   static std::string GatherDimensionNumbersToString(
2127       const GatherDimensionNumbers& gather_dimension_numbers);
2128 
ClassOf(const HloInstruction * hlo)2129   static bool ClassOf(const HloInstruction* hlo) {
2130     return hlo->opcode() == HloOpcode::kGather;
2131   }
2132 
2133  private:
2134   std::vector<std::string> ExtraAttributesToStringImpl(
2135       const HloPrintOptions& options) const override;
2136   bool IdenticalSlowPath(
2137       const HloInstruction& other,
2138       const std::function<bool(const HloComputation*, const HloComputation*)>&
2139           eq_computations) const override;
2140   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2141       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2142       HloCloneContext* context) const override;
2143 
2144   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
2145   std::vector<int64_t> gather_slice_sizes_;
2146   bool indices_are_sorted_;
2147 };
2148 
2149 class HloScatterInstruction : public HloInstruction {
2150  public:
2151   explicit HloScatterInstruction(
2152       const Shape& shape, absl::Span<HloInstruction* const> args,
2153       HloComputation* update_computation,
2154       const ScatterDimensionNumbers& scatter_dim_numbers,
2155       bool indices_are_sorted, bool unique_indices);
scatter_dimension_numbers()2156   const ScatterDimensionNumbers& scatter_dimension_numbers() const {
2157     CHECK(scatter_dimension_numbers_ != nullptr);
2158     return *scatter_dimension_numbers_;
2159   }
indices_are_sorted()2160   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)2161   void set_indices_are_sorted(bool indices_are_sorted) {
2162     indices_are_sorted_ = indices_are_sorted;
2163   }
unique_indices()2164   bool unique_indices() const override { return unique_indices_; }
2165   // Returns a serialized representation of this instruction.
2166   HloInstructionProto ToProto() const override;
scatter_operand_count()2167   int64_t scatter_operand_count() const { return operand_count() / 2; }
scatter_operands()2168   absl::Span<HloInstruction* const> scatter_operands() const {
2169     return absl::MakeConstSpan(operands()).first(scatter_operand_count());
2170   }
scatter_updates()2171   absl::Span<HloInstruction* const> scatter_updates() const {
2172     return absl::MakeConstSpan(operands()).last(scatter_operand_count());
2173   }
scatter_indices()2174   const HloInstruction* scatter_indices() const {
2175     return operand(scatter_operand_count());
2176   }
scatter_indices()2177   HloInstruction* scatter_indices() {
2178     return mutable_operand(scatter_operand_count());
2179   }
2180 
2181   // Creates an instance of ScatterDimensionNumbers.
2182   static ScatterDimensionNumbers MakeScatterDimNumbers(
2183       absl::Span<const int64_t> update_window_dims,
2184       absl::Span<const int64_t> inserted_window_dims,
2185       absl::Span<const int64_t> scatter_dims_to_operand_dims,
2186       int64_t index_vector_dim);
2187   // Returns the dump string of the given scatter dimension numbers.
2188   static std::string ScatterDimensionNumbersToString(
2189       const ScatterDimensionNumbers& scatter_dimension_numbers);
2190 
ClassOf(const HloInstruction * hlo)2191   static bool ClassOf(const HloInstruction* hlo) {
2192     return hlo->opcode() == HloOpcode::kScatter;
2193   }
2194 
2195  private:
2196   std::vector<std::string> ExtraAttributesToStringImpl(
2197       const HloPrintOptions& options) const override;
2198   bool IdenticalSlowPath(
2199       const HloInstruction& other,
2200       const std::function<bool(const HloComputation*, const HloComputation*)>&
2201           eq_computations) const override;
2202   // Implementation for non-common logic of CloneWithNewOperands.
2203   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2204       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2205       HloCloneContext* context) const override;
2206 
2207   std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
2208   bool indices_are_sorted_;
2209   bool unique_indices_;
2210 };
2211 
2212 class HloIotaInstruction : public HloInstruction {
2213  public:
2214   explicit HloIotaInstruction(const Shape& shape, int64_t iota_dimension);
2215 
2216   // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()2217   int64_t iota_dimension() const { return iota_dimension_; }
dimensions()2218   absl::Span<const int64_t> dimensions() const override {
2219     return absl::MakeConstSpan(&iota_dimension_, 1);
2220   }
2221   // Returns a serialized representation of this instruction.
2222   HloInstructionProto ToProto() const override;
2223 
ClassOf(const HloInstruction * hlo)2224   static bool ClassOf(const HloInstruction* hlo) {
2225     return hlo->opcode() == HloOpcode::kIota;
2226   }
2227 
2228  private:
2229   std::vector<std::string> ExtraAttributesToStringImpl(
2230       const HloPrintOptions& options) const override;
2231   bool IdenticalSlowPath(
2232       const HloInstruction& other,
2233       const std::function<bool(const HloComputation*, const HloComputation*)>&
2234           eq_computations) const override;
2235   // Implementation for non-common logic of CloneWithNewOperands.
2236   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2237       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2238       HloCloneContext* context) const override;
2239 
2240   int64_t iota_dimension_;
2241 };
2242 
2243 class HloDotInstruction : public HloInstruction {
2244  public:
2245   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
2246   // dimensions specified in 'dimension_numbers'.
2247   explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
2248                              HloInstruction* rhs,
2249                              const DotDimensionNumbers& dimension_numbers,
2250                              const PrecisionConfig& precision_config);
2251 
2252   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()2253   const DotDimensionNumbers& dot_dimension_numbers() const {
2254     return dot_dimension_numbers_;
2255   }
2256 
2257   // Returns the information used to tell the implementation information about
2258   // what sort of precision is requested. The meaning of the field is backend
2259   // specific. At the moment, it is only supported for kConvolution and kDot.
2260   // Transformations on one kDot or kConvolution to another will preserve this
2261   // information. Transformations to other HLOs will not preserve this
2262   // information but it is presumed that the alternate lowering is strictly
2263   // superior.
precision_config()2264   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()2265   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
2266 
2267   // Returns a serialized representation of this instruction.
2268   HloInstructionProto ToProto() const override;
2269 
ClassOf(const HloInstruction * hlo)2270   static bool ClassOf(const HloInstruction* hlo) {
2271     return hlo->opcode() == HloOpcode::kDot;
2272   }
2273 
2274  private:
2275   std::vector<std::string> ExtraAttributesToStringImpl(
2276       const HloPrintOptions& options) const override;
2277   bool IdenticalSlowPath(
2278       const HloInstruction& other,
2279       const std::function<bool(const HloComputation*, const HloComputation*)>&
2280           eq_computations) const override;
2281   // Implementation for non-common logic of CloneWithNewOperands.
2282   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2283       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2284       HloCloneContext* context) const override;
2285 
2286   // Describes the dimension numbers used for a dot.
2287   DotDimensionNumbers dot_dimension_numbers_;
2288 
2289   // Information used to communicate to the implementation about the algorithm
2290   // used to produce results. See the documentation on precision_config().
2291   PrecisionConfig precision_config_;
2292 };
2293 
2294 class HloDomainInstruction : public HloInstruction {
2295  public:
2296   explicit HloDomainInstruction(
2297       const Shape& shape, HloInstruction* operand,
2298       std::unique_ptr<DomainMetadata> operand_side_metadata,
2299       std::unique_ptr<DomainMetadata> user_side_metadata);
2300 
2301   // Returns a serialized representation of this instruction.
2302   HloInstructionProto ToProto() const override;
2303 
2304   // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()2305   const DomainMetadata& operand_side_metadata() const {
2306     return *operand_side_metadata_;
2307   }
2308   // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()2309   const DomainMetadata& user_side_metadata() const {
2310     return *user_side_metadata_;
2311   }
2312 
ClassOf(const HloInstruction * hlo)2313   static bool ClassOf(const HloInstruction* hlo) {
2314     return hlo->opcode() == HloOpcode::kDomain;
2315   }
2316 
2317  private:
2318   std::vector<std::string> ExtraAttributesToStringImpl(
2319       const HloPrintOptions& options) const override;
2320   bool IdenticalSlowPath(
2321       const HloInstruction& other,
2322       const std::function<bool(const HloComputation*, const HloComputation*)>&
2323           eq_computations) const override;
2324   // Implementation for non-common logic of CloneWithNewOperands.
2325   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2326       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2327       HloCloneContext* context) const override;
2328 
2329   std::unique_ptr<DomainMetadata> operand_side_metadata_;
2330   std::unique_ptr<DomainMetadata> user_side_metadata_;
2331 };
2332 
2333 class HloGetDimensionSizeInstruction : public HloInstruction {
2334  public:
2335   explicit HloGetDimensionSizeInstruction(const Shape& shape,
2336                                           HloInstruction* operand,
2337                                           int64_t dimension);
2338 
2339   // Returns the dimension sizes or numbers associated with this instruction.
dimension()2340   int64_t dimension() const { return dimension_; }
2341   // Returns a serialized representation of this instruction.
2342   HloInstructionProto ToProto() const override;
2343 
ClassOf(const HloInstruction * hlo)2344   static bool ClassOf(const HloInstruction* hlo) {
2345     return hlo->opcode() == HloOpcode::kGetDimensionSize;
2346   }
2347 
2348  private:
2349   std::vector<std::string> ExtraAttributesToStringImpl(
2350       const HloPrintOptions& options) const override;
2351   bool IdenticalSlowPath(
2352       const HloInstruction& other,
2353       const std::function<bool(const HloComputation*, const HloComputation*)>&
2354           eq_computations) const override;
2355   // Implementation for non-common logic of CloneWithNewOperands.
2356   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2357       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2358       HloCloneContext* context) const override;
2359 
2360   int64_t dimension_;
2361 };
2362 
2363 class HloSetDimensionSizeInstruction : public HloInstruction {
2364  public:
2365   explicit HloSetDimensionSizeInstruction(const Shape& shape,
2366                                           HloInstruction* operand,
2367                                           HloInstruction* val,
2368                                           int64_t dimension);
2369 
2370   // Returns the dimension sizes or numbers associated with this instruction.
dimension()2371   int64_t dimension() const { return dimension_; }
2372   // Returns a serialized representation of this instruction.
2373   HloInstructionProto ToProto() const override;
2374 
ClassOf(const HloInstruction * hlo)2375   static bool ClassOf(const HloInstruction* hlo) {
2376     return hlo->opcode() == HloOpcode::kSetDimensionSize;
2377   }
2378 
2379  private:
2380   std::vector<std::string> ExtraAttributesToStringImpl(
2381       const HloPrintOptions& options) const override;
2382   bool IdenticalSlowPath(
2383       const HloInstruction& other,
2384       const std::function<bool(const HloComputation*, const HloComputation*)>&
2385           eq_computations) const override;
2386   // Implementation for non-common logic of CloneWithNewOperands.
2387   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2388       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2389       HloCloneContext* context) const override;
2390 
2391   int64_t dimension_;
2392 };
2393 
2394 class HloRngGetAndUpdateStateInstruction : public HloInstruction {
2395  public:
2396   explicit HloRngGetAndUpdateStateInstruction(const Shape& shape,
2397                                               int64_t delta);
2398 
2399   // Returns the delta value.
delta()2400   int64_t delta() const { return delta_; }
set_delta(int64_t delta)2401   void set_delta(int64_t delta) { delta_ = delta; }
2402   // Returns a serialized representation of this instruction.
2403   HloInstructionProto ToProto() const override;
2404 
ClassOf(const HloInstruction * hlo)2405   static bool ClassOf(const HloInstruction* hlo) {
2406     return hlo->opcode() == HloOpcode::kRngGetAndUpdateState;
2407   }
2408 
2409  private:
2410   std::vector<std::string> ExtraAttributesToStringImpl(
2411       const HloPrintOptions& options) const override;
2412   bool IdenticalSlowPath(
2413       const HloInstruction& other,
2414       const std::function<bool(const HloComputation*, const HloComputation*)>&
2415           eq_computations) const override;
2416   // Implementation for non-common logic of CloneWithNewOperands.
2417   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2418       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2419       HloCloneContext* context) const override;
2420 
2421   int64_t delta_;
2422 };
2423 
2424 class HloRngBitGeneratorInstruction : public HloInstruction {
2425  public:
2426   HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
2427                                 RandomAlgorithm algorithm);
2428 
algorithm()2429   RandomAlgorithm algorithm() const { return algorithm_; }
2430   HloInstructionProto ToProto() const override;
2431 
ClassOf(const HloInstruction * hlo)2432   static bool ClassOf(const HloInstruction* hlo) {
2433     return hlo->opcode() == HloOpcode::kRngBitGenerator;
2434   }
2435 
2436  private:
2437   std::vector<std::string> ExtraAttributesToStringImpl(
2438       const HloPrintOptions& options) const override;
2439   bool IdenticalSlowPath(
2440       const HloInstruction& other,
2441       const std::function<bool(const HloComputation*, const HloComputation*)>&
2442           eq_computations) const override;
2443   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2444       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2445       HloCloneContext* context) const override;
2446 
2447   RandomAlgorithm algorithm_;
2448 };
2449 
2450 }  // namespace xla
2451 
2452 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
2453