• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17 
18 #include <deque>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 
40 namespace xla {
41 namespace {
42 
43 using absl::CEscape;
44 using absl::StrAppend;
45 using absl::StrCat;
46 using absl::StrJoin;
47 
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)48 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
49                                        const HloInstruction* operand) {
50   const auto operand_indices = instruction->OperandIndices(operand);
51   return absl::c_all_of(operand_indices, [instruction](int64_t operand_index) {
52     return instruction->IsElementwiseOnOperand(operand_index);
53   });
54 }
55 
PrecisionConfigToString(const PrecisionConfig & precision_config)56 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
57   if (absl::c_all_of(
58           precision_config.operand_precision(), [](int32_t precision) {
59             return static_cast<PrecisionConfig::Precision>(precision) ==
60                    PrecisionConfig::DEFAULT;
61           })) {
62     return "";
63   }
64 
65   return StrCat(
66       "operand_precision={",
67       StrJoin(
68           precision_config.operand_precision(), ",",
69           [](string* out, int32_t precision) {
70             CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
71             StrAppend(out,
72                       PrecisionToString(
73                           static_cast<PrecisionConfig::Precision>(precision)));
74           }),
75       "}");
76 }
77 }  // namespace
78 
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64_t feature_index)79 HloBatchNormInstruction::HloBatchNormInstruction(
80     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
81     HloInstruction* scale, float epsilon, int64_t feature_index)
82     : HloInstruction(opcode, shape),
83       epsilon_(epsilon),
84       feature_index_(feature_index) {
85   AppendOperand(operand);
86   AppendOperand(scale);
87 }
88 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const89 bool HloBatchNormInstruction::IdenticalSlowPath(
90     const HloInstruction& other,
91     const std::function<bool(const HloComputation*, const HloComputation*)>&
92         eq_computations) const {
93   const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
94   return feature_index() == casted_other.feature_index() &&
95          epsilon() == casted_other.epsilon();
96 }
97 
ToProto() const98 HloInstructionProto HloBatchNormInstruction::ToProto() const {
99   HloInstructionProto proto = HloInstruction::ToProto();
100   proto.set_epsilon(epsilon_);
101   proto.set_feature_index(feature_index_);
102   return proto;
103 }
104 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const105 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
106     const HloPrintOptions& options) const {
107   return {StrCat("epsilon=", epsilon()),
108           StrCat("feature_index=", feature_index())};
109 }
110 
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)111 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
112     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
113     HloInstruction* offset, float epsilon, int64_t feature_index)
114     : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
115                               scale, epsilon, feature_index) {
116   AppendOperand(offset);
117 }
118 
119 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const120 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
121     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
122     HloCloneContext* context) const {
123   CHECK_EQ(new_operands.size(), 3);
124   return absl::make_unique<HloBatchNormTrainingInstruction>(
125       shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
126       feature_index());
127 }
128 
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)129 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
130     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
131     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
132     float epsilon, int64_t feature_index)
133     : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
134                               scale, epsilon, feature_index) {
135   AppendOperand(offset);
136   AppendOperand(mean);
137   AppendOperand(variance);
138 }
139 
140 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const141 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
142     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
143     HloCloneContext* context) const {
144   CHECK_EQ(new_operands.size(), 5);
145   return absl::make_unique<HloBatchNormInferenceInstruction>(
146       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
147       new_operands[4], epsilon(), feature_index());
148 }
149 
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)150 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
151     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
152     HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
153     float epsilon, int64_t feature_index)
154     : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
155                               epsilon, feature_index) {
156   AppendOperand(mean);
157   AppendOperand(variance);
158   AppendOperand(grad_output);
159 }
160 
161 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const162 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
163     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
164     HloCloneContext* context) const {
165   CHECK_EQ(new_operands.size(), 5);
166   return absl::make_unique<HloBatchNormGradInstruction>(
167       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
168       new_operands[4], epsilon(), feature_index());
169 }
170 
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)171 HloFftInstruction::HloFftInstruction(const Shape& shape,
172                                      HloInstruction* operand, FftType fft_type,
173                                      absl::Span<const int64> fft_length)
174     : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
175   fft_length_.assign(fft_length.begin(), fft_length.end());
176   AppendOperand(operand);
177 }
178 
ToProto() const179 HloInstructionProto HloFftInstruction::ToProto() const {
180   HloInstructionProto proto = HloInstruction::ToProto();
181   proto.set_fft_type(fft_type_);
182   for (int64_t fft_len : fft_length_) {
183     proto.add_fft_length(fft_len);
184   }
185   return proto;
186 }
187 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const188 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
189     const HloPrintOptions& options) const {
190   return {StrCat("fft_type=", FftType_Name(fft_type())),
191           StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
192 }
193 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const194 bool HloFftInstruction::IdenticalSlowPath(
195     const HloInstruction& other,
196     const std::function<bool(const HloComputation*, const HloComputation*)>&
197         eq_computations) const {
198   const auto& casted_other = static_cast<const HloFftInstruction&>(other);
199   return fft_type() == casted_other.fft_type() &&
200          fft_length() == casted_other.fft_length();
201 }
202 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const203 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
204     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
205     HloCloneContext* context) const {
206   CHECK_EQ(new_operands.size(), 1);
207   return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
208                                               fft_length_);
209 }
210 
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)211 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
212                                                  HloInstruction* operand,
213                                                  bool is_cross_program_prefetch)
214     : HloInstruction(HloOpcode::kCopyStart, shape),
215       is_cross_program_prefetch_(is_cross_program_prefetch) {
216   AppendOperand(operand);
217 }
218 
ToProto() const219 HloInstructionProto HloCopyStartInstruction::ToProto() const {
220   HloInstructionProto proto = HloInstruction::ToProto();
221   proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
222   return proto;
223 }
224 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const225 std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
226     const HloPrintOptions& options) const {
227   std::vector<string> result;
228   if (is_cross_program_prefetch()) {
229     result.push_back("is_cross_program_prefetch=true");
230   }
231   return result;
232 }
233 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const234 bool HloCopyStartInstruction::IdenticalSlowPath(
235     const HloInstruction& other,
236     const std::function<bool(const HloComputation*, const HloComputation*)>&
237         eq_computations) const {
238   const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
239   return is_cross_program_prefetch() ==
240          casted_other.is_cross_program_prefetch();
241 }
242 
243 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const244 HloCopyStartInstruction::CloneWithNewOperandsImpl(
245     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
246     HloCloneContext* context) const {
247   CHECK_EQ(new_operands.size(), 1);
248   return absl::make_unique<HloCopyStartInstruction>(
249       shape, new_operands[0], is_cross_program_prefetch());
250 }
251 
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)252 HloCompareInstruction::HloCompareInstruction(
253     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
254     ComparisonDirection direction, absl::optional<Comparison::Type> type)
255     : HloInstruction(HloOpcode::kCompare, shape),
256       compare_(direction, type ? (*type)
257                                : Comparison::DefaultComparisonType(
258                                      lhs->shape().element_type())) {
259   AppendOperand(lhs);
260   AppendOperand(rhs);
261 }
262 
ToProto() const263 HloInstructionProto HloCompareInstruction::ToProto() const {
264   HloInstructionProto proto = HloInstruction::ToProto();
265   proto.set_comparison_direction(
266       ComparisonDirectionToString(compare_.GetDirection()));
267   proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
268   return proto;
269 }
270 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const271 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
272     const HloPrintOptions& options) const {
273   std::vector<string> result;
274   result.push_back(
275       StrCat("direction=", ComparisonDirectionToString(direction())));
276   if (compare_.GetType() !=
277       Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
278     result.push_back(
279         StrCat("type=", ComparisonTypeToString(compare_.GetType())));
280   }
281   return result;
282 }
283 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const284 bool HloCompareInstruction::IdenticalSlowPath(
285     const HloInstruction& other,
286     const std::function<bool(const HloComputation*, const HloComputation*)>&
287         eq_computations) const {
288   const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
289   return direction() == casted_other.direction();
290 }
291 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const292 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
293     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
294     HloCloneContext* context) const {
295   CHECK_EQ(new_operands.size(), 2);
296   return absl::make_unique<HloCompareInstruction>(
297       shape, new_operands[0], new_operands[1], direction(), type());
298 }
299 
300 namespace {
301 
302 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
303 // of "key=value" attribute strings generically, using protocol buffer
304 // reflection.
305 //
306 // Currently implements a small subset of cases; feel free to add more as
307 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)308 std::vector<string> AttributeProtoToStringVector(
309     const tensorflow::protobuf::Message& message) {
310   const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
311   std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
312   reflection->ListFields(message, &fields);
313 
314   std::vector<string> output;
315   for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
316     string s = absl::StrCat(field->name(), "=");
317     CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
318     switch (field->type()) {
319       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
320         bool val = reflection->GetBool(message, field);
321         absl::StrAppend(&s, val ? "true" : "false");
322         break;
323       }
324       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
325         const tensorflow::protobuf::EnumValueDescriptor* evd =
326             reflection->GetEnum(message, field);
327         absl::StrAppend(&s, evd->name());
328         break;
329       }
330       default:
331         LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
332     }
333     output.push_back(std::move(s));
334   }
335   return output;
336 }
337 
338 }  // namespace
339 
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)340 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
341     const Shape& shape, HloInstruction* a, HloInstruction* b,
342     const TriangularSolveOptions& options)
343     : HloInstruction(HloOpcode::kTriangularSolve, shape),
344       triangular_solve_options_(options) {
345   AppendOperand(a);
346   AppendOperand(b);
347 }
348 
ToProto() const349 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
350   HloInstructionProto proto = HloInstruction::ToProto();
351   *proto.mutable_triangular_solve_options() = triangular_solve_options_;
352   return proto;
353 }
354 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const355 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
356     const HloPrintOptions& options) const {
357   return AttributeProtoToStringVector(triangular_solve_options_);
358 }
359 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const360 bool HloTriangularSolveInstruction::IdenticalSlowPath(
361     const HloInstruction& other,
362     const std::function<bool(const HloComputation*, const HloComputation*)>&
363         eq_computations) const {
364   const auto& casted_other =
365       static_cast<const HloTriangularSolveInstruction&>(other);
366   const auto& options = triangular_solve_options();
367   const auto& other_options = casted_other.triangular_solve_options();
368 
369   return options.left_side() == other_options.left_side() &&
370          options.lower() == other_options.lower() &&
371          options.unit_diagonal() == other_options.unit_diagonal() &&
372          options.transpose_a() == other_options.transpose_a();
373 }
374 
375 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const376 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
377     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
378     HloCloneContext* context) const {
379   CHECK_EQ(new_operands.size(), 2);
380   return absl::make_unique<HloTriangularSolveInstruction>(
381       shape, new_operands[0], new_operands[1], triangular_solve_options());
382 }
383 
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)384 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
385                                                HloInstruction* a,
386                                                const CholeskyOptions& options)
387     : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
388   AppendOperand(a);
389 }
390 
ToProto() const391 HloInstructionProto HloCholeskyInstruction::ToProto() const {
392   HloInstructionProto proto = HloInstruction::ToProto();
393   *proto.mutable_cholesky_options() = cholesky_options_;
394   return proto;
395 }
396 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const397 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
398     const HloPrintOptions& options) const {
399   return AttributeProtoToStringVector(cholesky_options_);
400 }
401 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const402 bool HloCholeskyInstruction::IdenticalSlowPath(
403     const HloInstruction& other,
404     const std::function<bool(const HloComputation*, const HloComputation*)>&
405         eq_computations) const {
406   const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
407   const auto& options = cholesky_options();
408   const auto& other_options = casted_other.cholesky_options();
409 
410   return options.lower() == other_options.lower();
411 }
412 
413 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const414 HloCholeskyInstruction::CloneWithNewOperandsImpl(
415     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
416     HloCloneContext* context) const {
417   CHECK_EQ(new_operands.size(), 1);
418   return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
419                                                    cholesky_options());
420 }
421 
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)422 HloChannelInstruction::HloChannelInstruction(
423     HloOpcode opcode, const Shape& shape,
424     const absl::optional<int64>& channel_id)
425     : HloInstruction(opcode, shape), channel_id_(channel_id) {}
426 
set_channel_id(const absl::optional<int64> & channel_id)427 void HloChannelInstruction::set_channel_id(
428     const absl::optional<int64>& channel_id) {
429   channel_id_ = channel_id;
430 }
431 
ToProto() const432 HloInstructionProto HloChannelInstruction::ToProto() const {
433   HloInstructionProto proto = HloInstruction::ToProto();
434   if (channel_id_) {
435     CHECK_GT(channel_id_.value(), 0)
436         << "Non-positive channel id is equivalent to no channel id";
437     proto.set_channel_id(*channel_id_);
438   }
439   return proto;
440 }
441 
ExtraAttributesToStringImpl(const HloPrintOptions &) const442 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
443     const HloPrintOptions& /*options*/) const {
444   std::vector<string> result;
445   if (channel_id_) {
446     result.push_back(StrCat("channel_id=", *channel_id_));
447   }
448   return result;
449 }
450 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const451 bool HloChannelInstruction::IdenticalSlowPath(
452     const HloInstruction& other,
453     const std::function<bool(const HloComputation*, const HloComputation*)>&
454         eq_computations) const {
455   if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
456     return false;
457   }
458   const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
459   return channel_id() == casted_other.channel_id();
460 }
461 
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64_t channel_id,bool is_host_transfer)462 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
463                                                const Shape& shape,
464                                                int64_t channel_id,
465                                                bool is_host_transfer)
466     : HloChannelInstruction(opcode, shape, channel_id),
467       is_host_transfer_(is_host_transfer) {}
468 
ToProto() const469 HloInstructionProto HloSendRecvInstruction::ToProto() const {
470   HloInstructionProto proto = HloChannelInstruction::ToProto();
471   proto.set_is_host_transfer(is_host_transfer_);
472   return proto;
473 }
474 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const475 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
476     const HloPrintOptions& options) const {
477   std::vector<string> attrs =
478       HloChannelInstruction::ExtraAttributesToStringImpl(options);
479   if (is_host_transfer()) {
480     attrs.push_back("is_host_transfer=true");
481   }
482   return attrs;
483 }
484 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const485 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
486     const HloInstruction& other,
487     const std::function<bool(const HloComputation*, const HloComputation*)>&
488         eq_computations) const {
489   // Not yet supported.
490   return false;
491 }
492 
493 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)494 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
495                                        HloInstruction* token,
496                                        int64_t channel_id,
497                                        bool is_host_transfer)
498     : HloSendRecvInstruction(
499           HloOpcode::kSend,
500           ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
501                                      ShapeUtil::MakeShape(U32, {}),
502                                      ShapeUtil::MakeTokenShape()}),
503           channel_id, is_host_transfer) {
504   AppendOperand(operand);
505   AppendOperand(token);
506 }
507 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const508 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
509     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
510     HloCloneContext* context) const {
511   CHECK_EQ(new_operands.size(), 2);
512   return absl::make_unique<HloSendInstruction>(
513       new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
514 }
515 
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)516 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
517                                                bool is_host_transfer)
518     : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
519                              CHECK_NOTNULL(operand)->channel_id().value(),
520                              is_host_transfer) {
521   AppendOperand(operand);
522 }
523 
524 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const525 HloSendDoneInstruction::CloneWithNewOperandsImpl(
526     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
527     HloCloneContext* context) const {
528   CHECK_EQ(new_operands.size(), 1);
529   return absl::make_unique<HloSendDoneInstruction>(
530       Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
531 }
532 
533 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)534 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
535                                        HloInstruction* token,
536                                        int64_t channel_id,
537                                        bool is_host_transfer)
538     : HloSendRecvInstruction(
539           HloOpcode::kRecv,
540           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
541                                      ShapeUtil::MakeTokenShape()}),
542           channel_id, is_host_transfer) {
543   AppendOperand(token);
544 }
545 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const546 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
547     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
548     HloCloneContext* context) const {
549   CHECK_EQ(new_operands.size(), 1);
550   return absl::make_unique<HloRecvInstruction>(
551       ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
552       is_host_transfer());
553 }
554 
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)555 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
556                                                bool is_host_transfer)
557     : HloSendRecvInstruction(
558           HloOpcode::kRecvDone,
559           ShapeUtil::MakeTupleShape(
560               {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
561                ShapeUtil::MakeTokenShape()}),
562           CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
563   AppendOperand(operand);
564 }
565 
566 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const567 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
568     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
569     HloCloneContext* context) const {
570   CHECK_EQ(new_operands.size(), 1);
571   return absl::make_unique<HloRecvDoneInstruction>(
572       Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
573 }
574 
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)575 HloCollectiveInstruction::HloCollectiveInstruction(
576     HloOpcode opcode, const Shape& shape,
577     absl::Span<HloInstruction* const> operands,
578     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
579     const absl::optional<int64>& channel_id)
580     : HloChannelInstruction(opcode, shape, channel_id),
581       replica_groups_(SpanToVector(replica_groups)),
582       constrain_layout_(constrain_layout) {
583   for (auto operand : operands) {
584     AppendOperand(operand);
585   }
586 }
587 
ToProto() const588 HloInstructionProto HloCollectiveInstruction::ToProto() const {
589   HloInstructionProto proto = HloChannelInstruction::ToProto();
590   *proto.mutable_replica_groups() = {replica_groups_.begin(),
591                                      replica_groups_.end()};
592   proto.set_constrain_layout(constrain_layout_);
593   return proto;
594 }
595 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const596 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
597     const HloPrintOptions& options) const {
598   std::vector<string> result =
599       HloChannelInstruction::ExtraAttributesToStringImpl(options);
600   result.push_back(
601       StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
602   if (constrain_layout_) {
603     result.push_back("constrain_layout=true");
604   }
605   return result;
606 }
607 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const608 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
609     const HloInstruction& other,
610     const std::function<bool(const HloComputation*, const HloComputation*)>&
611         eq_computations) const {
612   const auto& casted_other =
613       static_cast<const HloCollectiveInstruction&>(other);
614   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
615              other, eq_computations) &&
616          constrain_layout() == casted_other.constrain_layout() &&
617          absl::c_equal(replica_groups(), casted_other.replica_groups(),
618                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
619                          return absl::c_equal(a.replica_ids(), b.replica_ids());
620                        });
621 }
622 
HloAllGatherInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)623 HloAllGatherInstruction::HloAllGatherInstruction(
624     HloOpcode opcode, const Shape& shape,
625     absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
626     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
627     const absl::optional<int64>& channel_id, bool use_global_device_ids)
628     : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
629                                constrain_layout, channel_id),
630       all_gather_dimension_(all_gather_dimension),
631       use_global_device_ids_(use_global_device_ids) {}
632 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const633 std::vector<string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
634     const HloPrintOptions& options) const {
635   std::vector<string> result =
636       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
637   result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
638   if (use_global_device_ids_) {
639     result.push_back("use_global_device_ids=true");
640   }
641   return result;
642 }
643 
644 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const645 HloAllGatherInstruction::CloneWithNewOperandsImpl(
646     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
647     HloCloneContext* /*context*/) const {
648   return absl::make_unique<HloAllGatherInstruction>(
649       opcode(), shape, new_operands, all_gather_dimension(), replica_groups(),
650       constrain_layout(), channel_id(), use_global_device_ids());
651 }
652 
ToProto() const653 HloInstructionProto HloAllGatherInstruction::ToProto() const {
654   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
655   proto.add_dimensions(all_gather_dimension_);
656   proto.set_use_global_device_ids(use_global_device_ids_);
657   return proto;
658 }
659 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const660 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
661     const HloInstruction& other,
662     const std::function<bool(const HloComputation*, const HloComputation*)>&
663         eq_computations) const {
664   const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
665   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
666              other, eq_computations) &&
667          all_gather_dimension_ == casted_other.all_gather_dimension() &&
668          use_global_device_ids() == casted_other.use_global_device_ids();
669 }
670 
HloAllReduceInstructionBase(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)671 HloAllReduceInstructionBase::HloAllReduceInstructionBase(
672     HloOpcode opcode, const Shape& shape,
673     absl::Span<HloInstruction* const> operands,
674     HloComputation* reduce_computation,
675     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
676     const absl::optional<int64>& channel_id, bool use_global_device_ids)
677     : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
678                                constrain_layout, channel_id),
679       use_global_device_ids_(use_global_device_ids) {
680   AppendComputation(reduce_computation);
681 }
682 
ToProto() const683 HloInstructionProto HloAllReduceInstructionBase::ToProto() const {
684   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
685   proto.set_use_global_device_ids(use_global_device_ids_);
686   return proto;
687 }
688 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const689 std::vector<string> HloAllReduceInstructionBase::ExtraAttributesToStringImpl(
690     const HloPrintOptions& options) const {
691   std::vector<string> result =
692       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
693   if (use_global_device_ids_) {
694     result.push_back("use_global_device_ids=true");
695   }
696   return result;
697 }
698 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const699 bool HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
700     const HloInstruction& other,
701     const std::function<bool(const HloComputation*, const HloComputation*)>&
702         eq_computations) const {
703   if (opcode() != other.opcode()) {
704     return false;
705   }
706   const auto& casted_other =
707       static_cast<const HloAllReduceInstructionBase&>(other);
708   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
709              other, eq_computations) &&
710          constrain_layout() == casted_other.constrain_layout() &&
711          use_global_device_ids() == casted_other.use_global_device_ids() &&
712          eq_computations(to_apply(), casted_other.to_apply());
713 }
714 
IsNoop() const715 bool HloAllReduceInstruction::IsNoop() const {
716   for (const auto& replica_group : replica_groups()) {
717     if (replica_group.replica_ids().size() != 1) {
718       return false;
719     }
720   }
721   return !channel_id();
722 }
723 
724 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const725 HloAllReduceInstruction::CloneWithNewOperandsImpl(
726     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
727     HloCloneContext* /*context*/) const {
728   return absl::make_unique<HloAllReduceInstruction>(
729       opcode(), shape, new_operands, to_apply(), replica_groups(),
730       constrain_layout(), channel_id(), use_global_device_ids());
731 }
732 
HloReduceScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)733 HloReduceScatterInstruction::HloReduceScatterInstruction(
734     const Shape& shape, absl::Span<HloInstruction* const> operands,
735     HloComputation* reduce_computation,
736     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
737     const absl::optional<int64>& channel_id, bool use_global_device_ids,
738     int64_t scatter_dimension)
739     : HloAllReduceInstructionBase(
740           HloOpcode::kReduceScatter, shape, operands, reduce_computation,
741           replica_groups, constrain_layout, channel_id, use_global_device_ids),
742       scatter_dimension_(scatter_dimension) {}
743 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const744 std::vector<string> HloReduceScatterInstruction::ExtraAttributesToStringImpl(
745     const HloPrintOptions& options) const {
746   std::vector<string> result =
747       HloAllReduceInstructionBase::ExtraAttributesToStringImpl(options);
748   result.push_back(StrCat("dimensions={", scatter_dimension_, "}"));
749   return result;
750 }
751 
ToProto() const752 HloInstructionProto HloReduceScatterInstruction::ToProto() const {
753   HloInstructionProto proto = HloAllReduceInstructionBase::ToProto();
754   proto.add_dimensions(scatter_dimension_);
755   return proto;
756 }
757 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const758 bool HloReduceScatterInstruction::IdenticalSlowPathIgnoringChannelIdValues(
759     const HloInstruction& other,
760     const std::function<bool(const HloComputation*, const HloComputation*)>&
761         eq_computations) const {
762   const auto& casted_other =
763       static_cast<const HloReduceScatterInstruction&>(other);
764   return HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
765              other, eq_computations) &&
766          scatter_dimension_ == casted_other.scatter_dimension();
767 }
768 
769 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const770 HloReduceScatterInstruction::CloneWithNewOperandsImpl(
771     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
772     HloCloneContext* /*context*/) const {
773   return absl::make_unique<HloReduceScatterInstruction>(
774       shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
775       channel_id(), use_global_device_ids(), scatter_dimension());
776 }
777 
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)778 HloAllToAllInstruction::HloAllToAllInstruction(
779     const Shape& shape, absl::Span<HloInstruction* const> operands,
780     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
781     const absl::optional<int64>& channel_id,
782     const absl::optional<int64>& split_dimension)
783     : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
784                                replica_groups, constrain_layout, channel_id),
785       split_dimension_(split_dimension) {}
786 
787 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const788 HloAllToAllInstruction::CloneWithNewOperandsImpl(
789     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
790     HloCloneContext* /*context*/) const {
791   return absl::make_unique<HloAllToAllInstruction>(
792       shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
793       split_dimension());
794 }
795 
ToProto() const796 HloInstructionProto HloAllToAllInstruction::ToProto() const {
797   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
798   if (split_dimension_) {
799     proto.add_dimensions(*split_dimension_);
800   }
801   return proto;
802 }
803 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const804 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
805     const HloPrintOptions& options) const {
806   std::vector<string> result =
807       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
808   if (split_dimension_) {
809     result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
810   }
811   return result;
812 }
813 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const814 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
815     const HloInstruction& other,
816     const std::function<bool(const HloComputation*, const HloComputation*)>&
817         eq_computations) const {
818   const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
819   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
820              other, eq_computations) &&
821          split_dimension_ == casted_other.split_dimension();
822 }
823 
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)824 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
825     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
826     const std::vector<std::pair<int64, int64>>& source_target_pairs,
827     const absl::optional<int64>& channel_id)
828     : HloChannelInstruction(opcode, shape, channel_id),
829       source_target_pairs_(source_target_pairs) {
830   AppendOperand(operand);
831 }
832 
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)833 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
834     HloOpcode opcode, const Shape& shape, HloInstruction* input,
835     HloInstruction* output, HloInstruction* input_start_indices,
836     HloInstruction* output_start_indices,
837     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
838     absl::Span<const std::vector<int64_t>> slice_sizes,
839     const absl::optional<int64_t>& channel_id)
840     : HloChannelInstruction(opcode, shape, channel_id),
841       source_target_pairs_(source_target_pairs.begin(),
842                            source_target_pairs.end()),
843       slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
844   AppendOperand(input);
845   AppendOperand(output);
846   AppendOperand(input_start_indices);
847   AppendOperand(output_start_indices);
848 }
849 
ToProto() const850 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
851   HloInstructionProto proto = HloChannelInstruction::ToProto();
852   for (const auto& pair : source_target_pairs()) {
853     auto* proto_pair = proto.add_source_target_pairs();
854     proto_pair->set_source(pair.first);
855     proto_pair->set_target(pair.second);
856   }
857   for (const auto& slice_size : dynamic_slice_sizes_list()) {
858     for (const auto& dimension_slice_size : slice_size) {
859       proto.add_dynamic_slice_sizes(dimension_slice_size);
860     }
861   }
862   return proto;
863 }
864 
865 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const866 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
867     const HloPrintOptions& options) const {
868   std::vector<string> result =
869       HloChannelInstruction::ExtraAttributesToStringImpl(options);
870   {
871     std::vector<string> strs;
872     for (const auto& pair : source_target_pairs()) {
873       strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
874     }
875     result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
876   }
877   if (!dynamic_slice_sizes_list().empty()) {
878     std::vector<string> strs;
879     for (const auto& slice_sizes : dynamic_slice_sizes_list()) {
880       strs.push_back(StrCat("{", StrJoin(slice_sizes, ","), "}"));
881     }
882     result.push_back(StrCat("slice_sizes={", StrJoin(strs, ","), "}"));
883   }
884   return result;
885 }
886 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const887 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
888     const HloInstruction& other,
889     const std::function<bool(const HloComputation*, const HloComputation*)>&
890         eq_computations) const {
891   if (opcode() != other.opcode()) {
892     return false;
893   }
894   const auto& casted_other =
895       static_cast<const HloCollectivePermuteInstruction&>(other);
896   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
897              other, eq_computations) &&
898          absl::c_equal(
899              source_target_pairs(), casted_other.source_target_pairs(),
900              [](const std::pair<int64, int64>& a,
901                 const std::pair<int64, int64>& b) { return a == b; }) &&
902          absl::c_equal(
903              dynamic_slice_sizes_list(),
904              casted_other.dynamic_slice_sizes_list(),
905              [](const std::vector<int64>& a, const std::vector<int64>& b) {
906                return absl::c_equal(a, b);
907              });
908 }
909 
910 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const911 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
912     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
913     HloCloneContext* /*context*/) const {
914   if (dynamic_slice_sizes_list().empty()) {
915     return absl::make_unique<HloCollectivePermuteInstruction>(
916         opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
917   } else {
918     return absl::make_unique<HloCollectivePermuteInstruction>(
919         opcode(), shape, new_operands[0], new_operands[1], new_operands[2],
920         new_operands[3], source_target_pairs(), dynamic_slice_sizes_list(),
921         channel_id());
922   }
923 }
924 
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)925 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
926                                              HloInstruction* operand,
927                                              absl::Span<const int64> dimensions)
928     : HloInstruction(HloOpcode::kReverse, shape),
929       dimensions_(dimensions.begin(), dimensions.end()) {
930   AppendOperand(operand);
931 }
932 
ToProto() const933 HloInstructionProto HloReverseInstruction::ToProto() const {
934   HloInstructionProto proto = HloInstruction::ToProto();
935   for (int64_t dimension : dimensions_) {
936     proto.add_dimensions(dimension);
937   }
938   return proto;
939 }
940 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const941 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
942     const HloPrintOptions& options) const {
943   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
944 }
945 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const946 bool HloReverseInstruction::IdenticalSlowPath(
947     const HloInstruction& other,
948     const std::function<bool(const HloComputation*, const HloComputation*)>&
949         eq_computations) const {
950   const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
951   return dimensions() == casted_other.dimensions();
952 }
953 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const954 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
955     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
956     HloCloneContext* context) const {
957   CHECK_EQ(new_operands.size(), 1);
958   return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
959                                                   dimensions());
960 }
961 
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)962 HloConcatenateInstruction::HloConcatenateInstruction(
963     const Shape& shape, absl::Span<HloInstruction* const> operands,
964     int64_t dimension)
965     : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
966   for (auto operand : operands) {
967     AppendOperand(operand);
968   }
969 }
970 
ToProto() const971 HloInstructionProto HloConcatenateInstruction::ToProto() const {
972   HloInstructionProto proto = HloInstruction::ToProto();
973   for (int64_t dimension : dimensions_) {
974     proto.add_dimensions(dimension);
975   }
976   return proto;
977 }
978 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const979 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
980     const HloPrintOptions& options) const {
981   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
982 }
983 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const984 bool HloConcatenateInstruction::IdenticalSlowPath(
985     const HloInstruction& other,
986     const std::function<bool(const HloComputation*, const HloComputation*)>&
987         eq_computations) const {
988   const auto& casted_other =
989       static_cast<const HloConcatenateInstruction&>(other);
990   return dimensions() == casted_other.dimensions();
991 }
992 
993 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const994 HloConcatenateInstruction::CloneWithNewOperandsImpl(
995     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
996     HloCloneContext* context) const {
997   return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
998                                                       dimensions(0));
999 }
1000 
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1001 HloReduceInstruction::HloReduceInstruction(
1002     const Shape& shape, absl::Span<HloInstruction* const> args,
1003     absl::Span<const int64> dimensions_to_reduce,
1004     HloComputation* reduce_computation)
1005     : HloInstruction(HloOpcode::kReduce, shape),
1006       dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
1007   for (HloInstruction* arg : args) {
1008     AppendOperand(arg);
1009   }
1010   AppendComputation(reduce_computation);
1011 }
1012 
ToProto() const1013 HloInstructionProto HloReduceInstruction::ToProto() const {
1014   HloInstructionProto proto = HloInstruction::ToProto();
1015   for (int64_t dimension : dimensions_) {
1016     proto.add_dimensions(dimension);
1017   }
1018   return proto;
1019 }
1020 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1021 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
1022     const HloPrintOptions& options) const {
1023   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1024 }
1025 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1026 bool HloReduceInstruction::IdenticalSlowPath(
1027     const HloInstruction& other,
1028     const std::function<bool(const HloComputation*, const HloComputation*)>&
1029         eq_computations) const {
1030   const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
1031   // Reduction results are determined by the reduction dimension and the
1032   // reduction computation.
1033   return dimensions() == casted_other.dimensions() &&
1034          eq_computations(to_apply(), casted_other.to_apply());
1035 }
1036 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1037 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
1038     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1039     HloCloneContext* context) const {
1040   CHECK_EQ(new_operands.size() % 2, 0);
1041   return absl::make_unique<HloReduceInstruction>(shape, new_operands,
1042                                                  dimensions(), to_apply());
1043 }
1044 
HloSortInstruction(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1045 HloSortInstruction::HloSortInstruction(
1046     const Shape& shape, int64_t dimension,
1047     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1048     bool is_stable)
1049     : HloInstruction(HloOpcode::kSort, shape),
1050       dimensions_({dimension}),
1051       is_stable_(is_stable) {
1052   for (auto* value : operands) {
1053     AppendOperand(value);
1054   }
1055   AppendComputation(compare);
1056 }
1057 
ToProto() const1058 HloInstructionProto HloSortInstruction::ToProto() const {
1059   HloInstructionProto proto = HloInstruction::ToProto();
1060   for (int64_t dimension : dimensions_) {
1061     proto.add_dimensions(dimension);
1062   }
1063   proto.set_is_stable(is_stable());
1064   return proto;
1065 }
1066 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1067 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
1068     const HloPrintOptions& options) const {
1069   std::vector<string> attrs;
1070   attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
1071   if (is_stable()) {
1072     attrs.push_back("is_stable=true");
1073   }
1074   return attrs;
1075 }
1076 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1077 bool HloSortInstruction::IdenticalSlowPath(
1078     const HloInstruction& other,
1079     const std::function<bool(const HloComputation*, const HloComputation*)>&
1080         eq_computations) const {
1081   const auto& casted_other = static_cast<const HloSortInstruction&>(other);
1082   if (dimensions() != casted_other.dimensions()) {
1083     return false;
1084   }
1085   if (is_stable() != casted_other.is_stable()) {
1086     return false;
1087   }
1088   return eq_computations(to_apply(), other.to_apply());
1089 }
1090 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1091 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
1092     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1093     HloCloneContext* context) const {
1094   return absl::make_unique<HloSortInstruction>(
1095       shape, dimensions(0), new_operands, to_apply(), is_stable());
1096 }
1097 
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1098 HloTransposeInstruction::HloTransposeInstruction(
1099     const Shape& shape, HloInstruction* operand,
1100     absl::Span<const int64> dimensions)
1101     : HloInstruction(HloOpcode::kTranspose, shape),
1102       dimensions_(dimensions.begin(), dimensions.end()) {
1103   AppendOperand(operand);
1104 }
1105 
IsRank2Transpose() const1106 bool HloTransposeInstruction::IsRank2Transpose() const {
1107   return dimensions() == std::vector<int64>({1, 0}) &&
1108          shape().dimensions_size() == 2 &&
1109          std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1110                     operand(0)->shape().dimensions().rbegin());
1111 }
1112 
ToProto() const1113 HloInstructionProto HloTransposeInstruction::ToProto() const {
1114   HloInstructionProto proto = HloInstruction::ToProto();
1115   for (int64_t dimension : dimensions_) {
1116     proto.add_dimensions(dimension);
1117   }
1118   return proto;
1119 }
1120 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1121 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
1122     const HloPrintOptions& options) const {
1123   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1124 }
1125 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1126 bool HloTransposeInstruction::IdenticalSlowPath(
1127     const HloInstruction& other,
1128     const std::function<bool(const HloComputation*, const HloComputation*)>&
1129         eq_computations) const {
1130   const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
1131   return dimensions() == casted_other.dimensions();
1132 }
1133 
1134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1135 HloTransposeInstruction::CloneWithNewOperandsImpl(
1136     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1137     HloCloneContext* context) const {
1138   CHECK_EQ(new_operands.size(), 1);
1139   return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1140                                                     dimensions());
1141 }
1142 
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)1143 HloBroadcastInstruction::HloBroadcastInstruction(
1144     const Shape& shape, HloInstruction* operand,
1145     absl::Span<const int64> broadcast_dimension)
1146     : HloInstruction(HloOpcode::kBroadcast, shape),
1147       dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
1148   AppendOperand(operand);
1149 }
1150 
ToProto() const1151 HloInstructionProto HloBroadcastInstruction::ToProto() const {
1152   HloInstructionProto proto = HloInstruction::ToProto();
1153   for (int64_t dimension : dimensions_) {
1154     proto.add_dimensions(dimension);
1155   }
1156   return proto;
1157 }
1158 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1159 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
1160     const HloPrintOptions& options) const {
1161   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1162 }
1163 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1164 bool HloBroadcastInstruction::IdenticalSlowPath(
1165     const HloInstruction& other,
1166     const std::function<bool(const HloComputation*, const HloComputation*)>&
1167         eq_computations) const {
1168   const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
1169   return dimensions() == casted_other.dimensions();
1170 }
1171 
1172 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1173 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1174     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1175     HloCloneContext* context) const {
1176   CHECK_EQ(new_operands.size(), 1);
1177   return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1178                                                     dimensions());
1179 }
1180 
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1181 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1182     const Shape& shape, HloInstruction* data_operand,
1183     absl::Span<HloInstruction* const> dim_sizes)
1184     : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1185   AppendOperand(data_operand);
1186   for (auto operand : dim_sizes) {
1187     AppendOperand(operand);
1188   }
1189 }
1190 
1191 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1192 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1193     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1194     HloCloneContext* context) const {
1195   CHECK_GE(new_operands.size(), 1);
1196   return absl::make_unique<HloDynamicReshapeInstruction>(
1197       shape, new_operands[0], new_operands.subspan(1));
1198 }
1199 
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1200 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1201                                              HloInstruction* operand,
1202                                              int64_t inferred_dimension)
1203     : HloInstruction(HloOpcode::kReshape, shape),
1204       inferred_dimension_(inferred_dimension) {
1205   AppendOperand(operand);
1206 }
1207 
ToProto() const1208 HloInstructionProto HloReshapeInstruction::ToProto() const {
1209   HloInstructionProto proto = HloInstruction::ToProto();
1210   if (inferred_dimension_ != -1) {
1211     proto.add_dimensions(inferred_dimension_);
1212   }
1213   return proto;
1214 }
1215 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1216 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1217     const HloPrintOptions& options) const {
1218   if (inferred_dimension() == -1) {
1219     return {};
1220   }
1221   return {StrCat("inferred_dimension=", inferred_dimension())};
1222 }
1223 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1224 bool HloReshapeInstruction::IdenticalSlowPath(
1225     const HloInstruction& other,
1226     const std::function<bool(const HloComputation*, const HloComputation*)>&
1227         eq_computations) const {
1228   const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1229   return inferred_dimension() == casted_other.inferred_dimension();
1230 }
1231 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1232 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1233     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1234     HloCloneContext* context) const {
1235   CHECK_EQ(new_operands.size(), 1);
1236   return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1237                                                   inferred_dimension());
1238 }
1239 
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1240 HloMapInstruction::HloMapInstruction(const Shape& shape,
1241                                      absl::Span<HloInstruction* const> operands,
1242                                      HloComputation* map_computation)
1243     : HloInstruction(HloOpcode::kMap, shape) {
1244   for (auto operand : operands) {
1245     AppendOperand(operand);
1246   }
1247   AppendComputation(map_computation);
1248   // TODO(b/65689298) Remove code below once Map is generalized to accept
1249   // arbitrary map dimensions.
1250   dimensions_.resize(shape.rank());
1251   std::iota(dimensions_.begin(), dimensions_.end(), 0);
1252 }
1253 
ToProto() const1254 HloInstructionProto HloMapInstruction::ToProto() const {
1255   HloInstructionProto proto = HloInstruction::ToProto();
1256   for (int64_t dimension : dimensions_) {
1257     proto.add_dimensions(dimension);
1258   }
1259   return proto;
1260 }
1261 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1262 bool HloMapInstruction::IsElementwiseImpl(
1263     const absl::optional<int64>& operand_idx) const {
1264   if (!dimensions().empty()) {
1265     // Check that the map is executed in elementwise compatible dimensions.
1266     if (dimensions().size() != shape().dimensions_size()) {
1267       return false;
1268     }
1269     for (int i = 0; i < dimensions().size(); ++i) {
1270       if (dimensions()[i] != i) {
1271         return false;
1272       }
1273     }
1274   }
1275   return true;
1276 }
1277 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1278 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1279     const HloPrintOptions& options) const {
1280   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1281 }
1282 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1283 bool HloMapInstruction::IdenticalSlowPath(
1284     const HloInstruction& other,
1285     const std::function<bool(const HloComputation*, const HloComputation*)>&
1286         eq_computations) const {
1287   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1288   return eq_computations(to_apply(), casted_other.to_apply()) &&
1289          dimensions() == casted_other.dimensions();
1290 }
1291 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1292 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1293     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1294     HloCloneContext* context) const {
1295   return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1296 }
1297 
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1298 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1299                                          HloInstruction* operand,
1300                                          absl::Span<const int64> start_indices,
1301                                          absl::Span<const int64> limit_indices,
1302                                          absl::Span<const int64> strides)
1303     : HloInstruction(HloOpcode::kSlice, shape),
1304       slice_starts_(start_indices.begin(), start_indices.end()),
1305       slice_limits_(limit_indices.begin(), limit_indices.end()),
1306       slice_strides_(strides.begin(), strides.end()) {
1307   AppendOperand(operand);
1308   // For backward compatibility with old serialized computations: if there are
1309   // no strides, assume all strides are 1.
1310   // TODO(b/63317920): remove this code.
1311   if (slice_strides_.empty()) {
1312     slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1313   }
1314 }
1315 
ToProto() const1316 HloInstructionProto HloSliceInstruction::ToProto() const {
1317   HloInstructionProto proto = HloInstruction::ToProto();
1318   for (int i = 0; i < slice_starts_.size(); ++i) {
1319     auto* slice_dimension = proto.add_slice_dimensions();
1320     slice_dimension->set_start(slice_starts_[i]);
1321     slice_dimension->set_limit(slice_limits_[i]);
1322     slice_dimension->set_stride(slice_strides_[i]);
1323   }
1324   return proto;
1325 }
1326 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1327 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1328     const HloPrintOptions& options) const {
1329   std::vector<string> bounds;
1330   bounds.reserve(slice_starts_.size());
1331   const bool omit_stride = absl::c_all_of(
1332       slice_strides_, [](int64_t stride) { return stride == 1; });
1333   for (int i = 0; i < slice_starts_.size(); ++i) {
1334     string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1335     bounds.push_back(
1336         StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1337   }
1338   return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1339 }
1340 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1341 bool HloSliceInstruction::IdenticalSlowPath(
1342     const HloInstruction& other,
1343     const std::function<bool(const HloComputation*, const HloComputation*)>&
1344         eq_computations) const {
1345   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1346   return slice_starts_ == other_slice.slice_starts_ &&
1347          slice_limits_ == other_slice.slice_limits_ &&
1348          slice_strides_ == other_slice.slice_strides_;
1349 }
1350 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1351 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1352     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1353     HloCloneContext* context) const {
1354   CHECK_EQ(new_operands.size(), 1);
1355   return absl::make_unique<HloSliceInstruction>(
1356       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1357 }
1358 
HloConstantInstruction(Literal literal)1359 HloConstantInstruction::HloConstantInstruction(Literal literal)
1360     : HloInstruction(HloOpcode::kConstant, literal.shape()),
1361       literal_(std::move(literal)) {}
1362 
HloConstantInstruction(Literal literal,const Shape & shape)1363 HloConstantInstruction::HloConstantInstruction(Literal literal,
1364                                                const Shape& shape)
1365     : HloInstruction(HloOpcode::kConstant, shape),
1366       literal_(std::move(literal)) {}
1367 
HloConstantInstruction(const Shape & shape)1368 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1369     : HloInstruction(HloOpcode::kConstant, shape) {}
1370 
ToProto() const1371 HloInstructionProto HloConstantInstruction::ToProto() const {
1372   HloInstructionProto proto = HloInstruction::ToProto();
1373   if (literal_.has_value()) {
1374     *proto.mutable_literal() = literal_->ToProto();
1375   }
1376   return proto;
1377 }
1378 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1379 bool HloConstantInstruction::IsElementwiseImpl(
1380     const absl::optional<int64>& operand_idx) const {
1381   return true;
1382 }
1383 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1384 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1385                                               const ShapeIndex& shape_index) {
1386   Shape* mutable_array_subshape =
1387       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1388   CHECK(mutable_array_subshape->IsArray());
1389 
1390   // Normally array_subshape will always have a layout, but this invariant is
1391   // temporarily broken in LayoutAssignment::AssignLayouts.
1392 
1393   if (!mutable_array_subshape->has_layout() ||
1394       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1395     *literal_ = literal_->Relayout(new_layout, shape_index);
1396     *mutable_array_subshape->mutable_layout() = new_layout;
1397   }
1398 }
1399 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1400 bool HloConstantInstruction::IdenticalSlowPath(
1401     const HloInstruction& other,
1402     const std::function<bool(const HloComputation*, const HloComputation*)>&
1403         eq_computations) const {
1404   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1405   return literal() == other_slice.literal();
1406 }
1407 
1408 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1409 HloConstantInstruction::CloneWithNewOperandsImpl(
1410     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1411     HloCloneContext* context) const {
1412   CHECK(literal_.has_value());
1413   // Literal's shape may have no/different tiling info. Use this instruction's
1414   // shape instead.
1415   CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1416                                                   this->shape()));
1417   return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1418                                                    this->shape());
1419 }
1420 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1421 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1422     const HloPrintOptions& options,
1423     CanonicalNameMap* canonical_name_map) const {
1424   if (options.print_only_essential_constants()) {
1425     if (!literal_.has_value()) {
1426       return "{...}";
1427     }
1428     if (literal().IsAll(0)) {
1429       return "0";
1430     }
1431     if (literal().IsAll(1)) {
1432       return "1";
1433     }
1434     if (shape().IsInteger()) {
1435       return literal_->ToStringWithoutShapeOneline();
1436     }
1437     return "{...}";
1438   }
1439 
1440   // For constants, show the actual value in place of an empty operand list.
1441   if (literal_.has_value() &&
1442       ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1443        options.print_large_constants())) {
1444     // Literal::ToString emits multidimensional arrays over multiple
1445     // lines. Compact this into one line by stripping out white space.
1446     return literal_->ToStringWithoutShapeOneline();
1447   } else {
1448     // Do not show large constants or tuples.
1449     return "{...}";
1450   }
1451 }
1452 
HloTraceInstruction(const string & tag,HloInstruction * operand)1453 HloTraceInstruction::HloTraceInstruction(const string& tag,
1454                                          HloInstruction* operand)
1455     : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1456       literal_(LiteralUtil::CreateR1U8(tag)) {
1457   AppendOperand(operand);
1458   operand->set_tracing(this);
1459 }
1460 
ToProto() const1461 HloInstructionProto HloTraceInstruction::ToProto() const {
1462   HloInstructionProto proto = HloInstruction::ToProto();
1463   *proto.mutable_literal() = literal_.ToProto();
1464   return proto;
1465 }
1466 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1467 bool HloTraceInstruction::IdenticalSlowPath(
1468     const HloInstruction& other,
1469     const std::function<bool(const HloComputation*, const HloComputation*)>&
1470         eq_computations) const {
1471   return false;
1472 }
1473 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1474 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1475     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1476     HloCloneContext* context) const {
1477   LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1478 }
1479 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1480 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1481                                            FusionKind fusion_kind,
1482                                            HloInstruction* fused_root)
1483     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1484   CHECK(fused_root != nullptr);
1485   SetAndSanitizeName("fusion");
1486   set_parent(fused_root->parent());
1487   set_metadata(fused_root->metadata());
1488   CloneAndFuseInternal(fused_root);
1489 }
1490 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1491 HloFusionInstruction::HloFusionInstruction(
1492     const Shape& shape, FusionKind fusion_kind,
1493     absl::Span<HloInstruction* const> operands,
1494     HloComputation* fusion_computation)
1495     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1496   for (auto operand : operands) {
1497     AppendOperand(operand);
1498   }
1499   SetAndSanitizeName("fusion");
1500   AppendComputation(fusion_computation);
1501   fusion_computation->SetFusionInstruction(this);
1502 }
1503 
~HloFusionInstruction()1504 HloFusionInstruction::~HloFusionInstruction() {
1505   ClearFusionComputationInstruction();
1506 }
1507 
ClearFusionComputationInstruction()1508 void HloFusionInstruction::ClearFusionComputationInstruction() {
1509   // Each fusion calls a single computation, but we use called_computations()
1510   // instead of fused_instructions_computation(), because the order in which
1511   // things get destructed can vary; the fusion computation's back-pointer may
1512   // already be null, which violates a check in fused_instructions_computation.
1513   for (HloComputation* computation : called_computations()) {
1514     // Some passes that rewrite fusions may reassign a fusion computation to a
1515     // different fusion instruction as this instruction gets destructed.
1516     if (computation->FusionInstruction() == this) {
1517       computation->SetFusionInstruction(nullptr);
1518     }
1519   }
1520 }
1521 
ClearCalledComputations()1522 void HloFusionInstruction::ClearCalledComputations() {
1523   ClearFusionComputationInstruction();
1524   HloInstruction::ClearCalledComputations();
1525 }
1526 
ToCategory() const1527 string HloFusionInstruction::ToCategory() const {
1528   switch (fusion_kind()) {
1529     case FusionKind::kLoop:
1530       return "loop fusion";
1531     case FusionKind::kInput:
1532       return "input fusion";
1533     case FusionKind::kOutput:
1534       return "output fusion";
1535     case FusionKind::kCustom:
1536       return "custom fusion";
1537   }
1538 }
1539 
ToProto() const1540 HloInstructionProto HloFusionInstruction::ToProto() const {
1541   HloInstructionProto proto = HloInstruction::ToProto();
1542   proto.set_fusion_kind(xla::ToString(fusion_kind()));
1543   proto.add_called_computation_ids(
1544       fused_instructions_computation()->unique_id());
1545   return proto;
1546 }
1547 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1548 bool HloFusionInstruction::IsElementwiseImpl(
1549     const absl::optional<int64>& operand_idx) const {
1550   if (!operand_idx.has_value()) {
1551     for (auto* fused : fused_instructions()) {
1552       if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1553         return false;
1554       }
1555     }
1556     return true;
1557   }
1558   // A loop-fusion is elementwise on an operand if all operations (computed
1559   // using BFS) between the operand and the fused root are elementwise.
1560   std::deque<HloInstruction*> worklist;
1561   std::unordered_set<const HloInstruction*> visited;
1562   worklist.push_back(fused_parameter(operand_idx.value()));
1563   visited.insert(fused_parameter(operand_idx.value()));
1564   while (!worklist.empty()) {
1565     HloInstruction* operand = worklist.front();
1566     worklist.pop_front();
1567     for (HloInstruction* user : operand->users()) {
1568       CHECK_GE(user->unique_id(), 0);
1569       if (ContainsKey(visited, user)) {
1570         continue;
1571       }
1572       if (user->IsElementwise() ||
1573           IsInstructionElementwiseOnOperand(user, operand)) {
1574         worklist.push_back(user);
1575         visited.insert(user);
1576       } else {
1577         return false;
1578       }
1579     }
1580   }
1581   return true;
1582 }
1583 
AddFusionOperand(HloInstruction * new_operand)1584 HloInstruction* HloFusionInstruction::AddFusionOperand(
1585     HloInstruction* new_operand) {
1586   CHECK_EQ(operand_count(),
1587            fused_instructions_computation()->parameter_instructions().size());
1588   const int64_t param_no = operand_count();
1589   string param_name = StrCat("param_", param_no);
1590   HloInstruction* fused_parameter =
1591       fused_instructions_computation()->AddParameter(
1592           HloInstruction::CreateParameter(param_no, new_operand->shape(),
1593                                           param_name));
1594   AppendOperand(new_operand);
1595   return fused_parameter;
1596 }
1597 
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1598 void HloFusionInstruction::MergeFusionInstruction(
1599     HloFusionInstruction* instruction_to_merge) {
1600   CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1601   // Clone the instruction from which to merge fused instructions.
1602   std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1603   HloFusionInstruction* cloned_fusion =
1604       static_cast<HloFusionInstruction*>(cloned.get());
1605   // Replace uses of fused parameters with the corresponding operand of the
1606   // fusion.  Add all non-parameter fused instructions to
1607   // 'unfused_instructions' to be merged into 'this'.  This is done in reverse
1608   // post order.
1609   std::vector<HloInstruction*> unfused_instructions;
1610   auto fused_instructions = cloned_fusion->fused_instructions_computation()
1611                                 ->MakeInstructionPostOrder();
1612   for (auto fused_it = fused_instructions.rbegin();
1613        fused_it != fused_instructions.rend(); ++fused_it) {
1614     auto fused_instruction = *fused_it;
1615     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1616       TF_CHECK_OK(
1617           fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1618               fused_instruction->parameter_number())));
1619     } else {
1620       unfused_instructions.push_back(fused_instruction);
1621     }
1622   }
1623 
1624   // If there are no unfused instructions, the fused computation must consist
1625   // only of kParameter instructions. Make the operand of the corresponding
1626   // parameter number the new root.
1627   HloInstruction* unfused_root =
1628       unfused_instructions.empty()
1629           ? instruction_to_merge->mutable_operand(
1630                 instruction_to_merge->fused_instructions_computation()
1631                     ->root_instruction()
1632                     ->parameter_number())
1633           : unfused_instructions.front();
1634   CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1635         unfused_instructions.empty());
1636   // Replace instruction_to_merge use of 'this' with unfused_root.
1637   TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1638 
1639   // Build a dummy root for the cloned fusion as we may remove the original root
1640   // in the fusion process.
1641   if (!unfused_instructions.empty()) {
1642     HloComputation* computation = unfused_root->parent();
1643     auto* dummy_root = computation->AddInstruction(
1644         HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1645     computation->set_root_instruction(dummy_root,
1646                                       /*accept_different_shape=*/true);
1647   }
1648 
1649   // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1650   // we remove it from the closed fusion node. This is so that we don't add
1651   // extra users to the producer of that instruction (we use user count to
1652   // decide if a side-effectful instruction is fusible).
1653   for (auto& instruction : unfused_instructions) {
1654     auto* fused = FuseInstruction(instruction);
1655     TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1656     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1657   }
1658   CHECK_EQ(0, cloned_fusion->user_count());
1659   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1660       cloned_fusion->fused_instructions_computation()));
1661 }
1662 
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1663 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1664     HloFusionInstruction* instruction_to_merge) {
1665   // Add all non-parameter fused instructions to 'unfused_instructions' to be
1666   // merged into 'this'. `old_to_new' maps the instructions in the fused node
1667   // to the disassembled fusion instructions.
1668   // Note that we add the unfused instructions to this->parent_ computation.
1669   // This is necessary because the unique_id needs for an instruction and
1670   // it's only added when inserting to the computation.
1671   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1672   std::vector<HloInstruction*> unfused_instructions;
1673   auto computation_to_merge =
1674       instruction_to_merge->fused_instructions_computation();
1675   auto post_order = computation_to_merge->MakeInstructionPostOrder();
1676   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1677     auto fused_instruction = *rit;
1678     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1679       InsertOrDie(&old_to_new, fused_instruction,
1680                   instruction_to_merge->mutable_operand(
1681                       fused_instruction->parameter_number()));
1682       continue;
1683     }
1684 
1685     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1686     // which clones again. This can be improved.
1687     auto cloned_instruction =
1688         parent()->AddInstruction(fused_instruction->Clone());
1689     unfused_instructions.push_back(cloned_instruction);
1690     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1691   }
1692   for (auto unfused_instruction : unfused_instructions) {
1693     for (int64_t index = 0; index < unfused_instruction->operand_count();
1694          index++) {
1695       auto new_operand =
1696           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1697       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1698     }
1699   }
1700 
1701   // If there are no unfused instructions, the fused computation must consist
1702   // only of kParameter instructions. Make the operand of the corresponding
1703   // parameter number the new root.
1704   HloInstruction* unfused_root =
1705       unfused_instructions.empty()
1706           ? instruction_to_merge->mutable_operand(
1707                 instruction_to_merge->fused_instructions_computation()
1708                     ->root_instruction()
1709                     ->parameter_number())
1710           : unfused_instructions.front();
1711   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1712 
1713   TF_CHECK_OK(
1714       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1715   if (GetModule()) {
1716     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1717   }
1718 
1719   // Fuse the root instruction and generate multiple outputs.
1720   if (unfused_instructions.empty()) {
1721     return;
1722   }
1723   FuseInstructionIntoMultiOutput(unfused_root);
1724   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1725   // The rest instructions are of normal fusing.
1726   for (int64_t i = 1; i < unfused_instructions.size(); i++) {
1727     auto instruction = unfused_instructions[i];
1728     FuseInstruction(instruction);
1729     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1730   }
1731 }
1732 
fused_instructions_computation() const1733 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1734   CHECK(!called_computations().empty());
1735   auto* fused_instructions_computation = called_computations().front();
1736   CHECK(fused_instructions_computation->IsFusionComputation())
1737       << "Computation " << fused_instructions_computation->name()
1738       << " is not a fusion kind";
1739   return fused_instructions_computation;
1740 }
1741 
fused_expression_root() const1742 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1743   return fused_instructions_computation()->root_instruction();
1744 }
1745 
fused_parameter(int64_t parameter_number) const1746 HloInstruction* HloFusionInstruction::fused_parameter(
1747     int64_t parameter_number) const {
1748   return fused_instructions_computation()->parameter_instruction(
1749       parameter_number);
1750 }
1751 
fused_parameters() const1752 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1753     const {
1754   return fused_instructions_computation()->parameter_instructions();
1755 }
1756 
1757 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1758     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1759 HloFusionInstruction::fused_instructions() const {
1760   const HloComputation* subcomp = fused_instructions_computation();
1761   return subcomp->instructions();
1762 }
1763 
1764 const tensorflow::gtl::iterator_range<
1765     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1766 HloFusionInstruction::fused_instructions() {
1767   return fused_instructions_computation()->instructions();
1768 }
1769 
fused_instruction_count() const1770 int64 HloFusionInstruction::fused_instruction_count() const {
1771   return fused_instructions_computation()->instruction_count();
1772 }
1773 
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1774 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1775     HloInstruction* instruction_to_fuse, bool add_output) {
1776   // When add_output is false, this fusion instruction must be a user of
1777   // instruction_to_fuse.
1778   if (!add_output) {
1779     CHECK(IsUserOf(instruction_to_fuse));
1780   }
1781   HloInstruction* fused_instruction =
1782       CloneAndFuseInternal(instruction_to_fuse, add_output);
1783   return fused_instruction;
1784 }
1785 
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1786 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1787     HloInstruction* instruction_to_fuse, bool add_output) {
1788   CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1789   VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1790   HloInstruction* clone = nullptr;
1791   if (called_computations().empty()) {
1792     // New fusion instruction. It should not be a multioutput instruction.
1793     CHECK(!add_output);
1794     auto builder = HloComputation::Builder("fused_computation", this);
1795     builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1796     AppendComputation(
1797         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1798     clone = fused_expression_root();
1799   } else {
1800     // When add_output is false, instruction_to_fuse is necessarily an operand
1801     // of the fusion instruction. After fusion this will no longer be the
1802     // case. Remove the operand from the operand list and remove its
1803     // corresponding fused parameter instruction. Renumber parameters as
1804     // necessary to make parameter numbers consistent with their index in the
1805     // fused_parameter_ vector.
1806     bool in_operand_list =
1807         absl::c_linear_search(operands(), instruction_to_fuse);
1808     CHECK(add_output || in_operand_list);
1809     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1810       // We assume all uses of a kTuple operation are GTE ops, not another
1811       // fusion node. In this case, we don't need to clone
1812       // 'instruction_to_fuse'.
1813       CHECK(!in_operand_list);
1814       clone = instruction_to_fuse;
1815     } else {
1816       clone = fused_instructions_computation()->AddInstruction(
1817           instruction_to_fuse->Clone(/*suffix=*/""));
1818     }
1819     const std::vector<HloInstruction*>& fused_parameters =
1820         fused_instructions_computation()->parameter_instructions();
1821     for (int64_t operand_num = 0; operand_num < operand_count();
1822          ++operand_num) {
1823       if (instruction_to_fuse == operand(operand_num)) {
1824         // replace the fused parameter instruction's uses with the clone.
1825         HloInstruction* fused_parameter = fused_parameters[operand_num];
1826         TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1827 
1828         // Remove the corresponding fused parameter and operand from their
1829         // respective vectors.
1830         TF_CHECK_OK(
1831             fused_instructions_computation()->RemoveParameter(operand_num));
1832         RemoveOperandAt(operand_num);
1833         break;
1834       }
1835     }
1836     // We've cloned instruction_to_fuse into this fusion instruction, so this
1837     // fusion instruction is no longer a use of instruction_to_fuse.
1838     if (in_operand_list) {
1839       DetachFrom(instruction_to_fuse);
1840       // When the instruction_to_fuse does not have other users, we don't need
1841       // to generate a multioutput fusion instruction.
1842       if (instruction_to_fuse->user_count() == 0) {
1843         add_output = false;
1844       }
1845     }
1846   }
1847 
1848   // Reread the parameters in the computation.
1849   const std::vector<HloInstruction*>& fused_parameters =
1850       fused_instructions_computation()->parameter_instructions();
1851 
1852   // Add each operand of the clone as an operand of the fusion instruction. A
1853   // complication is that some clone operands may already be operands of the
1854   // fusion instruction.
1855   for (int64_t operand_num = 0; operand_num < clone->operand_count();
1856        ++operand_num) {
1857     HloInstruction* operand = clone->mutable_operand(operand_num);
1858 
1859     // See if this operand is already an operand of the fusion node.
1860     CHECK_EQ(operands().size(), fused_parameters.size());
1861     HloInstruction* fused_param = nullptr;
1862     for (int64_t i = 0; i < operands().size(); ++i) {
1863       if (this->operand(i) == operand) {
1864         fused_param = fused_parameters[i];
1865         break;
1866       }
1867     }
1868 
1869     if (fused_param == nullptr) {
1870       // Clone's operand was not already an operand of the fusion
1871       // instruction. Add it as an operand and add a corresponding fused
1872       // parameter instruction.
1873       fused_param = AddFusionOperand(operand);
1874     }
1875     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1876   }
1877 
1878   if (add_output) {
1879     CHECK_GT(instruction_to_fuse->user_count(), 0);
1880     // If this is already a multioutput fusion instruction, expand the root
1881     // tuple by 1.
1882     HloInstruction* fused_root = fused_expression_root();
1883     HloInstruction::InstructionVector tuple_elements;
1884     bool newly_created_tuple_instr = false;
1885     if (fused_root->opcode() == HloOpcode::kTuple) {
1886       tuple_elements = fused_root->operands();
1887     } else {
1888       tuple_elements.push_back(fused_root);
1889       newly_created_tuple_instr = true;
1890     }
1891     if (clone->opcode() == HloOpcode::kTuple) {
1892       for (auto inst : clone->operands()) {
1893         tuple_elements.push_back(inst);
1894       }
1895     } else {
1896       tuple_elements.push_back(clone);
1897     }
1898     HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1899         HloInstruction::CreateTuple(tuple_elements));
1900     fused_instructions_computation()->set_root_instruction(new_root);
1901     *mutable_shape() = new_root->shape();
1902     if (fused_root->opcode() == HloOpcode::kTuple) {
1903       TF_CHECK_OK(
1904           fused_instructions_computation()->RemoveInstruction(fused_root));
1905     }
1906 
1907     // If this is a newly created multioutput instruction, we need to update
1908     // the use of the original fusion instruction.
1909     if (newly_created_tuple_instr) {
1910       HloInstruction* new_instr = parent()->AddInstruction(
1911           HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1912       TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1913     }
1914     int64_t index = tuple_elements.size();
1915     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1916       CHECK_EQ(clone, instruction_to_fuse);
1917       index -= clone->operand_count();
1918       std::vector<HloInstruction*> to_be_removed;
1919       for (auto old_gte : clone->users()) {
1920         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1921         int64_t old_tuple_index = old_gte->tuple_index();
1922         HloInstruction* new_gte =
1923             parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1924                 old_gte->shape(), this, index + old_tuple_index));
1925         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1926         to_be_removed.push_back(old_gte);
1927       }
1928       for (auto old_gte : to_be_removed) {
1929         TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1930       }
1931     } else {
1932       HloInstruction* new_gte =
1933           parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1934               clone->shape(), this, index - 1));
1935       TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1936     }
1937   }
1938 
1939   if (clone != instruction_to_fuse) {
1940     VLOG(2) << "New clone:\n" << clone->ToString();
1941   }
1942   return clone;
1943 }
1944 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1945 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1946     const HloPrintOptions& options) const {
1947   return {StrCat("kind=", xla::ToString(fusion_kind()))};
1948 }
1949 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1950 bool HloFusionInstruction::IdenticalSlowPath(
1951     const HloInstruction& other,
1952     const std::function<bool(const HloComputation*, const HloComputation*)>&
1953         eq_computations) const {
1954   return fusion_kind() == other.fusion_kind() &&
1955          eq_computations(fused_instructions_computation(),
1956                          other.fused_instructions_computation());
1957 }
1958 
InnerHash() const1959 uint64 HloFusionInstruction::InnerHash() const {
1960   return fused_instructions_computation()->root_instruction()->Hash();
1961 }
1962 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1963 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1964     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1965     HloCloneContext* context) const {
1966   HloModule* module = context != nullptr ? context->module() : GetModule();
1967   HloComputation* new_fused_computation = nullptr;
1968   if (context != nullptr) {
1969     new_fused_computation =
1970         context->FindComputation(fused_instructions_computation());
1971   }
1972   if (new_fused_computation == nullptr) {
1973     new_fused_computation = module->AddEmbeddedComputation(
1974         fused_instructions_computation()->Clone("clone", context));
1975   }
1976   return absl::make_unique<HloFusionInstruction>(
1977       shape, fusion_kind(), new_operands, new_fused_computation);
1978 }
1979 
DeduplicateFusionOperands()1980 Status HloFusionInstruction::DeduplicateFusionOperands() {
1981   if (IsCustomFusion()) {
1982     return Status::OK();
1983   }
1984   absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1985   std::vector<int> operands_to_remove;
1986   for (int i = 0; i < operand_count(); ++i) {
1987     auto emplace_result = operand_indices.emplace(operand(i), i);
1988     if (!emplace_result.second) {
1989       TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1990           fused_parameter(emplace_result.first->second)));
1991       operands_to_remove.push_back(i);
1992     }
1993   }
1994   if (operands_to_remove.empty()) {
1995     return Status::OK();
1996   }
1997   TF_RETURN_IF_ERROR(fused_instructions_computation()
1998                          ->RemoveUnusedParametersFromFusedComputation());
1999   RemoveOperandsAtAscendingIndices(operands_to_remove);
2000   return Status::OK();
2001 }
2002 
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)2003 HloRngInstruction::HloRngInstruction(
2004     const Shape& shape, RandomDistribution distribution,
2005     absl::Span<HloInstruction* const> parameters)
2006     : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
2007   for (HloInstruction* param : parameters) {
2008     AppendOperand(param);
2009   }
2010 }
2011 
ToProto() const2012 HloInstructionProto HloRngInstruction::ToProto() const {
2013   HloInstructionProto proto = HloInstruction::ToProto();
2014   proto.set_distribution(distribution_);
2015   return proto;
2016 }
2017 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2018 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
2019     const HloPrintOptions& options) const {
2020   return {StrCat("distribution=", RandomDistributionToString(distribution_))};
2021 }
2022 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2023 bool HloRngInstruction::IsElementwiseImpl(
2024     const absl::optional<int64>& operand_idx) const {
2025   return true;
2026 }
2027 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2028 bool HloRngInstruction::IdenticalSlowPath(
2029     const HloInstruction& other,
2030     const std::function<bool(const HloComputation*, const HloComputation*)>&
2031         eq_computations) const {
2032   const auto& casted_other = static_cast<const HloRngInstruction&>(other);
2033   return distribution_ == casted_other.distribution_;
2034 }
2035 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2036 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
2037     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2038     HloCloneContext* context) const {
2039   return absl::make_unique<HloRngInstruction>(shape, distribution_,
2040                                               new_operands);
2041 }
2042 
HloParameterInstruction(int64_t parameter_number,const Shape & shape,const string & name)2043 HloParameterInstruction::HloParameterInstruction(int64_t parameter_number,
2044                                                  const Shape& shape,
2045                                                  const string& name)
2046     : HloInstruction(HloOpcode::kParameter, shape),
2047       parameter_number_(parameter_number) {
2048   SetAndSanitizeName(name);
2049 }
2050 
ToProto() const2051 HloInstructionProto HloParameterInstruction::ToProto() const {
2052   HloInstructionProto proto = HloInstruction::ToProto();
2053   proto.set_parameter_number(parameter_number_);
2054   if (parameter_replicated_at_leaf_buffers_) {
2055     for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2056       proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
2057           replicated);
2058     }
2059   }
2060   return proto;
2061 }
2062 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2063 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
2064     const HloPrintOptions& options) const {
2065   std::vector<string> result;
2066   if (!parameter_replicated_at_leaf_buffers_) {
2067     return result;
2068   }
2069   std::vector<string> buffers_replicated_strs;
2070   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2071     buffers_replicated_strs.push_back(replicated ? "true" : "false");
2072   }
2073   if (options.print_ids()) {
2074     result.push_back(StrCat("parameter_replication={",
2075                             StrJoin(buffers_replicated_strs, ","), "}"));
2076   }
2077   return result;
2078 }
2079 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2080 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
2081     const HloPrintOptions& options,
2082     CanonicalNameMap* canonical_name_map) const {
2083   return StrCat(parameter_number_);
2084 }
2085 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2086 bool HloParameterInstruction::IdenticalSlowPath(
2087     const HloInstruction& other,
2088     const std::function<bool(const HloComputation*, const HloComputation*)>&
2089         eq_computations) const {
2090   const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
2091   return parameter_number() == casted_other.parameter_number();
2092 }
2093 
2094 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2095 HloParameterInstruction::CloneWithNewOperandsImpl(
2096     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2097     HloCloneContext* context) const {
2098   auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_,
2099                                                           shape, name());
2100   if (parameter_replicated_at_leaf_buffers_ &&
2101       ShapeUtil::Equal(shape, this->shape())) {
2102     clone->set_parameter_replicated_at_leaf_buffers(
2103         *parameter_replicated_at_leaf_buffers_);
2104   }
2105   return clone;
2106 }
2107 
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64_t index)2108 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
2109     const Shape& shape, HloInstruction* operand, int64_t index)
2110     : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
2111   AppendOperand(operand);
2112 }
2113 
ToProto() const2114 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
2115   HloInstructionProto proto = HloInstruction::ToProto();
2116   proto.set_tuple_index(tuple_index_);
2117   return proto;
2118 }
2119 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2120 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
2121     const HloPrintOptions& options) const {
2122   return {StrCat("index=", tuple_index())};
2123 }
2124 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2125 bool HloGetTupleElementInstruction::IdenticalSlowPath(
2126     const HloInstruction& other,
2127     const std::function<bool(const HloComputation*, const HloComputation*)>&
2128         eq_computations) const {
2129   const auto& casted_other =
2130       static_cast<const HloGetTupleElementInstruction&>(other);
2131   return tuple_index() == casted_other.tuple_index();
2132 }
2133 
2134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2135 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
2136     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2137     HloCloneContext* context) const {
2138   CHECK_EQ(new_operands.size(), 1);
2139   return absl::make_unique<HloGetTupleElementInstruction>(
2140       shape, new_operands[0], tuple_index());
2141 }
2142 
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2143 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2144     const Shape& shape, HloInstruction* operand, const int exponent_bits,
2145     const int mantissa_bits)
2146     : HloInstruction(HloOpcode::kReducePrecision, shape),
2147       exponent_bits_(exponent_bits),
2148       mantissa_bits_(mantissa_bits) {
2149   AppendOperand(operand);
2150 }
2151 
ToProto() const2152 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2153   HloInstructionProto proto = HloInstruction::ToProto();
2154   proto.set_exponent_bits(exponent_bits_);
2155   proto.set_mantissa_bits(mantissa_bits_);
2156   return proto;
2157 }
2158 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2159 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2160     const HloPrintOptions& options) const {
2161   return {StrCat("exponent_bits=", exponent_bits_),
2162           StrCat("mantissa_bits=", mantissa_bits_)};
2163 }
2164 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2165 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2166     const HloInstruction& other,
2167     const std::function<bool(const HloComputation*, const HloComputation*)>&
2168         eq_computations) const {
2169   const auto& casted_other =
2170       static_cast<const HloReducePrecisionInstruction&>(other);
2171   // A reduce-precision operation is determined by the bit sizes.
2172   return exponent_bits() == casted_other.exponent_bits() &&
2173          mantissa_bits() == casted_other.mantissa_bits();
2174 }
2175 
2176 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2177 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2178     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2179     HloCloneContext* context) const {
2180   CHECK_EQ(new_operands.size(), 1);
2181   return absl::make_unique<HloReducePrecisionInstruction>(
2182       shape, new_operands[0], exponent_bits(), mantissa_bits());
2183 }
2184 
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)2185 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2186                                            HloInstruction* token_operand,
2187                                            const string& config)
2188     : HloInstruction(HloOpcode::kInfeed,
2189                      ShapeUtil::MakeTupleShape(
2190                          {infeed_shape, ShapeUtil::MakeTokenShape()})),
2191       infeed_config_(config) {
2192   AppendOperand(token_operand);
2193 }
2194 
ToProto() const2195 HloInstructionProto HloInfeedInstruction::ToProto() const {
2196   HloInstructionProto proto = HloInstruction::ToProto();
2197   proto.set_infeed_config(infeed_config_);
2198   return proto;
2199 }
2200 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2201 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2202     const HloPrintOptions& options) const {
2203   if (!options.print_infeed_outfeed_config() || infeed_config_.empty()) {
2204     return {};
2205   }
2206   return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2207 }
2208 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2209 bool HloInfeedInstruction::IdenticalSlowPath(
2210     const HloInstruction& other,
2211     const std::function<bool(const HloComputation*, const HloComputation*)>&
2212         eq_computations) const {
2213   // Not yet supported.
2214   return false;
2215 }
2216 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2217 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2218     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2219     HloCloneContext* context) const {
2220   CHECK_EQ(new_operands.size(), 1);
2221   return absl::make_unique<HloInfeedInstruction>(
2222       infeed_shape(), new_operands[0], infeed_config());
2223 }
2224 
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2225 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2226                                              HloInstruction* operand,
2227                                              HloInstruction* token_operand,
2228                                              absl::string_view outfeed_config)
2229     : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2230       outfeed_shape_(outfeed_shape),
2231       outfeed_config_(outfeed_config) {
2232   AppendOperand(operand);
2233   AppendOperand(token_operand);
2234 }
2235 
ToProto() const2236 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2237   HloInstructionProto proto = HloInstruction::ToProto();
2238   proto.set_outfeed_config(outfeed_config());
2239   *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2240   return proto;
2241 }
2242 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2243 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2244     const HloPrintOptions& options) const {
2245   std::vector<string> extra;
2246   extra.push_back(StrCat("outfeed_shape=",
2247                          ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2248   if (options.print_infeed_outfeed_config() && !outfeed_config_.empty()) {
2249     extra.push_back(
2250         StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2251   }
2252   return extra;
2253 }
2254 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2255 bool HloOutfeedInstruction::IdenticalSlowPath(
2256     const HloInstruction& other,
2257     const std::function<bool(const HloComputation*, const HloComputation*)>&
2258         eq_computations) const {
2259   // Not yet supported.
2260   return false;
2261 }
2262 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2263 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2264     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2265     HloCloneContext* context) const {
2266   CHECK_EQ(new_operands.size(), 2);
2267   return absl::make_unique<HloOutfeedInstruction>(
2268       outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2269 }
2270 
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2271 HloConvolutionInstruction::HloConvolutionInstruction(
2272     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2273     int64_t feature_group_count, int64_t batch_group_count,
2274     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
2275     const PrecisionConfig& precision_config)
2276     : HloInstruction(HloOpcode::kConvolution, shape),
2277       feature_group_count_(feature_group_count),
2278       batch_group_count_(batch_group_count),
2279       window_(window),
2280       convolution_dimension_numbers_(dimension_numbers),
2281       precision_config_(precision_config) {
2282   if (window_util::HasBaseDilation(window)) {
2283     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2284   }
2285   if (window_util::HasWindowDilation(window)) {
2286     SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2287   }
2288   AppendOperand(lhs);
2289   AppendOperand(rhs);
2290 }
2291 
ToCategory() const2292 string HloConvolutionInstruction::ToCategory() const {
2293   string category = "convolution";
2294   if (window_util::HasBaseDilation(window())) {
2295     category += " base-dilated";
2296   }
2297   if (window_util::HasWindowDilation(window())) {
2298     category += " window-dilated";
2299   }
2300   return category;
2301 }
2302 
ToProto() const2303 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2304   HloInstructionProto proto = HloInstruction::ToProto();
2305   *proto.mutable_window() = window_;
2306   *proto.mutable_convolution_dimension_numbers() =
2307       convolution_dimension_numbers_;
2308   proto.set_feature_group_count(feature_group_count_);
2309   proto.set_batch_group_count(batch_group_count_);
2310   *proto.mutable_precision_config() = precision_config_;
2311   return proto;
2312 }
2313 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2314 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2315     const HloPrintOptions& options) const {
2316   std::vector<string> extra;
2317   if (window_.dimensions_size() != 0) {
2318     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2319   }
2320   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2321                                             convolution_dimension_numbers_)));
2322   if (feature_group_count_ != 1) {
2323     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2324   }
2325 
2326   if (batch_group_count_ != 1) {
2327     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2328   }
2329 
2330   string precision_config_string = PrecisionConfigToString(precision_config_);
2331   if (!precision_config_string.empty()) {
2332     extra.push_back(precision_config_string);
2333   }
2334   return extra;
2335 }
2336 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2337 bool HloConvolutionInstruction::IdenticalSlowPath(
2338     const HloInstruction& other,
2339     const std::function<bool(const HloComputation*, const HloComputation*)>&
2340         eq_computations) const {
2341   const auto& casted_other =
2342       static_cast<const HloConvolutionInstruction&>(other);
2343   if (feature_group_count_ != other.feature_group_count()) {
2344     return false;
2345   }
2346   if (batch_group_count_ != other.batch_group_count()) {
2347     return false;
2348   }
2349   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2350          protobuf_util::ProtobufEquals(
2351              convolution_dimension_numbers(),
2352              casted_other.convolution_dimension_numbers()) &&
2353          protobuf_util::ProtobufEquals(precision_config(),
2354                                        casted_other.precision_config());
2355 }
2356 
2357 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2358 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2359     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2360     HloCloneContext* context) const {
2361   CHECK_EQ(new_operands.size(), 2);
2362   return absl::make_unique<HloConvolutionInstruction>(
2363       shape, new_operands[0], new_operands[1], feature_group_count_,
2364       batch_group_count_, window(), convolution_dimension_numbers_,
2365       precision_config_);
2366 }
2367 
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2368 HloReduceWindowInstruction::HloReduceWindowInstruction(
2369     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2370     const Window& window, HloComputation* reduce_computation)
2371     : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2372                                  absl::MakeSpan(&init_value, 1), window,
2373                                  reduce_computation) {}
2374 
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2375 HloReduceWindowInstruction::HloReduceWindowInstruction(
2376     const Shape& shape, absl::Span<HloInstruction* const> operands,
2377     absl::Span<HloInstruction* const> init_values, const Window& window,
2378     HloComputation* reduce_computation)
2379     : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2380   for (auto* operand : operands) {
2381     AppendOperand(operand);
2382   }
2383   for (auto* init_value : init_values) {
2384     AppendOperand(init_value);
2385   }
2386   AppendComputation(reduce_computation);
2387 }
2388 
ToProto() const2389 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2390   HloInstructionProto proto = HloInstruction::ToProto();
2391   *proto.mutable_window() = window_;
2392   return proto;
2393 }
2394 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2395 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2396     const HloPrintOptions& options) const {
2397   std::vector<string> extra;
2398   if (window_.dimensions_size() != 0) {
2399     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2400   }
2401   return extra;
2402 }
2403 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2404 bool HloReduceWindowInstruction::IdenticalSlowPath(
2405     const HloInstruction& other,
2406     const std::function<bool(const HloComputation*, const HloComputation*)>&
2407         eq_computations) const {
2408   const auto& casted_other =
2409       static_cast<const HloReduceWindowInstruction&>(other);
2410   return eq_computations(to_apply(), casted_other.to_apply()) &&
2411          protobuf_util::ProtobufEquals(window(), casted_other.window());
2412 }
2413 
2414 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2415 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2416     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2417     HloCloneContext* context) const {
2418   CHECK_EQ(new_operands.size() % 2, 0);
2419   int64_t num_operands = new_operands.size() / 2;
2420   return absl::make_unique<HloReduceWindowInstruction>(
2421       shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2422       absl::MakeSpan(new_operands)
2423           .subspan(num_operands, new_operands.size() / 2),
2424       window(), to_apply());
2425 }
2426 
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2427 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2428     const Shape& shape, HloInstruction* operand, HloComputation* select,
2429     const Window& window, HloInstruction* source, HloInstruction* init_value,
2430     HloComputation* scatter)
2431     : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2432   AppendOperand(operand);
2433   AppendOperand(source);
2434   AppendOperand(init_value);
2435   // Select comes before scatter in the vector.
2436   AppendComputation(select);
2437   AppendComputation(scatter);
2438 }
2439 
ToProto() const2440 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2441   HloInstructionProto proto = HloInstruction::ToProto();
2442   *proto.mutable_window() = window_;
2443   return proto;
2444 }
2445 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2446 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2447     const HloPrintOptions& options) const {
2448   std::vector<string> extra;
2449   if (window_.dimensions_size() != 0) {
2450     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2451   }
2452   return extra;
2453 }
2454 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2455 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2456     const HloInstruction& other,
2457     const std::function<bool(const HloComputation*, const HloComputation*)>&
2458         eq_computations) const {
2459   const auto& casted_other =
2460       static_cast<const HloSelectAndScatterInstruction&>(other);
2461   return eq_computations(select(), casted_other.select()) &&
2462          eq_computations(scatter(), casted_other.scatter()) &&
2463          protobuf_util::ProtobufEquals(window(), casted_other.window());
2464 }
2465 
2466 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2467 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2468     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2469     HloCloneContext* context) const {
2470   CHECK_EQ(new_operands.size(), 3);
2471   return absl::make_unique<HloSelectAndScatterInstruction>(
2472       shape, new_operands[0], select(), window(), new_operands[1],
2473       new_operands[2], scatter());
2474 }
2475 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2476 HloCustomCallInstruction::HloCustomCallInstruction(
2477     const Shape& shape, absl::Span<HloInstruction* const> operands,
2478     absl::string_view custom_call_target, string opaque,
2479     CustomCallApiVersion api_version)
2480     : HloInstruction(HloOpcode::kCustomCall, shape),
2481       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2482       feature_group_count_(1),
2483       batch_group_count_(1),
2484       layout_constrained_(false),
2485       padding_type_(PaddingType::PADDING_INVALID),
2486       custom_call_has_side_effect_(false),
2487       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2488       api_version_(api_version) {
2489   set_raw_backend_config_string(std::move(opaque));
2490   for (auto operand : operands) {
2491     AppendOperand(operand);
2492   }
2493 }
2494 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2495 HloCustomCallInstruction::HloCustomCallInstruction(
2496     const Shape& shape, absl::Span<HloInstruction* const> operands,
2497     HloComputation* to_apply, absl::string_view custom_call_target,
2498     string opaque, CustomCallApiVersion api_version)
2499     : HloInstruction(HloOpcode::kCustomCall, shape),
2500       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2501       feature_group_count_(1),
2502       batch_group_count_(1),
2503       layout_constrained_(false),
2504       padding_type_(PaddingType::PADDING_INVALID),
2505       custom_call_has_side_effect_(false),
2506       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2507       api_version_(api_version) {
2508   set_raw_backend_config_string(std::move(opaque));
2509   for (auto operand : operands) {
2510     AppendOperand(operand);
2511   }
2512   AppendComputation(to_apply);
2513   to_apply->SetCustomCallInstruction(this);
2514 }
2515 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2516 HloCustomCallInstruction::HloCustomCallInstruction(
2517     const Shape& shape, absl::Span<HloInstruction* const> operands,
2518     absl::Span<HloComputation* const> called_computations,
2519     absl::string_view custom_call_target, string opaque,
2520     CustomCallApiVersion api_version)
2521     : HloInstruction(HloOpcode::kCustomCall, shape),
2522       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2523       feature_group_count_(1),
2524       batch_group_count_(1),
2525       layout_constrained_(false),
2526       padding_type_(PaddingType::PADDING_INVALID),
2527       custom_call_has_side_effect_(false),
2528       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2529       api_version_(api_version) {
2530   set_raw_backend_config_string(std::move(opaque));
2531   for (auto operand : operands) {
2532     AppendOperand(operand);
2533   }
2534   for (auto comp : called_computations) {
2535     AppendComputation(comp);
2536   }
2537 }
2538 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout,CustomCallApiVersion api_version)2539 HloCustomCallInstruction::HloCustomCallInstruction(
2540     const Shape& shape, absl::Span<HloInstruction* const> operands,
2541     absl::string_view custom_call_target, string opaque,
2542     absl::Span<const Shape> operand_shapes_with_layout,
2543     CustomCallApiVersion api_version)
2544     : HloInstruction(HloOpcode::kCustomCall, shape),
2545       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2546       feature_group_count_(1),
2547       batch_group_count_(1),
2548       layout_constrained_(true),
2549       padding_type_(PaddingType::PADDING_INVALID),
2550       operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2551                                   operand_shapes_with_layout.end()),
2552       custom_call_has_side_effect_(false),
2553       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2554       api_version_(api_version) {
2555   set_raw_backend_config_string(std::move(opaque));
2556   for (auto operand : operands) {
2557     AppendOperand(operand);
2558   }
2559 }
2560 
ToProto() const2561 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2562   HloInstructionProto proto = HloInstruction::ToProto();
2563   if (window_ != nullptr) {
2564     *proto.mutable_window() = *window_;
2565   }
2566   if (convolution_dimension_numbers_ != nullptr) {
2567     *proto.mutable_convolution_dimension_numbers() =
2568         *convolution_dimension_numbers_;
2569   }
2570   proto.set_custom_call_target(custom_call_target_);
2571   proto.set_feature_group_count(feature_group_count_);
2572   proto.set_batch_group_count(batch_group_count_);
2573   *proto.mutable_precision_config() = precision_config_;
2574   proto.set_padding_type(padding_type_);
2575   if (layout_constrained()) {
2576     proto.set_constrain_layout(true);
2577     for (const Shape& shape : operand_shapes_with_layout_) {
2578       *proto.add_operand_shapes_with_layout() = shape.ToProto();
2579     }
2580   }
2581   proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2582   if (literal_.has_value()) {
2583     *proto.mutable_literal() = literal_->ToProto();
2584   }
2585   for (const auto& pair : output_to_operand_aliasing_) {
2586     auto aliasing = proto.add_custom_call_output_operand_aliasing();
2587     aliasing->set_operand_index(pair.second.first);
2588     for (int64_t index : pair.first) {
2589       aliasing->add_output_shape_index(index);
2590     }
2591     for (int64_t index : pair.second.second) {
2592       aliasing->add_operand_shape_index(index);
2593     }
2594   }
2595   proto.set_custom_call_schedule(custom_call_schedule_);
2596   proto.set_custom_call_api_version(api_version_);
2597   return proto;
2598 }
2599 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2600 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2601     const HloPrintOptions& options) const {
2602   std::vector<string> extra;
2603   if (window_ != nullptr) {
2604     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2605   }
2606   if (convolution_dimension_numbers_ != nullptr) {
2607     extra.push_back(StrCat(
2608         "dim_labels=",
2609         ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2610   }
2611   if (feature_group_count_ != 1) {
2612     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2613   }
2614   if (batch_group_count_ != 1) {
2615     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2616   }
2617   string precision_config_string = PrecisionConfigToString(precision_config_);
2618   if (!precision_config_string.empty()) {
2619     extra.push_back(precision_config_string);
2620   }
2621   if (padding_type_ != PaddingType::PADDING_INVALID) {
2622     extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2623   }
2624   // By contract, we print the custom call target even if
2625   // options.print_subcomputation_mode() == kOff, because the call target is not
2626   // an HloComputation.
2627   extra.push_back(
2628       StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2629 
2630   if (layout_constrained()) {
2631     std::vector<string> shape_strings;
2632     for (const Shape& shape : operand_shapes_with_layout_) {
2633       shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2634     }
2635     extra.push_back(StrCat("operand_layout_constraints={",
2636                            StrJoin(shape_strings, ", "), "}"));
2637   }
2638   if (custom_call_has_side_effect_) {
2639     extra.push_back("custom_call_has_side_effect=true");
2640   }
2641   if (literal_.has_value()) {
2642     extra.push_back(StrCat("literal=", literal_->ToStringWithLayoutOneline()));
2643   }
2644   if (!output_to_operand_aliasing_.empty()) {
2645     std::vector<string> pair_strings;
2646     for (const auto& pair : output_to_operand_aliasing_) {
2647       pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2648                                     pair.second.first, ", ",
2649                                     pair.second.second.ToString(), ")"));
2650     }
2651     extra.push_back(StrCat("output_to_operand_aliasing={",
2652                            StrJoin(pair_strings, ", "), "}"));
2653   }
2654   if (custom_call_schedule_ != CustomCallSchedule::SCHEDULE_NONE) {
2655     extra.push_back(
2656         StrCat("schedule=", CustomCallSchedule_Name(custom_call_schedule_)));
2657   }
2658   if (api_version_ != CustomCallApiVersion::API_VERSION_ORIGINAL) {
2659     extra.push_back(
2660         StrCat("api_version=", CustomCallApiVersion_Name(api_version_)));
2661   }
2662   return extra;
2663 }
2664 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2665 bool HloCustomCallInstruction::IdenticalSlowPath(
2666     const HloInstruction& other,
2667     const std::function<bool(const HloComputation*, const HloComputation*)>&
2668         eq_computations) const {
2669   const auto& casted_other =
2670       static_cast<const HloCustomCallInstruction&>(other);
2671   if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2672       (window_ != nullptr &&
2673        !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2674     return false;
2675   }
2676   if ((convolution_dimension_numbers_ == nullptr) !=
2677           (casted_other.convolution_dimension_numbers_ == nullptr) ||
2678       (convolution_dimension_numbers_ != nullptr &&
2679        !protobuf_util::ProtobufEquals(
2680            convolution_dimension_numbers(),
2681            casted_other.convolution_dimension_numbers()))) {
2682     return false;
2683   }
2684   if (feature_group_count_ != casted_other.feature_group_count_) {
2685     return false;
2686   }
2687   if (batch_group_count_ != casted_other.batch_group_count_) {
2688     return false;
2689   }
2690 
2691   if (padding_type_ != casted_other.padding_type()) {
2692     return false;
2693   }
2694 
2695   if (layout_constrained() != casted_other.layout_constrained()) {
2696     return false;
2697   }
2698   if (layout_constrained()) {
2699     for (int64_t i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2700       if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2701                             casted_other.operand_shapes_with_layout_[i])) {
2702         return false;
2703       }
2704     }
2705   }
2706   if (custom_call_has_side_effect_ !=
2707       casted_other.custom_call_has_side_effect()) {
2708     return false;
2709   }
2710   if (output_to_operand_aliasing_ !=
2711       casted_other.output_to_operand_aliasing()) {
2712     return false;
2713   }
2714   if (!protobuf_util::ProtobufEquals(precision_config(),
2715                                      casted_other.precision_config())) {
2716     return false;
2717   }
2718 
2719   if (called_computations().size() != other.called_computations().size()) {
2720     return false;
2721   }
2722   for (int64_t i = 0; i < called_computations().size(); ++i) {
2723     if (!eq_computations(called_computations()[i],
2724                          other.called_computations()[i])) {
2725       return false;
2726     }
2727   }
2728   if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
2729     return false;
2730   }
2731   if (HasLiteral() == casted_other.HasLiteral()) {
2732     if (HasLiteral() && literal() == casted_other.literal()) {
2733       return false;
2734     }
2735   } else {
2736     return true;
2737   }
2738   // Note: backend_config comparison is done in Identical, which is the
2739   // intended/exposed way to compare computations, and so not repeated here.
2740   return custom_call_target_ == casted_other.custom_call_target_;
2741 }
2742 
2743 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2744 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2745     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2746     HloCloneContext* context) const {
2747   auto cloned = absl::make_unique<HloCustomCallInstruction>(
2748       shape, new_operands, called_computations(), custom_call_target(),
2749       opaque(), api_version_);
2750   if (layout_constrained()) {
2751     cloned->layout_constrained_ = true;
2752     cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2753   }
2754   if (window_ != nullptr) {
2755     cloned->set_window(*window_);
2756   }
2757   if (convolution_dimension_numbers_ != nullptr) {
2758     cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2759   }
2760   if (HasLiteral()) {
2761     cloned->set_literal(literal().Clone());
2762   }
2763   cloned->set_feature_group_count(feature_group_count_);
2764   cloned->set_batch_group_count(batch_group_count_);
2765   cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2766   cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2767   cloned->set_padding_type(padding_type_);
2768   *cloned->mutable_precision_config() = precision_config();
2769   cloned->set_custom_call_schedule(custom_call_schedule_);
2770   return std::move(cloned);
2771 }
2772 
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2773 HloPadInstruction::HloPadInstruction(const Shape& shape,
2774                                      HloInstruction* operand,
2775                                      HloInstruction* padding_value,
2776                                      const PaddingConfig& padding_config)
2777     : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2778   AppendOperand(operand);
2779   AppendOperand(padding_value);
2780 }
2781 
ToProto() const2782 HloInstructionProto HloPadInstruction::ToProto() const {
2783   HloInstructionProto proto = HloInstruction::ToProto();
2784   *proto.mutable_padding_config() = padding_config_;
2785   return proto;
2786 }
2787 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2788 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2789     const HloPrintOptions& options) const {
2790   return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2791 }
2792 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2793 bool HloPadInstruction::IdenticalSlowPath(
2794     const HloInstruction& other,
2795     const std::function<bool(const HloComputation*, const HloComputation*)>&
2796         eq_computations) const {
2797   const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2798   return protobuf_util::ProtobufEquals(padding_config(),
2799                                        casted_other.padding_config());
2800 }
2801 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2802 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2803     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2804     HloCloneContext* context) const {
2805   CHECK_EQ(new_operands.size(), 2);
2806   return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2807                                               new_operands[1], padding_config_);
2808 }
2809 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2810 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2811     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2812     absl::Span<const int64> slice_sizes)
2813     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2814       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2815   AppendOperand(operand);
2816   AppendOperand(start_indices);
2817 }
2818 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2819 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2820     const Shape& shape, HloInstruction* operand,
2821     absl::Span<HloInstruction* const> start_indices,
2822     absl::Span<const int64> slice_sizes)
2823     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2824       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2825   AppendOperand(operand);
2826   for (HloInstruction* index : start_indices) {
2827     AppendOperand(index);
2828   }
2829 }
2830 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2831 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2832     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2833     HloInstruction* start_indices)
2834     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2835   AppendOperand(operand);
2836   AppendOperand(update);
2837   AppendOperand(start_indices);
2838 }
2839 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2840 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2841     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2842     absl::Span<HloInstruction* const> start_indices)
2843     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2844   AppendOperand(operand);
2845   AppendOperand(update);
2846   for (HloInstruction* index : start_indices) {
2847     AppendOperand(index);
2848   }
2849 }
2850 
ToProto() const2851 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2852   HloInstructionProto proto = HloInstruction::ToProto();
2853   for (int64_t slice_size : dynamic_slice_sizes_) {
2854     proto.add_dynamic_slice_sizes(slice_size);
2855   }
2856   return proto;
2857 }
2858 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2859 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2860     const HloPrintOptions& options) const {
2861   return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2862                  "}")};
2863 }
2864 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2865 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2866     const HloInstruction& other,
2867     const std::function<bool(const HloComputation*, const HloComputation*)>&
2868         eq_computations) const {
2869   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
2870   return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
2871 }
2872 
2873 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2874 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2875     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2876     HloCloneContext* context) const {
2877   if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2878     // TODO(b/118437727): Old form, remove this path.
2879     return absl::make_unique<HloDynamicSliceInstruction>(
2880         shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2881   } else {
2882     return absl::make_unique<HloDynamicSliceInstruction>(
2883         shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2884   }
2885 }
2886 
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2887 HloGatherInstruction::HloGatherInstruction(
2888     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2889     const GatherDimensionNumbers& gather_dim_numbers,
2890     absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2891     : HloInstruction(HloOpcode::kGather, shape),
2892       indices_are_sorted_(indices_are_sorted) {
2893   AppendOperand(operand);
2894   AppendOperand(start_indices);
2895   gather_dimension_numbers_ =
2896       absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2897   absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2898 }
2899 
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2900 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2901     const GatherDimensionNumbers& gather_dimension_numbers) {
2902   string offset_dims =
2903       StrCat("offset_dims={",
2904              StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2905   string collapsed_slice_dims = StrCat(
2906       "collapsed_slice_dims={",
2907       StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2908   string start_index_map =
2909       StrCat("start_index_map={",
2910              StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2911   string index_vector_dim =
2912       StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2913 
2914   return StrJoin<std::initializer_list<string>>(
2915       {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2916       ", ");
2917 }
2918 
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64_t index_vector_dim)2919 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2920     absl::Span<const int64> offset_dims,
2921     absl::Span<const int64> collapsed_slice_dims,
2922     absl::Span<const int64> start_index_map, int64_t index_vector_dim) {
2923   GatherDimensionNumbers gather_dim_numbers;
2924   for (int64_t output_window_dim : offset_dims) {
2925     gather_dim_numbers.add_offset_dims(output_window_dim);
2926   }
2927   for (int64_t elided_window_dim : collapsed_slice_dims) {
2928     gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2929   }
2930   for (int64_t gather_dim_to_input_dim : start_index_map) {
2931     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2932   }
2933 
2934   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2935   return gather_dim_numbers;
2936 }
2937 
ToProto() const2938 HloInstructionProto HloGatherInstruction::ToProto() const {
2939   HloInstructionProto proto = HloInstruction::ToProto();
2940   *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2941   for (int64_t bound : gather_slice_sizes()) {
2942     proto.add_gather_slice_sizes(bound);
2943   }
2944   proto.set_indices_are_sorted(indices_are_sorted());
2945   return proto;
2946 }
2947 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2948 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2949     const HloPrintOptions& options) const {
2950   std::vector<string> attrs{
2951       GatherDimensionNumbersToString(gather_dimension_numbers()),
2952       StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2953   if (indices_are_sorted()) {
2954     attrs.push_back("indices_are_sorted=true");
2955   }
2956   return attrs;
2957 }
2958 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2959 bool HloGatherInstruction::IdenticalSlowPath(
2960     const HloInstruction& other,
2961     const std::function<bool(const HloComputation*, const HloComputation*)>&
2962         eq_computations) const {
2963   const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2964   return protobuf_util::ProtobufEquals(
2965              gather_dimension_numbers(),
2966              casted_other.gather_dimension_numbers()) &&
2967          gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2968          indices_are_sorted() == casted_other.indices_are_sorted();
2969 }
2970 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2971 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2972     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2973     HloCloneContext* context) const {
2974   CHECK_EQ(new_operands.size(), 2);
2975   return absl::make_unique<HloGatherInstruction>(
2976       shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2977       gather_slice_sizes(), indices_are_sorted());
2978 }
2979 
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2980 HloScatterInstruction::HloScatterInstruction(
2981     const Shape& shape, HloInstruction* operand,
2982     HloInstruction* scatter_indices, HloInstruction* updates,
2983     HloComputation* update_computation,
2984     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2985     bool unique_indices)
2986     : HloInstruction(HloOpcode::kScatter, shape),
2987       indices_are_sorted_(indices_are_sorted),
2988       unique_indices_(unique_indices) {
2989   AppendOperand(operand);
2990   AppendOperand(scatter_indices);
2991   AppendOperand(updates);
2992   AppendComputation(update_computation);
2993   scatter_dimension_numbers_ =
2994       absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2995 }
2996 
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2997 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2998     const ScatterDimensionNumbers& scatter_dimension_numbers) {
2999   string update_window_dims =
3000       StrCat("update_window_dims={",
3001              StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
3002   string inserted_window_dims = StrCat(
3003       "inserted_window_dims={",
3004       StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
3005   string scatter_dims_to_operand_dims = StrCat(
3006       "scatter_dims_to_operand_dims={",
3007       StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
3008       "}");
3009   string index_vector_dim =
3010       StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
3011 
3012   return StrJoin<std::initializer_list<string>>(
3013       {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
3014        index_vector_dim},
3015       ", ");
3016 }
3017 
3018 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64_t index_vector_dim)3019 HloScatterInstruction::MakeScatterDimNumbers(
3020     absl::Span<const int64> update_window_dims,
3021     absl::Span<const int64> inserted_window_dims,
3022     absl::Span<const int64> scatter_dims_to_operand_dims,
3023     int64_t index_vector_dim) {
3024   ScatterDimensionNumbers scatter_dim_numbers;
3025   for (int64_t update_window_dim : update_window_dims) {
3026     scatter_dim_numbers.add_update_window_dims(update_window_dim);
3027   }
3028   for (int64_t inserted_window_dim : inserted_window_dims) {
3029     scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
3030   }
3031   for (int64_t scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
3032     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
3033         scatter_dim_to_operand_dim);
3034   }
3035   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
3036   return scatter_dim_numbers;
3037 }
3038 
ToProto() const3039 HloInstructionProto HloScatterInstruction::ToProto() const {
3040   HloInstructionProto proto = HloInstruction::ToProto();
3041   *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
3042   proto.set_indices_are_sorted(indices_are_sorted());
3043   proto.set_unique_indices(unique_indices());
3044   return proto;
3045 }
3046 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3047 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
3048     const HloPrintOptions& options) const {
3049   std::vector<string> attrs{
3050       ScatterDimensionNumbersToString(scatter_dimension_numbers())};
3051   if (indices_are_sorted()) {
3052     attrs.push_back("indices_are_sorted=true");
3053   }
3054   if (unique_indices()) {
3055     attrs.push_back("unique_indices=true");
3056   }
3057   return attrs;
3058 }
3059 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3060 bool HloScatterInstruction::IdenticalSlowPath(
3061     const HloInstruction& other,
3062     const std::function<bool(const HloComputation*, const HloComputation*)>&
3063         eq_computations) const {
3064   const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
3065   return protobuf_util::ProtobufEquals(
3066              scatter_dimension_numbers(),
3067              casted_other.scatter_dimension_numbers()) &&
3068          eq_computations(to_apply(), casted_other.to_apply()) &&
3069          indices_are_sorted() == casted_other.indices_are_sorted() &&
3070          unique_indices() == casted_other.unique_indices();
3071 }
3072 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3073 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
3074     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3075     HloCloneContext* context) const {
3076   CHECK_EQ(new_operands.size(), 3);
3077   return absl::make_unique<HloScatterInstruction>(
3078       shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
3079       scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
3080 }
3081 
HloIotaInstruction(const Shape & shape,int64_t iota_dimension)3082 HloIotaInstruction::HloIotaInstruction(const Shape& shape,
3083                                        int64_t iota_dimension)
3084     : HloInstruction(HloOpcode::kIota, shape),
3085       iota_dimension_(iota_dimension) {}
3086 
ToProto() const3087 HloInstructionProto HloIotaInstruction::ToProto() const {
3088   HloInstructionProto proto = HloInstruction::ToProto();
3089   proto.add_dimensions(iota_dimension());
3090   return proto;
3091 }
3092 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3093 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
3094     const HloPrintOptions& options) const {
3095   return {StrCat("iota_dimension=", iota_dimension())};
3096 }
3097 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3098 bool HloIotaInstruction::IdenticalSlowPath(
3099     const HloInstruction& other,
3100     const std::function<bool(const HloComputation*, const HloComputation*)>&
3101         eq_computations) const {
3102   const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
3103   return iota_dimension() == casted_other.iota_dimension();
3104 }
3105 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3106 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
3107     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3108     HloCloneContext* context) const {
3109   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
3110 }
3111 
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)3112 HloDotInstruction::HloDotInstruction(
3113     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
3114     const DotDimensionNumbers& dimension_numbers,
3115     const PrecisionConfig& precision_config)
3116     : HloInstruction(HloOpcode::kDot, shape),
3117       dot_dimension_numbers_(dimension_numbers),
3118       precision_config_(precision_config) {
3119   AppendOperand(lhs);
3120   AppendOperand(rhs);
3121 }
3122 
ToProto() const3123 HloInstructionProto HloDotInstruction::ToProto() const {
3124   HloInstructionProto proto = HloInstruction::ToProto();
3125   *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
3126   *proto.mutable_precision_config() = precision_config_;
3127   return proto;
3128 }
3129 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3130 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
3131     const HloPrintOptions& options) const {
3132   std::vector<string> extra = {DotDimensionNumbersToString()};
3133 
3134   string precision_config_string = PrecisionConfigToString(precision_config_);
3135   if (!precision_config_string.empty()) {
3136     extra.push_back(precision_config_string);
3137   }
3138   return extra;
3139 }
3140 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3141 bool HloDotInstruction::IdenticalSlowPath(
3142     const HloInstruction& other,
3143     const std::function<bool(const HloComputation*, const HloComputation*)>&
3144         eq_computations) const {
3145   const auto& casted_other = static_cast<const HloDotInstruction&>(other);
3146   return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
3147                                        casted_other.dot_dimension_numbers()) &&
3148          protobuf_util::ProtobufEquals(precision_config(),
3149                                        casted_other.precision_config());
3150 }
3151 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3152 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
3153     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3154     HloCloneContext* context) const {
3155   CHECK_EQ(new_operands.size(), 2);
3156   return absl::make_unique<HloDotInstruction>(
3157       shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
3158       precision_config_);
3159 }
3160 
DotDimensionNumbersToString() const3161 string HloDotInstruction::DotDimensionNumbersToString() const {
3162   std::vector<string> result;
3163   const DotDimensionNumbers& dnums = dot_dimension_numbers_;
3164   if (!dnums.lhs_batch_dimensions().empty()) {
3165     result.push_back(StrCat("lhs_batch_dims={",
3166                             StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
3167   }
3168   result.push_back(StrCat("lhs_contracting_dims={",
3169                           StrJoin(dnums.lhs_contracting_dimensions(), ","),
3170                           "}"));
3171 
3172   if (!dnums.rhs_batch_dimensions().empty()) {
3173     result.push_back(StrCat("rhs_batch_dims={",
3174                             StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
3175   }
3176   result.push_back(StrCat("rhs_contracting_dims={",
3177                           StrJoin(dnums.rhs_contracting_dimensions(), ","),
3178                           "}"));
3179 
3180   return StrJoin(result, ", ");
3181 }
3182 
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3183 HloDomainInstruction::HloDomainInstruction(
3184     const Shape& shape, HloInstruction* operand,
3185     std::unique_ptr<DomainMetadata> operand_side_metadata,
3186     std::unique_ptr<DomainMetadata> user_side_metadata)
3187     : HloInstruction(HloOpcode::kDomain, shape),
3188       operand_side_metadata_(std::move(operand_side_metadata)),
3189       user_side_metadata_(std::move(user_side_metadata)) {
3190   AppendOperand(operand);
3191 }
3192 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3193 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
3194     const HloPrintOptions& options) const {
3195   if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3196     return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3197                    "\", entry=", user_side_metadata_->ToString(),
3198                    ", exit=", operand_side_metadata_->ToString(), "}")};
3199   }
3200   return {};
3201 }
3202 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3203 bool HloDomainInstruction::IdenticalSlowPath(
3204     const HloInstruction& other,
3205     const std::function<bool(const HloComputation*, const HloComputation*)>&
3206         eq_computations) const {
3207   const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3208   return operand_side_metadata().Matches(
3209              casted_other.operand_side_metadata()) &&
3210          user_side_metadata().Matches(casted_other.user_side_metadata());
3211 }
3212 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3213 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3214     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3215     HloCloneContext* context) const {
3216   CHECK_EQ(new_operands.size(), 1);
3217   return absl::make_unique<HloDomainInstruction>(
3218       shape, new_operands[0], operand_side_metadata_->Clone(),
3219       user_side_metadata_->Clone());
3220 }
3221 
ToProto() const3222 HloInstructionProto HloDomainInstruction::ToProto() const {
3223   HloInstructionProto proto = HloInstruction::ToProto();
3224   auto operand_side_sharding =
3225       dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3226   if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3227     *proto.mutable_domain_entry_sharding() =
3228         operand_side_sharding->sharding()->ToProto();
3229   }
3230 
3231   auto user_side_sharding =
3232       dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3233   if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3234     *proto.mutable_domain_exit_sharding() =
3235         user_side_sharding->sharding()->ToProto();
3236   }
3237 
3238   return proto;
3239 }
3240 
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64_t dimension)3241 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3242     const Shape& shape, HloInstruction* operand, int64_t dimension)
3243     : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3244       dimension_(dimension) {
3245   AppendOperand(operand);
3246 }
3247 
ToProto() const3248 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3249   HloInstructionProto proto = HloInstruction::ToProto();
3250   proto.add_dimensions(dimension());
3251   return proto;
3252 }
3253 
ExtraAttributesToStringImpl(const HloPrintOptions &) const3254 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3255     const HloPrintOptions& /*options*/) const {
3256   return {StrCat("dimensions={", dimension(), "}")};
3257 }
3258 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3259 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3260     const HloInstruction& other,
3261     const std::function<bool(const HloComputation*, const HloComputation*)>&
3262     /*eq_computations*/) const {
3263   const auto& casted_other =
3264       static_cast<const HloGetDimensionSizeInstruction&>(other);
3265   return dimension() == casted_other.dimension();
3266 }
3267 
3268 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3269 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3270     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3271     HloCloneContext* /*context*/) const {
3272   if (new_operands.size() != 1) {
3273     LOG(FATAL) << "expects 1 operand";
3274   }
3275   return absl::make_unique<HloGetDimensionSizeInstruction>(
3276       shape, new_operands[0], dimension());
3277 }
3278 
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)3279 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3280     const Shape& shape, HloInstruction* operand, HloInstruction* val,
3281     int64_t dimension)
3282     : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3283       dimension_(dimension) {
3284   AppendOperand(operand);
3285   AppendOperand(val);
3286 }
3287 
ExtraAttributesToStringImpl(const HloPrintOptions &) const3288 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3289     const HloPrintOptions& /*options*/) const {
3290   return {StrCat("dimensions={", dimension(), "}")};
3291 }
3292 
ToProto() const3293 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3294   HloInstructionProto proto = HloInstruction::ToProto();
3295   proto.add_dimensions(dimension());
3296   return proto;
3297 }
3298 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3299 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3300     const HloInstruction& other,
3301     const std::function<bool(const HloComputation*, const HloComputation*)>&
3302     /*eq_computations*/) const {
3303   const auto& casted_other =
3304       static_cast<const HloSetDimensionSizeInstruction&>(other);
3305   return dimension() == casted_other.dimension();
3306 }
3307 
3308 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3309 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3310     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3311     HloCloneContext* /*context*/) const {
3312   if (new_operands.size() != 2) {
3313     LOG(FATAL) << "expects 2 operand";
3314   }
3315   return absl::make_unique<HloSetDimensionSizeInstruction>(
3316       shape, new_operands[0], new_operands[1], dimension());
3317 }
3318 
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64_t delta)3319 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3320     const Shape& shape, int64_t delta)
3321     : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3322 
ToProto() const3323 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3324   HloInstructionProto proto = HloInstruction::ToProto();
3325   proto.set_delta(delta_);
3326   return proto;
3327 }
3328 
3329 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3330 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3331     const HloPrintOptions& /*options*/) const {
3332   return {StrCat("delta=", delta())};
3333 }
3334 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3335 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3336     const HloInstruction& other,
3337     const std::function<bool(const HloComputation*, const HloComputation*)>&
3338     /*eq_computations*/) const {
3339   const auto& casted_other =
3340       static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3341   return delta() == casted_other.delta();
3342 }
3343 
3344 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3345 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3346     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3347     HloCloneContext* /*context*/) const {
3348   if (!new_operands.empty()) {
3349     LOG(FATAL) << "expects 0 operand";
3350   }
3351   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3352 }
3353 
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3354 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3355     const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3356     : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3357       algorithm_(algorithm) {
3358   AppendOperand(state);
3359 }
3360 
ToProto() const3361 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3362   HloInstructionProto proto = HloInstruction::ToProto();
3363   proto.set_rng_algorithm(algorithm_);
3364   return proto;
3365 }
3366 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3367 std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3368     const HloPrintOptions& options) const {
3369   return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3370 }
3371 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3372 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3373     const HloInstruction& other,
3374     const std::function<bool(const HloComputation*, const HloComputation*)>&
3375         eq_computations) const {
3376   const auto& casted_other =
3377       static_cast<const HloRngBitGeneratorInstruction&>(other);
3378   return algorithm() == casted_other.algorithm();
3379 }
3380 
3381 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3382 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3383     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3384     HloCloneContext* /*context*/) const {
3385   CHECK_EQ(new_operands.size(), 1);
3386   return absl::make_unique<HloRngBitGeneratorInstruction>(
3387       shape, new_operands[0], algorithm());
3388 }
3389 
3390 }  // namespace xla
3391