• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // All HloInstruction subclasses are put in this file.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 
29 class HloBatchNormInstruction : public HloInstruction {
30  public:
31   // Returns feature_index field associated with the instruction. The index
32   // represents the index of the feature dimension.
feature_index()33   int64 feature_index() const { return feature_index_; }
34 
35   // Returns a epsilon value associated with the instruction. The is a small
36   // number added to the variance to avoid divide-by-zero error.
epsilon()37   float epsilon() const { return epsilon_; }
38 
39   // Returns a serialized representation of this instruction.
40   HloInstructionProto ToProto() const override;
41 
42  protected:
43   explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
44                                    HloInstruction* operand,
45                                    HloInstruction* scale, float epsilon,
46                                    int64_t feature_index);
47 
48  private:
49   std::vector<string> ExtraAttributesToStringImpl(
50       const HloPrintOptions& options) const override;
51   bool IdenticalSlowPath(
52       const HloInstruction& other,
53       const std::function<bool(const HloComputation*, const HloComputation*)>&
54           eq_computations) const override;
55   // A small float number added to the variance to avoid divide-by-zero error.
56   float epsilon_ = 0.0f;
57 
58   // An integer value representing the index of the feature dimension.
59   int64 feature_index_ = -1;
60 };
61 
62 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
63  public:
64   explicit HloBatchNormTrainingInstruction(
65       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
66       HloInstruction* offset, float epsilon, int64_t feature_index);
67 
68  private:
69   // Implementation for non-common logic of CloneWithNewOperands.
70   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
71       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
72       HloCloneContext* context) const override;
73 };
74 
75 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
76  public:
77   explicit HloBatchNormInferenceInstruction(
78       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
79       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
80       float epsilon, int64_t feature_index);
81 
82  private:
83   // Implementation for non-common logic of CloneWithNewOperands.
84   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
85       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
86       HloCloneContext* context) const override;
87 };
88 
89 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
90  public:
91   explicit HloBatchNormGradInstruction(
92       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
93       HloInstruction* mean, HloInstruction* variance,
94       HloInstruction* grad_output, float epsilon, int64_t feature_index);
95 
96  private:
97   // Implementation for non-common logic of CloneWithNewOperands.
98   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
99       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
100       HloCloneContext* context) const override;
101 };
102 
103 class HloFftInstruction : public HloInstruction {
104  public:
105   explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
106                              FftType fft_type,
107                              absl::Span<const int64> fft_length);
fft_type()108   FftType fft_type() const { return fft_type_; }
109 
fft_length()110   const std::vector<int64>& fft_length() const { return fft_length_; }
111 
112   // Returns a serialized representation of this instruction.
113   HloInstructionProto ToProto() const override;
114 
115  private:
116   std::vector<string> ExtraAttributesToStringImpl(
117       const HloPrintOptions& options) const override;
118   bool IdenticalSlowPath(
119       const HloInstruction& other,
120       const std::function<bool(const HloComputation*, const HloComputation*)>&
121           eq_computations) const override;
122 
123   // Implementation for non-common logic of CloneWithNewOperands.
124   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
125       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
126       HloCloneContext* context) const override;
127 
128   // Describes FFT type for an FFT instruction.
129   FftType fft_type_ = FftType::FFT;
130 
131   // Indicates the FFT length for an FFT instruction.
132   std::vector<int64> fft_length_;
133 };
134 
135 class HloCopyStartInstruction : public HloInstruction {
136  public:
137   explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
138                                    bool is_cross_program_prefetch);
139 
is_cross_program_prefetch()140   bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
141   HloInstructionProto ToProto() const override;
142 
143  private:
144   std::vector<string> ExtraAttributesToStringImpl(
145       const HloPrintOptions& options) const override;
146   bool IdenticalSlowPath(
147       const HloInstruction& other,
148       const std::function<bool(const HloComputation*, const HloComputation*)>&
149           eq_computations) const override;
150   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
151       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
152       HloCloneContext* context) const override;
153 
154   bool is_cross_program_prefetch_;
155 };
156 
157 class HloCompareInstruction : public HloInstruction {
158  public:
159   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
160                                  HloInstruction* rhs,
161                                  ComparisonDirection direction,
162                                  absl::optional<Comparison::Type> type);
direction()163   ComparisonDirection direction() const { return compare_.GetDirection(); }
type()164   Comparison::Type type() const { return compare_.GetType(); }
165   HloInstructionProto ToProto() const override;
166 
167  private:
168   std::vector<string> ExtraAttributesToStringImpl(
169       const HloPrintOptions& options) const override;
170   bool IdenticalSlowPath(
171       const HloInstruction& other,
172       const std::function<bool(const HloComputation*, const HloComputation*)>&
173           eq_computations) const override;
174   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
175       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
176       HloCloneContext* context) const override;
177 
178   Comparison compare_;
179 };
180 
181 class HloTriangularSolveInstruction : public HloInstruction {
182  public:
183   explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
184                                          HloInstruction* b,
185                                          const TriangularSolveOptions& options);
triangular_solve_options()186   const TriangularSolveOptions& triangular_solve_options() const {
187     return triangular_solve_options_;
188   }
189 
190   // Returns a serialized representation of this instruction.
191   HloInstructionProto ToProto() const override;
192 
193  private:
194   std::vector<string> ExtraAttributesToStringImpl(
195       const HloPrintOptions& options) const override;
196   bool IdenticalSlowPath(
197       const HloInstruction& other,
198       const std::function<bool(const HloComputation*, const HloComputation*)>&
199           eq_computations) const override;
200 
201   // Implementation for non-common logic of CloneWithNewOperands.
202   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
203       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
204       HloCloneContext* context) const override;
205 
206   TriangularSolveOptions triangular_solve_options_;
207 };
208 
209 class HloCholeskyInstruction : public HloInstruction {
210  public:
211   explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
212                                   const CholeskyOptions& options);
cholesky_options()213   const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
214 
215   // Returns a serialized representation of this instruction.
216   HloInstructionProto ToProto() const override;
217 
218  private:
219   std::vector<string> ExtraAttributesToStringImpl(
220       const HloPrintOptions& options) const override;
221   bool IdenticalSlowPath(
222       const HloInstruction& other,
223       const std::function<bool(const HloComputation*, const HloComputation*)>&
224           eq_computations) const override;
225 
226   // Implementation for non-common logic of CloneWithNewOperands.
227   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
228       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
229       HloCloneContext* context) const override;
230 
231   CholeskyOptions cholesky_options_;
232 };
233 
234 // Class that represents instructions that synchronize and transfer data between
235 // partitioned devices. Send/Recv and collective instructions (AllReduce,
236 // AllToAll, CollectivePermute) belong to this instruction type. A group of
237 // instructions (of the same opcode) with the same channel_id communicate during
238 // execution.
239 class HloChannelInstruction : public HloInstruction {
240  public:
241   // Returns the channel id associated with the instruction. The id is
242   // shared between each Send/Recv pair or a group of collective instructions
243   // and is globally unique to identify each channel.
channel_id()244   absl::optional<int64> channel_id() const { return channel_id_; }
245   void set_channel_id(const absl::optional<int64>& channel_id);
246 
247   // Whether this instruction is identical to `other` except for the values of
248   // channel IDs, as long as both have channel IDs or neither has a channel ID.
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)249   virtual bool IdenticalSlowPathIgnoringChannelIdValues(
250       const HloInstruction& other,
251       const std::function<bool(const HloComputation*, const HloComputation*)>&
252           eq_computations) const {
253     return channel_id_.has_value() == other.channel_id().has_value();
254   }
255 
256  protected:
257   explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
258                                  const absl::optional<int64>& channel_id);
259 
260   HloInstructionProto ToProto() const override;
261 
262   std::vector<string> ExtraAttributesToStringImpl(
263       const HloPrintOptions& options) const override;
264 
265   // Do not override IdenticalSlowPath(). Override
266   // IdenticalSlowPathIgnoringChannelIdValues() instead.
267   bool IdenticalSlowPath(
268       const HloInstruction& other,
269       const std::function<bool(const HloComputation*, const HloComputation*)>&
270           eq_computations) const final;
271 
272   absl::optional<int64> channel_id_;
273 };
274 
275 class HloSendRecvInstruction : public HloChannelInstruction {
276  public:
277   // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()278   bool is_host_transfer() const { return is_host_transfer_; }
279 
280   // Returns a serialized representation of this instruction.
281   HloInstructionProto ToProto() const override;
282 
283  protected:
284   explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
285                                   int64_t channel_id, bool is_host_transfer);
286 
287  private:
288   std::vector<string> ExtraAttributesToStringImpl(
289       const HloPrintOptions& options) const override;
290   bool IdenticalSlowPathIgnoringChannelIdValues(
291       const HloInstruction& other,
292       const std::function<bool(const HloComputation*, const HloComputation*)>&
293           eq_computations) const override;
294   // Whether this send/recv instruction sends data to/from the host.
295   bool is_host_transfer_;
296 };
297 
298 class HloSendInstruction : public HloSendRecvInstruction {
299  public:
300   explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
301                               int64_t channel_id, bool is_host_transfer);
302 
303  private:
304   // Implementation for non-common logic of CloneWithNewOperands.
305   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
306       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
307       HloCloneContext* context) const override;
308 };
309 
310 class HloSendDoneInstruction : public HloSendRecvInstruction {
311  public:
312   explicit HloSendDoneInstruction(HloSendInstruction* operand,
313                                   bool is_host_transfer);
314 
315  private:
316   // Implementation for non-common logic of CloneWithNewOperands.
317   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
318       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
319       HloCloneContext* context) const override;
320 };
321 
322 class HloRecvInstruction : public HloSendRecvInstruction {
323  public:
324   explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
325                               int64_t channel_id, bool is_host_transfer);
326 
327  private:
328   // Implementation for non-common logic of CloneWithNewOperands.
329   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
330       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
331       HloCloneContext* context) const override;
332 };
333 
334 class HloRecvDoneInstruction : public HloSendRecvInstruction {
335  public:
336   explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
337                                   bool is_host_transfer);
338 
339  private:
340   // Implementation for non-common logic of CloneWithNewOperands.
341   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
342       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
343       HloCloneContext* context) const override;
344 };
345 
346 class HloCollectiveInstruction : public HloChannelInstruction {
347  public:
replica_groups()348   const std::vector<ReplicaGroup>& replica_groups() const {
349     return replica_groups_;
350   }
351 
352   // Returns true if the layout of the AllReduce is enforced by XLA client (as
353   // the layout set in the shape). The only reason for the client to set the
354   // layout is to separately compile computations that communicate with
355   // AllReduce. Since this field is only set `true` by the client, the compiler
356   // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
357   // `false` for all other cases.
358   //
359   // When this is `true`, there may be communication endpoints outside the
360   // current compilation unit, so the compiler considers this AllReduce as
361   // side-effecting to disable compiler transformations. The compiler is free to
362   // transform unconstrained AllReduces differently across compilation units.
363   // It is an error for an HloModule to have a mix of constrained and
364   // unconstrained AllReduce instructions (checked by HloVerifier).
constrain_layout()365   bool constrain_layout() const { return constrain_layout_; }
366 
367  protected:
368   explicit HloCollectiveInstruction(
369       HloOpcode opcode, const Shape& shape,
370       absl::Span<HloInstruction* const> operands,
371       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
372       const absl::optional<int64>& channel_id);
373 
374   HloInstructionProto ToProto() const override;
375 
376   std::vector<string> ExtraAttributesToStringImpl(
377       const HloPrintOptions& options) const override;
378   bool IdenticalSlowPathIgnoringChannelIdValues(
379       const HloInstruction& other,
380       const std::function<bool(const HloComputation*, const HloComputation*)>&
381           eq_computations) const override;
382 
383   std::vector<ReplicaGroup> replica_groups_;
384   bool constrain_layout_;
385 };
386 
387 class HloAllGatherInstruction : public HloCollectiveInstruction {
388  public:
389   explicit HloAllGatherInstruction(
390       HloOpcode opcode, const Shape& shape,
391       absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
392       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
393       const absl::optional<int64>& channel_id, bool use_global_device_ids);
394   // Same as HloAllReduceInstruction::use_global_device_ids.
use_global_device_ids()395   bool use_global_device_ids() const { return use_global_device_ids_; }
396 
397   // The dimension on which data from different participants are concatenated.
all_gather_dimension()398   int64 all_gather_dimension() const { return all_gather_dimension_; }
399 
set_all_gather_dimension(int64_t dim)400   void set_all_gather_dimension(int64_t dim) { all_gather_dimension_ = dim; }
401 
402  protected:
403   std::vector<string> ExtraAttributesToStringImpl(
404       const HloPrintOptions& options) const override;
405   HloInstructionProto ToProto() const override;
406 
407  private:
408   bool IdenticalSlowPathIgnoringChannelIdValues(
409       const HloInstruction& other,
410       const std::function<bool(const HloComputation*, const HloComputation*)>&
411           eq_computations) const override;
412 
413   // Implementation for non-common logic of CloneWithNewOperands.
414   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
415       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
416       HloCloneContext* context) const override;
417 
418   int64 all_gather_dimension_;
419   bool use_global_device_ids_;
420 };
421 
422 // Base class for all-reduce and all-reduce scatter instructions.
423 class HloAllReduceInstructionBase : public HloCollectiveInstruction {
424  public:
425   explicit HloAllReduceInstructionBase(
426       HloOpcode opcode, const Shape& shape,
427       absl::Span<HloInstruction* const> operands,
428       HloComputation* reduce_computation,
429       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
430       const absl::optional<int64>& channel_id, bool use_global_device_ids);
431 
432   // Returns true if the ids in the ReplicaGroup config represent a global id of
433   // (replica_id * partition_count + partition_id) instead of a replica id.
434   // This enables more flexible grouping of devices if this all-reduce is both
435   // cross-partition and cross-replica.
436   //
437   // For example with 2 replicas and 4 partitions,
438   // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
439   // group[0] = (0,0), (0,1), (1,0), (1,1)
440   // group[1] = (0,2), (0,3), (1,2), (1,3)
441   // where each pair is (replica_id, partition_id).
use_global_device_ids()442   bool use_global_device_ids() const { return use_global_device_ids_; }
443 
444  protected:
445   std::vector<string> ExtraAttributesToStringImpl(
446       const HloPrintOptions& options) const override;
447   HloInstructionProto ToProto() const override;
448 
449   bool IdenticalSlowPathIgnoringChannelIdValues(
450       const HloInstruction& other,
451       const std::function<bool(const HloComputation*, const HloComputation*)>&
452           eq_computations) const override;
453 
454  private:
455   bool use_global_device_ids_;
456 };
457 
458 class HloAllReduceInstruction : public HloAllReduceInstructionBase {
459  public:
460   using HloAllReduceInstructionBase::HloAllReduceInstructionBase;
461 
462   // Returns true if the AllReduce does no communication, so it's equivalent
463   // to a mem copy.
464   bool IsNoop() const;
465 
466  private:
467   // Implementation for non-common logic of CloneWithNewOperands.
468   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
469       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
470       HloCloneContext* context) const override;
471 };
472 
473 class HloReduceScatterInstruction : public HloAllReduceInstructionBase {
474  public:
475   explicit HloReduceScatterInstruction(
476       const Shape& shape, absl::Span<HloInstruction* const> operands,
477       HloComputation* reduce_computation,
478       absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
479       const absl::optional<int64>& channel_id, bool use_global_device_ids,
480       int64_t scatter_dimension);
481 
482   // The dimension on which reduced data is scattered to different participants.
scatter_dimension()483   int64 scatter_dimension() const { return scatter_dimension_; }
484 
485  protected:
486   std::vector<string> ExtraAttributesToStringImpl(
487       const HloPrintOptions& options) const override;
488   HloInstructionProto ToProto() const override;
489 
490  private:
491   bool IdenticalSlowPathIgnoringChannelIdValues(
492       const HloInstruction& other,
493       const std::function<bool(const HloComputation*, const HloComputation*)>&
494           eq_computations) const override;
495 
496   // Implementation for non-common logic of CloneWithNewOperands.
497   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
498       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
499       HloCloneContext* context) const override;
500 
501   int64 scatter_dimension_;
502 };
503 
504 class HloAllToAllInstruction : public HloCollectiveInstruction {
505  public:
506   explicit HloAllToAllInstruction(const Shape& shape,
507                                   absl::Span<HloInstruction* const> operands,
508                                   absl::Span<const ReplicaGroup> replica_groups,
509                                   bool constrain_layout,
510                                   const absl::optional<int64>& channel_id,
511                                   const absl::optional<int64>& split_dimension);
512 
513   // AllToAll can optionally take a split dimension, which means that this
514   // AllToAll takes a single (flattened) array operand and produces an array
515   // output (instead of taking a list of operands and producing a tuple).
516   //
517   // split_dimension specifies which dimension in the operand is split across
518   // devices in each replica_group, and also means the concatenated dimension
519   // on the output (i.e., input and the output shapes are the same).
split_dimension()520   absl::optional<int64> split_dimension() const { return split_dimension_; }
set_split_dimension(int64_t dim)521   void set_split_dimension(int64_t dim) { split_dimension_ = dim; }
522 
523  protected:
524   std::vector<string> ExtraAttributesToStringImpl(
525       const HloPrintOptions& options) const override;
526   HloInstructionProto ToProto() const override;
527 
528  private:
529   bool IdenticalSlowPathIgnoringChannelIdValues(
530       const HloInstruction& other,
531       const std::function<bool(const HloComputation*, const HloComputation*)>&
532           eq_computations) const override;
533 
534   // Implementation for non-common logic of CloneWithNewOperands.
535   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
536       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
537       HloCloneContext* context) const override;
538 
539   absl::optional<int64> split_dimension_;
540 };
541 
542 class HloCollectivePermuteInstruction : public HloChannelInstruction {
543  public:
544   explicit HloCollectivePermuteInstruction(
545       HloOpcode opcode, const Shape& shape, HloInstruction* operand,
546       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
547       const absl::optional<int64_t>& channel_id);
548 
549   explicit HloCollectivePermuteInstruction(
550       HloOpcode opcode, const Shape& shape, HloInstruction* input,
551       HloInstruction* output, HloInstruction* input_start_indices,
552       HloInstruction* output_start_indices,
553       absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
554       absl::Span<const std::vector<int64_t>> slice_sizes,
555       const absl::optional<int64_t>& channel_id);
556 
source_target_pairs()557   const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const {
558     return source_target_pairs_;
559   }
560 
dynamic_slice_sizes_list()561   const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const {
562     return slice_sizes_;
563   }
564 
565   // Returns a serialized representation of this instruction.
566   HloInstructionProto ToProto() const override;
567 
568  private:
569   std::vector<string> ExtraAttributesToStringImpl(
570       const HloPrintOptions& options) const override;
571   bool IdenticalSlowPathIgnoringChannelIdValues(
572       const HloInstruction& other,
573       const std::function<bool(const HloComputation*, const HloComputation*)>&
574           eq_computations) const override;
575 
576   // Implementation for non-common logic of CloneWithNewOperands.
577   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
578       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
579       HloCloneContext* context) const override;
580 
581   const std::vector<std::pair<int64_t, int64_t>> source_target_pairs_;
582   const std::vector<std::vector<int64_t>> slice_sizes_;
583 };
584 
585 class HloReverseInstruction : public HloInstruction {
586  public:
587   explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
588                                  absl::Span<const int64> dimensions);
589   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()590   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)591   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()592   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
593   // Returns a serialized representation of this instruction.
594   HloInstructionProto ToProto() const override;
595 
596  private:
597   std::vector<string> ExtraAttributesToStringImpl(
598       const HloPrintOptions& options) const override;
599   bool IdenticalSlowPath(
600       const HloInstruction& other,
601       const std::function<bool(const HloComputation*, const HloComputation*)>&
602           eq_computations) const override;
603   // Implementation for non-common logic of CloneWithNewOperands.
604   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
605       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
606       HloCloneContext* context) const override;
607 
608   std::vector<int64> dimensions_;
609 };
610 
611 class HloConcatenateInstruction : public HloInstruction {
612  public:
613   explicit HloConcatenateInstruction(const Shape& shape,
614                                      absl::Span<HloInstruction* const> operands,
615                                      int64_t dimension);
616   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()617   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)618   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()619   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
620   // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()621   int64 concatenate_dimension() const { return dimensions(0); }
622   // Returns a serialized representation of this instruction.
623   HloInstructionProto ToProto() const override;
624 
625  private:
626   std::vector<string> ExtraAttributesToStringImpl(
627       const HloPrintOptions& options) const override;
628   bool IdenticalSlowPath(
629       const HloInstruction& other,
630       const std::function<bool(const HloComputation*, const HloComputation*)>&
631           eq_computations) const override;
632   // Implementation for non-common logic of CloneWithNewOperands.
633   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
634       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
635       HloCloneContext* context) const override;
636 
637   std::vector<int64> dimensions_;
638 };
639 
640 class HloReduceInstruction : public HloInstruction {
641  public:
642   explicit HloReduceInstruction(const Shape& shape,
643                                 absl::Span<HloInstruction* const> args,
644                                 absl::Span<const int64> dimensions_to_reduce,
645                                 HloComputation* reduce_computation);
646   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()647   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)648   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()649   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
650   // Returns a serialized representation of this instruction.
651   HloInstructionProto ToProto() const override;
652 
653   // Returns the number of input arrays (and, consequentially, the number of
654   // init values) this reduce has.
input_count()655   int64 input_count() const { return operand_count() / 2; }
656 
657   // Returns the input tensors to be reduced.
inputs()658   absl::Span<HloInstruction* const> inputs() const {
659     return absl::MakeSpan(operands()).subspan(0, input_count());
660   }
661 
662   // Returns the init values of the reduction.
init_values()663   absl::Span<HloInstruction* const> init_values() const {
664     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
665   }
666 
667  private:
668   std::vector<string> ExtraAttributesToStringImpl(
669       const HloPrintOptions& options) const override;
670   bool IdenticalSlowPath(
671       const HloInstruction& other,
672       const std::function<bool(const HloComputation*, const HloComputation*)>&
673           eq_computations) const override;
674   // Implementation for non-common logic of CloneWithNewOperands.
675   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
676       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
677       HloCloneContext* context) const override;
678 
679   std::vector<int64> dimensions_;
680 };
681 
682 class HloSortInstruction : public HloInstruction {
683  public:
684   explicit HloSortInstruction(const Shape& shape, int64_t dimension,
685                               absl::Span<HloInstruction* const> operands,
686                               HloComputation* compare, bool is_stable);
687   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()688   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)689   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()690   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
691   // Returns the sort dimension for this instruction
sort_dimension()692   int64 sort_dimension() const { return dimensions(0); }
693   // Returns a serialized representation of this instruction.
694   HloInstructionProto ToProto() const override;
695   // Returns the key operand to this instruction.
keys()696   const HloInstruction* keys() const { return operand(0); }
mutable_keys()697   HloInstruction* mutable_keys() { return mutable_operand(0); }
698   // Returns the number of value operands.
values_count()699   int64 values_count() const { return operand_count() - 1; }
is_stable()700   bool is_stable() const { return is_stable_; }
701 
702  private:
703   std::vector<string> ExtraAttributesToStringImpl(
704       const HloPrintOptions& options) const override;
705   bool IdenticalSlowPath(
706       const HloInstruction& other,
707       const std::function<bool(const HloComputation*, const HloComputation*)>&
708           eq_computations) const override;
709   // Implementation for non-common logic of CloneWithNewOperands.
710   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
711       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
712       HloCloneContext* context) const override;
713 
714   std::vector<int64> dimensions_;
715   bool is_stable_;
716 };
717 
718 class HloTransposeInstruction : public HloInstruction {
719  public:
720   explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
721                                    absl::Span<const int64> dimensions);
722   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()723   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)724   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()725   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
726   // Returns whether this instruction does a rank-2 transposition.
727   bool IsRank2Transpose() const;
728   // Returns a serialized representation of this instruction.
729   HloInstructionProto ToProto() const override;
730 
731  private:
732   std::vector<string> ExtraAttributesToStringImpl(
733       const HloPrintOptions& options) const override;
734   bool IdenticalSlowPath(
735       const HloInstruction& other,
736       const std::function<bool(const HloComputation*, const HloComputation*)>&
737           eq_computations) const override;
738   // Implementation for non-common logic of CloneWithNewOperands.
739   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
740       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
741       HloCloneContext* context) const override;
742 
743   std::vector<int64> dimensions_;
744 };
745 
746 class HloBroadcastInstruction : public HloInstruction {
747  public:
748   explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
749                                    absl::Span<const int64> broadcast_dimension);
750   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()751   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)752   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()753   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
754   // Returns a serialized representation of this instruction.
755   HloInstructionProto ToProto() const override;
756 
757  private:
758   std::vector<string> ExtraAttributesToStringImpl(
759       const HloPrintOptions& options) const override;
760   bool IdenticalSlowPath(
761       const HloInstruction& other,
762       const std::function<bool(const HloComputation*, const HloComputation*)>&
763           eq_computations) const override;
764   // Implementation for non-common logic of CloneWithNewOperands.
765   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
766       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
767       HloCloneContext* context) const override;
768 
769   std::vector<int64> dimensions_;
770 };
771 
772 class HloDynamicReshapeInstruction : public HloInstruction {
773  public:
774   explicit HloDynamicReshapeInstruction(
775       const Shape& shape, HloInstruction* data_operand,
776       absl::Span<HloInstruction* const> dim_sizes);
777 
778   // Returns the input dim sizes dimensions, which is operands[1:]
dim_sizes()779   absl::Span<HloInstruction* const> dim_sizes() const {
780     return absl::MakeSpan(operands()).subspan(1, operand_count());
781   }
782 
783   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
784       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
785       HloCloneContext* context) const override;
786 
787   // Returns the input dim size dimension, which is operands[1+i]
dim_sizes(int64_t i)788   HloInstruction* dim_sizes(int64_t i) const { return operands()[i + 1]; }
789 };
790 
791 class HloReshapeInstruction : public HloInstruction {
792  public:
793   explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
794                                  int64_t inferred_dimension);
inferred_dimension()795   int64 inferred_dimension() const { return inferred_dimension_; }
796   HloInstructionProto ToProto() const override;
797 
798  private:
799   std::vector<string> ExtraAttributesToStringImpl(
800       const HloPrintOptions& options) const override;
801   bool IdenticalSlowPath(
802       const HloInstruction& other,
803       const std::function<bool(const HloComputation*, const HloComputation*)>&
804           eq_computations) const override;
805   // Implementation for non-common logic of CloneWithNewOperands.
806   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
807       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
808       HloCloneContext* context) const override;
809   int64 inferred_dimension_;
810 };
811 
812 class HloMapInstruction : public HloInstruction {
813  public:
814   explicit HloMapInstruction(const Shape& shape,
815                              absl::Span<HloInstruction* const> operands,
816                              HloComputation* map_computation);
817   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()818   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64_t index)819   int64 dimensions(int64_t index) const override { return dimensions()[index]; }
mutable_dimensions()820   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
821   // Returns a serialized representation of this instruction.
822   HloInstructionProto ToProto() const override;
823 
824  private:
825   bool IsElementwiseImpl(
826       const absl::optional<int64>& operand_idx) const override;
827   std::vector<string> ExtraAttributesToStringImpl(
828       const HloPrintOptions& options) const override;
829   bool IdenticalSlowPath(
830       const HloInstruction& other,
831       const std::function<bool(const HloComputation*, const HloComputation*)>&
832           eq_computations) const override;
833   // Implementation for non-common logic of CloneWithNewOperands.
834   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
835       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
836       HloCloneContext* context) const override;
837 
838   std::vector<int64> dimensions_;
839 };
840 
841 class HloSliceInstruction : public HloInstruction {
842  public:
843   explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
844                                absl::Span<const int64> start_indices,
845                                absl::Span<const int64> limit_indices,
846                                absl::Span<const int64> strides);
847 
848   HloInstructionProto ToProto() const override;
849 
850   // Returns the start index in the given dimension for a slice node.
slice_starts(int64_t dimension)851   int64 slice_starts(int64_t dimension) const {
852     return slice_starts_[dimension];
853   }
slice_starts()854   const std::vector<int64>& slice_starts() const { return slice_starts_; }
mutable_slice_starts()855   std::vector<int64>* mutable_slice_starts() { return &slice_starts_; }
856 
857   // Returns the (exclusive) limit index in the given dimension for a slice
858   // node.
slice_limits(int64_t dimension)859   int64 slice_limits(int64_t dimension) const {
860     return slice_limits_[dimension];
861   }
slice_limits()862   const std::vector<int64>& slice_limits() const { return slice_limits_; }
mutable_slice_limits()863   std::vector<int64>* mutable_slice_limits() { return &slice_limits_; }
864 
865   // Returns the stride in the given dimension for a slice node.
slice_strides(int64_t dimension)866   int64 slice_strides(int64_t dimension) const {
867     return slice_strides_[dimension];
868   }
slice_strides()869   const std::vector<int64>& slice_strides() const { return slice_strides_; }
mutable_slice_strides()870   std::vector<int64>* mutable_slice_strides() { return &slice_strides_; }
871 
872  private:
873   std::vector<string> ExtraAttributesToStringImpl(
874       const HloPrintOptions& options) const override;
875   bool IdenticalSlowPath(
876       const HloInstruction& other,
877       const std::function<bool(const HloComputation*, const HloComputation*)>&
878           eq_computations) const override;
879   // Implementation for non-common logic of CloneWithNewOperands.
880   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
881       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
882       HloCloneContext* context) const override;
883 
884   // Describes the [begin, end) index range for a slice.
885   std::vector<int64> slice_starts_;
886   std::vector<int64> slice_limits_;
887   std::vector<int64> slice_strides_;
888 };
889 
890 class HloConstantInstruction : public HloInstruction {
891  public:
892   explicit HloConstantInstruction(Literal literal);
893   explicit HloConstantInstruction(Literal literal, const Shape& shape);
894   // Used when the literal is too large and dropped.
895   explicit HloConstantInstruction(const Shape& shape);
896   // Returns the literal associated with this instruction.
literal()897   const Literal& literal() const { return *literal_; }
898   // Returns the (mutable) literal associated with this instruction.
mutable_literal()899   Literal* mutable_literal() { return &literal_.value(); }
900   // Returns whether there is literal associated with this instruction.
HasLiteral()901   bool HasLiteral() const { return literal_.has_value(); }
902   // Returns a serialized representation of this instruction.
903   HloInstructionProto ToProto() const override;
904 
905   // Change the layout for an Constant Hlo instruction to match new_layout.  For
906   // tuple shaped constants shape_index is the path to the internal array
907   // subshape whose layout needs to be changed.
908   void RelayoutConstant(const Layout& new_layout,
909                         const ShapeIndex& shape_index = {});
910 
911  private:
912   bool IsElementwiseImpl(
913       const absl::optional<int64>& operand_idx) const override;
914   bool IdenticalSlowPath(
915       const HloInstruction& other,
916       const std::function<bool(const HloComputation*, const HloComputation*)>&
917           eq_computations) const override;
918   string OperandsToStringWithCanonicalNameMap(
919       const HloPrintOptions& options,
920       CanonicalNameMap* canonical_name_map) const override;
921   // Implementation for non-common logic of CloneWithNewOperands.
922   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
923       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
924       HloCloneContext* context) const override;
925   absl::optional<Literal> literal_;
926 };
927 
928 class HloTraceInstruction : public HloInstruction {
929  public:
930   explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
931   // Returns a tag to be used in tracing.
TracingTag()932   string TracingTag() const { return literal_.GetR1U8AsString(); }
933   // Returns a serialized representation of this instruction.
934   HloInstructionProto ToProto() const override;
935 
936  private:
937   bool IdenticalSlowPath(
938       const HloInstruction& other,
939       const std::function<bool(const HloComputation*, const HloComputation*)>&
940           eq_computations) const override;
941   // Implementation for non-common logic of CloneWithNewOperands.
942   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
943       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
944       HloCloneContext* context) const override;
945   Literal literal_;
946 };
947 
948 class HloFusionInstruction : public HloInstruction {
949  public:
950   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
951                                 HloInstruction* fused_root);
952 
953   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
954                                 absl::Span<HloInstruction* const> operands,
955                                 HloComputation* fusion_computation);
956 
957   ~HloFusionInstruction() override;
958 
959   void ClearCalledComputations() override;
960 
961   // When a fusion instruction is being destructed, clear the back pointer of
962   // its fusion computation, to avoid referencing freed memory.
963   void ClearFusionComputationInstruction();
964 
965   string ToCategory() const override;
966   // Returns a serialized representation of this instruction.
967   HloInstructionProto ToProto() const override;
968 
969   // Adds a new operand the fusion instruction.
970   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
971 
972   // Merges the fused instructions from 'instruction_to_merge' into the
973   // fused instruction set of 'this', updating operands as necessary.
974   //
975   // Precondition: 'instruction_to_merge' must be an operand of 'this'.
976   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
977 
978   // Merges the fused instructions from instruction_to_merge into the fused
979   // instruction set of 'this' and generates multioutput fusion instructions.
980   // All the users of instruction_to_merge will be redirected to 'this'
981   // instruction. instruction_to_merge will be removed from its parent
982   // computation.
983   void MergeFusionInstructionIntoMultiOutput(
984       HloFusionInstruction* instruction_to_merge);
985 
986   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
987   // is cloned and the clone is placed in the fusion
988   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
989   // than moved to cleanly handle the case where the instruction has a use
990   // outside the fusion instruction. Moving such an instruction into a fusion
991   // instruction would violate the single-result invariant of HLO instructions
992   // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)993   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
994     return FuseInstructionInternal(instruction_to_fuse);
995   }
996 
997   // Fuses the given instruction in this fusion instruction and generates a
998   // multioutput fusion instruction. A clone of the instruction_to_fuse will
999   // be part of the output of fusion instructions. The users of
1000   // instruction_to_fuse will be redirected to this fusion instructions.
1001   // instruction_to_fuse is unchanged otherwise.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)1002   HloInstruction* FuseInstructionIntoMultiOutput(
1003       HloInstruction* instruction_to_fuse) {
1004     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
1005   }
1006 
1007   // Returns the computation for this fused instruction.
1008   HloComputation* fused_instructions_computation() const;
1009 
1010   // Returns the root instruction of the fused expression contained within this
1011   // fusion instruction.
1012   HloInstruction* fused_expression_root() const;
1013 
1014   // Returns the list of fused instructions inside this fusion instruction.  The
1015   // returned type is a range of HloInstruction*s.
1016   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1017       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1018   fused_instructions() const;
1019 
1020   const tensorflow::gtl::iterator_range<
1021       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1022   fused_instructions();
1023 
1024   // Gets the number of instructions inside this fusion instruction.
1025   int64 fused_instruction_count() const;
1026 
1027   // Returns the fused parameter instruction in this fusion instruction
1028   // corresponding to the given parameter number.
1029   HloInstruction* fused_parameter(int64_t parameter_number) const;
1030 
1031   // Returns the vector of fused parameters inside this fusion instruction.
1032   const std::vector<HloInstruction*>& fused_parameters() const;
1033 
1034   // Returns true if this instruction is a fusion instruction that generates
1035   // multiple outputs.
IsMultiOutputFusion()1036   const bool IsMultiOutputFusion() const {
1037     return fused_expression_root()->opcode() == HloOpcode::kTuple;
1038   }
1039 
fusion_kind()1040   FusionKind fusion_kind() const { return fusion_kind_; }
1041 
set_fusion_kind(FusionKind kind)1042   void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
1043 
1044   // If multiple operands are the same instruction, keeps only one of them.
1045   Status DeduplicateFusionOperands();
1046 
1047  private:
1048   // Fuses the given instruction into this fusion instruction.
1049   // instruction_to_fuse is cloned and the clone is placed in the fusion
1050   // instruction.  The users of instruction_to_fuse will be redirected to this
1051   // fusion instruction. instruction_to_fuse is unchanged otherwise. When
1052   // add_output is true, a clone of the instruction_to_fuse will be added as
1053   // additional output resulting in a multi-output fusion.
1054   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1055                                           bool add_output = false);
1056   // Clones the given instruction_to_fuse and insert the clone into this fusion
1057   // instruction. If add_output is true, a clone of instruction_to_fuse will
1058   // be in the output of the this fusion instruction (part of the tuple of the
1059   // fusion root).
1060   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1061                                        bool add_output = false);
1062 
1063   bool IsElementwiseImpl(
1064       const absl::optional<int64>& operand_idx) const override;
1065   std::vector<string> ExtraAttributesToStringImpl(
1066       const HloPrintOptions& options) const override;
1067   bool IdenticalSlowPath(
1068       const HloInstruction& other,
1069       const std::function<bool(const HloComputation*, const HloComputation*)>&
1070           eq_computations) const override;
1071   uint64 InnerHash() const override;
1072 
1073   // Implementation for non-common logic of CloneWithNewOperands.
1074   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1075       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1076       HloCloneContext* context) const override;
1077 
1078   // The type of the fusion. Used by kFusion only.
1079   FusionKind fusion_kind_;
1080 };
1081 
1082 class HloRngInstruction : public HloInstruction {
1083  public:
1084   explicit HloRngInstruction(const Shape& shape,
1085                              RandomDistribution distribution,
1086                              absl::Span<HloInstruction* const> parameters);
1087   // Returns the random distribution for this rng node.
random_distribution()1088   RandomDistribution random_distribution() const { return distribution_; }
1089   // Returns a serialized representation of this instruction.
1090   HloInstructionProto ToProto() const override;
1091 
1092  private:
1093   bool IsElementwiseImpl(
1094       const absl::optional<int64>& operand_idx) const override;
1095   std::vector<string> ExtraAttributesToStringImpl(
1096       const HloPrintOptions& options) const override;
1097   bool IdenticalSlowPath(
1098       const HloInstruction& other,
1099       const std::function<bool(const HloComputation*, const HloComputation*)>&
1100           eq_computations) const override;
1101   // Implementation for non-common logic of CloneWithNewOperands.
1102   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1103       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1104       HloCloneContext* context) const override;
1105 
1106   // The distribution requested for random number generation.
1107   RandomDistribution distribution_;
1108 };
1109 
1110 class HloParameterInstruction : public HloInstruction {
1111  public:
1112   explicit HloParameterInstruction(int64_t parameter_number, const Shape& shape,
1113                                    const string& name);
parameter_number()1114   int64 parameter_number() const { return parameter_number_; }
1115 
1116   // Sets and gets the whether all replicas will receive the same parameter data
1117   // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1118   void set_parameter_replicated_at_leaf_buffers(
1119       absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
1120     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1121              parameter_replicated_at_leaf_buffers.size());
1122     parameter_replicated_at_leaf_buffers_.emplace(
1123         parameter_replicated_at_leaf_buffers.begin(),
1124         parameter_replicated_at_leaf_buffers.end());
1125   }
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1126   void set_parameter_replicated_at_leaf_buffers(
1127       const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
1128     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1129              parameter_replicated_at_leaf_buffers.size());
1130     parameter_replicated_at_leaf_buffers_ =
1131         parameter_replicated_at_leaf_buffers;
1132   }
1133   const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers()1134   parameter_replicated_at_leaf_buffers() const {
1135     return parameter_replicated_at_leaf_buffers_;
1136   }
1137 
1138   // Returns a serialized representation of this instruction.
1139   HloInstructionProto ToProto() const override;
1140 
1141  private:
1142   std::vector<string> ExtraAttributesToStringImpl(
1143       const HloPrintOptions& options) const override;
1144   bool IdenticalSlowPath(
1145       const HloInstruction& other,
1146       const std::function<bool(const HloComputation*, const HloComputation*)>&
1147           eq_computations) const override;
1148   string OperandsToStringWithCanonicalNameMap(
1149       const HloPrintOptions& options,
1150       CanonicalNameMap* canonical_name_map) const override;
1151   // Implementation for non-common logic of CloneWithNewOperands.
1152   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1153       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1154       HloCloneContext* context) const override;
1155 
1156   int64 parameter_number_ = 0;
1157 
1158   // Specifies whether each buffer has the same parameter value on all replicas
1159   // in data parallelism.
1160   absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
1161 };
1162 
1163 class HloGetTupleElementInstruction : public HloInstruction {
1164  public:
1165   explicit HloGetTupleElementInstruction(const Shape& shape,
1166                                          HloInstruction* operand,
1167                                          int64_t index);
1168   // Returns the tuple index associated with this instruction.
tuple_index()1169   int64 tuple_index() const { return tuple_index_; }
1170   // Sets the tuple index associated with this instruction.
set_tuple_index(int64_t new_tuple_index)1171   void set_tuple_index(int64_t new_tuple_index) {
1172     tuple_index_ = new_tuple_index;
1173   }
1174   // Returns a serialized representation of this instruction.
1175   HloInstructionProto ToProto() const override;
1176 
1177  private:
1178   std::vector<string> ExtraAttributesToStringImpl(
1179       const HloPrintOptions& options) const override;
1180   bool IdenticalSlowPath(
1181       const HloInstruction& other,
1182       const std::function<bool(const HloComputation*, const HloComputation*)>&
1183           eq_computations) const override;
1184   // Implementation for non-common logic of CloneWithNewOperands.
1185   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1186       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1187       HloCloneContext* context) const override;
1188 
1189   int64 tuple_index_ = -1;
1190 };
1191 
1192 class HloReducePrecisionInstruction : public HloInstruction {
1193  public:
1194   explicit HloReducePrecisionInstruction(const Shape& shape,
1195                                          HloInstruction* operand,
1196                                          const int exponent_bits,
1197                                          const int mantissa_bits);
1198   // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()1199   int32 exponent_bits() const { return exponent_bits_; }
1200   // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()1201   int32 mantissa_bits() const { return mantissa_bits_; }
1202   // Returns a serialized representation of this instruction.
1203   HloInstructionProto ToProto() const override;
1204 
1205  private:
1206   std::vector<string> ExtraAttributesToStringImpl(
1207       const HloPrintOptions& options) const override;
1208   bool IdenticalSlowPath(
1209       const HloInstruction& other,
1210       const std::function<bool(const HloComputation*, const HloComputation*)>&
1211           eq_computations) const override;
1212   // Implementation for non-common logic of CloneWithNewOperands.
1213   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1214       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1215       HloCloneContext* context) const override;
1216 
1217   // The bit sizes for a reduce-precision operation.
1218   int32 exponent_bits_ = 0;
1219   int32 mantissa_bits_ = 0;
1220 };
1221 
1222 class HloInfeedInstruction : public HloInstruction {
1223  public:
1224   explicit HloInfeedInstruction(const Shape& infeed_shape,
1225                                 HloInstruction* token_operand,
1226                                 const string& config);
1227   // Returns the infeed configuration string. The infeed configuration includes
1228   // any metadata needed for the backend compiler (e.g., infeed buffer address)
1229   // and is target-dependent.
infeed_config()1230   string infeed_config() const { return infeed_config_; }
set_infeed_config(const string & config)1231   void set_infeed_config(const string& config) { infeed_config_ = config; }
1232   // Returns the shape of the data received by the infeed. This is not the same
1233   // as the shape of the infeed instruction which produces a tuple containing
1234   // the infeed data shape and a TOKEN.
infeed_shape()1235   const Shape& infeed_shape() const {
1236     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
1237     return ShapeUtil::GetSubshape(shape(), {0});
1238   }
1239   // Returns a serialized representation of this instruction.
1240   HloInstructionProto ToProto() const override;
1241 
1242  private:
1243   std::vector<string> ExtraAttributesToStringImpl(
1244       const HloPrintOptions& options) const override;
1245   bool IdenticalSlowPath(
1246       const HloInstruction& other,
1247       const std::function<bool(const HloComputation*, const HloComputation*)>&
1248           eq_computations) const override;
1249   // Implementation for non-common logic of CloneWithNewOperands.
1250   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1251       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1252       HloCloneContext* context) const override;
1253 
1254   // The string representation of the infeed configuration.
1255   string infeed_config_;
1256 };
1257 
1258 class HloOutfeedInstruction : public HloInstruction {
1259  public:
1260   explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1261                                  HloInstruction* operand,
1262                                  HloInstruction* token_operand,
1263                                  absl::string_view outfeed_config);
1264   // Returns the shape for the Outfeed instruction.
outfeed_shape()1265   const Shape& outfeed_shape() const { return outfeed_shape_; }
1266   // Returns the mutable shape for the Outfeed instruction.
mutable_outfeed_shape()1267   Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
1268   // Returns the config for the Outfeed instruction.
outfeed_config()1269   const string& outfeed_config() const { return outfeed_config_; }
set_outfeed_config(const string & config)1270   void set_outfeed_config(const string& config) { outfeed_config_ = config; }
1271   // Returns a serialized representation of this instruction.
1272   HloInstructionProto ToProto() const override;
1273 
1274  private:
1275   std::vector<string> ExtraAttributesToStringImpl(
1276       const HloPrintOptions& options) const override;
1277   bool IdenticalSlowPath(
1278       const HloInstruction& other,
1279       const std::function<bool(const HloComputation*, const HloComputation*)>&
1280           eq_computations) const override;
1281   // Implementation for non-common logic of CloneWithNewOperands.
1282   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1283       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1284       HloCloneContext* context) const override;
1285 
1286   // Shape of outfeed request.
1287   Shape outfeed_shape_;
1288   // Outfeed configuration information, only present for kOutfeed.
1289   string outfeed_config_;
1290 };
1291 
1292 class HloConvolutionInstruction : public HloInstruction {
1293  public:
1294   explicit HloConvolutionInstruction(
1295       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1296       int64_t feature_group_count, int64_t batch_group_count,
1297       const Window& window,
1298       const ConvolutionDimensionNumbers& dimension_numbers,
1299       const PrecisionConfig& precision_config);
window()1300   const Window& window() const override { return window_; }
set_window(const Window & window)1301   void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1302   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1303     return convolution_dimension_numbers_;
1304   }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1305   void set_convolution_dimension_numbers(
1306       const ConvolutionDimensionNumbers& dnums) {
1307     convolution_dimension_numbers_ = dnums;
1308   }
1309   // The number of feature groups. Must be a divisor of the input feature
1310   // dimension and output feature dimension.
feature_group_count()1311   int64 feature_group_count() const { return feature_group_count_; }
set_feature_group_count(int64_t num_feature_groups)1312   void set_feature_group_count(int64_t num_feature_groups) {
1313     feature_group_count_ = num_feature_groups;
1314   }
1315   // The number of batch groups. Must be a divisor of the input batch dimension.
batch_group_count()1316   int64 batch_group_count() const { return batch_group_count_; }
set_batch_group_count(int64_t num_batch_groups)1317   void set_batch_group_count(int64_t num_batch_groups) {
1318     batch_group_count_ = num_batch_groups;
1319   }
1320 
1321   // Returns the information used to tell the implementation information about
1322   // what sort of precision is requested. The meaning of the field is backend
1323   // specific. At the moment, it is only supported for kConvolution and kDot.
1324   // Transformations on one kDot or kConvolution to another will preserve this
1325   // information. Transformations to other HLOs will not preserve this
1326   // information but it is presumed that the alternate lowering is strictly
1327   // superior.
precision_config()1328   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1329   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1330 
1331   string ToCategory() const override;
1332   // Returns a serialized representation of this instruction.
1333   HloInstructionProto ToProto() const override;
1334 
1335  private:
1336   std::vector<string> ExtraAttributesToStringImpl(
1337       const HloPrintOptions& options) const override;
1338   bool IdenticalSlowPath(
1339       const HloInstruction& other,
1340       const std::function<bool(const HloComputation*, const HloComputation*)>&
1341           eq_computations) const override;
1342   // Implementation for non-common logic of CloneWithNewOperands.
1343   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1344       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1345       HloCloneContext* context) const override;
1346   // The number of feature groups. Must be a divisor of the input feature
1347   // dimension and output feature dimension.
1348   int64 feature_group_count_;
1349   // The number of batch groups. Must be a divisor of the input batch dimension.
1350   int64 batch_group_count_;
1351   // Describes the window used for a convolution.
1352   Window window_;
1353   // Describes the dimension numbers used for a convolution.
1354   ConvolutionDimensionNumbers convolution_dimension_numbers_;
1355   // Information used to communicate to the implementation about the algorithm
1356   // used to produce results. See the documentation on precision_config().
1357   PrecisionConfig precision_config_;
1358 };
1359 
1360 class HloReduceWindowInstruction : public HloInstruction {
1361  public:
1362   explicit HloReduceWindowInstruction(const Shape& shape,
1363                                       HloInstruction* operand,
1364                                       HloInstruction* init_value,
1365                                       const Window& window,
1366                                       HloComputation* reduce_computation);
1367   explicit HloReduceWindowInstruction(
1368       const Shape& shape, absl::Span<HloInstruction* const> operands,
1369       absl::Span<HloInstruction* const> init_values, const Window& window,
1370       HloComputation* reduce_computation);
window()1371   const Window& window() const override { return window_; }
set_window(const Window & window)1372   void set_window(const Window& window) override { window_ = window; }
1373   // Returns a serialized representation of this instruction.
1374   HloInstructionProto ToProto() const override;
1375   // Returns the number of input arrays (and, consequentially, the number of
1376   // init values) this reduce has.
input_count()1377   int64 input_count() const { return operand_count() / 2; }
1378   // Returns the input tensors to be reduced.
inputs()1379   absl::Span<HloInstruction* const> inputs() const {
1380     return absl::MakeSpan(operands()).subspan(0, input_count());
1381   }
1382   // Returns the init values of the reduction.
init_values()1383   absl::Span<HloInstruction* const> init_values() const {
1384     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
1385   }
1386   // Returns the shapes of input tensors to be reduced.
input_shapes()1387   absl::InlinedVector<const Shape*, 2> input_shapes() const {
1388     absl::InlinedVector<const Shape*, 2> shapes;
1389     for (const auto* op : inputs()) {
1390       VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
1391       shapes.push_back(&op->shape());
1392       VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
1393     }
1394     return shapes;
1395   }
1396   // Returns the init values of the reduction.
init_value_shapes()1397   absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
1398     absl::InlinedVector<const Shape*, 2> shapes;
1399     for (const auto* op : init_values()) {
1400       shapes.push_back(&op->shape());
1401     }
1402     return shapes;
1403   }
1404   // Returns the shapes of the reduced output tensors.
output_shapes()1405   absl::InlinedVector<const Shape*, 2> output_shapes() const {
1406     absl::InlinedVector<const Shape*, 2> shapes;
1407     if (shape().IsArray()) {
1408       shapes.push_back(&shape());
1409     } else {
1410       for (const Shape& tuple_element_shape : shape().tuple_shapes()) {
1411         shapes.push_back(&tuple_element_shape);
1412       }
1413     }
1414     return shapes;
1415   }
1416 
1417  private:
1418   std::vector<string> ExtraAttributesToStringImpl(
1419       const HloPrintOptions& options) const override;
1420   bool IdenticalSlowPath(
1421       const HloInstruction& other,
1422       const std::function<bool(const HloComputation*, const HloComputation*)>&
1423           eq_computations) const override;
1424   // Implementation for non-common logic of CloneWithNewOperands.
1425   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1426       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1427       HloCloneContext* context) const override;
1428 
1429   Window window_;
1430 };
1431 
1432 class HloSelectAndScatterInstruction : public HloInstruction {
1433  public:
1434   explicit HloSelectAndScatterInstruction(
1435       const Shape& shape, HloInstruction* operand, HloComputation* select,
1436       const Window& window, HloInstruction* source, HloInstruction* init_value,
1437       HloComputation* scatter);
window()1438   const Window& window() const override { return window_; }
set_window(const Window & window)1439   void set_window(const Window& window) override { window_ = window; }
1440   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1441   // setters should only be called by HloModule or HloComputation methods.
select()1442   HloComputation* select() const {
1443     return called_computations()[kSelectComputationIndex];
1444   }
1445 
scatter()1446   HloComputation* scatter() const {
1447     return called_computations()[kScatterComputationIndex];
1448   }
1449 
set_select(HloComputation * computation)1450   void set_select(HloComputation* computation) {
1451     // Don't allow changing the computation for fused instructions so we don't
1452     // have to recompute called_instructions for the entire fusion instruction.
1453     CHECK(!IsFused());
1454     set_called_computation(kSelectComputationIndex, computation);
1455   }
1456 
set_scatter(HloComputation * computation)1457   void set_scatter(HloComputation* computation) {
1458     // Don't allow changing the computation for fused instructions so we don't
1459     // have to recompute called_instructions for the entire fusion instruction.
1460     CHECK(!IsFused());
1461     set_called_computation(kScatterComputationIndex, computation);
1462   }
1463   // Returns a serialized representation of this instruction.
1464   HloInstructionProto ToProto() const override;
1465 
1466  private:
1467   std::vector<string> ExtraAttributesToStringImpl(
1468       const HloPrintOptions& options) const override;
1469   bool IdenticalSlowPath(
1470       const HloInstruction& other,
1471       const std::function<bool(const HloComputation*, const HloComputation*)>&
1472           eq_computations) const override;
1473   // Implementation for non-common logic of CloneWithNewOperands.
1474   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1475       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1476       HloCloneContext* context) const override;
1477   Window window_;
1478 };
1479 
1480 class HloCustomCallInstruction : public HloInstruction {
1481  public:
1482   HloCustomCallInstruction(const Shape& shape,
1483                            absl::Span<HloInstruction* const> operands,
1484                            absl::string_view custom_call_target, string opaque,
1485                            CustomCallApiVersion api_version);
1486 
1487   // Constructor for a custom call with constrained layout. 'shape' and
1488   // 'operands_with_layout' must all have layouts.
1489   HloCustomCallInstruction(const Shape& shape,
1490                            absl::Span<HloInstruction* const> operands,
1491                            absl::string_view custom_call_target, string opaque,
1492                            absl::Span<const Shape> operand_shapes_with_layout,
1493                            CustomCallApiVersion api_version);
1494 
1495   // Constructor for a custom call with a to_apply computation.
1496   HloCustomCallInstruction(const Shape& shape,
1497                            absl::Span<HloInstruction* const> operands,
1498                            HloComputation* to_apply,
1499                            absl::string_view custom_call_target, string opaque,
1500                            CustomCallApiVersion api_version);
1501 
1502   // Constructor for a custom call with multiple computations.
1503   HloCustomCallInstruction(
1504       const Shape& shape, absl::Span<HloInstruction* const> operands,
1505       absl::Span<HloComputation* const> called_computations,
1506       absl::string_view custom_call_target, string opaque,
1507       CustomCallApiVersion api_version);
1508 
window()1509   const Window& window() const override {
1510     CHECK(window_ != nullptr);
1511     return *window_;
1512   }
1513 
set_window(const Window & window)1514   void set_window(const Window& window) override {
1515     window_ = absl::make_unique<Window>(window);
1516   }
1517 
convolution_dimension_numbers()1518   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1519     CHECK(convolution_dimension_numbers_ != nullptr);
1520     return *convolution_dimension_numbers_;
1521   }
1522 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1523   void set_convolution_dimension_numbers(
1524       const ConvolutionDimensionNumbers& dnums) {
1525     convolution_dimension_numbers_ =
1526         absl::make_unique<ConvolutionDimensionNumbers>(dnums);
1527   }
1528   // TODO(jpienaar): Remove this accessor in the follow up.
opaque()1529   const string& opaque() const { return raw_backend_config_string(); }
custom_call_target()1530   const string& custom_call_target() const { return custom_call_target_; }
set_feature_group_count(int64_t feature_group_count)1531   void set_feature_group_count(int64_t feature_group_count) {
1532     feature_group_count_ = feature_group_count;
1533   }
set_batch_group_count(int64_t batch_group_count)1534   void set_batch_group_count(int64_t batch_group_count) {
1535     batch_group_count_ = batch_group_count;
1536   }
1537   // Sets whether this custom call has a side-effect - by default a custom call
1538   // has no side-effects.
set_custom_call_has_side_effect(bool custom_call_has_side_effect)1539   void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
1540     custom_call_has_side_effect_ = custom_call_has_side_effect;
1541   }
feature_group_count()1542   int64 feature_group_count() const { return feature_group_count_; }
batch_group_count()1543   int64 batch_group_count() const { return batch_group_count_; }
custom_call_has_side_effect()1544   bool custom_call_has_side_effect() const {
1545     return custom_call_has_side_effect_;
1546   }
1547   // Returns padding type used for ops like convolution.
padding_type()1548   PaddingType padding_type() const { return padding_type_; }
1549 
set_padding_type(PaddingType padding_type)1550   void set_padding_type(PaddingType padding_type) {
1551     padding_type_ = padding_type;
1552   }
1553 
1554   // Returns the literal associated with this instruction.
literal()1555   const Literal& literal() const { return *literal_; }
1556   // Set the value of literal to a new one.
set_literal(Literal && literal)1557   void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
1558   // Returns whether there is literal associated with this instruction.
HasLiteral()1559   bool HasLiteral() const { return literal_.has_value(); }
1560 
precision_config()1561   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1562   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1563 
1564   // Returns a serialized representation of this instruction.
1565   HloInstructionProto ToProto() const override;
1566 
1567   // Returns whether the result and operand layouts are constrained.
layout_constrained()1568   bool layout_constrained() const { return layout_constrained_; }
1569 
1570   // Returns the shapes (with layout) of the operands. CHECKs if this custom
1571   // call does not have constrained layouts.
operand_shapes_with_layout()1572   const std::vector<Shape>& operand_shapes_with_layout() const {
1573     CHECK(layout_constrained());
1574     return operand_shapes_with_layout_;
1575   }
1576   // Gets a list of output/operand buffer pairs that alias each other, where the
1577   // output buffer is represented as a ShapeIndex, and the operand buffer is
1578   // represented as the operand index and the ShapeIndex. By default this list
1579   // is empty.
1580   const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
output_to_operand_aliasing()1581   output_to_operand_aliasing() const {
1582     return output_to_operand_aliasing_;
1583   }
1584   // Sets the list of output/operand buffer pairs that alias each other.
set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> aliasing)1585   void set_output_to_operand_aliasing(
1586       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1587           aliasing) {
1588     output_to_operand_aliasing_ = std::move(aliasing);
1589   }
set_custom_call_schedule(CustomCallSchedule custom_call_schedule)1590   void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) {
1591     custom_call_schedule_ = custom_call_schedule;
1592   }
custom_call_schedule()1593   CustomCallSchedule custom_call_schedule() const {
1594     return custom_call_schedule_;
1595   }
set_api_version(CustomCallApiVersion api_version)1596   void set_api_version(CustomCallApiVersion api_version) {
1597     api_version_ = api_version;
1598   }
api_version()1599   CustomCallApiVersion api_version() const { return api_version_; }
1600 
1601  private:
1602   std::vector<string> ExtraAttributesToStringImpl(
1603       const HloPrintOptions& options) const override;
1604   bool IdenticalSlowPath(
1605       const HloInstruction& other,
1606       const std::function<bool(const HloComputation*, const HloComputation*)>&
1607           eq_computations) const override;
1608   // Implementation for non-common logic of CloneWithNewOperands.
1609   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1610       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1611       HloCloneContext* context) const override;
1612   // Name of a global symbol to call.
1613   string custom_call_target_;
1614   // Describes the window in a windowed operation such as convolution.
1615   std::unique_ptr<Window> window_;
1616   // Describes the dimension numbers used for a convolution.
1617   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1618   // The number of feature groups. This is used for grouped convolutions.
1619   int64 feature_group_count_;
1620   int64 batch_group_count_;
1621   // Whether the result and operand layouts are constrained.
1622   bool layout_constrained_;
1623   // Information used to communicate to the implementation about the algorithm
1624   // used to produce results for convolution instructions.
1625   PrecisionConfig precision_config_;
1626   // Describes the padding type for convolution instructions.
1627   PaddingType padding_type_;
1628   // For layout-constrained custom calls, this vector holds the shape with
1629   // layout for each operand.
1630   std::vector<Shape> operand_shapes_with_layout_;
1631   // Whether this custom call has a side-effect.
1632   bool custom_call_has_side_effect_;
1633   // A list of output/operand buffer pairs that alias each other. See comment of
1634   // output_to_operand_aliasing().
1635   std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1636       output_to_operand_aliasing_;
1637   absl::optional<Literal> literal_;
1638   // A custom-call schedule hint.
1639   CustomCallSchedule custom_call_schedule_;
1640   // The version of the API used by the custom call function.
1641   // TODO(b/189822916): Remove this field when all clients are migrated to the
1642   // status-returning API.
1643   CustomCallApiVersion api_version_;
1644 };
1645 
1646 class HloPadInstruction : public HloInstruction {
1647  public:
1648   explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1649                              HloInstruction* padding_value,
1650                              const PaddingConfig& padding_config);
1651   // Returns the padding configuration for a pad node.
padding_config()1652   const PaddingConfig& padding_config() const { return padding_config_; }
mutable_padding_config()1653   PaddingConfig* mutable_padding_config() { return &padding_config_; }
1654   // Returns the padding value.
padding_value()1655   const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1656   HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1657   // Returns a serialized representation of this instruction.
1658   HloInstructionProto ToProto() const override;
1659 
1660  private:
1661   std::vector<string> ExtraAttributesToStringImpl(
1662       const HloPrintOptions& options) const override;
1663   bool IdenticalSlowPath(
1664       const HloInstruction& other,
1665       const std::function<bool(const HloComputation*, const HloComputation*)>&
1666           eq_computations) const override;
1667   // Implementation for non-common logic of CloneWithNewOperands.
1668   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1669       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1670       HloCloneContext* context) const override;
1671 
1672   // The padding configuration that describes the edge padding and interior
1673   // padding of this pad instruction.
1674   PaddingConfig padding_config_;
1675 };
1676 
1677 class HloDynamicIndexInstruction : public HloInstruction {
1678  public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1679   explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
1680       : HloInstruction(opcode, shape) {}
1681   virtual int64 first_index_operand_number() const = 0;
1682 
1683   // Returns a subspan of operands which represent the start indices.
index_operands()1684   absl::Span<HloInstruction* const> index_operands() const {
1685     return absl::MakeSpan(operands()).subspan(first_index_operand_number());
1686   }
1687 
1688   // Returns the shapes of the index operands.
index_shapes()1689   std::vector<Shape> index_shapes() const {
1690     std::vector<Shape> shapes;
1691     auto indices = index_operands();
1692     for (const HloInstruction* index : indices) {
1693       shapes.push_back(index->shape());
1694     }
1695     return shapes;
1696   }
1697 };
1698 
1699 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
1700  public:
1701   explicit HloDynamicSliceInstruction(const Shape& shape,
1702                                       HloInstruction* operand,
1703                                       HloInstruction* start_indices,
1704                                       absl::Span<const int64> slice_sizes);
1705   explicit HloDynamicSliceInstruction(
1706       const Shape& shape, HloInstruction* operand,
1707       absl::Span<HloInstruction* const> start_indices,
1708       absl::Span<const int64> slice_sizes);
1709   // Old methods kept for smooth subclassing transition END.
1710   // Returns the size of the slice in the given dimension for a dynamic
1711   // slice node.
slice_sizes(int64_t dimension)1712   int64 slice_sizes(int64_t dimension) const {
1713     return dynamic_slice_sizes_[dimension];
1714   }
dynamic_slice_sizes()1715   const std::vector<int64>& dynamic_slice_sizes() const {
1716     return dynamic_slice_sizes_;
1717   }
1718   // Returns a serialized representation of this instruction.
1719   HloInstructionProto ToProto() const override;
1720 
first_index_operand_number()1721   int64 first_index_operand_number() const override { return 1; }
1722 
1723  private:
1724   std::vector<string> ExtraAttributesToStringImpl(
1725       const HloPrintOptions& options) const override;
1726   bool IdenticalSlowPath(
1727       const HloInstruction& other,
1728       const std::function<bool(const HloComputation*, const HloComputation*)>&
1729           eq_computations) const override;
1730   // Implementation for non-common logic of CloneWithNewOperands.
1731   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1732       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1733       HloCloneContext* context) const override;
1734 
1735   // Describes the [start, start + size) range size for a dynamic slice
1736   // ('start' is specified dynamically in the second operand of the operation).
1737   std::vector<int64> dynamic_slice_sizes_;
1738 };
1739 
1740 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
1741  public:
1742   explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
1743                                             HloInstruction* operand,
1744                                             HloInstruction* update,
1745                                             HloInstruction* start_indices);
1746   explicit HloDynamicUpdateSliceInstruction(
1747       const Shape& shape, HloInstruction* operand, HloInstruction* update,
1748       absl::Span<HloInstruction* const> start_indices);
1749 
first_index_operand_number()1750   int64 first_index_operand_number() const override { return 2; }
1751 };
1752 
1753 class HloGatherInstruction : public HloInstruction {
1754  public:
1755   explicit HloGatherInstruction(
1756       const Shape& shape, HloInstruction* operand,
1757       HloInstruction* start_indices,
1758       const GatherDimensionNumbers& gather_dim_numbers,
1759       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
gather_dimension_numbers()1760   const GatherDimensionNumbers& gather_dimension_numbers() const {
1761     CHECK(gather_dimension_numbers_ != nullptr);
1762     return *gather_dimension_numbers_;
1763   }
gather_slice_sizes()1764   absl::Span<const int64> gather_slice_sizes() const {
1765     return gather_slice_sizes_;
1766   }
indices_are_sorted()1767   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1768   void set_indices_are_sorted(bool indices_are_sorted) {
1769     indices_are_sorted_ = indices_are_sorted;
1770   }
1771   // Returns a serialized representation of this instruction.
1772   HloInstructionProto ToProto() const override;
1773 
1774   // Creates an instance of GatherDimensionNumbers.
1775   static GatherDimensionNumbers MakeGatherDimNumbers(
1776       absl::Span<const int64> offset_dims,
1777       absl::Span<const int64> collapsed_slice_dims,
1778       absl::Span<const int64> start_index_map, int64_t index_vector_dim);
1779   // Returns the dump string of the given gather dimension numbers.
1780   static string GatherDimensionNumbersToString(
1781       const GatherDimensionNumbers& gather_dimension_numbers);
1782 
1783  private:
1784   std::vector<string> ExtraAttributesToStringImpl(
1785       const HloPrintOptions& options) const override;
1786   bool IdenticalSlowPath(
1787       const HloInstruction& other,
1788       const std::function<bool(const HloComputation*, const HloComputation*)>&
1789           eq_computations) const override;
1790   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1791       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1792       HloCloneContext* context) const override;
1793 
1794   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1795   std::vector<int64> gather_slice_sizes_;
1796   bool indices_are_sorted_;
1797 };
1798 
1799 class HloScatterInstruction : public HloInstruction {
1800  public:
1801   explicit HloScatterInstruction(
1802       const Shape& shape, HloInstruction* operand,
1803       HloInstruction* scatter_indices, HloInstruction* updates,
1804       HloComputation* update_computation,
1805       const ScatterDimensionNumbers& scatter_dim_numbers,
1806       bool indices_are_sorted, bool unique_indices);
scatter_dimension_numbers()1807   const ScatterDimensionNumbers& scatter_dimension_numbers() const {
1808     CHECK(scatter_dimension_numbers_ != nullptr);
1809     return *scatter_dimension_numbers_;
1810   }
indices_are_sorted()1811   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1812   void set_indices_are_sorted(bool indices_are_sorted) {
1813     indices_are_sorted_ = indices_are_sorted;
1814   }
unique_indices()1815   bool unique_indices() const override { return unique_indices_; }
1816   // Returns a serialized representation of this instruction.
1817   HloInstructionProto ToProto() const override;
1818 
1819   // Creates an instance of ScatterDimensionNumbers.
1820   static ScatterDimensionNumbers MakeScatterDimNumbers(
1821       absl::Span<const int64> update_window_dims,
1822       absl::Span<const int64> inserted_window_dims,
1823       absl::Span<const int64> scatter_dims_to_operand_dims,
1824       int64_t index_vector_dim);
1825   // Returns the dump string of the given scatter dimension numbers.
1826   static string ScatterDimensionNumbersToString(
1827       const ScatterDimensionNumbers& scatter_dimension_numbers);
1828 
1829  private:
1830   std::vector<string> ExtraAttributesToStringImpl(
1831       const HloPrintOptions& options) const override;
1832   bool IdenticalSlowPath(
1833       const HloInstruction& other,
1834       const std::function<bool(const HloComputation*, const HloComputation*)>&
1835           eq_computations) const override;
1836   // Implementation for non-common logic of CloneWithNewOperands.
1837   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1838       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1839       HloCloneContext* context) const override;
1840 
1841   std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
1842   bool indices_are_sorted_;
1843   bool unique_indices_;
1844 };
1845 
1846 class HloIotaInstruction : public HloInstruction {
1847  public:
1848   explicit HloIotaInstruction(const Shape& shape, int64_t iota_dimension);
1849   // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()1850   int64 iota_dimension() const { return iota_dimension_; }
1851   // Returns a serialized representation of this instruction.
1852   HloInstructionProto ToProto() const override;
1853 
1854  private:
1855   std::vector<string> ExtraAttributesToStringImpl(
1856       const HloPrintOptions& options) const override;
1857   bool IdenticalSlowPath(
1858       const HloInstruction& other,
1859       const std::function<bool(const HloComputation*, const HloComputation*)>&
1860           eq_computations) const override;
1861   // Implementation for non-common logic of CloneWithNewOperands.
1862   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1863       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1864       HloCloneContext* context) const override;
1865 
1866   const int64 iota_dimension_;
1867 };
1868 
1869 class HloDotInstruction : public HloInstruction {
1870  public:
1871   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
1872   // dimensions specified in 'dimension_numbers'.
1873   explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
1874                              HloInstruction* rhs,
1875                              const DotDimensionNumbers& dimension_numbers,
1876                              const PrecisionConfig& precision_config);
1877 
1878   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()1879   const DotDimensionNumbers& dot_dimension_numbers() const {
1880     return dot_dimension_numbers_;
1881   }
1882 
1883   // Returns the information used to tell the implementation information about
1884   // what sort of precision is requested. The meaning of the field is backend
1885   // specific. At the moment, it is only supported for kConvolution and kDot.
1886   // Transformations on one kDot or kConvolution to another will preserve this
1887   // information. Transformations to other HLOs will not preserve this
1888   // information but it is presumed that the alternate lowering is strictly
1889   // superior.
precision_config()1890   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1891   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1892 
1893   // Returns a serialized representation of this instruction.
1894   HloInstructionProto ToProto() const override;
1895 
1896  private:
1897   std::vector<string> ExtraAttributesToStringImpl(
1898       const HloPrintOptions& options) const override;
1899   bool IdenticalSlowPath(
1900       const HloInstruction& other,
1901       const std::function<bool(const HloComputation*, const HloComputation*)>&
1902           eq_computations) const override;
1903   // Implementation for non-common logic of CloneWithNewOperands.
1904   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1905       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1906       HloCloneContext* context) const override;
1907   // Returns the dump string of the dot dimension numbers.
1908   string DotDimensionNumbersToString() const;
1909 
1910   // Describes the dimension numbers used for a dot.
1911   DotDimensionNumbers dot_dimension_numbers_;
1912 
1913   // Information used to communicate to the implementation about the algorithm
1914   // used to produce results. See the documentation on precision_config().
1915   PrecisionConfig precision_config_;
1916 };
1917 
1918 class HloDomainInstruction : public HloInstruction {
1919  public:
1920   explicit HloDomainInstruction(
1921       const Shape& shape, HloInstruction* operand,
1922       std::unique_ptr<DomainMetadata> operand_side_metadata,
1923       std::unique_ptr<DomainMetadata> user_side_metadata);
1924 
1925   // Returns a serialized representation of this instruction.
1926   HloInstructionProto ToProto() const override;
1927 
1928   // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()1929   const DomainMetadata& operand_side_metadata() const {
1930     return *operand_side_metadata_;
1931   }
1932   // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()1933   const DomainMetadata& user_side_metadata() const {
1934     return *user_side_metadata_;
1935   }
1936 
1937  private:
1938   std::vector<string> ExtraAttributesToStringImpl(
1939       const HloPrintOptions& options) const override;
1940   bool IdenticalSlowPath(
1941       const HloInstruction& other,
1942       const std::function<bool(const HloComputation*, const HloComputation*)>&
1943           eq_computations) const override;
1944   // Implementation for non-common logic of CloneWithNewOperands.
1945   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1946       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1947       HloCloneContext* context) const override;
1948 
1949   std::unique_ptr<DomainMetadata> operand_side_metadata_;
1950   std::unique_ptr<DomainMetadata> user_side_metadata_;
1951 };
1952 
1953 class HloGetDimensionSizeInstruction : public HloInstruction {
1954  public:
1955   explicit HloGetDimensionSizeInstruction(const Shape& shape,
1956                                           HloInstruction* operand,
1957                                           int64_t dimension);
1958 
1959   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1960   int64 dimension() const { return dimension_; }
1961   // Returns a serialized representation of this instruction.
1962   HloInstructionProto ToProto() const override;
1963 
1964  private:
1965   std::vector<string> ExtraAttributesToStringImpl(
1966       const HloPrintOptions& options) const override;
1967   bool IdenticalSlowPath(
1968       const HloInstruction& other,
1969       const std::function<bool(const HloComputation*, const HloComputation*)>&
1970           eq_computations) const override;
1971   // Implementation for non-common logic of CloneWithNewOperands.
1972   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1973       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1974       HloCloneContext* context) const override;
1975 
1976   int64 dimension_;
1977 };
1978 
1979 class HloSetDimensionSizeInstruction : public HloInstruction {
1980  public:
1981   explicit HloSetDimensionSizeInstruction(const Shape& shape,
1982                                           HloInstruction* operand,
1983                                           HloInstruction* val,
1984                                           int64_t dimension);
1985 
1986   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1987   int64 dimension() const { return dimension_; }
1988   // Returns a serialized representation of this instruction.
1989   HloInstructionProto ToProto() const override;
1990 
1991  private:
1992   std::vector<string> ExtraAttributesToStringImpl(
1993       const HloPrintOptions& options) const override;
1994   bool IdenticalSlowPath(
1995       const HloInstruction& other,
1996       const std::function<bool(const HloComputation*, const HloComputation*)>&
1997           eq_computations) const override;
1998   // Implementation for non-common logic of CloneWithNewOperands.
1999   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2000       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2001       HloCloneContext* context) const override;
2002 
2003   int64 dimension_;
2004 };
2005 
2006 class HloRngGetAndUpdateStateInstruction : public HloInstruction {
2007  public:
2008   explicit HloRngGetAndUpdateStateInstruction(const Shape& shape,
2009                                               int64_t delta);
2010 
2011   // Returns the delta value.
delta()2012   int64 delta() const { return delta_; }
set_delta(int64_t delta)2013   void set_delta(int64_t delta) { delta_ = delta; }
2014   // Returns a serialized representation of this instruction.
2015   HloInstructionProto ToProto() const override;
2016 
2017  private:
2018   std::vector<string> ExtraAttributesToStringImpl(
2019       const HloPrintOptions& options) const override;
2020   bool IdenticalSlowPath(
2021       const HloInstruction& other,
2022       const std::function<bool(const HloComputation*, const HloComputation*)>&
2023           eq_computations) const override;
2024   // Implementation for non-common logic of CloneWithNewOperands.
2025   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2026       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2027       HloCloneContext* context) const override;
2028 
2029   int64 delta_;
2030 };
2031 
2032 class HloRngBitGeneratorInstruction : public HloInstruction {
2033  public:
2034   HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
2035                                 RandomAlgorithm algorithm);
2036 
algorithm()2037   RandomAlgorithm algorithm() const { return algorithm_; }
2038   HloInstructionProto ToProto() const override;
2039 
2040  private:
2041   std::vector<string> ExtraAttributesToStringImpl(
2042       const HloPrintOptions& options) const override;
2043   bool IdenticalSlowPath(
2044       const HloInstruction& other,
2045       const std::function<bool(const HloComputation*, const HloComputation*)>&
2046           eq_computations) const override;
2047   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2048       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2049       HloCloneContext* context) const override;
2050 
2051   RandomAlgorithm algorithm_;
2052 };
2053 
2054 }  // namespace xla
2055 
2056 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
2057