1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // All HloInstruction subclasses are put in this file.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20
21 #include <functional>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36
37 namespace xla {
38
39 // Base class for instructions with a dimensions vector.
40 class HloDimensionsInstruction : public HloInstruction {
41 public:
dimensions()42 absl::Span<const int64_t> dimensions() const override { return dimensions_; }
43
mutable_dimensions()44 std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; }
45
46 HloInstructionProto ToProto() const override;
47
ClassOf(const HloInstruction * hlo)48 static bool ClassOf(const HloInstruction* hlo) {
49 switch (hlo->opcode()) {
50 case HloOpcode::kBroadcast:
51 case HloOpcode::kConcatenate:
52 case HloOpcode::kReduce:
53 case HloOpcode::kReverse:
54 case HloOpcode::kSort:
55 case HloOpcode::kTranspose:
56 return true;
57 default:
58 return false;
59 }
60 }
61
62 protected:
HloDimensionsInstruction(HloOpcode opcode,const Shape & shape,absl::Span<const int64_t> dimensions)63 HloDimensionsInstruction(HloOpcode opcode, const Shape& shape,
64 absl::Span<const int64_t> dimensions)
65 : HloInstruction(opcode, shape),
66 dimensions_(dimensions.begin(), dimensions.end()) {}
67 std::vector<std::string> ExtraAttributesToStringImpl(
68 const HloPrintOptions& options) const override;
69
70 bool IdenticalSlowPath(
71 const HloInstruction& other,
72 const std::function<bool(const HloComputation*, const HloComputation*)>&
73 eq_computations) const override;
74
75 std::vector<int64_t> dimensions_;
76 };
77
78 class HloBatchNormInstruction : public HloInstruction {
79 public:
80 // Returns feature_index field associated with the instruction. The index
81 // represents the index of the feature dimension.
feature_index()82 int64_t feature_index() const { return feature_index_; }
83
84 // Returns a epsilon value associated with the instruction. The is a small
85 // number added to the variance to avoid divide-by-zero error.
epsilon()86 float epsilon() const { return epsilon_; }
87
88 // Returns a serialized representation of this instruction.
89 HloInstructionProto ToProto() const override;
90
ClassOf(const HloInstruction * hlo)91 static bool ClassOf(const HloInstruction* hlo) {
92 switch (hlo->opcode()) {
93 case HloOpcode::kBatchNormGrad:
94 case HloOpcode::kBatchNormInference:
95 case HloOpcode::kBatchNormTraining:
96 return true;
97 default:
98 return false;
99 }
100 }
101
102 protected:
103 explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
104 HloInstruction* operand,
105 HloInstruction* scale, float epsilon,
106 int64_t feature_index);
107
108 private:
109 std::vector<std::string> ExtraAttributesToStringImpl(
110 const HloPrintOptions& options) const override;
111 bool IdenticalSlowPath(
112 const HloInstruction& other,
113 const std::function<bool(const HloComputation*, const HloComputation*)>&
114 eq_computations) const override;
115 // A small float number added to the variance to avoid divide-by-zero error.
116 float epsilon_ = 0.0f;
117
118 // An integer value representing the index of the feature dimension.
119 int64_t feature_index_ = -1;
120 };
121
122 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
123 public:
124 explicit HloBatchNormTrainingInstruction(
125 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
126 HloInstruction* offset, float epsilon, int64_t feature_index);
127
ClassOf(const HloInstruction * hlo)128 static bool ClassOf(const HloInstruction* hlo) {
129 return hlo->opcode() == HloOpcode::kBatchNormTraining;
130 }
131
132 private:
133 // Implementation for non-common logic of CloneWithNewOperands.
134 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
135 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
136 HloCloneContext* context) const override;
137 };
138
139 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
140 public:
141 explicit HloBatchNormInferenceInstruction(
142 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
143 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
144 float epsilon, int64_t feature_index);
145
ClassOf(const HloInstruction * hlo)146 static bool ClassOf(const HloInstruction* hlo) {
147 return hlo->opcode() == HloOpcode::kBatchNormInference;
148 }
149
150 private:
151 // Implementation for non-common logic of CloneWithNewOperands.
152 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
153 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
154 HloCloneContext* context) const override;
155 };
156
157 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
158 public:
159 explicit HloBatchNormGradInstruction(
160 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
161 HloInstruction* mean, HloInstruction* variance,
162 HloInstruction* grad_output, float epsilon, int64_t feature_index);
163
ClassOf(const HloInstruction * hlo)164 static bool ClassOf(const HloInstruction* hlo) {
165 return hlo->opcode() == HloOpcode::kBatchNormGrad;
166 }
167
168 private:
169 // Implementation for non-common logic of CloneWithNewOperands.
170 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
171 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
172 HloCloneContext* context) const override;
173 };
174
175 class HloFftInstruction : public HloInstruction {
176 public:
177 explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
178 FftType fft_type,
179 absl::Span<const int64_t> fft_length);
fft_type()180 FftType fft_type() const { return fft_type_; }
181
fft_length()182 const std::vector<int64_t>& fft_length() const { return fft_length_; }
183
184 // Returns a serialized representation of this instruction.
185 HloInstructionProto ToProto() const override;
186
ClassOf(const HloInstruction * hlo)187 static bool ClassOf(const HloInstruction* hlo) {
188 return hlo->opcode() == HloOpcode::kFft;
189 }
190
191 private:
192 std::vector<std::string> ExtraAttributesToStringImpl(
193 const HloPrintOptions& options) const override;
194 bool IdenticalSlowPath(
195 const HloInstruction& other,
196 const std::function<bool(const HloComputation*, const HloComputation*)>&
197 eq_computations) const override;
198
199 // Implementation for non-common logic of CloneWithNewOperands.
200 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
201 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
202 HloCloneContext* context) const override;
203
204 // Describes FFT type for an FFT instruction.
205 FftType fft_type_ = FftType::FFT;
206
207 // Indicates the FFT length for an FFT instruction.
208 std::vector<int64_t> fft_length_;
209 };
210
211 class HloAsyncInstruction : public HloInstruction {
212 public:
213 HloAsyncInstruction(
214 HloOpcode opcode, const Shape& shape,
215 absl::Span<HloInstruction* const> operands,
216 HloComputation* async_computation,
217 std::optional<int64_t> async_group_id = std::nullopt,
218 absl::string_view async_execution_thread = kMainExecutionThread);
219 HloAsyncInstruction(
220 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
221 HloComputation* async_computation,
222 std::optional<int64_t> async_group_id = std::nullopt,
223 absl::string_view async_execution_thread = kMainExecutionThread);
224
225 ~HloAsyncInstruction() override;
226 // When an async instruction is being destructed, remove it from the vector of
227 // pointers of its called computation, to avoid referencing freed memory.
228 void ClearAsyncComputationInstruction();
229
230 HloInstruction* async_wrapped_instruction() const;
231 HloOpcode async_wrapped_opcode() const;
232
233 // Async group id is a unique id given to a group of async operations that
234 // consist of one async start, one async done, and zero or more async update
235 // operations. The async group participates in a single async operation. The
236 // async operation canonicalizer pass assigns async group ids.
async_group_id()237 std::optional<int64_t> async_group_id() const { return async_group_id_; }
238
239 // Async thread name is a unique thread name for one or more async groups.
240 // Typically one HLO module contains a main thread as well as one or more
241 // parallel threads.
async_execution_thread()242 absl::string_view async_execution_thread() const {
243 return async_execution_thread_;
244 }
245 void set_async_group_id(std::optional<int64_t> async_group_id);
246 void set_async_execution_thread(absl::string_view async_execution_thread);
247 HloInstructionProto ToProto() const override;
248
ClassOf(const HloInstruction * hlo)249 static bool ClassOf(const HloInstruction* hlo) {
250 switch (hlo->opcode()) {
251 case HloOpcode::kAsyncStart:
252 case HloOpcode::kAsyncUpdate:
253 case HloOpcode::kAsyncDone:
254 return true;
255 default:
256 return false;
257 }
258 }
259
260 private:
261 std::vector<std::string> ExtraAttributesToStringImpl(
262 const HloPrintOptions& options) const override;
263 bool IdenticalSlowPath(
264 const HloInstruction& other,
265 const std::function<bool(const HloComputation*, const HloComputation*)>&
266 eq_computations) const override;
267 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
268 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
269 HloCloneContext* context) const override;
270 std::optional<int64_t> async_group_id_;
271 std::string async_execution_thread_ = kMainExecutionThread;
272 };
273
274 class HloCopyStartInstruction : public HloInstruction {
275 public:
276 explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
277 bool is_cross_program_prefetch);
278
is_cross_program_prefetch()279 bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
280 HloInstructionProto ToProto() const override;
281
ClassOf(const HloInstruction * hlo)282 static bool ClassOf(const HloInstruction* hlo) {
283 return hlo->opcode() == HloOpcode::kCopyStart;
284 }
285
286 private:
287 std::vector<std::string> ExtraAttributesToStringImpl(
288 const HloPrintOptions& options) const override;
289 bool IdenticalSlowPath(
290 const HloInstruction& other,
291 const std::function<bool(const HloComputation*, const HloComputation*)>&
292 eq_computations) const override;
293 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
294 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
295 HloCloneContext* context) const override;
296
297 bool is_cross_program_prefetch_;
298 };
299
300 class HloCompareInstruction : public HloInstruction {
301 public:
302 explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
303 HloInstruction* rhs,
304 ComparisonDirection direction,
305 std::optional<Comparison::Type> type);
direction()306 ComparisonDirection direction() const { return compare_.GetDirection(); }
order()307 ComparisonOrder order() const { return compare_.GetOrder(); }
type()308 Comparison::Type type() const { return compare_.GetType(); }
309 HloInstructionProto ToProto() const override;
310
ClassOf(const HloInstruction * hlo)311 static bool ClassOf(const HloInstruction* hlo) {
312 return hlo->opcode() == HloOpcode::kCompare;
313 }
314
315 private:
316 std::vector<std::string> ExtraAttributesToStringImpl(
317 const HloPrintOptions& options) const override;
318 bool IdenticalSlowPath(
319 const HloInstruction& other,
320 const std::function<bool(const HloComputation*, const HloComputation*)>&
321 eq_computations) const override;
322 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
323 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
324 HloCloneContext* context) const override;
325
326 Comparison compare_;
327 };
328
329 class HloTriangularSolveInstruction : public HloInstruction {
330 public:
331 explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
332 HloInstruction* b,
333 const TriangularSolveOptions& options);
triangular_solve_options()334 const TriangularSolveOptions& triangular_solve_options() const {
335 return triangular_solve_options_;
336 }
337
338 // Returns a serialized representation of this instruction.
339 HloInstructionProto ToProto() const override;
340
ClassOf(const HloInstruction * hlo)341 static bool ClassOf(const HloInstruction* hlo) {
342 return hlo->opcode() == HloOpcode::kTriangularSolve;
343 }
344
345 private:
346 std::vector<std::string> ExtraAttributesToStringImpl(
347 const HloPrintOptions& options) const override;
348 bool IdenticalSlowPath(
349 const HloInstruction& other,
350 const std::function<bool(const HloComputation*, const HloComputation*)>&
351 eq_computations) const override;
352
353 // Implementation for non-common logic of CloneWithNewOperands.
354 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
355 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
356 HloCloneContext* context) const override;
357
358 TriangularSolveOptions triangular_solve_options_;
359 };
360
361 class HloCholeskyInstruction : public HloInstruction {
362 public:
363 explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
364 const CholeskyOptions& options);
cholesky_options()365 const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
366
367 // Returns a serialized representation of this instruction.
368 HloInstructionProto ToProto() const override;
369
ClassOf(const HloInstruction * hlo)370 static bool ClassOf(const HloInstruction* hlo) {
371 return hlo->opcode() == HloOpcode::kCholesky;
372 }
373
374 private:
375 std::vector<std::string> ExtraAttributesToStringImpl(
376 const HloPrintOptions& options) const override;
377 bool IdenticalSlowPath(
378 const HloInstruction& other,
379 const std::function<bool(const HloComputation*, const HloComputation*)>&
380 eq_computations) const override;
381
382 // Implementation for non-common logic of CloneWithNewOperands.
383 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
384 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
385 HloCloneContext* context) const override;
386
387 CholeskyOptions cholesky_options_;
388 };
389
390 // Class that represents instructions that synchronize and transfer data between
391 // partitioned devices. Send/Recv and collective instructions (AllReduce,
392 // AllToAll, CollectivePermute) belong to this instruction type. A group of
393 // instructions (of the same opcode) with the same channel_id communicate during
394 // execution.
395 class HloChannelInstruction : public HloInstruction {
396 public:
397 // Returns the channel id associated with the instruction. The id is
398 // shared between each Send/Recv pair or a group of collective instructions
399 // and is globally unique to identify each channel.
channel_id()400 std::optional<int64_t> channel_id() const { return channel_id_; }
401 void set_channel_id(const std::optional<int64_t>& channel_id);
402
403 // Whether this instruction is identical to `other` except for the values of
404 // channel IDs, as long as both have channel IDs or neither has a channel ID.
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)405 virtual bool IdenticalSlowPathIgnoringChannelIdValues(
406 const HloInstruction& other,
407 const std::function<bool(const HloComputation*, const HloComputation*)>&
408 eq_computations) const {
409 return channel_id_.has_value() == other.channel_id().has_value();
410 }
411
412 static bool ClassOf(const HloInstruction* hlo);
413
414 protected:
415 explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
416 const std::optional<int64_t>& channel_id);
417
418 HloInstructionProto ToProto() const override;
419
420 std::vector<std::string> ExtraAttributesToStringImpl(
421 const HloPrintOptions& options) const override;
422
423 // Do not override IdenticalSlowPath(). Override
424 // IdenticalSlowPathIgnoringChannelIdValues() instead.
425 bool IdenticalSlowPath(
426 const HloInstruction& other,
427 const std::function<bool(const HloComputation*, const HloComputation*)>&
428 eq_computations) const final;
429
430 std::optional<int64_t> channel_id_;
431 };
432
433 class HloSendRecvInstruction : public HloChannelInstruction {
434 public:
435 // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()436 bool is_host_transfer() const { return is_host_transfer_; }
437
438 // Returns a serialized representation of this instruction.
439 HloInstructionProto ToProto() const override;
440
ClassOf(const HloInstruction * hlo)441 static bool ClassOf(const HloInstruction* hlo) {
442 switch (hlo->opcode()) {
443 case HloOpcode::kSend:
444 case HloOpcode::kSendDone:
445 case HloOpcode::kRecv:
446 case HloOpcode::kRecvDone:
447 return true;
448 default:
449 return false;
450 }
451 }
452
453 protected:
454 explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
455 int64_t channel_id, bool is_host_transfer);
456
457 private:
458 std::vector<std::string> ExtraAttributesToStringImpl(
459 const HloPrintOptions& options) const override;
460 bool IdenticalSlowPathIgnoringChannelIdValues(
461 const HloInstruction& other,
462 const std::function<bool(const HloComputation*, const HloComputation*)>&
463 eq_computations) const override;
464 // Whether this send/recv instruction sends data to/from the host.
465 bool is_host_transfer_;
466 };
467
468 class HloSendInstruction : public HloSendRecvInstruction {
469 public:
470 explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
471 int64_t channel_id, bool is_host_transfer);
472
ClassOf(const HloInstruction * hlo)473 static bool ClassOf(const HloInstruction* hlo) {
474 return hlo->opcode() == HloOpcode::kSend;
475 }
476
477 private:
478 // Implementation for non-common logic of CloneWithNewOperands.
479 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
480 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
481 HloCloneContext* context) const override;
482 };
483
484 class HloSendDoneInstruction : public HloSendRecvInstruction {
485 public:
486 explicit HloSendDoneInstruction(HloSendInstruction* operand,
487 bool is_host_transfer);
488
ClassOf(const HloInstruction * hlo)489 static bool ClassOf(const HloInstruction* hlo) {
490 return hlo->opcode() == HloOpcode::kSendDone;
491 }
492
493 private:
494 // Implementation for non-common logic of CloneWithNewOperands.
495 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
496 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
497 HloCloneContext* context) const override;
498 };
499
500 class HloRecvInstruction : public HloSendRecvInstruction {
501 public:
502 explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
503 int64_t channel_id, bool is_host_transfer);
504
ClassOf(const HloInstruction * hlo)505 static bool ClassOf(const HloInstruction* hlo) {
506 return hlo->opcode() == HloOpcode::kRecv;
507 }
508
509 private:
510 // Implementation for non-common logic of CloneWithNewOperands.
511 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
512 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
513 HloCloneContext* context) const override;
514 };
515
516 class HloRecvDoneInstruction : public HloSendRecvInstruction {
517 public:
518 explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
519 bool is_host_transfer);
520
ClassOf(const HloInstruction * hlo)521 static bool ClassOf(const HloInstruction* hlo) {
522 return hlo->opcode() == HloOpcode::kRecvDone;
523 }
524
525 private:
526 // Implementation for non-common logic of CloneWithNewOperands.
527 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
528 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
529 HloCloneContext* context) const override;
530 };
531
532 class HloCollectiveInstruction : public HloChannelInstruction {
533 public:
replica_groups()534 const std::vector<ReplicaGroup>& replica_groups() const {
535 return replica_groups_;
536 }
537
538 // Returns true if the layout of the AllReduce is enforced by XLA client (as
539 // the layout set in the shape). The only reason for the client to set the
540 // layout is to separately compile computations that communicate with
541 // AllReduce. Since this field is only set `true` by the client, the compiler
542 // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
543 // `false` for all other cases.
544 //
545 // When this is `true`, there may be communication endpoints outside the
546 // current compilation unit, so the compiler considers this AllReduce as
547 // side-effecting to disable compiler transformations. The compiler is free to
548 // transform unconstrained AllReduces differently across compilation units.
549 // It is an error for an HloModule to have a mix of constrained and
550 // unconstrained AllReduce instructions (checked by HloVerifier).
constrain_layout()551 bool constrain_layout() const { return constrain_layout_; }
552
553 static bool ClassOf(const HloInstruction* hlo);
554
555 protected:
556 explicit HloCollectiveInstruction(
557 HloOpcode opcode, const Shape& shape,
558 absl::Span<HloInstruction* const> operands,
559 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
560 const std::optional<int64_t>& channel_id);
561
562 HloInstructionProto ToProto() const override;
563
564 std::vector<std::string> ExtraAttributesToStringImpl(
565 const HloPrintOptions& options) const override;
566 bool IdenticalSlowPathIgnoringChannelIdValues(
567 const HloInstruction& other,
568 const std::function<bool(const HloComputation*, const HloComputation*)>&
569 eq_computations) const override;
570
571 std::vector<ReplicaGroup> replica_groups_;
572 bool constrain_layout_;
573 };
574
575 class HloAllGatherInstruction : public HloCollectiveInstruction {
576 public:
577 explicit HloAllGatherInstruction(
578 HloOpcode opcode, const Shape& shape,
579 absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
580 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
581 const std::optional<int64_t>& channel_id, bool use_global_device_ids);
582 // Same as HloAllReduceInstruction::use_global_device_ids.
use_global_device_ids()583 bool use_global_device_ids() const { return use_global_device_ids_; }
584
585 // The dimension on which data from different participants are concatenated.
all_gather_dimension()586 int64_t all_gather_dimension() const { return all_gather_dimension_; }
dimensions()587 absl::Span<const int64_t> dimensions() const override {
588 return absl::MakeConstSpan(&all_gather_dimension_, 1);
589 }
590
set_all_gather_dimension(int64_t dim)591 void set_all_gather_dimension(int64_t dim) { all_gather_dimension_ = dim; }
592
ClassOf(const HloInstruction * hlo)593 static bool ClassOf(const HloInstruction* hlo) {
594 return hlo->opcode() == HloOpcode::kAllGather ||
595 hlo->opcode() == HloOpcode::kAllGatherStart;
596 }
597
598 protected:
599 std::vector<std::string> ExtraAttributesToStringImpl(
600 const HloPrintOptions& options) const override;
601 HloInstructionProto ToProto() const override;
602
603 private:
604 bool IdenticalSlowPathIgnoringChannelIdValues(
605 const HloInstruction& other,
606 const std::function<bool(const HloComputation*, const HloComputation*)>&
607 eq_computations) const override;
608
609 // Implementation for non-common logic of CloneWithNewOperands.
610 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
611 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
612 HloCloneContext* context) const override;
613
614 int64_t all_gather_dimension_;
615 bool use_global_device_ids_;
616 };
617
618 // Base class for all-reduce and all-reduce scatter instructions.
619 class HloAllReduceInstructionBase : public HloCollectiveInstruction {
620 public:
621 explicit HloAllReduceInstructionBase(
622 HloOpcode opcode, const Shape& shape,
623 absl::Span<HloInstruction* const> operands,
624 HloComputation* reduce_computation,
625 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
626 const std::optional<int64_t>& channel_id, bool use_global_device_ids);
627
628 // Returns true if the ids in the ReplicaGroup config represent a global id of
629 // (replica_id * partition_count + partition_id) instead of a replica id.
630 // This enables more flexible grouping of devices if this all-reduce is both
631 // cross-partition and cross-replica.
632 //
633 // For example with 2 replicas and 4 partitions,
634 // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
635 // group[0] = (0,0), (0,1), (1,0), (1,1)
636 // group[1] = (0,2), (0,3), (1,2), (1,3)
637 // where each pair is (replica_id, partition_id).
use_global_device_ids()638 bool use_global_device_ids() const { return use_global_device_ids_; }
set_use_global_device_ids(bool value)639 void set_use_global_device_ids(bool value) { use_global_device_ids_ = value; }
640
641 static bool ClassOf(const HloInstruction* hlo);
642
643 protected:
644 std::vector<std::string> ExtraAttributesToStringImpl(
645 const HloPrintOptions& options) const override;
646 HloInstructionProto ToProto() const override;
647
648 bool IdenticalSlowPathIgnoringChannelIdValues(
649 const HloInstruction& other,
650 const std::function<bool(const HloComputation*, const HloComputation*)>&
651 eq_computations) const override;
652
653 private:
654 bool use_global_device_ids_;
655 };
656
657 class HloAllReduceInstruction : public HloAllReduceInstructionBase {
658 public:
659 using HloAllReduceInstructionBase::HloAllReduceInstructionBase;
660
ClassOf(const HloInstruction * hlo)661 static bool ClassOf(const HloInstruction* hlo) {
662 return hlo->opcode() == HloOpcode::kAllReduce ||
663 hlo->opcode() == HloOpcode::kAllReduceStart;
664 }
665
666 // Returns true if the AllReduce does no communication, so it's equivalent
667 // to a mem copy.
668 bool IsNoop() const;
669
670 private:
671 // Implementation for non-common logic of CloneWithNewOperands.
672 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
673 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
674 HloCloneContext* context) const override;
675 };
676
677 class HloReduceScatterInstruction : public HloAllReduceInstructionBase {
678 public:
679 explicit HloReduceScatterInstruction(
680 const Shape& shape, absl::Span<HloInstruction* const> operands,
681 HloComputation* reduce_computation,
682 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
683 const std::optional<int64_t>& channel_id, bool use_global_device_ids,
684 int64_t scatter_dimension);
685
686 // The dimension on which reduced data is scattered to different participants.
scatter_dimension()687 int64_t scatter_dimension() const { return scatter_dimension_; }
dimensions()688 absl::Span<const int64_t> dimensions() const override {
689 return absl::MakeConstSpan(&scatter_dimension_, 1);
690 }
691
ClassOf(const HloInstruction * hlo)692 static bool ClassOf(const HloInstruction* hlo) {
693 return hlo->opcode() == HloOpcode::kReduceScatter;
694 }
695
696 protected:
697 std::vector<std::string> ExtraAttributesToStringImpl(
698 const HloPrintOptions& options) const override;
699 HloInstructionProto ToProto() const override;
700
701 private:
702 bool IdenticalSlowPathIgnoringChannelIdValues(
703 const HloInstruction& other,
704 const std::function<bool(const HloComputation*, const HloComputation*)>&
705 eq_computations) const override;
706
707 // Implementation for non-common logic of CloneWithNewOperands.
708 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
709 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
710 HloCloneContext* context) const override;
711
712 int64_t scatter_dimension_;
713 };
714
715 class HloAllToAllInstruction : public HloCollectiveInstruction {
716 public:
717 explicit HloAllToAllInstruction(
718 const Shape& shape, absl::Span<HloInstruction* const> operands,
719 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
720 const std::optional<int64_t>& channel_id,
721 const std::optional<int64_t>& split_dimension);
722
723 // AllToAll can optionally take a split dimension, which means that this
724 // AllToAll takes a single (flattened) array operand and produces an array
725 // output (instead of taking a list of operands and producing a tuple).
726 //
727 // split_dimension specifies which dimension in the operand is split across
728 // devices in each replica_group, and also means the concatenated dimension
729 // on the output (i.e., input and the output shapes are the same).
split_dimension()730 std::optional<int64_t> split_dimension() const { return split_dimension_; }
set_split_dimension(int64_t dim)731 void set_split_dimension(int64_t dim) { split_dimension_ = dim; }
732
ClassOf(const HloInstruction * hlo)733 static bool ClassOf(const HloInstruction* hlo) {
734 return hlo->opcode() == HloOpcode::kAllToAll;
735 }
736
737 protected:
738 std::vector<std::string> ExtraAttributesToStringImpl(
739 const HloPrintOptions& options) const override;
740 HloInstructionProto ToProto() const override;
741
742 private:
743 bool IdenticalSlowPathIgnoringChannelIdValues(
744 const HloInstruction& other,
745 const std::function<bool(const HloComputation*, const HloComputation*)>&
746 eq_computations) const override;
747
748 // Implementation for non-common logic of CloneWithNewOperands.
749 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
750 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
751 HloCloneContext* context) const override;
752
753 std::optional<int64_t> split_dimension_;
754 };
755
756 class HloCollectivePermuteInstruction : public HloChannelInstruction {
757 public:
758 explicit HloCollectivePermuteInstruction(
759 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
760 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
761 const std::optional<int64_t>& channel_id);
762
763 explicit HloCollectivePermuteInstruction(
764 HloOpcode opcode, const Shape& shape, HloInstruction* input,
765 HloInstruction* output, HloInstruction* input_start_indices,
766 HloInstruction* output_start_indices,
767 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
768 absl::Span<const std::vector<int64_t>> slice_sizes,
769 const std::optional<int64_t>& channel_id);
770
source_target_pairs()771 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const {
772 return source_target_pairs_;
773 }
774
dynamic_slice_sizes_list()775 const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const {
776 return slice_sizes_;
777 }
778
779 // Returns a serialized representation of this instruction.
780 HloInstructionProto ToProto() const override;
781
ClassOf(const HloInstruction * hlo)782 static bool ClassOf(const HloInstruction* hlo) {
783 return hlo->opcode() == HloOpcode::kCollectivePermute ||
784 hlo->opcode() == HloOpcode::kCollectivePermuteStart;
785 }
786
787 private:
788 std::vector<std::string> ExtraAttributesToStringImpl(
789 const HloPrintOptions& options) const override;
790 bool IdenticalSlowPathIgnoringChannelIdValues(
791 const HloInstruction& other,
792 const std::function<bool(const HloComputation*, const HloComputation*)>&
793 eq_computations) const override;
794
795 // Implementation for non-common logic of CloneWithNewOperands.
796 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
797 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
798 HloCloneContext* context) const override;
799
800 const std::vector<std::pair<int64_t, int64_t>> source_target_pairs_;
801 const std::vector<std::vector<int64_t>> slice_sizes_;
802 };
803
ClassOf(const HloInstruction * hlo)804 inline bool HloAllReduceInstructionBase::ClassOf(const HloInstruction* hlo) {
805 return HloAllReduceInstruction::ClassOf(hlo) ||
806 hlo->opcode() == HloOpcode::kReduceScatter;
807 }
808
ClassOf(const HloInstruction * hlo)809 inline bool HloCollectiveInstruction::ClassOf(const HloInstruction* hlo) {
810 return HloAllReduceInstructionBase::ClassOf(hlo) ||
811 HloAllGatherInstruction::ClassOf(hlo) ||
812 HloAllToAllInstruction::ClassOf(hlo);
813 }
814
ClassOf(const HloInstruction * hlo)815 inline bool HloChannelInstruction::ClassOf(const HloInstruction* hlo) {
816 return HloCollectiveInstruction::ClassOf(hlo) ||
817 HloCollectivePermuteInstruction::ClassOf(hlo) ||
818 HloSendRecvInstruction::ClassOf(hlo);
819 }
820
821 class HloReverseInstruction : public HloDimensionsInstruction {
822 public:
823 explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
824 absl::Span<const int64_t> dimensions);
825
ClassOf(const HloInstruction * hlo)826 static bool ClassOf(const HloInstruction* hlo) {
827 return hlo->opcode() == HloOpcode::kReverse;
828 }
829
830 private:
831 // Implementation for non-common logic of CloneWithNewOperands.
832 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
833 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
834 HloCloneContext* context) const override;
835 };
836
837 class HloConcatenateInstruction : public HloDimensionsInstruction {
838 public:
839 explicit HloConcatenateInstruction(const Shape& shape,
840 absl::Span<HloInstruction* const> operands,
841 int64_t dimension);
842 // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()843 int64_t concatenate_dimension() const override {
844 return HloInstruction::dimensions(0);
845 }
846
ClassOf(const HloInstruction * hlo)847 static bool ClassOf(const HloInstruction* hlo) {
848 return hlo->opcode() == HloOpcode::kConcatenate;
849 }
850
851 private:
852 // Implementation for non-common logic of CloneWithNewOperands.
853 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
854 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
855 HloCloneContext* context) const override;
856 };
857
858 class HloReduceInstruction : public HloDimensionsInstruction {
859 public:
860 explicit HloReduceInstruction(const Shape& shape,
861 absl::Span<HloInstruction* const> args,
862 absl::Span<const int64_t> dimensions_to_reduce,
863 HloComputation* reduce_computation);
864
865 // Returns the number of input arrays (and, consequentially, the number of
866 // init values) this reduce has.
input_count()867 int64_t input_count() const { return operand_count() / 2; }
868
869 // Returns the input tensors to be reduced.
inputs()870 absl::Span<HloInstruction* const> inputs() const {
871 return absl::MakeSpan(operands()).subspan(0, input_count());
872 }
873
874 // Returns the init values of the reduction.
init_values()875 absl::Span<HloInstruction* const> init_values() const {
876 return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
877 }
878
ClassOf(const HloInstruction * hlo)879 static bool ClassOf(const HloInstruction* hlo) {
880 return hlo->opcode() == HloOpcode::kReduce;
881 }
882
883 private:
884 bool IdenticalSlowPath(
885 const HloInstruction& other,
886 const std::function<bool(const HloComputation*, const HloComputation*)>&
887 eq_computations) const override;
888 // Implementation for non-common logic of CloneWithNewOperands.
889 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
890 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
891 HloCloneContext* context) const override;
892 };
893
894 class HloSortInstruction : public HloDimensionsInstruction {
895 public:
896 explicit HloSortInstruction(const Shape& shape, int64_t dimension,
897 absl::Span<HloInstruction* const> operands,
898 HloComputation* compare, bool is_stable);
899 // Returns the sort dimension for this instruction
sort_dimension()900 int64_t sort_dimension() const { return HloInstruction::dimensions(0); }
901 // Returns a serialized representation of this instruction.
902 HloInstructionProto ToProto() const override;
903 // Returns the key operand to this instruction.
keys()904 const HloInstruction* keys() const { return operand(0); }
mutable_keys()905 HloInstruction* mutable_keys() { return mutable_operand(0); }
906 // Returns the number of value operands.
values_count()907 int64_t values_count() const { return operand_count() - 1; }
is_stable()908 bool is_stable() const { return is_stable_; }
909
ClassOf(const HloInstruction * hlo)910 static bool ClassOf(const HloInstruction* hlo) {
911 return hlo->opcode() == HloOpcode::kSort;
912 }
913
914 private:
915 std::vector<std::string> ExtraAttributesToStringImpl(
916 const HloPrintOptions& options) const override;
917 bool IdenticalSlowPath(
918 const HloInstruction& other,
919 const std::function<bool(const HloComputation*, const HloComputation*)>&
920 eq_computations) const override;
921 // Implementation for non-common logic of CloneWithNewOperands.
922 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
923 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
924 HloCloneContext* context) const override;
925
926 bool is_stable_;
927 };
928
929 class HloTransposeInstruction : public HloDimensionsInstruction {
930 public:
931 explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
932 absl::Span<const int64_t> dimensions);
933 // Returns whether this instruction does a rank-2 transposition.
934 bool IsRank2Transpose() const;
935
ClassOf(const HloInstruction * hlo)936 static bool ClassOf(const HloInstruction* hlo) {
937 return hlo->opcode() == HloOpcode::kTranspose;
938 }
939
940 private:
941 // Implementation for non-common logic of CloneWithNewOperands.
942 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
943 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
944 HloCloneContext* context) const override;
945 };
946
947 class HloBroadcastInstruction : public HloDimensionsInstruction {
948 public:
949 explicit HloBroadcastInstruction(
950 const Shape& shape, HloInstruction* operand,
951 absl::Span<const int64_t> broadcast_dimension);
952
ClassOf(const HloInstruction * hlo)953 static bool ClassOf(const HloInstruction* hlo) {
954 return hlo->opcode() == HloOpcode::kBroadcast;
955 }
956
957 private:
958 // Implementation for non-common logic of CloneWithNewOperands.
959 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
960 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
961 HloCloneContext* context) const override;
962 };
963
964 class HloDynamicReshapeInstruction : public HloInstruction {
965 public:
966 explicit HloDynamicReshapeInstruction(
967 const Shape& shape, HloInstruction* data_operand,
968 absl::Span<HloInstruction* const> dim_sizes);
969
970 // Returns the input dim sizes dimensions, which is operands[1:]
dim_sizes()971 absl::Span<HloInstruction* const> dim_sizes() const {
972 return absl::MakeSpan(operands()).subspan(1, operand_count());
973 }
974
975 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
976 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
977 HloCloneContext* context) const override;
978
979 // Returns the input dim size dimension, which is operands[1+i]
dim_sizes(int64_t i)980 HloInstruction* dim_sizes(int64_t i) const { return operands()[i + 1]; }
981
ClassOf(const HloInstruction * hlo)982 static bool ClassOf(const HloInstruction* hlo) {
983 return hlo->opcode() == HloOpcode::kDynamicReshape;
984 }
985 };
986
987 class HloReshapeInstruction : public HloInstruction {
988 public:
989 explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
990 int64_t inferred_dimension);
inferred_dimension()991 int64_t inferred_dimension() const { return inferred_dimension_; }
992 HloInstructionProto ToProto() const override;
993
ClassOf(const HloInstruction * hlo)994 static bool ClassOf(const HloInstruction* hlo) {
995 return hlo->opcode() == HloOpcode::kReshape;
996 }
997
998 private:
999 std::vector<std::string> ExtraAttributesToStringImpl(
1000 const HloPrintOptions& options) const override;
1001 bool IdenticalSlowPath(
1002 const HloInstruction& other,
1003 const std::function<bool(const HloComputation*, const HloComputation*)>&
1004 eq_computations) const override;
1005 // Implementation for non-common logic of CloneWithNewOperands.
1006 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1007 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1008 HloCloneContext* context) const override;
1009 int64_t inferred_dimension_;
1010 };
1011
1012 class HloMapInstruction : public HloInstruction {
1013 public:
1014 explicit HloMapInstruction(const Shape& shape,
1015 absl::Span<HloInstruction* const> operands,
1016 HloComputation* map_computation);
1017 // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1018 absl::Span<const int64_t> dimensions() const override { return dimensions_; }
1019
mutable_dimensions()1020 std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; }
1021 // Returns a serialized representation of this instruction.
1022 HloInstructionProto ToProto() const override;
1023
ClassOf(const HloInstruction * hlo)1024 static bool ClassOf(const HloInstruction* hlo) {
1025 return hlo->opcode() == HloOpcode::kMap;
1026 }
1027
1028 private:
1029 bool IsElementwiseImpl(
1030 const std::optional<int64_t>& operand_idx) const override;
1031 std::vector<std::string> ExtraAttributesToStringImpl(
1032 const HloPrintOptions& options) const override;
1033 bool IdenticalSlowPath(
1034 const HloInstruction& other,
1035 const std::function<bool(const HloComputation*, const HloComputation*)>&
1036 eq_computations) const override;
1037 // Implementation for non-common logic of CloneWithNewOperands.
1038 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1039 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1040 HloCloneContext* context) const override;
1041
1042 std::vector<int64_t> dimensions_;
1043 };
1044
1045 class HloSliceInstruction : public HloInstruction {
1046 public:
1047 explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
1048 absl::Span<const int64_t> start_indices,
1049 absl::Span<const int64_t> limit_indices,
1050 absl::Span<const int64_t> strides);
1051
1052 HloInstructionProto ToProto() const override;
1053
1054 // Returns the start index in the given dimension for a slice node.
slice_starts(int64_t dimension)1055 int64_t slice_starts(int64_t dimension) const {
1056 return slice_starts_[dimension];
1057 }
slice_starts()1058 const std::vector<int64_t>& slice_starts() const { return slice_starts_; }
mutable_slice_starts()1059 std::vector<int64_t>* mutable_slice_starts() { return &slice_starts_; }
1060
1061 // Returns the (exclusive) limit index in the given dimension for a slice
1062 // node.
slice_limits(int64_t dimension)1063 int64_t slice_limits(int64_t dimension) const {
1064 return slice_limits_[dimension];
1065 }
slice_limits()1066 const std::vector<int64_t>& slice_limits() const { return slice_limits_; }
mutable_slice_limits()1067 std::vector<int64_t>* mutable_slice_limits() { return &slice_limits_; }
1068
1069 // Returns the stride in the given dimension for a slice node.
slice_strides(int64_t dimension)1070 int64_t slice_strides(int64_t dimension) const {
1071 return slice_strides_[dimension];
1072 }
slice_strides()1073 const std::vector<int64_t>& slice_strides() const { return slice_strides_; }
mutable_slice_strides()1074 std::vector<int64_t>* mutable_slice_strides() { return &slice_strides_; }
1075
ClassOf(const HloInstruction * hlo)1076 static bool ClassOf(const HloInstruction* hlo) {
1077 return hlo->opcode() == HloOpcode::kSlice;
1078 }
1079
1080 private:
1081 std::vector<std::string> ExtraAttributesToStringImpl(
1082 const HloPrintOptions& options) const override;
1083 bool IdenticalSlowPath(
1084 const HloInstruction& other,
1085 const std::function<bool(const HloComputation*, const HloComputation*)>&
1086 eq_computations) const override;
1087 // Implementation for non-common logic of CloneWithNewOperands.
1088 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1089 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1090 HloCloneContext* context) const override;
1091
1092 // Describes the [begin, end) index range for a slice.
1093 std::vector<int64_t> slice_starts_;
1094 std::vector<int64_t> slice_limits_;
1095 std::vector<int64_t> slice_strides_;
1096 };
1097
1098 class HloConstantInstruction : public HloInstruction {
1099 public:
1100 explicit HloConstantInstruction(Literal literal);
1101 explicit HloConstantInstruction(Literal literal, const Shape& shape);
1102 // Used when the literal is too large and dropped.
1103 explicit HloConstantInstruction(const Shape& shape);
1104 // Returns the literal associated with this instruction.
literal()1105 const Literal& literal() const { return *literal_; }
1106 // Returns the (mutable) literal associated with this instruction.
mutable_literal()1107 Literal* mutable_literal() { return &literal_.value(); }
1108 // Returns whether there is literal associated with this instruction.
HasLiteral()1109 bool HasLiteral() const { return literal_.has_value(); }
1110 // Returns a serialized representation of this instruction.
1111 HloInstructionProto ToProto() const override;
1112
1113 // Change the layout for an Constant Hlo instruction to match new_layout. For
1114 // tuple shaped constants shape_index is the path to the internal array
1115 // subshape whose layout needs to be changed.
1116 void RelayoutConstant(const Layout& new_layout,
1117 const ShapeIndex& shape_index = {});
1118
ClassOf(const HloInstruction * hlo)1119 static bool ClassOf(const HloInstruction* hlo) {
1120 return hlo->opcode() == HloOpcode::kConstant;
1121 }
1122
1123 private:
1124 bool IsElementwiseImpl(
1125 const std::optional<int64_t>& operand_idx) const override;
1126 bool IdenticalSlowPath(
1127 const HloInstruction& other,
1128 const std::function<bool(const HloComputation*, const HloComputation*)>&
1129 eq_computations) const override;
1130 std::string OperandsToStringWithCanonicalNameMap(
1131 const HloPrintOptions& options,
1132 CanonicalNameMap* canonical_name_map) const override;
1133 // Implementation for non-common logic of CloneWithNewOperands.
1134 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1135 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1136 HloCloneContext* context) const override;
1137 std::optional<Literal> literal_;
1138 };
1139
1140 // Abstract class that represents an HLO instruction that "calls" a computation.
1141 // Fusion and Call HLOs inherit from this class.
1142 class HloCallableInstruction : public HloInstruction {
1143 public:
1144 HloCallableInstruction(HloOpcode opcode, const Shape& shape);
1145
1146 HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1147 absl::Span<HloInstruction* const> operands);
1148
1149 HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1150 absl::Span<HloInstruction* const> operands,
1151 HloComputation* called_computation,
1152 absl::string_view prefix = "");
1153
1154 HloCallableInstruction(HloOpcode opcode, const Shape& shape,
1155 absl::Span<HloInstruction* const> operands,
1156 absl::Span<HloComputation* const> called_computations);
1157
1158 ~HloCallableInstruction() override;
1159
1160 // Adds a new operand to the callable instruction.
1161 HloInstruction* AddCallOperand(HloInstruction* new_operand);
1162
1163 // Appends (fuses) the given instruction into this callable instruction.
1164 // instruction_to_append is cloned and the clone is placed in the callable
1165 // instruction. The users of instruction_to_append will be redirected to this
1166 // callable instruction. instruction_to_append is unchanged otherwise. When
1167 // add_output is true, a clone of the instruction_to_append will be added as
1168 // additional output resulting in a multi-output callable instruction.
1169 HloInstruction* AppendInstructionIntoCalledComputation(
1170 HloInstruction* instruction_to_append, bool add_output = false);
1171 // Clones the given instruction_to_append and inserts the clone into this
1172 // callable instruction. If add_output is true, a clone of
1173 // instruction_to_append will be in the output of the this callable
1174 // instruction (part of the tuple of the callable root).
1175 HloInstruction* CloneAndAppendInstructionIntoCalledComputation(
1176 HloInstruction* instruction_to_append, bool add_output = false);
1177
1178 // Retrieves the called computations of an HloCallableInstruction that is
1179 // being cloned. If the called computations have not yet been cloned, then
1180 // they are first cloned and added to the context.
1181 absl::InlinedVector<HloComputation*, 1> GetOrCloneCalledComputations(
1182 HloCloneContext* context) const;
1183
1184 HloComputation* called_computation() const;
1185
1186 HloInstruction* called_computation_root() const;
1187
1188 // Recursively sets all nested called computation to have thread name as
1189 // `execution_thread`. if `skip_async_execution_thread_overwrite` is true,
1190 // skip overwrite async instruction and its comptuations thread name
1191 // overwriting.
1192 void RecursivelySetComputationsThreadName(
1193 absl::string_view execution_thread,
1194 bool skip_async_execution_thread_overwrite);
1195
ClassOf(const HloInstruction * hlo)1196 static bool ClassOf(const HloInstruction* hlo) {
1197 return hlo->opcode() == HloOpcode::kFusion ||
1198 hlo->opcode() == HloOpcode::kCall ||
1199 hlo->opcode() == HloOpcode::kCustomCall;
1200 }
1201
1202 protected:
1203 // Returns the default called computation name.
1204 virtual std::string default_called_computation_name() const = 0;
1205 };
1206
1207 class HloFusionInstruction : public HloCallableInstruction {
1208 public:
1209 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
1210 HloInstruction* fused_root);
1211
1212 explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
1213 absl::Span<HloInstruction* const> operands,
1214 HloComputation* fusion_computation,
1215 absl::string_view prefix = "");
1216
1217 ~HloFusionInstruction() override;
1218
1219 void ClearCalledComputations() override;
1220
1221 // When a fusion instruction is being destructed, clear the back pointer of
1222 // its fusion computation, to avoid referencing freed memory.
1223 void ClearFusionComputationInstruction();
1224
1225 std::string ToCategory() const override;
1226 // Returns a serialized representation of this instruction.
1227 HloInstructionProto ToProto() const override;
1228
1229 // Adds a new operand the fusion instruction.
1230 HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1231
1232 // Merges the fused instructions from 'instruction_to_merge' into the
1233 // fused instruction set of 'this', updating operands as necessary.
1234 //
1235 // Precondition: 'instruction_to_merge' must be an operand of 'this'.
1236 void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
1237
1238 // Merges the fused instructions from instruction_to_merge into the fused
1239 // instruction set of 'this' and generates multioutput fusion instructions.
1240 // All the users of instruction_to_merge will be redirected to 'this'
1241 // instruction. instruction_to_merge will be removed from its parent
1242 // computation.
1243 void MergeFusionInstructionIntoMultiOutput(
1244 HloFusionInstruction* instruction_to_merge);
1245
1246 // Fuses the given instruction in this fusion instruction. instruction_to_fuse
1247 // is cloned and the clone is placed in the fusion
1248 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
1249 // than moved to cleanly handle the case where the instruction has a use
1250 // outside the fusion instruction. Moving such an instruction into a fusion
1251 // instruction would violate the single-result invariant of HLO instructions
1252 // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)1253 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
1254 CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1255 return AppendInstructionIntoCalledComputation(instruction_to_fuse);
1256 }
1257
1258 // Fuses the given instruction in this fusion instruction and generates a
1259 // multioutput fusion instruction. A clone of the instruction_to_fuse will
1260 // be part of the output of fusion instructions. The users of
1261 // instruction_to_fuse will be redirected to this fusion instructions.
1262 // instruction_to_fuse is unchanged otherwise.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)1263 HloInstruction* FuseInstructionIntoMultiOutput(
1264 HloInstruction* instruction_to_fuse) {
1265 return AppendInstructionIntoCalledComputation(instruction_to_fuse,
1266 /*add_output=*/true);
1267 }
1268
1269 // Returns the computation for this fused instruction.
1270 HloComputation* fused_instructions_computation() const;
1271
1272 // Returns the root instruction of the fused expression contained within this
1273 // fusion instruction.
1274 HloInstruction* fused_expression_root() const;
1275
1276 // Returns the list of fused instructions inside this fusion instruction. The
1277 // returned type is a range of HloInstruction*s.
1278 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1279 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1280 fused_instructions() const;
1281
1282 const tensorflow::gtl::iterator_range<
1283 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1284 fused_instructions();
1285
1286 // Gets the number of instructions inside this fusion instruction.
1287 int64_t fused_instruction_count() const;
1288
1289 // Returns the fused parameter instruction in this fusion instruction
1290 // corresponding to the given parameter number.
1291 HloInstruction* fused_parameter(int64_t parameter_number) const;
1292
1293 // Returns the vector of fused parameters inside this fusion instruction.
1294 const std::vector<HloInstruction*>& fused_parameters() const;
1295
1296 // Returns true if this instruction is a fusion instruction that generates
1297 // multiple outputs.
IsMultiOutputFusion()1298 const bool IsMultiOutputFusion() const {
1299 return fused_expression_root()->opcode() == HloOpcode::kTuple;
1300 }
1301
fusion_kind()1302 FusionKind fusion_kind() const { return fusion_kind_; }
1303
set_fusion_kind(FusionKind kind)1304 void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
1305
1306 // If multiple operands are the same instruction, keeps only one of them.
1307 Status DeduplicateFusionOperands();
1308
ClassOf(const HloInstruction * hlo)1309 static bool ClassOf(const HloInstruction* hlo) {
1310 return hlo->opcode() == HloOpcode::kFusion;
1311 }
1312
1313 protected:
default_called_computation_name()1314 std::string default_called_computation_name() const override {
1315 return "fused_computation";
1316 }
1317
1318 private:
1319 bool IsElementwiseImpl(
1320 const std::optional<int64_t>& operand_idx) const override;
1321 std::vector<std::string> ExtraAttributesToStringImpl(
1322 const HloPrintOptions& options) const override;
1323 bool IdenticalSlowPath(
1324 const HloInstruction& other,
1325 const std::function<bool(const HloComputation*, const HloComputation*)>&
1326 eq_computations) const override;
1327
1328 // Implementation for non-common logic of CloneWithNewOperands.
1329 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1330 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1331 HloCloneContext* context) const override;
1332
1333 // The type of the fusion.
1334 FusionKind fusion_kind_;
1335 };
1336
1337 class HloCallInstruction : public HloCallableInstruction {
1338 public:
1339 HloCallInstruction(const Shape& shape,
1340 HloInstruction* called_computation_root);
1341
1342 HloCallInstruction(const Shape& shape,
1343 absl::Span<HloInstruction* const> operands,
1344 HloComputation* called_computation);
1345
ClassOf(const HloInstruction * hlo)1346 static bool ClassOf(const HloInstruction* hlo) {
1347 return hlo->opcode() == HloOpcode::kCall;
1348 }
1349
1350 protected:
default_called_computation_name()1351 std::string default_called_computation_name() const override {
1352 return "called_computation";
1353 }
1354 };
1355
1356 class HloRngInstruction : public HloInstruction {
1357 public:
1358 explicit HloRngInstruction(const Shape& shape,
1359 RandomDistribution distribution,
1360 absl::Span<HloInstruction* const> parameters);
1361 // Returns the random distribution for this rng node.
random_distribution()1362 RandomDistribution random_distribution() const { return distribution_; }
1363 // Returns a serialized representation of this instruction.
1364 HloInstructionProto ToProto() const override;
1365
ClassOf(const HloInstruction * hlo)1366 static bool ClassOf(const HloInstruction* hlo) {
1367 return hlo->opcode() == HloOpcode::kRng;
1368 }
1369
1370 private:
1371 bool IsElementwiseImpl(
1372 const std::optional<int64_t>& operand_idx) const override;
1373 std::vector<std::string> ExtraAttributesToStringImpl(
1374 const HloPrintOptions& options) const override;
1375 bool IdenticalSlowPath(
1376 const HloInstruction& other,
1377 const std::function<bool(const HloComputation*, const HloComputation*)>&
1378 eq_computations) const override;
1379 // Implementation for non-common logic of CloneWithNewOperands.
1380 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1381 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1382 HloCloneContext* context) const override;
1383
1384 // The distribution requested for random number generation.
1385 RandomDistribution distribution_;
1386 };
1387
1388 class HloParameterInstruction : public HloInstruction {
1389 public:
1390 explicit HloParameterInstruction(int64_t parameter_number, const Shape& shape,
1391 const std::string& name);
parameter_number()1392 int64_t parameter_number() const { return parameter_number_; }
1393
1394 // Sets and gets the whether all replicas will receive the same parameter data
1395 // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1396 void set_parameter_replicated_at_leaf_buffers(
1397 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
1398 CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1399 parameter_replicated_at_leaf_buffers.size());
1400 parameter_replicated_at_leaf_buffers_.emplace(
1401 parameter_replicated_at_leaf_buffers.begin(),
1402 parameter_replicated_at_leaf_buffers.end());
1403 }
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1404 void set_parameter_replicated_at_leaf_buffers(
1405 const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
1406 CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1407 parameter_replicated_at_leaf_buffers.size());
1408 parameter_replicated_at_leaf_buffers_ =
1409 parameter_replicated_at_leaf_buffers;
1410 }
parameter_replicated_at_leaf_buffers()1411 const std::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers()
1412 const {
1413 return parameter_replicated_at_leaf_buffers_;
1414 }
1415
1416 // Returns a serialized representation of this instruction.
1417 HloInstructionProto ToProto() const override;
1418
ClassOf(const HloInstruction * hlo)1419 static bool ClassOf(const HloInstruction* hlo) {
1420 return hlo->opcode() == HloOpcode::kParameter;
1421 }
1422
1423 private:
1424 std::vector<std::string> ExtraAttributesToStringImpl(
1425 const HloPrintOptions& options) const override;
1426 bool IdenticalSlowPath(
1427 const HloInstruction& other,
1428 const std::function<bool(const HloComputation*, const HloComputation*)>&
1429 eq_computations) const override;
1430 std::string OperandsToStringWithCanonicalNameMap(
1431 const HloPrintOptions& options,
1432 CanonicalNameMap* canonical_name_map) const override;
1433 // Implementation for non-common logic of CloneWithNewOperands.
1434 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1435 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1436 HloCloneContext* context) const override;
1437
1438 int64_t parameter_number_ = 0;
1439
1440 // Specifies whether each buffer has the same parameter value on all replicas
1441 // in data parallelism.
1442 std::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
1443 };
1444
1445 class HloGetTupleElementInstruction : public HloInstruction {
1446 public:
1447 explicit HloGetTupleElementInstruction(const Shape& shape,
1448 HloInstruction* operand,
1449 int64_t index);
1450 // Returns the tuple index associated with this instruction.
tuple_index()1451 int64_t tuple_index() const { return tuple_index_; }
1452 // Sets the tuple index associated with this instruction.
set_tuple_index(int64_t new_tuple_index)1453 void set_tuple_index(int64_t new_tuple_index) {
1454 tuple_index_ = new_tuple_index;
1455 }
1456 // Returns a serialized representation of this instruction.
1457 HloInstructionProto ToProto() const override;
1458
ClassOf(const HloInstruction * hlo)1459 static bool ClassOf(const HloInstruction* hlo) {
1460 return hlo->opcode() == HloOpcode::kGetTupleElement;
1461 }
1462
1463 private:
1464 std::vector<std::string> ExtraAttributesToStringImpl(
1465 const HloPrintOptions& options) const override;
1466 bool IdenticalSlowPath(
1467 const HloInstruction& other,
1468 const std::function<bool(const HloComputation*, const HloComputation*)>&
1469 eq_computations) const override;
1470 // Implementation for non-common logic of CloneWithNewOperands.
1471 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1472 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1473 HloCloneContext* context) const override;
1474
1475 int64_t tuple_index_ = -1;
1476 };
1477
1478 class HloReducePrecisionInstruction : public HloInstruction {
1479 public:
1480 explicit HloReducePrecisionInstruction(const Shape& shape,
1481 HloInstruction* operand,
1482 const int exponent_bits,
1483 const int mantissa_bits);
1484 // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()1485 int32_t exponent_bits() const { return exponent_bits_; }
1486 // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()1487 int32_t mantissa_bits() const { return mantissa_bits_; }
1488 // Returns a serialized representation of this instruction.
1489 HloInstructionProto ToProto() const override;
1490
ClassOf(const HloInstruction * hlo)1491 static bool ClassOf(const HloInstruction* hlo) {
1492 return hlo->opcode() == HloOpcode::kReducePrecision;
1493 }
1494
1495 private:
1496 std::vector<std::string> ExtraAttributesToStringImpl(
1497 const HloPrintOptions& options) const override;
1498 bool IdenticalSlowPath(
1499 const HloInstruction& other,
1500 const std::function<bool(const HloComputation*, const HloComputation*)>&
1501 eq_computations) const override;
1502 // Implementation for non-common logic of CloneWithNewOperands.
1503 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1504 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1505 HloCloneContext* context) const override;
1506
1507 // The bit sizes for a reduce-precision operation.
1508 int32_t exponent_bits_ = 0;
1509 int32_t mantissa_bits_ = 0;
1510 };
1511
1512 class HloInfeedInstruction : public HloInstruction {
1513 public:
1514 explicit HloInfeedInstruction(const Shape& infeed_shape,
1515 HloInstruction* token_operand,
1516 const std::string& config);
1517 // Returns the infeed configuration string. The infeed configuration includes
1518 // any metadata needed for the backend compiler (e.g., infeed buffer address)
1519 // and is target-dependent.
infeed_config()1520 std::string infeed_config() const { return infeed_config_; }
set_infeed_config(const std::string & config)1521 void set_infeed_config(const std::string& config) { infeed_config_ = config; }
1522 // Returns the shape of the data received by the infeed. This is not the same
1523 // as the shape of the infeed instruction which produces a tuple containing
1524 // the infeed data shape and a TOKEN.
infeed_shape()1525 const Shape& infeed_shape() const {
1526 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
1527 return ShapeUtil::GetSubshape(shape(), {0});
1528 }
1529 // Returns a serialized representation of this instruction.
1530 HloInstructionProto ToProto() const override;
1531
ClassOf(const HloInstruction * hlo)1532 static bool ClassOf(const HloInstruction* hlo) {
1533 return hlo->opcode() == HloOpcode::kInfeed;
1534 }
1535
1536 private:
1537 std::vector<std::string> ExtraAttributesToStringImpl(
1538 const HloPrintOptions& options) const override;
1539 bool IdenticalSlowPath(
1540 const HloInstruction& other,
1541 const std::function<bool(const HloComputation*, const HloComputation*)>&
1542 eq_computations) const override;
1543 // Implementation for non-common logic of CloneWithNewOperands.
1544 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1545 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1546 HloCloneContext* context) const override;
1547
1548 // The string representation of the infeed configuration.
1549 std::string infeed_config_;
1550 };
1551
1552 class HloOutfeedInstruction : public HloInstruction {
1553 public:
1554 explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1555 HloInstruction* operand,
1556 HloInstruction* token_operand,
1557 absl::string_view outfeed_config);
1558 // Returns the shape for the Outfeed instruction.
outfeed_shape()1559 const Shape& outfeed_shape() const { return outfeed_shape_; }
1560 // Returns the mutable shape for the Outfeed instruction.
mutable_outfeed_shape()1561 Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
1562 // Returns the config for the Outfeed instruction.
outfeed_config()1563 const std::string& outfeed_config() const { return outfeed_config_; }
set_outfeed_config(const std::string & config)1564 void set_outfeed_config(const std::string& config) {
1565 outfeed_config_ = config;
1566 }
1567 // Returns a serialized representation of this instruction.
1568 HloInstructionProto ToProto() const override;
1569
ClassOf(const HloInstruction * hlo)1570 static bool ClassOf(const HloInstruction* hlo) {
1571 return hlo->opcode() == HloOpcode::kOutfeed;
1572 }
1573
1574 private:
1575 std::vector<std::string> ExtraAttributesToStringImpl(
1576 const HloPrintOptions& options) const override;
1577 bool IdenticalSlowPath(
1578 const HloInstruction& other,
1579 const std::function<bool(const HloComputation*, const HloComputation*)>&
1580 eq_computations) const override;
1581 // Implementation for non-common logic of CloneWithNewOperands.
1582 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1583 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1584 HloCloneContext* context) const override;
1585
1586 // Shape of outfeed request.
1587 Shape outfeed_shape_;
1588 // Outfeed configuration information, only present for kOutfeed.
1589 std::string outfeed_config_;
1590 };
1591
1592 class HloConvolutionInstruction : public HloInstruction {
1593 public:
1594 explicit HloConvolutionInstruction(
1595 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1596 int64_t feature_group_count, int64_t batch_group_count,
1597 const Window& window,
1598 const ConvolutionDimensionNumbers& dimension_numbers,
1599 const PrecisionConfig& precision_config);
window()1600 const Window& window() const override { return window_; }
set_window(const Window & window)1601 void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1602 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1603 return convolution_dimension_numbers_;
1604 }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1605 void set_convolution_dimension_numbers(
1606 const ConvolutionDimensionNumbers& dnums) {
1607 convolution_dimension_numbers_ = dnums;
1608 }
1609 // The number of feature groups. Must be a divisor of the input feature
1610 // dimension and output feature dimension.
feature_group_count()1611 int64_t feature_group_count() const { return feature_group_count_; }
set_feature_group_count(int64_t num_feature_groups)1612 void set_feature_group_count(int64_t num_feature_groups) {
1613 feature_group_count_ = num_feature_groups;
1614 }
1615 // The number of batch groups. Must be a divisor of the input batch dimension.
batch_group_count()1616 int64_t batch_group_count() const { return batch_group_count_; }
set_batch_group_count(int64_t num_batch_groups)1617 void set_batch_group_count(int64_t num_batch_groups) {
1618 batch_group_count_ = num_batch_groups;
1619 }
1620
1621 // Returns the information used to tell the implementation information about
1622 // what sort of precision is requested. The meaning of the field is backend
1623 // specific. At the moment, it is only supported for kConvolution and kDot.
1624 // Transformations on one kDot or kConvolution to another will preserve this
1625 // information. Transformations to other HLOs will not preserve this
1626 // information but it is presumed that the alternate lowering is strictly
1627 // superior.
precision_config()1628 const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1629 PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1630
1631 std::string ToCategory() const override;
1632 // Returns a serialized representation of this instruction.
1633 HloInstructionProto ToProto() const override;
1634
ClassOf(const HloInstruction * hlo)1635 static bool ClassOf(const HloInstruction* hlo) {
1636 return hlo->opcode() == HloOpcode::kConvolution;
1637 }
1638
1639 private:
1640 std::vector<std::string> ExtraAttributesToStringImpl(
1641 const HloPrintOptions& options) const override;
1642 bool IdenticalSlowPath(
1643 const HloInstruction& other,
1644 const std::function<bool(const HloComputation*, const HloComputation*)>&
1645 eq_computations) const override;
1646 // Implementation for non-common logic of CloneWithNewOperands.
1647 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1648 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1649 HloCloneContext* context) const override;
1650 // The number of feature groups. Must be a divisor of the input feature
1651 // dimension and output feature dimension.
1652 int64_t feature_group_count_;
1653 // The number of batch groups. Must be a divisor of the input batch dimension.
1654 int64_t batch_group_count_;
1655 // Describes the window used for a convolution.
1656 Window window_;
1657 // Describes the dimension numbers used for a convolution.
1658 ConvolutionDimensionNumbers convolution_dimension_numbers_;
1659 // Information used to communicate to the implementation about the algorithm
1660 // used to produce results. See the documentation on precision_config().
1661 PrecisionConfig precision_config_;
1662 };
1663
1664 class HloReduceWindowInstruction : public HloInstruction {
1665 public:
1666 explicit HloReduceWindowInstruction(const Shape& shape,
1667 HloInstruction* operand,
1668 HloInstruction* init_value,
1669 const Window& window,
1670 HloComputation* reduce_computation);
1671 explicit HloReduceWindowInstruction(
1672 const Shape& shape, absl::Span<HloInstruction* const> operands,
1673 absl::Span<HloInstruction* const> init_values, const Window& window,
1674 HloComputation* reduce_computation);
window()1675 const Window& window() const override { return window_; }
set_window(const Window & window)1676 void set_window(const Window& window) override { window_ = window; }
1677 // Returns a serialized representation of this instruction.
1678 HloInstructionProto ToProto() const override;
1679 // Returns the number of input arrays (and, consequentially, the number of
1680 // init values) this reduce has.
input_count()1681 int64_t input_count() const { return operand_count() / 2; }
1682 // Returns the input tensors to be reduced.
inputs()1683 absl::Span<HloInstruction* const> inputs() const {
1684 return absl::MakeSpan(operands()).subspan(0, input_count());
1685 }
1686 // Returns the init values of the reduction.
init_values()1687 absl::Span<HloInstruction* const> init_values() const {
1688 return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
1689 }
1690 // Returns the shapes of input tensors to be reduced.
input_shapes()1691 absl::InlinedVector<const Shape*, 2> input_shapes() const {
1692 absl::InlinedVector<const Shape*, 2> shapes;
1693 for (const auto* op : inputs()) {
1694 VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
1695 shapes.push_back(&op->shape());
1696 VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
1697 }
1698 return shapes;
1699 }
1700 // Returns the init values of the reduction.
init_value_shapes()1701 absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
1702 absl::InlinedVector<const Shape*, 2> shapes;
1703 for (const auto* op : init_values()) {
1704 shapes.push_back(&op->shape());
1705 }
1706 return shapes;
1707 }
1708 // Returns the shapes of the reduced output tensors.
output_shapes()1709 absl::InlinedVector<const Shape*, 2> output_shapes() const {
1710 absl::InlinedVector<const Shape*, 2> shapes;
1711 if (shape().IsArray()) {
1712 shapes.push_back(&shape());
1713 } else {
1714 for (const Shape& tuple_element_shape : shape().tuple_shapes()) {
1715 shapes.push_back(&tuple_element_shape);
1716 }
1717 }
1718 return shapes;
1719 }
1720
ClassOf(const HloInstruction * hlo)1721 static bool ClassOf(const HloInstruction* hlo) {
1722 return hlo->opcode() == HloOpcode::kReduceWindow;
1723 }
1724
1725 private:
1726 std::vector<std::string> ExtraAttributesToStringImpl(
1727 const HloPrintOptions& options) const override;
1728 bool IdenticalSlowPath(
1729 const HloInstruction& other,
1730 const std::function<bool(const HloComputation*, const HloComputation*)>&
1731 eq_computations) const override;
1732 // Implementation for non-common logic of CloneWithNewOperands.
1733 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1734 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1735 HloCloneContext* context) const override;
1736
1737 Window window_;
1738 };
1739
1740 class HloSelectAndScatterInstruction : public HloInstruction {
1741 public:
1742 explicit HloSelectAndScatterInstruction(
1743 const Shape& shape, HloInstruction* operand, HloComputation* select,
1744 const Window& window, HloInstruction* source, HloInstruction* init_value,
1745 HloComputation* scatter);
window()1746 const Window& window() const override { return window_; }
set_window(const Window & window)1747 void set_window(const Window& window) override { window_ = window; }
1748 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1749 // setters should only be called by HloModule or HloComputation methods.
select()1750 HloComputation* select() const {
1751 return called_computations()[kSelectComputationIndex];
1752 }
1753
scatter()1754 HloComputation* scatter() const {
1755 return called_computations()[kScatterComputationIndex];
1756 }
1757
set_select(HloComputation * computation)1758 void set_select(HloComputation* computation) {
1759 // Don't allow changing the computation for fused instructions so we don't
1760 // have to recompute called_instructions for the entire fusion instruction.
1761 CHECK(!IsFused());
1762 set_called_computation(kSelectComputationIndex, computation);
1763 }
1764
set_scatter(HloComputation * computation)1765 void set_scatter(HloComputation* computation) {
1766 // Don't allow changing the computation for fused instructions so we don't
1767 // have to recompute called_instructions for the entire fusion instruction.
1768 CHECK(!IsFused());
1769 set_called_computation(kScatterComputationIndex, computation);
1770 }
1771 // Returns a serialized representation of this instruction.
1772 HloInstructionProto ToProto() const override;
1773
ClassOf(const HloInstruction * hlo)1774 static bool ClassOf(const HloInstruction* hlo) {
1775 return hlo->opcode() == HloOpcode::kSelectAndScatter;
1776 }
1777
1778 private:
1779 std::vector<std::string> ExtraAttributesToStringImpl(
1780 const HloPrintOptions& options) const override;
1781 bool IdenticalSlowPath(
1782 const HloInstruction& other,
1783 const std::function<bool(const HloComputation*, const HloComputation*)>&
1784 eq_computations) const override;
1785 // Implementation for non-common logic of CloneWithNewOperands.
1786 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1787 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1788 HloCloneContext* context) const override;
1789 Window window_;
1790 };
1791
1792 class HloCustomCallInstruction : public HloCallableInstruction {
1793 public:
1794 HloCustomCallInstruction(const Shape& shape,
1795 absl::Span<HloInstruction* const> operands,
1796 absl::string_view custom_call_target,
1797 std::string opaque,
1798 CustomCallApiVersion api_version);
1799
1800 // Constructor for a custom call with constrained layout. 'shape' and
1801 // 'operands_with_layout' must all have layouts.
1802 HloCustomCallInstruction(const Shape& shape,
1803 absl::Span<HloInstruction* const> operands,
1804 absl::string_view custom_call_target,
1805 std::string opaque,
1806 absl::Span<const Shape> operand_shapes_with_layout,
1807 CustomCallApiVersion api_version);
1808
1809 // Constructor for a custom call with a to_apply computation.
1810 HloCustomCallInstruction(const Shape& shape,
1811 absl::Span<HloInstruction* const> operands,
1812 HloComputation* to_apply,
1813 absl::string_view custom_call_target,
1814 std::string opaque,
1815 CustomCallApiVersion api_version);
1816
1817 // Constructor for a custom call with multiple computations.
1818 HloCustomCallInstruction(
1819 const Shape& shape, absl::Span<HloInstruction* const> operands,
1820 absl::Span<HloComputation* const> called_computations,
1821 absl::string_view custom_call_target, std::string opaque,
1822 CustomCallApiVersion api_version);
1823
window()1824 const Window& window() const override {
1825 CHECK(window_ != nullptr);
1826 return *window_;
1827 }
1828
set_window(const Window & window)1829 void set_window(const Window& window) override {
1830 window_ = std::make_unique<Window>(window);
1831 }
1832
convolution_dimension_numbers()1833 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1834 CHECK(convolution_dimension_numbers_ != nullptr);
1835 return *convolution_dimension_numbers_;
1836 }
1837
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1838 void set_convolution_dimension_numbers(
1839 const ConvolutionDimensionNumbers& dnums) {
1840 convolution_dimension_numbers_ =
1841 std::make_unique<ConvolutionDimensionNumbers>(dnums);
1842 }
1843 // TODO(jpienaar): Remove this accessor in the follow up.
opaque()1844 const std::string& opaque() const { return raw_backend_config_string(); }
custom_call_target()1845 const std::string& custom_call_target() const { return custom_call_target_; }
set_custom_call_target(absl::string_view target)1846 void set_custom_call_target(absl::string_view target) {
1847 custom_call_target_ = std::string(target);
1848 }
set_feature_group_count(int64_t feature_group_count)1849 void set_feature_group_count(int64_t feature_group_count) {
1850 feature_group_count_ = feature_group_count;
1851 }
set_batch_group_count(int64_t batch_group_count)1852 void set_batch_group_count(int64_t batch_group_count) {
1853 batch_group_count_ = batch_group_count;
1854 }
1855 // Sets whether this custom call has a side-effect - by default a custom call
1856 // has no side-effects.
set_custom_call_has_side_effect(bool custom_call_has_side_effect)1857 void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
1858 custom_call_has_side_effect_ = custom_call_has_side_effect;
1859 }
feature_group_count()1860 int64_t feature_group_count() const { return feature_group_count_; }
batch_group_count()1861 int64_t batch_group_count() const { return batch_group_count_; }
custom_call_has_side_effect()1862 bool custom_call_has_side_effect() const {
1863 return custom_call_has_side_effect_;
1864 }
1865 // Returns padding type used for ops like convolution.
padding_type()1866 PaddingType padding_type() const { return padding_type_; }
1867
set_padding_type(PaddingType padding_type)1868 void set_padding_type(PaddingType padding_type) {
1869 padding_type_ = padding_type;
1870 }
1871
1872 // Returns the literal associated with this instruction.
literal()1873 const Literal& literal() const { return *literal_; }
1874 // Set the value of literal to a new one.
set_literal(Literal && literal)1875 void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
1876 // Returns whether there is literal associated with this instruction.
HasLiteral()1877 bool HasLiteral() const { return literal_.has_value(); }
1878
precision_config()1879 const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1880 PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1881
1882 // Returns a serialized representation of this instruction.
1883 HloInstructionProto ToProto() const override;
1884
1885 // Returns whether the result and operand layouts are constrained.
layout_constrained()1886 bool layout_constrained() const { return layout_constrained_; }
1887
1888 // Returns the shapes (with layout) of the operands. CHECKs if this custom
1889 // call does not have constrained layouts.
operand_shapes_with_layout()1890 const std::vector<Shape>& operand_shapes_with_layout() const {
1891 CHECK(layout_constrained());
1892 return operand_shapes_with_layout_;
1893 }
1894 // Gets a list of output/operand buffer pairs that alias each other, where the
1895 // output buffer is represented as a ShapeIndex, and the operand buffer is
1896 // represented as the operand index and the ShapeIndex. By default this list
1897 // is empty.
1898 const std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>&
output_to_operand_aliasing()1899 output_to_operand_aliasing() const {
1900 return output_to_operand_aliasing_;
1901 }
1902 // Sets the list of output/operand buffer pairs that alias each other.
set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> aliasing)1903 void set_output_to_operand_aliasing(
1904 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1905 aliasing) {
1906 output_to_operand_aliasing_ = std::move(aliasing);
1907 }
set_custom_call_schedule(CustomCallSchedule custom_call_schedule)1908 void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) {
1909 custom_call_schedule_ = custom_call_schedule;
1910 }
custom_call_schedule()1911 CustomCallSchedule custom_call_schedule() const {
1912 return custom_call_schedule_;
1913 }
set_api_version(CustomCallApiVersion api_version)1914 void set_api_version(CustomCallApiVersion api_version) {
1915 api_version_ = api_version;
1916 }
api_version()1917 CustomCallApiVersion api_version() const { return api_version_; }
1918
ClassOf(const HloInstruction * hlo)1919 static bool ClassOf(const HloInstruction* hlo) {
1920 return hlo->opcode() == HloOpcode::kCustomCall;
1921 }
1922
1923 protected:
default_called_computation_name()1924 std::string default_called_computation_name() const override {
1925 return "custom_call_computation";
1926 }
1927
1928 private:
1929 std::vector<std::string> ExtraAttributesToStringImpl(
1930 const HloPrintOptions& options) const override;
1931 bool IdenticalSlowPath(
1932 const HloInstruction& other,
1933 const std::function<bool(const HloComputation*, const HloComputation*)>&
1934 eq_computations) const override;
1935 // Implementation for non-common logic of CloneWithNewOperands.
1936 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1937 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1938 HloCloneContext* context) const override;
1939 // Name of a global symbol to call.
1940 std::string custom_call_target_;
1941 // Describes the window in a windowed operation such as convolution.
1942 std::unique_ptr<Window> window_;
1943 // Describes the dimension numbers used for a convolution.
1944 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1945 // The number of feature groups. This is used for grouped convolutions.
1946 int64_t feature_group_count_;
1947 int64_t batch_group_count_;
1948 // Whether the result and operand layouts are constrained.
1949 bool layout_constrained_;
1950 // Information used to communicate to the implementation about the algorithm
1951 // used to produce results for convolution instructions.
1952 PrecisionConfig precision_config_;
1953 // Describes the padding type for convolution instructions.
1954 PaddingType padding_type_;
1955 // For layout-constrained custom calls, this vector holds the shape with
1956 // layout for each operand.
1957 std::vector<Shape> operand_shapes_with_layout_;
1958 // Whether this custom call has a side-effect.
1959 bool custom_call_has_side_effect_;
1960 // A list of output/operand buffer pairs that alias each other. See comment of
1961 // output_to_operand_aliasing().
1962 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1963 output_to_operand_aliasing_;
1964 std::optional<Literal> literal_;
1965 // A custom-call schedule hint.
1966 CustomCallSchedule custom_call_schedule_;
1967 // The version of the API used by the custom call function.
1968 // TODO(b/189822916): Remove this field when all clients are migrated to the
1969 // status-returning API.
1970 CustomCallApiVersion api_version_;
1971 };
1972
1973 class HloPadInstruction : public HloInstruction {
1974 public:
1975 explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1976 HloInstruction* padding_value,
1977 const PaddingConfig& padding_config);
1978 // Returns the padding configuration for a pad node.
padding_config()1979 const PaddingConfig& padding_config() const { return padding_config_; }
mutable_padding_config()1980 PaddingConfig* mutable_padding_config() { return &padding_config_; }
1981 // Returns the operand being padded.
padded_operand()1982 const HloInstruction* padded_operand() const { return operand(0); }
mutable_padded_operand()1983 HloInstruction* mutable_padded_operand() { return mutable_operand(0); }
1984 // Returns the padding value.
padding_value()1985 const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1986 HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1987 // Returns a serialized representation of this instruction.
1988 HloInstructionProto ToProto() const override;
1989
ClassOf(const HloInstruction * hlo)1990 static bool ClassOf(const HloInstruction* hlo) {
1991 return hlo->opcode() == HloOpcode::kPad;
1992 }
1993
1994 private:
1995 std::vector<std::string> ExtraAttributesToStringImpl(
1996 const HloPrintOptions& options) const override;
1997 bool IdenticalSlowPath(
1998 const HloInstruction& other,
1999 const std::function<bool(const HloComputation*, const HloComputation*)>&
2000 eq_computations) const override;
2001 // Implementation for non-common logic of CloneWithNewOperands.
2002 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2003 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2004 HloCloneContext* context) const override;
2005
2006 // The padding configuration that describes the edge padding and interior
2007 // padding of this pad instruction.
2008 PaddingConfig padding_config_;
2009 };
2010
2011 class HloDynamicIndexInstruction : public HloInstruction {
2012 public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)2013 explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
2014 : HloInstruction(opcode, shape) {}
2015 virtual int64_t first_index_operand_number() const = 0;
2016
2017 // Returns a subspan of operands which represent the start indices.
index_operands()2018 absl::Span<HloInstruction* const> index_operands() const {
2019 return absl::MakeSpan(operands()).subspan(first_index_operand_number());
2020 }
2021
2022 // Returns the shapes of the index operands.
index_shapes()2023 std::vector<Shape> index_shapes() const {
2024 std::vector<Shape> shapes;
2025 auto indices = index_operands();
2026 for (const HloInstruction* index : indices) {
2027 shapes.push_back(index->shape());
2028 }
2029 return shapes;
2030 }
2031
ClassOf(const HloInstruction * hlo)2032 static bool ClassOf(const HloInstruction* hlo) {
2033 return hlo->opcode() == HloOpcode::kDynamicSlice ||
2034 hlo->opcode() == HloOpcode::kDynamicUpdateSlice;
2035 }
2036 };
2037
2038 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
2039 public:
2040 explicit HloDynamicSliceInstruction(const Shape& shape,
2041 HloInstruction* operand,
2042 HloInstruction* start_indices,
2043 absl::Span<const int64_t> slice_sizes);
2044 explicit HloDynamicSliceInstruction(
2045 const Shape& shape, HloInstruction* operand,
2046 absl::Span<HloInstruction* const> start_indices,
2047 absl::Span<const int64_t> slice_sizes);
2048 // Old methods kept for smooth subclassing transition END.
2049 // Returns the size of the slice in the given dimension for a dynamic
2050 // slice node.
slice_sizes(int64_t dimension)2051 int64_t slice_sizes(int64_t dimension) const {
2052 return dynamic_slice_sizes_[dimension];
2053 }
dynamic_slice_sizes()2054 const std::vector<int64_t>& dynamic_slice_sizes() const {
2055 return dynamic_slice_sizes_;
2056 }
2057 // Returns a serialized representation of this instruction.
2058 HloInstructionProto ToProto() const override;
2059
first_index_operand_number()2060 int64_t first_index_operand_number() const override { return 1; }
ClassOf(const HloInstruction * hlo)2061 static bool ClassOf(const HloInstruction* hlo) {
2062 return hlo->opcode() == HloOpcode::kDynamicSlice;
2063 }
2064
2065 private:
2066 std::vector<std::string> ExtraAttributesToStringImpl(
2067 const HloPrintOptions& options) const override;
2068 bool IdenticalSlowPath(
2069 const HloInstruction& other,
2070 const std::function<bool(const HloComputation*, const HloComputation*)>&
2071 eq_computations) const override;
2072 // Implementation for non-common logic of CloneWithNewOperands.
2073 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2074 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2075 HloCloneContext* context) const override;
2076
2077 // Describes the [start, start + size) range size for a dynamic slice
2078 // ('start' is specified dynamically in the second operand of the operation).
2079 std::vector<int64_t> dynamic_slice_sizes_;
2080 };
2081
2082 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
2083 public:
2084 explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
2085 HloInstruction* operand,
2086 HloInstruction* update,
2087 HloInstruction* start_indices);
2088 explicit HloDynamicUpdateSliceInstruction(
2089 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2090 absl::Span<HloInstruction* const> start_indices);
2091
first_index_operand_number()2092 int64_t first_index_operand_number() const override { return 2; }
2093
ClassOf(const HloInstruction * hlo)2094 static bool ClassOf(const HloInstruction* hlo) {
2095 return hlo->opcode() == HloOpcode::kDynamicUpdateSlice;
2096 }
2097 };
2098
2099 class HloGatherInstruction : public HloInstruction {
2100 public:
2101 explicit HloGatherInstruction(
2102 const Shape& shape, HloInstruction* operand,
2103 HloInstruction* start_indices,
2104 const GatherDimensionNumbers& gather_dim_numbers,
2105 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted);
gather_dimension_numbers()2106 const GatherDimensionNumbers& gather_dimension_numbers() const {
2107 CHECK(gather_dimension_numbers_ != nullptr);
2108 return *gather_dimension_numbers_;
2109 }
gather_slice_sizes()2110 absl::Span<const int64_t> gather_slice_sizes() const {
2111 return gather_slice_sizes_;
2112 }
indices_are_sorted()2113 bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)2114 void set_indices_are_sorted(bool indices_are_sorted) {
2115 indices_are_sorted_ = indices_are_sorted;
2116 }
2117 // Returns a serialized representation of this instruction.
2118 HloInstructionProto ToProto() const override;
2119
2120 // Creates an instance of GatherDimensionNumbers.
2121 static GatherDimensionNumbers MakeGatherDimNumbers(
2122 absl::Span<const int64_t> offset_dims,
2123 absl::Span<const int64_t> collapsed_slice_dims,
2124 absl::Span<const int64_t> start_index_map, int64_t index_vector_dim);
2125 // Returns the dump string of the given gather dimension numbers.
2126 static std::string GatherDimensionNumbersToString(
2127 const GatherDimensionNumbers& gather_dimension_numbers);
2128
ClassOf(const HloInstruction * hlo)2129 static bool ClassOf(const HloInstruction* hlo) {
2130 return hlo->opcode() == HloOpcode::kGather;
2131 }
2132
2133 private:
2134 std::vector<std::string> ExtraAttributesToStringImpl(
2135 const HloPrintOptions& options) const override;
2136 bool IdenticalSlowPath(
2137 const HloInstruction& other,
2138 const std::function<bool(const HloComputation*, const HloComputation*)>&
2139 eq_computations) const override;
2140 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2141 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2142 HloCloneContext* context) const override;
2143
2144 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
2145 std::vector<int64_t> gather_slice_sizes_;
2146 bool indices_are_sorted_;
2147 };
2148
2149 class HloScatterInstruction : public HloInstruction {
2150 public:
2151 explicit HloScatterInstruction(
2152 const Shape& shape, absl::Span<HloInstruction* const> args,
2153 HloComputation* update_computation,
2154 const ScatterDimensionNumbers& scatter_dim_numbers,
2155 bool indices_are_sorted, bool unique_indices);
scatter_dimension_numbers()2156 const ScatterDimensionNumbers& scatter_dimension_numbers() const {
2157 CHECK(scatter_dimension_numbers_ != nullptr);
2158 return *scatter_dimension_numbers_;
2159 }
indices_are_sorted()2160 bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)2161 void set_indices_are_sorted(bool indices_are_sorted) {
2162 indices_are_sorted_ = indices_are_sorted;
2163 }
unique_indices()2164 bool unique_indices() const override { return unique_indices_; }
2165 // Returns a serialized representation of this instruction.
2166 HloInstructionProto ToProto() const override;
scatter_operand_count()2167 int64_t scatter_operand_count() const { return operand_count() / 2; }
scatter_operands()2168 absl::Span<HloInstruction* const> scatter_operands() const {
2169 return absl::MakeConstSpan(operands()).first(scatter_operand_count());
2170 }
scatter_updates()2171 absl::Span<HloInstruction* const> scatter_updates() const {
2172 return absl::MakeConstSpan(operands()).last(scatter_operand_count());
2173 }
scatter_indices()2174 const HloInstruction* scatter_indices() const {
2175 return operand(scatter_operand_count());
2176 }
scatter_indices()2177 HloInstruction* scatter_indices() {
2178 return mutable_operand(scatter_operand_count());
2179 }
2180
2181 // Creates an instance of ScatterDimensionNumbers.
2182 static ScatterDimensionNumbers MakeScatterDimNumbers(
2183 absl::Span<const int64_t> update_window_dims,
2184 absl::Span<const int64_t> inserted_window_dims,
2185 absl::Span<const int64_t> scatter_dims_to_operand_dims,
2186 int64_t index_vector_dim);
2187 // Returns the dump string of the given scatter dimension numbers.
2188 static std::string ScatterDimensionNumbersToString(
2189 const ScatterDimensionNumbers& scatter_dimension_numbers);
2190
ClassOf(const HloInstruction * hlo)2191 static bool ClassOf(const HloInstruction* hlo) {
2192 return hlo->opcode() == HloOpcode::kScatter;
2193 }
2194
2195 private:
2196 std::vector<std::string> ExtraAttributesToStringImpl(
2197 const HloPrintOptions& options) const override;
2198 bool IdenticalSlowPath(
2199 const HloInstruction& other,
2200 const std::function<bool(const HloComputation*, const HloComputation*)>&
2201 eq_computations) const override;
2202 // Implementation for non-common logic of CloneWithNewOperands.
2203 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2204 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2205 HloCloneContext* context) const override;
2206
2207 std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
2208 bool indices_are_sorted_;
2209 bool unique_indices_;
2210 };
2211
2212 class HloIotaInstruction : public HloInstruction {
2213 public:
2214 explicit HloIotaInstruction(const Shape& shape, int64_t iota_dimension);
2215
2216 // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()2217 int64_t iota_dimension() const { return iota_dimension_; }
dimensions()2218 absl::Span<const int64_t> dimensions() const override {
2219 return absl::MakeConstSpan(&iota_dimension_, 1);
2220 }
2221 // Returns a serialized representation of this instruction.
2222 HloInstructionProto ToProto() const override;
2223
ClassOf(const HloInstruction * hlo)2224 static bool ClassOf(const HloInstruction* hlo) {
2225 return hlo->opcode() == HloOpcode::kIota;
2226 }
2227
2228 private:
2229 std::vector<std::string> ExtraAttributesToStringImpl(
2230 const HloPrintOptions& options) const override;
2231 bool IdenticalSlowPath(
2232 const HloInstruction& other,
2233 const std::function<bool(const HloComputation*, const HloComputation*)>&
2234 eq_computations) const override;
2235 // Implementation for non-common logic of CloneWithNewOperands.
2236 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2237 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2238 HloCloneContext* context) const override;
2239
2240 int64_t iota_dimension_;
2241 };
2242
2243 class HloDotInstruction : public HloInstruction {
2244 public:
2245 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
2246 // dimensions specified in 'dimension_numbers'.
2247 explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
2248 HloInstruction* rhs,
2249 const DotDimensionNumbers& dimension_numbers,
2250 const PrecisionConfig& precision_config);
2251
2252 // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()2253 const DotDimensionNumbers& dot_dimension_numbers() const {
2254 return dot_dimension_numbers_;
2255 }
2256
2257 // Returns the information used to tell the implementation information about
2258 // what sort of precision is requested. The meaning of the field is backend
2259 // specific. At the moment, it is only supported for kConvolution and kDot.
2260 // Transformations on one kDot or kConvolution to another will preserve this
2261 // information. Transformations to other HLOs will not preserve this
2262 // information but it is presumed that the alternate lowering is strictly
2263 // superior.
precision_config()2264 const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()2265 PrecisionConfig* mutable_precision_config() { return &precision_config_; }
2266
2267 // Returns a serialized representation of this instruction.
2268 HloInstructionProto ToProto() const override;
2269
ClassOf(const HloInstruction * hlo)2270 static bool ClassOf(const HloInstruction* hlo) {
2271 return hlo->opcode() == HloOpcode::kDot;
2272 }
2273
2274 private:
2275 std::vector<std::string> ExtraAttributesToStringImpl(
2276 const HloPrintOptions& options) const override;
2277 bool IdenticalSlowPath(
2278 const HloInstruction& other,
2279 const std::function<bool(const HloComputation*, const HloComputation*)>&
2280 eq_computations) const override;
2281 // Implementation for non-common logic of CloneWithNewOperands.
2282 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2283 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2284 HloCloneContext* context) const override;
2285
2286 // Describes the dimension numbers used for a dot.
2287 DotDimensionNumbers dot_dimension_numbers_;
2288
2289 // Information used to communicate to the implementation about the algorithm
2290 // used to produce results. See the documentation on precision_config().
2291 PrecisionConfig precision_config_;
2292 };
2293
2294 class HloDomainInstruction : public HloInstruction {
2295 public:
2296 explicit HloDomainInstruction(
2297 const Shape& shape, HloInstruction* operand,
2298 std::unique_ptr<DomainMetadata> operand_side_metadata,
2299 std::unique_ptr<DomainMetadata> user_side_metadata);
2300
2301 // Returns a serialized representation of this instruction.
2302 HloInstructionProto ToProto() const override;
2303
2304 // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()2305 const DomainMetadata& operand_side_metadata() const {
2306 return *operand_side_metadata_;
2307 }
2308 // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()2309 const DomainMetadata& user_side_metadata() const {
2310 return *user_side_metadata_;
2311 }
2312
ClassOf(const HloInstruction * hlo)2313 static bool ClassOf(const HloInstruction* hlo) {
2314 return hlo->opcode() == HloOpcode::kDomain;
2315 }
2316
2317 private:
2318 std::vector<std::string> ExtraAttributesToStringImpl(
2319 const HloPrintOptions& options) const override;
2320 bool IdenticalSlowPath(
2321 const HloInstruction& other,
2322 const std::function<bool(const HloComputation*, const HloComputation*)>&
2323 eq_computations) const override;
2324 // Implementation for non-common logic of CloneWithNewOperands.
2325 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2326 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2327 HloCloneContext* context) const override;
2328
2329 std::unique_ptr<DomainMetadata> operand_side_metadata_;
2330 std::unique_ptr<DomainMetadata> user_side_metadata_;
2331 };
2332
2333 class HloGetDimensionSizeInstruction : public HloInstruction {
2334 public:
2335 explicit HloGetDimensionSizeInstruction(const Shape& shape,
2336 HloInstruction* operand,
2337 int64_t dimension);
2338
2339 // Returns the dimension sizes or numbers associated with this instruction.
dimension()2340 int64_t dimension() const { return dimension_; }
2341 // Returns a serialized representation of this instruction.
2342 HloInstructionProto ToProto() const override;
2343
ClassOf(const HloInstruction * hlo)2344 static bool ClassOf(const HloInstruction* hlo) {
2345 return hlo->opcode() == HloOpcode::kGetDimensionSize;
2346 }
2347
2348 private:
2349 std::vector<std::string> ExtraAttributesToStringImpl(
2350 const HloPrintOptions& options) const override;
2351 bool IdenticalSlowPath(
2352 const HloInstruction& other,
2353 const std::function<bool(const HloComputation*, const HloComputation*)>&
2354 eq_computations) const override;
2355 // Implementation for non-common logic of CloneWithNewOperands.
2356 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2357 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2358 HloCloneContext* context) const override;
2359
2360 int64_t dimension_;
2361 };
2362
2363 class HloSetDimensionSizeInstruction : public HloInstruction {
2364 public:
2365 explicit HloSetDimensionSizeInstruction(const Shape& shape,
2366 HloInstruction* operand,
2367 HloInstruction* val,
2368 int64_t dimension);
2369
2370 // Returns the dimension sizes or numbers associated with this instruction.
dimension()2371 int64_t dimension() const { return dimension_; }
2372 // Returns a serialized representation of this instruction.
2373 HloInstructionProto ToProto() const override;
2374
ClassOf(const HloInstruction * hlo)2375 static bool ClassOf(const HloInstruction* hlo) {
2376 return hlo->opcode() == HloOpcode::kSetDimensionSize;
2377 }
2378
2379 private:
2380 std::vector<std::string> ExtraAttributesToStringImpl(
2381 const HloPrintOptions& options) const override;
2382 bool IdenticalSlowPath(
2383 const HloInstruction& other,
2384 const std::function<bool(const HloComputation*, const HloComputation*)>&
2385 eq_computations) const override;
2386 // Implementation for non-common logic of CloneWithNewOperands.
2387 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2388 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2389 HloCloneContext* context) const override;
2390
2391 int64_t dimension_;
2392 };
2393
2394 class HloRngGetAndUpdateStateInstruction : public HloInstruction {
2395 public:
2396 explicit HloRngGetAndUpdateStateInstruction(const Shape& shape,
2397 int64_t delta);
2398
2399 // Returns the delta value.
delta()2400 int64_t delta() const { return delta_; }
set_delta(int64_t delta)2401 void set_delta(int64_t delta) { delta_ = delta; }
2402 // Returns a serialized representation of this instruction.
2403 HloInstructionProto ToProto() const override;
2404
ClassOf(const HloInstruction * hlo)2405 static bool ClassOf(const HloInstruction* hlo) {
2406 return hlo->opcode() == HloOpcode::kRngGetAndUpdateState;
2407 }
2408
2409 private:
2410 std::vector<std::string> ExtraAttributesToStringImpl(
2411 const HloPrintOptions& options) const override;
2412 bool IdenticalSlowPath(
2413 const HloInstruction& other,
2414 const std::function<bool(const HloComputation*, const HloComputation*)>&
2415 eq_computations) const override;
2416 // Implementation for non-common logic of CloneWithNewOperands.
2417 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2418 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2419 HloCloneContext* context) const override;
2420
2421 int64_t delta_;
2422 };
2423
2424 class HloRngBitGeneratorInstruction : public HloInstruction {
2425 public:
2426 HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
2427 RandomAlgorithm algorithm);
2428
algorithm()2429 RandomAlgorithm algorithm() const { return algorithm_; }
2430 HloInstructionProto ToProto() const override;
2431
ClassOf(const HloInstruction * hlo)2432 static bool ClassOf(const HloInstruction* hlo) {
2433 return hlo->opcode() == HloOpcode::kRngBitGenerator;
2434 }
2435
2436 private:
2437 std::vector<std::string> ExtraAttributesToStringImpl(
2438 const HloPrintOptions& options) const override;
2439 bool IdenticalSlowPath(
2440 const HloInstruction& other,
2441 const std::function<bool(const HloComputation*, const HloComputation*)>&
2442 eq_computations) const override;
2443 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2444 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2445 HloCloneContext* context) const override;
2446
2447 RandomAlgorithm algorithm_;
2448 };
2449
2450 } // namespace xla
2451
2452 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
2453