• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17 
18 #include <deque>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
32 #include "tensorflow/compiler/xla/window_util.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 
35 namespace xla {
36 namespace {
37 
38 using absl::CEscape;
39 using absl::StrAppend;
40 using absl::StrCat;
41 using absl::StrJoin;
42 
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)43 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
44                                        const HloInstruction* operand) {
45   std::vector<int64> operand_indices = instruction->OperandIndices(operand);
46   return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
47     return instruction->IsElementwiseOnOperand(operand_index);
48   });
49 }
50 
PrecisionConfigToString(const PrecisionConfig & precision_config)51 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
52   if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
53         return static_cast<PrecisionConfig::Precision>(precision) ==
54                PrecisionConfig::DEFAULT;
55       })) {
56     return "";
57   }
58 
59   return StrCat(
60       "operand_precision={",
61       StrJoin(
62           precision_config.operand_precision(), ",",
63           [](string* out, int32 precision) {
64             CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
65             StrAppend(out,
66                       PrecisionToString(
67                           static_cast<PrecisionConfig::Precision>(precision)));
68           }),
69       "}");
70 }
71 }  // namespace
72 
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)73 HloBatchNormInstruction::HloBatchNormInstruction(
74     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
75     HloInstruction* scale, float epsilon, int64 feature_index)
76     : HloInstruction(opcode, shape),
77       epsilon_(epsilon),
78       feature_index_(feature_index) {
79   AppendOperand(operand);
80   AppendOperand(scale);
81 }
82 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const83 bool HloBatchNormInstruction::IdenticalSlowPath(
84     const HloInstruction& other,
85     const std::function<bool(const HloComputation*, const HloComputation*)>&
86         eq_computations) const {
87   const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
88   return feature_index() == casted_other.feature_index() &&
89          epsilon() == casted_other.epsilon();
90 }
91 
ToProto() const92 HloInstructionProto HloBatchNormInstruction::ToProto() const {
93   HloInstructionProto proto = HloInstruction::ToProto();
94   proto.set_epsilon(epsilon_);
95   proto.set_feature_index(feature_index_);
96   return proto;
97 }
98 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const99 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
100     const HloPrintOptions& options) const {
101   return {StrCat("epsilon=", epsilon()),
102           StrCat("feature_index=", feature_index())};
103 }
104 
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)105 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
106     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
107     HloInstruction* offset, float epsilon, int64 feature_index)
108     : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
109                               scale, epsilon, feature_index) {
110   AppendOperand(offset);
111 }
112 
113 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const114 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
115     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
116     HloCloneContext* context) const {
117   CHECK_EQ(new_operands.size(), 3);
118   return absl::make_unique<HloBatchNormTrainingInstruction>(
119       shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
120       feature_index());
121 }
122 
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)123 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
124     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
125     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
126     float epsilon, int64 feature_index)
127     : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
128                               scale, epsilon, feature_index) {
129   AppendOperand(offset);
130   AppendOperand(mean);
131   AppendOperand(variance);
132 }
133 
134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const135 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
136     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
137     HloCloneContext* context) const {
138   CHECK_EQ(new_operands.size(), 5);
139   return absl::make_unique<HloBatchNormInferenceInstruction>(
140       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
141       new_operands[4], epsilon(), feature_index());
142 }
143 
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)144 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
145     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
146     HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
147     float epsilon, int64 feature_index)
148     : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
149                               epsilon, feature_index) {
150   AppendOperand(mean);
151   AppendOperand(variance);
152   AppendOperand(grad_output);
153 }
154 
155 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const156 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
157     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
158     HloCloneContext* context) const {
159   CHECK_EQ(new_operands.size(), 5);
160   return absl::make_unique<HloBatchNormGradInstruction>(
161       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
162       new_operands[4], epsilon(), feature_index());
163 }
164 
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)165 HloFftInstruction::HloFftInstruction(const Shape& shape,
166                                      HloInstruction* operand, FftType fft_type,
167                                      absl::Span<const int64> fft_length)
168     : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
169   fft_length_.assign(fft_length.begin(), fft_length.end());
170   AppendOperand(operand);
171 }
172 
ToProto() const173 HloInstructionProto HloFftInstruction::ToProto() const {
174   HloInstructionProto proto = HloInstruction::ToProto();
175   proto.set_fft_type(fft_type_);
176   for (int64 fft_len : fft_length_) {
177     proto.add_fft_length(fft_len);
178   }
179   return proto;
180 }
181 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const182 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
183     const HloPrintOptions& options) const {
184   return {StrCat("fft_type=", FftType_Name(fft_type())),
185           StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
186 }
187 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const188 bool HloFftInstruction::IdenticalSlowPath(
189     const HloInstruction& other,
190     const std::function<bool(const HloComputation*, const HloComputation*)>&
191         eq_computations) const {
192   const auto& casted_other = static_cast<const HloFftInstruction&>(other);
193   return fft_type() == casted_other.fft_type() &&
194          fft_length() == casted_other.fft_length();
195 }
196 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const197 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
198     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
199     HloCloneContext* context) const {
200   CHECK_EQ(new_operands.size(), 1);
201   return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
202                                               fft_length_);
203 }
204 
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)205 HloCompareInstruction::HloCompareInstruction(const Shape& shape,
206                                              HloInstruction* lhs,
207                                              HloInstruction* rhs,
208                                              ComparisonDirection direction)
209     : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
210   AppendOperand(lhs);
211   AppendOperand(rhs);
212 }
213 
ToProto() const214 HloInstructionProto HloCompareInstruction::ToProto() const {
215   HloInstructionProto proto = HloInstruction::ToProto();
216   proto.set_comparison_direction(ComparisonDirectionToString(direction_));
217   return proto;
218 }
219 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const220 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
221     const HloPrintOptions& options) const {
222   return {StrCat("direction=", ComparisonDirectionToString(direction()))};
223 }
224 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const225 bool HloCompareInstruction::IdenticalSlowPath(
226     const HloInstruction& other,
227     const std::function<bool(const HloComputation*, const HloComputation*)>&
228         eq_computations) const {
229   const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
230   return direction() == casted_other.direction();
231 }
232 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const233 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
234     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
235     HloCloneContext* context) const {
236   CHECK_EQ(new_operands.size(), 2);
237   return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
238                                                   new_operands[1], direction());
239 }
240 
241 namespace {
242 
243 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
244 // of "key=value" attribute strings generically, using protocol buffer
245 // reflection.
246 //
247 // Currently implements a small subset of cases; feel free to add more as
248 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)249 std::vector<string> AttributeProtoToStringVector(
250     const tensorflow::protobuf::Message& message) {
251   const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
252   std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
253   reflection->ListFields(message, &fields);
254 
255   std::vector<string> output;
256   for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
257     string s = absl::StrCat(field->name(), "=");
258     CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
259     switch (field->type()) {
260       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
261         bool val = reflection->GetBool(message, field);
262         absl::StrAppend(&s, val ? "true" : "false");
263         break;
264       }
265       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
266         const tensorflow::protobuf::EnumValueDescriptor* evd =
267             reflection->GetEnum(message, field);
268         absl::StrAppend(&s, evd->name());
269         break;
270       }
271       default:
272         LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
273     }
274     output.push_back(std::move(s));
275   }
276   return output;
277 }
278 
279 }  // namespace
280 
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)281 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
282     const Shape& shape, HloInstruction* a, HloInstruction* b,
283     const TriangularSolveOptions& options)
284     : HloInstruction(HloOpcode::kTriangularSolve, shape),
285       triangular_solve_options_(options) {
286   AppendOperand(a);
287   AppendOperand(b);
288 }
289 
ToProto() const290 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
291   HloInstructionProto proto = HloInstruction::ToProto();
292   *proto.mutable_triangular_solve_options() = triangular_solve_options_;
293   return proto;
294 }
295 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const296 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
297     const HloPrintOptions& options) const {
298   return AttributeProtoToStringVector(triangular_solve_options_);
299 }
300 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const301 bool HloTriangularSolveInstruction::IdenticalSlowPath(
302     const HloInstruction& other,
303     const std::function<bool(const HloComputation*, const HloComputation*)>&
304         eq_computations) const {
305   const auto& casted_other =
306       static_cast<const HloTriangularSolveInstruction&>(other);
307   const auto& options = triangular_solve_options();
308   const auto& other_options = casted_other.triangular_solve_options();
309 
310   return options.left_side() == other_options.left_side() &&
311          options.lower() == other_options.lower() &&
312          options.unit_diagonal() == other_options.unit_diagonal() &&
313          options.transpose_a() == other_options.transpose_a();
314 }
315 
316 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const317 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
318     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
319     HloCloneContext* context) const {
320   CHECK_EQ(new_operands.size(), 2);
321   return absl::make_unique<HloTriangularSolveInstruction>(
322       shape, new_operands[0], new_operands[1], triangular_solve_options());
323 }
324 
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)325 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
326                                                HloInstruction* a,
327                                                const CholeskyOptions& options)
328     : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
329   AppendOperand(a);
330 }
331 
ToProto() const332 HloInstructionProto HloCholeskyInstruction::ToProto() const {
333   HloInstructionProto proto = HloInstruction::ToProto();
334   *proto.mutable_cholesky_options() = cholesky_options_;
335   return proto;
336 }
337 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const338 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
339     const HloPrintOptions& options) const {
340   return AttributeProtoToStringVector(cholesky_options_);
341 }
342 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const343 bool HloCholeskyInstruction::IdenticalSlowPath(
344     const HloInstruction& other,
345     const std::function<bool(const HloComputation*, const HloComputation*)>&
346         eq_computations) const {
347   const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
348   const auto& options = cholesky_options();
349   const auto& other_options = casted_other.cholesky_options();
350 
351   return options.lower() == other_options.lower();
352 }
353 
354 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const355 HloCholeskyInstruction::CloneWithNewOperandsImpl(
356     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
357     HloCloneContext* context) const {
358   CHECK_EQ(new_operands.size(), 1);
359   return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
360                                                    cholesky_options());
361 }
362 
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)363 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
364                                                const Shape& shape,
365                                                int64 channel_id,
366                                                bool is_host_transfer)
367     : HloInstruction(opcode, shape),
368       channel_id_(channel_id),
369       is_host_transfer_(is_host_transfer) {}
370 
ToProto() const371 HloInstructionProto HloSendRecvInstruction::ToProto() const {
372   HloInstructionProto proto = HloInstruction::ToProto();
373   proto.set_channel_id(channel_id_);
374   proto.set_is_host_transfer(is_host_transfer_);
375   return proto;
376 }
377 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const378 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
379     const HloPrintOptions& options) const {
380   std::vector<string> attrs;
381   attrs.push_back(StrCat("channel_id=", channel_id_));
382   if (is_host_transfer()) {
383     attrs.push_back("is_host_transfer=true");
384   }
385   return attrs;
386 }
387 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const388 bool HloSendRecvInstruction::IdenticalSlowPath(
389     const HloInstruction& other,
390     const std::function<bool(const HloComputation*, const HloComputation*)>&
391         eq_computations) const {
392   // Not yet supported.
393   return false;
394 }
395 
396 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)397 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
398                                        HloInstruction* token, int64 channel_id,
399                                        bool is_host_transfer)
400     : HloSendRecvInstruction(
401           HloOpcode::kSend,
402           ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
403                                      ShapeUtil::MakeShape(U32, {}),
404                                      ShapeUtil::MakeTokenShape()}),
405           channel_id, is_host_transfer) {
406   AppendOperand(operand);
407   AppendOperand(token);
408 }
409 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const410 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
411     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
412     HloCloneContext* context) const {
413   CHECK_EQ(new_operands.size(), 2);
414   return absl::make_unique<HloSendInstruction>(
415       new_operands[0], new_operands[1], channel_id(), is_host_transfer());
416 }
417 
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)418 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
419                                                bool is_host_transfer)
420     : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
421                              CHECK_NOTNULL(operand)->channel_id(),
422                              is_host_transfer) {
423   AppendOperand(operand);
424 }
425 
426 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const427 HloSendDoneInstruction::CloneWithNewOperandsImpl(
428     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
429     HloCloneContext* context) const {
430   CHECK_EQ(new_operands.size(), 1);
431   return absl::make_unique<HloSendDoneInstruction>(
432       Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
433 }
434 
435 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)436 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
437                                        HloInstruction* token, int64 channel_id,
438                                        bool is_host_transfer)
439     : HloSendRecvInstruction(
440           HloOpcode::kRecv,
441           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
442                                      ShapeUtil::MakeTokenShape()}),
443           channel_id, is_host_transfer) {
444   AppendOperand(token);
445 }
446 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const447 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
448     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
449     HloCloneContext* context) const {
450   CHECK_EQ(new_operands.size(), 1);
451   return absl::make_unique<HloRecvInstruction>(
452       ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
453       is_host_transfer());
454 }
455 
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)456 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
457                                                bool is_host_transfer)
458     : HloSendRecvInstruction(
459           HloOpcode::kRecvDone,
460           ShapeUtil::MakeTupleShape(
461               {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
462                ShapeUtil::MakeTokenShape()}),
463           CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
464   AppendOperand(operand);
465 }
466 
467 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const468 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
469     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
470     HloCloneContext* context) const {
471   CHECK_EQ(new_operands.size(), 1);
472   return absl::make_unique<HloRecvDoneInstruction>(
473       Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
474 }
475 
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)476 HloCollectiveInstruction::HloCollectiveInstruction(
477     HloOpcode opcode, const Shape& shape,
478     absl::Span<HloInstruction* const> operands,
479     const std::vector<ReplicaGroup>& replica_groups)
480     : HloInstruction(opcode, shape), replica_groups_(replica_groups) {
481   for (auto operand : operands) {
482     AppendOperand(operand);
483   }
484 }
485 
ToProto() const486 HloInstructionProto HloCollectiveInstruction::ToProto() const {
487   HloInstructionProto proto = HloInstruction::ToProto();
488   *proto.mutable_replica_groups() = {replica_groups_.begin(),
489                                      replica_groups_.end()};
490   return proto;
491 }
492 
ExtraAttributesToStringImpl(const HloPrintOptions &) const493 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
494     const HloPrintOptions& /*options*/) const {
495   std::vector<string> result;
496   std::vector<string> replica_group_str;
497   for (const ReplicaGroup& group : replica_groups()) {
498     replica_group_str.push_back(
499         StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
500   }
501   result.push_back(
502       StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}"));
503   return result;
504 }
505 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const506 bool HloCollectiveInstruction::IdenticalSlowPath(
507     const HloInstruction& other,
508     const std::function<bool(const HloComputation*, const HloComputation*)>&
509     /*eq_computations*/) const {
510   const auto& casted_other =
511       static_cast<const HloCollectiveInstruction&>(other);
512   return absl::c_equal(replica_groups(), casted_other.replica_groups(),
513                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
514                          return absl::c_equal(a.replica_ids(), b.replica_ids());
515                        });
516 }
517 
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,absl::string_view barrier,const absl::optional<int64> & all_reduce_id)518 HloAllReduceInstruction::HloAllReduceInstruction(
519     const Shape& shape, absl::Span<HloInstruction* const> operands,
520     HloComputation* reduce_computation,
521     const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
522     const absl::optional<int64>& all_reduce_id)
523     : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
524                                replica_groups),
525       all_reduce_barrier_(barrier),
526       all_reduce_id_(all_reduce_id) {
527   AppendComputation(reduce_computation);
528 }
529 
set_all_reduce_id(const absl::optional<int64> & all_reduce_id)530 void HloAllReduceInstruction::set_all_reduce_id(
531     const absl::optional<int64>& all_reduce_id) {
532   all_reduce_id_ = all_reduce_id;
533 }
534 
ToProto() const535 HloInstructionProto HloAllReduceInstruction::ToProto() const {
536   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
537   // Proto3 is so sad.
538   if (all_reduce_id_) {
539     proto.set_all_reduce_id(*all_reduce_id_);
540   }
541   proto.set_all_reduce_barrier(all_reduce_barrier_);
542   return proto;
543 }
544 
IsNoop() const545 bool HloAllReduceInstruction::IsNoop() const {
546   for (auto replica_group : replica_groups()) {
547     if (replica_group.replica_ids().size() != 1) {
548       return false;
549     }
550   }
551   return !all_reduce_id();
552 }
553 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const554 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
555     const HloPrintOptions& options) const {
556   std::vector<string> result =
557       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
558   if (!all_reduce_barrier().empty()) {
559     result.push_back(StrCat("barrier=\"", all_reduce_barrier(), "\""));
560   }
561   if (all_reduce_id_) {
562     result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
563   }
564   return result;
565 }
566 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const567 bool HloAllReduceInstruction::IdenticalSlowPath(
568     const HloInstruction& other,
569     const std::function<bool(const HloComputation*, const HloComputation*)>&
570         eq_computations) const {
571   const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
572   return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
573          eq_computations(to_apply(), casted_other.to_apply()) &&
574          all_reduce_barrier() == casted_other.all_reduce_barrier() &&
575          all_reduce_id() == casted_other.all_reduce_id();
576 }
577 
578 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const579 HloAllReduceInstruction::CloneWithNewOperandsImpl(
580     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
581     HloCloneContext* /*context*/) const {
582   return absl::make_unique<HloAllReduceInstruction>(
583       shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(),
584       all_reduce_id());
585 }
586 
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)587 HloAllToAllInstruction::HloAllToAllInstruction(
588     const Shape& shape, absl::Span<HloInstruction* const> operands,
589     const std::vector<ReplicaGroup>& replica_groups)
590     : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
591                                replica_groups) {}
592 
593 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const594 HloAllToAllInstruction::CloneWithNewOperandsImpl(
595     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
596     HloCloneContext* /*context*/) const {
597   return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
598                                                    replica_groups());
599 }
600 
HloCollectivePermuteInstruction(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)601 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
602     const Shape& shape, HloInstruction* operand,
603     const std::vector<std::pair<int64, int64>>& source_target_pairs)
604     : HloInstruction(HloOpcode::kCollectivePermute, shape),
605       source_target_pairs_(source_target_pairs) {
606   AppendOperand(operand);
607 }
608 
ToProto() const609 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
610   HloInstructionProto proto = HloInstruction::ToProto();
611   for (const auto& pair : source_target_pairs()) {
612     auto* proto_pair = proto.add_source_target_pairs();
613     proto_pair->set_source(pair.first);
614     proto_pair->set_target(pair.second);
615   }
616   return proto;
617 }
618 
619 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const620 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
621     const HloPrintOptions& /*options*/) const {
622   std::vector<string> result;
623   std::vector<string> strs;
624   for (const auto& pair : source_target_pairs()) {
625     strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
626   }
627   result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
628   return result;
629 }
630 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const631 bool HloCollectivePermuteInstruction::IdenticalSlowPath(
632     const HloInstruction& other,
633     const std::function<bool(const HloComputation*, const HloComputation*)>&
634     /*eq_computations*/) const {
635   const auto& casted_other =
636       static_cast<const HloCollectivePermuteInstruction&>(other);
637   return absl::c_equal(source_target_pairs(),
638                        casted_other.source_target_pairs(),
639                        [](const std::pair<int64, int64>& a,
640                           const std::pair<int64, int64>& b) { return a == b; });
641 }
642 
643 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const644 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
645     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
646     HloCloneContext* /*context*/) const {
647   return absl::make_unique<HloCollectivePermuteInstruction>(
648       shape, new_operands[0], source_target_pairs());
649 }
650 
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)651 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
652                                              HloInstruction* operand,
653                                              absl::Span<const int64> dimensions)
654     : HloInstruction(HloOpcode::kReverse, shape),
655       dimensions_(dimensions.begin(), dimensions.end()) {
656   AppendOperand(operand);
657 }
658 
ToProto() const659 HloInstructionProto HloReverseInstruction::ToProto() const {
660   HloInstructionProto proto = HloInstruction::ToProto();
661   for (int64 dimension : dimensions_) {
662     proto.add_dimensions(dimension);
663   }
664   return proto;
665 }
666 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const667 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
668     const HloPrintOptions& options) const {
669   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
670 }
671 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const672 bool HloReverseInstruction::IdenticalSlowPath(
673     const HloInstruction& other,
674     const std::function<bool(const HloComputation*, const HloComputation*)>&
675         eq_computations) const {
676   const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
677   return dimensions() == casted_other.dimensions();
678 }
679 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const680 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
681     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
682     HloCloneContext* context) const {
683   CHECK_EQ(new_operands.size(), 1);
684   return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
685                                                   dimensions());
686 }
687 
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)688 HloConcatenateInstruction::HloConcatenateInstruction(
689     const Shape& shape, absl::Span<HloInstruction* const> operands,
690     int64 dimension)
691     : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
692   for (auto operand : operands) {
693     AppendOperand(operand);
694   }
695 }
696 
ToProto() const697 HloInstructionProto HloConcatenateInstruction::ToProto() const {
698   HloInstructionProto proto = HloInstruction::ToProto();
699   for (int64 dimension : dimensions_) {
700     proto.add_dimensions(dimension);
701   }
702   return proto;
703 }
704 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const705 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
706     const HloPrintOptions& options) const {
707   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
708 }
709 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const710 bool HloConcatenateInstruction::IdenticalSlowPath(
711     const HloInstruction& other,
712     const std::function<bool(const HloComputation*, const HloComputation*)>&
713         eq_computations) const {
714   const auto& casted_other =
715       static_cast<const HloConcatenateInstruction&>(other);
716   return dimensions() == casted_other.dimensions();
717 }
718 
719 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const720 HloConcatenateInstruction::CloneWithNewOperandsImpl(
721     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
722     HloCloneContext* context) const {
723   return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
724                                                       dimensions(0));
725 }
726 
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)727 HloReduceInstruction::HloReduceInstruction(
728     const Shape& shape, absl::Span<HloInstruction* const> args,
729     absl::Span<const int64> dimensions_to_reduce,
730     HloComputation* reduce_computation)
731     : HloInstruction(HloOpcode::kReduce, shape),
732       dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
733   for (HloInstruction* arg : args) {
734     AppendOperand(arg);
735   }
736   AppendComputation(reduce_computation);
737 }
738 
ToProto() const739 HloInstructionProto HloReduceInstruction::ToProto() const {
740   HloInstructionProto proto = HloInstruction::ToProto();
741   for (int64 dimension : dimensions_) {
742     proto.add_dimensions(dimension);
743   }
744   return proto;
745 }
746 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const747 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
748     const HloPrintOptions& options) const {
749   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
750 }
751 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const752 bool HloReduceInstruction::IdenticalSlowPath(
753     const HloInstruction& other,
754     const std::function<bool(const HloComputation*, const HloComputation*)>&
755         eq_computations) const {
756   const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
757   // Reduction results are determined by the reduction dimension and the
758   // reduction computation.
759   return dimensions() == casted_other.dimensions() &&
760          eq_computations(to_apply(), casted_other.to_apply());
761 }
762 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const763 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
764     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
765     HloCloneContext* context) const {
766   CHECK_EQ(new_operands.size() % 2, 0);
767   return absl::make_unique<HloReduceInstruction>(shape, new_operands,
768                                                  dimensions(), to_apply());
769 }
770 
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)771 HloSortInstruction::HloSortInstruction(
772     const Shape& shape, int64 dimension,
773     absl::Span<HloInstruction* const> operands, HloComputation* compare,
774     bool is_stable)
775     : HloInstruction(HloOpcode::kSort, shape),
776       dimensions_({dimension}),
777       is_stable_(is_stable) {
778   for (auto* value : operands) {
779     AppendOperand(value);
780   }
781   AppendComputation(compare);
782 }
783 
ToProto() const784 HloInstructionProto HloSortInstruction::ToProto() const {
785   HloInstructionProto proto = HloInstruction::ToProto();
786   for (int64 dimension : dimensions_) {
787     proto.add_dimensions(dimension);
788   }
789   proto.set_is_stable(is_stable());
790   return proto;
791 }
792 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const793 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
794     const HloPrintOptions& options) const {
795   std::vector<string> attrs;
796   attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
797   if (is_stable()) {
798     attrs.push_back("is_stable=true");
799   }
800   return attrs;
801 }
802 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const803 bool HloSortInstruction::IdenticalSlowPath(
804     const HloInstruction& other,
805     const std::function<bool(const HloComputation*, const HloComputation*)>&
806         eq_computations) const {
807   const auto& casted_other = static_cast<const HloSortInstruction&>(other);
808   if (dimensions() != casted_other.dimensions()) {
809     return false;
810   }
811   if (is_stable() != casted_other.is_stable()) {
812     return false;
813   }
814   return eq_computations(to_apply(), other.to_apply());
815 }
816 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const817 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
818     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
819     HloCloneContext* context) const {
820   return absl::make_unique<HloSortInstruction>(
821       shape, dimensions(0), new_operands, to_apply(), is_stable());
822 }
823 
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)824 HloTransposeInstruction::HloTransposeInstruction(
825     const Shape& shape, HloInstruction* operand,
826     absl::Span<const int64> dimensions)
827     : HloInstruction(HloOpcode::kTranspose, shape),
828       dimensions_(dimensions.begin(), dimensions.end()) {
829   AppendOperand(operand);
830 }
831 
IsRank2Transpose() const832 bool HloTransposeInstruction::IsRank2Transpose() const {
833   return dimensions() == std::vector<int64>({1, 0}) &&
834          shape().dimensions_size() == 2 &&
835          std::equal(shape().dimensions().begin(), shape().dimensions().end(),
836                     operand(0)->shape().dimensions().rbegin());
837 }
838 
ToProto() const839 HloInstructionProto HloTransposeInstruction::ToProto() const {
840   HloInstructionProto proto = HloInstruction::ToProto();
841   for (int64 dimension : dimensions_) {
842     proto.add_dimensions(dimension);
843   }
844   return proto;
845 }
846 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const847 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
848     const HloPrintOptions& options) const {
849   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
850 }
851 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const852 bool HloTransposeInstruction::IdenticalSlowPath(
853     const HloInstruction& other,
854     const std::function<bool(const HloComputation*, const HloComputation*)>&
855         eq_computations) const {
856   const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
857   return dimensions() == casted_other.dimensions();
858 }
859 
860 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const861 HloTransposeInstruction::CloneWithNewOperandsImpl(
862     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
863     HloCloneContext* context) const {
864   CHECK_EQ(new_operands.size(), 1);
865   return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
866                                                     dimensions());
867 }
868 
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)869 HloBroadcastInstruction::HloBroadcastInstruction(
870     const Shape& shape, HloInstruction* operand,
871     absl::Span<const int64> broadcast_dimension)
872     : HloInstruction(HloOpcode::kBroadcast, shape),
873       dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
874   AppendOperand(operand);
875 }
876 
ToProto() const877 HloInstructionProto HloBroadcastInstruction::ToProto() const {
878   HloInstructionProto proto = HloInstruction::ToProto();
879   for (int64 dimension : dimensions_) {
880     proto.add_dimensions(dimension);
881   }
882   return proto;
883 }
884 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const885 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
886     const HloPrintOptions& options) const {
887   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
888 }
889 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const890 bool HloBroadcastInstruction::IdenticalSlowPath(
891     const HloInstruction& other,
892     const std::function<bool(const HloComputation*, const HloComputation*)>&
893         eq_computations) const {
894   const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
895   return dimensions() == casted_other.dimensions();
896 }
897 
898 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const899 HloBroadcastInstruction::CloneWithNewOperandsImpl(
900     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
901     HloCloneContext* context) const {
902   CHECK_EQ(new_operands.size(), 1);
903   return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
904                                                     dimensions());
905 }
906 
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)907 HloMapInstruction::HloMapInstruction(const Shape& shape,
908                                      absl::Span<HloInstruction* const> operands,
909                                      HloComputation* map_computation)
910     : HloInstruction(HloOpcode::kMap, shape) {
911   for (auto operand : operands) {
912     AppendOperand(operand);
913   }
914   AppendComputation(map_computation);
915   // TODO(b/65689298) Remove code below once Map is generalized to accept
916   // arbitrary map dimensions.
917   dimensions_.resize(shape.rank());
918   std::iota(dimensions_.begin(), dimensions_.end(), 0);
919 }
920 
ToProto() const921 HloInstructionProto HloMapInstruction::ToProto() const {
922   HloInstructionProto proto = HloInstruction::ToProto();
923   for (int64 dimension : dimensions_) {
924     proto.add_dimensions(dimension);
925   }
926   return proto;
927 }
928 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const929 bool HloMapInstruction::IsElementwiseImpl(
930     const absl::optional<int64>& operand_idx) const {
931   if (!dimensions().empty()) {
932     // Check that the map is executed in elementwise compatible dimensions.
933     if (dimensions().size() != shape().dimensions_size()) {
934       return false;
935     }
936     for (int i = 0; i < dimensions().size(); ++i) {
937       if (dimensions()[i] != i) {
938         return false;
939       }
940     }
941   }
942   return true;
943 }
944 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const945 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
946     const HloPrintOptions& options) const {
947   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
948 }
949 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const950 bool HloMapInstruction::IdenticalSlowPath(
951     const HloInstruction& other,
952     const std::function<bool(const HloComputation*, const HloComputation*)>&
953         eq_computations) const {
954   return eq_computations(to_apply(), other.to_apply());
955 }
956 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const957 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
958     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
959     HloCloneContext* context) const {
960   return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
961 }
962 
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)963 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
964                                          HloInstruction* operand,
965                                          absl::Span<const int64> start_indices,
966                                          absl::Span<const int64> limit_indices,
967                                          absl::Span<const int64> strides)
968     : HloInstruction(HloOpcode::kSlice, shape),
969       slice_starts_(start_indices.begin(), start_indices.end()),
970       slice_limits_(limit_indices.begin(), limit_indices.end()),
971       slice_strides_(strides.begin(), strides.end()) {
972   AppendOperand(operand);
973   // For backward compatibility with old serialized computations: if there are
974   // no strides, assume all strides are 1.
975   // TODO(b/63317920): remove this code.
976   if (slice_strides_.empty()) {
977     slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
978   }
979 }
980 
ToProto() const981 HloInstructionProto HloSliceInstruction::ToProto() const {
982   HloInstructionProto proto = HloInstruction::ToProto();
983   for (int i = 0; i < slice_starts_.size(); ++i) {
984     auto* slice_dimension = proto.add_slice_dimensions();
985     slice_dimension->set_start(slice_starts_[i]);
986     slice_dimension->set_limit(slice_limits_[i]);
987     slice_dimension->set_stride(slice_strides_[i]);
988   }
989   return proto;
990 }
991 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const992 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
993     const HloPrintOptions& options) const {
994   std::vector<string> bounds;
995   bounds.reserve(slice_starts_.size());
996   const bool omit_stride =
997       absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
998   for (int i = 0; i < slice_starts_.size(); ++i) {
999     string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1000     bounds.push_back(
1001         StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1002   }
1003   return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1004 }
1005 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1006 bool HloSliceInstruction::IdenticalSlowPath(
1007     const HloInstruction& other,
1008     const std::function<bool(const HloComputation*, const HloComputation*)>&
1009         eq_computations) const {
1010   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1011   return slice_starts_ == other_slice.slice_starts_ &&
1012          slice_limits_ == other_slice.slice_limits_ &&
1013          slice_strides_ == other_slice.slice_strides_;
1014 }
1015 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1016 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1017     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1018     HloCloneContext* context) const {
1019   CHECK_EQ(new_operands.size(), 1);
1020   return absl::make_unique<HloSliceInstruction>(
1021       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1022 }
1023 
HloConstantInstruction(Literal literal)1024 HloConstantInstruction::HloConstantInstruction(Literal literal)
1025     : HloInstruction(HloOpcode::kConstant, literal.shape()),
1026       literal_(std::move(literal)) {}
1027 
HloConstantInstruction(const Shape & shape)1028 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1029     : HloInstruction(HloOpcode::kConstant, shape) {}
1030 
ToProto() const1031 HloInstructionProto HloConstantInstruction::ToProto() const {
1032   HloInstructionProto proto = HloInstruction::ToProto();
1033   if (literal_.has_value()) {
1034     *proto.mutable_literal() = literal_->ToProto();
1035   }
1036   return proto;
1037 }
1038 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1039 bool HloConstantInstruction::IsElementwiseImpl(
1040     const absl::optional<int64>& operand_idx) const {
1041   return true;
1042 }
1043 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1044 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1045                                               const ShapeIndex& shape_index) {
1046   Shape* mutable_array_subshape =
1047       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1048   CHECK(mutable_array_subshape->IsArray());
1049 
1050   // Normally array_subshape will always have a layout, but this invariant is
1051   // temporarily broken in LayoutAssignment::AssignLayouts.
1052 
1053   if (!mutable_array_subshape->has_layout() ||
1054       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1055     *literal_ = literal_->Relayout(new_layout, shape_index);
1056     *mutable_array_subshape->mutable_layout() = new_layout;
1057   }
1058 }
1059 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1060 bool HloConstantInstruction::IdenticalSlowPath(
1061     const HloInstruction& other,
1062     const std::function<bool(const HloComputation*, const HloComputation*)>&
1063         eq_computations) const {
1064   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1065   return literal() == other_slice.literal();
1066 }
1067 
1068 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1069 HloConstantInstruction::CloneWithNewOperandsImpl(
1070     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1071     HloCloneContext* context) const {
1072   CHECK(literal_.has_value());
1073   return absl::make_unique<HloConstantInstruction>(literal_->Clone());
1074 }
1075 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1076 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1077     const HloPrintOptions& options,
1078     CanonicalNameMap* canonical_name_map) const {
1079   string operands;
1080   // For constants, show the actual value in place of an empty operand list.
1081   if (literal_.has_value() &&
1082       ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1083        options.print_large_constants())) {
1084     // Literal::ToString emits multidimensional arrays over multiple
1085     // lines. Compact this into one line by stripping out white space.
1086     string tmp = literal().ToStringWithoutShape();
1087     std::replace(tmp.begin(), tmp.end(), '\n', ' ');
1088     std::vector<string> v = absl::StrSplit(tmp, ' ');
1089     bool first = true;
1090     // Concatenate elements in "v" with spaces separating them, but ignoring
1091     // empty entries.
1092     for (const auto& s : v) {
1093       if (s.empty()) {
1094         continue;
1095       }
1096       StrAppend(&operands, (first ? "" : " "), s);
1097       first = false;
1098     }
1099   } else {
1100     // Do not show large constants or tuples.
1101     operands = "{...}";
1102   }
1103   return operands;
1104 }
1105 
HloTraceInstruction(const string & tag,HloInstruction * operand)1106 HloTraceInstruction::HloTraceInstruction(const string& tag,
1107                                          HloInstruction* operand)
1108     : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1109       literal_(LiteralUtil::CreateR1U8(tag)) {
1110   AppendOperand(operand);
1111   operand->set_tracing(this);
1112 }
1113 
ToProto() const1114 HloInstructionProto HloTraceInstruction::ToProto() const {
1115   HloInstructionProto proto = HloInstruction::ToProto();
1116   *proto.mutable_literal() = literal_.ToProto();
1117   return proto;
1118 }
1119 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1120 bool HloTraceInstruction::IdenticalSlowPath(
1121     const HloInstruction& other,
1122     const std::function<bool(const HloComputation*, const HloComputation*)>&
1123         eq_computations) const {
1124   return false;
1125 }
1126 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1127 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1128     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1129     HloCloneContext* context) const {
1130   LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1131 }
1132 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1133 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1134                                            FusionKind fusion_kind,
1135                                            HloInstruction* fused_root)
1136     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1137   CHECK(fused_root != nullptr);
1138   SetAndSanitizeName("fusion");
1139   set_parent(fused_root->parent());
1140   set_metadata(fused_root->metadata());
1141   CloneAndFuseInternal(fused_root);
1142 }
1143 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1144 HloFusionInstruction::HloFusionInstruction(
1145     const Shape& shape, FusionKind fusion_kind,
1146     absl::Span<HloInstruction* const> operands,
1147     HloComputation* fusion_computation)
1148     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1149   for (auto operand : operands) {
1150     AppendOperand(operand);
1151   }
1152   SetAndSanitizeName("fusion");
1153   AppendComputation(fusion_computation);
1154   fusion_computation->SetFusionInstruction(this);
1155 }
1156 
ToCategory() const1157 string HloFusionInstruction::ToCategory() const {
1158   switch (fusion_kind()) {
1159     case FusionKind::kLoop:
1160       return "loop fusion";
1161     case FusionKind::kInput:
1162       return "input fusion";
1163     case FusionKind::kOutput:
1164       return "output fusion";
1165     case FusionKind::kCustom:
1166       return "custom fusion";
1167   }
1168 }
1169 
ToProto() const1170 HloInstructionProto HloFusionInstruction::ToProto() const {
1171   HloInstructionProto proto = HloInstruction::ToProto();
1172   proto.set_fusion_kind(xla::ToString(fusion_kind()));
1173   proto.add_called_computation_ids(
1174       fused_instructions_computation()->unique_id());
1175   return proto;
1176 }
1177 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1178 bool HloFusionInstruction::IsElementwiseImpl(
1179     const absl::optional<int64>& operand_idx) const {
1180   if (!operand_idx.has_value()) {
1181     for (auto* fused : fused_instructions()) {
1182       if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1183         return false;
1184       }
1185     }
1186     return true;
1187   }
1188   // A loop-fusion is elementwise on an operand if all operations (computed
1189   // using BFS) between the operand and the fused root are elementwise.
1190   std::deque<HloInstruction*> worklist;
1191   std::unordered_set<const HloInstruction*> visited;
1192   worklist.push_back(fused_parameter(operand_idx.value()));
1193   visited.insert(fused_parameter(operand_idx.value()));
1194   while (!worklist.empty()) {
1195     HloInstruction* operand = worklist.front();
1196     worklist.pop_front();
1197     for (HloInstruction* user : operand->users()) {
1198       CHECK_GE(user->unique_id(), 0);
1199       if (ContainsKey(visited, user)) {
1200         continue;
1201       }
1202       if (user->IsElementwise() ||
1203           IsInstructionElementwiseOnOperand(user, operand)) {
1204         worklist.push_back(user);
1205         visited.insert(user);
1206       } else {
1207         return false;
1208       }
1209     }
1210   }
1211   return true;
1212 }
1213 
AddFusionOperand(HloInstruction * new_operand)1214 HloInstruction* HloFusionInstruction::AddFusionOperand(
1215     HloInstruction* new_operand) {
1216   CHECK_EQ(operand_count(),
1217            fused_instructions_computation()->parameter_instructions().size());
1218   const int64 param_no = operand_count();
1219   // Name the parameter after the instruction it represents in the outer
1220   // (non-fusion) computation.
1221   // string param_name = StrCat(new_operand->name(), ".param_", param_no);
1222   string param_name = StrCat("param_", param_no);
1223   HloInstruction* fused_parameter =
1224       fused_instructions_computation()->AddParameter(
1225           HloInstruction::CreateParameter(param_no, new_operand->shape(),
1226                                           param_name));
1227   AppendOperand(new_operand);
1228   return fused_parameter;
1229 }
1230 
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1231 void HloFusionInstruction::MergeFusionInstruction(
1232     HloFusionInstruction* instruction_to_merge) {
1233   CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1234   // Clone the instruction from which to merge fused instructions.
1235   std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1236   HloFusionInstruction* cloned_fusion =
1237       static_cast<HloFusionInstruction*>(cloned.get());
1238   // Replace uses of fused parameters with the corresponding operand of the
1239   // fusion.  Add all non-parameter fused instructions to
1240   // 'unfused_instructions' to be merged into 'this'.  This is done in reverse
1241   // post order.
1242   std::vector<HloInstruction*> unfused_instructions;
1243   auto fused_instructions = cloned_fusion->fused_instructions_computation()
1244                                 ->MakeInstructionPostOrder();
1245   for (auto fused_it = fused_instructions.rbegin();
1246        fused_it != fused_instructions.rend(); ++fused_it) {
1247     auto fused_instruction = *fused_it;
1248     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1249       TF_CHECK_OK(
1250           fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1251               fused_instruction->parameter_number())));
1252     } else {
1253       unfused_instructions.push_back(fused_instruction);
1254     }
1255   }
1256   CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
1257   // Replace instruction_to_merge use of 'this' with unfused_root.
1258   TF_CHECK_OK(
1259       instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
1260   // Fuse 'unfused_instructions' into 'this'.
1261   for (auto& instruction : unfused_instructions) {
1262     FuseInstruction(instruction);
1263   }
1264   CHECK_EQ(0, cloned_fusion->user_count());
1265   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1266       cloned_fusion->fused_instructions_computation()));
1267 }
1268 
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1269 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1270     HloFusionInstruction* instruction_to_merge) {
1271   // Add all non-parameter fused instructions to 'unfused_instructions' to be
1272   // merged into 'this'. `old_to_new' maps the instructions in the fused node
1273   // to the disaseembled fusion instructions.
1274   // Note that we add the unfused instructions to this->parent_ computation.
1275   // This is necessary because the unique_id needs for an instruction and
1276   // it's only added when inserting to the computation.
1277   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1278   std::vector<HloInstruction*> unfused_instructions;
1279   auto computation_to_merge =
1280       instruction_to_merge->fused_instructions_computation();
1281   auto post_order = computation_to_merge->MakeInstructionPostOrder();
1282   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1283     auto fused_instruction = *rit;
1284     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1285       InsertOrDie(&old_to_new, fused_instruction,
1286                   instruction_to_merge->mutable_operand(
1287                       fused_instruction->parameter_number()));
1288       continue;
1289     }
1290 
1291     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1292     // which clones again. This can be improved.
1293     auto cloned_instruction =
1294         parent()->AddInstruction(fused_instruction->Clone());
1295     unfused_instructions.push_back(cloned_instruction);
1296     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1297   }
1298   for (auto unfused_instruction : unfused_instructions) {
1299     for (int64 index = 0; index < unfused_instruction->operand_count();
1300          index++) {
1301       auto new_operand =
1302           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1303       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1304     }
1305   }
1306 
1307   HloInstruction* unfused_root = unfused_instructions.front();
1308   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1309 
1310   TF_CHECK_OK(
1311       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1312   if (GetModule()) {
1313     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1314   }
1315 
1316   // Fuse the root instruction and generate multiple outputs.
1317   FuseInstructionIntoMultiOutput(unfused_root);
1318   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1319   // The rest instructions are of normal fusing.
1320   for (int64 i = 1; i < unfused_instructions.size(); i++) {
1321     auto instruction = unfused_instructions[i];
1322     FuseInstruction(instruction);
1323     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1324   }
1325 }
1326 
fused_instructions_computation() const1327 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1328   CHECK(!called_computations().empty());
1329   auto* fused_instructions_computation = called_computations().front();
1330   CHECK(fused_instructions_computation->IsFusionComputation())
1331       << "Computation " << fused_instructions_computation->name()
1332       << " is not a fusion kind";
1333   return fused_instructions_computation;
1334 }
1335 
fused_expression_root() const1336 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1337   return fused_instructions_computation()->root_instruction();
1338 }
1339 
fused_parameter(int64 parameter_number) const1340 HloInstruction* HloFusionInstruction::fused_parameter(
1341     int64 parameter_number) const {
1342   return fused_instructions_computation()->parameter_instruction(
1343       parameter_number);
1344 }
1345 
fused_parameters() const1346 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1347     const {
1348   return fused_instructions_computation()->parameter_instructions();
1349 }
1350 
1351 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1352     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1353 HloFusionInstruction::fused_instructions() const {
1354   const HloComputation* subcomp = fused_instructions_computation();
1355   return subcomp->instructions();
1356 }
1357 
1358 const tensorflow::gtl::iterator_range<
1359     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1360 HloFusionInstruction::fused_instructions() {
1361   return fused_instructions_computation()->instructions();
1362 }
1363 
fused_instruction_count() const1364 int64 HloFusionInstruction::fused_instruction_count() const {
1365   return fused_instructions_computation()->instruction_count();
1366 }
1367 
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1368 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1369     HloInstruction* instruction_to_fuse, bool add_output) {
1370   // When add_output is false, this fusion instruction must be a user of
1371   // instruction_to_fuse.
1372   if (!add_output) {
1373     CHECK(IsUserOf(instruction_to_fuse));
1374   }
1375   HloInstruction* fused_instruction =
1376       CloneAndFuseInternal(instruction_to_fuse, add_output);
1377   return fused_instruction;
1378 }
1379 
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1380 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1381     HloInstruction* instruction_to_fuse, bool add_output) {
1382   CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1383   VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1384   HloInstruction* clone = nullptr;
1385   if (called_computations().empty()) {
1386     // New fusion instruction. It should not be a multioutput instruction.
1387     CHECK(!add_output);
1388     auto builder = HloComputation::Builder("fused_computation", this);
1389     builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1390     AppendComputation(
1391         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1392     clone = fused_expression_root();
1393   } else {
1394     // When add_output is false, instruction_to_fuse is necessarily an operand
1395     // of the fusion instruction. After fusion this will no longer be the
1396     // case. Remove the operand from the operand list and remove its
1397     // corresponding fused parameter instruction. Renumber parameters as
1398     // necessary to make parameter numbers consistent with their index in the
1399     // fused_parameter_ vector.
1400     bool in_operand_list =
1401         absl::c_linear_search(operands(), instruction_to_fuse);
1402     CHECK(add_output || in_operand_list);
1403     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1404       // We assume all uses of a kTuple operation are GTE ops, not another
1405       // fusion node. In this case, we don't need to clone
1406       // 'instruction_to_fuse'.
1407       CHECK(!in_operand_list);
1408       clone = instruction_to_fuse;
1409     } else {
1410       clone = fused_instructions_computation()->AddInstruction(
1411           instruction_to_fuse->Clone(/*suffix=*/""));
1412     }
1413     const std::vector<HloInstruction*>& fused_parameters =
1414         fused_instructions_computation()->parameter_instructions();
1415     for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1416       if (instruction_to_fuse == operand(operand_num)) {
1417         // replace the fused parameter instruction's uses with the clone.
1418         HloInstruction* fused_parameter = fused_parameters[operand_num];
1419         TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1420 
1421         // Remove the corresponding fused parameter and operand from their
1422         // respective vectors.
1423         TF_CHECK_OK(
1424             fused_instructions_computation()->RemoveParameter(operand_num));
1425         RemoveOperandAt(operand_num);
1426         break;
1427       }
1428     }
1429     // We've cloned instruction_to_fuse into this fusion instruction, so this
1430     // fusion instruction is no longer a use of instruction_to_fuse.
1431     if (in_operand_list) {
1432       DetachFrom(instruction_to_fuse);
1433       // When the instruction_to_fuse does not have other users, we don't need
1434       // to generate a multioutput fusion instruction.
1435       if (instruction_to_fuse->user_count() == 0) {
1436         add_output = false;
1437       }
1438     }
1439   }
1440 
1441   // Reread the parameters in the computation.
1442   const std::vector<HloInstruction*>& fused_parameters =
1443       fused_instructions_computation()->parameter_instructions();
1444 
1445   // Add each operand of the clone as an operand of the fusion instruction. A
1446   // complication is that some clone operands may already be operands of the
1447   // fusion instruction.
1448   for (int64 operand_num = 0; operand_num < clone->operand_count();
1449        ++operand_num) {
1450     HloInstruction* operand = clone->mutable_operand(operand_num);
1451 
1452     // See if this operand is already an operand of the fusion node.
1453     CHECK_EQ(operands().size(), fused_parameters.size());
1454     HloInstruction* fused_param = nullptr;
1455     for (int64 i = 0; i < operands().size(); ++i) {
1456       if (this->operand(i) == operand) {
1457         fused_param = fused_parameters[i];
1458         break;
1459       }
1460     }
1461 
1462     if (fused_param == nullptr) {
1463       // Clone's operand was not already an operand of the fusion
1464       // instruction. Add it as an operand and add a corresponding fused
1465       // parameter instruction.
1466       fused_param = AddFusionOperand(operand);
1467     }
1468     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1469   }
1470 
1471   if (add_output) {
1472     CHECK_GT(instruction_to_fuse->user_count(), 0);
1473     // If this is already a multioutput fusion instruction, expand the root
1474     // tuple by 1.
1475     HloInstruction* fused_root = fused_expression_root();
1476     HloInstruction::InstructionVector tuple_elements;
1477     bool newly_created_tuple_instr = false;
1478     if (fused_root->opcode() == HloOpcode::kTuple) {
1479       tuple_elements = fused_root->operands();
1480     } else {
1481       tuple_elements.push_back(fused_root);
1482       newly_created_tuple_instr = true;
1483     }
1484     if (clone->opcode() == HloOpcode::kTuple) {
1485       for (auto inst : clone->operands()) {
1486         tuple_elements.push_back(inst);
1487       }
1488     } else {
1489       tuple_elements.push_back(clone);
1490     }
1491     HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1492         HloInstruction::CreateTuple(tuple_elements));
1493     fused_instructions_computation()->set_root_instruction(new_root);
1494     *mutable_shape() = new_root->shape();
1495     if (fused_root->opcode() == HloOpcode::kTuple) {
1496       TF_CHECK_OK(
1497           fused_instructions_computation()->RemoveInstruction(fused_root));
1498     }
1499 
1500     // If this is a newly created multioutput instruction, we need to update
1501     // the use of the original fusion instruction.
1502     if (newly_created_tuple_instr) {
1503       HloInstruction* new_instr = parent()->AddInstruction(
1504           HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1505       TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1506     }
1507     int64 index = tuple_elements.size();
1508     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1509       CHECK_EQ(clone, instruction_to_fuse);
1510       index -= clone->operand_count();
1511       std::vector<HloInstruction*> to_be_removed;
1512       for (auto old_gte : clone->users()) {
1513         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1514         int64 old_tuple_index = old_gte->tuple_index();
1515         HloInstruction* new_gte =
1516             parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1517                 old_gte->shape(), this, index + old_tuple_index));
1518         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1519         to_be_removed.push_back(old_gte);
1520       }
1521       for (auto old_gte : to_be_removed) {
1522         TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1523       }
1524     } else {
1525       HloInstruction* new_gte =
1526           parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1527               clone->shape(), this, index - 1));
1528       TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1529     }
1530   }
1531 
1532   if (clone != instruction_to_fuse) {
1533     VLOG(2) << "New clone:\n" << clone->ToString();
1534   }
1535   return clone;
1536 }
1537 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1538 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1539     const HloPrintOptions& options) const {
1540   return {StrCat("kind=", xla::ToString(fusion_kind()))};
1541 }
1542 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1543 bool HloFusionInstruction::IdenticalSlowPath(
1544     const HloInstruction& other,
1545     const std::function<bool(const HloComputation*, const HloComputation*)>&
1546         eq_computations) const {
1547   return fusion_kind() == other.fusion_kind() &&
1548          eq_computations(fused_instructions_computation(),
1549                          other.fused_instructions_computation());
1550 }
1551 
HashOperandRecursive(const HloInstruction * hlo)1552 static uint64 HashOperandRecursive(const HloInstruction* hlo) {
1553   return hlo->Hash(HashOperandRecursive);
1554 }
1555 
InnerHash() const1556 uint64 HloFusionInstruction::InnerHash() const {
1557   // Use HashOperandRecursive to recursively compute hash on inner operands.
1558   return fused_instructions_computation()->root_instruction()->Hash(
1559       HashOperandRecursive);
1560 }
1561 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1562 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1563     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1564     HloCloneContext* context) const {
1565   HloModule* module = context != nullptr ? context->module() : GetModule();
1566   HloComputation* new_fused_computation = nullptr;
1567   if (context != nullptr) {
1568     new_fused_computation =
1569         context->FindComputation(fused_instructions_computation());
1570   }
1571   if (new_fused_computation == nullptr) {
1572     new_fused_computation = module->AddEmbeddedComputation(
1573         fused_instructions_computation()->Clone("clone", context));
1574   }
1575   return absl::make_unique<HloFusionInstruction>(
1576       shape, fusion_kind(), new_operands, new_fused_computation);
1577 }
1578 
DeduplicateFusionOperands()1579 Status HloFusionInstruction::DeduplicateFusionOperands() {
1580   absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1581   std::vector<int> operands_to_remove;
1582   for (int i = 0; i < operand_count(); ++i) {
1583     auto emplace_result = operand_indices.emplace(operand(i), i);
1584     if (!emplace_result.second) {
1585       TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1586           fused_parameter(emplace_result.first->second)));
1587       operands_to_remove.push_back(i);
1588     }
1589   }
1590   if (operands_to_remove.empty()) {
1591     return Status::OK();
1592   }
1593   TF_RETURN_IF_ERROR(
1594       fused_instructions_computation()->RemoveUnusedParameters());
1595   RemoveOperandsAtAscendingIndices(operands_to_remove);
1596   return Status::OK();
1597 }
1598 
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1599 HloRngInstruction::HloRngInstruction(
1600     const Shape& shape, RandomDistribution distribution,
1601     absl::Span<HloInstruction* const> parameters)
1602     : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1603   for (HloInstruction* param : parameters) {
1604     AppendOperand(param);
1605   }
1606 }
1607 
ToProto() const1608 HloInstructionProto HloRngInstruction::ToProto() const {
1609   HloInstructionProto proto = HloInstruction::ToProto();
1610   proto.set_distribution(distribution_);
1611   return proto;
1612 }
1613 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1614 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1615     const HloPrintOptions& options) const {
1616   return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1617 }
1618 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1619 bool HloRngInstruction::IsElementwiseImpl(
1620     const absl::optional<int64>& operand_idx) const {
1621   return true;
1622 }
1623 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1624 bool HloRngInstruction::IdenticalSlowPath(
1625     const HloInstruction& other,
1626     const std::function<bool(const HloComputation*, const HloComputation*)>&
1627         eq_computations) const {
1628   return false;
1629 }
1630 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1631 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1632     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1633     HloCloneContext* context) const {
1634   return absl::make_unique<HloRngInstruction>(shape, distribution_,
1635                                               new_operands);
1636 }
1637 
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1638 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1639                                                  const Shape& shape,
1640                                                  const string& name)
1641     : HloInstruction(HloOpcode::kParameter, shape),
1642       parameter_number_(parameter_number) {
1643   SetAndSanitizeName(name);
1644 }
1645 
ToProto() const1646 HloInstructionProto HloParameterInstruction::ToProto() const {
1647   HloInstructionProto proto = HloInstruction::ToProto();
1648   proto.set_parameter_number(parameter_number_);
1649   if (parameter_replicated_at_leaf_buffers_) {
1650     for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1651       proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1652           replicated);
1653     }
1654   }
1655   return proto;
1656 }
1657 
ExtraAttributesToStringImpl(const HloPrintOptions &) const1658 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1659     const HloPrintOptions& /*options*/) const {
1660   std::vector<string> result;
1661   if (!parameter_replicated_at_leaf_buffers_) {
1662     return result;
1663   }
1664   std::vector<string> buffers_replicated_strs;
1665   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1666     buffers_replicated_strs.push_back(replicated ? "true" : "false");
1667   }
1668   result.push_back(StrCat("parameter_replication={",
1669                           StrJoin(buffers_replicated_strs, ","), "}"));
1670   return result;
1671 }
1672 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1673 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1674     const HloPrintOptions& options,
1675     CanonicalNameMap* canonical_name_map) const {
1676   return StrCat(parameter_number_);
1677 }
1678 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1679 bool HloParameterInstruction::IdenticalSlowPath(
1680     const HloInstruction& other,
1681     const std::function<bool(const HloComputation*, const HloComputation*)>&
1682         eq_computations) const {
1683   const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1684   return parameter_number() == casted_other.parameter_number();
1685 }
1686 
1687 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1688 HloParameterInstruction::CloneWithNewOperandsImpl(
1689     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1690     HloCloneContext* context) const {
1691   return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
1692                                                     name());
1693 }
1694 
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1695 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1696     const Shape& shape, HloInstruction* operand, int64 index)
1697     : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1698   AppendOperand(operand);
1699 }
1700 
ToProto() const1701 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1702   HloInstructionProto proto = HloInstruction::ToProto();
1703   proto.set_tuple_index(tuple_index_);
1704   return proto;
1705 }
1706 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1707 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1708     const HloPrintOptions& options) const {
1709   return {StrCat("index=", tuple_index())};
1710 }
1711 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1712 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1713     const HloInstruction& other,
1714     const std::function<bool(const HloComputation*, const HloComputation*)>&
1715         eq_computations) const {
1716   const auto& casted_other =
1717       static_cast<const HloGetTupleElementInstruction&>(other);
1718   return tuple_index() == casted_other.tuple_index();
1719 }
1720 
1721 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1722 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1723     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1724     HloCloneContext* context) const {
1725   CHECK_EQ(new_operands.size(), 1);
1726   return absl::make_unique<HloGetTupleElementInstruction>(
1727       shape, new_operands[0], tuple_index());
1728 }
1729 
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1730 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
1731     const Shape& shape, HloInstruction* operand, const int exponent_bits,
1732     const int mantissa_bits)
1733     : HloInstruction(HloOpcode::kReducePrecision, shape),
1734       exponent_bits_(exponent_bits),
1735       mantissa_bits_(mantissa_bits) {
1736   AppendOperand(operand);
1737 }
1738 
ToProto() const1739 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
1740   HloInstructionProto proto = HloInstruction::ToProto();
1741   proto.set_exponent_bits(exponent_bits_);
1742   proto.set_mantissa_bits(mantissa_bits_);
1743   return proto;
1744 }
1745 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1746 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
1747     const HloPrintOptions& options) const {
1748   return {StrCat("exponent_bits=", exponent_bits_),
1749           StrCat("mantissa_bits=", mantissa_bits_)};
1750 }
1751 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1752 bool HloReducePrecisionInstruction::IdenticalSlowPath(
1753     const HloInstruction& other,
1754     const std::function<bool(const HloComputation*, const HloComputation*)>&
1755         eq_computations) const {
1756   const auto& casted_other =
1757       static_cast<const HloReducePrecisionInstruction&>(other);
1758   // A reduce-precision operation is determined by the bit sizes.
1759   return exponent_bits() == casted_other.exponent_bits() &&
1760          mantissa_bits() == casted_other.mantissa_bits();
1761 }
1762 
1763 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1764 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
1765     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1766     HloCloneContext* context) const {
1767   CHECK_EQ(new_operands.size(), 1);
1768   return absl::make_unique<HloReducePrecisionInstruction>(
1769       shape, new_operands[0], exponent_bits(), mantissa_bits());
1770 }
1771 
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1772 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
1773                                            HloInstruction* token_operand,
1774                                            const string& config)
1775     : HloInstruction(HloOpcode::kInfeed,
1776                      ShapeUtil::MakeTupleShape(
1777                          {infeed_shape, ShapeUtil::MakeTokenShape()})),
1778       infeed_config_(config) {
1779   AppendOperand(token_operand);
1780 }
1781 
ToProto() const1782 HloInstructionProto HloInfeedInstruction::ToProto() const {
1783   HloInstructionProto proto = HloInstruction::ToProto();
1784   proto.set_infeed_config(infeed_config_);
1785   return proto;
1786 }
1787 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1788 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
1789     const HloPrintOptions& options) const {
1790   if (infeed_config_.empty()) {
1791     return {};
1792   }
1793   return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
1794 }
1795 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1796 bool HloInfeedInstruction::IdenticalSlowPath(
1797     const HloInstruction& other,
1798     const std::function<bool(const HloComputation*, const HloComputation*)>&
1799         eq_computations) const {
1800   // Not yet supported.
1801   return false;
1802 }
1803 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1804 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
1805     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1806     HloCloneContext* context) const {
1807   CHECK_EQ(new_operands.size(), 1);
1808   return absl::make_unique<HloInfeedInstruction>(
1809       infeed_shape(), new_operands[0], infeed_config());
1810 }
1811 
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1812 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
1813                                              HloInstruction* operand,
1814                                              HloInstruction* token_operand,
1815                                              absl::string_view outfeed_config)
1816     : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
1817       outfeed_shape_(outfeed_shape),
1818       outfeed_config_(outfeed_config) {
1819   AppendOperand(operand);
1820   AppendOperand(token_operand);
1821 }
1822 
ToProto() const1823 HloInstructionProto HloOutfeedInstruction::ToProto() const {
1824   HloInstructionProto proto = HloInstruction::ToProto();
1825   proto.set_outfeed_config(outfeed_config());
1826   *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
1827   return proto;
1828 }
1829 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1830 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
1831     const HloPrintOptions& options) const {
1832   if (outfeed_config_.empty()) {
1833     return {};
1834   }
1835   return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
1836 }
1837 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1838 bool HloOutfeedInstruction::IdenticalSlowPath(
1839     const HloInstruction& other,
1840     const std::function<bool(const HloComputation*, const HloComputation*)>&
1841         eq_computations) const {
1842   // Not yet supported.
1843   return false;
1844 }
1845 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1846 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
1847     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1848     HloCloneContext* context) const {
1849   CHECK_EQ(new_operands.size(), 2);
1850   return absl::make_unique<HloOutfeedInstruction>(
1851       outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
1852 }
1853 
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1854 HloConvolutionInstruction::HloConvolutionInstruction(
1855     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1856     int64 feature_group_count, int64 batch_group_count, const Window& window,
1857     const ConvolutionDimensionNumbers& dimension_numbers,
1858     const PrecisionConfig& precision_config)
1859     : HloInstruction(HloOpcode::kConvolution, shape),
1860       feature_group_count_(feature_group_count),
1861       batch_group_count_(batch_group_count),
1862       window_(window),
1863       convolution_dimension_numbers_(dimension_numbers),
1864       precision_config_(precision_config) {
1865   if (window_util::HasBaseDilation(window)) {
1866     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
1867   }
1868   if (window_util::HasWindowDilation(window)) {
1869     SetAndSanitizeName(StrCat(name(), "-window-dilated"));
1870   }
1871   AppendOperand(lhs);
1872   AppendOperand(rhs);
1873 }
1874 
ToCategory() const1875 string HloConvolutionInstruction::ToCategory() const {
1876   string category = "convolution";
1877   if (window_util::HasBaseDilation(window())) {
1878     category += " base-dilated";
1879   }
1880   if (window_util::HasWindowDilation(window())) {
1881     category += " window-dilated";
1882   }
1883   return category;
1884 }
1885 
ToProto() const1886 HloInstructionProto HloConvolutionInstruction::ToProto() const {
1887   HloInstructionProto proto = HloInstruction::ToProto();
1888   *proto.mutable_window() = window_;
1889   *proto.mutable_convolution_dimension_numbers() =
1890       convolution_dimension_numbers_;
1891   proto.set_feature_group_count(feature_group_count_);
1892   proto.set_batch_group_count(batch_group_count_);
1893   *proto.mutable_precision_config() = precision_config_;
1894   return proto;
1895 }
1896 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1897 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
1898     const HloPrintOptions& options) const {
1899   std::vector<string> extra;
1900   if (window_.dimensions_size() != 0) {
1901     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
1902   }
1903   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
1904                                             convolution_dimension_numbers_)));
1905   if (feature_group_count_ != 1) {
1906     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
1907   }
1908 
1909   if (batch_group_count_ != 1) {
1910     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
1911   }
1912 
1913   string precision_config_string = PrecisionConfigToString(precision_config_);
1914   if (!precision_config_string.empty()) {
1915     extra.push_back(precision_config_string);
1916   }
1917 
1918   return extra;
1919 }
1920 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1921 bool HloConvolutionInstruction::IdenticalSlowPath(
1922     const HloInstruction& other,
1923     const std::function<bool(const HloComputation*, const HloComputation*)>&
1924         eq_computations) const {
1925   const auto& casted_other =
1926       static_cast<const HloConvolutionInstruction&>(other);
1927   if (feature_group_count_ != other.feature_group_count()) {
1928     return false;
1929   }
1930   if (batch_group_count_ != other.batch_group_count()) {
1931     return false;
1932   }
1933   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
1934          protobuf_util::ProtobufEquals(
1935              convolution_dimension_numbers(),
1936              casted_other.convolution_dimension_numbers()) &&
1937          protobuf_util::ProtobufEquals(precision_config(),
1938                                        casted_other.precision_config());
1939 }
1940 
1941 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1942 HloConvolutionInstruction::CloneWithNewOperandsImpl(
1943     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1944     HloCloneContext* context) const {
1945   CHECK_EQ(new_operands.size(), 2);
1946   return absl::make_unique<HloConvolutionInstruction>(
1947       shape, new_operands[0], new_operands[1], feature_group_count_,
1948       batch_group_count_, window(), convolution_dimension_numbers_,
1949       precision_config_);
1950 }
1951 
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1952 HloReduceWindowInstruction::HloReduceWindowInstruction(
1953     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1954     const Window& window, HloComputation* reduce_computation)
1955     : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
1956   AppendOperand(operand);
1957   AppendOperand(init_value);
1958   AppendComputation(reduce_computation);
1959 }
1960 
ToProto() const1961 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
1962   HloInstructionProto proto = HloInstruction::ToProto();
1963   *proto.mutable_window() = window_;
1964   return proto;
1965 }
1966 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1967 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
1968     const HloPrintOptions& options) const {
1969   std::vector<string> extra;
1970   if (window_.dimensions_size() != 0) {
1971     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
1972   }
1973   return extra;
1974 }
1975 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1976 bool HloReduceWindowInstruction::IdenticalSlowPath(
1977     const HloInstruction& other,
1978     const std::function<bool(const HloComputation*, const HloComputation*)>&
1979         eq_computations) const {
1980   const auto& casted_other =
1981       static_cast<const HloReduceWindowInstruction&>(other);
1982   return eq_computations(to_apply(), casted_other.to_apply()) &&
1983          protobuf_util::ProtobufEquals(window(), casted_other.window());
1984 }
1985 
1986 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1987 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
1988     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1989     HloCloneContext* context) const {
1990   CHECK_EQ(new_operands.size(), 2);
1991   return absl::make_unique<HloReduceWindowInstruction>(
1992       shape, new_operands[0], new_operands[1], window(), to_apply());
1993 }
1994 
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1995 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
1996     const Shape& shape, HloInstruction* operand, HloComputation* select,
1997     const Window& window, HloInstruction* source, HloInstruction* init_value,
1998     HloComputation* scatter)
1999     : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2000   AppendOperand(operand);
2001   AppendOperand(source);
2002   AppendOperand(init_value);
2003   // Select comes before scatter in the vector.
2004   AppendComputation(select);
2005   AppendComputation(scatter);
2006 }
2007 
ToProto() const2008 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2009   HloInstructionProto proto = HloInstruction::ToProto();
2010   *proto.mutable_window() = window_;
2011   return proto;
2012 }
2013 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2014 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2015     const HloPrintOptions& options) const {
2016   std::vector<string> extra;
2017   if (window_.dimensions_size() != 0) {
2018     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2019   }
2020   return extra;
2021 }
2022 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2023 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2024     const HloInstruction& other,
2025     const std::function<bool(const HloComputation*, const HloComputation*)>&
2026         eq_computations) const {
2027   const auto& casted_other =
2028       static_cast<const HloSelectAndScatterInstruction&>(other);
2029   return eq_computations(select(), casted_other.select()) &&
2030          eq_computations(scatter(), casted_other.scatter()) &&
2031          protobuf_util::ProtobufEquals(window(), casted_other.window());
2032 }
2033 
2034 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2035 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2036     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2037     HloCloneContext* context) const {
2038   CHECK_EQ(new_operands.size(), 3);
2039   return absl::make_unique<HloSelectAndScatterInstruction>(
2040       shape, new_operands[0], select(), window(), new_operands[1],
2041       new_operands[2], scatter());
2042 }
2043 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque)2044 HloCustomCallInstruction::HloCustomCallInstruction(
2045     const Shape& shape, absl::Span<HloInstruction* const> operands,
2046     absl::string_view custom_call_target, absl::string_view opaque)
2047     : HloInstruction(HloOpcode::kCustomCall, shape),
2048       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2049       opaque_(opaque.begin(), opaque.end()),
2050       feature_group_count_(1),
2051       batch_group_count_(1),
2052       layout_constrained_(false) {
2053   for (auto operand : operands) {
2054     AppendOperand(operand);
2055   }
2056 }
2057 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque,absl::Span<const Shape> operand_shapes_with_layout)2058 HloCustomCallInstruction::HloCustomCallInstruction(
2059     const Shape& shape, absl::Span<HloInstruction* const> operands,
2060     absl::string_view custom_call_target, absl::string_view opaque,
2061     absl::Span<const Shape> operand_shapes_with_layout)
2062     : HloInstruction(HloOpcode::kCustomCall, shape),
2063       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2064       opaque_(opaque.begin(), opaque.end()),
2065       feature_group_count_(1),
2066       batch_group_count_(1),
2067       layout_constrained_(true),
2068       operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2069                                   operand_shapes_with_layout.end()) {
2070   for (auto operand : operands) {
2071     AppendOperand(operand);
2072   }
2073 }
2074 
ToProto() const2075 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2076   HloInstructionProto proto = HloInstruction::ToProto();
2077   if (window_ != nullptr) {
2078     *proto.mutable_window() = *window_;
2079   }
2080   if (convolution_dimension_numbers_ != nullptr) {
2081     *proto.mutable_convolution_dimension_numbers() =
2082         *convolution_dimension_numbers_;
2083   }
2084   proto.set_custom_call_target(custom_call_target_);
2085   proto.set_custom_call_opaque(opaque_);
2086   proto.set_feature_group_count(feature_group_count_);
2087   proto.set_batch_group_count(batch_group_count_);
2088   if (layout_constrained()) {
2089     proto.set_constrain_layout(true);
2090     for (const Shape& shape : operand_shapes_with_layout_) {
2091       *proto.add_operand_shapes_with_layout() = shape.ToProto();
2092     }
2093   }
2094   return proto;
2095 }
2096 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2097 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2098     const HloPrintOptions& options) const {
2099   std::vector<string> extra;
2100   if (window_ != nullptr && window_->dimensions_size() != 0) {
2101     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2102   }
2103   if (convolution_dimension_numbers_ != nullptr) {
2104     extra.push_back(StrCat(
2105         "dim_labels=",
2106         ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2107   }
2108   if (feature_group_count_ != 1) {
2109     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2110   }
2111   if (batch_group_count_ != 1) {
2112     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2113   }
2114   // By contract, we print the custom call target even if
2115   // options.print_subcomputation_mode() == kOff, because the call target is not
2116   // an HloComputation.
2117   extra.push_back(
2118       StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2119   // If the opaque string becomes enormous we may want to reconsider printing
2120   // this inline and consider other options.
2121   if (!opaque_.empty()) {
2122     extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
2123   }
2124   if (layout_constrained()) {
2125     std::vector<string> shape_strings;
2126     for (const Shape& shape : operand_shapes_with_layout_) {
2127       shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2128     }
2129     extra.push_back(StrCat("operand_layout_constraints={",
2130                            StrJoin(shape_strings, ", "), "}"));
2131   }
2132   return extra;
2133 }
2134 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2135 bool HloCustomCallInstruction::IdenticalSlowPath(
2136     const HloInstruction& other,
2137     const std::function<bool(const HloComputation*, const HloComputation*)>&
2138         eq_computations) const {
2139   const auto& casted_other =
2140       static_cast<const HloCustomCallInstruction&>(other);
2141   if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2142       (window_ != nullptr &&
2143        !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2144     return false;
2145   }
2146   if ((convolution_dimension_numbers_ == nullptr) !=
2147           (casted_other.convolution_dimension_numbers_ == nullptr) ||
2148       (convolution_dimension_numbers_ != nullptr &&
2149        !protobuf_util::ProtobufEquals(
2150            convolution_dimension_numbers(),
2151            casted_other.convolution_dimension_numbers()))) {
2152     return false;
2153   }
2154   if (feature_group_count_ != casted_other.feature_group_count_) {
2155     return false;
2156   }
2157   if (batch_group_count_ != casted_other.batch_group_count_) {
2158     return false;
2159   }
2160   if (layout_constrained() != casted_other.layout_constrained()) {
2161     return false;
2162   }
2163   if (layout_constrained()) {
2164     for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2165       if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2166                             casted_other.operand_shapes_with_layout_[i])) {
2167         return false;
2168       }
2169     }
2170   }
2171   return custom_call_target_ == casted_other.custom_call_target_ &&
2172          opaque_ == casted_other.opaque_;
2173 }
2174 
2175 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2176 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2177     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2178     HloCloneContext* context) const {
2179   auto cloned = absl::make_unique<HloCustomCallInstruction>(
2180       shape, new_operands, custom_call_target(), opaque());
2181   if (layout_constrained()) {
2182     cloned->layout_constrained_ = true;
2183     cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2184   }
2185   if (window_ != nullptr) {
2186     cloned->set_window(*window_);
2187   }
2188   if (convolution_dimension_numbers_ != nullptr) {
2189     cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2190   }
2191   cloned->set_feature_group_count(feature_group_count_);
2192   cloned->set_batch_group_count(batch_group_count_);
2193   return std::move(cloned);
2194 }
2195 
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2196 HloPadInstruction::HloPadInstruction(const Shape& shape,
2197                                      HloInstruction* operand,
2198                                      HloInstruction* padding_value,
2199                                      const PaddingConfig& padding_config)
2200     : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2201   AppendOperand(operand);
2202   AppendOperand(padding_value);
2203 }
2204 
ToProto() const2205 HloInstructionProto HloPadInstruction::ToProto() const {
2206   HloInstructionProto proto = HloInstruction::ToProto();
2207   *proto.mutable_padding_config() = padding_config_;
2208   return proto;
2209 }
2210 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2211 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2212     const HloPrintOptions& options) const {
2213   return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2214 }
2215 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2216 bool HloPadInstruction::IdenticalSlowPath(
2217     const HloInstruction& other,
2218     const std::function<bool(const HloComputation*, const HloComputation*)>&
2219         eq_computations) const {
2220   const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2221   return protobuf_util::ProtobufEquals(padding_config(),
2222                                        casted_other.padding_config());
2223 }
2224 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2225 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2226     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2227     HloCloneContext* context) const {
2228   CHECK_EQ(new_operands.size(), 2);
2229   return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2230                                               new_operands[1], padding_config_);
2231 }
2232 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2233 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2234     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2235     absl::Span<const int64> slice_sizes)
2236     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2237       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2238   AppendOperand(operand);
2239   AppendOperand(start_indices);
2240 }
2241 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2242 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2243     const Shape& shape, HloInstruction* operand,
2244     absl::Span<HloInstruction* const> start_indices,
2245     absl::Span<const int64> slice_sizes)
2246     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2247       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2248   AppendOperand(operand);
2249   for (HloInstruction* index : start_indices) {
2250     AppendOperand(index);
2251   }
2252 }
2253 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2254 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2255     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2256     HloInstruction* start_indices)
2257     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2258   AppendOperand(operand);
2259   AppendOperand(update);
2260   AppendOperand(start_indices);
2261 }
2262 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2263 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2264     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2265     absl::Span<HloInstruction* const> start_indices)
2266     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2267   AppendOperand(operand);
2268   AppendOperand(update);
2269   for (HloInstruction* index : start_indices) {
2270     AppendOperand(index);
2271   }
2272 }
2273 
ToProto() const2274 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2275   HloInstructionProto proto = HloInstruction::ToProto();
2276   for (int64 slice_size : dynamic_slice_sizes_) {
2277     proto.add_dynamic_slice_sizes(slice_size);
2278   }
2279   return proto;
2280 }
2281 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2282 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2283     const HloPrintOptions& options) const {
2284   return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2285                  "}")};
2286 }
2287 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2288 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2289     const HloInstruction& other,
2290     const std::function<bool(const HloComputation*, const HloComputation*)>&
2291         eq_computations) const {
2292   return true;
2293 }
2294 
2295 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2296 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2297     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2298     HloCloneContext* context) const {
2299   if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2300     // TODO(b/118437727): Old form, remove this path.
2301     return absl::make_unique<HloDynamicSliceInstruction>(
2302         shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2303   } else {
2304     return absl::make_unique<HloDynamicSliceInstruction>(
2305         shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2306   }
2307 }
2308 
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)2309 HloGatherInstruction::HloGatherInstruction(
2310     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2311     const GatherDimensionNumbers& gather_dim_numbers,
2312     absl::Span<const int64> slice_sizes)
2313     : HloInstruction(HloOpcode::kGather, shape) {
2314   AppendOperand(operand);
2315   AppendOperand(start_indices);
2316   gather_dimension_numbers_ =
2317       absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2318   absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2319 }
2320 
GatherDimensionNumbersToString() const2321 string HloGatherInstruction::GatherDimensionNumbersToString() const {
2322   CHECK(gather_dimension_numbers_ != nullptr);
2323   string offset_dims =
2324       StrCat("offset_dims={",
2325              StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}");
2326   string collapsed_slice_dims = StrCat(
2327       "collapsed_slice_dims={",
2328       StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
2329   string start_index_map =
2330       StrCat("start_index_map={",
2331              StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}");
2332   string index_vector_dim = StrCat(
2333       "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
2334 
2335   return StrJoin<std::initializer_list<string>>(
2336       {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2337       ", ");
2338 }
2339 
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2340 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2341     absl::Span<const int64> offset_dims,
2342     absl::Span<const int64> collapsed_slice_dims,
2343     absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2344   GatherDimensionNumbers gather_dim_numbers;
2345   for (int64 output_window_dim : offset_dims) {
2346     gather_dim_numbers.add_offset_dims(output_window_dim);
2347   }
2348   for (int64 elided_window_dim : collapsed_slice_dims) {
2349     gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2350   }
2351   for (int64 gather_dim_to_input_dim : start_index_map) {
2352     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2353   }
2354 
2355   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2356   return gather_dim_numbers;
2357 }
2358 
ToProto() const2359 HloInstructionProto HloGatherInstruction::ToProto() const {
2360   HloInstructionProto proto = HloInstruction::ToProto();
2361   *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2362   for (int64 bound : gather_slice_sizes()) {
2363     proto.add_gather_slice_sizes(bound);
2364   }
2365   return proto;
2366 }
2367 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2368 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2369     const HloPrintOptions& options) const {
2370   return {GatherDimensionNumbersToString(),
2371           StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2372 }
2373 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2374 bool HloGatherInstruction::IdenticalSlowPath(
2375     const HloInstruction& other,
2376     const std::function<bool(const HloComputation*, const HloComputation*)>&
2377         eq_computations) const {
2378   const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2379   return protobuf_util::ProtobufEquals(
2380              gather_dimension_numbers(),
2381              casted_other.gather_dimension_numbers()) &&
2382          gather_slice_sizes() == casted_other.gather_slice_sizes();
2383 }
2384 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2385 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2386     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2387     HloCloneContext* context) const {
2388   CHECK_EQ(new_operands.size(), 2);
2389   return absl::make_unique<HloGatherInstruction>(
2390       shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2391       gather_slice_sizes());
2392 }
2393 
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers)2394 HloScatterInstruction::HloScatterInstruction(
2395     const Shape& shape, HloInstruction* operand,
2396     HloInstruction* scatter_indices, HloInstruction* updates,
2397     HloComputation* update_computation,
2398     const ScatterDimensionNumbers& scatter_dim_numbers)
2399     : HloInstruction(HloOpcode::kScatter, shape) {
2400   AppendOperand(operand);
2401   AppendOperand(scatter_indices);
2402   AppendOperand(updates);
2403   AppendComputation(update_computation);
2404   scatter_dimension_numbers_ =
2405       absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2406 }
2407 
ScatterDimensionNumbersToString() const2408 string HloScatterInstruction::ScatterDimensionNumbersToString() const {
2409   string update_window_dims = StrCat(
2410       "update_window_dims={",
2411       StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}");
2412   string inserted_window_dims = StrCat(
2413       "inserted_window_dims={",
2414       StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
2415   string scatter_dims_to_operand_dims = StrCat(
2416       "scatter_dims_to_operand_dims={",
2417       StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
2418       "}");
2419   string index_vector_dim = StrCat(
2420       "index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
2421 
2422   return StrJoin<std::initializer_list<string>>(
2423       {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2424        index_vector_dim},
2425       ", ");
2426 }
2427 
2428 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2429 HloScatterInstruction::MakeScatterDimNumbers(
2430     absl::Span<const int64> update_window_dims,
2431     absl::Span<const int64> inserted_window_dims,
2432     absl::Span<const int64> scatter_dims_to_operand_dims,
2433     int64 index_vector_dim) {
2434   ScatterDimensionNumbers scatter_dim_numbers;
2435   for (int64 update_window_dim : update_window_dims) {
2436     scatter_dim_numbers.add_update_window_dims(update_window_dim);
2437   }
2438   for (int64 inserted_window_dim : inserted_window_dims) {
2439     scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2440   }
2441   for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2442     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2443         scatter_dim_to_operand_dim);
2444   }
2445   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2446   return scatter_dim_numbers;
2447 }
2448 
ToProto() const2449 HloInstructionProto HloScatterInstruction::ToProto() const {
2450   HloInstructionProto proto = HloInstruction::ToProto();
2451   *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2452   return proto;
2453 }
2454 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2455 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2456     const HloPrintOptions& options) const {
2457   return {ScatterDimensionNumbersToString()};
2458 }
2459 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2460 bool HloScatterInstruction::IdenticalSlowPath(
2461     const HloInstruction& other,
2462     const std::function<bool(const HloComputation*, const HloComputation*)>&
2463         eq_computations) const {
2464   const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2465   return protobuf_util::ProtobufEquals(
2466              scatter_dimension_numbers(),
2467              casted_other.scatter_dimension_numbers()) &&
2468          eq_computations(to_apply(), casted_other.to_apply());
2469 }
2470 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2471 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2472     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2473     HloCloneContext* context) const {
2474   CHECK_EQ(new_operands.size(), 3);
2475   return absl::make_unique<HloScatterInstruction>(
2476       shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2477       scatter_dimension_numbers());
2478 }
2479 
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2480 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2481     : HloInstruction(HloOpcode::kIota, shape),
2482       iota_dimension_(iota_dimension) {}
2483 
ToProto() const2484 HloInstructionProto HloIotaInstruction::ToProto() const {
2485   HloInstructionProto proto = HloInstruction::ToProto();
2486   proto.add_dimensions(iota_dimension());
2487   return proto;
2488 }
2489 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2490 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2491     const HloPrintOptions& options) const {
2492   return {StrCat("iota_dimension=", iota_dimension())};
2493 }
2494 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2495 bool HloIotaInstruction::IdenticalSlowPath(
2496     const HloInstruction& other,
2497     const std::function<bool(const HloComputation*, const HloComputation*)>&
2498         eq_computations) const {
2499   const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2500   return iota_dimension() == casted_other.iota_dimension();
2501 }
2502 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2503 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2504     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2505     HloCloneContext* context) const {
2506   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2507 }
2508 
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2509 HloDotInstruction::HloDotInstruction(
2510     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2511     const DotDimensionNumbers& dimension_numbers,
2512     const PrecisionConfig& precision_config)
2513     : HloInstruction(HloOpcode::kDot, shape),
2514       dot_dimension_numbers_(dimension_numbers),
2515       precision_config_(precision_config) {
2516   AppendOperand(lhs);
2517   AppendOperand(rhs);
2518 }
2519 
ToProto() const2520 HloInstructionProto HloDotInstruction::ToProto() const {
2521   HloInstructionProto proto = HloInstruction::ToProto();
2522   *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2523   *proto.mutable_precision_config() = precision_config_;
2524   return proto;
2525 }
2526 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2527 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2528     const HloPrintOptions& options) const {
2529   std::vector<string> extra = {DotDimensionNumbersToString()};
2530 
2531   string precision_config_string = PrecisionConfigToString(precision_config_);
2532   if (!precision_config_string.empty()) {
2533     extra.push_back(precision_config_string);
2534   }
2535   return extra;
2536 }
2537 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2538 bool HloDotInstruction::IdenticalSlowPath(
2539     const HloInstruction& other,
2540     const std::function<bool(const HloComputation*, const HloComputation*)>&
2541         eq_computations) const {
2542   const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2543   return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2544                                        casted_other.dot_dimension_numbers()) &&
2545          protobuf_util::ProtobufEquals(precision_config(),
2546                                        casted_other.precision_config());
2547 }
2548 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2549 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2550     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2551     HloCloneContext* context) const {
2552   CHECK_EQ(new_operands.size(), 2);
2553   return absl::make_unique<HloDotInstruction>(
2554       shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2555       precision_config_);
2556 }
2557 
DotDimensionNumbersToString() const2558 string HloDotInstruction::DotDimensionNumbersToString() const {
2559   std::vector<string> result;
2560   const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2561   if (!dnums.lhs_batch_dimensions().empty()) {
2562     result.push_back(StrCat("lhs_batch_dims={",
2563                             StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
2564   }
2565   result.push_back(StrCat("lhs_contracting_dims={",
2566                           StrJoin(dnums.lhs_contracting_dimensions(), ","),
2567                           "}"));
2568 
2569   if (!dnums.rhs_batch_dimensions().empty()) {
2570     result.push_back(StrCat("rhs_batch_dims={",
2571                             StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
2572   }
2573   result.push_back(StrCat("rhs_contracting_dims={",
2574                           StrJoin(dnums.rhs_contracting_dimensions(), ","),
2575                           "}"));
2576 
2577   return StrJoin(result, ", ");
2578 }
2579 
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)2580 HloDomainInstruction::HloDomainInstruction(
2581     const Shape& shape, HloInstruction* operand,
2582     std::unique_ptr<DomainMetadata> operand_side_metadata,
2583     std::unique_ptr<DomainMetadata> user_side_metadata)
2584     : HloInstruction(HloOpcode::kDomain, shape),
2585       operand_side_metadata_(std::move(operand_side_metadata)),
2586       user_side_metadata_(std::move(user_side_metadata)) {
2587   AppendOperand(operand);
2588 }
2589 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2590 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
2591     const HloPrintOptions& options) const {
2592   if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
2593     return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
2594                    "\", entry=", user_side_metadata_->ToString(),
2595                    ", exit=", operand_side_metadata_->ToString(), "}")};
2596   }
2597   return {};
2598 }
2599 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2600 bool HloDomainInstruction::IdenticalSlowPath(
2601     const HloInstruction& other,
2602     const std::function<bool(const HloComputation*, const HloComputation*)>&
2603         eq_computations) const {
2604   const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
2605   return operand_side_metadata().Matches(
2606              casted_other.operand_side_metadata()) &&
2607          user_side_metadata().Matches(casted_other.user_side_metadata());
2608 }
2609 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2610 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
2611     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2612     HloCloneContext* context) const {
2613   CHECK_EQ(new_operands.size(), 1);
2614   return absl::make_unique<HloDomainInstruction>(
2615       shape, new_operands[0], operand_side_metadata_->Clone(),
2616       user_side_metadata_->Clone());
2617 }
2618 
ToProto() const2619 HloInstructionProto HloDomainInstruction::ToProto() const {
2620   HloInstructionProto proto = HloInstruction::ToProto();
2621   auto operand_side_sharding =
2622       dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
2623   if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
2624     *proto.mutable_domain_entry_sharding() =
2625         operand_side_sharding->sharding()->ToProto();
2626   }
2627 
2628   auto user_side_sharding =
2629       dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
2630   if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
2631     *proto.mutable_domain_exit_sharding() =
2632         user_side_sharding->sharding()->ToProto();
2633   }
2634 
2635   return proto;
2636 }
2637 
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)2638 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
2639     const Shape& shape, HloInstruction* operand, int64 dimension)
2640     : HloInstruction(HloOpcode::kGetDimensionSize, shape),
2641       dimension_(dimension) {
2642   AppendOperand(operand);
2643 }
2644 
ToProto() const2645 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
2646   HloInstructionProto proto = HloInstruction::ToProto();
2647   proto.add_dimensions(dimension());
2648   return proto;
2649 }
2650 
ExtraAttributesToStringImpl(const HloPrintOptions &) const2651 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2652     const HloPrintOptions& /*options*/) const {
2653   return {StrCat("dimensions={", dimension(), "}")};
2654 }
2655 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2656 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
2657     const HloInstruction& other,
2658     const std::function<bool(const HloComputation*, const HloComputation*)>&
2659     /*eq_computations*/) const {
2660   const auto& casted_other =
2661       static_cast<const HloGetDimensionSizeInstruction&>(other);
2662   return dimension() == casted_other.dimension();
2663 }
2664 
2665 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2666 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2667     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2668     HloCloneContext* /*context*/) const {
2669   if (new_operands.size() != 1) {
2670     LOG(FATAL) << "expects 1 operand";
2671   }
2672   return absl::make_unique<HloGetDimensionSizeInstruction>(
2673       shape, new_operands[0], dimension());
2674 }
2675 
2676 }  // namespace xla
2677