1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17
18 #include <deque>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
32 #include "tensorflow/compiler/xla/window_util.h"
33 #include "tensorflow/core/platform/protobuf.h"
34
35 namespace xla {
36 namespace {
37
38 using absl::CEscape;
39 using absl::StrAppend;
40 using absl::StrCat;
41 using absl::StrJoin;
42
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)43 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
44 const HloInstruction* operand) {
45 std::vector<int64> operand_indices = instruction->OperandIndices(operand);
46 return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
47 return instruction->IsElementwiseOnOperand(operand_index);
48 });
49 }
50
PrecisionConfigToString(const PrecisionConfig & precision_config)51 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
52 if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
53 return static_cast<PrecisionConfig::Precision>(precision) ==
54 PrecisionConfig::DEFAULT;
55 })) {
56 return "";
57 }
58
59 return StrCat(
60 "operand_precision={",
61 StrJoin(
62 precision_config.operand_precision(), ",",
63 [](string* out, int32 precision) {
64 CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
65 StrAppend(out,
66 PrecisionToString(
67 static_cast<PrecisionConfig::Precision>(precision)));
68 }),
69 "}");
70 }
71 } // namespace
72
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)73 HloBatchNormInstruction::HloBatchNormInstruction(
74 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
75 HloInstruction* scale, float epsilon, int64 feature_index)
76 : HloInstruction(opcode, shape),
77 epsilon_(epsilon),
78 feature_index_(feature_index) {
79 AppendOperand(operand);
80 AppendOperand(scale);
81 }
82
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const83 bool HloBatchNormInstruction::IdenticalSlowPath(
84 const HloInstruction& other,
85 const std::function<bool(const HloComputation*, const HloComputation*)>&
86 eq_computations) const {
87 const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
88 return feature_index() == casted_other.feature_index() &&
89 epsilon() == casted_other.epsilon();
90 }
91
ToProto() const92 HloInstructionProto HloBatchNormInstruction::ToProto() const {
93 HloInstructionProto proto = HloInstruction::ToProto();
94 proto.set_epsilon(epsilon_);
95 proto.set_feature_index(feature_index_);
96 return proto;
97 }
98
ExtraAttributesToStringImpl(const HloPrintOptions & options) const99 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
100 const HloPrintOptions& options) const {
101 return {StrCat("epsilon=", epsilon()),
102 StrCat("feature_index=", feature_index())};
103 }
104
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)105 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
106 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
107 HloInstruction* offset, float epsilon, int64 feature_index)
108 : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
109 scale, epsilon, feature_index) {
110 AppendOperand(offset);
111 }
112
113 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const114 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
115 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
116 HloCloneContext* context) const {
117 CHECK_EQ(new_operands.size(), 3);
118 return absl::make_unique<HloBatchNormTrainingInstruction>(
119 shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
120 feature_index());
121 }
122
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)123 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
124 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
125 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
126 float epsilon, int64 feature_index)
127 : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
128 scale, epsilon, feature_index) {
129 AppendOperand(offset);
130 AppendOperand(mean);
131 AppendOperand(variance);
132 }
133
134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const135 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
136 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
137 HloCloneContext* context) const {
138 CHECK_EQ(new_operands.size(), 5);
139 return absl::make_unique<HloBatchNormInferenceInstruction>(
140 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
141 new_operands[4], epsilon(), feature_index());
142 }
143
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)144 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
145 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
146 HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
147 float epsilon, int64 feature_index)
148 : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
149 epsilon, feature_index) {
150 AppendOperand(mean);
151 AppendOperand(variance);
152 AppendOperand(grad_output);
153 }
154
155 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const156 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
157 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
158 HloCloneContext* context) const {
159 CHECK_EQ(new_operands.size(), 5);
160 return absl::make_unique<HloBatchNormGradInstruction>(
161 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
162 new_operands[4], epsilon(), feature_index());
163 }
164
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)165 HloFftInstruction::HloFftInstruction(const Shape& shape,
166 HloInstruction* operand, FftType fft_type,
167 absl::Span<const int64> fft_length)
168 : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
169 fft_length_.assign(fft_length.begin(), fft_length.end());
170 AppendOperand(operand);
171 }
172
ToProto() const173 HloInstructionProto HloFftInstruction::ToProto() const {
174 HloInstructionProto proto = HloInstruction::ToProto();
175 proto.set_fft_type(fft_type_);
176 for (int64 fft_len : fft_length_) {
177 proto.add_fft_length(fft_len);
178 }
179 return proto;
180 }
181
ExtraAttributesToStringImpl(const HloPrintOptions & options) const182 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
183 const HloPrintOptions& options) const {
184 return {StrCat("fft_type=", FftType_Name(fft_type())),
185 StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
186 }
187
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const188 bool HloFftInstruction::IdenticalSlowPath(
189 const HloInstruction& other,
190 const std::function<bool(const HloComputation*, const HloComputation*)>&
191 eq_computations) const {
192 const auto& casted_other = static_cast<const HloFftInstruction&>(other);
193 return fft_type() == casted_other.fft_type() &&
194 fft_length() == casted_other.fft_length();
195 }
196
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const197 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
198 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
199 HloCloneContext* context) const {
200 CHECK_EQ(new_operands.size(), 1);
201 return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
202 fft_length_);
203 }
204
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)205 HloCompareInstruction::HloCompareInstruction(const Shape& shape,
206 HloInstruction* lhs,
207 HloInstruction* rhs,
208 ComparisonDirection direction)
209 : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
210 AppendOperand(lhs);
211 AppendOperand(rhs);
212 }
213
ToProto() const214 HloInstructionProto HloCompareInstruction::ToProto() const {
215 HloInstructionProto proto = HloInstruction::ToProto();
216 proto.set_comparison_direction(ComparisonDirectionToString(direction_));
217 return proto;
218 }
219
ExtraAttributesToStringImpl(const HloPrintOptions & options) const220 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
221 const HloPrintOptions& options) const {
222 return {StrCat("direction=", ComparisonDirectionToString(direction()))};
223 }
224
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const225 bool HloCompareInstruction::IdenticalSlowPath(
226 const HloInstruction& other,
227 const std::function<bool(const HloComputation*, const HloComputation*)>&
228 eq_computations) const {
229 const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
230 return direction() == casted_other.direction();
231 }
232
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const233 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
234 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
235 HloCloneContext* context) const {
236 CHECK_EQ(new_operands.size(), 2);
237 return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
238 new_operands[1], direction());
239 }
240
241 namespace {
242
243 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
244 // of "key=value" attribute strings generically, using protocol buffer
245 // reflection.
246 //
247 // Currently implements a small subset of cases; feel free to add more as
248 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)249 std::vector<string> AttributeProtoToStringVector(
250 const tensorflow::protobuf::Message& message) {
251 const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
252 std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
253 reflection->ListFields(message, &fields);
254
255 std::vector<string> output;
256 for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
257 string s = absl::StrCat(field->name(), "=");
258 CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
259 switch (field->type()) {
260 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
261 bool val = reflection->GetBool(message, field);
262 absl::StrAppend(&s, val ? "true" : "false");
263 break;
264 }
265 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
266 const tensorflow::protobuf::EnumValueDescriptor* evd =
267 reflection->GetEnum(message, field);
268 absl::StrAppend(&s, evd->name());
269 break;
270 }
271 default:
272 LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
273 }
274 output.push_back(std::move(s));
275 }
276 return output;
277 }
278
279 } // namespace
280
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)281 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
282 const Shape& shape, HloInstruction* a, HloInstruction* b,
283 const TriangularSolveOptions& options)
284 : HloInstruction(HloOpcode::kTriangularSolve, shape),
285 triangular_solve_options_(options) {
286 AppendOperand(a);
287 AppendOperand(b);
288 }
289
ToProto() const290 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
291 HloInstructionProto proto = HloInstruction::ToProto();
292 *proto.mutable_triangular_solve_options() = triangular_solve_options_;
293 return proto;
294 }
295
ExtraAttributesToStringImpl(const HloPrintOptions & options) const296 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
297 const HloPrintOptions& options) const {
298 return AttributeProtoToStringVector(triangular_solve_options_);
299 }
300
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const301 bool HloTriangularSolveInstruction::IdenticalSlowPath(
302 const HloInstruction& other,
303 const std::function<bool(const HloComputation*, const HloComputation*)>&
304 eq_computations) const {
305 const auto& casted_other =
306 static_cast<const HloTriangularSolveInstruction&>(other);
307 const auto& options = triangular_solve_options();
308 const auto& other_options = casted_other.triangular_solve_options();
309
310 return options.left_side() == other_options.left_side() &&
311 options.lower() == other_options.lower() &&
312 options.unit_diagonal() == other_options.unit_diagonal() &&
313 options.transpose_a() == other_options.transpose_a();
314 }
315
316 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const317 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
318 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
319 HloCloneContext* context) const {
320 CHECK_EQ(new_operands.size(), 2);
321 return absl::make_unique<HloTriangularSolveInstruction>(
322 shape, new_operands[0], new_operands[1], triangular_solve_options());
323 }
324
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)325 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
326 HloInstruction* a,
327 const CholeskyOptions& options)
328 : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
329 AppendOperand(a);
330 }
331
ToProto() const332 HloInstructionProto HloCholeskyInstruction::ToProto() const {
333 HloInstructionProto proto = HloInstruction::ToProto();
334 *proto.mutable_cholesky_options() = cholesky_options_;
335 return proto;
336 }
337
ExtraAttributesToStringImpl(const HloPrintOptions & options) const338 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
339 const HloPrintOptions& options) const {
340 return AttributeProtoToStringVector(cholesky_options_);
341 }
342
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const343 bool HloCholeskyInstruction::IdenticalSlowPath(
344 const HloInstruction& other,
345 const std::function<bool(const HloComputation*, const HloComputation*)>&
346 eq_computations) const {
347 const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
348 const auto& options = cholesky_options();
349 const auto& other_options = casted_other.cholesky_options();
350
351 return options.lower() == other_options.lower();
352 }
353
354 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const355 HloCholeskyInstruction::CloneWithNewOperandsImpl(
356 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
357 HloCloneContext* context) const {
358 CHECK_EQ(new_operands.size(), 1);
359 return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
360 cholesky_options());
361 }
362
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)363 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
364 const Shape& shape,
365 int64 channel_id,
366 bool is_host_transfer)
367 : HloInstruction(opcode, shape),
368 channel_id_(channel_id),
369 is_host_transfer_(is_host_transfer) {}
370
ToProto() const371 HloInstructionProto HloSendRecvInstruction::ToProto() const {
372 HloInstructionProto proto = HloInstruction::ToProto();
373 proto.set_channel_id(channel_id_);
374 proto.set_is_host_transfer(is_host_transfer_);
375 return proto;
376 }
377
ExtraAttributesToStringImpl(const HloPrintOptions & options) const378 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
379 const HloPrintOptions& options) const {
380 std::vector<string> attrs;
381 attrs.push_back(StrCat("channel_id=", channel_id_));
382 if (is_host_transfer()) {
383 attrs.push_back("is_host_transfer=true");
384 }
385 return attrs;
386 }
387
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const388 bool HloSendRecvInstruction::IdenticalSlowPath(
389 const HloInstruction& other,
390 const std::function<bool(const HloComputation*, const HloComputation*)>&
391 eq_computations) const {
392 // Not yet supported.
393 return false;
394 }
395
396 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)397 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
398 HloInstruction* token, int64 channel_id,
399 bool is_host_transfer)
400 : HloSendRecvInstruction(
401 HloOpcode::kSend,
402 ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
403 ShapeUtil::MakeShape(U32, {}),
404 ShapeUtil::MakeTokenShape()}),
405 channel_id, is_host_transfer) {
406 AppendOperand(operand);
407 AppendOperand(token);
408 }
409
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const410 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
411 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
412 HloCloneContext* context) const {
413 CHECK_EQ(new_operands.size(), 2);
414 return absl::make_unique<HloSendInstruction>(
415 new_operands[0], new_operands[1], channel_id(), is_host_transfer());
416 }
417
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)418 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
419 bool is_host_transfer)
420 : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
421 CHECK_NOTNULL(operand)->channel_id(),
422 is_host_transfer) {
423 AppendOperand(operand);
424 }
425
426 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const427 HloSendDoneInstruction::CloneWithNewOperandsImpl(
428 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
429 HloCloneContext* context) const {
430 CHECK_EQ(new_operands.size(), 1);
431 return absl::make_unique<HloSendDoneInstruction>(
432 Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
433 }
434
435 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)436 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
437 HloInstruction* token, int64 channel_id,
438 bool is_host_transfer)
439 : HloSendRecvInstruction(
440 HloOpcode::kRecv,
441 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
442 ShapeUtil::MakeTokenShape()}),
443 channel_id, is_host_transfer) {
444 AppendOperand(token);
445 }
446
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const447 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
448 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
449 HloCloneContext* context) const {
450 CHECK_EQ(new_operands.size(), 1);
451 return absl::make_unique<HloRecvInstruction>(
452 ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
453 is_host_transfer());
454 }
455
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)456 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
457 bool is_host_transfer)
458 : HloSendRecvInstruction(
459 HloOpcode::kRecvDone,
460 ShapeUtil::MakeTupleShape(
461 {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
462 ShapeUtil::MakeTokenShape()}),
463 CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
464 AppendOperand(operand);
465 }
466
467 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const468 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
469 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
470 HloCloneContext* context) const {
471 CHECK_EQ(new_operands.size(), 1);
472 return absl::make_unique<HloRecvDoneInstruction>(
473 Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
474 }
475
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)476 HloCollectiveInstruction::HloCollectiveInstruction(
477 HloOpcode opcode, const Shape& shape,
478 absl::Span<HloInstruction* const> operands,
479 const std::vector<ReplicaGroup>& replica_groups)
480 : HloInstruction(opcode, shape), replica_groups_(replica_groups) {
481 for (auto operand : operands) {
482 AppendOperand(operand);
483 }
484 }
485
ToProto() const486 HloInstructionProto HloCollectiveInstruction::ToProto() const {
487 HloInstructionProto proto = HloInstruction::ToProto();
488 *proto.mutable_replica_groups() = {replica_groups_.begin(),
489 replica_groups_.end()};
490 return proto;
491 }
492
ExtraAttributesToStringImpl(const HloPrintOptions &) const493 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
494 const HloPrintOptions& /*options*/) const {
495 std::vector<string> result;
496 std::vector<string> replica_group_str;
497 for (const ReplicaGroup& group : replica_groups()) {
498 replica_group_str.push_back(
499 StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
500 }
501 result.push_back(
502 StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}"));
503 return result;
504 }
505
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const506 bool HloCollectiveInstruction::IdenticalSlowPath(
507 const HloInstruction& other,
508 const std::function<bool(const HloComputation*, const HloComputation*)>&
509 /*eq_computations*/) const {
510 const auto& casted_other =
511 static_cast<const HloCollectiveInstruction&>(other);
512 return absl::c_equal(replica_groups(), casted_other.replica_groups(),
513 [](const ReplicaGroup& a, const ReplicaGroup& b) {
514 return absl::c_equal(a.replica_ids(), b.replica_ids());
515 });
516 }
517
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,absl::string_view barrier,const absl::optional<int64> & all_reduce_id)518 HloAllReduceInstruction::HloAllReduceInstruction(
519 const Shape& shape, absl::Span<HloInstruction* const> operands,
520 HloComputation* reduce_computation,
521 const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
522 const absl::optional<int64>& all_reduce_id)
523 : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
524 replica_groups),
525 all_reduce_barrier_(barrier),
526 all_reduce_id_(all_reduce_id) {
527 AppendComputation(reduce_computation);
528 }
529
set_all_reduce_id(const absl::optional<int64> & all_reduce_id)530 void HloAllReduceInstruction::set_all_reduce_id(
531 const absl::optional<int64>& all_reduce_id) {
532 all_reduce_id_ = all_reduce_id;
533 }
534
ToProto() const535 HloInstructionProto HloAllReduceInstruction::ToProto() const {
536 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
537 // Proto3 is so sad.
538 if (all_reduce_id_) {
539 proto.set_all_reduce_id(*all_reduce_id_);
540 }
541 proto.set_all_reduce_barrier(all_reduce_barrier_);
542 return proto;
543 }
544
IsNoop() const545 bool HloAllReduceInstruction::IsNoop() const {
546 for (auto replica_group : replica_groups()) {
547 if (replica_group.replica_ids().size() != 1) {
548 return false;
549 }
550 }
551 return !all_reduce_id();
552 }
553
ExtraAttributesToStringImpl(const HloPrintOptions & options) const554 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
555 const HloPrintOptions& options) const {
556 std::vector<string> result =
557 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
558 if (!all_reduce_barrier().empty()) {
559 result.push_back(StrCat("barrier=\"", all_reduce_barrier(), "\""));
560 }
561 if (all_reduce_id_) {
562 result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
563 }
564 return result;
565 }
566
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const567 bool HloAllReduceInstruction::IdenticalSlowPath(
568 const HloInstruction& other,
569 const std::function<bool(const HloComputation*, const HloComputation*)>&
570 eq_computations) const {
571 const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
572 return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
573 eq_computations(to_apply(), casted_other.to_apply()) &&
574 all_reduce_barrier() == casted_other.all_reduce_barrier() &&
575 all_reduce_id() == casted_other.all_reduce_id();
576 }
577
578 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const579 HloAllReduceInstruction::CloneWithNewOperandsImpl(
580 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
581 HloCloneContext* /*context*/) const {
582 return absl::make_unique<HloAllReduceInstruction>(
583 shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(),
584 all_reduce_id());
585 }
586
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)587 HloAllToAllInstruction::HloAllToAllInstruction(
588 const Shape& shape, absl::Span<HloInstruction* const> operands,
589 const std::vector<ReplicaGroup>& replica_groups)
590 : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
591 replica_groups) {}
592
593 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const594 HloAllToAllInstruction::CloneWithNewOperandsImpl(
595 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
596 HloCloneContext* /*context*/) const {
597 return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
598 replica_groups());
599 }
600
HloCollectivePermuteInstruction(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)601 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
602 const Shape& shape, HloInstruction* operand,
603 const std::vector<std::pair<int64, int64>>& source_target_pairs)
604 : HloInstruction(HloOpcode::kCollectivePermute, shape),
605 source_target_pairs_(source_target_pairs) {
606 AppendOperand(operand);
607 }
608
ToProto() const609 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
610 HloInstructionProto proto = HloInstruction::ToProto();
611 for (const auto& pair : source_target_pairs()) {
612 auto* proto_pair = proto.add_source_target_pairs();
613 proto_pair->set_source(pair.first);
614 proto_pair->set_target(pair.second);
615 }
616 return proto;
617 }
618
619 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const620 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
621 const HloPrintOptions& /*options*/) const {
622 std::vector<string> result;
623 std::vector<string> strs;
624 for (const auto& pair : source_target_pairs()) {
625 strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
626 }
627 result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
628 return result;
629 }
630
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const631 bool HloCollectivePermuteInstruction::IdenticalSlowPath(
632 const HloInstruction& other,
633 const std::function<bool(const HloComputation*, const HloComputation*)>&
634 /*eq_computations*/) const {
635 const auto& casted_other =
636 static_cast<const HloCollectivePermuteInstruction&>(other);
637 return absl::c_equal(source_target_pairs(),
638 casted_other.source_target_pairs(),
639 [](const std::pair<int64, int64>& a,
640 const std::pair<int64, int64>& b) { return a == b; });
641 }
642
643 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const644 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
645 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
646 HloCloneContext* /*context*/) const {
647 return absl::make_unique<HloCollectivePermuteInstruction>(
648 shape, new_operands[0], source_target_pairs());
649 }
650
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)651 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
652 HloInstruction* operand,
653 absl::Span<const int64> dimensions)
654 : HloInstruction(HloOpcode::kReverse, shape),
655 dimensions_(dimensions.begin(), dimensions.end()) {
656 AppendOperand(operand);
657 }
658
ToProto() const659 HloInstructionProto HloReverseInstruction::ToProto() const {
660 HloInstructionProto proto = HloInstruction::ToProto();
661 for (int64 dimension : dimensions_) {
662 proto.add_dimensions(dimension);
663 }
664 return proto;
665 }
666
ExtraAttributesToStringImpl(const HloPrintOptions & options) const667 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
668 const HloPrintOptions& options) const {
669 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
670 }
671
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const672 bool HloReverseInstruction::IdenticalSlowPath(
673 const HloInstruction& other,
674 const std::function<bool(const HloComputation*, const HloComputation*)>&
675 eq_computations) const {
676 const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
677 return dimensions() == casted_other.dimensions();
678 }
679
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const680 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
681 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
682 HloCloneContext* context) const {
683 CHECK_EQ(new_operands.size(), 1);
684 return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
685 dimensions());
686 }
687
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)688 HloConcatenateInstruction::HloConcatenateInstruction(
689 const Shape& shape, absl::Span<HloInstruction* const> operands,
690 int64 dimension)
691 : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
692 for (auto operand : operands) {
693 AppendOperand(operand);
694 }
695 }
696
ToProto() const697 HloInstructionProto HloConcatenateInstruction::ToProto() const {
698 HloInstructionProto proto = HloInstruction::ToProto();
699 for (int64 dimension : dimensions_) {
700 proto.add_dimensions(dimension);
701 }
702 return proto;
703 }
704
ExtraAttributesToStringImpl(const HloPrintOptions & options) const705 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
706 const HloPrintOptions& options) const {
707 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
708 }
709
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const710 bool HloConcatenateInstruction::IdenticalSlowPath(
711 const HloInstruction& other,
712 const std::function<bool(const HloComputation*, const HloComputation*)>&
713 eq_computations) const {
714 const auto& casted_other =
715 static_cast<const HloConcatenateInstruction&>(other);
716 return dimensions() == casted_other.dimensions();
717 }
718
719 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const720 HloConcatenateInstruction::CloneWithNewOperandsImpl(
721 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
722 HloCloneContext* context) const {
723 return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
724 dimensions(0));
725 }
726
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)727 HloReduceInstruction::HloReduceInstruction(
728 const Shape& shape, absl::Span<HloInstruction* const> args,
729 absl::Span<const int64> dimensions_to_reduce,
730 HloComputation* reduce_computation)
731 : HloInstruction(HloOpcode::kReduce, shape),
732 dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
733 for (HloInstruction* arg : args) {
734 AppendOperand(arg);
735 }
736 AppendComputation(reduce_computation);
737 }
738
ToProto() const739 HloInstructionProto HloReduceInstruction::ToProto() const {
740 HloInstructionProto proto = HloInstruction::ToProto();
741 for (int64 dimension : dimensions_) {
742 proto.add_dimensions(dimension);
743 }
744 return proto;
745 }
746
ExtraAttributesToStringImpl(const HloPrintOptions & options) const747 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
748 const HloPrintOptions& options) const {
749 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
750 }
751
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const752 bool HloReduceInstruction::IdenticalSlowPath(
753 const HloInstruction& other,
754 const std::function<bool(const HloComputation*, const HloComputation*)>&
755 eq_computations) const {
756 const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
757 // Reduction results are determined by the reduction dimension and the
758 // reduction computation.
759 return dimensions() == casted_other.dimensions() &&
760 eq_computations(to_apply(), casted_other.to_apply());
761 }
762
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const763 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
764 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
765 HloCloneContext* context) const {
766 CHECK_EQ(new_operands.size() % 2, 0);
767 return absl::make_unique<HloReduceInstruction>(shape, new_operands,
768 dimensions(), to_apply());
769 }
770
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)771 HloSortInstruction::HloSortInstruction(
772 const Shape& shape, int64 dimension,
773 absl::Span<HloInstruction* const> operands, HloComputation* compare,
774 bool is_stable)
775 : HloInstruction(HloOpcode::kSort, shape),
776 dimensions_({dimension}),
777 is_stable_(is_stable) {
778 for (auto* value : operands) {
779 AppendOperand(value);
780 }
781 AppendComputation(compare);
782 }
783
ToProto() const784 HloInstructionProto HloSortInstruction::ToProto() const {
785 HloInstructionProto proto = HloInstruction::ToProto();
786 for (int64 dimension : dimensions_) {
787 proto.add_dimensions(dimension);
788 }
789 proto.set_is_stable(is_stable());
790 return proto;
791 }
792
ExtraAttributesToStringImpl(const HloPrintOptions & options) const793 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
794 const HloPrintOptions& options) const {
795 std::vector<string> attrs;
796 attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
797 if (is_stable()) {
798 attrs.push_back("is_stable=true");
799 }
800 return attrs;
801 }
802
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const803 bool HloSortInstruction::IdenticalSlowPath(
804 const HloInstruction& other,
805 const std::function<bool(const HloComputation*, const HloComputation*)>&
806 eq_computations) const {
807 const auto& casted_other = static_cast<const HloSortInstruction&>(other);
808 if (dimensions() != casted_other.dimensions()) {
809 return false;
810 }
811 if (is_stable() != casted_other.is_stable()) {
812 return false;
813 }
814 return eq_computations(to_apply(), other.to_apply());
815 }
816
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const817 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
818 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
819 HloCloneContext* context) const {
820 return absl::make_unique<HloSortInstruction>(
821 shape, dimensions(0), new_operands, to_apply(), is_stable());
822 }
823
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)824 HloTransposeInstruction::HloTransposeInstruction(
825 const Shape& shape, HloInstruction* operand,
826 absl::Span<const int64> dimensions)
827 : HloInstruction(HloOpcode::kTranspose, shape),
828 dimensions_(dimensions.begin(), dimensions.end()) {
829 AppendOperand(operand);
830 }
831
IsRank2Transpose() const832 bool HloTransposeInstruction::IsRank2Transpose() const {
833 return dimensions() == std::vector<int64>({1, 0}) &&
834 shape().dimensions_size() == 2 &&
835 std::equal(shape().dimensions().begin(), shape().dimensions().end(),
836 operand(0)->shape().dimensions().rbegin());
837 }
838
ToProto() const839 HloInstructionProto HloTransposeInstruction::ToProto() const {
840 HloInstructionProto proto = HloInstruction::ToProto();
841 for (int64 dimension : dimensions_) {
842 proto.add_dimensions(dimension);
843 }
844 return proto;
845 }
846
ExtraAttributesToStringImpl(const HloPrintOptions & options) const847 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
848 const HloPrintOptions& options) const {
849 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
850 }
851
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const852 bool HloTransposeInstruction::IdenticalSlowPath(
853 const HloInstruction& other,
854 const std::function<bool(const HloComputation*, const HloComputation*)>&
855 eq_computations) const {
856 const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
857 return dimensions() == casted_other.dimensions();
858 }
859
860 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const861 HloTransposeInstruction::CloneWithNewOperandsImpl(
862 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
863 HloCloneContext* context) const {
864 CHECK_EQ(new_operands.size(), 1);
865 return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
866 dimensions());
867 }
868
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)869 HloBroadcastInstruction::HloBroadcastInstruction(
870 const Shape& shape, HloInstruction* operand,
871 absl::Span<const int64> broadcast_dimension)
872 : HloInstruction(HloOpcode::kBroadcast, shape),
873 dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
874 AppendOperand(operand);
875 }
876
ToProto() const877 HloInstructionProto HloBroadcastInstruction::ToProto() const {
878 HloInstructionProto proto = HloInstruction::ToProto();
879 for (int64 dimension : dimensions_) {
880 proto.add_dimensions(dimension);
881 }
882 return proto;
883 }
884
ExtraAttributesToStringImpl(const HloPrintOptions & options) const885 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
886 const HloPrintOptions& options) const {
887 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
888 }
889
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const890 bool HloBroadcastInstruction::IdenticalSlowPath(
891 const HloInstruction& other,
892 const std::function<bool(const HloComputation*, const HloComputation*)>&
893 eq_computations) const {
894 const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
895 return dimensions() == casted_other.dimensions();
896 }
897
898 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const899 HloBroadcastInstruction::CloneWithNewOperandsImpl(
900 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
901 HloCloneContext* context) const {
902 CHECK_EQ(new_operands.size(), 1);
903 return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
904 dimensions());
905 }
906
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)907 HloMapInstruction::HloMapInstruction(const Shape& shape,
908 absl::Span<HloInstruction* const> operands,
909 HloComputation* map_computation)
910 : HloInstruction(HloOpcode::kMap, shape) {
911 for (auto operand : operands) {
912 AppendOperand(operand);
913 }
914 AppendComputation(map_computation);
915 // TODO(b/65689298) Remove code below once Map is generalized to accept
916 // arbitrary map dimensions.
917 dimensions_.resize(shape.rank());
918 std::iota(dimensions_.begin(), dimensions_.end(), 0);
919 }
920
ToProto() const921 HloInstructionProto HloMapInstruction::ToProto() const {
922 HloInstructionProto proto = HloInstruction::ToProto();
923 for (int64 dimension : dimensions_) {
924 proto.add_dimensions(dimension);
925 }
926 return proto;
927 }
928
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const929 bool HloMapInstruction::IsElementwiseImpl(
930 const absl::optional<int64>& operand_idx) const {
931 if (!dimensions().empty()) {
932 // Check that the map is executed in elementwise compatible dimensions.
933 if (dimensions().size() != shape().dimensions_size()) {
934 return false;
935 }
936 for (int i = 0; i < dimensions().size(); ++i) {
937 if (dimensions()[i] != i) {
938 return false;
939 }
940 }
941 }
942 return true;
943 }
944
ExtraAttributesToStringImpl(const HloPrintOptions & options) const945 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
946 const HloPrintOptions& options) const {
947 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
948 }
949
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const950 bool HloMapInstruction::IdenticalSlowPath(
951 const HloInstruction& other,
952 const std::function<bool(const HloComputation*, const HloComputation*)>&
953 eq_computations) const {
954 return eq_computations(to_apply(), other.to_apply());
955 }
956
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const957 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
958 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
959 HloCloneContext* context) const {
960 return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
961 }
962
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)963 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
964 HloInstruction* operand,
965 absl::Span<const int64> start_indices,
966 absl::Span<const int64> limit_indices,
967 absl::Span<const int64> strides)
968 : HloInstruction(HloOpcode::kSlice, shape),
969 slice_starts_(start_indices.begin(), start_indices.end()),
970 slice_limits_(limit_indices.begin(), limit_indices.end()),
971 slice_strides_(strides.begin(), strides.end()) {
972 AppendOperand(operand);
973 // For backward compatibility with old serialized computations: if there are
974 // no strides, assume all strides are 1.
975 // TODO(b/63317920): remove this code.
976 if (slice_strides_.empty()) {
977 slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
978 }
979 }
980
ToProto() const981 HloInstructionProto HloSliceInstruction::ToProto() const {
982 HloInstructionProto proto = HloInstruction::ToProto();
983 for (int i = 0; i < slice_starts_.size(); ++i) {
984 auto* slice_dimension = proto.add_slice_dimensions();
985 slice_dimension->set_start(slice_starts_[i]);
986 slice_dimension->set_limit(slice_limits_[i]);
987 slice_dimension->set_stride(slice_strides_[i]);
988 }
989 return proto;
990 }
991
ExtraAttributesToStringImpl(const HloPrintOptions & options) const992 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
993 const HloPrintOptions& options) const {
994 std::vector<string> bounds;
995 bounds.reserve(slice_starts_.size());
996 const bool omit_stride =
997 absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
998 for (int i = 0; i < slice_starts_.size(); ++i) {
999 string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1000 bounds.push_back(
1001 StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1002 }
1003 return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1004 }
1005
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1006 bool HloSliceInstruction::IdenticalSlowPath(
1007 const HloInstruction& other,
1008 const std::function<bool(const HloComputation*, const HloComputation*)>&
1009 eq_computations) const {
1010 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1011 return slice_starts_ == other_slice.slice_starts_ &&
1012 slice_limits_ == other_slice.slice_limits_ &&
1013 slice_strides_ == other_slice.slice_strides_;
1014 }
1015
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1016 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1017 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1018 HloCloneContext* context) const {
1019 CHECK_EQ(new_operands.size(), 1);
1020 return absl::make_unique<HloSliceInstruction>(
1021 shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1022 }
1023
HloConstantInstruction(Literal literal)1024 HloConstantInstruction::HloConstantInstruction(Literal literal)
1025 : HloInstruction(HloOpcode::kConstant, literal.shape()),
1026 literal_(std::move(literal)) {}
1027
HloConstantInstruction(const Shape & shape)1028 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1029 : HloInstruction(HloOpcode::kConstant, shape) {}
1030
ToProto() const1031 HloInstructionProto HloConstantInstruction::ToProto() const {
1032 HloInstructionProto proto = HloInstruction::ToProto();
1033 if (literal_.has_value()) {
1034 *proto.mutable_literal() = literal_->ToProto();
1035 }
1036 return proto;
1037 }
1038
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1039 bool HloConstantInstruction::IsElementwiseImpl(
1040 const absl::optional<int64>& operand_idx) const {
1041 return true;
1042 }
1043
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1044 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1045 const ShapeIndex& shape_index) {
1046 Shape* mutable_array_subshape =
1047 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1048 CHECK(mutable_array_subshape->IsArray());
1049
1050 // Normally array_subshape will always have a layout, but this invariant is
1051 // temporarily broken in LayoutAssignment::AssignLayouts.
1052
1053 if (!mutable_array_subshape->has_layout() ||
1054 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1055 *literal_ = literal_->Relayout(new_layout, shape_index);
1056 *mutable_array_subshape->mutable_layout() = new_layout;
1057 }
1058 }
1059
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1060 bool HloConstantInstruction::IdenticalSlowPath(
1061 const HloInstruction& other,
1062 const std::function<bool(const HloComputation*, const HloComputation*)>&
1063 eq_computations) const {
1064 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1065 return literal() == other_slice.literal();
1066 }
1067
1068 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1069 HloConstantInstruction::CloneWithNewOperandsImpl(
1070 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1071 HloCloneContext* context) const {
1072 CHECK(literal_.has_value());
1073 return absl::make_unique<HloConstantInstruction>(literal_->Clone());
1074 }
1075
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1076 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1077 const HloPrintOptions& options,
1078 CanonicalNameMap* canonical_name_map) const {
1079 string operands;
1080 // For constants, show the actual value in place of an empty operand list.
1081 if (literal_.has_value() &&
1082 ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1083 options.print_large_constants())) {
1084 // Literal::ToString emits multidimensional arrays over multiple
1085 // lines. Compact this into one line by stripping out white space.
1086 string tmp = literal().ToStringWithoutShape();
1087 std::replace(tmp.begin(), tmp.end(), '\n', ' ');
1088 std::vector<string> v = absl::StrSplit(tmp, ' ');
1089 bool first = true;
1090 // Concatenate elements in "v" with spaces separating them, but ignoring
1091 // empty entries.
1092 for (const auto& s : v) {
1093 if (s.empty()) {
1094 continue;
1095 }
1096 StrAppend(&operands, (first ? "" : " "), s);
1097 first = false;
1098 }
1099 } else {
1100 // Do not show large constants or tuples.
1101 operands = "{...}";
1102 }
1103 return operands;
1104 }
1105
HloTraceInstruction(const string & tag,HloInstruction * operand)1106 HloTraceInstruction::HloTraceInstruction(const string& tag,
1107 HloInstruction* operand)
1108 : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1109 literal_(LiteralUtil::CreateR1U8(tag)) {
1110 AppendOperand(operand);
1111 operand->set_tracing(this);
1112 }
1113
ToProto() const1114 HloInstructionProto HloTraceInstruction::ToProto() const {
1115 HloInstructionProto proto = HloInstruction::ToProto();
1116 *proto.mutable_literal() = literal_.ToProto();
1117 return proto;
1118 }
1119
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1120 bool HloTraceInstruction::IdenticalSlowPath(
1121 const HloInstruction& other,
1122 const std::function<bool(const HloComputation*, const HloComputation*)>&
1123 eq_computations) const {
1124 return false;
1125 }
1126
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1127 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1128 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1129 HloCloneContext* context) const {
1130 LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1131 }
1132
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1133 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1134 FusionKind fusion_kind,
1135 HloInstruction* fused_root)
1136 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1137 CHECK(fused_root != nullptr);
1138 SetAndSanitizeName("fusion");
1139 set_parent(fused_root->parent());
1140 set_metadata(fused_root->metadata());
1141 CloneAndFuseInternal(fused_root);
1142 }
1143
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1144 HloFusionInstruction::HloFusionInstruction(
1145 const Shape& shape, FusionKind fusion_kind,
1146 absl::Span<HloInstruction* const> operands,
1147 HloComputation* fusion_computation)
1148 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1149 for (auto operand : operands) {
1150 AppendOperand(operand);
1151 }
1152 SetAndSanitizeName("fusion");
1153 AppendComputation(fusion_computation);
1154 fusion_computation->SetFusionInstruction(this);
1155 }
1156
ToCategory() const1157 string HloFusionInstruction::ToCategory() const {
1158 switch (fusion_kind()) {
1159 case FusionKind::kLoop:
1160 return "loop fusion";
1161 case FusionKind::kInput:
1162 return "input fusion";
1163 case FusionKind::kOutput:
1164 return "output fusion";
1165 case FusionKind::kCustom:
1166 return "custom fusion";
1167 }
1168 }
1169
ToProto() const1170 HloInstructionProto HloFusionInstruction::ToProto() const {
1171 HloInstructionProto proto = HloInstruction::ToProto();
1172 proto.set_fusion_kind(xla::ToString(fusion_kind()));
1173 proto.add_called_computation_ids(
1174 fused_instructions_computation()->unique_id());
1175 return proto;
1176 }
1177
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1178 bool HloFusionInstruction::IsElementwiseImpl(
1179 const absl::optional<int64>& operand_idx) const {
1180 if (!operand_idx.has_value()) {
1181 for (auto* fused : fused_instructions()) {
1182 if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1183 return false;
1184 }
1185 }
1186 return true;
1187 }
1188 // A loop-fusion is elementwise on an operand if all operations (computed
1189 // using BFS) between the operand and the fused root are elementwise.
1190 std::deque<HloInstruction*> worklist;
1191 std::unordered_set<const HloInstruction*> visited;
1192 worklist.push_back(fused_parameter(operand_idx.value()));
1193 visited.insert(fused_parameter(operand_idx.value()));
1194 while (!worklist.empty()) {
1195 HloInstruction* operand = worklist.front();
1196 worklist.pop_front();
1197 for (HloInstruction* user : operand->users()) {
1198 CHECK_GE(user->unique_id(), 0);
1199 if (ContainsKey(visited, user)) {
1200 continue;
1201 }
1202 if (user->IsElementwise() ||
1203 IsInstructionElementwiseOnOperand(user, operand)) {
1204 worklist.push_back(user);
1205 visited.insert(user);
1206 } else {
1207 return false;
1208 }
1209 }
1210 }
1211 return true;
1212 }
1213
AddFusionOperand(HloInstruction * new_operand)1214 HloInstruction* HloFusionInstruction::AddFusionOperand(
1215 HloInstruction* new_operand) {
1216 CHECK_EQ(operand_count(),
1217 fused_instructions_computation()->parameter_instructions().size());
1218 const int64 param_no = operand_count();
1219 // Name the parameter after the instruction it represents in the outer
1220 // (non-fusion) computation.
1221 // string param_name = StrCat(new_operand->name(), ".param_", param_no);
1222 string param_name = StrCat("param_", param_no);
1223 HloInstruction* fused_parameter =
1224 fused_instructions_computation()->AddParameter(
1225 HloInstruction::CreateParameter(param_no, new_operand->shape(),
1226 param_name));
1227 AppendOperand(new_operand);
1228 return fused_parameter;
1229 }
1230
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1231 void HloFusionInstruction::MergeFusionInstruction(
1232 HloFusionInstruction* instruction_to_merge) {
1233 CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1234 // Clone the instruction from which to merge fused instructions.
1235 std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1236 HloFusionInstruction* cloned_fusion =
1237 static_cast<HloFusionInstruction*>(cloned.get());
1238 // Replace uses of fused parameters with the corresponding operand of the
1239 // fusion. Add all non-parameter fused instructions to
1240 // 'unfused_instructions' to be merged into 'this'. This is done in reverse
1241 // post order.
1242 std::vector<HloInstruction*> unfused_instructions;
1243 auto fused_instructions = cloned_fusion->fused_instructions_computation()
1244 ->MakeInstructionPostOrder();
1245 for (auto fused_it = fused_instructions.rbegin();
1246 fused_it != fused_instructions.rend(); ++fused_it) {
1247 auto fused_instruction = *fused_it;
1248 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1249 TF_CHECK_OK(
1250 fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1251 fused_instruction->parameter_number())));
1252 } else {
1253 unfused_instructions.push_back(fused_instruction);
1254 }
1255 }
1256 CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
1257 // Replace instruction_to_merge use of 'this' with unfused_root.
1258 TF_CHECK_OK(
1259 instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
1260 // Fuse 'unfused_instructions' into 'this'.
1261 for (auto& instruction : unfused_instructions) {
1262 FuseInstruction(instruction);
1263 }
1264 CHECK_EQ(0, cloned_fusion->user_count());
1265 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1266 cloned_fusion->fused_instructions_computation()));
1267 }
1268
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1269 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1270 HloFusionInstruction* instruction_to_merge) {
1271 // Add all non-parameter fused instructions to 'unfused_instructions' to be
1272 // merged into 'this'. `old_to_new' maps the instructions in the fused node
1273 // to the disaseembled fusion instructions.
1274 // Note that we add the unfused instructions to this->parent_ computation.
1275 // This is necessary because the unique_id needs for an instruction and
1276 // it's only added when inserting to the computation.
1277 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1278 std::vector<HloInstruction*> unfused_instructions;
1279 auto computation_to_merge =
1280 instruction_to_merge->fused_instructions_computation();
1281 auto post_order = computation_to_merge->MakeInstructionPostOrder();
1282 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1283 auto fused_instruction = *rit;
1284 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1285 InsertOrDie(&old_to_new, fused_instruction,
1286 instruction_to_merge->mutable_operand(
1287 fused_instruction->parameter_number()));
1288 continue;
1289 }
1290
1291 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1292 // which clones again. This can be improved.
1293 auto cloned_instruction =
1294 parent()->AddInstruction(fused_instruction->Clone());
1295 unfused_instructions.push_back(cloned_instruction);
1296 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1297 }
1298 for (auto unfused_instruction : unfused_instructions) {
1299 for (int64 index = 0; index < unfused_instruction->operand_count();
1300 index++) {
1301 auto new_operand =
1302 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1303 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1304 }
1305 }
1306
1307 HloInstruction* unfused_root = unfused_instructions.front();
1308 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1309
1310 TF_CHECK_OK(
1311 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1312 if (GetModule()) {
1313 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1314 }
1315
1316 // Fuse the root instruction and generate multiple outputs.
1317 FuseInstructionIntoMultiOutput(unfused_root);
1318 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1319 // The rest instructions are of normal fusing.
1320 for (int64 i = 1; i < unfused_instructions.size(); i++) {
1321 auto instruction = unfused_instructions[i];
1322 FuseInstruction(instruction);
1323 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1324 }
1325 }
1326
fused_instructions_computation() const1327 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1328 CHECK(!called_computations().empty());
1329 auto* fused_instructions_computation = called_computations().front();
1330 CHECK(fused_instructions_computation->IsFusionComputation())
1331 << "Computation " << fused_instructions_computation->name()
1332 << " is not a fusion kind";
1333 return fused_instructions_computation;
1334 }
1335
fused_expression_root() const1336 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1337 return fused_instructions_computation()->root_instruction();
1338 }
1339
fused_parameter(int64 parameter_number) const1340 HloInstruction* HloFusionInstruction::fused_parameter(
1341 int64 parameter_number) const {
1342 return fused_instructions_computation()->parameter_instruction(
1343 parameter_number);
1344 }
1345
fused_parameters() const1346 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1347 const {
1348 return fused_instructions_computation()->parameter_instructions();
1349 }
1350
1351 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1352 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1353 HloFusionInstruction::fused_instructions() const {
1354 const HloComputation* subcomp = fused_instructions_computation();
1355 return subcomp->instructions();
1356 }
1357
1358 const tensorflow::gtl::iterator_range<
1359 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1360 HloFusionInstruction::fused_instructions() {
1361 return fused_instructions_computation()->instructions();
1362 }
1363
fused_instruction_count() const1364 int64 HloFusionInstruction::fused_instruction_count() const {
1365 return fused_instructions_computation()->instruction_count();
1366 }
1367
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1368 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1369 HloInstruction* instruction_to_fuse, bool add_output) {
1370 // When add_output is false, this fusion instruction must be a user of
1371 // instruction_to_fuse.
1372 if (!add_output) {
1373 CHECK(IsUserOf(instruction_to_fuse));
1374 }
1375 HloInstruction* fused_instruction =
1376 CloneAndFuseInternal(instruction_to_fuse, add_output);
1377 return fused_instruction;
1378 }
1379
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1380 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1381 HloInstruction* instruction_to_fuse, bool add_output) {
1382 CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1383 VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1384 HloInstruction* clone = nullptr;
1385 if (called_computations().empty()) {
1386 // New fusion instruction. It should not be a multioutput instruction.
1387 CHECK(!add_output);
1388 auto builder = HloComputation::Builder("fused_computation", this);
1389 builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1390 AppendComputation(
1391 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1392 clone = fused_expression_root();
1393 } else {
1394 // When add_output is false, instruction_to_fuse is necessarily an operand
1395 // of the fusion instruction. After fusion this will no longer be the
1396 // case. Remove the operand from the operand list and remove its
1397 // corresponding fused parameter instruction. Renumber parameters as
1398 // necessary to make parameter numbers consistent with their index in the
1399 // fused_parameter_ vector.
1400 bool in_operand_list =
1401 absl::c_linear_search(operands(), instruction_to_fuse);
1402 CHECK(add_output || in_operand_list);
1403 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1404 // We assume all uses of a kTuple operation are GTE ops, not another
1405 // fusion node. In this case, we don't need to clone
1406 // 'instruction_to_fuse'.
1407 CHECK(!in_operand_list);
1408 clone = instruction_to_fuse;
1409 } else {
1410 clone = fused_instructions_computation()->AddInstruction(
1411 instruction_to_fuse->Clone(/*suffix=*/""));
1412 }
1413 const std::vector<HloInstruction*>& fused_parameters =
1414 fused_instructions_computation()->parameter_instructions();
1415 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1416 if (instruction_to_fuse == operand(operand_num)) {
1417 // replace the fused parameter instruction's uses with the clone.
1418 HloInstruction* fused_parameter = fused_parameters[operand_num];
1419 TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1420
1421 // Remove the corresponding fused parameter and operand from their
1422 // respective vectors.
1423 TF_CHECK_OK(
1424 fused_instructions_computation()->RemoveParameter(operand_num));
1425 RemoveOperandAt(operand_num);
1426 break;
1427 }
1428 }
1429 // We've cloned instruction_to_fuse into this fusion instruction, so this
1430 // fusion instruction is no longer a use of instruction_to_fuse.
1431 if (in_operand_list) {
1432 DetachFrom(instruction_to_fuse);
1433 // When the instruction_to_fuse does not have other users, we don't need
1434 // to generate a multioutput fusion instruction.
1435 if (instruction_to_fuse->user_count() == 0) {
1436 add_output = false;
1437 }
1438 }
1439 }
1440
1441 // Reread the parameters in the computation.
1442 const std::vector<HloInstruction*>& fused_parameters =
1443 fused_instructions_computation()->parameter_instructions();
1444
1445 // Add each operand of the clone as an operand of the fusion instruction. A
1446 // complication is that some clone operands may already be operands of the
1447 // fusion instruction.
1448 for (int64 operand_num = 0; operand_num < clone->operand_count();
1449 ++operand_num) {
1450 HloInstruction* operand = clone->mutable_operand(operand_num);
1451
1452 // See if this operand is already an operand of the fusion node.
1453 CHECK_EQ(operands().size(), fused_parameters.size());
1454 HloInstruction* fused_param = nullptr;
1455 for (int64 i = 0; i < operands().size(); ++i) {
1456 if (this->operand(i) == operand) {
1457 fused_param = fused_parameters[i];
1458 break;
1459 }
1460 }
1461
1462 if (fused_param == nullptr) {
1463 // Clone's operand was not already an operand of the fusion
1464 // instruction. Add it as an operand and add a corresponding fused
1465 // parameter instruction.
1466 fused_param = AddFusionOperand(operand);
1467 }
1468 TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1469 }
1470
1471 if (add_output) {
1472 CHECK_GT(instruction_to_fuse->user_count(), 0);
1473 // If this is already a multioutput fusion instruction, expand the root
1474 // tuple by 1.
1475 HloInstruction* fused_root = fused_expression_root();
1476 HloInstruction::InstructionVector tuple_elements;
1477 bool newly_created_tuple_instr = false;
1478 if (fused_root->opcode() == HloOpcode::kTuple) {
1479 tuple_elements = fused_root->operands();
1480 } else {
1481 tuple_elements.push_back(fused_root);
1482 newly_created_tuple_instr = true;
1483 }
1484 if (clone->opcode() == HloOpcode::kTuple) {
1485 for (auto inst : clone->operands()) {
1486 tuple_elements.push_back(inst);
1487 }
1488 } else {
1489 tuple_elements.push_back(clone);
1490 }
1491 HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1492 HloInstruction::CreateTuple(tuple_elements));
1493 fused_instructions_computation()->set_root_instruction(new_root);
1494 *mutable_shape() = new_root->shape();
1495 if (fused_root->opcode() == HloOpcode::kTuple) {
1496 TF_CHECK_OK(
1497 fused_instructions_computation()->RemoveInstruction(fused_root));
1498 }
1499
1500 // If this is a newly created multioutput instruction, we need to update
1501 // the use of the original fusion instruction.
1502 if (newly_created_tuple_instr) {
1503 HloInstruction* new_instr = parent()->AddInstruction(
1504 HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1505 TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1506 }
1507 int64 index = tuple_elements.size();
1508 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1509 CHECK_EQ(clone, instruction_to_fuse);
1510 index -= clone->operand_count();
1511 std::vector<HloInstruction*> to_be_removed;
1512 for (auto old_gte : clone->users()) {
1513 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1514 int64 old_tuple_index = old_gte->tuple_index();
1515 HloInstruction* new_gte =
1516 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1517 old_gte->shape(), this, index + old_tuple_index));
1518 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1519 to_be_removed.push_back(old_gte);
1520 }
1521 for (auto old_gte : to_be_removed) {
1522 TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1523 }
1524 } else {
1525 HloInstruction* new_gte =
1526 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1527 clone->shape(), this, index - 1));
1528 TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1529 }
1530 }
1531
1532 if (clone != instruction_to_fuse) {
1533 VLOG(2) << "New clone:\n" << clone->ToString();
1534 }
1535 return clone;
1536 }
1537
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1538 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1539 const HloPrintOptions& options) const {
1540 return {StrCat("kind=", xla::ToString(fusion_kind()))};
1541 }
1542
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1543 bool HloFusionInstruction::IdenticalSlowPath(
1544 const HloInstruction& other,
1545 const std::function<bool(const HloComputation*, const HloComputation*)>&
1546 eq_computations) const {
1547 return fusion_kind() == other.fusion_kind() &&
1548 eq_computations(fused_instructions_computation(),
1549 other.fused_instructions_computation());
1550 }
1551
HashOperandRecursive(const HloInstruction * hlo)1552 static uint64 HashOperandRecursive(const HloInstruction* hlo) {
1553 return hlo->Hash(HashOperandRecursive);
1554 }
1555
InnerHash() const1556 uint64 HloFusionInstruction::InnerHash() const {
1557 // Use HashOperandRecursive to recursively compute hash on inner operands.
1558 return fused_instructions_computation()->root_instruction()->Hash(
1559 HashOperandRecursive);
1560 }
1561
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1562 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1563 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1564 HloCloneContext* context) const {
1565 HloModule* module = context != nullptr ? context->module() : GetModule();
1566 HloComputation* new_fused_computation = nullptr;
1567 if (context != nullptr) {
1568 new_fused_computation =
1569 context->FindComputation(fused_instructions_computation());
1570 }
1571 if (new_fused_computation == nullptr) {
1572 new_fused_computation = module->AddEmbeddedComputation(
1573 fused_instructions_computation()->Clone("clone", context));
1574 }
1575 return absl::make_unique<HloFusionInstruction>(
1576 shape, fusion_kind(), new_operands, new_fused_computation);
1577 }
1578
DeduplicateFusionOperands()1579 Status HloFusionInstruction::DeduplicateFusionOperands() {
1580 absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1581 std::vector<int> operands_to_remove;
1582 for (int i = 0; i < operand_count(); ++i) {
1583 auto emplace_result = operand_indices.emplace(operand(i), i);
1584 if (!emplace_result.second) {
1585 TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1586 fused_parameter(emplace_result.first->second)));
1587 operands_to_remove.push_back(i);
1588 }
1589 }
1590 if (operands_to_remove.empty()) {
1591 return Status::OK();
1592 }
1593 TF_RETURN_IF_ERROR(
1594 fused_instructions_computation()->RemoveUnusedParameters());
1595 RemoveOperandsAtAscendingIndices(operands_to_remove);
1596 return Status::OK();
1597 }
1598
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1599 HloRngInstruction::HloRngInstruction(
1600 const Shape& shape, RandomDistribution distribution,
1601 absl::Span<HloInstruction* const> parameters)
1602 : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1603 for (HloInstruction* param : parameters) {
1604 AppendOperand(param);
1605 }
1606 }
1607
ToProto() const1608 HloInstructionProto HloRngInstruction::ToProto() const {
1609 HloInstructionProto proto = HloInstruction::ToProto();
1610 proto.set_distribution(distribution_);
1611 return proto;
1612 }
1613
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1614 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1615 const HloPrintOptions& options) const {
1616 return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1617 }
1618
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1619 bool HloRngInstruction::IsElementwiseImpl(
1620 const absl::optional<int64>& operand_idx) const {
1621 return true;
1622 }
1623
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1624 bool HloRngInstruction::IdenticalSlowPath(
1625 const HloInstruction& other,
1626 const std::function<bool(const HloComputation*, const HloComputation*)>&
1627 eq_computations) const {
1628 return false;
1629 }
1630
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1631 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1632 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1633 HloCloneContext* context) const {
1634 return absl::make_unique<HloRngInstruction>(shape, distribution_,
1635 new_operands);
1636 }
1637
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1638 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1639 const Shape& shape,
1640 const string& name)
1641 : HloInstruction(HloOpcode::kParameter, shape),
1642 parameter_number_(parameter_number) {
1643 SetAndSanitizeName(name);
1644 }
1645
ToProto() const1646 HloInstructionProto HloParameterInstruction::ToProto() const {
1647 HloInstructionProto proto = HloInstruction::ToProto();
1648 proto.set_parameter_number(parameter_number_);
1649 if (parameter_replicated_at_leaf_buffers_) {
1650 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1651 proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1652 replicated);
1653 }
1654 }
1655 return proto;
1656 }
1657
ExtraAttributesToStringImpl(const HloPrintOptions &) const1658 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1659 const HloPrintOptions& /*options*/) const {
1660 std::vector<string> result;
1661 if (!parameter_replicated_at_leaf_buffers_) {
1662 return result;
1663 }
1664 std::vector<string> buffers_replicated_strs;
1665 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1666 buffers_replicated_strs.push_back(replicated ? "true" : "false");
1667 }
1668 result.push_back(StrCat("parameter_replication={",
1669 StrJoin(buffers_replicated_strs, ","), "}"));
1670 return result;
1671 }
1672
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1673 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1674 const HloPrintOptions& options,
1675 CanonicalNameMap* canonical_name_map) const {
1676 return StrCat(parameter_number_);
1677 }
1678
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1679 bool HloParameterInstruction::IdenticalSlowPath(
1680 const HloInstruction& other,
1681 const std::function<bool(const HloComputation*, const HloComputation*)>&
1682 eq_computations) const {
1683 const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1684 return parameter_number() == casted_other.parameter_number();
1685 }
1686
1687 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1688 HloParameterInstruction::CloneWithNewOperandsImpl(
1689 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1690 HloCloneContext* context) const {
1691 return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
1692 name());
1693 }
1694
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1695 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1696 const Shape& shape, HloInstruction* operand, int64 index)
1697 : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1698 AppendOperand(operand);
1699 }
1700
ToProto() const1701 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1702 HloInstructionProto proto = HloInstruction::ToProto();
1703 proto.set_tuple_index(tuple_index_);
1704 return proto;
1705 }
1706
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1707 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1708 const HloPrintOptions& options) const {
1709 return {StrCat("index=", tuple_index())};
1710 }
1711
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1712 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1713 const HloInstruction& other,
1714 const std::function<bool(const HloComputation*, const HloComputation*)>&
1715 eq_computations) const {
1716 const auto& casted_other =
1717 static_cast<const HloGetTupleElementInstruction&>(other);
1718 return tuple_index() == casted_other.tuple_index();
1719 }
1720
1721 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1722 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1723 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1724 HloCloneContext* context) const {
1725 CHECK_EQ(new_operands.size(), 1);
1726 return absl::make_unique<HloGetTupleElementInstruction>(
1727 shape, new_operands[0], tuple_index());
1728 }
1729
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1730 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
1731 const Shape& shape, HloInstruction* operand, const int exponent_bits,
1732 const int mantissa_bits)
1733 : HloInstruction(HloOpcode::kReducePrecision, shape),
1734 exponent_bits_(exponent_bits),
1735 mantissa_bits_(mantissa_bits) {
1736 AppendOperand(operand);
1737 }
1738
ToProto() const1739 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
1740 HloInstructionProto proto = HloInstruction::ToProto();
1741 proto.set_exponent_bits(exponent_bits_);
1742 proto.set_mantissa_bits(mantissa_bits_);
1743 return proto;
1744 }
1745
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1746 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
1747 const HloPrintOptions& options) const {
1748 return {StrCat("exponent_bits=", exponent_bits_),
1749 StrCat("mantissa_bits=", mantissa_bits_)};
1750 }
1751
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1752 bool HloReducePrecisionInstruction::IdenticalSlowPath(
1753 const HloInstruction& other,
1754 const std::function<bool(const HloComputation*, const HloComputation*)>&
1755 eq_computations) const {
1756 const auto& casted_other =
1757 static_cast<const HloReducePrecisionInstruction&>(other);
1758 // A reduce-precision operation is determined by the bit sizes.
1759 return exponent_bits() == casted_other.exponent_bits() &&
1760 mantissa_bits() == casted_other.mantissa_bits();
1761 }
1762
1763 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1764 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
1765 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1766 HloCloneContext* context) const {
1767 CHECK_EQ(new_operands.size(), 1);
1768 return absl::make_unique<HloReducePrecisionInstruction>(
1769 shape, new_operands[0], exponent_bits(), mantissa_bits());
1770 }
1771
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1772 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
1773 HloInstruction* token_operand,
1774 const string& config)
1775 : HloInstruction(HloOpcode::kInfeed,
1776 ShapeUtil::MakeTupleShape(
1777 {infeed_shape, ShapeUtil::MakeTokenShape()})),
1778 infeed_config_(config) {
1779 AppendOperand(token_operand);
1780 }
1781
ToProto() const1782 HloInstructionProto HloInfeedInstruction::ToProto() const {
1783 HloInstructionProto proto = HloInstruction::ToProto();
1784 proto.set_infeed_config(infeed_config_);
1785 return proto;
1786 }
1787
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1788 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
1789 const HloPrintOptions& options) const {
1790 if (infeed_config_.empty()) {
1791 return {};
1792 }
1793 return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
1794 }
1795
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1796 bool HloInfeedInstruction::IdenticalSlowPath(
1797 const HloInstruction& other,
1798 const std::function<bool(const HloComputation*, const HloComputation*)>&
1799 eq_computations) const {
1800 // Not yet supported.
1801 return false;
1802 }
1803
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1804 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
1805 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1806 HloCloneContext* context) const {
1807 CHECK_EQ(new_operands.size(), 1);
1808 return absl::make_unique<HloInfeedInstruction>(
1809 infeed_shape(), new_operands[0], infeed_config());
1810 }
1811
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1812 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
1813 HloInstruction* operand,
1814 HloInstruction* token_operand,
1815 absl::string_view outfeed_config)
1816 : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
1817 outfeed_shape_(outfeed_shape),
1818 outfeed_config_(outfeed_config) {
1819 AppendOperand(operand);
1820 AppendOperand(token_operand);
1821 }
1822
ToProto() const1823 HloInstructionProto HloOutfeedInstruction::ToProto() const {
1824 HloInstructionProto proto = HloInstruction::ToProto();
1825 proto.set_outfeed_config(outfeed_config());
1826 *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
1827 return proto;
1828 }
1829
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1830 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
1831 const HloPrintOptions& options) const {
1832 if (outfeed_config_.empty()) {
1833 return {};
1834 }
1835 return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
1836 }
1837
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1838 bool HloOutfeedInstruction::IdenticalSlowPath(
1839 const HloInstruction& other,
1840 const std::function<bool(const HloComputation*, const HloComputation*)>&
1841 eq_computations) const {
1842 // Not yet supported.
1843 return false;
1844 }
1845
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1846 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
1847 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1848 HloCloneContext* context) const {
1849 CHECK_EQ(new_operands.size(), 2);
1850 return absl::make_unique<HloOutfeedInstruction>(
1851 outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
1852 }
1853
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1854 HloConvolutionInstruction::HloConvolutionInstruction(
1855 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1856 int64 feature_group_count, int64 batch_group_count, const Window& window,
1857 const ConvolutionDimensionNumbers& dimension_numbers,
1858 const PrecisionConfig& precision_config)
1859 : HloInstruction(HloOpcode::kConvolution, shape),
1860 feature_group_count_(feature_group_count),
1861 batch_group_count_(batch_group_count),
1862 window_(window),
1863 convolution_dimension_numbers_(dimension_numbers),
1864 precision_config_(precision_config) {
1865 if (window_util::HasBaseDilation(window)) {
1866 SetAndSanitizeName(StrCat(name(), "-base-dilated"));
1867 }
1868 if (window_util::HasWindowDilation(window)) {
1869 SetAndSanitizeName(StrCat(name(), "-window-dilated"));
1870 }
1871 AppendOperand(lhs);
1872 AppendOperand(rhs);
1873 }
1874
ToCategory() const1875 string HloConvolutionInstruction::ToCategory() const {
1876 string category = "convolution";
1877 if (window_util::HasBaseDilation(window())) {
1878 category += " base-dilated";
1879 }
1880 if (window_util::HasWindowDilation(window())) {
1881 category += " window-dilated";
1882 }
1883 return category;
1884 }
1885
ToProto() const1886 HloInstructionProto HloConvolutionInstruction::ToProto() const {
1887 HloInstructionProto proto = HloInstruction::ToProto();
1888 *proto.mutable_window() = window_;
1889 *proto.mutable_convolution_dimension_numbers() =
1890 convolution_dimension_numbers_;
1891 proto.set_feature_group_count(feature_group_count_);
1892 proto.set_batch_group_count(batch_group_count_);
1893 *proto.mutable_precision_config() = precision_config_;
1894 return proto;
1895 }
1896
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1897 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
1898 const HloPrintOptions& options) const {
1899 std::vector<string> extra;
1900 if (window_.dimensions_size() != 0) {
1901 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
1902 }
1903 extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
1904 convolution_dimension_numbers_)));
1905 if (feature_group_count_ != 1) {
1906 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
1907 }
1908
1909 if (batch_group_count_ != 1) {
1910 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
1911 }
1912
1913 string precision_config_string = PrecisionConfigToString(precision_config_);
1914 if (!precision_config_string.empty()) {
1915 extra.push_back(precision_config_string);
1916 }
1917
1918 return extra;
1919 }
1920
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1921 bool HloConvolutionInstruction::IdenticalSlowPath(
1922 const HloInstruction& other,
1923 const std::function<bool(const HloComputation*, const HloComputation*)>&
1924 eq_computations) const {
1925 const auto& casted_other =
1926 static_cast<const HloConvolutionInstruction&>(other);
1927 if (feature_group_count_ != other.feature_group_count()) {
1928 return false;
1929 }
1930 if (batch_group_count_ != other.batch_group_count()) {
1931 return false;
1932 }
1933 return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
1934 protobuf_util::ProtobufEquals(
1935 convolution_dimension_numbers(),
1936 casted_other.convolution_dimension_numbers()) &&
1937 protobuf_util::ProtobufEquals(precision_config(),
1938 casted_other.precision_config());
1939 }
1940
1941 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1942 HloConvolutionInstruction::CloneWithNewOperandsImpl(
1943 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1944 HloCloneContext* context) const {
1945 CHECK_EQ(new_operands.size(), 2);
1946 return absl::make_unique<HloConvolutionInstruction>(
1947 shape, new_operands[0], new_operands[1], feature_group_count_,
1948 batch_group_count_, window(), convolution_dimension_numbers_,
1949 precision_config_);
1950 }
1951
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1952 HloReduceWindowInstruction::HloReduceWindowInstruction(
1953 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1954 const Window& window, HloComputation* reduce_computation)
1955 : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
1956 AppendOperand(operand);
1957 AppendOperand(init_value);
1958 AppendComputation(reduce_computation);
1959 }
1960
ToProto() const1961 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
1962 HloInstructionProto proto = HloInstruction::ToProto();
1963 *proto.mutable_window() = window_;
1964 return proto;
1965 }
1966
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1967 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
1968 const HloPrintOptions& options) const {
1969 std::vector<string> extra;
1970 if (window_.dimensions_size() != 0) {
1971 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
1972 }
1973 return extra;
1974 }
1975
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1976 bool HloReduceWindowInstruction::IdenticalSlowPath(
1977 const HloInstruction& other,
1978 const std::function<bool(const HloComputation*, const HloComputation*)>&
1979 eq_computations) const {
1980 const auto& casted_other =
1981 static_cast<const HloReduceWindowInstruction&>(other);
1982 return eq_computations(to_apply(), casted_other.to_apply()) &&
1983 protobuf_util::ProtobufEquals(window(), casted_other.window());
1984 }
1985
1986 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1987 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
1988 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1989 HloCloneContext* context) const {
1990 CHECK_EQ(new_operands.size(), 2);
1991 return absl::make_unique<HloReduceWindowInstruction>(
1992 shape, new_operands[0], new_operands[1], window(), to_apply());
1993 }
1994
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1995 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
1996 const Shape& shape, HloInstruction* operand, HloComputation* select,
1997 const Window& window, HloInstruction* source, HloInstruction* init_value,
1998 HloComputation* scatter)
1999 : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2000 AppendOperand(operand);
2001 AppendOperand(source);
2002 AppendOperand(init_value);
2003 // Select comes before scatter in the vector.
2004 AppendComputation(select);
2005 AppendComputation(scatter);
2006 }
2007
ToProto() const2008 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2009 HloInstructionProto proto = HloInstruction::ToProto();
2010 *proto.mutable_window() = window_;
2011 return proto;
2012 }
2013
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2014 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2015 const HloPrintOptions& options) const {
2016 std::vector<string> extra;
2017 if (window_.dimensions_size() != 0) {
2018 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2019 }
2020 return extra;
2021 }
2022
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2023 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2024 const HloInstruction& other,
2025 const std::function<bool(const HloComputation*, const HloComputation*)>&
2026 eq_computations) const {
2027 const auto& casted_other =
2028 static_cast<const HloSelectAndScatterInstruction&>(other);
2029 return eq_computations(select(), casted_other.select()) &&
2030 eq_computations(scatter(), casted_other.scatter()) &&
2031 protobuf_util::ProtobufEquals(window(), casted_other.window());
2032 }
2033
2034 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2035 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2036 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2037 HloCloneContext* context) const {
2038 CHECK_EQ(new_operands.size(), 3);
2039 return absl::make_unique<HloSelectAndScatterInstruction>(
2040 shape, new_operands[0], select(), window(), new_operands[1],
2041 new_operands[2], scatter());
2042 }
2043
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque)2044 HloCustomCallInstruction::HloCustomCallInstruction(
2045 const Shape& shape, absl::Span<HloInstruction* const> operands,
2046 absl::string_view custom_call_target, absl::string_view opaque)
2047 : HloInstruction(HloOpcode::kCustomCall, shape),
2048 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2049 opaque_(opaque.begin(), opaque.end()),
2050 feature_group_count_(1),
2051 batch_group_count_(1),
2052 layout_constrained_(false) {
2053 for (auto operand : operands) {
2054 AppendOperand(operand);
2055 }
2056 }
2057
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque,absl::Span<const Shape> operand_shapes_with_layout)2058 HloCustomCallInstruction::HloCustomCallInstruction(
2059 const Shape& shape, absl::Span<HloInstruction* const> operands,
2060 absl::string_view custom_call_target, absl::string_view opaque,
2061 absl::Span<const Shape> operand_shapes_with_layout)
2062 : HloInstruction(HloOpcode::kCustomCall, shape),
2063 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2064 opaque_(opaque.begin(), opaque.end()),
2065 feature_group_count_(1),
2066 batch_group_count_(1),
2067 layout_constrained_(true),
2068 operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2069 operand_shapes_with_layout.end()) {
2070 for (auto operand : operands) {
2071 AppendOperand(operand);
2072 }
2073 }
2074
ToProto() const2075 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2076 HloInstructionProto proto = HloInstruction::ToProto();
2077 if (window_ != nullptr) {
2078 *proto.mutable_window() = *window_;
2079 }
2080 if (convolution_dimension_numbers_ != nullptr) {
2081 *proto.mutable_convolution_dimension_numbers() =
2082 *convolution_dimension_numbers_;
2083 }
2084 proto.set_custom_call_target(custom_call_target_);
2085 proto.set_custom_call_opaque(opaque_);
2086 proto.set_feature_group_count(feature_group_count_);
2087 proto.set_batch_group_count(batch_group_count_);
2088 if (layout_constrained()) {
2089 proto.set_constrain_layout(true);
2090 for (const Shape& shape : operand_shapes_with_layout_) {
2091 *proto.add_operand_shapes_with_layout() = shape.ToProto();
2092 }
2093 }
2094 return proto;
2095 }
2096
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2097 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2098 const HloPrintOptions& options) const {
2099 std::vector<string> extra;
2100 if (window_ != nullptr && window_->dimensions_size() != 0) {
2101 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2102 }
2103 if (convolution_dimension_numbers_ != nullptr) {
2104 extra.push_back(StrCat(
2105 "dim_labels=",
2106 ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2107 }
2108 if (feature_group_count_ != 1) {
2109 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2110 }
2111 if (batch_group_count_ != 1) {
2112 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2113 }
2114 // By contract, we print the custom call target even if
2115 // options.print_subcomputation_mode() == kOff, because the call target is not
2116 // an HloComputation.
2117 extra.push_back(
2118 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2119 // If the opaque string becomes enormous we may want to reconsider printing
2120 // this inline and consider other options.
2121 if (!opaque_.empty()) {
2122 extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
2123 }
2124 if (layout_constrained()) {
2125 std::vector<string> shape_strings;
2126 for (const Shape& shape : operand_shapes_with_layout_) {
2127 shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2128 }
2129 extra.push_back(StrCat("operand_layout_constraints={",
2130 StrJoin(shape_strings, ", "), "}"));
2131 }
2132 return extra;
2133 }
2134
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2135 bool HloCustomCallInstruction::IdenticalSlowPath(
2136 const HloInstruction& other,
2137 const std::function<bool(const HloComputation*, const HloComputation*)>&
2138 eq_computations) const {
2139 const auto& casted_other =
2140 static_cast<const HloCustomCallInstruction&>(other);
2141 if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2142 (window_ != nullptr &&
2143 !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2144 return false;
2145 }
2146 if ((convolution_dimension_numbers_ == nullptr) !=
2147 (casted_other.convolution_dimension_numbers_ == nullptr) ||
2148 (convolution_dimension_numbers_ != nullptr &&
2149 !protobuf_util::ProtobufEquals(
2150 convolution_dimension_numbers(),
2151 casted_other.convolution_dimension_numbers()))) {
2152 return false;
2153 }
2154 if (feature_group_count_ != casted_other.feature_group_count_) {
2155 return false;
2156 }
2157 if (batch_group_count_ != casted_other.batch_group_count_) {
2158 return false;
2159 }
2160 if (layout_constrained() != casted_other.layout_constrained()) {
2161 return false;
2162 }
2163 if (layout_constrained()) {
2164 for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2165 if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2166 casted_other.operand_shapes_with_layout_[i])) {
2167 return false;
2168 }
2169 }
2170 }
2171 return custom_call_target_ == casted_other.custom_call_target_ &&
2172 opaque_ == casted_other.opaque_;
2173 }
2174
2175 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2176 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2177 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2178 HloCloneContext* context) const {
2179 auto cloned = absl::make_unique<HloCustomCallInstruction>(
2180 shape, new_operands, custom_call_target(), opaque());
2181 if (layout_constrained()) {
2182 cloned->layout_constrained_ = true;
2183 cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2184 }
2185 if (window_ != nullptr) {
2186 cloned->set_window(*window_);
2187 }
2188 if (convolution_dimension_numbers_ != nullptr) {
2189 cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2190 }
2191 cloned->set_feature_group_count(feature_group_count_);
2192 cloned->set_batch_group_count(batch_group_count_);
2193 return std::move(cloned);
2194 }
2195
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2196 HloPadInstruction::HloPadInstruction(const Shape& shape,
2197 HloInstruction* operand,
2198 HloInstruction* padding_value,
2199 const PaddingConfig& padding_config)
2200 : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2201 AppendOperand(operand);
2202 AppendOperand(padding_value);
2203 }
2204
ToProto() const2205 HloInstructionProto HloPadInstruction::ToProto() const {
2206 HloInstructionProto proto = HloInstruction::ToProto();
2207 *proto.mutable_padding_config() = padding_config_;
2208 return proto;
2209 }
2210
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2211 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2212 const HloPrintOptions& options) const {
2213 return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2214 }
2215
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2216 bool HloPadInstruction::IdenticalSlowPath(
2217 const HloInstruction& other,
2218 const std::function<bool(const HloComputation*, const HloComputation*)>&
2219 eq_computations) const {
2220 const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2221 return protobuf_util::ProtobufEquals(padding_config(),
2222 casted_other.padding_config());
2223 }
2224
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2225 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2226 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2227 HloCloneContext* context) const {
2228 CHECK_EQ(new_operands.size(), 2);
2229 return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2230 new_operands[1], padding_config_);
2231 }
2232
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2233 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2234 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2235 absl::Span<const int64> slice_sizes)
2236 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2237 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2238 AppendOperand(operand);
2239 AppendOperand(start_indices);
2240 }
2241
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2242 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2243 const Shape& shape, HloInstruction* operand,
2244 absl::Span<HloInstruction* const> start_indices,
2245 absl::Span<const int64> slice_sizes)
2246 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2247 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2248 AppendOperand(operand);
2249 for (HloInstruction* index : start_indices) {
2250 AppendOperand(index);
2251 }
2252 }
2253
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2254 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2255 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2256 HloInstruction* start_indices)
2257 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2258 AppendOperand(operand);
2259 AppendOperand(update);
2260 AppendOperand(start_indices);
2261 }
2262
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2263 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2264 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2265 absl::Span<HloInstruction* const> start_indices)
2266 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2267 AppendOperand(operand);
2268 AppendOperand(update);
2269 for (HloInstruction* index : start_indices) {
2270 AppendOperand(index);
2271 }
2272 }
2273
ToProto() const2274 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2275 HloInstructionProto proto = HloInstruction::ToProto();
2276 for (int64 slice_size : dynamic_slice_sizes_) {
2277 proto.add_dynamic_slice_sizes(slice_size);
2278 }
2279 return proto;
2280 }
2281
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2282 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2283 const HloPrintOptions& options) const {
2284 return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2285 "}")};
2286 }
2287
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2288 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2289 const HloInstruction& other,
2290 const std::function<bool(const HloComputation*, const HloComputation*)>&
2291 eq_computations) const {
2292 return true;
2293 }
2294
2295 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2296 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2297 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2298 HloCloneContext* context) const {
2299 if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2300 // TODO(b/118437727): Old form, remove this path.
2301 return absl::make_unique<HloDynamicSliceInstruction>(
2302 shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2303 } else {
2304 return absl::make_unique<HloDynamicSliceInstruction>(
2305 shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2306 }
2307 }
2308
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)2309 HloGatherInstruction::HloGatherInstruction(
2310 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2311 const GatherDimensionNumbers& gather_dim_numbers,
2312 absl::Span<const int64> slice_sizes)
2313 : HloInstruction(HloOpcode::kGather, shape) {
2314 AppendOperand(operand);
2315 AppendOperand(start_indices);
2316 gather_dimension_numbers_ =
2317 absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2318 absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2319 }
2320
GatherDimensionNumbersToString() const2321 string HloGatherInstruction::GatherDimensionNumbersToString() const {
2322 CHECK(gather_dimension_numbers_ != nullptr);
2323 string offset_dims =
2324 StrCat("offset_dims={",
2325 StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}");
2326 string collapsed_slice_dims = StrCat(
2327 "collapsed_slice_dims={",
2328 StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
2329 string start_index_map =
2330 StrCat("start_index_map={",
2331 StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}");
2332 string index_vector_dim = StrCat(
2333 "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
2334
2335 return StrJoin<std::initializer_list<string>>(
2336 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2337 ", ");
2338 }
2339
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2340 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2341 absl::Span<const int64> offset_dims,
2342 absl::Span<const int64> collapsed_slice_dims,
2343 absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2344 GatherDimensionNumbers gather_dim_numbers;
2345 for (int64 output_window_dim : offset_dims) {
2346 gather_dim_numbers.add_offset_dims(output_window_dim);
2347 }
2348 for (int64 elided_window_dim : collapsed_slice_dims) {
2349 gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2350 }
2351 for (int64 gather_dim_to_input_dim : start_index_map) {
2352 gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2353 }
2354
2355 gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2356 return gather_dim_numbers;
2357 }
2358
ToProto() const2359 HloInstructionProto HloGatherInstruction::ToProto() const {
2360 HloInstructionProto proto = HloInstruction::ToProto();
2361 *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2362 for (int64 bound : gather_slice_sizes()) {
2363 proto.add_gather_slice_sizes(bound);
2364 }
2365 return proto;
2366 }
2367
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2368 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2369 const HloPrintOptions& options) const {
2370 return {GatherDimensionNumbersToString(),
2371 StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2372 }
2373
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2374 bool HloGatherInstruction::IdenticalSlowPath(
2375 const HloInstruction& other,
2376 const std::function<bool(const HloComputation*, const HloComputation*)>&
2377 eq_computations) const {
2378 const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2379 return protobuf_util::ProtobufEquals(
2380 gather_dimension_numbers(),
2381 casted_other.gather_dimension_numbers()) &&
2382 gather_slice_sizes() == casted_other.gather_slice_sizes();
2383 }
2384
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2385 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2386 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2387 HloCloneContext* context) const {
2388 CHECK_EQ(new_operands.size(), 2);
2389 return absl::make_unique<HloGatherInstruction>(
2390 shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2391 gather_slice_sizes());
2392 }
2393
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers)2394 HloScatterInstruction::HloScatterInstruction(
2395 const Shape& shape, HloInstruction* operand,
2396 HloInstruction* scatter_indices, HloInstruction* updates,
2397 HloComputation* update_computation,
2398 const ScatterDimensionNumbers& scatter_dim_numbers)
2399 : HloInstruction(HloOpcode::kScatter, shape) {
2400 AppendOperand(operand);
2401 AppendOperand(scatter_indices);
2402 AppendOperand(updates);
2403 AppendComputation(update_computation);
2404 scatter_dimension_numbers_ =
2405 absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2406 }
2407
ScatterDimensionNumbersToString() const2408 string HloScatterInstruction::ScatterDimensionNumbersToString() const {
2409 string update_window_dims = StrCat(
2410 "update_window_dims={",
2411 StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}");
2412 string inserted_window_dims = StrCat(
2413 "inserted_window_dims={",
2414 StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
2415 string scatter_dims_to_operand_dims = StrCat(
2416 "scatter_dims_to_operand_dims={",
2417 StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
2418 "}");
2419 string index_vector_dim = StrCat(
2420 "index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
2421
2422 return StrJoin<std::initializer_list<string>>(
2423 {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2424 index_vector_dim},
2425 ", ");
2426 }
2427
2428 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2429 HloScatterInstruction::MakeScatterDimNumbers(
2430 absl::Span<const int64> update_window_dims,
2431 absl::Span<const int64> inserted_window_dims,
2432 absl::Span<const int64> scatter_dims_to_operand_dims,
2433 int64 index_vector_dim) {
2434 ScatterDimensionNumbers scatter_dim_numbers;
2435 for (int64 update_window_dim : update_window_dims) {
2436 scatter_dim_numbers.add_update_window_dims(update_window_dim);
2437 }
2438 for (int64 inserted_window_dim : inserted_window_dims) {
2439 scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2440 }
2441 for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2442 scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2443 scatter_dim_to_operand_dim);
2444 }
2445 scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2446 return scatter_dim_numbers;
2447 }
2448
ToProto() const2449 HloInstructionProto HloScatterInstruction::ToProto() const {
2450 HloInstructionProto proto = HloInstruction::ToProto();
2451 *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2452 return proto;
2453 }
2454
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2455 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2456 const HloPrintOptions& options) const {
2457 return {ScatterDimensionNumbersToString()};
2458 }
2459
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2460 bool HloScatterInstruction::IdenticalSlowPath(
2461 const HloInstruction& other,
2462 const std::function<bool(const HloComputation*, const HloComputation*)>&
2463 eq_computations) const {
2464 const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2465 return protobuf_util::ProtobufEquals(
2466 scatter_dimension_numbers(),
2467 casted_other.scatter_dimension_numbers()) &&
2468 eq_computations(to_apply(), casted_other.to_apply());
2469 }
2470
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2471 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2472 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2473 HloCloneContext* context) const {
2474 CHECK_EQ(new_operands.size(), 3);
2475 return absl::make_unique<HloScatterInstruction>(
2476 shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2477 scatter_dimension_numbers());
2478 }
2479
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2480 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2481 : HloInstruction(HloOpcode::kIota, shape),
2482 iota_dimension_(iota_dimension) {}
2483
ToProto() const2484 HloInstructionProto HloIotaInstruction::ToProto() const {
2485 HloInstructionProto proto = HloInstruction::ToProto();
2486 proto.add_dimensions(iota_dimension());
2487 return proto;
2488 }
2489
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2490 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2491 const HloPrintOptions& options) const {
2492 return {StrCat("iota_dimension=", iota_dimension())};
2493 }
2494
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2495 bool HloIotaInstruction::IdenticalSlowPath(
2496 const HloInstruction& other,
2497 const std::function<bool(const HloComputation*, const HloComputation*)>&
2498 eq_computations) const {
2499 const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2500 return iota_dimension() == casted_other.iota_dimension();
2501 }
2502
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2503 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2504 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2505 HloCloneContext* context) const {
2506 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2507 }
2508
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2509 HloDotInstruction::HloDotInstruction(
2510 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2511 const DotDimensionNumbers& dimension_numbers,
2512 const PrecisionConfig& precision_config)
2513 : HloInstruction(HloOpcode::kDot, shape),
2514 dot_dimension_numbers_(dimension_numbers),
2515 precision_config_(precision_config) {
2516 AppendOperand(lhs);
2517 AppendOperand(rhs);
2518 }
2519
ToProto() const2520 HloInstructionProto HloDotInstruction::ToProto() const {
2521 HloInstructionProto proto = HloInstruction::ToProto();
2522 *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2523 *proto.mutable_precision_config() = precision_config_;
2524 return proto;
2525 }
2526
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2527 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2528 const HloPrintOptions& options) const {
2529 std::vector<string> extra = {DotDimensionNumbersToString()};
2530
2531 string precision_config_string = PrecisionConfigToString(precision_config_);
2532 if (!precision_config_string.empty()) {
2533 extra.push_back(precision_config_string);
2534 }
2535 return extra;
2536 }
2537
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2538 bool HloDotInstruction::IdenticalSlowPath(
2539 const HloInstruction& other,
2540 const std::function<bool(const HloComputation*, const HloComputation*)>&
2541 eq_computations) const {
2542 const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2543 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2544 casted_other.dot_dimension_numbers()) &&
2545 protobuf_util::ProtobufEquals(precision_config(),
2546 casted_other.precision_config());
2547 }
2548
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2549 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2550 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2551 HloCloneContext* context) const {
2552 CHECK_EQ(new_operands.size(), 2);
2553 return absl::make_unique<HloDotInstruction>(
2554 shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2555 precision_config_);
2556 }
2557
DotDimensionNumbersToString() const2558 string HloDotInstruction::DotDimensionNumbersToString() const {
2559 std::vector<string> result;
2560 const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2561 if (!dnums.lhs_batch_dimensions().empty()) {
2562 result.push_back(StrCat("lhs_batch_dims={",
2563 StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
2564 }
2565 result.push_back(StrCat("lhs_contracting_dims={",
2566 StrJoin(dnums.lhs_contracting_dimensions(), ","),
2567 "}"));
2568
2569 if (!dnums.rhs_batch_dimensions().empty()) {
2570 result.push_back(StrCat("rhs_batch_dims={",
2571 StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
2572 }
2573 result.push_back(StrCat("rhs_contracting_dims={",
2574 StrJoin(dnums.rhs_contracting_dimensions(), ","),
2575 "}"));
2576
2577 return StrJoin(result, ", ");
2578 }
2579
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)2580 HloDomainInstruction::HloDomainInstruction(
2581 const Shape& shape, HloInstruction* operand,
2582 std::unique_ptr<DomainMetadata> operand_side_metadata,
2583 std::unique_ptr<DomainMetadata> user_side_metadata)
2584 : HloInstruction(HloOpcode::kDomain, shape),
2585 operand_side_metadata_(std::move(operand_side_metadata)),
2586 user_side_metadata_(std::move(user_side_metadata)) {
2587 AppendOperand(operand);
2588 }
2589
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2590 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
2591 const HloPrintOptions& options) const {
2592 if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
2593 return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
2594 "\", entry=", user_side_metadata_->ToString(),
2595 ", exit=", operand_side_metadata_->ToString(), "}")};
2596 }
2597 return {};
2598 }
2599
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2600 bool HloDomainInstruction::IdenticalSlowPath(
2601 const HloInstruction& other,
2602 const std::function<bool(const HloComputation*, const HloComputation*)>&
2603 eq_computations) const {
2604 const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
2605 return operand_side_metadata().Matches(
2606 casted_other.operand_side_metadata()) &&
2607 user_side_metadata().Matches(casted_other.user_side_metadata());
2608 }
2609
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2610 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
2611 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2612 HloCloneContext* context) const {
2613 CHECK_EQ(new_operands.size(), 1);
2614 return absl::make_unique<HloDomainInstruction>(
2615 shape, new_operands[0], operand_side_metadata_->Clone(),
2616 user_side_metadata_->Clone());
2617 }
2618
ToProto() const2619 HloInstructionProto HloDomainInstruction::ToProto() const {
2620 HloInstructionProto proto = HloInstruction::ToProto();
2621 auto operand_side_sharding =
2622 dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
2623 if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
2624 *proto.mutable_domain_entry_sharding() =
2625 operand_side_sharding->sharding()->ToProto();
2626 }
2627
2628 auto user_side_sharding =
2629 dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
2630 if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
2631 *proto.mutable_domain_exit_sharding() =
2632 user_side_sharding->sharding()->ToProto();
2633 }
2634
2635 return proto;
2636 }
2637
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)2638 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
2639 const Shape& shape, HloInstruction* operand, int64 dimension)
2640 : HloInstruction(HloOpcode::kGetDimensionSize, shape),
2641 dimension_(dimension) {
2642 AppendOperand(operand);
2643 }
2644
ToProto() const2645 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
2646 HloInstructionProto proto = HloInstruction::ToProto();
2647 proto.add_dimensions(dimension());
2648 return proto;
2649 }
2650
ExtraAttributesToStringImpl(const HloPrintOptions &) const2651 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2652 const HloPrintOptions& /*options*/) const {
2653 return {StrCat("dimensions={", dimension(), "}")};
2654 }
2655
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2656 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
2657 const HloInstruction& other,
2658 const std::function<bool(const HloComputation*, const HloComputation*)>&
2659 /*eq_computations*/) const {
2660 const auto& casted_other =
2661 static_cast<const HloGetDimensionSizeInstruction&>(other);
2662 return dimension() == casted_other.dimension();
2663 }
2664
2665 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2666 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2667 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2668 HloCloneContext* /*context*/) const {
2669 if (new_operands.size() != 1) {
2670 LOG(FATAL) << "expects 1 operand";
2671 }
2672 return absl::make_unique<HloGetDimensionSizeInstruction>(
2673 shape, new_operands[0], dimension());
2674 }
2675
2676 } // namespace xla
2677