• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17 
18 #include <deque>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 
36 namespace xla {
37 namespace {
38 
39 using absl::CEscape;
40 using absl::StrAppend;
41 using absl::StrCat;
42 using absl::StrJoin;
43 
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)44 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
45                                        const HloInstruction* operand) {
46   const auto operand_indices = instruction->OperandIndices(operand);
47   return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
48     return instruction->IsElementwiseOnOperand(operand_index);
49   });
50 }
51 
PrecisionConfigToString(const PrecisionConfig & precision_config)52 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
53   if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
54         return static_cast<PrecisionConfig::Precision>(precision) ==
55                PrecisionConfig::DEFAULT;
56       })) {
57     return "";
58   }
59 
60   return StrCat(
61       "operand_precision={",
62       StrJoin(
63           precision_config.operand_precision(), ",",
64           [](string* out, int32 precision) {
65             CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
66             StrAppend(out,
67                       PrecisionToString(
68                           static_cast<PrecisionConfig::Precision>(precision)));
69           }),
70       "}");
71 }
72 }  // namespace
73 
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)74 HloBatchNormInstruction::HloBatchNormInstruction(
75     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
76     HloInstruction* scale, float epsilon, int64 feature_index)
77     : HloInstruction(opcode, shape),
78       epsilon_(epsilon),
79       feature_index_(feature_index) {
80   AppendOperand(operand);
81   AppendOperand(scale);
82 }
83 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const84 bool HloBatchNormInstruction::IdenticalSlowPath(
85     const HloInstruction& other,
86     const std::function<bool(const HloComputation*, const HloComputation*)>&
87         eq_computations) const {
88   const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
89   return feature_index() == casted_other.feature_index() &&
90          epsilon() == casted_other.epsilon();
91 }
92 
ToProto() const93 HloInstructionProto HloBatchNormInstruction::ToProto() const {
94   HloInstructionProto proto = HloInstruction::ToProto();
95   proto.set_epsilon(epsilon_);
96   proto.set_feature_index(feature_index_);
97   return proto;
98 }
99 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const100 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
101     const HloPrintOptions& options) const {
102   return {StrCat("epsilon=", epsilon()),
103           StrCat("feature_index=", feature_index())};
104 }
105 
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)106 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
107     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
108     HloInstruction* offset, float epsilon, int64 feature_index)
109     : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
110                               scale, epsilon, feature_index) {
111   AppendOperand(offset);
112 }
113 
114 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const115 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
116     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
117     HloCloneContext* context) const {
118   CHECK_EQ(new_operands.size(), 3);
119   return absl::make_unique<HloBatchNormTrainingInstruction>(
120       shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
121       feature_index());
122 }
123 
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)124 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
125     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
126     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
127     float epsilon, int64 feature_index)
128     : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
129                               scale, epsilon, feature_index) {
130   AppendOperand(offset);
131   AppendOperand(mean);
132   AppendOperand(variance);
133 }
134 
135 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const136 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
137     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
138     HloCloneContext* context) const {
139   CHECK_EQ(new_operands.size(), 5);
140   return absl::make_unique<HloBatchNormInferenceInstruction>(
141       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
142       new_operands[4], epsilon(), feature_index());
143 }
144 
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)145 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
146     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
147     HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
148     float epsilon, int64 feature_index)
149     : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
150                               epsilon, feature_index) {
151   AppendOperand(mean);
152   AppendOperand(variance);
153   AppendOperand(grad_output);
154 }
155 
156 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const157 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
158     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
159     HloCloneContext* context) const {
160   CHECK_EQ(new_operands.size(), 5);
161   return absl::make_unique<HloBatchNormGradInstruction>(
162       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
163       new_operands[4], epsilon(), feature_index());
164 }
165 
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)166 HloFftInstruction::HloFftInstruction(const Shape& shape,
167                                      HloInstruction* operand, FftType fft_type,
168                                      absl::Span<const int64> fft_length)
169     : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
170   fft_length_.assign(fft_length.begin(), fft_length.end());
171   AppendOperand(operand);
172 }
173 
ToProto() const174 HloInstructionProto HloFftInstruction::ToProto() const {
175   HloInstructionProto proto = HloInstruction::ToProto();
176   proto.set_fft_type(fft_type_);
177   for (int64 fft_len : fft_length_) {
178     proto.add_fft_length(fft_len);
179   }
180   return proto;
181 }
182 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const183 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
184     const HloPrintOptions& options) const {
185   return {StrCat("fft_type=", FftType_Name(fft_type())),
186           StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
187 }
188 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const189 bool HloFftInstruction::IdenticalSlowPath(
190     const HloInstruction& other,
191     const std::function<bool(const HloComputation*, const HloComputation*)>&
192         eq_computations) const {
193   const auto& casted_other = static_cast<const HloFftInstruction&>(other);
194   return fft_type() == casted_other.fft_type() &&
195          fft_length() == casted_other.fft_length();
196 }
197 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const198 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
199     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
200     HloCloneContext* context) const {
201   CHECK_EQ(new_operands.size(), 1);
202   return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
203                                               fft_length_);
204 }
205 
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)206 HloCompareInstruction::HloCompareInstruction(const Shape& shape,
207                                              HloInstruction* lhs,
208                                              HloInstruction* rhs,
209                                              ComparisonDirection direction)
210     : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
211   AppendOperand(lhs);
212   AppendOperand(rhs);
213 }
214 
ToProto() const215 HloInstructionProto HloCompareInstruction::ToProto() const {
216   HloInstructionProto proto = HloInstruction::ToProto();
217   proto.set_comparison_direction(ComparisonDirectionToString(direction_));
218   return proto;
219 }
220 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const221 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
222     const HloPrintOptions& options) const {
223   return {StrCat("direction=", ComparisonDirectionToString(direction()))};
224 }
225 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const226 bool HloCompareInstruction::IdenticalSlowPath(
227     const HloInstruction& other,
228     const std::function<bool(const HloComputation*, const HloComputation*)>&
229         eq_computations) const {
230   const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
231   return direction() == casted_other.direction();
232 }
233 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const234 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
235     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
236     HloCloneContext* context) const {
237   CHECK_EQ(new_operands.size(), 2);
238   return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
239                                                   new_operands[1], direction());
240 }
241 
242 namespace {
243 
244 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
245 // of "key=value" attribute strings generically, using protocol buffer
246 // reflection.
247 //
248 // Currently implements a small subset of cases; feel free to add more as
249 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)250 std::vector<string> AttributeProtoToStringVector(
251     const tensorflow::protobuf::Message& message) {
252   const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
253   std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
254   reflection->ListFields(message, &fields);
255 
256   std::vector<string> output;
257   for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
258     string s = absl::StrCat(field->name(), "=");
259     CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
260     switch (field->type()) {
261       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
262         bool val = reflection->GetBool(message, field);
263         absl::StrAppend(&s, val ? "true" : "false");
264         break;
265       }
266       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
267         const tensorflow::protobuf::EnumValueDescriptor* evd =
268             reflection->GetEnum(message, field);
269         absl::StrAppend(&s, evd->name());
270         break;
271       }
272       default:
273         LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
274     }
275     output.push_back(std::move(s));
276   }
277   return output;
278 }
279 
280 }  // namespace
281 
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)282 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
283     const Shape& shape, HloInstruction* a, HloInstruction* b,
284     const TriangularSolveOptions& options)
285     : HloInstruction(HloOpcode::kTriangularSolve, shape),
286       triangular_solve_options_(options) {
287   AppendOperand(a);
288   AppendOperand(b);
289 }
290 
ToProto() const291 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
292   HloInstructionProto proto = HloInstruction::ToProto();
293   *proto.mutable_triangular_solve_options() = triangular_solve_options_;
294   return proto;
295 }
296 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const297 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
298     const HloPrintOptions& options) const {
299   return AttributeProtoToStringVector(triangular_solve_options_);
300 }
301 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const302 bool HloTriangularSolveInstruction::IdenticalSlowPath(
303     const HloInstruction& other,
304     const std::function<bool(const HloComputation*, const HloComputation*)>&
305         eq_computations) const {
306   const auto& casted_other =
307       static_cast<const HloTriangularSolveInstruction&>(other);
308   const auto& options = triangular_solve_options();
309   const auto& other_options = casted_other.triangular_solve_options();
310 
311   return options.left_side() == other_options.left_side() &&
312          options.lower() == other_options.lower() &&
313          options.unit_diagonal() == other_options.unit_diagonal() &&
314          options.transpose_a() == other_options.transpose_a();
315 }
316 
317 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const318 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
319     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
320     HloCloneContext* context) const {
321   CHECK_EQ(new_operands.size(), 2);
322   return absl::make_unique<HloTriangularSolveInstruction>(
323       shape, new_operands[0], new_operands[1], triangular_solve_options());
324 }
325 
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)326 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
327                                                HloInstruction* a,
328                                                const CholeskyOptions& options)
329     : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
330   AppendOperand(a);
331 }
332 
ToProto() const333 HloInstructionProto HloCholeskyInstruction::ToProto() const {
334   HloInstructionProto proto = HloInstruction::ToProto();
335   *proto.mutable_cholesky_options() = cholesky_options_;
336   return proto;
337 }
338 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const339 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
340     const HloPrintOptions& options) const {
341   return AttributeProtoToStringVector(cholesky_options_);
342 }
343 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const344 bool HloCholeskyInstruction::IdenticalSlowPath(
345     const HloInstruction& other,
346     const std::function<bool(const HloComputation*, const HloComputation*)>&
347         eq_computations) const {
348   const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
349   const auto& options = cholesky_options();
350   const auto& other_options = casted_other.cholesky_options();
351 
352   return options.lower() == other_options.lower();
353 }
354 
355 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const356 HloCholeskyInstruction::CloneWithNewOperandsImpl(
357     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
358     HloCloneContext* context) const {
359   CHECK_EQ(new_operands.size(), 1);
360   return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
361                                                    cholesky_options());
362 }
363 
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)364 HloChannelInstruction::HloChannelInstruction(
365     HloOpcode opcode, const Shape& shape,
366     const absl::optional<int64>& channel_id)
367     : HloInstruction(opcode, shape), channel_id_(channel_id) {}
368 
set_channel_id(const absl::optional<int64> & channel_id)369 void HloChannelInstruction::set_channel_id(
370     const absl::optional<int64>& channel_id) {
371   channel_id_ = channel_id;
372 }
373 
ToProto() const374 HloInstructionProto HloChannelInstruction::ToProto() const {
375   HloInstructionProto proto = HloInstruction::ToProto();
376   if (channel_id_) {
377     CHECK_GT(channel_id_.value(), 0)
378         << "Non-positive channel id is equivalent to no channel id";
379     proto.set_channel_id(*channel_id_);
380   }
381   return proto;
382 }
383 
ExtraAttributesToStringImpl(const HloPrintOptions &) const384 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
385     const HloPrintOptions& /*options*/) const {
386   std::vector<string> result;
387   if (channel_id_) {
388     result.push_back(StrCat("channel_id=", *channel_id_));
389   }
390   return result;
391 }
392 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const393 bool HloChannelInstruction::IdenticalSlowPath(
394     const HloInstruction& other,
395     const std::function<bool(const HloComputation*, const HloComputation*)>&
396     /*eq_computations*/) const {
397   const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
398   return channel_id() == casted_other.channel_id();
399 }
400 
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)401 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
402                                                const Shape& shape,
403                                                int64 channel_id,
404                                                bool is_host_transfer)
405     : HloChannelInstruction(opcode, shape, channel_id),
406       is_host_transfer_(is_host_transfer) {}
407 
ToProto() const408 HloInstructionProto HloSendRecvInstruction::ToProto() const {
409   HloInstructionProto proto = HloChannelInstruction::ToProto();
410   proto.set_is_host_transfer(is_host_transfer_);
411   return proto;
412 }
413 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const414 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
415     const HloPrintOptions& options) const {
416   std::vector<string> attrs =
417       HloChannelInstruction::ExtraAttributesToStringImpl(options);
418   if (is_host_transfer()) {
419     attrs.push_back("is_host_transfer=true");
420   }
421   return attrs;
422 }
423 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const424 bool HloSendRecvInstruction::IdenticalSlowPath(
425     const HloInstruction& other,
426     const std::function<bool(const HloComputation*, const HloComputation*)>&
427         eq_computations) const {
428   // Not yet supported.
429   return false;
430 }
431 
432 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)433 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
434                                        HloInstruction* token, int64 channel_id,
435                                        bool is_host_transfer)
436     : HloSendRecvInstruction(
437           HloOpcode::kSend,
438           ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
439                                      ShapeUtil::MakeShape(U32, {}),
440                                      ShapeUtil::MakeTokenShape()}),
441           channel_id, is_host_transfer) {
442   AppendOperand(operand);
443   AppendOperand(token);
444 }
445 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const446 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
447     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
448     HloCloneContext* context) const {
449   CHECK_EQ(new_operands.size(), 2);
450   return absl::make_unique<HloSendInstruction>(
451       new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
452 }
453 
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)454 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
455                                                bool is_host_transfer)
456     : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
457                              CHECK_NOTNULL(operand)->channel_id().value(),
458                              is_host_transfer) {
459   AppendOperand(operand);
460 }
461 
462 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const463 HloSendDoneInstruction::CloneWithNewOperandsImpl(
464     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
465     HloCloneContext* context) const {
466   CHECK_EQ(new_operands.size(), 1);
467   return absl::make_unique<HloSendDoneInstruction>(
468       Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
469 }
470 
471 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)472 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
473                                        HloInstruction* token, int64 channel_id,
474                                        bool is_host_transfer)
475     : HloSendRecvInstruction(
476           HloOpcode::kRecv,
477           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
478                                      ShapeUtil::MakeTokenShape()}),
479           channel_id, is_host_transfer) {
480   AppendOperand(token);
481 }
482 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const483 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
484     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
485     HloCloneContext* context) const {
486   CHECK_EQ(new_operands.size(), 1);
487   return absl::make_unique<HloRecvInstruction>(
488       ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
489       is_host_transfer());
490 }
491 
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)492 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
493                                                bool is_host_transfer)
494     : HloSendRecvInstruction(
495           HloOpcode::kRecvDone,
496           ShapeUtil::MakeTupleShape(
497               {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
498                ShapeUtil::MakeTokenShape()}),
499           CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
500   AppendOperand(operand);
501 }
502 
503 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const504 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
505     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
506     HloCloneContext* context) const {
507   CHECK_EQ(new_operands.size(), 1);
508   return absl::make_unique<HloRecvDoneInstruction>(
509       Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
510 }
511 
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id)512 HloCollectiveInstruction::HloCollectiveInstruction(
513     HloOpcode opcode, const Shape& shape,
514     absl::Span<HloInstruction* const> operands,
515     const std::vector<ReplicaGroup>& replica_groups,
516     const absl::optional<int64>& channel_id)
517     : HloChannelInstruction(opcode, shape, channel_id),
518       replica_groups_(replica_groups) {
519   for (auto operand : operands) {
520     AppendOperand(operand);
521   }
522 }
523 
ToProto() const524 HloInstructionProto HloCollectiveInstruction::ToProto() const {
525   HloInstructionProto proto = HloChannelInstruction::ToProto();
526   *proto.mutable_replica_groups() = {replica_groups_.begin(),
527                                      replica_groups_.end()};
528   return proto;
529 }
530 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const531 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
532     const HloPrintOptions& options) const {
533   std::vector<string> result =
534       HloChannelInstruction::ExtraAttributesToStringImpl(options);
535   result.push_back(
536       StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
537   return result;
538 }
539 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const540 bool HloCollectiveInstruction::IdenticalSlowPath(
541     const HloInstruction& other,
542     const std::function<bool(const HloComputation*, const HloComputation*)>&
543         eq_computations) const {
544   const auto& casted_other =
545       static_cast<const HloCollectiveInstruction&>(other);
546   return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
547          absl::c_equal(replica_groups(), casted_other.replica_groups(),
548                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
549                          return absl::c_equal(a.replica_ids(), b.replica_ids());
550                        });
551 }
552 
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)553 HloAllReduceInstruction::HloAllReduceInstruction(
554     const Shape& shape, absl::Span<HloInstruction* const> operands,
555     HloComputation* reduce_computation,
556     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
557     const absl::optional<int64>& channel_id)
558     : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
559                                replica_groups, channel_id),
560       constrain_layout_(constrain_layout) {
561   AppendComputation(reduce_computation);
562 }
563 
IsNoop() const564 bool HloAllReduceInstruction::IsNoop() const {
565   for (auto replica_group : replica_groups()) {
566     if (replica_group.replica_ids().size() != 1) {
567       return false;
568     }
569   }
570   return !channel_id();
571 }
572 
ToProto() const573 HloInstructionProto HloAllReduceInstruction::ToProto() const {
574   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
575   proto.set_constrain_layout(constrain_layout_);
576   return proto;
577 }
578 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const579 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
580     const HloPrintOptions& options) const {
581   std::vector<string> result =
582       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
583   if (constrain_layout_) {
584     result.push_back("constrain_layout=true");
585   }
586   return result;
587 }
588 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const589 bool HloAllReduceInstruction::IdenticalSlowPath(
590     const HloInstruction& other,
591     const std::function<bool(const HloComputation*, const HloComputation*)>&
592         eq_computations) const {
593   const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
594   return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
595          constrain_layout() == casted_other.constrain_layout() &&
596          eq_computations(to_apply(), casted_other.to_apply());
597 }
598 
599 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const600 HloAllReduceInstruction::CloneWithNewOperandsImpl(
601     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
602     HloCloneContext* /*context*/) const {
603   return absl::make_unique<HloAllReduceInstruction>(
604       shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
605       channel_id());
606 }
607 
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)608 HloAllToAllInstruction::HloAllToAllInstruction(
609     const Shape& shape, absl::Span<HloInstruction* const> operands,
610     const std::vector<ReplicaGroup>& replica_groups,
611     const absl::optional<int64>& channel_id,
612     const absl::optional<int64>& split_dimension)
613     : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
614                                replica_groups, channel_id),
615       split_dimension_(split_dimension) {}
616 
617 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const618 HloAllToAllInstruction::CloneWithNewOperandsImpl(
619     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
620     HloCloneContext* /*context*/) const {
621   return absl::make_unique<HloAllToAllInstruction>(
622       shape, new_operands, replica_groups(), channel_id(), split_dimension());
623 }
624 
ToProto() const625 HloInstructionProto HloAllToAllInstruction::ToProto() const {
626   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
627   if (split_dimension_) {
628     proto.add_dimensions(*split_dimension_);
629   }
630   return proto;
631 }
632 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const633 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
634     const HloPrintOptions& options) const {
635   std::vector<string> result =
636       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
637   if (split_dimension_) {
638     result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
639   }
640   return result;
641 }
642 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const643 bool HloAllToAllInstruction::IdenticalSlowPath(
644     const HloInstruction& other,
645     const std::function<bool(const HloComputation*, const HloComputation*)>&
646         eq_computations) const {
647   const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
648   return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
649          split_dimension_ == casted_other.split_dimension();
650 }
651 
HloCollectivePermuteInstruction(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)652 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
653     const Shape& shape, HloInstruction* operand,
654     const std::vector<std::pair<int64, int64>>& source_target_pairs,
655     const absl::optional<int64>& channel_id)
656     : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id),
657       source_target_pairs_(source_target_pairs) {
658   AppendOperand(operand);
659 }
660 
ToProto() const661 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
662   HloInstructionProto proto = HloChannelInstruction::ToProto();
663   for (const auto& pair : source_target_pairs()) {
664     auto* proto_pair = proto.add_source_target_pairs();
665     proto_pair->set_source(pair.first);
666     proto_pair->set_target(pair.second);
667   }
668   return proto;
669 }
670 
671 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const672 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
673     const HloPrintOptions& options) const {
674   std::vector<string> result =
675       HloChannelInstruction::ExtraAttributesToStringImpl(options);
676   std::vector<string> strs;
677   for (const auto& pair : source_target_pairs()) {
678     strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
679   }
680   result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
681   return result;
682 }
683 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const684 bool HloCollectivePermuteInstruction::IdenticalSlowPath(
685     const HloInstruction& other,
686     const std::function<bool(const HloComputation*, const HloComputation*)>&
687         eq_computations) const {
688   const auto& casted_other =
689       static_cast<const HloCollectivePermuteInstruction&>(other);
690   return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
691          absl::c_equal(source_target_pairs(),
692                        casted_other.source_target_pairs(),
693                        [](const std::pair<int64, int64>& a,
694                           const std::pair<int64, int64>& b) { return a == b; });
695 }
696 
697 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const698 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
699     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
700     HloCloneContext* /*context*/) const {
701   return absl::make_unique<HloCollectivePermuteInstruction>(
702       shape, new_operands[0], source_target_pairs(), channel_id());
703 }
704 
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)705 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
706                                              HloInstruction* operand,
707                                              absl::Span<const int64> dimensions)
708     : HloInstruction(HloOpcode::kReverse, shape),
709       dimensions_(dimensions.begin(), dimensions.end()) {
710   AppendOperand(operand);
711 }
712 
ToProto() const713 HloInstructionProto HloReverseInstruction::ToProto() const {
714   HloInstructionProto proto = HloInstruction::ToProto();
715   for (int64 dimension : dimensions_) {
716     proto.add_dimensions(dimension);
717   }
718   return proto;
719 }
720 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const721 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
722     const HloPrintOptions& options) const {
723   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
724 }
725 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const726 bool HloReverseInstruction::IdenticalSlowPath(
727     const HloInstruction& other,
728     const std::function<bool(const HloComputation*, const HloComputation*)>&
729         eq_computations) const {
730   const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
731   return dimensions() == casted_other.dimensions();
732 }
733 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const734 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
735     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
736     HloCloneContext* context) const {
737   CHECK_EQ(new_operands.size(), 1);
738   return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
739                                                   dimensions());
740 }
741 
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)742 HloConcatenateInstruction::HloConcatenateInstruction(
743     const Shape& shape, absl::Span<HloInstruction* const> operands,
744     int64 dimension)
745     : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
746   for (auto operand : operands) {
747     AppendOperand(operand);
748   }
749 }
750 
ToProto() const751 HloInstructionProto HloConcatenateInstruction::ToProto() const {
752   HloInstructionProto proto = HloInstruction::ToProto();
753   for (int64 dimension : dimensions_) {
754     proto.add_dimensions(dimension);
755   }
756   return proto;
757 }
758 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const759 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
760     const HloPrintOptions& options) const {
761   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
762 }
763 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const764 bool HloConcatenateInstruction::IdenticalSlowPath(
765     const HloInstruction& other,
766     const std::function<bool(const HloComputation*, const HloComputation*)>&
767         eq_computations) const {
768   const auto& casted_other =
769       static_cast<const HloConcatenateInstruction&>(other);
770   return dimensions() == casted_other.dimensions();
771 }
772 
773 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const774 HloConcatenateInstruction::CloneWithNewOperandsImpl(
775     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
776     HloCloneContext* context) const {
777   return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
778                                                       dimensions(0));
779 }
780 
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)781 HloReduceInstruction::HloReduceInstruction(
782     const Shape& shape, absl::Span<HloInstruction* const> args,
783     absl::Span<const int64> dimensions_to_reduce,
784     HloComputation* reduce_computation)
785     : HloInstruction(HloOpcode::kReduce, shape),
786       dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
787   for (HloInstruction* arg : args) {
788     AppendOperand(arg);
789   }
790   AppendComputation(reduce_computation);
791 }
792 
ToProto() const793 HloInstructionProto HloReduceInstruction::ToProto() const {
794   HloInstructionProto proto = HloInstruction::ToProto();
795   for (int64 dimension : dimensions_) {
796     proto.add_dimensions(dimension);
797   }
798   return proto;
799 }
800 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const801 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
802     const HloPrintOptions& options) const {
803   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
804 }
805 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const806 bool HloReduceInstruction::IdenticalSlowPath(
807     const HloInstruction& other,
808     const std::function<bool(const HloComputation*, const HloComputation*)>&
809         eq_computations) const {
810   const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
811   // Reduction results are determined by the reduction dimension and the
812   // reduction computation.
813   return dimensions() == casted_other.dimensions() &&
814          eq_computations(to_apply(), casted_other.to_apply());
815 }
816 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const817 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
818     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
819     HloCloneContext* context) const {
820   CHECK_EQ(new_operands.size() % 2, 0);
821   return absl::make_unique<HloReduceInstruction>(shape, new_operands,
822                                                  dimensions(), to_apply());
823 }
824 
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)825 HloSortInstruction::HloSortInstruction(
826     const Shape& shape, int64 dimension,
827     absl::Span<HloInstruction* const> operands, HloComputation* compare,
828     bool is_stable)
829     : HloInstruction(HloOpcode::kSort, shape),
830       dimensions_({dimension}),
831       is_stable_(is_stable) {
832   for (auto* value : operands) {
833     AppendOperand(value);
834   }
835   AppendComputation(compare);
836 }
837 
ToProto() const838 HloInstructionProto HloSortInstruction::ToProto() const {
839   HloInstructionProto proto = HloInstruction::ToProto();
840   for (int64 dimension : dimensions_) {
841     proto.add_dimensions(dimension);
842   }
843   proto.set_is_stable(is_stable());
844   return proto;
845 }
846 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const847 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
848     const HloPrintOptions& options) const {
849   std::vector<string> attrs;
850   attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
851   if (is_stable()) {
852     attrs.push_back("is_stable=true");
853   }
854   return attrs;
855 }
856 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const857 bool HloSortInstruction::IdenticalSlowPath(
858     const HloInstruction& other,
859     const std::function<bool(const HloComputation*, const HloComputation*)>&
860         eq_computations) const {
861   const auto& casted_other = static_cast<const HloSortInstruction&>(other);
862   if (dimensions() != casted_other.dimensions()) {
863     return false;
864   }
865   if (is_stable() != casted_other.is_stable()) {
866     return false;
867   }
868   return eq_computations(to_apply(), other.to_apply());
869 }
870 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const871 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
872     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
873     HloCloneContext* context) const {
874   return absl::make_unique<HloSortInstruction>(
875       shape, dimensions(0), new_operands, to_apply(), is_stable());
876 }
877 
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)878 HloTransposeInstruction::HloTransposeInstruction(
879     const Shape& shape, HloInstruction* operand,
880     absl::Span<const int64> dimensions)
881     : HloInstruction(HloOpcode::kTranspose, shape),
882       dimensions_(dimensions.begin(), dimensions.end()) {
883   AppendOperand(operand);
884 }
885 
IsRank2Transpose() const886 bool HloTransposeInstruction::IsRank2Transpose() const {
887   return dimensions() == std::vector<int64>({1, 0}) &&
888          shape().dimensions_size() == 2 &&
889          std::equal(shape().dimensions().begin(), shape().dimensions().end(),
890                     operand(0)->shape().dimensions().rbegin());
891 }
892 
ToProto() const893 HloInstructionProto HloTransposeInstruction::ToProto() const {
894   HloInstructionProto proto = HloInstruction::ToProto();
895   for (int64 dimension : dimensions_) {
896     proto.add_dimensions(dimension);
897   }
898   return proto;
899 }
900 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const901 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
902     const HloPrintOptions& options) const {
903   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
904 }
905 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const906 bool HloTransposeInstruction::IdenticalSlowPath(
907     const HloInstruction& other,
908     const std::function<bool(const HloComputation*, const HloComputation*)>&
909         eq_computations) const {
910   const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
911   return dimensions() == casted_other.dimensions();
912 }
913 
914 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const915 HloTransposeInstruction::CloneWithNewOperandsImpl(
916     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
917     HloCloneContext* context) const {
918   CHECK_EQ(new_operands.size(), 1);
919   return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
920                                                     dimensions());
921 }
922 
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)923 HloBroadcastInstruction::HloBroadcastInstruction(
924     const Shape& shape, HloInstruction* operand,
925     absl::Span<const int64> broadcast_dimension)
926     : HloInstruction(HloOpcode::kBroadcast, shape),
927       dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
928   AppendOperand(operand);
929 }
930 
ToProto() const931 HloInstructionProto HloBroadcastInstruction::ToProto() const {
932   HloInstructionProto proto = HloInstruction::ToProto();
933   for (int64 dimension : dimensions_) {
934     proto.add_dimensions(dimension);
935   }
936   return proto;
937 }
938 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const939 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
940     const HloPrintOptions& options) const {
941   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
942 }
943 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const944 bool HloBroadcastInstruction::IdenticalSlowPath(
945     const HloInstruction& other,
946     const std::function<bool(const HloComputation*, const HloComputation*)>&
947         eq_computations) const {
948   const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
949   return dimensions() == casted_other.dimensions();
950 }
951 
952 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const953 HloBroadcastInstruction::CloneWithNewOperandsImpl(
954     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
955     HloCloneContext* context) const {
956   CHECK_EQ(new_operands.size(), 1);
957   return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
958                                                     dimensions());
959 }
960 
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)961 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
962                                              HloInstruction* operand,
963                                              int64 inferred_dimension)
964     : HloInstruction(HloOpcode::kReshape, shape),
965       inferred_dimension_(inferred_dimension) {
966   AppendOperand(operand);
967 }
968 
ToProto() const969 HloInstructionProto HloReshapeInstruction::ToProto() const {
970   HloInstructionProto proto = HloInstruction::ToProto();
971   if (inferred_dimension_ != -1) {
972     proto.add_dimensions(inferred_dimension_);
973   }
974   return proto;
975 }
976 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const977 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
978     const HloPrintOptions& options) const {
979   if (inferred_dimension() == -1) {
980     return {};
981   }
982   return {StrCat("inferred_dimension=", inferred_dimension())};
983 }
984 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const985 bool HloReshapeInstruction::IdenticalSlowPath(
986     const HloInstruction& other,
987     const std::function<bool(const HloComputation*, const HloComputation*)>&
988         eq_computations) const {
989   const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
990   return inferred_dimension() == casted_other.inferred_dimension();
991 }
992 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const993 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
994     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
995     HloCloneContext* context) const {
996   CHECK_EQ(new_operands.size(), 1);
997   return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
998                                                   inferred_dimension());
999 }
1000 
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1001 HloMapInstruction::HloMapInstruction(const Shape& shape,
1002                                      absl::Span<HloInstruction* const> operands,
1003                                      HloComputation* map_computation)
1004     : HloInstruction(HloOpcode::kMap, shape) {
1005   for (auto operand : operands) {
1006     AppendOperand(operand);
1007   }
1008   AppendComputation(map_computation);
1009   // TODO(b/65689298) Remove code below once Map is generalized to accept
1010   // arbitrary map dimensions.
1011   dimensions_.resize(shape.rank());
1012   std::iota(dimensions_.begin(), dimensions_.end(), 0);
1013 }
1014 
ToProto() const1015 HloInstructionProto HloMapInstruction::ToProto() const {
1016   HloInstructionProto proto = HloInstruction::ToProto();
1017   for (int64 dimension : dimensions_) {
1018     proto.add_dimensions(dimension);
1019   }
1020   return proto;
1021 }
1022 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1023 bool HloMapInstruction::IsElementwiseImpl(
1024     const absl::optional<int64>& operand_idx) const {
1025   if (!dimensions().empty()) {
1026     // Check that the map is executed in elementwise compatible dimensions.
1027     if (dimensions().size() != shape().dimensions_size()) {
1028       return false;
1029     }
1030     for (int i = 0; i < dimensions().size(); ++i) {
1031       if (dimensions()[i] != i) {
1032         return false;
1033       }
1034     }
1035   }
1036   return true;
1037 }
1038 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1039 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1040     const HloPrintOptions& options) const {
1041   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1042 }
1043 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1044 bool HloMapInstruction::IdenticalSlowPath(
1045     const HloInstruction& other,
1046     const std::function<bool(const HloComputation*, const HloComputation*)>&
1047         eq_computations) const {
1048   return eq_computations(to_apply(), other.to_apply());
1049 }
1050 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1051 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1052     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1053     HloCloneContext* context) const {
1054   return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1055 }
1056 
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1057 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1058                                          HloInstruction* operand,
1059                                          absl::Span<const int64> start_indices,
1060                                          absl::Span<const int64> limit_indices,
1061                                          absl::Span<const int64> strides)
1062     : HloInstruction(HloOpcode::kSlice, shape),
1063       slice_starts_(start_indices.begin(), start_indices.end()),
1064       slice_limits_(limit_indices.begin(), limit_indices.end()),
1065       slice_strides_(strides.begin(), strides.end()) {
1066   AppendOperand(operand);
1067   // For backward compatibility with old serialized computations: if there are
1068   // no strides, assume all strides are 1.
1069   // TODO(b/63317920): remove this code.
1070   if (slice_strides_.empty()) {
1071     slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1072   }
1073 }
1074 
ToProto() const1075 HloInstructionProto HloSliceInstruction::ToProto() const {
1076   HloInstructionProto proto = HloInstruction::ToProto();
1077   for (int i = 0; i < slice_starts_.size(); ++i) {
1078     auto* slice_dimension = proto.add_slice_dimensions();
1079     slice_dimension->set_start(slice_starts_[i]);
1080     slice_dimension->set_limit(slice_limits_[i]);
1081     slice_dimension->set_stride(slice_strides_[i]);
1082   }
1083   return proto;
1084 }
1085 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1086 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1087     const HloPrintOptions& options) const {
1088   std::vector<string> bounds;
1089   bounds.reserve(slice_starts_.size());
1090   const bool omit_stride =
1091       absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
1092   for (int i = 0; i < slice_starts_.size(); ++i) {
1093     string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1094     bounds.push_back(
1095         StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1096   }
1097   return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1098 }
1099 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1100 bool HloSliceInstruction::IdenticalSlowPath(
1101     const HloInstruction& other,
1102     const std::function<bool(const HloComputation*, const HloComputation*)>&
1103         eq_computations) const {
1104   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1105   return slice_starts_ == other_slice.slice_starts_ &&
1106          slice_limits_ == other_slice.slice_limits_ &&
1107          slice_strides_ == other_slice.slice_strides_;
1108 }
1109 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1110 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1111     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1112     HloCloneContext* context) const {
1113   CHECK_EQ(new_operands.size(), 1);
1114   return absl::make_unique<HloSliceInstruction>(
1115       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1116 }
1117 
HloConstantInstruction(Literal literal)1118 HloConstantInstruction::HloConstantInstruction(Literal literal)
1119     : HloInstruction(HloOpcode::kConstant, literal.shape()),
1120       literal_(std::move(literal)) {}
1121 
HloConstantInstruction(Literal literal,const Shape & shape)1122 HloConstantInstruction::HloConstantInstruction(Literal literal,
1123                                                const Shape& shape)
1124     : HloInstruction(HloOpcode::kConstant, shape),
1125       literal_(std::move(literal)) {}
1126 
HloConstantInstruction(const Shape & shape)1127 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1128     : HloInstruction(HloOpcode::kConstant, shape) {}
1129 
ToProto() const1130 HloInstructionProto HloConstantInstruction::ToProto() const {
1131   HloInstructionProto proto = HloInstruction::ToProto();
1132   if (literal_.has_value()) {
1133     *proto.mutable_literal() = literal_->ToProto();
1134   }
1135   return proto;
1136 }
1137 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1138 bool HloConstantInstruction::IsElementwiseImpl(
1139     const absl::optional<int64>& operand_idx) const {
1140   return true;
1141 }
1142 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1143 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1144                                               const ShapeIndex& shape_index) {
1145   Shape* mutable_array_subshape =
1146       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1147   CHECK(mutable_array_subshape->IsArray());
1148 
1149   // Normally array_subshape will always have a layout, but this invariant is
1150   // temporarily broken in LayoutAssignment::AssignLayouts.
1151 
1152   if (!mutable_array_subshape->has_layout() ||
1153       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1154     *literal_ = literal_->Relayout(new_layout, shape_index);
1155     *mutable_array_subshape->mutable_layout() = new_layout;
1156   }
1157 }
1158 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1159 bool HloConstantInstruction::IdenticalSlowPath(
1160     const HloInstruction& other,
1161     const std::function<bool(const HloComputation*, const HloComputation*)>&
1162         eq_computations) const {
1163   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1164   return literal() == other_slice.literal();
1165 }
1166 
1167 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1168 HloConstantInstruction::CloneWithNewOperandsImpl(
1169     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1170     HloCloneContext* context) const {
1171   CHECK(literal_.has_value());
1172   // Literal's shape may have no/different tiling info. Use this instruction's
1173   // shape instead.
1174   CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1175                                                   this->shape()));
1176   return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1177                                                    this->shape());
1178 }
1179 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1180 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1181     const HloPrintOptions& options,
1182     CanonicalNameMap* canonical_name_map) const {
1183   string operands;
1184   // For constants, show the actual value in place of an empty operand list.
1185   if (literal_.has_value() &&
1186       ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1187        options.print_large_constants())) {
1188     // Literal::ToString emits multidimensional arrays over multiple
1189     // lines. Compact this into one line by stripping out white space.
1190     string tmp = literal().ToStringWithoutShape();
1191     std::replace(tmp.begin(), tmp.end(), '\n', ' ');
1192     std::vector<string> v = absl::StrSplit(tmp, ' ');
1193     bool first = true;
1194     // Concatenate elements in "v" with spaces separating them, but ignoring
1195     // empty entries.
1196     for (const auto& s : v) {
1197       if (s.empty()) {
1198         continue;
1199       }
1200       StrAppend(&operands, (first ? "" : " "), s);
1201       first = false;
1202     }
1203   } else {
1204     // Do not show large constants or tuples.
1205     operands = "{...}";
1206   }
1207   return operands;
1208 }
1209 
HloTraceInstruction(const string & tag,HloInstruction * operand)1210 HloTraceInstruction::HloTraceInstruction(const string& tag,
1211                                          HloInstruction* operand)
1212     : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1213       literal_(LiteralUtil::CreateR1U8(tag)) {
1214   AppendOperand(operand);
1215   operand->set_tracing(this);
1216 }
1217 
ToProto() const1218 HloInstructionProto HloTraceInstruction::ToProto() const {
1219   HloInstructionProto proto = HloInstruction::ToProto();
1220   *proto.mutable_literal() = literal_.ToProto();
1221   return proto;
1222 }
1223 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1224 bool HloTraceInstruction::IdenticalSlowPath(
1225     const HloInstruction& other,
1226     const std::function<bool(const HloComputation*, const HloComputation*)>&
1227         eq_computations) const {
1228   return false;
1229 }
1230 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1231 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1232     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1233     HloCloneContext* context) const {
1234   LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1235 }
1236 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1237 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1238                                            FusionKind fusion_kind,
1239                                            HloInstruction* fused_root)
1240     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1241   CHECK(fused_root != nullptr);
1242   SetAndSanitizeName("fusion");
1243   set_parent(fused_root->parent());
1244   set_metadata(fused_root->metadata());
1245   CloneAndFuseInternal(fused_root);
1246 }
1247 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1248 HloFusionInstruction::HloFusionInstruction(
1249     const Shape& shape, FusionKind fusion_kind,
1250     absl::Span<HloInstruction* const> operands,
1251     HloComputation* fusion_computation)
1252     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1253   for (auto operand : operands) {
1254     AppendOperand(operand);
1255   }
1256   SetAndSanitizeName("fusion");
1257   AppendComputation(fusion_computation);
1258   fusion_computation->SetFusionInstruction(this);
1259 }
1260 
ToCategory() const1261 string HloFusionInstruction::ToCategory() const {
1262   switch (fusion_kind()) {
1263     case FusionKind::kLoop:
1264       return "loop fusion";
1265     case FusionKind::kInput:
1266       return "input fusion";
1267     case FusionKind::kOutput:
1268       return "output fusion";
1269     case FusionKind::kCustom:
1270       return "custom fusion";
1271   }
1272 }
1273 
ToProto() const1274 HloInstructionProto HloFusionInstruction::ToProto() const {
1275   HloInstructionProto proto = HloInstruction::ToProto();
1276   proto.set_fusion_kind(xla::ToString(fusion_kind()));
1277   proto.add_called_computation_ids(
1278       fused_instructions_computation()->unique_id());
1279   return proto;
1280 }
1281 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1282 bool HloFusionInstruction::IsElementwiseImpl(
1283     const absl::optional<int64>& operand_idx) const {
1284   if (!operand_idx.has_value()) {
1285     for (auto* fused : fused_instructions()) {
1286       if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1287         return false;
1288       }
1289     }
1290     return true;
1291   }
1292   // A loop-fusion is elementwise on an operand if all operations (computed
1293   // using BFS) between the operand and the fused root are elementwise.
1294   std::deque<HloInstruction*> worklist;
1295   std::unordered_set<const HloInstruction*> visited;
1296   worklist.push_back(fused_parameter(operand_idx.value()));
1297   visited.insert(fused_parameter(operand_idx.value()));
1298   while (!worklist.empty()) {
1299     HloInstruction* operand = worklist.front();
1300     worklist.pop_front();
1301     for (HloInstruction* user : operand->users()) {
1302       CHECK_GE(user->unique_id(), 0);
1303       if (ContainsKey(visited, user)) {
1304         continue;
1305       }
1306       if (user->IsElementwise() ||
1307           IsInstructionElementwiseOnOperand(user, operand)) {
1308         worklist.push_back(user);
1309         visited.insert(user);
1310       } else {
1311         return false;
1312       }
1313     }
1314   }
1315   return true;
1316 }
1317 
AddFusionOperand(HloInstruction * new_operand)1318 HloInstruction* HloFusionInstruction::AddFusionOperand(
1319     HloInstruction* new_operand) {
1320   CHECK_EQ(operand_count(),
1321            fused_instructions_computation()->parameter_instructions().size());
1322   const int64 param_no = operand_count();
1323   string param_name = StrCat("param_", param_no);
1324   HloInstruction* fused_parameter =
1325       fused_instructions_computation()->AddParameter(
1326           HloInstruction::CreateParameter(param_no, new_operand->shape(),
1327                                           param_name));
1328   AppendOperand(new_operand);
1329   return fused_parameter;
1330 }
1331 
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1332 void HloFusionInstruction::MergeFusionInstruction(
1333     HloFusionInstruction* instruction_to_merge) {
1334   CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1335   // Clone the instruction from which to merge fused instructions.
1336   std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1337   HloFusionInstruction* cloned_fusion =
1338       static_cast<HloFusionInstruction*>(cloned.get());
1339   // Replace uses of fused parameters with the corresponding operand of the
1340   // fusion.  Add all non-parameter fused instructions to
1341   // 'unfused_instructions' to be merged into 'this'.  This is done in reverse
1342   // post order.
1343   std::vector<HloInstruction*> unfused_instructions;
1344   auto fused_instructions = cloned_fusion->fused_instructions_computation()
1345                                 ->MakeInstructionPostOrder();
1346   for (auto fused_it = fused_instructions.rbegin();
1347        fused_it != fused_instructions.rend(); ++fused_it) {
1348     auto fused_instruction = *fused_it;
1349     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1350       TF_CHECK_OK(
1351           fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1352               fused_instruction->parameter_number())));
1353     } else {
1354       unfused_instructions.push_back(fused_instruction);
1355     }
1356   }
1357 
1358   // If there are no unfused instructions, the fused computation must consist
1359   // only of kParameter instructions. Make the operand of the corresponding
1360   // parameter number the new root.
1361   HloInstruction* unfused_root =
1362       unfused_instructions.empty()
1363           ? instruction_to_merge->mutable_operand(
1364                 instruction_to_merge->fused_instructions_computation()
1365                     ->root_instruction()
1366                     ->parameter_number())
1367           : unfused_instructions.front();
1368   CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1369         unfused_instructions.empty());
1370   // Replace instruction_to_merge use of 'this' with unfused_root.
1371   TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1372   // Fuse 'unfused_instructions' into 'this'.
1373   for (auto& instruction : unfused_instructions) {
1374     FuseInstruction(instruction);
1375   }
1376   CHECK_EQ(0, cloned_fusion->user_count());
1377   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1378       cloned_fusion->fused_instructions_computation()));
1379 }
1380 
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1381 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1382     HloFusionInstruction* instruction_to_merge) {
1383   // Add all non-parameter fused instructions to 'unfused_instructions' to be
1384   // merged into 'this'. `old_to_new' maps the instructions in the fused node
1385   // to the disassembled fusion instructions.
1386   // Note that we add the unfused instructions to this->parent_ computation.
1387   // This is necessary because the unique_id needs for an instruction and
1388   // it's only added when inserting to the computation.
1389   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1390   std::vector<HloInstruction*> unfused_instructions;
1391   auto computation_to_merge =
1392       instruction_to_merge->fused_instructions_computation();
1393   auto post_order = computation_to_merge->MakeInstructionPostOrder();
1394   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1395     auto fused_instruction = *rit;
1396     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1397       InsertOrDie(&old_to_new, fused_instruction,
1398                   instruction_to_merge->mutable_operand(
1399                       fused_instruction->parameter_number()));
1400       continue;
1401     }
1402 
1403     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1404     // which clones again. This can be improved.
1405     auto cloned_instruction =
1406         parent()->AddInstruction(fused_instruction->Clone());
1407     unfused_instructions.push_back(cloned_instruction);
1408     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1409   }
1410   for (auto unfused_instruction : unfused_instructions) {
1411     for (int64 index = 0; index < unfused_instruction->operand_count();
1412          index++) {
1413       auto new_operand =
1414           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1415       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1416     }
1417   }
1418 
1419   // If there are no unfused instructions, the fused computation must consist
1420   // only of kParameter instructions. Make the operand of the corresponding
1421   // parameter number the new root.
1422   HloInstruction* unfused_root =
1423       unfused_instructions.empty()
1424           ? instruction_to_merge->mutable_operand(
1425                 instruction_to_merge->fused_instructions_computation()
1426                     ->root_instruction()
1427                     ->parameter_number())
1428           : unfused_instructions.front();
1429   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1430 
1431   TF_CHECK_OK(
1432       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1433   if (GetModule()) {
1434     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1435   }
1436 
1437   // Fuse the root instruction and generate multiple outputs.
1438   if (unfused_instructions.empty()) {
1439     return;
1440   }
1441   FuseInstructionIntoMultiOutput(unfused_root);
1442   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1443   // The rest instructions are of normal fusing.
1444   for (int64 i = 1; i < unfused_instructions.size(); i++) {
1445     auto instruction = unfused_instructions[i];
1446     FuseInstruction(instruction);
1447     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1448   }
1449 }
1450 
fused_instructions_computation() const1451 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1452   CHECK(!called_computations().empty());
1453   auto* fused_instructions_computation = called_computations().front();
1454   CHECK(fused_instructions_computation->IsFusionComputation())
1455       << "Computation " << fused_instructions_computation->name()
1456       << " is not a fusion kind";
1457   return fused_instructions_computation;
1458 }
1459 
fused_expression_root() const1460 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1461   return fused_instructions_computation()->root_instruction();
1462 }
1463 
fused_parameter(int64 parameter_number) const1464 HloInstruction* HloFusionInstruction::fused_parameter(
1465     int64 parameter_number) const {
1466   return fused_instructions_computation()->parameter_instruction(
1467       parameter_number);
1468 }
1469 
fused_parameters() const1470 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1471     const {
1472   return fused_instructions_computation()->parameter_instructions();
1473 }
1474 
1475 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1476     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1477 HloFusionInstruction::fused_instructions() const {
1478   const HloComputation* subcomp = fused_instructions_computation();
1479   return subcomp->instructions();
1480 }
1481 
1482 const tensorflow::gtl::iterator_range<
1483     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1484 HloFusionInstruction::fused_instructions() {
1485   return fused_instructions_computation()->instructions();
1486 }
1487 
fused_instruction_count() const1488 int64 HloFusionInstruction::fused_instruction_count() const {
1489   return fused_instructions_computation()->instruction_count();
1490 }
1491 
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1492 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1493     HloInstruction* instruction_to_fuse, bool add_output) {
1494   // When add_output is false, this fusion instruction must be a user of
1495   // instruction_to_fuse.
1496   if (!add_output) {
1497     CHECK(IsUserOf(instruction_to_fuse));
1498   }
1499   HloInstruction* fused_instruction =
1500       CloneAndFuseInternal(instruction_to_fuse, add_output);
1501   return fused_instruction;
1502 }
1503 
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1504 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1505     HloInstruction* instruction_to_fuse, bool add_output) {
1506   CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1507   VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1508   HloInstruction* clone = nullptr;
1509   if (called_computations().empty()) {
1510     // New fusion instruction. It should not be a multioutput instruction.
1511     CHECK(!add_output);
1512     auto builder = HloComputation::Builder("fused_computation", this);
1513     builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1514     AppendComputation(
1515         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1516     clone = fused_expression_root();
1517   } else {
1518     // When add_output is false, instruction_to_fuse is necessarily an operand
1519     // of the fusion instruction. After fusion this will no longer be the
1520     // case. Remove the operand from the operand list and remove its
1521     // corresponding fused parameter instruction. Renumber parameters as
1522     // necessary to make parameter numbers consistent with their index in the
1523     // fused_parameter_ vector.
1524     bool in_operand_list =
1525         absl::c_linear_search(operands(), instruction_to_fuse);
1526     CHECK(add_output || in_operand_list);
1527     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1528       // We assume all uses of a kTuple operation are GTE ops, not another
1529       // fusion node. In this case, we don't need to clone
1530       // 'instruction_to_fuse'.
1531       CHECK(!in_operand_list);
1532       clone = instruction_to_fuse;
1533     } else {
1534       clone = fused_instructions_computation()->AddInstruction(
1535           instruction_to_fuse->Clone(/*suffix=*/""));
1536     }
1537     const std::vector<HloInstruction*>& fused_parameters =
1538         fused_instructions_computation()->parameter_instructions();
1539     for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1540       if (instruction_to_fuse == operand(operand_num)) {
1541         // replace the fused parameter instruction's uses with the clone.
1542         HloInstruction* fused_parameter = fused_parameters[operand_num];
1543         TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1544 
1545         // Remove the corresponding fused parameter and operand from their
1546         // respective vectors.
1547         TF_CHECK_OK(
1548             fused_instructions_computation()->RemoveParameter(operand_num));
1549         RemoveOperandAt(operand_num);
1550         break;
1551       }
1552     }
1553     // We've cloned instruction_to_fuse into this fusion instruction, so this
1554     // fusion instruction is no longer a use of instruction_to_fuse.
1555     if (in_operand_list) {
1556       DetachFrom(instruction_to_fuse);
1557       // When the instruction_to_fuse does not have other users, we don't need
1558       // to generate a multioutput fusion instruction.
1559       if (instruction_to_fuse->user_count() == 0) {
1560         add_output = false;
1561       }
1562     }
1563   }
1564 
1565   // Reread the parameters in the computation.
1566   const std::vector<HloInstruction*>& fused_parameters =
1567       fused_instructions_computation()->parameter_instructions();
1568 
1569   // Add each operand of the clone as an operand of the fusion instruction. A
1570   // complication is that some clone operands may already be operands of the
1571   // fusion instruction.
1572   for (int64 operand_num = 0; operand_num < clone->operand_count();
1573        ++operand_num) {
1574     HloInstruction* operand = clone->mutable_operand(operand_num);
1575 
1576     // See if this operand is already an operand of the fusion node.
1577     CHECK_EQ(operands().size(), fused_parameters.size());
1578     HloInstruction* fused_param = nullptr;
1579     for (int64 i = 0; i < operands().size(); ++i) {
1580       if (this->operand(i) == operand) {
1581         fused_param = fused_parameters[i];
1582         break;
1583       }
1584     }
1585 
1586     if (fused_param == nullptr) {
1587       // Clone's operand was not already an operand of the fusion
1588       // instruction. Add it as an operand and add a corresponding fused
1589       // parameter instruction.
1590       fused_param = AddFusionOperand(operand);
1591     }
1592     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1593   }
1594 
1595   if (add_output) {
1596     CHECK_GT(instruction_to_fuse->user_count(), 0);
1597     // If this is already a multioutput fusion instruction, expand the root
1598     // tuple by 1.
1599     HloInstruction* fused_root = fused_expression_root();
1600     HloInstruction::InstructionVector tuple_elements;
1601     bool newly_created_tuple_instr = false;
1602     if (fused_root->opcode() == HloOpcode::kTuple) {
1603       tuple_elements = fused_root->operands();
1604     } else {
1605       tuple_elements.push_back(fused_root);
1606       newly_created_tuple_instr = true;
1607     }
1608     if (clone->opcode() == HloOpcode::kTuple) {
1609       for (auto inst : clone->operands()) {
1610         tuple_elements.push_back(inst);
1611       }
1612     } else {
1613       tuple_elements.push_back(clone);
1614     }
1615     HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1616         HloInstruction::CreateTuple(tuple_elements));
1617     fused_instructions_computation()->set_root_instruction(new_root);
1618     *mutable_shape() = new_root->shape();
1619     if (fused_root->opcode() == HloOpcode::kTuple) {
1620       TF_CHECK_OK(
1621           fused_instructions_computation()->RemoveInstruction(fused_root));
1622     }
1623 
1624     // If this is a newly created multioutput instruction, we need to update
1625     // the use of the original fusion instruction.
1626     if (newly_created_tuple_instr) {
1627       HloInstruction* new_instr = parent()->AddInstruction(
1628           HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1629       TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1630     }
1631     int64 index = tuple_elements.size();
1632     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1633       CHECK_EQ(clone, instruction_to_fuse);
1634       index -= clone->operand_count();
1635       std::vector<HloInstruction*> to_be_removed;
1636       for (auto old_gte : clone->users()) {
1637         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1638         int64 old_tuple_index = old_gte->tuple_index();
1639         HloInstruction* new_gte =
1640             parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1641                 old_gte->shape(), this, index + old_tuple_index));
1642         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1643         to_be_removed.push_back(old_gte);
1644       }
1645       for (auto old_gte : to_be_removed) {
1646         TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1647       }
1648     } else {
1649       HloInstruction* new_gte =
1650           parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1651               clone->shape(), this, index - 1));
1652       TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1653     }
1654   }
1655 
1656   if (clone != instruction_to_fuse) {
1657     VLOG(2) << "New clone:\n" << clone->ToString();
1658   }
1659   return clone;
1660 }
1661 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1662 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1663     const HloPrintOptions& options) const {
1664   return {StrCat("kind=", xla::ToString(fusion_kind()))};
1665 }
1666 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1667 bool HloFusionInstruction::IdenticalSlowPath(
1668     const HloInstruction& other,
1669     const std::function<bool(const HloComputation*, const HloComputation*)>&
1670         eq_computations) const {
1671   return fusion_kind() == other.fusion_kind() &&
1672          eq_computations(fused_instructions_computation(),
1673                          other.fused_instructions_computation());
1674 }
1675 
InnerHash() const1676 uint64 HloFusionInstruction::InnerHash() const {
1677   return fused_instructions_computation()->root_instruction()->Hash();
1678 }
1679 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1680 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1681     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1682     HloCloneContext* context) const {
1683   HloModule* module = context != nullptr ? context->module() : GetModule();
1684   HloComputation* new_fused_computation = nullptr;
1685   if (context != nullptr) {
1686     new_fused_computation =
1687         context->FindComputation(fused_instructions_computation());
1688   }
1689   if (new_fused_computation == nullptr) {
1690     new_fused_computation = module->AddEmbeddedComputation(
1691         fused_instructions_computation()->Clone("clone", context));
1692   }
1693   return absl::make_unique<HloFusionInstruction>(
1694       shape, fusion_kind(), new_operands, new_fused_computation);
1695 }
1696 
DeduplicateFusionOperands()1697 Status HloFusionInstruction::DeduplicateFusionOperands() {
1698   if (IsCustomFusion()) {
1699     return Status::OK();
1700   }
1701   absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1702   std::vector<int> operands_to_remove;
1703   for (int i = 0; i < operand_count(); ++i) {
1704     auto emplace_result = operand_indices.emplace(operand(i), i);
1705     if (!emplace_result.second) {
1706       TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1707           fused_parameter(emplace_result.first->second)));
1708       operands_to_remove.push_back(i);
1709     }
1710   }
1711   if (operands_to_remove.empty()) {
1712     return Status::OK();
1713   }
1714   TF_RETURN_IF_ERROR(fused_instructions_computation()
1715                          ->RemoveUnusedParametersFromFusedComputation());
1716   RemoveOperandsAtAscendingIndices(operands_to_remove);
1717   return Status::OK();
1718 }
1719 
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1720 HloRngInstruction::HloRngInstruction(
1721     const Shape& shape, RandomDistribution distribution,
1722     absl::Span<HloInstruction* const> parameters)
1723     : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1724   for (HloInstruction* param : parameters) {
1725     AppendOperand(param);
1726   }
1727 }
1728 
ToProto() const1729 HloInstructionProto HloRngInstruction::ToProto() const {
1730   HloInstructionProto proto = HloInstruction::ToProto();
1731   proto.set_distribution(distribution_);
1732   return proto;
1733 }
1734 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1735 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1736     const HloPrintOptions& options) const {
1737   return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1738 }
1739 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1740 bool HloRngInstruction::IsElementwiseImpl(
1741     const absl::optional<int64>& operand_idx) const {
1742   return true;
1743 }
1744 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1745 bool HloRngInstruction::IdenticalSlowPath(
1746     const HloInstruction& other,
1747     const std::function<bool(const HloComputation*, const HloComputation*)>&
1748         eq_computations) const {
1749   return true;
1750 }
1751 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1752 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1753     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1754     HloCloneContext* context) const {
1755   return absl::make_unique<HloRngInstruction>(shape, distribution_,
1756                                               new_operands);
1757 }
1758 
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1759 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1760                                                  const Shape& shape,
1761                                                  const string& name)
1762     : HloInstruction(HloOpcode::kParameter, shape),
1763       parameter_number_(parameter_number) {
1764   SetAndSanitizeName(name);
1765 }
1766 
ToProto() const1767 HloInstructionProto HloParameterInstruction::ToProto() const {
1768   HloInstructionProto proto = HloInstruction::ToProto();
1769   proto.set_parameter_number(parameter_number_);
1770   if (parameter_replicated_at_leaf_buffers_) {
1771     for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1772       proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1773           replicated);
1774     }
1775   }
1776   return proto;
1777 }
1778 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1779 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1780     const HloPrintOptions& options) const {
1781   std::vector<string> result;
1782   if (!parameter_replicated_at_leaf_buffers_) {
1783     return result;
1784   }
1785   std::vector<string> buffers_replicated_strs;
1786   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1787     buffers_replicated_strs.push_back(replicated ? "true" : "false");
1788   }
1789   if (options.print_ids()) {
1790     result.push_back(StrCat("parameter_replication={",
1791                             StrJoin(buffers_replicated_strs, ","), "}"));
1792   }
1793   return result;
1794 }
1795 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1796 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1797     const HloPrintOptions& options,
1798     CanonicalNameMap* canonical_name_map) const {
1799   return StrCat(parameter_number_);
1800 }
1801 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1802 bool HloParameterInstruction::IdenticalSlowPath(
1803     const HloInstruction& other,
1804     const std::function<bool(const HloComputation*, const HloComputation*)>&
1805         eq_computations) const {
1806   const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1807   return parameter_number() == casted_other.parameter_number();
1808 }
1809 
1810 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1811 HloParameterInstruction::CloneWithNewOperandsImpl(
1812     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1813     HloCloneContext* context) const {
1814   return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
1815                                                     name());
1816 }
1817 
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1818 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1819     const Shape& shape, HloInstruction* operand, int64 index)
1820     : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1821   AppendOperand(operand);
1822 }
1823 
ToProto() const1824 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1825   HloInstructionProto proto = HloInstruction::ToProto();
1826   proto.set_tuple_index(tuple_index_);
1827   return proto;
1828 }
1829 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1830 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1831     const HloPrintOptions& options) const {
1832   return {StrCat("index=", tuple_index())};
1833 }
1834 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1835 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1836     const HloInstruction& other,
1837     const std::function<bool(const HloComputation*, const HloComputation*)>&
1838         eq_computations) const {
1839   const auto& casted_other =
1840       static_cast<const HloGetTupleElementInstruction&>(other);
1841   return tuple_index() == casted_other.tuple_index();
1842 }
1843 
1844 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1845 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1846     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1847     HloCloneContext* context) const {
1848   CHECK_EQ(new_operands.size(), 1);
1849   return absl::make_unique<HloGetTupleElementInstruction>(
1850       shape, new_operands[0], tuple_index());
1851 }
1852 
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1853 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
1854     const Shape& shape, HloInstruction* operand, const int exponent_bits,
1855     const int mantissa_bits)
1856     : HloInstruction(HloOpcode::kReducePrecision, shape),
1857       exponent_bits_(exponent_bits),
1858       mantissa_bits_(mantissa_bits) {
1859   AppendOperand(operand);
1860 }
1861 
ToProto() const1862 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
1863   HloInstructionProto proto = HloInstruction::ToProto();
1864   proto.set_exponent_bits(exponent_bits_);
1865   proto.set_mantissa_bits(mantissa_bits_);
1866   return proto;
1867 }
1868 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1869 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
1870     const HloPrintOptions& options) const {
1871   return {StrCat("exponent_bits=", exponent_bits_),
1872           StrCat("mantissa_bits=", mantissa_bits_)};
1873 }
1874 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1875 bool HloReducePrecisionInstruction::IdenticalSlowPath(
1876     const HloInstruction& other,
1877     const std::function<bool(const HloComputation*, const HloComputation*)>&
1878         eq_computations) const {
1879   const auto& casted_other =
1880       static_cast<const HloReducePrecisionInstruction&>(other);
1881   // A reduce-precision operation is determined by the bit sizes.
1882   return exponent_bits() == casted_other.exponent_bits() &&
1883          mantissa_bits() == casted_other.mantissa_bits();
1884 }
1885 
1886 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1887 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
1888     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1889     HloCloneContext* context) const {
1890   CHECK_EQ(new_operands.size(), 1);
1891   return absl::make_unique<HloReducePrecisionInstruction>(
1892       shape, new_operands[0], exponent_bits(), mantissa_bits());
1893 }
1894 
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1895 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
1896                                            HloInstruction* token_operand,
1897                                            const string& config)
1898     : HloInstruction(HloOpcode::kInfeed,
1899                      ShapeUtil::MakeTupleShape(
1900                          {infeed_shape, ShapeUtil::MakeTokenShape()})),
1901       infeed_config_(config) {
1902   AppendOperand(token_operand);
1903 }
1904 
ToProto() const1905 HloInstructionProto HloInfeedInstruction::ToProto() const {
1906   HloInstructionProto proto = HloInstruction::ToProto();
1907   proto.set_infeed_config(infeed_config_);
1908   return proto;
1909 }
1910 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1911 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
1912     const HloPrintOptions& options) const {
1913   if (infeed_config_.empty()) {
1914     return {};
1915   }
1916   return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
1917 }
1918 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1919 bool HloInfeedInstruction::IdenticalSlowPath(
1920     const HloInstruction& other,
1921     const std::function<bool(const HloComputation*, const HloComputation*)>&
1922         eq_computations) const {
1923   // Not yet supported.
1924   return false;
1925 }
1926 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1927 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
1928     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1929     HloCloneContext* context) const {
1930   CHECK_EQ(new_operands.size(), 1);
1931   return absl::make_unique<HloInfeedInstruction>(
1932       infeed_shape(), new_operands[0], infeed_config());
1933 }
1934 
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1935 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
1936                                              HloInstruction* operand,
1937                                              HloInstruction* token_operand,
1938                                              absl::string_view outfeed_config)
1939     : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
1940       outfeed_shape_(outfeed_shape),
1941       outfeed_config_(outfeed_config) {
1942   AppendOperand(operand);
1943   AppendOperand(token_operand);
1944 }
1945 
ToProto() const1946 HloInstructionProto HloOutfeedInstruction::ToProto() const {
1947   HloInstructionProto proto = HloInstruction::ToProto();
1948   proto.set_outfeed_config(outfeed_config());
1949   *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
1950   return proto;
1951 }
1952 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1953 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
1954     const HloPrintOptions& options) const {
1955   if (outfeed_config_.empty()) {
1956     return {};
1957   }
1958   return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
1959 }
1960 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1961 bool HloOutfeedInstruction::IdenticalSlowPath(
1962     const HloInstruction& other,
1963     const std::function<bool(const HloComputation*, const HloComputation*)>&
1964         eq_computations) const {
1965   // Not yet supported.
1966   return false;
1967 }
1968 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1969 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
1970     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1971     HloCloneContext* context) const {
1972   CHECK_EQ(new_operands.size(), 2);
1973   return absl::make_unique<HloOutfeedInstruction>(
1974       outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
1975 }
1976 
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1977 HloConvolutionInstruction::HloConvolutionInstruction(
1978     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1979     int64 feature_group_count, int64 batch_group_count, const Window& window,
1980     const ConvolutionDimensionNumbers& dimension_numbers,
1981     const PrecisionConfig& precision_config)
1982     : HloInstruction(HloOpcode::kConvolution, shape),
1983       feature_group_count_(feature_group_count),
1984       batch_group_count_(batch_group_count),
1985       window_(window),
1986       convolution_dimension_numbers_(dimension_numbers),
1987       precision_config_(precision_config) {
1988   if (window_util::HasBaseDilation(window)) {
1989     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
1990   }
1991   if (window_util::HasWindowDilation(window)) {
1992     SetAndSanitizeName(StrCat(name(), "-window-dilated"));
1993   }
1994   AppendOperand(lhs);
1995   AppendOperand(rhs);
1996 }
1997 
ToCategory() const1998 string HloConvolutionInstruction::ToCategory() const {
1999   string category = "convolution";
2000   if (window_util::HasBaseDilation(window())) {
2001     category += " base-dilated";
2002   }
2003   if (window_util::HasWindowDilation(window())) {
2004     category += " window-dilated";
2005   }
2006   return category;
2007 }
2008 
ToProto() const2009 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2010   HloInstructionProto proto = HloInstruction::ToProto();
2011   *proto.mutable_window() = window_;
2012   *proto.mutable_convolution_dimension_numbers() =
2013       convolution_dimension_numbers_;
2014   proto.set_feature_group_count(feature_group_count_);
2015   proto.set_batch_group_count(batch_group_count_);
2016   *proto.mutable_precision_config() = precision_config_;
2017   return proto;
2018 }
2019 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2020 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2021     const HloPrintOptions& options) const {
2022   std::vector<string> extra;
2023   if (window_.dimensions_size() != 0) {
2024     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2025   }
2026   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2027                                             convolution_dimension_numbers_)));
2028   if (feature_group_count_ != 1) {
2029     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2030   }
2031 
2032   if (batch_group_count_ != 1) {
2033     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2034   }
2035 
2036   string precision_config_string = PrecisionConfigToString(precision_config_);
2037   if (!precision_config_string.empty()) {
2038     extra.push_back(precision_config_string);
2039   }
2040 
2041   return extra;
2042 }
2043 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2044 bool HloConvolutionInstruction::IdenticalSlowPath(
2045     const HloInstruction& other,
2046     const std::function<bool(const HloComputation*, const HloComputation*)>&
2047         eq_computations) const {
2048   const auto& casted_other =
2049       static_cast<const HloConvolutionInstruction&>(other);
2050   if (feature_group_count_ != other.feature_group_count()) {
2051     return false;
2052   }
2053   if (batch_group_count_ != other.batch_group_count()) {
2054     return false;
2055   }
2056   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2057          protobuf_util::ProtobufEquals(
2058              convolution_dimension_numbers(),
2059              casted_other.convolution_dimension_numbers()) &&
2060          protobuf_util::ProtobufEquals(precision_config(),
2061                                        casted_other.precision_config());
2062 }
2063 
2064 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2065 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2066     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2067     HloCloneContext* context) const {
2068   CHECK_EQ(new_operands.size(), 2);
2069   return absl::make_unique<HloConvolutionInstruction>(
2070       shape, new_operands[0], new_operands[1], feature_group_count_,
2071       batch_group_count_, window(), convolution_dimension_numbers_,
2072       precision_config_);
2073 }
2074 
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2075 HloReduceWindowInstruction::HloReduceWindowInstruction(
2076     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2077     const Window& window, HloComputation* reduce_computation)
2078     : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2079   AppendOperand(operand);
2080   AppendOperand(init_value);
2081   AppendComputation(reduce_computation);
2082 }
2083 
ToProto() const2084 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2085   HloInstructionProto proto = HloInstruction::ToProto();
2086   *proto.mutable_window() = window_;
2087   return proto;
2088 }
2089 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2090 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2091     const HloPrintOptions& options) const {
2092   std::vector<string> extra;
2093   if (window_.dimensions_size() != 0) {
2094     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2095   }
2096   return extra;
2097 }
2098 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2099 bool HloReduceWindowInstruction::IdenticalSlowPath(
2100     const HloInstruction& other,
2101     const std::function<bool(const HloComputation*, const HloComputation*)>&
2102         eq_computations) const {
2103   const auto& casted_other =
2104       static_cast<const HloReduceWindowInstruction&>(other);
2105   return eq_computations(to_apply(), casted_other.to_apply()) &&
2106          protobuf_util::ProtobufEquals(window(), casted_other.window());
2107 }
2108 
2109 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2110 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2111     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2112     HloCloneContext* context) const {
2113   CHECK_EQ(new_operands.size(), 2);
2114   return absl::make_unique<HloReduceWindowInstruction>(
2115       shape, new_operands[0], new_operands[1], window(), to_apply());
2116 }
2117 
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2118 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2119     const Shape& shape, HloInstruction* operand, HloComputation* select,
2120     const Window& window, HloInstruction* source, HloInstruction* init_value,
2121     HloComputation* scatter)
2122     : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2123   AppendOperand(operand);
2124   AppendOperand(source);
2125   AppendOperand(init_value);
2126   // Select comes before scatter in the vector.
2127   AppendComputation(select);
2128   AppendComputation(scatter);
2129 }
2130 
ToProto() const2131 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2132   HloInstructionProto proto = HloInstruction::ToProto();
2133   *proto.mutable_window() = window_;
2134   return proto;
2135 }
2136 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2137 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2138     const HloPrintOptions& options) const {
2139   std::vector<string> extra;
2140   if (window_.dimensions_size() != 0) {
2141     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2142   }
2143   return extra;
2144 }
2145 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2146 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2147     const HloInstruction& other,
2148     const std::function<bool(const HloComputation*, const HloComputation*)>&
2149         eq_computations) const {
2150   const auto& casted_other =
2151       static_cast<const HloSelectAndScatterInstruction&>(other);
2152   return eq_computations(select(), casted_other.select()) &&
2153          eq_computations(scatter(), casted_other.scatter()) &&
2154          protobuf_util::ProtobufEquals(window(), casted_other.window());
2155 }
2156 
2157 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2158 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2159     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2160     HloCloneContext* context) const {
2161   CHECK_EQ(new_operands.size(), 3);
2162   return absl::make_unique<HloSelectAndScatterInstruction>(
2163       shape, new_operands[0], select(), window(), new_operands[1],
2164       new_operands[2], scatter());
2165 }
2166 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)2167 HloCustomCallInstruction::HloCustomCallInstruction(
2168     const Shape& shape, absl::Span<HloInstruction* const> operands,
2169     absl::string_view custom_call_target, string opaque)
2170     : HloInstruction(HloOpcode::kCustomCall, shape),
2171       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2172       feature_group_count_(1),
2173       batch_group_count_(1),
2174       layout_constrained_(false),
2175       custom_call_has_side_effect_(false) {
2176   set_raw_backend_config_string(std::move(opaque));
2177   for (auto operand : operands) {
2178     AppendOperand(operand);
2179   }
2180 }
2181 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout)2182 HloCustomCallInstruction::HloCustomCallInstruction(
2183     const Shape& shape, absl::Span<HloInstruction* const> operands,
2184     absl::string_view custom_call_target, string opaque,
2185     absl::Span<const Shape> operand_shapes_with_layout)
2186     : HloInstruction(HloOpcode::kCustomCall, shape),
2187       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2188       feature_group_count_(1),
2189       batch_group_count_(1),
2190       layout_constrained_(true),
2191       operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2192                                   operand_shapes_with_layout.end()),
2193       custom_call_has_side_effect_(false) {
2194   set_raw_backend_config_string(std::move(opaque));
2195   for (auto operand : operands) {
2196     AppendOperand(operand);
2197   }
2198 }
2199 
ToProto() const2200 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2201   HloInstructionProto proto = HloInstruction::ToProto();
2202   if (window_ != nullptr) {
2203     *proto.mutable_window() = *window_;
2204   }
2205   if (convolution_dimension_numbers_ != nullptr) {
2206     *proto.mutable_convolution_dimension_numbers() =
2207         *convolution_dimension_numbers_;
2208   }
2209   proto.set_custom_call_target(custom_call_target_);
2210   proto.set_feature_group_count(feature_group_count_);
2211   proto.set_batch_group_count(batch_group_count_);
2212   if (layout_constrained()) {
2213     proto.set_constrain_layout(true);
2214     for (const Shape& shape : operand_shapes_with_layout_) {
2215       *proto.add_operand_shapes_with_layout() = shape.ToProto();
2216     }
2217   }
2218   proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2219   return proto;
2220 }
2221 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2222 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2223     const HloPrintOptions& options) const {
2224   std::vector<string> extra;
2225   if (window_ != nullptr) {
2226     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2227   }
2228   if (convolution_dimension_numbers_ != nullptr) {
2229     extra.push_back(StrCat(
2230         "dim_labels=",
2231         ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2232   }
2233   if (feature_group_count_ != 1) {
2234     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2235   }
2236   if (batch_group_count_ != 1) {
2237     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2238   }
2239   // By contract, we print the custom call target even if
2240   // options.print_subcomputation_mode() == kOff, because the call target is not
2241   // an HloComputation.
2242   extra.push_back(
2243       StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2244 
2245   if (layout_constrained()) {
2246     std::vector<string> shape_strings;
2247     for (const Shape& shape : operand_shapes_with_layout_) {
2248       shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2249     }
2250     extra.push_back(StrCat("operand_layout_constraints={",
2251                            StrJoin(shape_strings, ", "), "}"));
2252   }
2253   if (custom_call_has_side_effect_) {
2254     extra.push_back("custom_call_has_side_effect=true");
2255   }
2256   return extra;
2257 }
2258 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2259 bool HloCustomCallInstruction::IdenticalSlowPath(
2260     const HloInstruction& other,
2261     const std::function<bool(const HloComputation*, const HloComputation*)>&
2262         eq_computations) const {
2263   const auto& casted_other =
2264       static_cast<const HloCustomCallInstruction&>(other);
2265   if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2266       (window_ != nullptr &&
2267        !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2268     return false;
2269   }
2270   if ((convolution_dimension_numbers_ == nullptr) !=
2271           (casted_other.convolution_dimension_numbers_ == nullptr) ||
2272       (convolution_dimension_numbers_ != nullptr &&
2273        !protobuf_util::ProtobufEquals(
2274            convolution_dimension_numbers(),
2275            casted_other.convolution_dimension_numbers()))) {
2276     return false;
2277   }
2278   if (feature_group_count_ != casted_other.feature_group_count_) {
2279     return false;
2280   }
2281   if (batch_group_count_ != casted_other.batch_group_count_) {
2282     return false;
2283   }
2284   if (layout_constrained() != casted_other.layout_constrained()) {
2285     return false;
2286   }
2287   if (layout_constrained()) {
2288     for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2289       if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2290                             casted_other.operand_shapes_with_layout_[i])) {
2291         return false;
2292       }
2293     }
2294   }
2295   if (custom_call_has_side_effect_ !=
2296       casted_other.custom_call_has_side_effect()) {
2297     return false;
2298   }
2299   // Note: backend_config comparison is done in Identical, which is the
2300   // intended/exposed way to compare computations, and so not repeated here.
2301   return custom_call_target_ == casted_other.custom_call_target_;
2302 }
2303 
2304 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2305 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2306     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2307     HloCloneContext* context) const {
2308   auto cloned = absl::make_unique<HloCustomCallInstruction>(
2309       shape, new_operands, custom_call_target(), opaque());
2310   if (layout_constrained()) {
2311     cloned->layout_constrained_ = true;
2312     cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2313   }
2314   if (window_ != nullptr) {
2315     cloned->set_window(*window_);
2316   }
2317   if (convolution_dimension_numbers_ != nullptr) {
2318     cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2319   }
2320   cloned->set_feature_group_count(feature_group_count_);
2321   cloned->set_batch_group_count(batch_group_count_);
2322   cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2323   return std::move(cloned);
2324 }
2325 
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2326 HloPadInstruction::HloPadInstruction(const Shape& shape,
2327                                      HloInstruction* operand,
2328                                      HloInstruction* padding_value,
2329                                      const PaddingConfig& padding_config)
2330     : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2331   AppendOperand(operand);
2332   AppendOperand(padding_value);
2333 }
2334 
ToProto() const2335 HloInstructionProto HloPadInstruction::ToProto() const {
2336   HloInstructionProto proto = HloInstruction::ToProto();
2337   *proto.mutable_padding_config() = padding_config_;
2338   return proto;
2339 }
2340 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2341 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2342     const HloPrintOptions& options) const {
2343   return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2344 }
2345 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2346 bool HloPadInstruction::IdenticalSlowPath(
2347     const HloInstruction& other,
2348     const std::function<bool(const HloComputation*, const HloComputation*)>&
2349         eq_computations) const {
2350   const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2351   return protobuf_util::ProtobufEquals(padding_config(),
2352                                        casted_other.padding_config());
2353 }
2354 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2355 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2356     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2357     HloCloneContext* context) const {
2358   CHECK_EQ(new_operands.size(), 2);
2359   return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2360                                               new_operands[1], padding_config_);
2361 }
2362 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2363 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2364     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2365     absl::Span<const int64> slice_sizes)
2366     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2367       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2368   AppendOperand(operand);
2369   AppendOperand(start_indices);
2370 }
2371 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2372 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2373     const Shape& shape, HloInstruction* operand,
2374     absl::Span<HloInstruction* const> start_indices,
2375     absl::Span<const int64> slice_sizes)
2376     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2377       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2378   AppendOperand(operand);
2379   for (HloInstruction* index : start_indices) {
2380     AppendOperand(index);
2381   }
2382 }
2383 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2384 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2385     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2386     HloInstruction* start_indices)
2387     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2388   AppendOperand(operand);
2389   AppendOperand(update);
2390   AppendOperand(start_indices);
2391 }
2392 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2393 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2394     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2395     absl::Span<HloInstruction* const> start_indices)
2396     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2397   AppendOperand(operand);
2398   AppendOperand(update);
2399   for (HloInstruction* index : start_indices) {
2400     AppendOperand(index);
2401   }
2402 }
2403 
ToProto() const2404 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2405   HloInstructionProto proto = HloInstruction::ToProto();
2406   for (int64 slice_size : dynamic_slice_sizes_) {
2407     proto.add_dynamic_slice_sizes(slice_size);
2408   }
2409   return proto;
2410 }
2411 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2412 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2413     const HloPrintOptions& options) const {
2414   return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2415                  "}")};
2416 }
2417 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2418 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2419     const HloInstruction& other,
2420     const std::function<bool(const HloComputation*, const HloComputation*)>&
2421         eq_computations) const {
2422   return true;
2423 }
2424 
2425 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2426 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2427     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2428     HloCloneContext* context) const {
2429   if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2430     // TODO(b/118437727): Old form, remove this path.
2431     return absl::make_unique<HloDynamicSliceInstruction>(
2432         shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2433   } else {
2434     return absl::make_unique<HloDynamicSliceInstruction>(
2435         shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2436   }
2437 }
2438 
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2439 HloGatherInstruction::HloGatherInstruction(
2440     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2441     const GatherDimensionNumbers& gather_dim_numbers,
2442     absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2443     : HloInstruction(HloOpcode::kGather, shape),
2444       indices_are_sorted_(indices_are_sorted) {
2445   AppendOperand(operand);
2446   AppendOperand(start_indices);
2447   gather_dimension_numbers_ =
2448       absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2449   absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2450 }
2451 
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2452 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2453     const GatherDimensionNumbers& gather_dimension_numbers) {
2454   string offset_dims =
2455       StrCat("offset_dims={",
2456              StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2457   string collapsed_slice_dims = StrCat(
2458       "collapsed_slice_dims={",
2459       StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2460   string start_index_map =
2461       StrCat("start_index_map={",
2462              StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2463   string index_vector_dim =
2464       StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2465 
2466   return StrJoin<std::initializer_list<string>>(
2467       {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2468       ", ");
2469 }
2470 
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2471 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2472     absl::Span<const int64> offset_dims,
2473     absl::Span<const int64> collapsed_slice_dims,
2474     absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2475   GatherDimensionNumbers gather_dim_numbers;
2476   for (int64 output_window_dim : offset_dims) {
2477     gather_dim_numbers.add_offset_dims(output_window_dim);
2478   }
2479   for (int64 elided_window_dim : collapsed_slice_dims) {
2480     gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2481   }
2482   for (int64 gather_dim_to_input_dim : start_index_map) {
2483     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2484   }
2485 
2486   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2487   return gather_dim_numbers;
2488 }
2489 
ToProto() const2490 HloInstructionProto HloGatherInstruction::ToProto() const {
2491   HloInstructionProto proto = HloInstruction::ToProto();
2492   *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2493   for (int64 bound : gather_slice_sizes()) {
2494     proto.add_gather_slice_sizes(bound);
2495   }
2496   proto.set_indices_are_sorted(indices_are_sorted());
2497   return proto;
2498 }
2499 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2500 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2501     const HloPrintOptions& options) const {
2502   std::vector<string> attrs{
2503       GatherDimensionNumbersToString(gather_dimension_numbers()),
2504       StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2505   if (indices_are_sorted()) {
2506     attrs.push_back("indices_are_sorted=true");
2507   }
2508   return attrs;
2509 }
2510 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2511 bool HloGatherInstruction::IdenticalSlowPath(
2512     const HloInstruction& other,
2513     const std::function<bool(const HloComputation*, const HloComputation*)>&
2514         eq_computations) const {
2515   const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2516   return protobuf_util::ProtobufEquals(
2517              gather_dimension_numbers(),
2518              casted_other.gather_dimension_numbers()) &&
2519          gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2520          indices_are_sorted() == casted_other.indices_are_sorted();
2521 }
2522 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2523 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2524     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2525     HloCloneContext* context) const {
2526   CHECK_EQ(new_operands.size(), 2);
2527   return absl::make_unique<HloGatherInstruction>(
2528       shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2529       gather_slice_sizes(), indices_are_sorted());
2530 }
2531 
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2532 HloScatterInstruction::HloScatterInstruction(
2533     const Shape& shape, HloInstruction* operand,
2534     HloInstruction* scatter_indices, HloInstruction* updates,
2535     HloComputation* update_computation,
2536     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2537     bool unique_indices)
2538     : HloInstruction(HloOpcode::kScatter, shape),
2539       indices_are_sorted_(indices_are_sorted),
2540       unique_indices_(unique_indices) {
2541   AppendOperand(operand);
2542   AppendOperand(scatter_indices);
2543   AppendOperand(updates);
2544   AppendComputation(update_computation);
2545   scatter_dimension_numbers_ =
2546       absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2547 }
2548 
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2549 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2550     const ScatterDimensionNumbers& scatter_dimension_numbers) {
2551   string update_window_dims =
2552       StrCat("update_window_dims={",
2553              StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
2554   string inserted_window_dims = StrCat(
2555       "inserted_window_dims={",
2556       StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
2557   string scatter_dims_to_operand_dims = StrCat(
2558       "scatter_dims_to_operand_dims={",
2559       StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
2560       "}");
2561   string index_vector_dim =
2562       StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
2563 
2564   return StrJoin<std::initializer_list<string>>(
2565       {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2566        index_vector_dim},
2567       ", ");
2568 }
2569 
2570 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2571 HloScatterInstruction::MakeScatterDimNumbers(
2572     absl::Span<const int64> update_window_dims,
2573     absl::Span<const int64> inserted_window_dims,
2574     absl::Span<const int64> scatter_dims_to_operand_dims,
2575     int64 index_vector_dim) {
2576   ScatterDimensionNumbers scatter_dim_numbers;
2577   for (int64 update_window_dim : update_window_dims) {
2578     scatter_dim_numbers.add_update_window_dims(update_window_dim);
2579   }
2580   for (int64 inserted_window_dim : inserted_window_dims) {
2581     scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2582   }
2583   for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2584     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2585         scatter_dim_to_operand_dim);
2586   }
2587   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2588   return scatter_dim_numbers;
2589 }
2590 
ToProto() const2591 HloInstructionProto HloScatterInstruction::ToProto() const {
2592   HloInstructionProto proto = HloInstruction::ToProto();
2593   *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2594   proto.set_indices_are_sorted(indices_are_sorted());
2595   proto.set_unique_indices(unique_indices());
2596   return proto;
2597 }
2598 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2599 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2600     const HloPrintOptions& options) const {
2601   std::vector<string> attrs{
2602       ScatterDimensionNumbersToString(scatter_dimension_numbers())};
2603   if (indices_are_sorted()) {
2604     attrs.push_back("indices_are_sorted=true");
2605   }
2606   if (unique_indices()) {
2607     attrs.push_back("unique_indices=true");
2608   }
2609   return attrs;
2610 }
2611 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2612 bool HloScatterInstruction::IdenticalSlowPath(
2613     const HloInstruction& other,
2614     const std::function<bool(const HloComputation*, const HloComputation*)>&
2615         eq_computations) const {
2616   const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2617   return protobuf_util::ProtobufEquals(
2618              scatter_dimension_numbers(),
2619              casted_other.scatter_dimension_numbers()) &&
2620          eq_computations(to_apply(), casted_other.to_apply()) &&
2621          indices_are_sorted() == casted_other.indices_are_sorted() &&
2622          unique_indices() == casted_other.unique_indices();
2623 }
2624 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2625 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2626     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2627     HloCloneContext* context) const {
2628   CHECK_EQ(new_operands.size(), 3);
2629   return absl::make_unique<HloScatterInstruction>(
2630       shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2631       scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
2632 }
2633 
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2634 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2635     : HloInstruction(HloOpcode::kIota, shape),
2636       iota_dimension_(iota_dimension) {}
2637 
ToProto() const2638 HloInstructionProto HloIotaInstruction::ToProto() const {
2639   HloInstructionProto proto = HloInstruction::ToProto();
2640   proto.add_dimensions(iota_dimension());
2641   return proto;
2642 }
2643 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2644 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2645     const HloPrintOptions& options) const {
2646   return {StrCat("iota_dimension=", iota_dimension())};
2647 }
2648 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2649 bool HloIotaInstruction::IdenticalSlowPath(
2650     const HloInstruction& other,
2651     const std::function<bool(const HloComputation*, const HloComputation*)>&
2652         eq_computations) const {
2653   const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2654   return iota_dimension() == casted_other.iota_dimension();
2655 }
2656 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2657 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2658     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2659     HloCloneContext* context) const {
2660   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2661 }
2662 
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2663 HloDotInstruction::HloDotInstruction(
2664     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2665     const DotDimensionNumbers& dimension_numbers,
2666     const PrecisionConfig& precision_config)
2667     : HloInstruction(HloOpcode::kDot, shape),
2668       dot_dimension_numbers_(dimension_numbers),
2669       precision_config_(precision_config) {
2670   AppendOperand(lhs);
2671   AppendOperand(rhs);
2672 }
2673 
ToProto() const2674 HloInstructionProto HloDotInstruction::ToProto() const {
2675   HloInstructionProto proto = HloInstruction::ToProto();
2676   *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2677   *proto.mutable_precision_config() = precision_config_;
2678   return proto;
2679 }
2680 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2681 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2682     const HloPrintOptions& options) const {
2683   std::vector<string> extra = {DotDimensionNumbersToString()};
2684 
2685   string precision_config_string = PrecisionConfigToString(precision_config_);
2686   if (!precision_config_string.empty()) {
2687     extra.push_back(precision_config_string);
2688   }
2689   return extra;
2690 }
2691 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2692 bool HloDotInstruction::IdenticalSlowPath(
2693     const HloInstruction& other,
2694     const std::function<bool(const HloComputation*, const HloComputation*)>&
2695         eq_computations) const {
2696   const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2697   return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2698                                        casted_other.dot_dimension_numbers()) &&
2699          protobuf_util::ProtobufEquals(precision_config(),
2700                                        casted_other.precision_config());
2701 }
2702 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2703 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2704     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2705     HloCloneContext* context) const {
2706   CHECK_EQ(new_operands.size(), 2);
2707   return absl::make_unique<HloDotInstruction>(
2708       shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2709       precision_config_);
2710 }
2711 
DotDimensionNumbersToString() const2712 string HloDotInstruction::DotDimensionNumbersToString() const {
2713   std::vector<string> result;
2714   const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2715   if (!dnums.lhs_batch_dimensions().empty()) {
2716     result.push_back(StrCat("lhs_batch_dims={",
2717                             StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
2718   }
2719   result.push_back(StrCat("lhs_contracting_dims={",
2720                           StrJoin(dnums.lhs_contracting_dimensions(), ","),
2721                           "}"));
2722 
2723   if (!dnums.rhs_batch_dimensions().empty()) {
2724     result.push_back(StrCat("rhs_batch_dims={",
2725                             StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
2726   }
2727   result.push_back(StrCat("rhs_contracting_dims={",
2728                           StrJoin(dnums.rhs_contracting_dimensions(), ","),
2729                           "}"));
2730 
2731   return StrJoin(result, ", ");
2732 }
2733 
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)2734 HloDomainInstruction::HloDomainInstruction(
2735     const Shape& shape, HloInstruction* operand,
2736     std::unique_ptr<DomainMetadata> operand_side_metadata,
2737     std::unique_ptr<DomainMetadata> user_side_metadata)
2738     : HloInstruction(HloOpcode::kDomain, shape),
2739       operand_side_metadata_(std::move(operand_side_metadata)),
2740       user_side_metadata_(std::move(user_side_metadata)) {
2741   AppendOperand(operand);
2742 }
2743 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2744 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
2745     const HloPrintOptions& options) const {
2746   if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
2747     return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
2748                    "\", entry=", user_side_metadata_->ToString(),
2749                    ", exit=", operand_side_metadata_->ToString(), "}")};
2750   }
2751   return {};
2752 }
2753 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2754 bool HloDomainInstruction::IdenticalSlowPath(
2755     const HloInstruction& other,
2756     const std::function<bool(const HloComputation*, const HloComputation*)>&
2757         eq_computations) const {
2758   const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
2759   return operand_side_metadata().Matches(
2760              casted_other.operand_side_metadata()) &&
2761          user_side_metadata().Matches(casted_other.user_side_metadata());
2762 }
2763 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2764 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
2765     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2766     HloCloneContext* context) const {
2767   CHECK_EQ(new_operands.size(), 1);
2768   return absl::make_unique<HloDomainInstruction>(
2769       shape, new_operands[0], operand_side_metadata_->Clone(),
2770       user_side_metadata_->Clone());
2771 }
2772 
ToProto() const2773 HloInstructionProto HloDomainInstruction::ToProto() const {
2774   HloInstructionProto proto = HloInstruction::ToProto();
2775   auto operand_side_sharding =
2776       dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
2777   if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
2778     *proto.mutable_domain_entry_sharding() =
2779         operand_side_sharding->sharding()->ToProto();
2780   }
2781 
2782   auto user_side_sharding =
2783       dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
2784   if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
2785     *proto.mutable_domain_exit_sharding() =
2786         user_side_sharding->sharding()->ToProto();
2787   }
2788 
2789   return proto;
2790 }
2791 
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)2792 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
2793     const Shape& shape, HloInstruction* operand, int64 dimension)
2794     : HloInstruction(HloOpcode::kGetDimensionSize, shape),
2795       dimension_(dimension) {
2796   AppendOperand(operand);
2797 }
2798 
ToProto() const2799 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
2800   HloInstructionProto proto = HloInstruction::ToProto();
2801   proto.add_dimensions(dimension());
2802   return proto;
2803 }
2804 
ExtraAttributesToStringImpl(const HloPrintOptions &) const2805 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2806     const HloPrintOptions& /*options*/) const {
2807   return {StrCat("dimensions={", dimension(), "}")};
2808 }
2809 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2810 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
2811     const HloInstruction& other,
2812     const std::function<bool(const HloComputation*, const HloComputation*)>&
2813     /*eq_computations*/) const {
2814   const auto& casted_other =
2815       static_cast<const HloGetDimensionSizeInstruction&>(other);
2816   return dimension() == casted_other.dimension();
2817 }
2818 
2819 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2820 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2821     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2822     HloCloneContext* /*context*/) const {
2823   if (new_operands.size() != 1) {
2824     LOG(FATAL) << "expects 1 operand";
2825   }
2826   return absl::make_unique<HloGetDimensionSizeInstruction>(
2827       shape, new_operands[0], dimension());
2828 }
2829 
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)2830 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
2831     const Shape& shape, HloInstruction* operand, HloInstruction* val,
2832     int64 dimension)
2833     : HloInstruction(HloOpcode::kSetDimensionSize, shape),
2834       dimension_(dimension) {
2835   AppendOperand(operand);
2836   AppendOperand(val);
2837 }
2838 
ExtraAttributesToStringImpl(const HloPrintOptions &) const2839 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2840     const HloPrintOptions& /*options*/) const {
2841   return {StrCat("dimensions={", dimension(), "}")};
2842 }
2843 
ToProto() const2844 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
2845   HloInstructionProto proto = HloInstruction::ToProto();
2846   proto.add_dimensions(dimension());
2847   return proto;
2848 }
2849 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2850 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
2851     const HloInstruction& other,
2852     const std::function<bool(const HloComputation*, const HloComputation*)>&
2853     /*eq_computations*/) const {
2854   const auto& casted_other =
2855       static_cast<const HloSetDimensionSizeInstruction&>(other);
2856   return dimension() == casted_other.dimension();
2857 }
2858 
2859 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2860 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2861     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2862     HloCloneContext* /*context*/) const {
2863   if (new_operands.size() != 2) {
2864     LOG(FATAL) << "expects 2 operand";
2865   }
2866   return absl::make_unique<HloSetDimensionSizeInstruction>(
2867       shape, new_operands[0], new_operands[1], dimension());
2868 }
2869 
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64 delta)2870 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
2871     const Shape& shape, int64 delta)
2872     : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
2873 
ToProto() const2874 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
2875   HloInstructionProto proto = HloInstruction::ToProto();
2876   proto.set_delta(delta_);
2877   return proto;
2878 }
2879 
2880 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const2881 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
2882     const HloPrintOptions& /*options*/) const {
2883   return {StrCat("delta=", delta())};
2884 }
2885 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2886 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
2887     const HloInstruction& other,
2888     const std::function<bool(const HloComputation*, const HloComputation*)>&
2889     /*eq_computations*/) const {
2890   const auto& casted_other =
2891       static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
2892   return delta() == casted_other.delta();
2893 }
2894 
2895 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2896 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
2897     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2898     HloCloneContext* /*context*/) const {
2899   if (!new_operands.empty()) {
2900     LOG(FATAL) << "expects 0 operand";
2901   }
2902   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
2903 }
2904 
2905 }  // namespace xla
2906