• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // All HloInstruction subclasses are put in this file.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 
24 namespace xla {
25 
26 class HloBatchNormInstruction : public HloInstruction {
27  public:
28   // Returns feature_index field associated with the instruction. The index
29   // represents the index of the feature dimension.
feature_index()30   int64 feature_index() const { return feature_index_; }
31 
32   // Returns a epsilon value associated with the instruction. The is a small
33   // number added to the variance to avoid divide-by-zero error.
epsilon()34   float epsilon() const { return epsilon_; }
35 
36   // Returns a serialized representation of this instruction.
37   HloInstructionProto ToProto() const override;
38 
39  protected:
40   explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
41                                    HloInstruction* operand,
42                                    HloInstruction* scale, float epsilon,
43                                    int64 feature_index);
44 
45  private:
46   std::vector<string> ExtraAttributesToStringImpl(
47       const HloPrintOptions& options) const override;
48   bool IdenticalSlowPath(
49       const HloInstruction& other,
50       const std::function<bool(const HloComputation*, const HloComputation*)>&
51           eq_computations) const override;
52   // A small float number added to the variance to avoid divide-by-zero error.
53   float epsilon_ = 0.0f;
54 
55   // An integer value representing the index of the feature dimension.
56   int64 feature_index_ = -1;
57 };
58 
59 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
60  public:
61   explicit HloBatchNormTrainingInstruction(const Shape& shape,
62                                            HloInstruction* operand,
63                                            HloInstruction* scale,
64                                            HloInstruction* offset,
65                                            float epsilon, int64 feature_index);
66 
67  private:
68   // Implementation for non-common logic of CloneWithNewOperands.
69   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
70       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
71       HloCloneContext* context) const override;
72 };
73 
74 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
75  public:
76   explicit HloBatchNormInferenceInstruction(
77       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
78       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
79       float epsilon, int64 feature_index);
80 
81  private:
82   // Implementation for non-common logic of CloneWithNewOperands.
83   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
84       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
85       HloCloneContext* context) const override;
86 };
87 
88 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
89  public:
90   explicit HloBatchNormGradInstruction(
91       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
92       HloInstruction* mean, HloInstruction* variance,
93       HloInstruction* grad_output, float epsilon, int64 feature_index);
94 
95  private:
96   // Implementation for non-common logic of CloneWithNewOperands.
97   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
98       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
99       HloCloneContext* context) const override;
100 };
101 
102 class HloFftInstruction : public HloInstruction {
103  public:
104   explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
105                              FftType fft_type,
106                              absl::Span<const int64> fft_length);
fft_type()107   FftType fft_type() const { return fft_type_; }
108 
fft_length()109   const std::vector<int64>& fft_length() const { return fft_length_; }
110 
111   // Returns a serialized representation of this instruction.
112   HloInstructionProto ToProto() const override;
113 
114  private:
115   std::vector<string> ExtraAttributesToStringImpl(
116       const HloPrintOptions& options) const override;
117   bool IdenticalSlowPath(
118       const HloInstruction& other,
119       const std::function<bool(const HloComputation*, const HloComputation*)>&
120           eq_computations) const override;
121 
122   // Implementation for non-common logic of CloneWithNewOperands.
123   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
124       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
125       HloCloneContext* context) const override;
126 
127   // Describes FFT type for an FFT instruction.
128   FftType fft_type_ = FftType::FFT;
129 
130   // Indicates the FFT length for an FFT instruction.
131   std::vector<int64> fft_length_;
132 };
133 
134 class HloCompareInstruction : public HloInstruction {
135  public:
136   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
137                                  HloInstruction* rhs,
138                                  ComparisonDirection direction);
direction()139   ComparisonDirection direction() const { return direction_; }
140   HloInstructionProto ToProto() const override;
141 
142  private:
143   std::vector<string> ExtraAttributesToStringImpl(
144       const HloPrintOptions& options) const override;
145   bool IdenticalSlowPath(
146       const HloInstruction& other,
147       const std::function<bool(const HloComputation*, const HloComputation*)>&
148           eq_computations) const override;
149   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
150       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
151       HloCloneContext* context) const override;
152 
153   ComparisonDirection direction_;
154 };
155 
156 class HloTriangularSolveInstruction : public HloInstruction {
157  public:
158   explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
159                                          HloInstruction* b,
160                                          const TriangularSolveOptions& options);
triangular_solve_options()161   const TriangularSolveOptions& triangular_solve_options() const {
162     return triangular_solve_options_;
163   }
164 
165   // Returns a serialized representation of this instruction.
166   HloInstructionProto ToProto() const override;
167 
168  private:
169   std::vector<string> ExtraAttributesToStringImpl(
170       const HloPrintOptions& options) const override;
171   bool IdenticalSlowPath(
172       const HloInstruction& other,
173       const std::function<bool(const HloComputation*, const HloComputation*)>&
174           eq_computations) const override;
175 
176   // Implementation for non-common logic of CloneWithNewOperands.
177   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
178       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
179       HloCloneContext* context) const override;
180 
181   TriangularSolveOptions triangular_solve_options_;
182 };
183 
184 class HloCholeskyInstruction : public HloInstruction {
185  public:
186   explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
187                                   const CholeskyOptions& options);
cholesky_options()188   const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
189 
190   // Returns a serialized representation of this instruction.
191   HloInstructionProto ToProto() const override;
192 
193  private:
194   std::vector<string> ExtraAttributesToStringImpl(
195       const HloPrintOptions& options) const override;
196   bool IdenticalSlowPath(
197       const HloInstruction& other,
198       const std::function<bool(const HloComputation*, const HloComputation*)>&
199           eq_computations) const override;
200 
201   // Implementation for non-common logic of CloneWithNewOperands.
202   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
203       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
204       HloCloneContext* context) const override;
205 
206   CholeskyOptions cholesky_options_;
207 };
208 
209 // Class that represents instructions that synchronize and transfer data between
210 // partitioned devices. Send/Recv and collective instructions (AllReduce,
211 // AllToAll, CollectivePermute) belong to this instruction type. A group of
212 // instructions (of the same opcode) with the same channel_id communicate during
213 // execution.
214 class HloChannelInstruction : public HloInstruction {
215  public:
216   // Returns the channel id associated with the instruction. The id is
217   // shared between each Send/Recv pair or a group of collective instructions
218   // and is globally unique to identify each channel.
channel_id()219   absl::optional<int64> channel_id() const { return channel_id_; }
220   void set_channel_id(const absl::optional<int64>& channel_id);
221 
222  protected:
223   explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
224                                  const absl::optional<int64>& channel_id);
225 
226   HloInstructionProto ToProto() const override;
227 
228   std::vector<string> ExtraAttributesToStringImpl(
229       const HloPrintOptions& options) const override;
230   bool IdenticalSlowPath(
231       const HloInstruction& other,
232       const std::function<bool(const HloComputation*, const HloComputation*)>&
233           eq_computations) const override;
234 
235   absl::optional<int64> channel_id_;
236 };
237 
238 class HloSendRecvInstruction : public HloChannelInstruction {
239  public:
240   // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()241   bool is_host_transfer() const { return is_host_transfer_; }
242 
243   // Returns a serialized representation of this instruction.
244   HloInstructionProto ToProto() const override;
245 
246  protected:
247   explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
248                                   int64 channel_id, bool is_host_transfer);
249 
250  private:
251   std::vector<string> ExtraAttributesToStringImpl(
252       const HloPrintOptions& options) const override;
253   bool IdenticalSlowPath(
254       const HloInstruction& other,
255       const std::function<bool(const HloComputation*, const HloComputation*)>&
256           eq_computations) const override;
257   // Whether this send/recv instruction sends data to/from the host.
258   bool is_host_transfer_;
259 };
260 
261 class HloSendInstruction : public HloSendRecvInstruction {
262  public:
263   explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
264                               int64 channel_id, bool is_host_transfer);
265 
266  private:
267   // Implementation for non-common logic of CloneWithNewOperands.
268   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
269       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
270       HloCloneContext* context) const override;
271 };
272 
273 class HloSendDoneInstruction : public HloSendRecvInstruction {
274  public:
275   explicit HloSendDoneInstruction(HloSendInstruction* operand,
276                                   bool is_host_transfer);
277 
278  private:
279   // Implementation for non-common logic of CloneWithNewOperands.
280   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
281       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
282       HloCloneContext* context) const override;
283 };
284 
285 class HloRecvInstruction : public HloSendRecvInstruction {
286  public:
287   explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
288                               int64 channel_id, bool is_host_transfer);
289 
290  private:
291   // Implementation for non-common logic of CloneWithNewOperands.
292   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
293       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
294       HloCloneContext* context) const override;
295 };
296 
297 class HloRecvDoneInstruction : public HloSendRecvInstruction {
298  public:
299   explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
300                                   bool is_host_transfer);
301 
302  private:
303   // Implementation for non-common logic of CloneWithNewOperands.
304   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
305       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
306       HloCloneContext* context) const override;
307 };
308 
309 class HloCollectiveInstruction : public HloChannelInstruction {
310  public:
replica_groups()311   const std::vector<ReplicaGroup>& replica_groups() const {
312     return replica_groups_;
313   }
314 
315  protected:
316   explicit HloCollectiveInstruction(
317       HloOpcode opcode, const Shape& shape,
318       absl::Span<HloInstruction* const> operands,
319       const std::vector<ReplicaGroup>& replica_groups,
320       const absl::optional<int64>& channel_id);
321 
322   HloInstructionProto ToProto() const override;
323 
324   std::vector<string> ExtraAttributesToStringImpl(
325       const HloPrintOptions& options) const override;
326   bool IdenticalSlowPath(
327       const HloInstruction& other,
328       const std::function<bool(const HloComputation*, const HloComputation*)>&
329           eq_computations) const override;
330 
331   std::vector<ReplicaGroup> replica_groups_;
332 };
333 
334 class HloAllReduceInstruction : public HloCollectiveInstruction {
335  public:
336   explicit HloAllReduceInstruction(
337       const Shape& shape, absl::Span<HloInstruction* const> operands,
338       HloComputation* reduce_computation,
339       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
340       const absl::optional<int64>& channel_id);
341 
342   // Returns true if the AllReduce does no communication, so it's equivalent
343   // to a mem copy.
344   bool IsNoop() const;
345 
346   // Returns true if the layout of the AllReduce is enforced by XLA client (as
347   // the layout set in the shape). The only reason for the client to set the
348   // layout is to separately compile computations that communicate with
349   // AllReduce. Since this field is only set `true` by the client, the compiler
350   // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
351   // `false` for all other cases.
352   //
353   // When this is `true`, there may be communication endpoints outside the
354   // current compilation unit, so the compiler considers this AllReduce as
355   // side-effecting to disable compiler transformations. The compiler is free to
356   // transform unconstrained AllReduces differently across compilation units.
357   // It is an error for an HloModule to have a mix of constrained and
358   // unconstrained AllReduce instructions (checked by HloVerifier).
constrain_layout()359   bool constrain_layout() const { return constrain_layout_; }
360 
361  protected:
362   std::vector<string> ExtraAttributesToStringImpl(
363       const HloPrintOptions& options) const override;
364   HloInstructionProto ToProto() const override;
365 
366  private:
367   bool IdenticalSlowPath(
368       const HloInstruction& other,
369       const std::function<bool(const HloComputation*, const HloComputation*)>&
370           eq_computations) const override;
371 
372   // Implementation for non-common logic of CloneWithNewOperands.
373   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
374       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
375       HloCloneContext* context) const override;
376 
377   bool constrain_layout_;
378 };
379 
380 class HloAllToAllInstruction : public HloCollectiveInstruction {
381  public:
382   explicit HloAllToAllInstruction(
383       const Shape& shape, absl::Span<HloInstruction* const> operands,
384       const std::vector<ReplicaGroup>& replica_groups,
385       const absl::optional<int64>& channel_id,
386       const absl::optional<int64>& split_dimension);
387 
388   // AllToAll can optionally take a split dimension, which means that this
389   // AllToAll takes a single (flattened) array operand and produces an array
390   // output (instead of taking a list of operands and producing a tuple).
391   //
392   // split_dimension specifies which dimension in the operand is split across
393   // devices in each replica_group, and also means the concatenated dimension
394   // on the output (i.e., input and the output shapes are the same).
split_dimension()395   absl::optional<int64> split_dimension() const { return split_dimension_; }
396 
397  protected:
398   std::vector<string> ExtraAttributesToStringImpl(
399       const HloPrintOptions& options) const override;
400   HloInstructionProto ToProto() const override;
401 
402  private:
403   bool IdenticalSlowPath(
404       const HloInstruction& other,
405       const std::function<bool(const HloComputation*, const HloComputation*)>&
406           eq_computations) const override;
407 
408   // Implementation for non-common logic of CloneWithNewOperands.
409   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
410       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
411       HloCloneContext* context) const override;
412 
413   absl::optional<int64> split_dimension_;
414 };
415 
416 class HloCollectivePermuteInstruction : public HloChannelInstruction {
417  public:
418   explicit HloCollectivePermuteInstruction(
419       const Shape& shape, HloInstruction* operand,
420       const std::vector<std::pair<int64, int64>>& source_target_pairs,
421       const absl::optional<int64>& channel_id);
422 
source_target_pairs()423   const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
424     return source_target_pairs_;
425   }
426 
427   // Returns a serialized representation of this instruction.
428   HloInstructionProto ToProto() const override;
429 
430  private:
431   std::vector<string> ExtraAttributesToStringImpl(
432       const HloPrintOptions& options) const override;
433   bool IdenticalSlowPath(
434       const HloInstruction& other,
435       const std::function<bool(const HloComputation*, const HloComputation*)>&
436           eq_computations) const override;
437 
438   // Implementation for non-common logic of CloneWithNewOperands.
439   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
440       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
441       HloCloneContext* context) const override;
442 
443   const std::vector<std::pair<int64, int64>> source_target_pairs_;
444 };
445 
446 class HloReverseInstruction : public HloInstruction {
447  public:
448   explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
449                                  absl::Span<const int64> dimensions);
450   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()451   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)452   int64 dimensions(int64 index) const override { return dimensions()[index]; }
453   // Returns a serialized representation of this instruction.
454   HloInstructionProto ToProto() const override;
455 
456  private:
457   std::vector<string> ExtraAttributesToStringImpl(
458       const HloPrintOptions& options) const override;
459   bool IdenticalSlowPath(
460       const HloInstruction& other,
461       const std::function<bool(const HloComputation*, const HloComputation*)>&
462           eq_computations) const override;
463   // Implementation for non-common logic of CloneWithNewOperands.
464   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
465       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
466       HloCloneContext* context) const override;
467 
468   std::vector<int64> dimensions_;
469 };
470 
471 class HloConcatenateInstruction : public HloInstruction {
472  public:
473   explicit HloConcatenateInstruction(const Shape& shape,
474                                      absl::Span<HloInstruction* const> operands,
475                                      int64 dimension);
476   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()477   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)478   int64 dimensions(int64 index) const override { return dimensions()[index]; }
479   // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()480   int64 concatenate_dimension() const { return dimensions(0); }
481   // Returns a serialized representation of this instruction.
482   HloInstructionProto ToProto() const override;
483 
484  private:
485   std::vector<string> ExtraAttributesToStringImpl(
486       const HloPrintOptions& options) const override;
487   bool IdenticalSlowPath(
488       const HloInstruction& other,
489       const std::function<bool(const HloComputation*, const HloComputation*)>&
490           eq_computations) const override;
491   // Implementation for non-common logic of CloneWithNewOperands.
492   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
493       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
494       HloCloneContext* context) const override;
495 
496   std::vector<int64> dimensions_;
497 };
498 
499 class HloReduceInstruction : public HloInstruction {
500  public:
501   explicit HloReduceInstruction(const Shape& shape,
502                                 absl::Span<HloInstruction* const> args,
503                                 absl::Span<const int64> dimensions_to_reduce,
504                                 HloComputation* reduce_computation);
505   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()506   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)507   int64 dimensions(int64 index) const override { return dimensions()[index]; }
508   // Returns a serialized representation of this instruction.
509   HloInstructionProto ToProto() const override;
510 
511   // Returns the number of input arrays (and, consequentially, the number of
512   // init values) this reduce has.
input_count()513   int64 input_count() const { return operand_count() / 2; }
514 
515   // Returns the input tensors to be reduced.
inputs()516   absl::Span<HloInstruction* const> inputs() const {
517     return absl::MakeSpan(operands()).subspan(0, input_count());
518   }
519 
520   // Returns the init values of the reduction.
init_values()521   absl::Span<HloInstruction* const> init_values() const {
522     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
523   }
524 
525  private:
526   std::vector<string> ExtraAttributesToStringImpl(
527       const HloPrintOptions& options) const override;
528   bool IdenticalSlowPath(
529       const HloInstruction& other,
530       const std::function<bool(const HloComputation*, const HloComputation*)>&
531           eq_computations) const override;
532   // Implementation for non-common logic of CloneWithNewOperands.
533   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
534       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
535       HloCloneContext* context) const override;
536 
537   std::vector<int64> dimensions_;
538 };
539 
540 class HloSortInstruction : public HloInstruction {
541  public:
542   explicit HloSortInstruction(const Shape& shape, int64 dimension,
543                               absl::Span<HloInstruction* const> operands,
544                               HloComputation* compare, bool is_stable);
545   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()546   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)547   int64 dimensions(int64 index) const override { return dimensions()[index]; }
548   // Returns the sort dimension for this instruction
sort_dimension()549   int64 sort_dimension() const { return dimensions(0); }
550   // Returns a serialized representation of this instruction.
551   HloInstructionProto ToProto() const override;
552   // Returns the key operand to this instruction.
keys()553   const HloInstruction* keys() const { return operand(0); }
mutable_keys()554   HloInstruction* mutable_keys() { return mutable_operand(0); }
555   // Returns the number of value operands.
values_count()556   int64 values_count() const { return operand_count() - 1; }
is_stable()557   bool is_stable() const { return is_stable_; }
558 
559  private:
560   std::vector<string> ExtraAttributesToStringImpl(
561       const HloPrintOptions& options) const override;
562   bool IdenticalSlowPath(
563       const HloInstruction& other,
564       const std::function<bool(const HloComputation*, const HloComputation*)>&
565           eq_computations) const override;
566   // Implementation for non-common logic of CloneWithNewOperands.
567   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
568       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
569       HloCloneContext* context) const override;
570 
571   std::vector<int64> dimensions_;
572   bool is_stable_;
573 };
574 
575 class HloTransposeInstruction : public HloInstruction {
576  public:
577   explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
578                                    absl::Span<const int64> dimensions);
579   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()580   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)581   int64 dimensions(int64 index) const override { return dimensions()[index]; }
582   // Returns whether this instruction does a rank-2 transposition.
583   bool IsRank2Transpose() const;
584   // Returns a serialized representation of this instruction.
585   HloInstructionProto ToProto() const override;
586 
587  private:
588   std::vector<string> ExtraAttributesToStringImpl(
589       const HloPrintOptions& options) const override;
590   bool IdenticalSlowPath(
591       const HloInstruction& other,
592       const std::function<bool(const HloComputation*, const HloComputation*)>&
593           eq_computations) const override;
594   // Implementation for non-common logic of CloneWithNewOperands.
595   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
596       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
597       HloCloneContext* context) const override;
598 
599   std::vector<int64> dimensions_;
600 };
601 
602 class HloBroadcastInstruction : public HloInstruction {
603  public:
604   explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
605                                    absl::Span<const int64> broadcast_dimension);
606   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()607   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)608   int64 dimensions(int64 index) const override { return dimensions()[index]; }
609   // Returns a serialized representation of this instruction.
610   HloInstructionProto ToProto() const override;
611 
612  private:
613   std::vector<string> ExtraAttributesToStringImpl(
614       const HloPrintOptions& options) const override;
615   bool IdenticalSlowPath(
616       const HloInstruction& other,
617       const std::function<bool(const HloComputation*, const HloComputation*)>&
618           eq_computations) const override;
619   // Implementation for non-common logic of CloneWithNewOperands.
620   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
621       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
622       HloCloneContext* context) const override;
623 
624   std::vector<int64> dimensions_;
625 };
626 
627 class HloReshapeInstruction : public HloInstruction {
628  public:
629   explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
630                                  int64 inferred_dimension);
inferred_dimension()631   int64 inferred_dimension() const { return inferred_dimension_; }
632   HloInstructionProto ToProto() const override;
633 
634  private:
635   std::vector<string> ExtraAttributesToStringImpl(
636       const HloPrintOptions& options) const override;
637   bool IdenticalSlowPath(
638       const HloInstruction& other,
639       const std::function<bool(const HloComputation*, const HloComputation*)>&
640           eq_computations) const override;
641   // Implementation for non-common logic of CloneWithNewOperands.
642   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
643       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
644       HloCloneContext* context) const override;
645   int64 inferred_dimension_;
646 };
647 
648 class HloMapInstruction : public HloInstruction {
649  public:
650   explicit HloMapInstruction(const Shape& shape,
651                              absl::Span<HloInstruction* const> operands,
652                              HloComputation* map_computation);
653   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()654   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)655   int64 dimensions(int64 index) const override { return dimensions()[index]; }
656   // Returns a serialized representation of this instruction.
657   HloInstructionProto ToProto() const override;
658 
659  private:
660   bool IsElementwiseImpl(
661       const absl::optional<int64>& operand_idx) const override;
662   std::vector<string> ExtraAttributesToStringImpl(
663       const HloPrintOptions& options) const override;
664   bool IdenticalSlowPath(
665       const HloInstruction& other,
666       const std::function<bool(const HloComputation*, const HloComputation*)>&
667           eq_computations) const override;
668   // Implementation for non-common logic of CloneWithNewOperands.
669   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
670       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
671       HloCloneContext* context) const override;
672 
673   std::vector<int64> dimensions_;
674 };
675 
676 class HloSliceInstruction : public HloInstruction {
677  public:
678   explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
679                                absl::Span<const int64> start_indices,
680                                absl::Span<const int64> limit_indices,
681                                absl::Span<const int64> strides);
682 
683   HloInstructionProto ToProto() const override;
684 
685   // Returns the start index in the given dimension for a slice node.
slice_starts(int64 dimension)686   int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
slice_starts()687   const std::vector<int64>& slice_starts() const { return slice_starts_; }
688 
689   // Returns the (exclusive) limit index in the given dimension for a slice
690   // node.
slice_limits(int64 dimension)691   int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
slice_limits()692   const std::vector<int64>& slice_limits() const { return slice_limits_; }
693 
694   // Returns the stride in the given dimension for a slice node.
slice_strides(int64 dimension)695   int64 slice_strides(int64 dimension) const {
696     return slice_strides_[dimension];
697   }
slice_strides()698   const std::vector<int64>& slice_strides() const { return slice_strides_; }
699 
700  private:
701   std::vector<string> ExtraAttributesToStringImpl(
702       const HloPrintOptions& options) const override;
703   bool IdenticalSlowPath(
704       const HloInstruction& other,
705       const std::function<bool(const HloComputation*, const HloComputation*)>&
706           eq_computations) const override;
707   // Implementation for non-common logic of CloneWithNewOperands.
708   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
709       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
710       HloCloneContext* context) const override;
711 
712   // Describes the [begin, end) index range for a slice.
713   std::vector<int64> slice_starts_;
714   std::vector<int64> slice_limits_;
715   std::vector<int64> slice_strides_;
716 };
717 
718 class HloConstantInstruction : public HloInstruction {
719  public:
720   explicit HloConstantInstruction(Literal literal);
721   explicit HloConstantInstruction(Literal literal, const Shape& shape);
722   // Used when the literal is too large and dropped.
723   explicit HloConstantInstruction(const Shape& shape);
724   // Returns the literal associated with this instruction.
literal()725   const Literal& literal() const { return *literal_; }
726   // Returns whether there is literal associated with this instruction.
HasLiteral()727   bool HasLiteral() const { return literal_.has_value(); }
728   // Returns a serialized representation of this instruction.
729   HloInstructionProto ToProto() const override;
730 
731   // Change the layout for an Constant Hlo instruction to match new_layout.  For
732   // tuple shaped constants shape_index is the path to the internal array
733   // subshape whose layout needs to be changed.
734   void RelayoutConstant(const Layout& new_layout,
735                         const ShapeIndex& shape_index = {});
736 
737  private:
738   bool IsElementwiseImpl(
739       const absl::optional<int64>& operand_idx) const override;
740   bool IdenticalSlowPath(
741       const HloInstruction& other,
742       const std::function<bool(const HloComputation*, const HloComputation*)>&
743           eq_computations) const override;
744   string OperandsToStringWithCanonicalNameMap(
745       const HloPrintOptions& options,
746       CanonicalNameMap* canonical_name_map) const override;
747   // Implementation for non-common logic of CloneWithNewOperands.
748   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
749       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
750       HloCloneContext* context) const override;
751   absl::optional<Literal> literal_;
752 };
753 
754 class HloTraceInstruction : public HloInstruction {
755  public:
756   explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
757   // Returns a tag to be used in tracing.
TracingTag()758   string TracingTag() const { return literal_.GetR1U8AsString(); }
759   // Returns a serialized representation of this instruction.
760   HloInstructionProto ToProto() const override;
761 
762  private:
763   bool IdenticalSlowPath(
764       const HloInstruction& other,
765       const std::function<bool(const HloComputation*, const HloComputation*)>&
766           eq_computations) const override;
767   // Implementation for non-common logic of CloneWithNewOperands.
768   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
769       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
770       HloCloneContext* context) const override;
771   Literal literal_;
772 };
773 
774 class HloFusionInstruction : public HloInstruction {
775  public:
776   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
777                                 HloInstruction* fused_root);
778 
779   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
780                                 absl::Span<HloInstruction* const> operands,
781                                 HloComputation* fusion_computation);
782 
783   string ToCategory() const override;
784   // Returns a serialized representation of this instruction.
785   HloInstructionProto ToProto() const override;
786 
787   // Adds a new operand the fusion instruction.
788   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
789 
790   // Merges the fused instructions from 'instruction_to_merge' into the
791   // fused instruction set of 'this', updating operands as necessary.
792   //
793   // Precondition: 'instruction_to_merge' must be an operand of 'this'.
794   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
795 
796   // Merges the fused instructions from instruction_to_merge into the fused
797   // instruction set of 'this' and generates multioutput fusion instructions.
798   // All the users of instruction_to_merge will be redirected to 'this'
799   // instruction. instruction_to_merge will be removed from its parent
800   // computation.
801   void MergeFusionInstructionIntoMultiOutput(
802       HloFusionInstruction* instruction_to_merge);
803 
804   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
805   // is cloned and the clone is placed in the fusion
806   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
807   // than moved to cleanly handle the case where the instruction has a use
808   // outside the fusion instruction. Moving such an instruction into a fusion
809   // instruction would violate the single-result invariant of HLO instructions
810   // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)811   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
812     return FuseInstructionInternal(instruction_to_fuse);
813   }
814 
815   // Fuses the given instruction in this fusion instruction and generates a
816   // multioutput fusion instruction. A clone of the instruction_to_fuse will
817   // be part of the output of fusion instructions. The users of
818   // instruction_to_fuse will be redirected to this fusion instructions.
819   // instruction_to_fuse is unchanged otherwise.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)820   HloInstruction* FuseInstructionIntoMultiOutput(
821       HloInstruction* instruction_to_fuse) {
822     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
823   }
824 
825   // Returns the computation for this fused instruction.
826   HloComputation* fused_instructions_computation() const;
827 
828   // Returns the root instruction of the fused expression contained within this
829   // fusion instruction.
830   HloInstruction* fused_expression_root() const;
831 
832   // Returns the list of fused instructions inside this fusion instruction.  The
833   // returned type is a range of HloInstruction*s.
834   const tensorflow::gtl::iterator_range<UnwrappingIterator<
835       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
836   fused_instructions() const;
837 
838   const tensorflow::gtl::iterator_range<
839       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
840   fused_instructions();
841 
842   // Gets the number of instructions inside this fusion instruction.
843   int64 fused_instruction_count() const;
844 
845   // Returns the fused parameter instruction in this fusion instruction
846   // corresponding to the given parameter number.
847   HloInstruction* fused_parameter(int64 parameter_number) const;
848 
849   // Returns the vector of fused parameters inside this fusion instruction.
850   const std::vector<HloInstruction*>& fused_parameters() const;
851 
852   // Returns true if this instruction is a fusion instruction that generates
853   // multiple outputs.
IsMultiOutputFusion()854   const bool IsMultiOutputFusion() const {
855     return fused_expression_root()->opcode() == HloOpcode::kTuple;
856   }
857 
fusion_kind()858   FusionKind fusion_kind() const { return fusion_kind_; }
859 
set_fusion_kind(FusionKind kind)860   void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
861 
862   // If multiple operands are the same instruction, keeps only one of them.
863   Status DeduplicateFusionOperands();
864 
865  private:
866   // Fuses the given instruction into this fusion instruction.
867   // instruction_to_fuse is cloned and the clone is placed in the fusion
868   // instruction.  The users of instruction_to_fuse will be redirected to this
869   // fusion instruction. instruction_to_fuse is unchanged otherwise. When
870   // add_output is true, a clone of the instruction_to_fuse will be added as
871   // additional output resulting in a multi-output fusion.
872   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
873                                           bool add_output = false);
874   // Clones the given instruction_to_fuse and insert the clone into this fusion
875   // instruction. If add_output is true, a clone of instruction_to_fuse will
876   // be in the output of the this fusion instruction (part of the tuple of the
877   // fusion root).
878   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
879                                        bool add_output = false);
880 
881   bool IsElementwiseImpl(
882       const absl::optional<int64>& operand_idx) const override;
883   std::vector<string> ExtraAttributesToStringImpl(
884       const HloPrintOptions& options) const override;
885   bool IdenticalSlowPath(
886       const HloInstruction& other,
887       const std::function<bool(const HloComputation*, const HloComputation*)>&
888           eq_computations) const override;
889   uint64 InnerHash() const override;
890 
891   // Implementation for non-common logic of CloneWithNewOperands.
892   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
893       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
894       HloCloneContext* context) const override;
895 
896   // The type of the fusion. Used by kFusion only.
897   FusionKind fusion_kind_;
898 };
899 
900 class HloRngInstruction : public HloInstruction {
901  public:
902   explicit HloRngInstruction(const Shape& shape,
903                              RandomDistribution distribution,
904                              absl::Span<HloInstruction* const> parameters);
905   // Returns the random distribution for this rng node.
random_distribution()906   RandomDistribution random_distribution() const { return distribution_; }
907   // Returns a serialized representation of this instruction.
908   HloInstructionProto ToProto() const override;
909 
910  private:
911   bool IsElementwiseImpl(
912       const absl::optional<int64>& operand_idx) const override;
913   std::vector<string> ExtraAttributesToStringImpl(
914       const HloPrintOptions& options) const override;
915   bool IdenticalSlowPath(
916       const HloInstruction& other,
917       const std::function<bool(const HloComputation*, const HloComputation*)>&
918           eq_computations) const override;
919   // Implementation for non-common logic of CloneWithNewOperands.
920   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
921       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
922       HloCloneContext* context) const override;
923 
924   // The distribution requested for random number generation.
925   RandomDistribution distribution_;
926 };
927 
928 class HloParameterInstruction : public HloInstruction {
929  public:
930   explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
931                                    const string& name);
parameter_number()932   int64 parameter_number() const { return parameter_number_; }
933 
934   // Sets and gets the whether all replicas will receive the same parameter data
935   // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)936   void set_parameter_replicated_at_leaf_buffers(
937       absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
938     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
939              parameter_replicated_at_leaf_buffers.size());
940     parameter_replicated_at_leaf_buffers_.emplace(
941         parameter_replicated_at_leaf_buffers.begin(),
942         parameter_replicated_at_leaf_buffers.end());
943   }
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)944   void set_parameter_replicated_at_leaf_buffers(
945       const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
946     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
947              parameter_replicated_at_leaf_buffers.size());
948     parameter_replicated_at_leaf_buffers_ =
949         parameter_replicated_at_leaf_buffers;
950   }
951   const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers()952   parameter_replicated_at_leaf_buffers() const {
953     return parameter_replicated_at_leaf_buffers_;
954   }
955 
956   // Returns a serialized representation of this instruction.
957   HloInstructionProto ToProto() const override;
958 
959  private:
960   std::vector<string> ExtraAttributesToStringImpl(
961       const HloPrintOptions& options) const override;
962   bool IdenticalSlowPath(
963       const HloInstruction& other,
964       const std::function<bool(const HloComputation*, const HloComputation*)>&
965           eq_computations) const override;
966   string OperandsToStringWithCanonicalNameMap(
967       const HloPrintOptions& options,
968       CanonicalNameMap* canonical_name_map) const override;
969   // Implementation for non-common logic of CloneWithNewOperands.
970   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
971       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
972       HloCloneContext* context) const override;
973 
974   int64 parameter_number_ = 0;
975 
976   // Specifies whether each buffer has the same parameter value on all replicas
977   // in data parallelism.
978   absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
979 };
980 
981 class HloGetTupleElementInstruction : public HloInstruction {
982  public:
983   explicit HloGetTupleElementInstruction(const Shape& shape,
984                                          HloInstruction* operand, int64 index);
985   // Returns the tuple index associated with this instruction.
tuple_index()986   int64 tuple_index() const { return tuple_index_; }
987   // Sets the tuple index associated with this instruction.
set_tuple_index(int64 new_tuple_index)988   void set_tuple_index(int64 new_tuple_index) {
989     tuple_index_ = new_tuple_index;
990   }
991   // Returns a serialized representation of this instruction.
992   HloInstructionProto ToProto() const override;
993 
994  private:
995   std::vector<string> ExtraAttributesToStringImpl(
996       const HloPrintOptions& options) const override;
997   bool IdenticalSlowPath(
998       const HloInstruction& other,
999       const std::function<bool(const HloComputation*, const HloComputation*)>&
1000           eq_computations) const override;
1001   // Implementation for non-common logic of CloneWithNewOperands.
1002   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1003       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1004       HloCloneContext* context) const override;
1005 
1006   int64 tuple_index_ = -1;
1007 };
1008 
1009 class HloReducePrecisionInstruction : public HloInstruction {
1010  public:
1011   explicit HloReducePrecisionInstruction(const Shape& shape,
1012                                          HloInstruction* operand,
1013                                          const int exponent_bits,
1014                                          const int mantissa_bits);
1015   // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()1016   int32 exponent_bits() const { return exponent_bits_; }
1017   // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()1018   int32 mantissa_bits() const { return mantissa_bits_; }
1019   // Returns a serialized representation of this instruction.
1020   HloInstructionProto ToProto() const override;
1021 
1022  private:
1023   std::vector<string> ExtraAttributesToStringImpl(
1024       const HloPrintOptions& options) const override;
1025   bool IdenticalSlowPath(
1026       const HloInstruction& other,
1027       const std::function<bool(const HloComputation*, const HloComputation*)>&
1028           eq_computations) const override;
1029   // Implementation for non-common logic of CloneWithNewOperands.
1030   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1031       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1032       HloCloneContext* context) const override;
1033 
1034   // The bit sizes for a reduce-precision operation.
1035   int32 exponent_bits_ = 0;
1036   int32 mantissa_bits_ = 0;
1037 };
1038 
1039 class HloInfeedInstruction : public HloInstruction {
1040  public:
1041   explicit HloInfeedInstruction(const Shape& infeed_shape,
1042                                 HloInstruction* token_operand,
1043                                 const string& config);
1044   // Returns the infeed configuration string. The infeed configuration includes
1045   // any metadata needed for the backend compiler (e.g., infeed buffer address)
1046   // and is target-dependent.
infeed_config()1047   string infeed_config() const { return infeed_config_; }
set_infeed_config(const string & config)1048   void set_infeed_config(const string& config) { infeed_config_ = config; }
1049   // Returns the shape of the data received by the infeed. This is not the same
1050   // as the shape of the infeed instruction which produces a tuple containing
1051   // the infeed data shape and a TOKEN.
infeed_shape()1052   const Shape& infeed_shape() const {
1053     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
1054     return ShapeUtil::GetSubshape(shape(), {0});
1055   }
1056   // Returns a serialized representation of this instruction.
1057   HloInstructionProto ToProto() const override;
1058 
1059  private:
1060   std::vector<string> ExtraAttributesToStringImpl(
1061       const HloPrintOptions& options) const override;
1062   bool IdenticalSlowPath(
1063       const HloInstruction& other,
1064       const std::function<bool(const HloComputation*, const HloComputation*)>&
1065           eq_computations) const override;
1066   // Implementation for non-common logic of CloneWithNewOperands.
1067   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1068       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1069       HloCloneContext* context) const override;
1070 
1071   // The string representation of the infeed configuration.
1072   string infeed_config_;
1073 };
1074 
1075 class HloOutfeedInstruction : public HloInstruction {
1076  public:
1077   explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1078                                  HloInstruction* operand,
1079                                  HloInstruction* token_operand,
1080                                  absl::string_view outfeed_config);
1081   // Returns the shape for the Outfeed instruction.
outfeed_shape()1082   const Shape& outfeed_shape() const { return outfeed_shape_; }
1083   // Returns the config for the Outfeed instruction.
outfeed_config()1084   const string& outfeed_config() const { return outfeed_config_; }
1085   // Returns a serialized representation of this instruction.
1086   HloInstructionProto ToProto() const override;
1087 
1088  private:
1089   std::vector<string> ExtraAttributesToStringImpl(
1090       const HloPrintOptions& options) const override;
1091   bool IdenticalSlowPath(
1092       const HloInstruction& other,
1093       const std::function<bool(const HloComputation*, const HloComputation*)>&
1094           eq_computations) const override;
1095   // Implementation for non-common logic of CloneWithNewOperands.
1096   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1097       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1098       HloCloneContext* context) const override;
1099 
1100   // Shape of outfeed request.
1101   Shape outfeed_shape_;
1102   // Outfeed configuration information, only present for kOutfeed.
1103   string outfeed_config_;
1104 };
1105 
1106 class HloConvolutionInstruction : public HloInstruction {
1107  public:
1108   explicit HloConvolutionInstruction(
1109       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1110       int64 feature_group_count, int64 batch_group_count, const Window& window,
1111       const ConvolutionDimensionNumbers& dimension_numbers,
1112       const PrecisionConfig& precision_config);
window()1113   const Window& window() const override { return window_; }
set_window(const Window & window)1114   void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1115   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1116     return convolution_dimension_numbers_;
1117   }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1118   void set_convolution_dimension_numbers(
1119       const ConvolutionDimensionNumbers& dnums) {
1120     convolution_dimension_numbers_ = dnums;
1121   }
1122   // The number of feature groups. Must be a divisor of the input feature
1123   // dimension and output feature dimension.
feature_group_count()1124   int64 feature_group_count() const { return feature_group_count_; }
set_feature_group_count(int64 num_feature_groups)1125   void set_feature_group_count(int64 num_feature_groups) {
1126     feature_group_count_ = num_feature_groups;
1127   }
1128   // The number of batch groups. Must be a divisor of the input batch dimension.
batch_group_count()1129   int64 batch_group_count() const { return batch_group_count_; }
set_batch_group_count(int64 num_batch_groups)1130   void set_batch_group_count(int64 num_batch_groups) {
1131     batch_group_count_ = num_batch_groups;
1132   }
1133 
1134   // Returns the information used to tell the implementation information about
1135   // what sort of precision is requested. The meaning of the field is backend
1136   // specific. At the moment, it is only supported for kConvolution and kDot.
1137   // Transformations on one kDot or kConvolution to another will preserve this
1138   // information. Transformations to other HLOs will not preserve this
1139   // information but it is presumed that the alternate lowering is strictly
1140   // superior.
precision_config()1141   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1142   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1143 
1144   string ToCategory() const override;
1145   // Returns a serialized representation of this instruction.
1146   HloInstructionProto ToProto() const override;
1147 
1148  private:
1149   std::vector<string> ExtraAttributesToStringImpl(
1150       const HloPrintOptions& options) const override;
1151   bool IdenticalSlowPath(
1152       const HloInstruction& other,
1153       const std::function<bool(const HloComputation*, const HloComputation*)>&
1154           eq_computations) const override;
1155   // Implementation for non-common logic of CloneWithNewOperands.
1156   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1157       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1158       HloCloneContext* context) const override;
1159   // The number of feature groups. Must be a divisor of the input feature
1160   // dimension and output feature dimension.
1161   int64 feature_group_count_;
1162   // The number of batch groups. Must be a divisor of the input batch dimension.
1163   int64 batch_group_count_;
1164   // Describes the window used for a convolution.
1165   Window window_;
1166   // Describes the dimension numbers used for a convolution.
1167   ConvolutionDimensionNumbers convolution_dimension_numbers_;
1168   // Information used to communicate to the implementation about the algorithm
1169   // used to produce results. See the documentation on precision_config().
1170   PrecisionConfig precision_config_;
1171 };
1172 
1173 class HloReduceWindowInstruction : public HloInstruction {
1174  public:
1175   explicit HloReduceWindowInstruction(const Shape& shape,
1176                                       HloInstruction* operand,
1177                                       HloInstruction* init_value,
1178                                       const Window& window,
1179                                       HloComputation* reduce_computation);
window()1180   const Window& window() const override { return window_; }
set_window(const Window & window)1181   void set_window(const Window& window) override { window_ = window; }
1182   // Returns a serialized representation of this instruction.
1183   HloInstructionProto ToProto() const override;
1184 
1185  private:
1186   std::vector<string> ExtraAttributesToStringImpl(
1187       const HloPrintOptions& options) const override;
1188   bool IdenticalSlowPath(
1189       const HloInstruction& other,
1190       const std::function<bool(const HloComputation*, const HloComputation*)>&
1191           eq_computations) const override;
1192   // Implementation for non-common logic of CloneWithNewOperands.
1193   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1194       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1195       HloCloneContext* context) const override;
1196   Window window_;
1197 };
1198 
1199 class HloSelectAndScatterInstruction : public HloInstruction {
1200  public:
1201   explicit HloSelectAndScatterInstruction(
1202       const Shape& shape, HloInstruction* operand, HloComputation* select,
1203       const Window& window, HloInstruction* source, HloInstruction* init_value,
1204       HloComputation* scatter);
window()1205   const Window& window() const override { return window_; }
set_window(const Window & window)1206   void set_window(const Window& window) override { window_ = window; }
1207   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1208   // setters should only be called by HloModule or HloComputation methods.
select()1209   HloComputation* select() const {
1210     return called_computations()[kSelectComputationIndex];
1211   }
1212 
scatter()1213   HloComputation* scatter() const {
1214     return called_computations()[kScatterComputationIndex];
1215   }
1216 
set_select(HloComputation * computation)1217   void set_select(HloComputation* computation) {
1218     // Don't allow changing the computation for fused instructions so we don't
1219     // have to recompute called_instructions for the entire fusion instruction.
1220     CHECK(!IsFused());
1221     set_called_computation(kSelectComputationIndex, computation);
1222   }
1223 
set_scatter(HloComputation * computation)1224   void set_scatter(HloComputation* computation) {
1225     // Don't allow changing the computation for fused instructions so we don't
1226     // have to recompute called_instructions for the entire fusion instruction.
1227     CHECK(!IsFused());
1228     set_called_computation(kScatterComputationIndex, computation);
1229   }
1230   // Returns a serialized representation of this instruction.
1231   HloInstructionProto ToProto() const override;
1232 
1233  private:
1234   std::vector<string> ExtraAttributesToStringImpl(
1235       const HloPrintOptions& options) const override;
1236   bool IdenticalSlowPath(
1237       const HloInstruction& other,
1238       const std::function<bool(const HloComputation*, const HloComputation*)>&
1239           eq_computations) const override;
1240   // Implementation for non-common logic of CloneWithNewOperands.
1241   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1242       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1243       HloCloneContext* context) const override;
1244   Window window_;
1245 };
1246 
1247 class HloCustomCallInstruction : public HloInstruction {
1248  public:
1249   HloCustomCallInstruction(const Shape& shape,
1250                            absl::Span<HloInstruction* const> operands,
1251                            absl::string_view custom_call_target, string opaque);
1252 
1253   // Constructor for a custom call with constrained layout. 'shape' and
1254   // 'operands_with_layout' must all have layouts.
1255   HloCustomCallInstruction(const Shape& shape,
1256                            absl::Span<HloInstruction* const> operands,
1257                            absl::string_view custom_call_target, string opaque,
1258                            absl::Span<const Shape> operand_shapes_with_layout);
1259 
window()1260   const Window& window() const override {
1261     CHECK(window_ != nullptr);
1262     return *window_;
1263   }
1264 
set_window(const Window & window)1265   void set_window(const Window& window) override {
1266     window_ = absl::make_unique<Window>(window);
1267   }
1268 
convolution_dimension_numbers()1269   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1270     CHECK(convolution_dimension_numbers_ != nullptr);
1271     return *convolution_dimension_numbers_;
1272   }
1273 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1274   void set_convolution_dimension_numbers(
1275       const ConvolutionDimensionNumbers& dnums) {
1276     convolution_dimension_numbers_ =
1277         absl::make_unique<ConvolutionDimensionNumbers>(dnums);
1278   }
1279   // TODO(jpienaar): Remove this accessor in the follow up.
opaque()1280   const string& opaque() const { return raw_backend_config_string(); }
custom_call_target()1281   const string& custom_call_target() const { return custom_call_target_; }
set_feature_group_count(int64 feature_group_count)1282   void set_feature_group_count(int64 feature_group_count) {
1283     feature_group_count_ = feature_group_count;
1284   }
set_batch_group_count(int64 batch_group_count)1285   void set_batch_group_count(int64 batch_group_count) {
1286     batch_group_count_ = batch_group_count;
1287   }
1288   // Sets whether this custom call has a side-effect - by default a custom call
1289   // has no side-effects.
set_custom_call_has_side_effect(bool custom_call_has_side_effect)1290   void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
1291     custom_call_has_side_effect_ = custom_call_has_side_effect;
1292   }
feature_group_count()1293   int64 feature_group_count() const { return feature_group_count_; }
batch_group_count()1294   int64 batch_group_count() const { return batch_group_count_; }
custom_call_has_side_effect()1295   bool custom_call_has_side_effect() const {
1296     return custom_call_has_side_effect_;
1297   }
1298   // Returns a serialized representation of this instruction.
1299   HloInstructionProto ToProto() const override;
1300 
1301   // Returns whether the result and operand layouts are constrained.
layout_constrained()1302   bool layout_constrained() const { return layout_constrained_; }
1303 
1304   // Returns the shapes (with layout) of the operands. CHECKs if this custom
1305   // call does not have constrained layouts.
operand_shapes_with_layout()1306   const std::vector<Shape>& operand_shapes_with_layout() const {
1307     CHECK(layout_constrained());
1308     return operand_shapes_with_layout_;
1309   }
1310 
1311  private:
1312   std::vector<string> ExtraAttributesToStringImpl(
1313       const HloPrintOptions& options) const override;
1314   bool IdenticalSlowPath(
1315       const HloInstruction& other,
1316       const std::function<bool(const HloComputation*, const HloComputation*)>&
1317           eq_computations) const override;
1318   // Implementation for non-common logic of CloneWithNewOperands.
1319   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1320       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1321       HloCloneContext* context) const override;
1322   // Name of a global symbol to call.
1323   string custom_call_target_;
1324   // Describes the window in a windowed operation such as convolution.
1325   std::unique_ptr<Window> window_;
1326   // Describes the dimension numbers used for a convolution.
1327   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1328   // The number of feature groups. This is used for grouped convolutions.
1329   int64 feature_group_count_;
1330   int64 batch_group_count_;
1331   // Whether the result and operand layouts are constrained.
1332   bool layout_constrained_;
1333   // For layout-constrained custom calls, this vector holds the shape with
1334   // layout for each operand.
1335   std::vector<Shape> operand_shapes_with_layout_;
1336   // Whether this custom call has a side-effect.
1337   bool custom_call_has_side_effect_;
1338 };
1339 
1340 class HloPadInstruction : public HloInstruction {
1341  public:
1342   explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1343                              HloInstruction* padding_value,
1344                              const PaddingConfig& padding_config);
1345   // Returns the padding configuration for a pad node.
padding_config()1346   const PaddingConfig& padding_config() const { return padding_config_; }
1347   // Returns the padding value.
padding_value()1348   const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1349   HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1350   // Returns a serialized representation of this instruction.
1351   HloInstructionProto ToProto() const override;
1352 
1353  private:
1354   std::vector<string> ExtraAttributesToStringImpl(
1355       const HloPrintOptions& options) const override;
1356   bool IdenticalSlowPath(
1357       const HloInstruction& other,
1358       const std::function<bool(const HloComputation*, const HloComputation*)>&
1359           eq_computations) const override;
1360   // Implementation for non-common logic of CloneWithNewOperands.
1361   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1362       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1363       HloCloneContext* context) const override;
1364 
1365   // The padding configuration that describes the edge padding and interior
1366   // padding of this pad instruction.
1367   PaddingConfig padding_config_;
1368 };
1369 
1370 class HloDynamicIndexInstruction : public HloInstruction {
1371  public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1372   explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
1373       : HloInstruction(opcode, shape) {}
1374   virtual int64 first_index_operand_number() const = 0;
1375 
1376   // Returns a subspan of operands which represent the start indices.
index_operands()1377   absl::Span<HloInstruction* const> index_operands() const {
1378     return absl::MakeSpan(operands()).subspan(first_index_operand_number());
1379   }
1380 
1381   // Returns the shapes of the index operands.
index_shapes()1382   std::vector<Shape> index_shapes() const {
1383     std::vector<Shape> shapes;
1384     auto indices = index_operands();
1385     for (const HloInstruction* index : indices) {
1386       shapes.push_back(index->shape());
1387     }
1388     return shapes;
1389   }
1390 };
1391 
1392 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
1393  public:
1394   explicit HloDynamicSliceInstruction(const Shape& shape,
1395                                       HloInstruction* operand,
1396                                       HloInstruction* start_indices,
1397                                       absl::Span<const int64> slice_sizes);
1398   explicit HloDynamicSliceInstruction(
1399       const Shape& shape, HloInstruction* operand,
1400       absl::Span<HloInstruction* const> start_indices,
1401       absl::Span<const int64> slice_sizes);
1402   // Old methods kept for smooth subclassing transition END.
1403   // Returns the size of the slice in the given dimension for a dynamic
1404   // slice node.
slice_sizes(int64 dimension)1405   int64 slice_sizes(int64 dimension) const {
1406     return dynamic_slice_sizes_[dimension];
1407   }
dynamic_slice_sizes()1408   const std::vector<int64>& dynamic_slice_sizes() const {
1409     return dynamic_slice_sizes_;
1410   }
1411   // Returns a serialized representation of this instruction.
1412   HloInstructionProto ToProto() const override;
1413 
first_index_operand_number()1414   int64 first_index_operand_number() const override { return 1; }
1415 
1416  private:
1417   std::vector<string> ExtraAttributesToStringImpl(
1418       const HloPrintOptions& options) const override;
1419   bool IdenticalSlowPath(
1420       const HloInstruction& other,
1421       const std::function<bool(const HloComputation*, const HloComputation*)>&
1422           eq_computations) const override;
1423   // Implementation for non-common logic of CloneWithNewOperands.
1424   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1425       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1426       HloCloneContext* context) const override;
1427 
1428   // Describes the [start, start + size) range size for a dynamic slice
1429   // ('start' is specified dynamically in the second operand of the operation).
1430   std::vector<int64> dynamic_slice_sizes_;
1431 };
1432 
1433 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
1434  public:
1435   explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
1436                                             HloInstruction* operand,
1437                                             HloInstruction* update,
1438                                             HloInstruction* start_indices);
1439   explicit HloDynamicUpdateSliceInstruction(
1440       const Shape& shape, HloInstruction* operand, HloInstruction* update,
1441       absl::Span<HloInstruction* const> start_indices);
1442 
first_index_operand_number()1443   int64 first_index_operand_number() const override { return 2; }
1444 };
1445 
1446 class HloGatherInstruction : public HloInstruction {
1447  public:
1448   explicit HloGatherInstruction(
1449       const Shape& shape, HloInstruction* operand,
1450       HloInstruction* start_indices,
1451       const GatherDimensionNumbers& gather_dim_numbers,
1452       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
gather_dimension_numbers()1453   const GatherDimensionNumbers& gather_dimension_numbers() const {
1454     CHECK(gather_dimension_numbers_ != nullptr);
1455     return *gather_dimension_numbers_;
1456   }
gather_slice_sizes()1457   absl::Span<const int64> gather_slice_sizes() const {
1458     return gather_slice_sizes_;
1459   }
indices_are_sorted()1460   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1461   void set_indices_are_sorted(bool indices_are_sorted) {
1462     indices_are_sorted_ = indices_are_sorted;
1463   }
1464   // Returns a serialized representation of this instruction.
1465   HloInstructionProto ToProto() const override;
1466 
1467   // Creates an instance of GatherDimensionNumbers.
1468   static GatherDimensionNumbers MakeGatherDimNumbers(
1469       absl::Span<const int64> offset_dims,
1470       absl::Span<const int64> collapsed_slice_dims,
1471       absl::Span<const int64> start_index_map, int64 index_vector_dim);
1472   // Returns the dump string of the given gather dimension numbers.
1473   static string GatherDimensionNumbersToString(
1474       const GatherDimensionNumbers& gather_dimension_numbers);
1475 
1476  private:
1477   std::vector<string> ExtraAttributesToStringImpl(
1478       const HloPrintOptions& options) const override;
1479   bool IdenticalSlowPath(
1480       const HloInstruction& other,
1481       const std::function<bool(const HloComputation*, const HloComputation*)>&
1482           eq_computations) const override;
1483   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1484       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1485       HloCloneContext* context) const override;
1486 
1487   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1488   std::vector<int64> gather_slice_sizes_;
1489   bool indices_are_sorted_;
1490 };
1491 
1492 class HloScatterInstruction : public HloInstruction {
1493  public:
1494   explicit HloScatterInstruction(
1495       const Shape& shape, HloInstruction* operand,
1496       HloInstruction* scatter_indices, HloInstruction* updates,
1497       HloComputation* update_computation,
1498       const ScatterDimensionNumbers& scatter_dim_numbers,
1499       bool indices_are_sorted, bool unique_indices);
scatter_dimension_numbers()1500   const ScatterDimensionNumbers& scatter_dimension_numbers() const {
1501     CHECK(scatter_dimension_numbers_ != nullptr);
1502     return *scatter_dimension_numbers_;
1503   }
indices_are_sorted()1504   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1505   void set_indices_are_sorted(bool indices_are_sorted) {
1506     indices_are_sorted_ = indices_are_sorted;
1507   }
unique_indices()1508   bool unique_indices() const override { return unique_indices_; }
1509   // Returns a serialized representation of this instruction.
1510   HloInstructionProto ToProto() const override;
1511 
1512   // Creates an instance of ScatterDimensionNumbers.
1513   static ScatterDimensionNumbers MakeScatterDimNumbers(
1514       absl::Span<const int64> update_window_dims,
1515       absl::Span<const int64> inserted_window_dims,
1516       absl::Span<const int64> scatter_dims_to_operand_dims,
1517       int64 index_vector_dim);
1518   // Returns the dump string of the given scatter dimension numbers.
1519   static string ScatterDimensionNumbersToString(
1520       const ScatterDimensionNumbers& scatter_dimension_numbers);
1521 
1522  private:
1523   std::vector<string> ExtraAttributesToStringImpl(
1524       const HloPrintOptions& options) const override;
1525   bool IdenticalSlowPath(
1526       const HloInstruction& other,
1527       const std::function<bool(const HloComputation*, const HloComputation*)>&
1528           eq_computations) const override;
1529   // Implementation for non-common logic of CloneWithNewOperands.
1530   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1531       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1532       HloCloneContext* context) const override;
1533 
1534   std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
1535   bool indices_are_sorted_;
1536   bool unique_indices_;
1537 };
1538 
1539 class HloIotaInstruction : public HloInstruction {
1540  public:
1541   explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
1542   // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()1543   int64 iota_dimension() const { return iota_dimension_; }
1544   // Returns a serialized representation of this instruction.
1545   HloInstructionProto ToProto() const override;
1546 
1547  private:
1548   std::vector<string> ExtraAttributesToStringImpl(
1549       const HloPrintOptions& options) const override;
1550   bool IdenticalSlowPath(
1551       const HloInstruction& other,
1552       const std::function<bool(const HloComputation*, const HloComputation*)>&
1553           eq_computations) const override;
1554   // Implementation for non-common logic of CloneWithNewOperands.
1555   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1556       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1557       HloCloneContext* context) const override;
1558 
1559   const int64 iota_dimension_;
1560 };
1561 
1562 class HloDotInstruction : public HloInstruction {
1563  public:
1564   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
1565   // dimensions specified in 'dimension_numbers'.
1566   explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
1567                              HloInstruction* rhs,
1568                              const DotDimensionNumbers& dimension_numbers,
1569                              const PrecisionConfig& precision_config);
1570 
1571   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()1572   const DotDimensionNumbers& dot_dimension_numbers() const {
1573     return dot_dimension_numbers_;
1574   }
1575 
1576   // Returns the information used to tell the implementation information about
1577   // what sort of precision is requested. The meaning of the field is backend
1578   // specific. At the moment, it is only supported for kConvolution and kDot.
1579   // Transformations on one kDot or kConvolution to another will preserve this
1580   // information. Transformations to other HLOs will not preserve this
1581   // information but it is presumed that the alternate lowering is strictly
1582   // superior.
precision_config()1583   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1584   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1585 
1586   // Returns a serialized representation of this instruction.
1587   HloInstructionProto ToProto() const override;
1588 
1589  private:
1590   std::vector<string> ExtraAttributesToStringImpl(
1591       const HloPrintOptions& options) const override;
1592   bool IdenticalSlowPath(
1593       const HloInstruction& other,
1594       const std::function<bool(const HloComputation*, const HloComputation*)>&
1595           eq_computations) const override;
1596   // Implementation for non-common logic of CloneWithNewOperands.
1597   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1598       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1599       HloCloneContext* context) const override;
1600   // Returns the dump string of the dot dimension numbers.
1601   string DotDimensionNumbersToString() const;
1602 
1603   // Describes the dimension numbers used for a dot.
1604   DotDimensionNumbers dot_dimension_numbers_;
1605 
1606   // Information used to communicate to the implementation about the algorithm
1607   // used to produce results. See the documentation on precision_config().
1608   PrecisionConfig precision_config_;
1609 };
1610 
1611 class HloDomainInstruction : public HloInstruction {
1612  public:
1613   explicit HloDomainInstruction(
1614       const Shape& shape, HloInstruction* operand,
1615       std::unique_ptr<DomainMetadata> operand_side_metadata,
1616       std::unique_ptr<DomainMetadata> user_side_metadata);
1617 
1618   // Returns a serialized representation of this instruction.
1619   HloInstructionProto ToProto() const override;
1620 
1621   // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()1622   const DomainMetadata& operand_side_metadata() const {
1623     return *operand_side_metadata_;
1624   }
1625   // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()1626   const DomainMetadata& user_side_metadata() const {
1627     return *user_side_metadata_;
1628   }
1629 
1630  private:
1631   std::vector<string> ExtraAttributesToStringImpl(
1632       const HloPrintOptions& options) const override;
1633   bool IdenticalSlowPath(
1634       const HloInstruction& other,
1635       const std::function<bool(const HloComputation*, const HloComputation*)>&
1636           eq_computations) const override;
1637   // Implementation for non-common logic of CloneWithNewOperands.
1638   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1639       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1640       HloCloneContext* context) const override;
1641 
1642   std::unique_ptr<DomainMetadata> operand_side_metadata_;
1643   std::unique_ptr<DomainMetadata> user_side_metadata_;
1644 };
1645 
1646 class HloGetDimensionSizeInstruction : public HloInstruction {
1647  public:
1648   explicit HloGetDimensionSizeInstruction(const Shape& shape,
1649                                           HloInstruction* operand,
1650                                           int64 dimension);
1651 
1652   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1653   int64 dimension() const { return dimension_; }
1654   // Returns a serialized representation of this instruction.
1655   HloInstructionProto ToProto() const override;
1656 
1657  private:
1658   std::vector<string> ExtraAttributesToStringImpl(
1659       const HloPrintOptions& options) const override;
1660   bool IdenticalSlowPath(
1661       const HloInstruction& other,
1662       const std::function<bool(const HloComputation*, const HloComputation*)>&
1663           eq_computations) const override;
1664   // Implementation for non-common logic of CloneWithNewOperands.
1665   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1666       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1667       HloCloneContext* context) const override;
1668 
1669   int64 dimension_;
1670 };
1671 
1672 class HloSetDimensionSizeInstruction : public HloInstruction {
1673  public:
1674   explicit HloSetDimensionSizeInstruction(const Shape& shape,
1675                                           HloInstruction* operand,
1676                                           HloInstruction* val, int64 dimension);
1677 
1678   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1679   int64 dimension() const { return dimension_; }
1680   // Returns a serialized representation of this instruction.
1681   HloInstructionProto ToProto() const override;
1682 
1683  private:
1684   std::vector<string> ExtraAttributesToStringImpl(
1685       const HloPrintOptions& options) const override;
1686   bool IdenticalSlowPath(
1687       const HloInstruction& other,
1688       const std::function<bool(const HloComputation*, const HloComputation*)>&
1689           eq_computations) const override;
1690   // Implementation for non-common logic of CloneWithNewOperands.
1691   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1692       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1693       HloCloneContext* context) const override;
1694 
1695   int64 dimension_;
1696 };
1697 
1698 class HloRngGetAndUpdateStateInstruction : public HloInstruction {
1699  public:
1700   explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta);
1701 
1702   // Returns the delta value.
delta()1703   int64 delta() const { return delta_; }
set_delta(int64 delta)1704   void set_delta(int64 delta) { delta_ = delta; }
1705   // Returns a serialized representation of this instruction.
1706   HloInstructionProto ToProto() const override;
1707 
1708  private:
1709   std::vector<string> ExtraAttributesToStringImpl(
1710       const HloPrintOptions& options) const override;
1711   bool IdenticalSlowPath(
1712       const HloInstruction& other,
1713       const std::function<bool(const HloComputation*, const HloComputation*)>&
1714           eq_computations) const override;
1715   // Implementation for non-common logic of CloneWithNewOperands.
1716   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1717       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1718       HloCloneContext* context) const override;
1719 
1720   int64 delta_;
1721 };
1722 
1723 }  // namespace xla
1724 
1725 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
1726