• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17 
18 #include <deque>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 namespace xla {
39 namespace {
40 
41 using absl::CEscape;
42 using absl::StrAppend;
43 using absl::StrCat;
44 using absl::StrJoin;
45 
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)46 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
47                                        const HloInstruction* operand) {
48   const auto operand_indices = instruction->OperandIndices(operand);
49   return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
50     return instruction->IsElementwiseOnOperand(operand_index);
51   });
52 }
53 
PrecisionConfigToString(const PrecisionConfig & precision_config)54 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
55   if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
56         return static_cast<PrecisionConfig::Precision>(precision) ==
57                PrecisionConfig::DEFAULT;
58       })) {
59     return "";
60   }
61 
62   return StrCat(
63       "operand_precision={",
64       StrJoin(
65           precision_config.operand_precision(), ",",
66           [](string* out, int32 precision) {
67             CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
68             StrAppend(out,
69                       PrecisionToString(
70                           static_cast<PrecisionConfig::Precision>(precision)));
71           }),
72       "}");
73 }
74 }  // namespace
75 
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)76 HloBatchNormInstruction::HloBatchNormInstruction(
77     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
78     HloInstruction* scale, float epsilon, int64 feature_index)
79     : HloInstruction(opcode, shape),
80       epsilon_(epsilon),
81       feature_index_(feature_index) {
82   AppendOperand(operand);
83   AppendOperand(scale);
84 }
85 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const86 bool HloBatchNormInstruction::IdenticalSlowPath(
87     const HloInstruction& other,
88     const std::function<bool(const HloComputation*, const HloComputation*)>&
89         eq_computations) const {
90   const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
91   return feature_index() == casted_other.feature_index() &&
92          epsilon() == casted_other.epsilon();
93 }
94 
ToProto() const95 HloInstructionProto HloBatchNormInstruction::ToProto() const {
96   HloInstructionProto proto = HloInstruction::ToProto();
97   proto.set_epsilon(epsilon_);
98   proto.set_feature_index(feature_index_);
99   return proto;
100 }
101 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const102 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
103     const HloPrintOptions& options) const {
104   return {StrCat("epsilon=", epsilon()),
105           StrCat("feature_index=", feature_index())};
106 }
107 
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)108 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
109     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
110     HloInstruction* offset, float epsilon, int64 feature_index)
111     : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
112                               scale, epsilon, feature_index) {
113   AppendOperand(offset);
114 }
115 
116 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const117 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
118     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
119     HloCloneContext* context) const {
120   CHECK_EQ(new_operands.size(), 3);
121   return absl::make_unique<HloBatchNormTrainingInstruction>(
122       shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
123       feature_index());
124 }
125 
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)126 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
127     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
128     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
129     float epsilon, int64 feature_index)
130     : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
131                               scale, epsilon, feature_index) {
132   AppendOperand(offset);
133   AppendOperand(mean);
134   AppendOperand(variance);
135 }
136 
137 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const138 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
139     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
140     HloCloneContext* context) const {
141   CHECK_EQ(new_operands.size(), 5);
142   return absl::make_unique<HloBatchNormInferenceInstruction>(
143       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
144       new_operands[4], epsilon(), feature_index());
145 }
146 
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)147 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
148     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
149     HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
150     float epsilon, int64 feature_index)
151     : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
152                               epsilon, feature_index) {
153   AppendOperand(mean);
154   AppendOperand(variance);
155   AppendOperand(grad_output);
156 }
157 
158 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const159 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
160     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
161     HloCloneContext* context) const {
162   CHECK_EQ(new_operands.size(), 5);
163   return absl::make_unique<HloBatchNormGradInstruction>(
164       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
165       new_operands[4], epsilon(), feature_index());
166 }
167 
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)168 HloFftInstruction::HloFftInstruction(const Shape& shape,
169                                      HloInstruction* operand, FftType fft_type,
170                                      absl::Span<const int64> fft_length)
171     : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
172   fft_length_.assign(fft_length.begin(), fft_length.end());
173   AppendOperand(operand);
174 }
175 
ToProto() const176 HloInstructionProto HloFftInstruction::ToProto() const {
177   HloInstructionProto proto = HloInstruction::ToProto();
178   proto.set_fft_type(fft_type_);
179   for (int64 fft_len : fft_length_) {
180     proto.add_fft_length(fft_len);
181   }
182   return proto;
183 }
184 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const185 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
186     const HloPrintOptions& options) const {
187   return {StrCat("fft_type=", FftType_Name(fft_type())),
188           StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
189 }
190 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const191 bool HloFftInstruction::IdenticalSlowPath(
192     const HloInstruction& other,
193     const std::function<bool(const HloComputation*, const HloComputation*)>&
194         eq_computations) const {
195   const auto& casted_other = static_cast<const HloFftInstruction&>(other);
196   return fft_type() == casted_other.fft_type() &&
197          fft_length() == casted_other.fft_length();
198 }
199 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const200 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
201     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
202     HloCloneContext* context) const {
203   CHECK_EQ(new_operands.size(), 1);
204   return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
205                                               fft_length_);
206 }
207 
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)208 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
209                                                  HloInstruction* operand,
210                                                  bool is_cross_program_prefetch)
211     : HloInstruction(HloOpcode::kCopyStart, shape),
212       is_cross_program_prefetch_(is_cross_program_prefetch) {
213   AppendOperand(operand);
214 }
215 
ToProto() const216 HloInstructionProto HloCopyStartInstruction::ToProto() const {
217   HloInstructionProto proto = HloInstruction::ToProto();
218   proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
219   return proto;
220 }
221 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const222 std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
223     const HloPrintOptions& options) const {
224   std::vector<string> result;
225   if (is_cross_program_prefetch()) {
226     result.push_back("is_cross_program_prefetch=true");
227   }
228   return result;
229 }
230 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const231 bool HloCopyStartInstruction::IdenticalSlowPath(
232     const HloInstruction& other,
233     const std::function<bool(const HloComputation*, const HloComputation*)>&
234         eq_computations) const {
235   const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
236   return is_cross_program_prefetch() ==
237          casted_other.is_cross_program_prefetch();
238 }
239 
240 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const241 HloCopyStartInstruction::CloneWithNewOperandsImpl(
242     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
243     HloCloneContext* context) const {
244   CHECK_EQ(new_operands.size(), 1);
245   return absl::make_unique<HloCopyStartInstruction>(
246       shape, new_operands[0], is_cross_program_prefetch());
247 }
248 
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)249 HloCompareInstruction::HloCompareInstruction(
250     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
251     ComparisonDirection direction, absl::optional<Comparison::Type> type)
252     : HloInstruction(HloOpcode::kCompare, shape),
253       compare_(direction, type ? (*type)
254                                : Comparison::DefaultComparisonType(
255                                      lhs->shape().element_type())) {
256   AppendOperand(lhs);
257   AppendOperand(rhs);
258 }
259 
ToProto() const260 HloInstructionProto HloCompareInstruction::ToProto() const {
261   HloInstructionProto proto = HloInstruction::ToProto();
262   proto.set_comparison_direction(
263       ComparisonDirectionToString(compare_.GetDirection()));
264   proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
265   return proto;
266 }
267 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const268 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
269     const HloPrintOptions& options) const {
270   std::vector<string> result;
271   result.push_back(
272       StrCat("direction=", ComparisonDirectionToString(direction())));
273   if (compare_.GetType() !=
274       Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
275     result.push_back(
276         StrCat("type=", ComparisonTypeToString(compare_.GetType())));
277   }
278   return result;
279 }
280 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const281 bool HloCompareInstruction::IdenticalSlowPath(
282     const HloInstruction& other,
283     const std::function<bool(const HloComputation*, const HloComputation*)>&
284         eq_computations) const {
285   const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
286   return direction() == casted_other.direction();
287 }
288 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const289 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
290     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
291     HloCloneContext* context) const {
292   CHECK_EQ(new_operands.size(), 2);
293   return absl::make_unique<HloCompareInstruction>(
294       shape, new_operands[0], new_operands[1], direction(), type());
295 }
296 
297 namespace {
298 
299 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
300 // of "key=value" attribute strings generically, using protocol buffer
301 // reflection.
302 //
303 // Currently implements a small subset of cases; feel free to add more as
304 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)305 std::vector<string> AttributeProtoToStringVector(
306     const tensorflow::protobuf::Message& message) {
307   const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
308   std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
309   reflection->ListFields(message, &fields);
310 
311   std::vector<string> output;
312   for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
313     string s = absl::StrCat(field->name(), "=");
314     CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
315     switch (field->type()) {
316       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
317         bool val = reflection->GetBool(message, field);
318         absl::StrAppend(&s, val ? "true" : "false");
319         break;
320       }
321       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
322         const tensorflow::protobuf::EnumValueDescriptor* evd =
323             reflection->GetEnum(message, field);
324         absl::StrAppend(&s, evd->name());
325         break;
326       }
327       default:
328         LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
329     }
330     output.push_back(std::move(s));
331   }
332   return output;
333 }
334 
335 }  // namespace
336 
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)337 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
338     const Shape& shape, HloInstruction* a, HloInstruction* b,
339     const TriangularSolveOptions& options)
340     : HloInstruction(HloOpcode::kTriangularSolve, shape),
341       triangular_solve_options_(options) {
342   AppendOperand(a);
343   AppendOperand(b);
344 }
345 
ToProto() const346 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
347   HloInstructionProto proto = HloInstruction::ToProto();
348   *proto.mutable_triangular_solve_options() = triangular_solve_options_;
349   return proto;
350 }
351 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const352 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
353     const HloPrintOptions& options) const {
354   return AttributeProtoToStringVector(triangular_solve_options_);
355 }
356 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const357 bool HloTriangularSolveInstruction::IdenticalSlowPath(
358     const HloInstruction& other,
359     const std::function<bool(const HloComputation*, const HloComputation*)>&
360         eq_computations) const {
361   const auto& casted_other =
362       static_cast<const HloTriangularSolveInstruction&>(other);
363   const auto& options = triangular_solve_options();
364   const auto& other_options = casted_other.triangular_solve_options();
365 
366   return options.left_side() == other_options.left_side() &&
367          options.lower() == other_options.lower() &&
368          options.unit_diagonal() == other_options.unit_diagonal() &&
369          options.transpose_a() == other_options.transpose_a();
370 }
371 
372 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const373 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
374     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
375     HloCloneContext* context) const {
376   CHECK_EQ(new_operands.size(), 2);
377   return absl::make_unique<HloTriangularSolveInstruction>(
378       shape, new_operands[0], new_operands[1], triangular_solve_options());
379 }
380 
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)381 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
382                                                HloInstruction* a,
383                                                const CholeskyOptions& options)
384     : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
385   AppendOperand(a);
386 }
387 
ToProto() const388 HloInstructionProto HloCholeskyInstruction::ToProto() const {
389   HloInstructionProto proto = HloInstruction::ToProto();
390   *proto.mutable_cholesky_options() = cholesky_options_;
391   return proto;
392 }
393 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const394 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
395     const HloPrintOptions& options) const {
396   return AttributeProtoToStringVector(cholesky_options_);
397 }
398 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const399 bool HloCholeskyInstruction::IdenticalSlowPath(
400     const HloInstruction& other,
401     const std::function<bool(const HloComputation*, const HloComputation*)>&
402         eq_computations) const {
403   const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
404   const auto& options = cholesky_options();
405   const auto& other_options = casted_other.cholesky_options();
406 
407   return options.lower() == other_options.lower();
408 }
409 
410 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const411 HloCholeskyInstruction::CloneWithNewOperandsImpl(
412     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
413     HloCloneContext* context) const {
414   CHECK_EQ(new_operands.size(), 1);
415   return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
416                                                    cholesky_options());
417 }
418 
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)419 HloChannelInstruction::HloChannelInstruction(
420     HloOpcode opcode, const Shape& shape,
421     const absl::optional<int64>& channel_id)
422     : HloInstruction(opcode, shape), channel_id_(channel_id) {}
423 
set_channel_id(const absl::optional<int64> & channel_id)424 void HloChannelInstruction::set_channel_id(
425     const absl::optional<int64>& channel_id) {
426   channel_id_ = channel_id;
427 }
428 
ToProto() const429 HloInstructionProto HloChannelInstruction::ToProto() const {
430   HloInstructionProto proto = HloInstruction::ToProto();
431   if (channel_id_) {
432     CHECK_GT(channel_id_.value(), 0)
433         << "Non-positive channel id is equivalent to no channel id";
434     proto.set_channel_id(*channel_id_);
435   }
436   return proto;
437 }
438 
ExtraAttributesToStringImpl(const HloPrintOptions &) const439 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
440     const HloPrintOptions& /*options*/) const {
441   std::vector<string> result;
442   if (channel_id_) {
443     result.push_back(StrCat("channel_id=", *channel_id_));
444   }
445   return result;
446 }
447 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const448 bool HloChannelInstruction::IdenticalSlowPath(
449     const HloInstruction& other,
450     const std::function<bool(const HloComputation*, const HloComputation*)>&
451         eq_computations) const {
452   if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
453     return false;
454   }
455   const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
456   return channel_id() == casted_other.channel_id();
457 }
458 
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)459 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
460                                                const Shape& shape,
461                                                int64 channel_id,
462                                                bool is_host_transfer)
463     : HloChannelInstruction(opcode, shape, channel_id),
464       is_host_transfer_(is_host_transfer) {}
465 
ToProto() const466 HloInstructionProto HloSendRecvInstruction::ToProto() const {
467   HloInstructionProto proto = HloChannelInstruction::ToProto();
468   proto.set_is_host_transfer(is_host_transfer_);
469   return proto;
470 }
471 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const472 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
473     const HloPrintOptions& options) const {
474   std::vector<string> attrs =
475       HloChannelInstruction::ExtraAttributesToStringImpl(options);
476   if (is_host_transfer()) {
477     attrs.push_back("is_host_transfer=true");
478   }
479   return attrs;
480 }
481 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const482 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
483     const HloInstruction& other,
484     const std::function<bool(const HloComputation*, const HloComputation*)>&
485         eq_computations) const {
486   // Not yet supported.
487   return false;
488 }
489 
490 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)491 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
492                                        HloInstruction* token, int64 channel_id,
493                                        bool is_host_transfer)
494     : HloSendRecvInstruction(
495           HloOpcode::kSend,
496           ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
497                                      ShapeUtil::MakeShape(U32, {}),
498                                      ShapeUtil::MakeTokenShape()}),
499           channel_id, is_host_transfer) {
500   AppendOperand(operand);
501   AppendOperand(token);
502 }
503 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const504 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
505     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
506     HloCloneContext* context) const {
507   CHECK_EQ(new_operands.size(), 2);
508   return absl::make_unique<HloSendInstruction>(
509       new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
510 }
511 
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)512 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
513                                                bool is_host_transfer)
514     : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
515                              CHECK_NOTNULL(operand)->channel_id().value(),
516                              is_host_transfer) {
517   AppendOperand(operand);
518 }
519 
520 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const521 HloSendDoneInstruction::CloneWithNewOperandsImpl(
522     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
523     HloCloneContext* context) const {
524   CHECK_EQ(new_operands.size(), 1);
525   return absl::make_unique<HloSendDoneInstruction>(
526       Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
527 }
528 
529 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)530 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
531                                        HloInstruction* token, int64 channel_id,
532                                        bool is_host_transfer)
533     : HloSendRecvInstruction(
534           HloOpcode::kRecv,
535           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
536                                      ShapeUtil::MakeTokenShape()}),
537           channel_id, is_host_transfer) {
538   AppendOperand(token);
539 }
540 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const541 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
542     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
543     HloCloneContext* context) const {
544   CHECK_EQ(new_operands.size(), 1);
545   return absl::make_unique<HloRecvInstruction>(
546       ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
547       is_host_transfer());
548 }
549 
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)550 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
551                                                bool is_host_transfer)
552     : HloSendRecvInstruction(
553           HloOpcode::kRecvDone,
554           ShapeUtil::MakeTupleShape(
555               {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
556                ShapeUtil::MakeTokenShape()}),
557           CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
558   AppendOperand(operand);
559 }
560 
561 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const562 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
563     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
564     HloCloneContext* context) const {
565   CHECK_EQ(new_operands.size(), 1);
566   return absl::make_unique<HloRecvDoneInstruction>(
567       Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
568 }
569 
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)570 HloCollectiveInstruction::HloCollectiveInstruction(
571     HloOpcode opcode, const Shape& shape,
572     absl::Span<HloInstruction* const> operands,
573     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
574     const absl::optional<int64>& channel_id)
575     : HloChannelInstruction(opcode, shape, channel_id),
576       replica_groups_(replica_groups),
577       constrain_layout_(constrain_layout) {
578   for (auto operand : operands) {
579     AppendOperand(operand);
580   }
581 }
582 
ToProto() const583 HloInstructionProto HloCollectiveInstruction::ToProto() const {
584   HloInstructionProto proto = HloChannelInstruction::ToProto();
585   *proto.mutable_replica_groups() = {replica_groups_.begin(),
586                                      replica_groups_.end()};
587   proto.set_constrain_layout(constrain_layout_);
588   return proto;
589 }
590 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const591 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
592     const HloPrintOptions& options) const {
593   std::vector<string> result =
594       HloChannelInstruction::ExtraAttributesToStringImpl(options);
595   result.push_back(
596       StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
597   if (constrain_layout_) {
598     result.push_back("constrain_layout=true");
599   }
600   return result;
601 }
602 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const603 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
604     const HloInstruction& other,
605     const std::function<bool(const HloComputation*, const HloComputation*)>&
606         eq_computations) const {
607   const auto& casted_other =
608       static_cast<const HloCollectiveInstruction&>(other);
609   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
610              other, eq_computations) &&
611          constrain_layout() == casted_other.constrain_layout() &&
612          absl::c_equal(replica_groups(), casted_other.replica_groups(),
613                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
614                          return absl::c_equal(a.replica_ids(), b.replica_ids());
615                        });
616 }
617 
HloAllGatherInstruction(const Shape & shape,HloInstruction * operand,int64 all_gather_dimension,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)618 HloAllGatherInstruction::HloAllGatherInstruction(
619     const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
620     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
621     const absl::optional<int64>& channel_id, bool use_global_device_ids)
622     : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand},
623                                replica_groups, constrain_layout, channel_id),
624       all_gather_dimension_(all_gather_dimension),
625       use_global_device_ids_(use_global_device_ids) {}
626 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const627 std::vector<string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
628     const HloPrintOptions& options) const {
629   std::vector<string> result =
630       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
631   result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
632   if (use_global_device_ids_) {
633     result.push_back("use_global_device_ids=true");
634   }
635   return result;
636 }
637 
638 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const639 HloAllGatherInstruction::CloneWithNewOperandsImpl(
640     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
641     HloCloneContext* /*context*/) const {
642   return absl::make_unique<HloAllGatherInstruction>(
643       shape, new_operands[0], all_gather_dimension(), replica_groups(),
644       constrain_layout(), channel_id(), use_global_device_ids());
645 }
646 
ToProto() const647 HloInstructionProto HloAllGatherInstruction::ToProto() const {
648   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
649   proto.add_dimensions(all_gather_dimension_);
650   proto.set_use_global_device_ids(use_global_device_ids_);
651   return proto;
652 }
653 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const654 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
655     const HloInstruction& other,
656     const std::function<bool(const HloComputation*, const HloComputation*)>&
657         eq_computations) const {
658   const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
659   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
660              other, eq_computations) &&
661          all_gather_dimension_ == casted_other.all_gather_dimension() &&
662          use_global_device_ids() == casted_other.use_global_device_ids();
663 }
664 
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)665 HloAllReduceInstruction::HloAllReduceInstruction(
666     const Shape& shape, absl::Span<HloInstruction* const> operands,
667     HloComputation* reduce_computation,
668     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
669     const absl::optional<int64>& channel_id, bool use_global_device_ids)
670     : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
671                                replica_groups, constrain_layout, channel_id),
672       use_global_device_ids_(use_global_device_ids) {
673   AppendComputation(reduce_computation);
674 }
675 
IsNoop() const676 bool HloAllReduceInstruction::IsNoop() const {
677   for (const auto& replica_group : replica_groups()) {
678     if (replica_group.replica_ids().size() != 1) {
679       return false;
680     }
681   }
682   return !channel_id();
683 }
684 
ToProto() const685 HloInstructionProto HloAllReduceInstruction::ToProto() const {
686   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
687   proto.set_use_global_device_ids(use_global_device_ids_);
688   return proto;
689 }
690 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const691 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
692     const HloPrintOptions& options) const {
693   std::vector<string> result =
694       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
695   if (use_global_device_ids_) {
696     result.push_back("use_global_device_ids=true");
697   }
698   return result;
699 }
700 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const701 bool HloAllReduceInstruction::IdenticalSlowPathIgnoringChannelIdValues(
702     const HloInstruction& other,
703     const std::function<bool(const HloComputation*, const HloComputation*)>&
704         eq_computations) const {
705   const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
706   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
707              other, eq_computations) &&
708          constrain_layout() == casted_other.constrain_layout() &&
709          use_global_device_ids() == casted_other.use_global_device_ids() &&
710          eq_computations(to_apply(), casted_other.to_apply());
711 }
712 
713 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const714 HloAllReduceInstruction::CloneWithNewOperandsImpl(
715     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
716     HloCloneContext* /*context*/) const {
717   return absl::make_unique<HloAllReduceInstruction>(
718       shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
719       channel_id(), use_global_device_ids());
720 }
721 
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)722 HloAllToAllInstruction::HloAllToAllInstruction(
723     const Shape& shape, absl::Span<HloInstruction* const> operands,
724     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
725     const absl::optional<int64>& channel_id,
726     const absl::optional<int64>& split_dimension)
727     : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
728                                replica_groups, constrain_layout, channel_id),
729       split_dimension_(split_dimension) {}
730 
731 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const732 HloAllToAllInstruction::CloneWithNewOperandsImpl(
733     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
734     HloCloneContext* /*context*/) const {
735   return absl::make_unique<HloAllToAllInstruction>(
736       shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
737       split_dimension());
738 }
739 
ToProto() const740 HloInstructionProto HloAllToAllInstruction::ToProto() const {
741   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
742   if (split_dimension_) {
743     proto.add_dimensions(*split_dimension_);
744   }
745   return proto;
746 }
747 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const748 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
749     const HloPrintOptions& options) const {
750   std::vector<string> result =
751       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
752   if (split_dimension_) {
753     result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
754   }
755   return result;
756 }
757 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const758 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
759     const HloInstruction& other,
760     const std::function<bool(const HloComputation*, const HloComputation*)>&
761         eq_computations) const {
762   const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
763   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
764              other, eq_computations) &&
765          split_dimension_ == casted_other.split_dimension();
766 }
767 
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)768 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
769     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
770     const std::vector<std::pair<int64, int64>>& source_target_pairs,
771     const absl::optional<int64>& channel_id)
772     : HloChannelInstruction(opcode, shape, channel_id),
773       source_target_pairs_(source_target_pairs) {
774   AppendOperand(operand);
775 }
776 
ToProto() const777 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
778   HloInstructionProto proto = HloChannelInstruction::ToProto();
779   for (const auto& pair : source_target_pairs()) {
780     auto* proto_pair = proto.add_source_target_pairs();
781     proto_pair->set_source(pair.first);
782     proto_pair->set_target(pair.second);
783   }
784   return proto;
785 }
786 
787 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const788 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
789     const HloPrintOptions& options) const {
790   std::vector<string> result =
791       HloChannelInstruction::ExtraAttributesToStringImpl(options);
792   std::vector<string> strs;
793   for (const auto& pair : source_target_pairs()) {
794     strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
795   }
796   result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
797   return result;
798 }
799 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const800 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
801     const HloInstruction& other,
802     const std::function<bool(const HloComputation*, const HloComputation*)>&
803         eq_computations) const {
804   if (opcode() != other.opcode()) {
805     return false;
806   }
807   const auto& casted_other =
808       static_cast<const HloCollectivePermuteInstruction&>(other);
809   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
810              other, eq_computations) &&
811          absl::c_equal(source_target_pairs(),
812                        casted_other.source_target_pairs(),
813                        [](const std::pair<int64, int64>& a,
814                           const std::pair<int64, int64>& b) { return a == b; });
815 }
816 
817 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const818 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
819     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
820     HloCloneContext* /*context*/) const {
821   return absl::make_unique<HloCollectivePermuteInstruction>(
822       opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
823 }
824 
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)825 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
826                                              HloInstruction* operand,
827                                              absl::Span<const int64> dimensions)
828     : HloInstruction(HloOpcode::kReverse, shape),
829       dimensions_(dimensions.begin(), dimensions.end()) {
830   AppendOperand(operand);
831 }
832 
ToProto() const833 HloInstructionProto HloReverseInstruction::ToProto() const {
834   HloInstructionProto proto = HloInstruction::ToProto();
835   for (int64 dimension : dimensions_) {
836     proto.add_dimensions(dimension);
837   }
838   return proto;
839 }
840 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const841 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
842     const HloPrintOptions& options) const {
843   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
844 }
845 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const846 bool HloReverseInstruction::IdenticalSlowPath(
847     const HloInstruction& other,
848     const std::function<bool(const HloComputation*, const HloComputation*)>&
849         eq_computations) const {
850   const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
851   return dimensions() == casted_other.dimensions();
852 }
853 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const854 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
855     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
856     HloCloneContext* context) const {
857   CHECK_EQ(new_operands.size(), 1);
858   return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
859                                                   dimensions());
860 }
861 
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)862 HloConcatenateInstruction::HloConcatenateInstruction(
863     const Shape& shape, absl::Span<HloInstruction* const> operands,
864     int64 dimension)
865     : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
866   for (auto operand : operands) {
867     AppendOperand(operand);
868   }
869 }
870 
ToProto() const871 HloInstructionProto HloConcatenateInstruction::ToProto() const {
872   HloInstructionProto proto = HloInstruction::ToProto();
873   for (int64 dimension : dimensions_) {
874     proto.add_dimensions(dimension);
875   }
876   return proto;
877 }
878 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const879 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
880     const HloPrintOptions& options) const {
881   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
882 }
883 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const884 bool HloConcatenateInstruction::IdenticalSlowPath(
885     const HloInstruction& other,
886     const std::function<bool(const HloComputation*, const HloComputation*)>&
887         eq_computations) const {
888   const auto& casted_other =
889       static_cast<const HloConcatenateInstruction&>(other);
890   return dimensions() == casted_other.dimensions();
891 }
892 
893 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const894 HloConcatenateInstruction::CloneWithNewOperandsImpl(
895     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
896     HloCloneContext* context) const {
897   return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
898                                                       dimensions(0));
899 }
900 
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)901 HloReduceInstruction::HloReduceInstruction(
902     const Shape& shape, absl::Span<HloInstruction* const> args,
903     absl::Span<const int64> dimensions_to_reduce,
904     HloComputation* reduce_computation)
905     : HloInstruction(HloOpcode::kReduce, shape),
906       dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
907   for (HloInstruction* arg : args) {
908     AppendOperand(arg);
909   }
910   AppendComputation(reduce_computation);
911 }
912 
ToProto() const913 HloInstructionProto HloReduceInstruction::ToProto() const {
914   HloInstructionProto proto = HloInstruction::ToProto();
915   for (int64 dimension : dimensions_) {
916     proto.add_dimensions(dimension);
917   }
918   return proto;
919 }
920 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const921 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
922     const HloPrintOptions& options) const {
923   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
924 }
925 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const926 bool HloReduceInstruction::IdenticalSlowPath(
927     const HloInstruction& other,
928     const std::function<bool(const HloComputation*, const HloComputation*)>&
929         eq_computations) const {
930   const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
931   // Reduction results are determined by the reduction dimension and the
932   // reduction computation.
933   return dimensions() == casted_other.dimensions() &&
934          eq_computations(to_apply(), casted_other.to_apply());
935 }
936 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const937 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
938     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
939     HloCloneContext* context) const {
940   CHECK_EQ(new_operands.size() % 2, 0);
941   return absl::make_unique<HloReduceInstruction>(shape, new_operands,
942                                                  dimensions(), to_apply());
943 }
944 
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)945 HloSortInstruction::HloSortInstruction(
946     const Shape& shape, int64 dimension,
947     absl::Span<HloInstruction* const> operands, HloComputation* compare,
948     bool is_stable)
949     : HloInstruction(HloOpcode::kSort, shape),
950       dimensions_({dimension}),
951       is_stable_(is_stable) {
952   for (auto* value : operands) {
953     AppendOperand(value);
954   }
955   AppendComputation(compare);
956 }
957 
ToProto() const958 HloInstructionProto HloSortInstruction::ToProto() const {
959   HloInstructionProto proto = HloInstruction::ToProto();
960   for (int64 dimension : dimensions_) {
961     proto.add_dimensions(dimension);
962   }
963   proto.set_is_stable(is_stable());
964   return proto;
965 }
966 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const967 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
968     const HloPrintOptions& options) const {
969   std::vector<string> attrs;
970   attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
971   if (is_stable()) {
972     attrs.push_back("is_stable=true");
973   }
974   return attrs;
975 }
976 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const977 bool HloSortInstruction::IdenticalSlowPath(
978     const HloInstruction& other,
979     const std::function<bool(const HloComputation*, const HloComputation*)>&
980         eq_computations) const {
981   const auto& casted_other = static_cast<const HloSortInstruction&>(other);
982   if (dimensions() != casted_other.dimensions()) {
983     return false;
984   }
985   if (is_stable() != casted_other.is_stable()) {
986     return false;
987   }
988   return eq_computations(to_apply(), other.to_apply());
989 }
990 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const991 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
992     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
993     HloCloneContext* context) const {
994   return absl::make_unique<HloSortInstruction>(
995       shape, dimensions(0), new_operands, to_apply(), is_stable());
996 }
997 
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)998 HloTransposeInstruction::HloTransposeInstruction(
999     const Shape& shape, HloInstruction* operand,
1000     absl::Span<const int64> dimensions)
1001     : HloInstruction(HloOpcode::kTranspose, shape),
1002       dimensions_(dimensions.begin(), dimensions.end()) {
1003   AppendOperand(operand);
1004 }
1005 
IsRank2Transpose() const1006 bool HloTransposeInstruction::IsRank2Transpose() const {
1007   return dimensions() == std::vector<int64>({1, 0}) &&
1008          shape().dimensions_size() == 2 &&
1009          std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1010                     operand(0)->shape().dimensions().rbegin());
1011 }
1012 
ToProto() const1013 HloInstructionProto HloTransposeInstruction::ToProto() const {
1014   HloInstructionProto proto = HloInstruction::ToProto();
1015   for (int64 dimension : dimensions_) {
1016     proto.add_dimensions(dimension);
1017   }
1018   return proto;
1019 }
1020 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1021 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
1022     const HloPrintOptions& options) const {
1023   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1024 }
1025 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1026 bool HloTransposeInstruction::IdenticalSlowPath(
1027     const HloInstruction& other,
1028     const std::function<bool(const HloComputation*, const HloComputation*)>&
1029         eq_computations) const {
1030   const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
1031   return dimensions() == casted_other.dimensions();
1032 }
1033 
1034 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1035 HloTransposeInstruction::CloneWithNewOperandsImpl(
1036     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1037     HloCloneContext* context) const {
1038   CHECK_EQ(new_operands.size(), 1);
1039   return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1040                                                     dimensions());
1041 }
1042 
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)1043 HloBroadcastInstruction::HloBroadcastInstruction(
1044     const Shape& shape, HloInstruction* operand,
1045     absl::Span<const int64> broadcast_dimension)
1046     : HloInstruction(HloOpcode::kBroadcast, shape),
1047       dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
1048   AppendOperand(operand);
1049 }
1050 
ToProto() const1051 HloInstructionProto HloBroadcastInstruction::ToProto() const {
1052   HloInstructionProto proto = HloInstruction::ToProto();
1053   for (int64 dimension : dimensions_) {
1054     proto.add_dimensions(dimension);
1055   }
1056   return proto;
1057 }
1058 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1059 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
1060     const HloPrintOptions& options) const {
1061   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1062 }
1063 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1064 bool HloBroadcastInstruction::IdenticalSlowPath(
1065     const HloInstruction& other,
1066     const std::function<bool(const HloComputation*, const HloComputation*)>&
1067         eq_computations) const {
1068   const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
1069   return dimensions() == casted_other.dimensions();
1070 }
1071 
1072 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1073 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1074     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1075     HloCloneContext* context) const {
1076   CHECK_EQ(new_operands.size(), 1);
1077   return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1078                                                     dimensions());
1079 }
1080 
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1081 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1082     const Shape& shape, HloInstruction* data_operand,
1083     absl::Span<HloInstruction* const> dim_sizes)
1084     : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1085   AppendOperand(data_operand);
1086   for (auto operand : dim_sizes) {
1087     AppendOperand(operand);
1088   }
1089 }
1090 
1091 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1092 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1093     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1094     HloCloneContext* context) const {
1095   CHECK_GE(new_operands.size(), 1);
1096   return absl::make_unique<HloDynamicReshapeInstruction>(
1097       shape, new_operands[0], new_operands.subspan(1));
1098 }
1099 
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1100 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1101                                              HloInstruction* operand,
1102                                              int64 inferred_dimension)
1103     : HloInstruction(HloOpcode::kReshape, shape),
1104       inferred_dimension_(inferred_dimension) {
1105   AppendOperand(operand);
1106 }
1107 
ToProto() const1108 HloInstructionProto HloReshapeInstruction::ToProto() const {
1109   HloInstructionProto proto = HloInstruction::ToProto();
1110   if (inferred_dimension_ != -1) {
1111     proto.add_dimensions(inferred_dimension_);
1112   }
1113   return proto;
1114 }
1115 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1116 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1117     const HloPrintOptions& options) const {
1118   if (inferred_dimension() == -1) {
1119     return {};
1120   }
1121   return {StrCat("inferred_dimension=", inferred_dimension())};
1122 }
1123 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1124 bool HloReshapeInstruction::IdenticalSlowPath(
1125     const HloInstruction& other,
1126     const std::function<bool(const HloComputation*, const HloComputation*)>&
1127         eq_computations) const {
1128   const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1129   return inferred_dimension() == casted_other.inferred_dimension();
1130 }
1131 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1132 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1133     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1134     HloCloneContext* context) const {
1135   CHECK_EQ(new_operands.size(), 1);
1136   return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1137                                                   inferred_dimension());
1138 }
1139 
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1140 HloMapInstruction::HloMapInstruction(const Shape& shape,
1141                                      absl::Span<HloInstruction* const> operands,
1142                                      HloComputation* map_computation)
1143     : HloInstruction(HloOpcode::kMap, shape) {
1144   for (auto operand : operands) {
1145     AppendOperand(operand);
1146   }
1147   AppendComputation(map_computation);
1148   // TODO(b/65689298) Remove code below once Map is generalized to accept
1149   // arbitrary map dimensions.
1150   dimensions_.resize(shape.rank());
1151   std::iota(dimensions_.begin(), dimensions_.end(), 0);
1152 }
1153 
ToProto() const1154 HloInstructionProto HloMapInstruction::ToProto() const {
1155   HloInstructionProto proto = HloInstruction::ToProto();
1156   for (int64 dimension : dimensions_) {
1157     proto.add_dimensions(dimension);
1158   }
1159   return proto;
1160 }
1161 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1162 bool HloMapInstruction::IsElementwiseImpl(
1163     const absl::optional<int64>& operand_idx) const {
1164   if (!dimensions().empty()) {
1165     // Check that the map is executed in elementwise compatible dimensions.
1166     if (dimensions().size() != shape().dimensions_size()) {
1167       return false;
1168     }
1169     for (int i = 0; i < dimensions().size(); ++i) {
1170       if (dimensions()[i] != i) {
1171         return false;
1172       }
1173     }
1174   }
1175   return true;
1176 }
1177 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1178 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1179     const HloPrintOptions& options) const {
1180   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1181 }
1182 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1183 bool HloMapInstruction::IdenticalSlowPath(
1184     const HloInstruction& other,
1185     const std::function<bool(const HloComputation*, const HloComputation*)>&
1186         eq_computations) const {
1187   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1188   return eq_computations(to_apply(), casted_other.to_apply()) &&
1189          dimensions() == casted_other.dimensions();
1190 }
1191 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1192 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1193     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1194     HloCloneContext* context) const {
1195   return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1196 }
1197 
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1198 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1199                                          HloInstruction* operand,
1200                                          absl::Span<const int64> start_indices,
1201                                          absl::Span<const int64> limit_indices,
1202                                          absl::Span<const int64> strides)
1203     : HloInstruction(HloOpcode::kSlice, shape),
1204       slice_starts_(start_indices.begin(), start_indices.end()),
1205       slice_limits_(limit_indices.begin(), limit_indices.end()),
1206       slice_strides_(strides.begin(), strides.end()) {
1207   AppendOperand(operand);
1208   // For backward compatibility with old serialized computations: if there are
1209   // no strides, assume all strides are 1.
1210   // TODO(b/63317920): remove this code.
1211   if (slice_strides_.empty()) {
1212     slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1213   }
1214 }
1215 
ToProto() const1216 HloInstructionProto HloSliceInstruction::ToProto() const {
1217   HloInstructionProto proto = HloInstruction::ToProto();
1218   for (int i = 0; i < slice_starts_.size(); ++i) {
1219     auto* slice_dimension = proto.add_slice_dimensions();
1220     slice_dimension->set_start(slice_starts_[i]);
1221     slice_dimension->set_limit(slice_limits_[i]);
1222     slice_dimension->set_stride(slice_strides_[i]);
1223   }
1224   return proto;
1225 }
1226 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1227 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1228     const HloPrintOptions& options) const {
1229   std::vector<string> bounds;
1230   bounds.reserve(slice_starts_.size());
1231   const bool omit_stride =
1232       absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
1233   for (int i = 0; i < slice_starts_.size(); ++i) {
1234     string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1235     bounds.push_back(
1236         StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1237   }
1238   return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1239 }
1240 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1241 bool HloSliceInstruction::IdenticalSlowPath(
1242     const HloInstruction& other,
1243     const std::function<bool(const HloComputation*, const HloComputation*)>&
1244         eq_computations) const {
1245   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1246   return slice_starts_ == other_slice.slice_starts_ &&
1247          slice_limits_ == other_slice.slice_limits_ &&
1248          slice_strides_ == other_slice.slice_strides_;
1249 }
1250 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1251 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1252     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1253     HloCloneContext* context) const {
1254   CHECK_EQ(new_operands.size(), 1);
1255   return absl::make_unique<HloSliceInstruction>(
1256       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1257 }
1258 
HloConstantInstruction(Literal literal)1259 HloConstantInstruction::HloConstantInstruction(Literal literal)
1260     : HloInstruction(HloOpcode::kConstant, literal.shape()),
1261       literal_(std::move(literal)) {}
1262 
HloConstantInstruction(Literal literal,const Shape & shape)1263 HloConstantInstruction::HloConstantInstruction(Literal literal,
1264                                                const Shape& shape)
1265     : HloInstruction(HloOpcode::kConstant, shape),
1266       literal_(std::move(literal)) {}
1267 
HloConstantInstruction(const Shape & shape)1268 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1269     : HloInstruction(HloOpcode::kConstant, shape) {}
1270 
ToProto() const1271 HloInstructionProto HloConstantInstruction::ToProto() const {
1272   HloInstructionProto proto = HloInstruction::ToProto();
1273   if (literal_.has_value()) {
1274     *proto.mutable_literal() = literal_->ToProto();
1275   }
1276   return proto;
1277 }
1278 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1279 bool HloConstantInstruction::IsElementwiseImpl(
1280     const absl::optional<int64>& operand_idx) const {
1281   return true;
1282 }
1283 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1284 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1285                                               const ShapeIndex& shape_index) {
1286   Shape* mutable_array_subshape =
1287       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1288   CHECK(mutable_array_subshape->IsArray());
1289 
1290   // Normally array_subshape will always have a layout, but this invariant is
1291   // temporarily broken in LayoutAssignment::AssignLayouts.
1292 
1293   if (!mutable_array_subshape->has_layout() ||
1294       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1295     *literal_ = literal_->Relayout(new_layout, shape_index);
1296     *mutable_array_subshape->mutable_layout() = new_layout;
1297   }
1298 }
1299 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1300 bool HloConstantInstruction::IdenticalSlowPath(
1301     const HloInstruction& other,
1302     const std::function<bool(const HloComputation*, const HloComputation*)>&
1303         eq_computations) const {
1304   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1305   return literal() == other_slice.literal();
1306 }
1307 
1308 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1309 HloConstantInstruction::CloneWithNewOperandsImpl(
1310     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1311     HloCloneContext* context) const {
1312   CHECK(literal_.has_value());
1313   // Literal's shape may have no/different tiling info. Use this instruction's
1314   // shape instead.
1315   CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1316                                                   this->shape()));
1317   return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1318                                                    this->shape());
1319 }
1320 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1321 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1322     const HloPrintOptions& options,
1323     CanonicalNameMap* canonical_name_map) const {
1324   string operands;
1325   // For constants, show the actual value in place of an empty operand list.
1326   if (literal_.has_value() &&
1327       ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1328        options.print_large_constants())) {
1329     // Literal::ToString emits multidimensional arrays over multiple
1330     // lines. Compact this into one line by stripping out white space.
1331     operands = literal_->ToStringWithoutShapeOneline();
1332   } else {
1333     // Do not show large constants or tuples.
1334     operands = "{...}";
1335   }
1336   return operands;
1337 }
1338 
HloTraceInstruction(const string & tag,HloInstruction * operand)1339 HloTraceInstruction::HloTraceInstruction(const string& tag,
1340                                          HloInstruction* operand)
1341     : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1342       literal_(LiteralUtil::CreateR1U8(tag)) {
1343   AppendOperand(operand);
1344   operand->set_tracing(this);
1345 }
1346 
ToProto() const1347 HloInstructionProto HloTraceInstruction::ToProto() const {
1348   HloInstructionProto proto = HloInstruction::ToProto();
1349   *proto.mutable_literal() = literal_.ToProto();
1350   return proto;
1351 }
1352 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1353 bool HloTraceInstruction::IdenticalSlowPath(
1354     const HloInstruction& other,
1355     const std::function<bool(const HloComputation*, const HloComputation*)>&
1356         eq_computations) const {
1357   return false;
1358 }
1359 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1360 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1361     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1362     HloCloneContext* context) const {
1363   LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1364 }
1365 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1366 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1367                                            FusionKind fusion_kind,
1368                                            HloInstruction* fused_root)
1369     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1370   CHECK(fused_root != nullptr);
1371   SetAndSanitizeName("fusion");
1372   set_parent(fused_root->parent());
1373   set_metadata(fused_root->metadata());
1374   CloneAndFuseInternal(fused_root);
1375 }
1376 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1377 HloFusionInstruction::HloFusionInstruction(
1378     const Shape& shape, FusionKind fusion_kind,
1379     absl::Span<HloInstruction* const> operands,
1380     HloComputation* fusion_computation)
1381     : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1382   for (auto operand : operands) {
1383     AppendOperand(operand);
1384   }
1385   SetAndSanitizeName("fusion");
1386   AppendComputation(fusion_computation);
1387   fusion_computation->SetFusionInstruction(this);
1388 }
1389 
ToCategory() const1390 string HloFusionInstruction::ToCategory() const {
1391   switch (fusion_kind()) {
1392     case FusionKind::kLoop:
1393       return "loop fusion";
1394     case FusionKind::kInput:
1395       return "input fusion";
1396     case FusionKind::kOutput:
1397       return "output fusion";
1398     case FusionKind::kCustom:
1399       return "custom fusion";
1400   }
1401 }
1402 
ToProto() const1403 HloInstructionProto HloFusionInstruction::ToProto() const {
1404   HloInstructionProto proto = HloInstruction::ToProto();
1405   proto.set_fusion_kind(xla::ToString(fusion_kind()));
1406   proto.add_called_computation_ids(
1407       fused_instructions_computation()->unique_id());
1408   return proto;
1409 }
1410 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1411 bool HloFusionInstruction::IsElementwiseImpl(
1412     const absl::optional<int64>& operand_idx) const {
1413   if (!operand_idx.has_value()) {
1414     for (auto* fused : fused_instructions()) {
1415       if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1416         return false;
1417       }
1418     }
1419     return true;
1420   }
1421   // A loop-fusion is elementwise on an operand if all operations (computed
1422   // using BFS) between the operand and the fused root are elementwise.
1423   std::deque<HloInstruction*> worklist;
1424   std::unordered_set<const HloInstruction*> visited;
1425   worklist.push_back(fused_parameter(operand_idx.value()));
1426   visited.insert(fused_parameter(operand_idx.value()));
1427   while (!worklist.empty()) {
1428     HloInstruction* operand = worklist.front();
1429     worklist.pop_front();
1430     for (HloInstruction* user : operand->users()) {
1431       CHECK_GE(user->unique_id(), 0);
1432       if (ContainsKey(visited, user)) {
1433         continue;
1434       }
1435       if (user->IsElementwise() ||
1436           IsInstructionElementwiseOnOperand(user, operand)) {
1437         worklist.push_back(user);
1438         visited.insert(user);
1439       } else {
1440         return false;
1441       }
1442     }
1443   }
1444   return true;
1445 }
1446 
AddFusionOperand(HloInstruction * new_operand)1447 HloInstruction* HloFusionInstruction::AddFusionOperand(
1448     HloInstruction* new_operand) {
1449   CHECK_EQ(operand_count(),
1450            fused_instructions_computation()->parameter_instructions().size());
1451   const int64 param_no = operand_count();
1452   string param_name = StrCat("param_", param_no);
1453   HloInstruction* fused_parameter =
1454       fused_instructions_computation()->AddParameter(
1455           HloInstruction::CreateParameter(param_no, new_operand->shape(),
1456                                           param_name));
1457   AppendOperand(new_operand);
1458   return fused_parameter;
1459 }
1460 
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1461 void HloFusionInstruction::MergeFusionInstruction(
1462     HloFusionInstruction* instruction_to_merge) {
1463   CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1464   // Clone the instruction from which to merge fused instructions.
1465   std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1466   HloFusionInstruction* cloned_fusion =
1467       static_cast<HloFusionInstruction*>(cloned.get());
1468   // Replace uses of fused parameters with the corresponding operand of the
1469   // fusion.  Add all non-parameter fused instructions to
1470   // 'unfused_instructions' to be merged into 'this'.  This is done in reverse
1471   // post order.
1472   std::vector<HloInstruction*> unfused_instructions;
1473   auto fused_instructions = cloned_fusion->fused_instructions_computation()
1474                                 ->MakeInstructionPostOrder();
1475   for (auto fused_it = fused_instructions.rbegin();
1476        fused_it != fused_instructions.rend(); ++fused_it) {
1477     auto fused_instruction = *fused_it;
1478     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1479       TF_CHECK_OK(
1480           fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1481               fused_instruction->parameter_number())));
1482     } else {
1483       unfused_instructions.push_back(fused_instruction);
1484     }
1485   }
1486 
1487   // If there are no unfused instructions, the fused computation must consist
1488   // only of kParameter instructions. Make the operand of the corresponding
1489   // parameter number the new root.
1490   HloInstruction* unfused_root =
1491       unfused_instructions.empty()
1492           ? instruction_to_merge->mutable_operand(
1493                 instruction_to_merge->fused_instructions_computation()
1494                     ->root_instruction()
1495                     ->parameter_number())
1496           : unfused_instructions.front();
1497   CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1498         unfused_instructions.empty());
1499   // Replace instruction_to_merge use of 'this' with unfused_root.
1500   TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1501 
1502   // Build a dummy root for the cloned fusion as we may remove the original root
1503   // in the fusion process.
1504   if (!unfused_instructions.empty()) {
1505     HloComputation* computation = unfused_root->parent();
1506     auto* dummy_root = computation->AddInstruction(
1507         HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1508     computation->set_root_instruction(dummy_root,
1509                                       /*accept_different_shape=*/true);
1510   }
1511 
1512   // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1513   // we remove it from the closed fusion node. This is so that we don't add
1514   // extra users to the producer of that instruction (we use user count to
1515   // decide if a side-effectful instruction is fusible).
1516   for (auto& instruction : unfused_instructions) {
1517     auto* fused = FuseInstruction(instruction);
1518     TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1519     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1520   }
1521   CHECK_EQ(0, cloned_fusion->user_count());
1522   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1523       cloned_fusion->fused_instructions_computation()));
1524 }
1525 
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1526 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1527     HloFusionInstruction* instruction_to_merge) {
1528   // Add all non-parameter fused instructions to 'unfused_instructions' to be
1529   // merged into 'this'. `old_to_new' maps the instructions in the fused node
1530   // to the disassembled fusion instructions.
1531   // Note that we add the unfused instructions to this->parent_ computation.
1532   // This is necessary because the unique_id needs for an instruction and
1533   // it's only added when inserting to the computation.
1534   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1535   std::vector<HloInstruction*> unfused_instructions;
1536   auto computation_to_merge =
1537       instruction_to_merge->fused_instructions_computation();
1538   auto post_order = computation_to_merge->MakeInstructionPostOrder();
1539   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1540     auto fused_instruction = *rit;
1541     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1542       InsertOrDie(&old_to_new, fused_instruction,
1543                   instruction_to_merge->mutable_operand(
1544                       fused_instruction->parameter_number()));
1545       continue;
1546     }
1547 
1548     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1549     // which clones again. This can be improved.
1550     auto cloned_instruction =
1551         parent()->AddInstruction(fused_instruction->Clone());
1552     unfused_instructions.push_back(cloned_instruction);
1553     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1554   }
1555   for (auto unfused_instruction : unfused_instructions) {
1556     for (int64 index = 0; index < unfused_instruction->operand_count();
1557          index++) {
1558       auto new_operand =
1559           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1560       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1561     }
1562   }
1563 
1564   // If there are no unfused instructions, the fused computation must consist
1565   // only of kParameter instructions. Make the operand of the corresponding
1566   // parameter number the new root.
1567   HloInstruction* unfused_root =
1568       unfused_instructions.empty()
1569           ? instruction_to_merge->mutable_operand(
1570                 instruction_to_merge->fused_instructions_computation()
1571                     ->root_instruction()
1572                     ->parameter_number())
1573           : unfused_instructions.front();
1574   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1575 
1576   TF_CHECK_OK(
1577       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1578   if (GetModule()) {
1579     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1580   }
1581 
1582   // Fuse the root instruction and generate multiple outputs.
1583   if (unfused_instructions.empty()) {
1584     return;
1585   }
1586   FuseInstructionIntoMultiOutput(unfused_root);
1587   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1588   // The rest instructions are of normal fusing.
1589   for (int64 i = 1; i < unfused_instructions.size(); i++) {
1590     auto instruction = unfused_instructions[i];
1591     FuseInstruction(instruction);
1592     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1593   }
1594 }
1595 
fused_instructions_computation() const1596 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1597   CHECK(!called_computations().empty());
1598   auto* fused_instructions_computation = called_computations().front();
1599   CHECK(fused_instructions_computation->IsFusionComputation())
1600       << "Computation " << fused_instructions_computation->name()
1601       << " is not a fusion kind";
1602   return fused_instructions_computation;
1603 }
1604 
fused_expression_root() const1605 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1606   return fused_instructions_computation()->root_instruction();
1607 }
1608 
fused_parameter(int64 parameter_number) const1609 HloInstruction* HloFusionInstruction::fused_parameter(
1610     int64 parameter_number) const {
1611   return fused_instructions_computation()->parameter_instruction(
1612       parameter_number);
1613 }
1614 
fused_parameters() const1615 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1616     const {
1617   return fused_instructions_computation()->parameter_instructions();
1618 }
1619 
1620 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1621     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1622 HloFusionInstruction::fused_instructions() const {
1623   const HloComputation* subcomp = fused_instructions_computation();
1624   return subcomp->instructions();
1625 }
1626 
1627 const tensorflow::gtl::iterator_range<
1628     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1629 HloFusionInstruction::fused_instructions() {
1630   return fused_instructions_computation()->instructions();
1631 }
1632 
fused_instruction_count() const1633 int64 HloFusionInstruction::fused_instruction_count() const {
1634   return fused_instructions_computation()->instruction_count();
1635 }
1636 
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1637 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1638     HloInstruction* instruction_to_fuse, bool add_output) {
1639   // When add_output is false, this fusion instruction must be a user of
1640   // instruction_to_fuse.
1641   if (!add_output) {
1642     CHECK(IsUserOf(instruction_to_fuse));
1643   }
1644   HloInstruction* fused_instruction =
1645       CloneAndFuseInternal(instruction_to_fuse, add_output);
1646   return fused_instruction;
1647 }
1648 
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1649 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1650     HloInstruction* instruction_to_fuse, bool add_output) {
1651   CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1652   VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1653   HloInstruction* clone = nullptr;
1654   if (called_computations().empty()) {
1655     // New fusion instruction. It should not be a multioutput instruction.
1656     CHECK(!add_output);
1657     auto builder = HloComputation::Builder("fused_computation", this);
1658     builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1659     AppendComputation(
1660         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1661     clone = fused_expression_root();
1662   } else {
1663     // When add_output is false, instruction_to_fuse is necessarily an operand
1664     // of the fusion instruction. After fusion this will no longer be the
1665     // case. Remove the operand from the operand list and remove its
1666     // corresponding fused parameter instruction. Renumber parameters as
1667     // necessary to make parameter numbers consistent with their index in the
1668     // fused_parameter_ vector.
1669     bool in_operand_list =
1670         absl::c_linear_search(operands(), instruction_to_fuse);
1671     CHECK(add_output || in_operand_list);
1672     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1673       // We assume all uses of a kTuple operation are GTE ops, not another
1674       // fusion node. In this case, we don't need to clone
1675       // 'instruction_to_fuse'.
1676       CHECK(!in_operand_list);
1677       clone = instruction_to_fuse;
1678     } else {
1679       clone = fused_instructions_computation()->AddInstruction(
1680           instruction_to_fuse->Clone(/*suffix=*/""));
1681     }
1682     const std::vector<HloInstruction*>& fused_parameters =
1683         fused_instructions_computation()->parameter_instructions();
1684     for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1685       if (instruction_to_fuse == operand(operand_num)) {
1686         // replace the fused parameter instruction's uses with the clone.
1687         HloInstruction* fused_parameter = fused_parameters[operand_num];
1688         TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1689 
1690         // Remove the corresponding fused parameter and operand from their
1691         // respective vectors.
1692         TF_CHECK_OK(
1693             fused_instructions_computation()->RemoveParameter(operand_num));
1694         RemoveOperandAt(operand_num);
1695         break;
1696       }
1697     }
1698     // We've cloned instruction_to_fuse into this fusion instruction, so this
1699     // fusion instruction is no longer a use of instruction_to_fuse.
1700     if (in_operand_list) {
1701       DetachFrom(instruction_to_fuse);
1702       // When the instruction_to_fuse does not have other users, we don't need
1703       // to generate a multioutput fusion instruction.
1704       if (instruction_to_fuse->user_count() == 0) {
1705         add_output = false;
1706       }
1707     }
1708   }
1709 
1710   // Reread the parameters in the computation.
1711   const std::vector<HloInstruction*>& fused_parameters =
1712       fused_instructions_computation()->parameter_instructions();
1713 
1714   // Add each operand of the clone as an operand of the fusion instruction. A
1715   // complication is that some clone operands may already be operands of the
1716   // fusion instruction.
1717   for (int64 operand_num = 0; operand_num < clone->operand_count();
1718        ++operand_num) {
1719     HloInstruction* operand = clone->mutable_operand(operand_num);
1720 
1721     // See if this operand is already an operand of the fusion node.
1722     CHECK_EQ(operands().size(), fused_parameters.size());
1723     HloInstruction* fused_param = nullptr;
1724     for (int64 i = 0; i < operands().size(); ++i) {
1725       if (this->operand(i) == operand) {
1726         fused_param = fused_parameters[i];
1727         break;
1728       }
1729     }
1730 
1731     if (fused_param == nullptr) {
1732       // Clone's operand was not already an operand of the fusion
1733       // instruction. Add it as an operand and add a corresponding fused
1734       // parameter instruction.
1735       fused_param = AddFusionOperand(operand);
1736     }
1737     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1738   }
1739 
1740   if (add_output) {
1741     CHECK_GT(instruction_to_fuse->user_count(), 0);
1742     // If this is already a multioutput fusion instruction, expand the root
1743     // tuple by 1.
1744     HloInstruction* fused_root = fused_expression_root();
1745     HloInstruction::InstructionVector tuple_elements;
1746     bool newly_created_tuple_instr = false;
1747     if (fused_root->opcode() == HloOpcode::kTuple) {
1748       tuple_elements = fused_root->operands();
1749     } else {
1750       tuple_elements.push_back(fused_root);
1751       newly_created_tuple_instr = true;
1752     }
1753     if (clone->opcode() == HloOpcode::kTuple) {
1754       for (auto inst : clone->operands()) {
1755         tuple_elements.push_back(inst);
1756       }
1757     } else {
1758       tuple_elements.push_back(clone);
1759     }
1760     HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1761         HloInstruction::CreateTuple(tuple_elements));
1762     fused_instructions_computation()->set_root_instruction(new_root);
1763     *mutable_shape() = new_root->shape();
1764     if (fused_root->opcode() == HloOpcode::kTuple) {
1765       TF_CHECK_OK(
1766           fused_instructions_computation()->RemoveInstruction(fused_root));
1767     }
1768 
1769     // If this is a newly created multioutput instruction, we need to update
1770     // the use of the original fusion instruction.
1771     if (newly_created_tuple_instr) {
1772       HloInstruction* new_instr = parent()->AddInstruction(
1773           HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1774       TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1775     }
1776     int64 index = tuple_elements.size();
1777     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1778       CHECK_EQ(clone, instruction_to_fuse);
1779       index -= clone->operand_count();
1780       std::vector<HloInstruction*> to_be_removed;
1781       for (auto old_gte : clone->users()) {
1782         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1783         int64 old_tuple_index = old_gte->tuple_index();
1784         HloInstruction* new_gte =
1785             parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1786                 old_gte->shape(), this, index + old_tuple_index));
1787         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1788         to_be_removed.push_back(old_gte);
1789       }
1790       for (auto old_gte : to_be_removed) {
1791         TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1792       }
1793     } else {
1794       HloInstruction* new_gte =
1795           parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1796               clone->shape(), this, index - 1));
1797       TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1798     }
1799   }
1800 
1801   if (clone != instruction_to_fuse) {
1802     VLOG(2) << "New clone:\n" << clone->ToString();
1803   }
1804   return clone;
1805 }
1806 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1807 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1808     const HloPrintOptions& options) const {
1809   return {StrCat("kind=", xla::ToString(fusion_kind()))};
1810 }
1811 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1812 bool HloFusionInstruction::IdenticalSlowPath(
1813     const HloInstruction& other,
1814     const std::function<bool(const HloComputation*, const HloComputation*)>&
1815         eq_computations) const {
1816   return fusion_kind() == other.fusion_kind() &&
1817          eq_computations(fused_instructions_computation(),
1818                          other.fused_instructions_computation());
1819 }
1820 
InnerHash() const1821 uint64 HloFusionInstruction::InnerHash() const {
1822   return fused_instructions_computation()->root_instruction()->Hash();
1823 }
1824 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1825 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1826     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1827     HloCloneContext* context) const {
1828   HloModule* module = context != nullptr ? context->module() : GetModule();
1829   HloComputation* new_fused_computation = nullptr;
1830   if (context != nullptr) {
1831     new_fused_computation =
1832         context->FindComputation(fused_instructions_computation());
1833   }
1834   if (new_fused_computation == nullptr) {
1835     new_fused_computation = module->AddEmbeddedComputation(
1836         fused_instructions_computation()->Clone("clone", context));
1837   }
1838   return absl::make_unique<HloFusionInstruction>(
1839       shape, fusion_kind(), new_operands, new_fused_computation);
1840 }
1841 
DeduplicateFusionOperands()1842 Status HloFusionInstruction::DeduplicateFusionOperands() {
1843   if (IsCustomFusion()) {
1844     return Status::OK();
1845   }
1846   absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1847   std::vector<int> operands_to_remove;
1848   for (int i = 0; i < operand_count(); ++i) {
1849     auto emplace_result = operand_indices.emplace(operand(i), i);
1850     if (!emplace_result.second) {
1851       TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1852           fused_parameter(emplace_result.first->second)));
1853       operands_to_remove.push_back(i);
1854     }
1855   }
1856   if (operands_to_remove.empty()) {
1857     return Status::OK();
1858   }
1859   TF_RETURN_IF_ERROR(fused_instructions_computation()
1860                          ->RemoveUnusedParametersFromFusedComputation());
1861   RemoveOperandsAtAscendingIndices(operands_to_remove);
1862   return Status::OK();
1863 }
1864 
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1865 HloRngInstruction::HloRngInstruction(
1866     const Shape& shape, RandomDistribution distribution,
1867     absl::Span<HloInstruction* const> parameters)
1868     : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1869   for (HloInstruction* param : parameters) {
1870     AppendOperand(param);
1871   }
1872 }
1873 
ToProto() const1874 HloInstructionProto HloRngInstruction::ToProto() const {
1875   HloInstructionProto proto = HloInstruction::ToProto();
1876   proto.set_distribution(distribution_);
1877   return proto;
1878 }
1879 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1880 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1881     const HloPrintOptions& options) const {
1882   return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1883 }
1884 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1885 bool HloRngInstruction::IsElementwiseImpl(
1886     const absl::optional<int64>& operand_idx) const {
1887   return true;
1888 }
1889 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1890 bool HloRngInstruction::IdenticalSlowPath(
1891     const HloInstruction& other,
1892     const std::function<bool(const HloComputation*, const HloComputation*)>&
1893         eq_computations) const {
1894   const auto& casted_other = static_cast<const HloRngInstruction&>(other);
1895   return distribution_ == casted_other.distribution_;
1896 }
1897 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1898 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1899     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1900     HloCloneContext* context) const {
1901   return absl::make_unique<HloRngInstruction>(shape, distribution_,
1902                                               new_operands);
1903 }
1904 
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1905 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1906                                                  const Shape& shape,
1907                                                  const string& name)
1908     : HloInstruction(HloOpcode::kParameter, shape),
1909       parameter_number_(parameter_number) {
1910   SetAndSanitizeName(name);
1911 }
1912 
ToProto() const1913 HloInstructionProto HloParameterInstruction::ToProto() const {
1914   HloInstructionProto proto = HloInstruction::ToProto();
1915   proto.set_parameter_number(parameter_number_);
1916   if (parameter_replicated_at_leaf_buffers_) {
1917     for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1918       proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1919           replicated);
1920     }
1921   }
1922   return proto;
1923 }
1924 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1925 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1926     const HloPrintOptions& options) const {
1927   std::vector<string> result;
1928   if (!parameter_replicated_at_leaf_buffers_) {
1929     return result;
1930   }
1931   std::vector<string> buffers_replicated_strs;
1932   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1933     buffers_replicated_strs.push_back(replicated ? "true" : "false");
1934   }
1935   if (options.print_ids()) {
1936     result.push_back(StrCat("parameter_replication={",
1937                             StrJoin(buffers_replicated_strs, ","), "}"));
1938   }
1939   return result;
1940 }
1941 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1942 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1943     const HloPrintOptions& options,
1944     CanonicalNameMap* canonical_name_map) const {
1945   return StrCat(parameter_number_);
1946 }
1947 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1948 bool HloParameterInstruction::IdenticalSlowPath(
1949     const HloInstruction& other,
1950     const std::function<bool(const HloComputation*, const HloComputation*)>&
1951         eq_computations) const {
1952   const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1953   return parameter_number() == casted_other.parameter_number();
1954 }
1955 
1956 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1957 HloParameterInstruction::CloneWithNewOperandsImpl(
1958     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1959     HloCloneContext* context) const {
1960   auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_,
1961                                                           shape, name());
1962   if (parameter_replicated_at_leaf_buffers_ &&
1963       ShapeUtil::Equal(shape, this->shape())) {
1964     clone->set_parameter_replicated_at_leaf_buffers(
1965         *parameter_replicated_at_leaf_buffers_);
1966   }
1967   return clone;
1968 }
1969 
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1970 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1971     const Shape& shape, HloInstruction* operand, int64 index)
1972     : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1973   AppendOperand(operand);
1974 }
1975 
ToProto() const1976 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1977   HloInstructionProto proto = HloInstruction::ToProto();
1978   proto.set_tuple_index(tuple_index_);
1979   return proto;
1980 }
1981 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1982 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1983     const HloPrintOptions& options) const {
1984   return {StrCat("index=", tuple_index())};
1985 }
1986 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1987 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1988     const HloInstruction& other,
1989     const std::function<bool(const HloComputation*, const HloComputation*)>&
1990         eq_computations) const {
1991   const auto& casted_other =
1992       static_cast<const HloGetTupleElementInstruction&>(other);
1993   return tuple_index() == casted_other.tuple_index();
1994 }
1995 
1996 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1997 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1998     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1999     HloCloneContext* context) const {
2000   CHECK_EQ(new_operands.size(), 1);
2001   return absl::make_unique<HloGetTupleElementInstruction>(
2002       shape, new_operands[0], tuple_index());
2003 }
2004 
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2005 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2006     const Shape& shape, HloInstruction* operand, const int exponent_bits,
2007     const int mantissa_bits)
2008     : HloInstruction(HloOpcode::kReducePrecision, shape),
2009       exponent_bits_(exponent_bits),
2010       mantissa_bits_(mantissa_bits) {
2011   AppendOperand(operand);
2012 }
2013 
ToProto() const2014 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2015   HloInstructionProto proto = HloInstruction::ToProto();
2016   proto.set_exponent_bits(exponent_bits_);
2017   proto.set_mantissa_bits(mantissa_bits_);
2018   return proto;
2019 }
2020 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2021 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2022     const HloPrintOptions& options) const {
2023   return {StrCat("exponent_bits=", exponent_bits_),
2024           StrCat("mantissa_bits=", mantissa_bits_)};
2025 }
2026 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2027 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2028     const HloInstruction& other,
2029     const std::function<bool(const HloComputation*, const HloComputation*)>&
2030         eq_computations) const {
2031   const auto& casted_other =
2032       static_cast<const HloReducePrecisionInstruction&>(other);
2033   // A reduce-precision operation is determined by the bit sizes.
2034   return exponent_bits() == casted_other.exponent_bits() &&
2035          mantissa_bits() == casted_other.mantissa_bits();
2036 }
2037 
2038 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2039 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2040     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2041     HloCloneContext* context) const {
2042   CHECK_EQ(new_operands.size(), 1);
2043   return absl::make_unique<HloReducePrecisionInstruction>(
2044       shape, new_operands[0], exponent_bits(), mantissa_bits());
2045 }
2046 
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)2047 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2048                                            HloInstruction* token_operand,
2049                                            const string& config)
2050     : HloInstruction(HloOpcode::kInfeed,
2051                      ShapeUtil::MakeTupleShape(
2052                          {infeed_shape, ShapeUtil::MakeTokenShape()})),
2053       infeed_config_(config) {
2054   AppendOperand(token_operand);
2055 }
2056 
ToProto() const2057 HloInstructionProto HloInfeedInstruction::ToProto() const {
2058   HloInstructionProto proto = HloInstruction::ToProto();
2059   proto.set_infeed_config(infeed_config_);
2060   return proto;
2061 }
2062 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2063 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2064     const HloPrintOptions& options) const {
2065   if (infeed_config_.empty()) {
2066     return {};
2067   }
2068   return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2069 }
2070 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2071 bool HloInfeedInstruction::IdenticalSlowPath(
2072     const HloInstruction& other,
2073     const std::function<bool(const HloComputation*, const HloComputation*)>&
2074         eq_computations) const {
2075   // Not yet supported.
2076   return false;
2077 }
2078 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2079 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2080     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2081     HloCloneContext* context) const {
2082   CHECK_EQ(new_operands.size(), 1);
2083   return absl::make_unique<HloInfeedInstruction>(
2084       infeed_shape(), new_operands[0], infeed_config());
2085 }
2086 
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2087 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2088                                              HloInstruction* operand,
2089                                              HloInstruction* token_operand,
2090                                              absl::string_view outfeed_config)
2091     : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2092       outfeed_shape_(outfeed_shape),
2093       outfeed_config_(outfeed_config) {
2094   AppendOperand(operand);
2095   AppendOperand(token_operand);
2096 }
2097 
ToProto() const2098 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2099   HloInstructionProto proto = HloInstruction::ToProto();
2100   proto.set_outfeed_config(outfeed_config());
2101   *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2102   return proto;
2103 }
2104 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2105 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2106     const HloPrintOptions& options) const {
2107   std::vector<string> extra;
2108   extra.push_back(StrCat("outfeed_shape=",
2109                          ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2110   if (!outfeed_config_.empty()) {
2111     extra.push_back(
2112         StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2113   }
2114   return extra;
2115 }
2116 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2117 bool HloOutfeedInstruction::IdenticalSlowPath(
2118     const HloInstruction& other,
2119     const std::function<bool(const HloComputation*, const HloComputation*)>&
2120         eq_computations) const {
2121   // Not yet supported.
2122   return false;
2123 }
2124 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2125 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2126     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2127     HloCloneContext* context) const {
2128   CHECK_EQ(new_operands.size(), 2);
2129   return absl::make_unique<HloOutfeedInstruction>(
2130       outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2131 }
2132 
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2133 HloConvolutionInstruction::HloConvolutionInstruction(
2134     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2135     int64 feature_group_count, int64 batch_group_count, const Window& window,
2136     const ConvolutionDimensionNumbers& dimension_numbers,
2137     const PrecisionConfig& precision_config)
2138     : HloInstruction(HloOpcode::kConvolution, shape),
2139       feature_group_count_(feature_group_count),
2140       batch_group_count_(batch_group_count),
2141       window_(window),
2142       convolution_dimension_numbers_(dimension_numbers),
2143       precision_config_(precision_config) {
2144   if (window_util::HasBaseDilation(window)) {
2145     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2146   }
2147   if (window_util::HasWindowDilation(window)) {
2148     SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2149   }
2150   AppendOperand(lhs);
2151   AppendOperand(rhs);
2152 }
2153 
ToCategory() const2154 string HloConvolutionInstruction::ToCategory() const {
2155   string category = "convolution";
2156   if (window_util::HasBaseDilation(window())) {
2157     category += " base-dilated";
2158   }
2159   if (window_util::HasWindowDilation(window())) {
2160     category += " window-dilated";
2161   }
2162   return category;
2163 }
2164 
ToProto() const2165 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2166   HloInstructionProto proto = HloInstruction::ToProto();
2167   *proto.mutable_window() = window_;
2168   *proto.mutable_convolution_dimension_numbers() =
2169       convolution_dimension_numbers_;
2170   proto.set_feature_group_count(feature_group_count_);
2171   proto.set_batch_group_count(batch_group_count_);
2172   *proto.mutable_precision_config() = precision_config_;
2173   return proto;
2174 }
2175 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2176 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2177     const HloPrintOptions& options) const {
2178   std::vector<string> extra;
2179   if (window_.dimensions_size() != 0) {
2180     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2181   }
2182   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2183                                             convolution_dimension_numbers_)));
2184   if (feature_group_count_ != 1) {
2185     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2186   }
2187 
2188   if (batch_group_count_ != 1) {
2189     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2190   }
2191 
2192   string precision_config_string = PrecisionConfigToString(precision_config_);
2193   if (!precision_config_string.empty()) {
2194     extra.push_back(precision_config_string);
2195   }
2196   return extra;
2197 }
2198 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2199 bool HloConvolutionInstruction::IdenticalSlowPath(
2200     const HloInstruction& other,
2201     const std::function<bool(const HloComputation*, const HloComputation*)>&
2202         eq_computations) const {
2203   const auto& casted_other =
2204       static_cast<const HloConvolutionInstruction&>(other);
2205   if (feature_group_count_ != other.feature_group_count()) {
2206     return false;
2207   }
2208   if (batch_group_count_ != other.batch_group_count()) {
2209     return false;
2210   }
2211   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2212          protobuf_util::ProtobufEquals(
2213              convolution_dimension_numbers(),
2214              casted_other.convolution_dimension_numbers()) &&
2215          protobuf_util::ProtobufEquals(precision_config(),
2216                                        casted_other.precision_config());
2217 }
2218 
2219 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2220 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2221     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2222     HloCloneContext* context) const {
2223   CHECK_EQ(new_operands.size(), 2);
2224   return absl::make_unique<HloConvolutionInstruction>(
2225       shape, new_operands[0], new_operands[1], feature_group_count_,
2226       batch_group_count_, window(), convolution_dimension_numbers_,
2227       precision_config_);
2228 }
2229 
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2230 HloReduceWindowInstruction::HloReduceWindowInstruction(
2231     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2232     const Window& window, HloComputation* reduce_computation)
2233     : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2234                                  absl::MakeSpan(&init_value, 1), window,
2235                                  reduce_computation) {}
2236 
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2237 HloReduceWindowInstruction::HloReduceWindowInstruction(
2238     const Shape& shape, absl::Span<HloInstruction* const> operands,
2239     absl::Span<HloInstruction* const> init_values, const Window& window,
2240     HloComputation* reduce_computation)
2241     : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2242   for (auto* operand : operands) {
2243     AppendOperand(operand);
2244   }
2245   for (auto* init_value : init_values) {
2246     AppendOperand(init_value);
2247   }
2248   AppendComputation(reduce_computation);
2249 }
2250 
ToProto() const2251 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2252   HloInstructionProto proto = HloInstruction::ToProto();
2253   *proto.mutable_window() = window_;
2254   return proto;
2255 }
2256 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2257 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2258     const HloPrintOptions& options) const {
2259   std::vector<string> extra;
2260   if (window_.dimensions_size() != 0) {
2261     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2262   }
2263   return extra;
2264 }
2265 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2266 bool HloReduceWindowInstruction::IdenticalSlowPath(
2267     const HloInstruction& other,
2268     const std::function<bool(const HloComputation*, const HloComputation*)>&
2269         eq_computations) const {
2270   const auto& casted_other =
2271       static_cast<const HloReduceWindowInstruction&>(other);
2272   return eq_computations(to_apply(), casted_other.to_apply()) &&
2273          protobuf_util::ProtobufEquals(window(), casted_other.window());
2274 }
2275 
2276 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2277 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2278     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2279     HloCloneContext* context) const {
2280   CHECK_EQ(new_operands.size() % 2, 0);
2281   int64 num_operands = new_operands.size() / 2;
2282   return absl::make_unique<HloReduceWindowInstruction>(
2283       shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2284       absl::MakeSpan(new_operands)
2285           .subspan(num_operands, new_operands.size() / 2),
2286       window(), to_apply());
2287 }
2288 
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2289 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2290     const Shape& shape, HloInstruction* operand, HloComputation* select,
2291     const Window& window, HloInstruction* source, HloInstruction* init_value,
2292     HloComputation* scatter)
2293     : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2294   AppendOperand(operand);
2295   AppendOperand(source);
2296   AppendOperand(init_value);
2297   // Select comes before scatter in the vector.
2298   AppendComputation(select);
2299   AppendComputation(scatter);
2300 }
2301 
ToProto() const2302 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2303   HloInstructionProto proto = HloInstruction::ToProto();
2304   *proto.mutable_window() = window_;
2305   return proto;
2306 }
2307 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2308 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2309     const HloPrintOptions& options) const {
2310   std::vector<string> extra;
2311   if (window_.dimensions_size() != 0) {
2312     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2313   }
2314   return extra;
2315 }
2316 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2317 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2318     const HloInstruction& other,
2319     const std::function<bool(const HloComputation*, const HloComputation*)>&
2320         eq_computations) const {
2321   const auto& casted_other =
2322       static_cast<const HloSelectAndScatterInstruction&>(other);
2323   return eq_computations(select(), casted_other.select()) &&
2324          eq_computations(scatter(), casted_other.scatter()) &&
2325          protobuf_util::ProtobufEquals(window(), casted_other.window());
2326 }
2327 
2328 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2329 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2330     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2331     HloCloneContext* context) const {
2332   CHECK_EQ(new_operands.size(), 3);
2333   return absl::make_unique<HloSelectAndScatterInstruction>(
2334       shape, new_operands[0], select(), window(), new_operands[1],
2335       new_operands[2], scatter());
2336 }
2337 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)2338 HloCustomCallInstruction::HloCustomCallInstruction(
2339     const Shape& shape, absl::Span<HloInstruction* const> operands,
2340     absl::string_view custom_call_target, string opaque)
2341     : HloInstruction(HloOpcode::kCustomCall, shape),
2342       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2343       feature_group_count_(1),
2344       batch_group_count_(1),
2345       layout_constrained_(false),
2346       padding_type_(PaddingType::PADDING_INVALID),
2347       custom_call_has_side_effect_(false) {
2348   set_raw_backend_config_string(std::move(opaque));
2349   for (auto operand : operands) {
2350     AppendOperand(operand);
2351   }
2352 }
2353 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque)2354 HloCustomCallInstruction::HloCustomCallInstruction(
2355     const Shape& shape, absl::Span<HloInstruction* const> operands,
2356     HloComputation* to_apply, absl::string_view custom_call_target,
2357     string opaque)
2358     : HloInstruction(HloOpcode::kCustomCall, shape),
2359       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2360       feature_group_count_(1),
2361       batch_group_count_(1),
2362       layout_constrained_(false),
2363       padding_type_(PaddingType::PADDING_INVALID),
2364       custom_call_has_side_effect_(false) {
2365   set_raw_backend_config_string(std::move(opaque));
2366   for (auto operand : operands) {
2367     AppendOperand(operand);
2368   }
2369   AppendComputation(to_apply);
2370 }
2371 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque)2372 HloCustomCallInstruction::HloCustomCallInstruction(
2373     const Shape& shape, absl::Span<HloInstruction* const> operands,
2374     absl::Span<HloComputation* const> called_computations,
2375     absl::string_view custom_call_target, string opaque)
2376     : HloInstruction(HloOpcode::kCustomCall, shape),
2377       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2378       feature_group_count_(1),
2379       batch_group_count_(1),
2380       layout_constrained_(false),
2381       padding_type_(PaddingType::PADDING_INVALID),
2382       custom_call_has_side_effect_(false) {
2383   set_raw_backend_config_string(std::move(opaque));
2384   for (auto operand : operands) {
2385     AppendOperand(operand);
2386   }
2387   for (auto comp : called_computations) {
2388     AppendComputation(comp);
2389   }
2390 }
2391 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout)2392 HloCustomCallInstruction::HloCustomCallInstruction(
2393     const Shape& shape, absl::Span<HloInstruction* const> operands,
2394     absl::string_view custom_call_target, string opaque,
2395     absl::Span<const Shape> operand_shapes_with_layout)
2396     : HloInstruction(HloOpcode::kCustomCall, shape),
2397       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2398       feature_group_count_(1),
2399       batch_group_count_(1),
2400       layout_constrained_(true),
2401       padding_type_(PaddingType::PADDING_INVALID),
2402       operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2403                                   operand_shapes_with_layout.end()),
2404       custom_call_has_side_effect_(false) {
2405   set_raw_backend_config_string(std::move(opaque));
2406   for (auto operand : operands) {
2407     AppendOperand(operand);
2408   }
2409 }
2410 
ToProto() const2411 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2412   HloInstructionProto proto = HloInstruction::ToProto();
2413   if (window_ != nullptr) {
2414     *proto.mutable_window() = *window_;
2415   }
2416   if (convolution_dimension_numbers_ != nullptr) {
2417     *proto.mutable_convolution_dimension_numbers() =
2418         *convolution_dimension_numbers_;
2419   }
2420   proto.set_custom_call_target(custom_call_target_);
2421   proto.set_feature_group_count(feature_group_count_);
2422   proto.set_batch_group_count(batch_group_count_);
2423   *proto.mutable_precision_config() = precision_config_;
2424   proto.set_padding_type(padding_type_);
2425   if (layout_constrained()) {
2426     proto.set_constrain_layout(true);
2427     for (const Shape& shape : operand_shapes_with_layout_) {
2428       *proto.add_operand_shapes_with_layout() = shape.ToProto();
2429     }
2430   }
2431   proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2432   if (literal_.has_value()) {
2433     *proto.mutable_literal() = literal_->ToProto();
2434   }
2435   for (const auto& pair : output_to_operand_aliasing_) {
2436     auto aliasing = proto.add_custom_call_output_operand_aliasing();
2437     aliasing->set_operand_index(pair.second.first);
2438     for (int64 index : pair.first) {
2439       aliasing->add_output_shape_index(index);
2440     }
2441     for (int64 index : pair.second.second) {
2442       aliasing->add_operand_shape_index(index);
2443     }
2444   }
2445   return proto;
2446 }
2447 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2448 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2449     const HloPrintOptions& options) const {
2450   std::vector<string> extra;
2451   if (window_ != nullptr) {
2452     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2453   }
2454   if (convolution_dimension_numbers_ != nullptr) {
2455     extra.push_back(StrCat(
2456         "dim_labels=",
2457         ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2458   }
2459   if (feature_group_count_ != 1) {
2460     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2461   }
2462   if (batch_group_count_ != 1) {
2463     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2464   }
2465   string precision_config_string = PrecisionConfigToString(precision_config_);
2466   if (!precision_config_string.empty()) {
2467     extra.push_back(precision_config_string);
2468   }
2469   if (padding_type_ != PaddingType::PADDING_INVALID) {
2470     extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2471   }
2472   // By contract, we print the custom call target even if
2473   // options.print_subcomputation_mode() == kOff, because the call target is not
2474   // an HloComputation.
2475   extra.push_back(
2476       StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2477 
2478   if (layout_constrained()) {
2479     std::vector<string> shape_strings;
2480     for (const Shape& shape : operand_shapes_with_layout_) {
2481       shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2482     }
2483     extra.push_back(StrCat("operand_layout_constraints={",
2484                            StrJoin(shape_strings, ", "), "}"));
2485   }
2486   if (custom_call_has_side_effect_) {
2487     extra.push_back("custom_call_has_side_effect=true");
2488   }
2489   if (literal_.has_value()) {
2490     extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
2491   }
2492   if (!output_to_operand_aliasing_.empty()) {
2493     std::vector<string> pair_strings;
2494     for (const auto& pair : output_to_operand_aliasing_) {
2495       pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2496                                     pair.second.first, ", ",
2497                                     pair.second.second.ToString(), ")"));
2498     }
2499     extra.push_back(StrCat("output_to_operand_aliasing={",
2500                            StrJoin(pair_strings, ", "), "}"));
2501   }
2502   return extra;
2503 }
2504 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2505 bool HloCustomCallInstruction::IdenticalSlowPath(
2506     const HloInstruction& other,
2507     const std::function<bool(const HloComputation*, const HloComputation*)>&
2508         eq_computations) const {
2509   const auto& casted_other =
2510       static_cast<const HloCustomCallInstruction&>(other);
2511   if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2512       (window_ != nullptr &&
2513        !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2514     return false;
2515   }
2516   if ((convolution_dimension_numbers_ == nullptr) !=
2517           (casted_other.convolution_dimension_numbers_ == nullptr) ||
2518       (convolution_dimension_numbers_ != nullptr &&
2519        !protobuf_util::ProtobufEquals(
2520            convolution_dimension_numbers(),
2521            casted_other.convolution_dimension_numbers()))) {
2522     return false;
2523   }
2524   if (feature_group_count_ != casted_other.feature_group_count_) {
2525     return false;
2526   }
2527   if (batch_group_count_ != casted_other.batch_group_count_) {
2528     return false;
2529   }
2530 
2531   if (padding_type_ != casted_other.padding_type()) {
2532     return false;
2533   }
2534 
2535   if (layout_constrained() != casted_other.layout_constrained()) {
2536     return false;
2537   }
2538   if (layout_constrained()) {
2539     for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2540       if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2541                             casted_other.operand_shapes_with_layout_[i])) {
2542         return false;
2543       }
2544     }
2545   }
2546   if (custom_call_has_side_effect_ !=
2547       casted_other.custom_call_has_side_effect()) {
2548     return false;
2549   }
2550   if (output_to_operand_aliasing_ !=
2551       casted_other.output_to_operand_aliasing()) {
2552     return false;
2553   }
2554   if (!protobuf_util::ProtobufEquals(precision_config(),
2555                                      casted_other.precision_config())) {
2556     return false;
2557   }
2558 
2559   if (called_computations().size() != other.called_computations().size()) {
2560     return false;
2561   }
2562   for (int64 i = 0; i < called_computations().size(); ++i) {
2563     if (!eq_computations(called_computations()[i],
2564                          other.called_computations()[i])) {
2565       return false;
2566     }
2567   }
2568   if (HasLiteral() == casted_other.HasLiteral()) {
2569     if (HasLiteral() && literal() == casted_other.literal()) {
2570       return false;
2571     }
2572   } else {
2573     return true;
2574   }
2575 
2576   // Note: backend_config comparison is done in Identical, which is the
2577   // intended/exposed way to compare computations, and so not repeated here.
2578   return custom_call_target_ == casted_other.custom_call_target_;
2579 }
2580 
2581 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2582 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2583     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2584     HloCloneContext* context) const {
2585   auto cloned = absl::make_unique<HloCustomCallInstruction>(
2586       shape, new_operands, custom_call_target(), opaque());
2587   if (layout_constrained()) {
2588     cloned->layout_constrained_ = true;
2589     cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2590   }
2591   if (window_ != nullptr) {
2592     cloned->set_window(*window_);
2593   }
2594   if (convolution_dimension_numbers_ != nullptr) {
2595     cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2596   }
2597   if (HasLiteral()) {
2598     cloned->set_literal(literal().Clone());
2599   }
2600   cloned->set_feature_group_count(feature_group_count_);
2601   cloned->set_batch_group_count(batch_group_count_);
2602   cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2603   cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2604   cloned->set_padding_type(padding_type_);
2605   *cloned->mutable_precision_config() = precision_config();
2606   return std::move(cloned);
2607 }
2608 
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2609 HloPadInstruction::HloPadInstruction(const Shape& shape,
2610                                      HloInstruction* operand,
2611                                      HloInstruction* padding_value,
2612                                      const PaddingConfig& padding_config)
2613     : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2614   AppendOperand(operand);
2615   AppendOperand(padding_value);
2616 }
2617 
ToProto() const2618 HloInstructionProto HloPadInstruction::ToProto() const {
2619   HloInstructionProto proto = HloInstruction::ToProto();
2620   *proto.mutable_padding_config() = padding_config_;
2621   return proto;
2622 }
2623 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2624 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2625     const HloPrintOptions& options) const {
2626   return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2627 }
2628 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2629 bool HloPadInstruction::IdenticalSlowPath(
2630     const HloInstruction& other,
2631     const std::function<bool(const HloComputation*, const HloComputation*)>&
2632         eq_computations) const {
2633   const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2634   return protobuf_util::ProtobufEquals(padding_config(),
2635                                        casted_other.padding_config());
2636 }
2637 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2638 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2639     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2640     HloCloneContext* context) const {
2641   CHECK_EQ(new_operands.size(), 2);
2642   return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2643                                               new_operands[1], padding_config_);
2644 }
2645 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2646 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2647     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2648     absl::Span<const int64> slice_sizes)
2649     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2650       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2651   AppendOperand(operand);
2652   AppendOperand(start_indices);
2653 }
2654 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2655 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2656     const Shape& shape, HloInstruction* operand,
2657     absl::Span<HloInstruction* const> start_indices,
2658     absl::Span<const int64> slice_sizes)
2659     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2660       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2661   AppendOperand(operand);
2662   for (HloInstruction* index : start_indices) {
2663     AppendOperand(index);
2664   }
2665 }
2666 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2667 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2668     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2669     HloInstruction* start_indices)
2670     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2671   AppendOperand(operand);
2672   AppendOperand(update);
2673   AppendOperand(start_indices);
2674 }
2675 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2676 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2677     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2678     absl::Span<HloInstruction* const> start_indices)
2679     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2680   AppendOperand(operand);
2681   AppendOperand(update);
2682   for (HloInstruction* index : start_indices) {
2683     AppendOperand(index);
2684   }
2685 }
2686 
ToProto() const2687 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2688   HloInstructionProto proto = HloInstruction::ToProto();
2689   for (int64 slice_size : dynamic_slice_sizes_) {
2690     proto.add_dynamic_slice_sizes(slice_size);
2691   }
2692   return proto;
2693 }
2694 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2695 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2696     const HloPrintOptions& options) const {
2697   return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2698                  "}")};
2699 }
2700 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2701 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2702     const HloInstruction& other,
2703     const std::function<bool(const HloComputation*, const HloComputation*)>&
2704         eq_computations) const {
2705   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
2706   return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
2707 }
2708 
2709 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2710 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2711     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2712     HloCloneContext* context) const {
2713   if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2714     // TODO(b/118437727): Old form, remove this path.
2715     return absl::make_unique<HloDynamicSliceInstruction>(
2716         shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2717   } else {
2718     return absl::make_unique<HloDynamicSliceInstruction>(
2719         shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2720   }
2721 }
2722 
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2723 HloGatherInstruction::HloGatherInstruction(
2724     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2725     const GatherDimensionNumbers& gather_dim_numbers,
2726     absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2727     : HloInstruction(HloOpcode::kGather, shape),
2728       indices_are_sorted_(indices_are_sorted) {
2729   AppendOperand(operand);
2730   AppendOperand(start_indices);
2731   gather_dimension_numbers_ =
2732       absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2733   absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2734 }
2735 
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2736 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2737     const GatherDimensionNumbers& gather_dimension_numbers) {
2738   string offset_dims =
2739       StrCat("offset_dims={",
2740              StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2741   string collapsed_slice_dims = StrCat(
2742       "collapsed_slice_dims={",
2743       StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2744   string start_index_map =
2745       StrCat("start_index_map={",
2746              StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2747   string index_vector_dim =
2748       StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2749 
2750   return StrJoin<std::initializer_list<string>>(
2751       {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2752       ", ");
2753 }
2754 
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2755 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2756     absl::Span<const int64> offset_dims,
2757     absl::Span<const int64> collapsed_slice_dims,
2758     absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2759   GatherDimensionNumbers gather_dim_numbers;
2760   for (int64 output_window_dim : offset_dims) {
2761     gather_dim_numbers.add_offset_dims(output_window_dim);
2762   }
2763   for (int64 elided_window_dim : collapsed_slice_dims) {
2764     gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2765   }
2766   for (int64 gather_dim_to_input_dim : start_index_map) {
2767     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2768   }
2769 
2770   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2771   return gather_dim_numbers;
2772 }
2773 
ToProto() const2774 HloInstructionProto HloGatherInstruction::ToProto() const {
2775   HloInstructionProto proto = HloInstruction::ToProto();
2776   *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2777   for (int64 bound : gather_slice_sizes()) {
2778     proto.add_gather_slice_sizes(bound);
2779   }
2780   proto.set_indices_are_sorted(indices_are_sorted());
2781   return proto;
2782 }
2783 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2784 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2785     const HloPrintOptions& options) const {
2786   std::vector<string> attrs{
2787       GatherDimensionNumbersToString(gather_dimension_numbers()),
2788       StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2789   if (indices_are_sorted()) {
2790     attrs.push_back("indices_are_sorted=true");
2791   }
2792   return attrs;
2793 }
2794 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2795 bool HloGatherInstruction::IdenticalSlowPath(
2796     const HloInstruction& other,
2797     const std::function<bool(const HloComputation*, const HloComputation*)>&
2798         eq_computations) const {
2799   const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2800   return protobuf_util::ProtobufEquals(
2801              gather_dimension_numbers(),
2802              casted_other.gather_dimension_numbers()) &&
2803          gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2804          indices_are_sorted() == casted_other.indices_are_sorted();
2805 }
2806 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2807 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2808     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2809     HloCloneContext* context) const {
2810   CHECK_EQ(new_operands.size(), 2);
2811   return absl::make_unique<HloGatherInstruction>(
2812       shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2813       gather_slice_sizes(), indices_are_sorted());
2814 }
2815 
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2816 HloScatterInstruction::HloScatterInstruction(
2817     const Shape& shape, HloInstruction* operand,
2818     HloInstruction* scatter_indices, HloInstruction* updates,
2819     HloComputation* update_computation,
2820     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2821     bool unique_indices)
2822     : HloInstruction(HloOpcode::kScatter, shape),
2823       indices_are_sorted_(indices_are_sorted),
2824       unique_indices_(unique_indices) {
2825   AppendOperand(operand);
2826   AppendOperand(scatter_indices);
2827   AppendOperand(updates);
2828   AppendComputation(update_computation);
2829   scatter_dimension_numbers_ =
2830       absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2831 }
2832 
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2833 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2834     const ScatterDimensionNumbers& scatter_dimension_numbers) {
2835   string update_window_dims =
2836       StrCat("update_window_dims={",
2837              StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
2838   string inserted_window_dims = StrCat(
2839       "inserted_window_dims={",
2840       StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
2841   string scatter_dims_to_operand_dims = StrCat(
2842       "scatter_dims_to_operand_dims={",
2843       StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
2844       "}");
2845   string index_vector_dim =
2846       StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
2847 
2848   return StrJoin<std::initializer_list<string>>(
2849       {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2850        index_vector_dim},
2851       ", ");
2852 }
2853 
2854 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2855 HloScatterInstruction::MakeScatterDimNumbers(
2856     absl::Span<const int64> update_window_dims,
2857     absl::Span<const int64> inserted_window_dims,
2858     absl::Span<const int64> scatter_dims_to_operand_dims,
2859     int64 index_vector_dim) {
2860   ScatterDimensionNumbers scatter_dim_numbers;
2861   for (int64 update_window_dim : update_window_dims) {
2862     scatter_dim_numbers.add_update_window_dims(update_window_dim);
2863   }
2864   for (int64 inserted_window_dim : inserted_window_dims) {
2865     scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2866   }
2867   for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2868     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2869         scatter_dim_to_operand_dim);
2870   }
2871   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2872   return scatter_dim_numbers;
2873 }
2874 
ToProto() const2875 HloInstructionProto HloScatterInstruction::ToProto() const {
2876   HloInstructionProto proto = HloInstruction::ToProto();
2877   *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2878   proto.set_indices_are_sorted(indices_are_sorted());
2879   proto.set_unique_indices(unique_indices());
2880   return proto;
2881 }
2882 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2883 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2884     const HloPrintOptions& options) const {
2885   std::vector<string> attrs{
2886       ScatterDimensionNumbersToString(scatter_dimension_numbers())};
2887   if (indices_are_sorted()) {
2888     attrs.push_back("indices_are_sorted=true");
2889   }
2890   if (unique_indices()) {
2891     attrs.push_back("unique_indices=true");
2892   }
2893   return attrs;
2894 }
2895 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2896 bool HloScatterInstruction::IdenticalSlowPath(
2897     const HloInstruction& other,
2898     const std::function<bool(const HloComputation*, const HloComputation*)>&
2899         eq_computations) const {
2900   const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2901   return protobuf_util::ProtobufEquals(
2902              scatter_dimension_numbers(),
2903              casted_other.scatter_dimension_numbers()) &&
2904          eq_computations(to_apply(), casted_other.to_apply()) &&
2905          indices_are_sorted() == casted_other.indices_are_sorted() &&
2906          unique_indices() == casted_other.unique_indices();
2907 }
2908 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2909 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2910     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2911     HloCloneContext* context) const {
2912   CHECK_EQ(new_operands.size(), 3);
2913   return absl::make_unique<HloScatterInstruction>(
2914       shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2915       scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
2916 }
2917 
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2918 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2919     : HloInstruction(HloOpcode::kIota, shape),
2920       iota_dimension_(iota_dimension) {}
2921 
ToProto() const2922 HloInstructionProto HloIotaInstruction::ToProto() const {
2923   HloInstructionProto proto = HloInstruction::ToProto();
2924   proto.add_dimensions(iota_dimension());
2925   return proto;
2926 }
2927 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2928 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2929     const HloPrintOptions& options) const {
2930   return {StrCat("iota_dimension=", iota_dimension())};
2931 }
2932 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2933 bool HloIotaInstruction::IdenticalSlowPath(
2934     const HloInstruction& other,
2935     const std::function<bool(const HloComputation*, const HloComputation*)>&
2936         eq_computations) const {
2937   const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2938   return iota_dimension() == casted_other.iota_dimension();
2939 }
2940 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2941 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2942     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2943     HloCloneContext* context) const {
2944   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2945 }
2946 
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2947 HloDotInstruction::HloDotInstruction(
2948     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2949     const DotDimensionNumbers& dimension_numbers,
2950     const PrecisionConfig& precision_config)
2951     : HloInstruction(HloOpcode::kDot, shape),
2952       dot_dimension_numbers_(dimension_numbers),
2953       precision_config_(precision_config) {
2954   AppendOperand(lhs);
2955   AppendOperand(rhs);
2956 }
2957 
ToProto() const2958 HloInstructionProto HloDotInstruction::ToProto() const {
2959   HloInstructionProto proto = HloInstruction::ToProto();
2960   *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2961   *proto.mutable_precision_config() = precision_config_;
2962   return proto;
2963 }
2964 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2965 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2966     const HloPrintOptions& options) const {
2967   std::vector<string> extra = {DotDimensionNumbersToString()};
2968 
2969   string precision_config_string = PrecisionConfigToString(precision_config_);
2970   if (!precision_config_string.empty()) {
2971     extra.push_back(precision_config_string);
2972   }
2973   return extra;
2974 }
2975 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2976 bool HloDotInstruction::IdenticalSlowPath(
2977     const HloInstruction& other,
2978     const std::function<bool(const HloComputation*, const HloComputation*)>&
2979         eq_computations) const {
2980   const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2981   return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2982                                        casted_other.dot_dimension_numbers()) &&
2983          protobuf_util::ProtobufEquals(precision_config(),
2984                                        casted_other.precision_config());
2985 }
2986 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2987 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2988     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2989     HloCloneContext* context) const {
2990   CHECK_EQ(new_operands.size(), 2);
2991   return absl::make_unique<HloDotInstruction>(
2992       shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2993       precision_config_);
2994 }
2995 
DotDimensionNumbersToString() const2996 string HloDotInstruction::DotDimensionNumbersToString() const {
2997   std::vector<string> result;
2998   const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2999   if (!dnums.lhs_batch_dimensions().empty()) {
3000     result.push_back(StrCat("lhs_batch_dims={",
3001                             StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
3002   }
3003   result.push_back(StrCat("lhs_contracting_dims={",
3004                           StrJoin(dnums.lhs_contracting_dimensions(), ","),
3005                           "}"));
3006 
3007   if (!dnums.rhs_batch_dimensions().empty()) {
3008     result.push_back(StrCat("rhs_batch_dims={",
3009                             StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
3010   }
3011   result.push_back(StrCat("rhs_contracting_dims={",
3012                           StrJoin(dnums.rhs_contracting_dimensions(), ","),
3013                           "}"));
3014 
3015   return StrJoin(result, ", ");
3016 }
3017 
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3018 HloDomainInstruction::HloDomainInstruction(
3019     const Shape& shape, HloInstruction* operand,
3020     std::unique_ptr<DomainMetadata> operand_side_metadata,
3021     std::unique_ptr<DomainMetadata> user_side_metadata)
3022     : HloInstruction(HloOpcode::kDomain, shape),
3023       operand_side_metadata_(std::move(operand_side_metadata)),
3024       user_side_metadata_(std::move(user_side_metadata)) {
3025   AppendOperand(operand);
3026 }
3027 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3028 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
3029     const HloPrintOptions& options) const {
3030   if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3031     return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3032                    "\", entry=", user_side_metadata_->ToString(),
3033                    ", exit=", operand_side_metadata_->ToString(), "}")};
3034   }
3035   return {};
3036 }
3037 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3038 bool HloDomainInstruction::IdenticalSlowPath(
3039     const HloInstruction& other,
3040     const std::function<bool(const HloComputation*, const HloComputation*)>&
3041         eq_computations) const {
3042   const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3043   return operand_side_metadata().Matches(
3044              casted_other.operand_side_metadata()) &&
3045          user_side_metadata().Matches(casted_other.user_side_metadata());
3046 }
3047 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3048 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3049     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3050     HloCloneContext* context) const {
3051   CHECK_EQ(new_operands.size(), 1);
3052   return absl::make_unique<HloDomainInstruction>(
3053       shape, new_operands[0], operand_side_metadata_->Clone(),
3054       user_side_metadata_->Clone());
3055 }
3056 
ToProto() const3057 HloInstructionProto HloDomainInstruction::ToProto() const {
3058   HloInstructionProto proto = HloInstruction::ToProto();
3059   auto operand_side_sharding =
3060       dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3061   if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3062     *proto.mutable_domain_entry_sharding() =
3063         operand_side_sharding->sharding()->ToProto();
3064   }
3065 
3066   auto user_side_sharding =
3067       dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3068   if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3069     *proto.mutable_domain_exit_sharding() =
3070         user_side_sharding->sharding()->ToProto();
3071   }
3072 
3073   return proto;
3074 }
3075 
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)3076 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3077     const Shape& shape, HloInstruction* operand, int64 dimension)
3078     : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3079       dimension_(dimension) {
3080   AppendOperand(operand);
3081 }
3082 
ToProto() const3083 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3084   HloInstructionProto proto = HloInstruction::ToProto();
3085   proto.add_dimensions(dimension());
3086   return proto;
3087 }
3088 
ExtraAttributesToStringImpl(const HloPrintOptions &) const3089 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3090     const HloPrintOptions& /*options*/) const {
3091   return {StrCat("dimensions={", dimension(), "}")};
3092 }
3093 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3094 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3095     const HloInstruction& other,
3096     const std::function<bool(const HloComputation*, const HloComputation*)>&
3097     /*eq_computations*/) const {
3098   const auto& casted_other =
3099       static_cast<const HloGetDimensionSizeInstruction&>(other);
3100   return dimension() == casted_other.dimension();
3101 }
3102 
3103 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3104 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3105     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3106     HloCloneContext* /*context*/) const {
3107   if (new_operands.size() != 1) {
3108     LOG(FATAL) << "expects 1 operand";
3109   }
3110   return absl::make_unique<HloGetDimensionSizeInstruction>(
3111       shape, new_operands[0], dimension());
3112 }
3113 
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)3114 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3115     const Shape& shape, HloInstruction* operand, HloInstruction* val,
3116     int64 dimension)
3117     : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3118       dimension_(dimension) {
3119   AppendOperand(operand);
3120   AppendOperand(val);
3121 }
3122 
ExtraAttributesToStringImpl(const HloPrintOptions &) const3123 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3124     const HloPrintOptions& /*options*/) const {
3125   return {StrCat("dimensions={", dimension(), "}")};
3126 }
3127 
ToProto() const3128 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3129   HloInstructionProto proto = HloInstruction::ToProto();
3130   proto.add_dimensions(dimension());
3131   return proto;
3132 }
3133 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3134 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3135     const HloInstruction& other,
3136     const std::function<bool(const HloComputation*, const HloComputation*)>&
3137     /*eq_computations*/) const {
3138   const auto& casted_other =
3139       static_cast<const HloSetDimensionSizeInstruction&>(other);
3140   return dimension() == casted_other.dimension();
3141 }
3142 
3143 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3144 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3145     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3146     HloCloneContext* /*context*/) const {
3147   if (new_operands.size() != 2) {
3148     LOG(FATAL) << "expects 2 operand";
3149   }
3150   return absl::make_unique<HloSetDimensionSizeInstruction>(
3151       shape, new_operands[0], new_operands[1], dimension());
3152 }
3153 
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64 delta)3154 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3155     const Shape& shape, int64 delta)
3156     : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3157 
ToProto() const3158 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3159   HloInstructionProto proto = HloInstruction::ToProto();
3160   proto.set_delta(delta_);
3161   return proto;
3162 }
3163 
3164 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3165 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3166     const HloPrintOptions& /*options*/) const {
3167   return {StrCat("delta=", delta())};
3168 }
3169 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3170 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3171     const HloInstruction& other,
3172     const std::function<bool(const HloComputation*, const HloComputation*)>&
3173     /*eq_computations*/) const {
3174   const auto& casted_other =
3175       static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3176   return delta() == casted_other.delta();
3177 }
3178 
3179 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3180 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3181     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3182     HloCloneContext* /*context*/) const {
3183   if (!new_operands.empty()) {
3184     LOG(FATAL) << "expects 0 operand";
3185   }
3186   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3187 }
3188 
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3189 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3190     const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3191     : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3192       algorithm_(algorithm) {
3193   AppendOperand(state);
3194 }
3195 
ToProto() const3196 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3197   HloInstructionProto proto = HloInstruction::ToProto();
3198   proto.set_rng_algorithm(algorithm_);
3199   return proto;
3200 }
3201 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3202 std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3203     const HloPrintOptions& options) const {
3204   return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3205 }
3206 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3207 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3208     const HloInstruction& other,
3209     const std::function<bool(const HloComputation*, const HloComputation*)>&
3210         eq_computations) const {
3211   const auto& casted_other =
3212       static_cast<const HloRngBitGeneratorInstruction&>(other);
3213   return algorithm() == casted_other.algorithm();
3214 }
3215 
3216 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3217 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3218     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3219     HloCloneContext* /*context*/) const {
3220   CHECK_EQ(new_operands.size(), 1);
3221   return absl::make_unique<HloRngBitGeneratorInstruction>(
3222       shape, new_operands[0], algorithm());
3223 }
3224 
3225 }  // namespace xla
3226