1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17
18 #include <deque>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/core/platform/protobuf.h"
35
36 namespace xla {
37 namespace {
38
39 using absl::CEscape;
40 using absl::StrAppend;
41 using absl::StrCat;
42 using absl::StrJoin;
43
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)44 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
45 const HloInstruction* operand) {
46 const auto operand_indices = instruction->OperandIndices(operand);
47 return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
48 return instruction->IsElementwiseOnOperand(operand_index);
49 });
50 }
51
PrecisionConfigToString(const PrecisionConfig & precision_config)52 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
53 if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
54 return static_cast<PrecisionConfig::Precision>(precision) ==
55 PrecisionConfig::DEFAULT;
56 })) {
57 return "";
58 }
59
60 return StrCat(
61 "operand_precision={",
62 StrJoin(
63 precision_config.operand_precision(), ",",
64 [](string* out, int32 precision) {
65 CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
66 StrAppend(out,
67 PrecisionToString(
68 static_cast<PrecisionConfig::Precision>(precision)));
69 }),
70 "}");
71 }
72 } // namespace
73
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)74 HloBatchNormInstruction::HloBatchNormInstruction(
75 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
76 HloInstruction* scale, float epsilon, int64 feature_index)
77 : HloInstruction(opcode, shape),
78 epsilon_(epsilon),
79 feature_index_(feature_index) {
80 AppendOperand(operand);
81 AppendOperand(scale);
82 }
83
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const84 bool HloBatchNormInstruction::IdenticalSlowPath(
85 const HloInstruction& other,
86 const std::function<bool(const HloComputation*, const HloComputation*)>&
87 eq_computations) const {
88 const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
89 return feature_index() == casted_other.feature_index() &&
90 epsilon() == casted_other.epsilon();
91 }
92
ToProto() const93 HloInstructionProto HloBatchNormInstruction::ToProto() const {
94 HloInstructionProto proto = HloInstruction::ToProto();
95 proto.set_epsilon(epsilon_);
96 proto.set_feature_index(feature_index_);
97 return proto;
98 }
99
ExtraAttributesToStringImpl(const HloPrintOptions & options) const100 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
101 const HloPrintOptions& options) const {
102 return {StrCat("epsilon=", epsilon()),
103 StrCat("feature_index=", feature_index())};
104 }
105
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)106 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
107 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
108 HloInstruction* offset, float epsilon, int64 feature_index)
109 : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
110 scale, epsilon, feature_index) {
111 AppendOperand(offset);
112 }
113
114 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const115 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
116 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
117 HloCloneContext* context) const {
118 CHECK_EQ(new_operands.size(), 3);
119 return absl::make_unique<HloBatchNormTrainingInstruction>(
120 shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
121 feature_index());
122 }
123
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)124 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
125 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
126 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
127 float epsilon, int64 feature_index)
128 : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
129 scale, epsilon, feature_index) {
130 AppendOperand(offset);
131 AppendOperand(mean);
132 AppendOperand(variance);
133 }
134
135 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const136 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
137 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
138 HloCloneContext* context) const {
139 CHECK_EQ(new_operands.size(), 5);
140 return absl::make_unique<HloBatchNormInferenceInstruction>(
141 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
142 new_operands[4], epsilon(), feature_index());
143 }
144
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)145 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
146 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
147 HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
148 float epsilon, int64 feature_index)
149 : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
150 epsilon, feature_index) {
151 AppendOperand(mean);
152 AppendOperand(variance);
153 AppendOperand(grad_output);
154 }
155
156 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const157 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
158 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
159 HloCloneContext* context) const {
160 CHECK_EQ(new_operands.size(), 5);
161 return absl::make_unique<HloBatchNormGradInstruction>(
162 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
163 new_operands[4], epsilon(), feature_index());
164 }
165
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)166 HloFftInstruction::HloFftInstruction(const Shape& shape,
167 HloInstruction* operand, FftType fft_type,
168 absl::Span<const int64> fft_length)
169 : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
170 fft_length_.assign(fft_length.begin(), fft_length.end());
171 AppendOperand(operand);
172 }
173
ToProto() const174 HloInstructionProto HloFftInstruction::ToProto() const {
175 HloInstructionProto proto = HloInstruction::ToProto();
176 proto.set_fft_type(fft_type_);
177 for (int64 fft_len : fft_length_) {
178 proto.add_fft_length(fft_len);
179 }
180 return proto;
181 }
182
ExtraAttributesToStringImpl(const HloPrintOptions & options) const183 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
184 const HloPrintOptions& options) const {
185 return {StrCat("fft_type=", FftType_Name(fft_type())),
186 StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
187 }
188
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const189 bool HloFftInstruction::IdenticalSlowPath(
190 const HloInstruction& other,
191 const std::function<bool(const HloComputation*, const HloComputation*)>&
192 eq_computations) const {
193 const auto& casted_other = static_cast<const HloFftInstruction&>(other);
194 return fft_type() == casted_other.fft_type() &&
195 fft_length() == casted_other.fft_length();
196 }
197
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const198 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
199 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
200 HloCloneContext* context) const {
201 CHECK_EQ(new_operands.size(), 1);
202 return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
203 fft_length_);
204 }
205
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)206 HloCompareInstruction::HloCompareInstruction(const Shape& shape,
207 HloInstruction* lhs,
208 HloInstruction* rhs,
209 ComparisonDirection direction)
210 : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
211 AppendOperand(lhs);
212 AppendOperand(rhs);
213 }
214
ToProto() const215 HloInstructionProto HloCompareInstruction::ToProto() const {
216 HloInstructionProto proto = HloInstruction::ToProto();
217 proto.set_comparison_direction(ComparisonDirectionToString(direction_));
218 return proto;
219 }
220
ExtraAttributesToStringImpl(const HloPrintOptions & options) const221 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
222 const HloPrintOptions& options) const {
223 return {StrCat("direction=", ComparisonDirectionToString(direction()))};
224 }
225
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const226 bool HloCompareInstruction::IdenticalSlowPath(
227 const HloInstruction& other,
228 const std::function<bool(const HloComputation*, const HloComputation*)>&
229 eq_computations) const {
230 const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
231 return direction() == casted_other.direction();
232 }
233
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const234 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
235 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
236 HloCloneContext* context) const {
237 CHECK_EQ(new_operands.size(), 2);
238 return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
239 new_operands[1], direction());
240 }
241
242 namespace {
243
244 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
245 // of "key=value" attribute strings generically, using protocol buffer
246 // reflection.
247 //
248 // Currently implements a small subset of cases; feel free to add more as
249 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)250 std::vector<string> AttributeProtoToStringVector(
251 const tensorflow::protobuf::Message& message) {
252 const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
253 std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
254 reflection->ListFields(message, &fields);
255
256 std::vector<string> output;
257 for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
258 string s = absl::StrCat(field->name(), "=");
259 CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
260 switch (field->type()) {
261 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
262 bool val = reflection->GetBool(message, field);
263 absl::StrAppend(&s, val ? "true" : "false");
264 break;
265 }
266 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
267 const tensorflow::protobuf::EnumValueDescriptor* evd =
268 reflection->GetEnum(message, field);
269 absl::StrAppend(&s, evd->name());
270 break;
271 }
272 default:
273 LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
274 }
275 output.push_back(std::move(s));
276 }
277 return output;
278 }
279
280 } // namespace
281
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)282 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
283 const Shape& shape, HloInstruction* a, HloInstruction* b,
284 const TriangularSolveOptions& options)
285 : HloInstruction(HloOpcode::kTriangularSolve, shape),
286 triangular_solve_options_(options) {
287 AppendOperand(a);
288 AppendOperand(b);
289 }
290
ToProto() const291 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
292 HloInstructionProto proto = HloInstruction::ToProto();
293 *proto.mutable_triangular_solve_options() = triangular_solve_options_;
294 return proto;
295 }
296
ExtraAttributesToStringImpl(const HloPrintOptions & options) const297 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
298 const HloPrintOptions& options) const {
299 return AttributeProtoToStringVector(triangular_solve_options_);
300 }
301
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const302 bool HloTriangularSolveInstruction::IdenticalSlowPath(
303 const HloInstruction& other,
304 const std::function<bool(const HloComputation*, const HloComputation*)>&
305 eq_computations) const {
306 const auto& casted_other =
307 static_cast<const HloTriangularSolveInstruction&>(other);
308 const auto& options = triangular_solve_options();
309 const auto& other_options = casted_other.triangular_solve_options();
310
311 return options.left_side() == other_options.left_side() &&
312 options.lower() == other_options.lower() &&
313 options.unit_diagonal() == other_options.unit_diagonal() &&
314 options.transpose_a() == other_options.transpose_a();
315 }
316
317 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const318 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
319 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
320 HloCloneContext* context) const {
321 CHECK_EQ(new_operands.size(), 2);
322 return absl::make_unique<HloTriangularSolveInstruction>(
323 shape, new_operands[0], new_operands[1], triangular_solve_options());
324 }
325
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)326 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
327 HloInstruction* a,
328 const CholeskyOptions& options)
329 : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
330 AppendOperand(a);
331 }
332
ToProto() const333 HloInstructionProto HloCholeskyInstruction::ToProto() const {
334 HloInstructionProto proto = HloInstruction::ToProto();
335 *proto.mutable_cholesky_options() = cholesky_options_;
336 return proto;
337 }
338
ExtraAttributesToStringImpl(const HloPrintOptions & options) const339 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
340 const HloPrintOptions& options) const {
341 return AttributeProtoToStringVector(cholesky_options_);
342 }
343
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const344 bool HloCholeskyInstruction::IdenticalSlowPath(
345 const HloInstruction& other,
346 const std::function<bool(const HloComputation*, const HloComputation*)>&
347 eq_computations) const {
348 const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
349 const auto& options = cholesky_options();
350 const auto& other_options = casted_other.cholesky_options();
351
352 return options.lower() == other_options.lower();
353 }
354
355 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const356 HloCholeskyInstruction::CloneWithNewOperandsImpl(
357 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
358 HloCloneContext* context) const {
359 CHECK_EQ(new_operands.size(), 1);
360 return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
361 cholesky_options());
362 }
363
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)364 HloChannelInstruction::HloChannelInstruction(
365 HloOpcode opcode, const Shape& shape,
366 const absl::optional<int64>& channel_id)
367 : HloInstruction(opcode, shape), channel_id_(channel_id) {}
368
set_channel_id(const absl::optional<int64> & channel_id)369 void HloChannelInstruction::set_channel_id(
370 const absl::optional<int64>& channel_id) {
371 channel_id_ = channel_id;
372 }
373
ToProto() const374 HloInstructionProto HloChannelInstruction::ToProto() const {
375 HloInstructionProto proto = HloInstruction::ToProto();
376 if (channel_id_) {
377 CHECK_GT(channel_id_.value(), 0)
378 << "Non-positive channel id is equivalent to no channel id";
379 proto.set_channel_id(*channel_id_);
380 }
381 return proto;
382 }
383
ExtraAttributesToStringImpl(const HloPrintOptions &) const384 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
385 const HloPrintOptions& /*options*/) const {
386 std::vector<string> result;
387 if (channel_id_) {
388 result.push_back(StrCat("channel_id=", *channel_id_));
389 }
390 return result;
391 }
392
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const393 bool HloChannelInstruction::IdenticalSlowPath(
394 const HloInstruction& other,
395 const std::function<bool(const HloComputation*, const HloComputation*)>&
396 /*eq_computations*/) const {
397 const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
398 return channel_id() == casted_other.channel_id();
399 }
400
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)401 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
402 const Shape& shape,
403 int64 channel_id,
404 bool is_host_transfer)
405 : HloChannelInstruction(opcode, shape, channel_id),
406 is_host_transfer_(is_host_transfer) {}
407
ToProto() const408 HloInstructionProto HloSendRecvInstruction::ToProto() const {
409 HloInstructionProto proto = HloChannelInstruction::ToProto();
410 proto.set_is_host_transfer(is_host_transfer_);
411 return proto;
412 }
413
ExtraAttributesToStringImpl(const HloPrintOptions & options) const414 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
415 const HloPrintOptions& options) const {
416 std::vector<string> attrs =
417 HloChannelInstruction::ExtraAttributesToStringImpl(options);
418 if (is_host_transfer()) {
419 attrs.push_back("is_host_transfer=true");
420 }
421 return attrs;
422 }
423
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const424 bool HloSendRecvInstruction::IdenticalSlowPath(
425 const HloInstruction& other,
426 const std::function<bool(const HloComputation*, const HloComputation*)>&
427 eq_computations) const {
428 // Not yet supported.
429 return false;
430 }
431
432 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)433 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
434 HloInstruction* token, int64 channel_id,
435 bool is_host_transfer)
436 : HloSendRecvInstruction(
437 HloOpcode::kSend,
438 ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
439 ShapeUtil::MakeShape(U32, {}),
440 ShapeUtil::MakeTokenShape()}),
441 channel_id, is_host_transfer) {
442 AppendOperand(operand);
443 AppendOperand(token);
444 }
445
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const446 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
447 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
448 HloCloneContext* context) const {
449 CHECK_EQ(new_operands.size(), 2);
450 return absl::make_unique<HloSendInstruction>(
451 new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
452 }
453
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)454 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
455 bool is_host_transfer)
456 : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
457 CHECK_NOTNULL(operand)->channel_id().value(),
458 is_host_transfer) {
459 AppendOperand(operand);
460 }
461
462 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const463 HloSendDoneInstruction::CloneWithNewOperandsImpl(
464 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
465 HloCloneContext* context) const {
466 CHECK_EQ(new_operands.size(), 1);
467 return absl::make_unique<HloSendDoneInstruction>(
468 Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
469 }
470
471 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)472 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
473 HloInstruction* token, int64 channel_id,
474 bool is_host_transfer)
475 : HloSendRecvInstruction(
476 HloOpcode::kRecv,
477 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
478 ShapeUtil::MakeTokenShape()}),
479 channel_id, is_host_transfer) {
480 AppendOperand(token);
481 }
482
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const483 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
484 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
485 HloCloneContext* context) const {
486 CHECK_EQ(new_operands.size(), 1);
487 return absl::make_unique<HloRecvInstruction>(
488 ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
489 is_host_transfer());
490 }
491
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)492 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
493 bool is_host_transfer)
494 : HloSendRecvInstruction(
495 HloOpcode::kRecvDone,
496 ShapeUtil::MakeTupleShape(
497 {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
498 ShapeUtil::MakeTokenShape()}),
499 CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
500 AppendOperand(operand);
501 }
502
503 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const504 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
505 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
506 HloCloneContext* context) const {
507 CHECK_EQ(new_operands.size(), 1);
508 return absl::make_unique<HloRecvDoneInstruction>(
509 Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
510 }
511
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id)512 HloCollectiveInstruction::HloCollectiveInstruction(
513 HloOpcode opcode, const Shape& shape,
514 absl::Span<HloInstruction* const> operands,
515 const std::vector<ReplicaGroup>& replica_groups,
516 const absl::optional<int64>& channel_id)
517 : HloChannelInstruction(opcode, shape, channel_id),
518 replica_groups_(replica_groups) {
519 for (auto operand : operands) {
520 AppendOperand(operand);
521 }
522 }
523
ToProto() const524 HloInstructionProto HloCollectiveInstruction::ToProto() const {
525 HloInstructionProto proto = HloChannelInstruction::ToProto();
526 *proto.mutable_replica_groups() = {replica_groups_.begin(),
527 replica_groups_.end()};
528 return proto;
529 }
530
ExtraAttributesToStringImpl(const HloPrintOptions & options) const531 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
532 const HloPrintOptions& options) const {
533 std::vector<string> result =
534 HloChannelInstruction::ExtraAttributesToStringImpl(options);
535 result.push_back(
536 StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
537 return result;
538 }
539
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const540 bool HloCollectiveInstruction::IdenticalSlowPath(
541 const HloInstruction& other,
542 const std::function<bool(const HloComputation*, const HloComputation*)>&
543 eq_computations) const {
544 const auto& casted_other =
545 static_cast<const HloCollectiveInstruction&>(other);
546 return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
547 absl::c_equal(replica_groups(), casted_other.replica_groups(),
548 [](const ReplicaGroup& a, const ReplicaGroup& b) {
549 return absl::c_equal(a.replica_ids(), b.replica_ids());
550 });
551 }
552
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)553 HloAllReduceInstruction::HloAllReduceInstruction(
554 const Shape& shape, absl::Span<HloInstruction* const> operands,
555 HloComputation* reduce_computation,
556 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
557 const absl::optional<int64>& channel_id)
558 : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
559 replica_groups, channel_id),
560 constrain_layout_(constrain_layout) {
561 AppendComputation(reduce_computation);
562 }
563
IsNoop() const564 bool HloAllReduceInstruction::IsNoop() const {
565 for (auto replica_group : replica_groups()) {
566 if (replica_group.replica_ids().size() != 1) {
567 return false;
568 }
569 }
570 return !channel_id();
571 }
572
ToProto() const573 HloInstructionProto HloAllReduceInstruction::ToProto() const {
574 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
575 proto.set_constrain_layout(constrain_layout_);
576 return proto;
577 }
578
ExtraAttributesToStringImpl(const HloPrintOptions & options) const579 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
580 const HloPrintOptions& options) const {
581 std::vector<string> result =
582 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
583 if (constrain_layout_) {
584 result.push_back("constrain_layout=true");
585 }
586 return result;
587 }
588
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const589 bool HloAllReduceInstruction::IdenticalSlowPath(
590 const HloInstruction& other,
591 const std::function<bool(const HloComputation*, const HloComputation*)>&
592 eq_computations) const {
593 const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
594 return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
595 constrain_layout() == casted_other.constrain_layout() &&
596 eq_computations(to_apply(), casted_other.to_apply());
597 }
598
599 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const600 HloAllReduceInstruction::CloneWithNewOperandsImpl(
601 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
602 HloCloneContext* /*context*/) const {
603 return absl::make_unique<HloAllReduceInstruction>(
604 shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
605 channel_id());
606 }
607
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)608 HloAllToAllInstruction::HloAllToAllInstruction(
609 const Shape& shape, absl::Span<HloInstruction* const> operands,
610 const std::vector<ReplicaGroup>& replica_groups,
611 const absl::optional<int64>& channel_id,
612 const absl::optional<int64>& split_dimension)
613 : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
614 replica_groups, channel_id),
615 split_dimension_(split_dimension) {}
616
617 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const618 HloAllToAllInstruction::CloneWithNewOperandsImpl(
619 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
620 HloCloneContext* /*context*/) const {
621 return absl::make_unique<HloAllToAllInstruction>(
622 shape, new_operands, replica_groups(), channel_id(), split_dimension());
623 }
624
ToProto() const625 HloInstructionProto HloAllToAllInstruction::ToProto() const {
626 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
627 if (split_dimension_) {
628 proto.add_dimensions(*split_dimension_);
629 }
630 return proto;
631 }
632
ExtraAttributesToStringImpl(const HloPrintOptions & options) const633 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
634 const HloPrintOptions& options) const {
635 std::vector<string> result =
636 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
637 if (split_dimension_) {
638 result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
639 }
640 return result;
641 }
642
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const643 bool HloAllToAllInstruction::IdenticalSlowPath(
644 const HloInstruction& other,
645 const std::function<bool(const HloComputation*, const HloComputation*)>&
646 eq_computations) const {
647 const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
648 return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
649 split_dimension_ == casted_other.split_dimension();
650 }
651
HloCollectivePermuteInstruction(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)652 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
653 const Shape& shape, HloInstruction* operand,
654 const std::vector<std::pair<int64, int64>>& source_target_pairs,
655 const absl::optional<int64>& channel_id)
656 : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id),
657 source_target_pairs_(source_target_pairs) {
658 AppendOperand(operand);
659 }
660
ToProto() const661 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
662 HloInstructionProto proto = HloChannelInstruction::ToProto();
663 for (const auto& pair : source_target_pairs()) {
664 auto* proto_pair = proto.add_source_target_pairs();
665 proto_pair->set_source(pair.first);
666 proto_pair->set_target(pair.second);
667 }
668 return proto;
669 }
670
671 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const672 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
673 const HloPrintOptions& options) const {
674 std::vector<string> result =
675 HloChannelInstruction::ExtraAttributesToStringImpl(options);
676 std::vector<string> strs;
677 for (const auto& pair : source_target_pairs()) {
678 strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
679 }
680 result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
681 return result;
682 }
683
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const684 bool HloCollectivePermuteInstruction::IdenticalSlowPath(
685 const HloInstruction& other,
686 const std::function<bool(const HloComputation*, const HloComputation*)>&
687 eq_computations) const {
688 const auto& casted_other =
689 static_cast<const HloCollectivePermuteInstruction&>(other);
690 return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
691 absl::c_equal(source_target_pairs(),
692 casted_other.source_target_pairs(),
693 [](const std::pair<int64, int64>& a,
694 const std::pair<int64, int64>& b) { return a == b; });
695 }
696
697 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const698 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
699 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
700 HloCloneContext* /*context*/) const {
701 return absl::make_unique<HloCollectivePermuteInstruction>(
702 shape, new_operands[0], source_target_pairs(), channel_id());
703 }
704
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)705 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
706 HloInstruction* operand,
707 absl::Span<const int64> dimensions)
708 : HloInstruction(HloOpcode::kReverse, shape),
709 dimensions_(dimensions.begin(), dimensions.end()) {
710 AppendOperand(operand);
711 }
712
ToProto() const713 HloInstructionProto HloReverseInstruction::ToProto() const {
714 HloInstructionProto proto = HloInstruction::ToProto();
715 for (int64 dimension : dimensions_) {
716 proto.add_dimensions(dimension);
717 }
718 return proto;
719 }
720
ExtraAttributesToStringImpl(const HloPrintOptions & options) const721 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
722 const HloPrintOptions& options) const {
723 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
724 }
725
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const726 bool HloReverseInstruction::IdenticalSlowPath(
727 const HloInstruction& other,
728 const std::function<bool(const HloComputation*, const HloComputation*)>&
729 eq_computations) const {
730 const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
731 return dimensions() == casted_other.dimensions();
732 }
733
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const734 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
735 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
736 HloCloneContext* context) const {
737 CHECK_EQ(new_operands.size(), 1);
738 return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
739 dimensions());
740 }
741
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)742 HloConcatenateInstruction::HloConcatenateInstruction(
743 const Shape& shape, absl::Span<HloInstruction* const> operands,
744 int64 dimension)
745 : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
746 for (auto operand : operands) {
747 AppendOperand(operand);
748 }
749 }
750
ToProto() const751 HloInstructionProto HloConcatenateInstruction::ToProto() const {
752 HloInstructionProto proto = HloInstruction::ToProto();
753 for (int64 dimension : dimensions_) {
754 proto.add_dimensions(dimension);
755 }
756 return proto;
757 }
758
ExtraAttributesToStringImpl(const HloPrintOptions & options) const759 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
760 const HloPrintOptions& options) const {
761 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
762 }
763
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const764 bool HloConcatenateInstruction::IdenticalSlowPath(
765 const HloInstruction& other,
766 const std::function<bool(const HloComputation*, const HloComputation*)>&
767 eq_computations) const {
768 const auto& casted_other =
769 static_cast<const HloConcatenateInstruction&>(other);
770 return dimensions() == casted_other.dimensions();
771 }
772
773 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const774 HloConcatenateInstruction::CloneWithNewOperandsImpl(
775 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
776 HloCloneContext* context) const {
777 return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
778 dimensions(0));
779 }
780
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)781 HloReduceInstruction::HloReduceInstruction(
782 const Shape& shape, absl::Span<HloInstruction* const> args,
783 absl::Span<const int64> dimensions_to_reduce,
784 HloComputation* reduce_computation)
785 : HloInstruction(HloOpcode::kReduce, shape),
786 dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
787 for (HloInstruction* arg : args) {
788 AppendOperand(arg);
789 }
790 AppendComputation(reduce_computation);
791 }
792
ToProto() const793 HloInstructionProto HloReduceInstruction::ToProto() const {
794 HloInstructionProto proto = HloInstruction::ToProto();
795 for (int64 dimension : dimensions_) {
796 proto.add_dimensions(dimension);
797 }
798 return proto;
799 }
800
ExtraAttributesToStringImpl(const HloPrintOptions & options) const801 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
802 const HloPrintOptions& options) const {
803 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
804 }
805
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const806 bool HloReduceInstruction::IdenticalSlowPath(
807 const HloInstruction& other,
808 const std::function<bool(const HloComputation*, const HloComputation*)>&
809 eq_computations) const {
810 const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
811 // Reduction results are determined by the reduction dimension and the
812 // reduction computation.
813 return dimensions() == casted_other.dimensions() &&
814 eq_computations(to_apply(), casted_other.to_apply());
815 }
816
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const817 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
818 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
819 HloCloneContext* context) const {
820 CHECK_EQ(new_operands.size() % 2, 0);
821 return absl::make_unique<HloReduceInstruction>(shape, new_operands,
822 dimensions(), to_apply());
823 }
824
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)825 HloSortInstruction::HloSortInstruction(
826 const Shape& shape, int64 dimension,
827 absl::Span<HloInstruction* const> operands, HloComputation* compare,
828 bool is_stable)
829 : HloInstruction(HloOpcode::kSort, shape),
830 dimensions_({dimension}),
831 is_stable_(is_stable) {
832 for (auto* value : operands) {
833 AppendOperand(value);
834 }
835 AppendComputation(compare);
836 }
837
ToProto() const838 HloInstructionProto HloSortInstruction::ToProto() const {
839 HloInstructionProto proto = HloInstruction::ToProto();
840 for (int64 dimension : dimensions_) {
841 proto.add_dimensions(dimension);
842 }
843 proto.set_is_stable(is_stable());
844 return proto;
845 }
846
ExtraAttributesToStringImpl(const HloPrintOptions & options) const847 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
848 const HloPrintOptions& options) const {
849 std::vector<string> attrs;
850 attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
851 if (is_stable()) {
852 attrs.push_back("is_stable=true");
853 }
854 return attrs;
855 }
856
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const857 bool HloSortInstruction::IdenticalSlowPath(
858 const HloInstruction& other,
859 const std::function<bool(const HloComputation*, const HloComputation*)>&
860 eq_computations) const {
861 const auto& casted_other = static_cast<const HloSortInstruction&>(other);
862 if (dimensions() != casted_other.dimensions()) {
863 return false;
864 }
865 if (is_stable() != casted_other.is_stable()) {
866 return false;
867 }
868 return eq_computations(to_apply(), other.to_apply());
869 }
870
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const871 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
872 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
873 HloCloneContext* context) const {
874 return absl::make_unique<HloSortInstruction>(
875 shape, dimensions(0), new_operands, to_apply(), is_stable());
876 }
877
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)878 HloTransposeInstruction::HloTransposeInstruction(
879 const Shape& shape, HloInstruction* operand,
880 absl::Span<const int64> dimensions)
881 : HloInstruction(HloOpcode::kTranspose, shape),
882 dimensions_(dimensions.begin(), dimensions.end()) {
883 AppendOperand(operand);
884 }
885
IsRank2Transpose() const886 bool HloTransposeInstruction::IsRank2Transpose() const {
887 return dimensions() == std::vector<int64>({1, 0}) &&
888 shape().dimensions_size() == 2 &&
889 std::equal(shape().dimensions().begin(), shape().dimensions().end(),
890 operand(0)->shape().dimensions().rbegin());
891 }
892
ToProto() const893 HloInstructionProto HloTransposeInstruction::ToProto() const {
894 HloInstructionProto proto = HloInstruction::ToProto();
895 for (int64 dimension : dimensions_) {
896 proto.add_dimensions(dimension);
897 }
898 return proto;
899 }
900
ExtraAttributesToStringImpl(const HloPrintOptions & options) const901 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
902 const HloPrintOptions& options) const {
903 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
904 }
905
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const906 bool HloTransposeInstruction::IdenticalSlowPath(
907 const HloInstruction& other,
908 const std::function<bool(const HloComputation*, const HloComputation*)>&
909 eq_computations) const {
910 const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
911 return dimensions() == casted_other.dimensions();
912 }
913
914 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const915 HloTransposeInstruction::CloneWithNewOperandsImpl(
916 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
917 HloCloneContext* context) const {
918 CHECK_EQ(new_operands.size(), 1);
919 return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
920 dimensions());
921 }
922
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)923 HloBroadcastInstruction::HloBroadcastInstruction(
924 const Shape& shape, HloInstruction* operand,
925 absl::Span<const int64> broadcast_dimension)
926 : HloInstruction(HloOpcode::kBroadcast, shape),
927 dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
928 AppendOperand(operand);
929 }
930
ToProto() const931 HloInstructionProto HloBroadcastInstruction::ToProto() const {
932 HloInstructionProto proto = HloInstruction::ToProto();
933 for (int64 dimension : dimensions_) {
934 proto.add_dimensions(dimension);
935 }
936 return proto;
937 }
938
ExtraAttributesToStringImpl(const HloPrintOptions & options) const939 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
940 const HloPrintOptions& options) const {
941 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
942 }
943
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const944 bool HloBroadcastInstruction::IdenticalSlowPath(
945 const HloInstruction& other,
946 const std::function<bool(const HloComputation*, const HloComputation*)>&
947 eq_computations) const {
948 const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
949 return dimensions() == casted_other.dimensions();
950 }
951
952 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const953 HloBroadcastInstruction::CloneWithNewOperandsImpl(
954 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
955 HloCloneContext* context) const {
956 CHECK_EQ(new_operands.size(), 1);
957 return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
958 dimensions());
959 }
960
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)961 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
962 HloInstruction* operand,
963 int64 inferred_dimension)
964 : HloInstruction(HloOpcode::kReshape, shape),
965 inferred_dimension_(inferred_dimension) {
966 AppendOperand(operand);
967 }
968
ToProto() const969 HloInstructionProto HloReshapeInstruction::ToProto() const {
970 HloInstructionProto proto = HloInstruction::ToProto();
971 if (inferred_dimension_ != -1) {
972 proto.add_dimensions(inferred_dimension_);
973 }
974 return proto;
975 }
976
ExtraAttributesToStringImpl(const HloPrintOptions & options) const977 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
978 const HloPrintOptions& options) const {
979 if (inferred_dimension() == -1) {
980 return {};
981 }
982 return {StrCat("inferred_dimension=", inferred_dimension())};
983 }
984
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const985 bool HloReshapeInstruction::IdenticalSlowPath(
986 const HloInstruction& other,
987 const std::function<bool(const HloComputation*, const HloComputation*)>&
988 eq_computations) const {
989 const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
990 return inferred_dimension() == casted_other.inferred_dimension();
991 }
992
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const993 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
994 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
995 HloCloneContext* context) const {
996 CHECK_EQ(new_operands.size(), 1);
997 return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
998 inferred_dimension());
999 }
1000
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1001 HloMapInstruction::HloMapInstruction(const Shape& shape,
1002 absl::Span<HloInstruction* const> operands,
1003 HloComputation* map_computation)
1004 : HloInstruction(HloOpcode::kMap, shape) {
1005 for (auto operand : operands) {
1006 AppendOperand(operand);
1007 }
1008 AppendComputation(map_computation);
1009 // TODO(b/65689298) Remove code below once Map is generalized to accept
1010 // arbitrary map dimensions.
1011 dimensions_.resize(shape.rank());
1012 std::iota(dimensions_.begin(), dimensions_.end(), 0);
1013 }
1014
ToProto() const1015 HloInstructionProto HloMapInstruction::ToProto() const {
1016 HloInstructionProto proto = HloInstruction::ToProto();
1017 for (int64 dimension : dimensions_) {
1018 proto.add_dimensions(dimension);
1019 }
1020 return proto;
1021 }
1022
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1023 bool HloMapInstruction::IsElementwiseImpl(
1024 const absl::optional<int64>& operand_idx) const {
1025 if (!dimensions().empty()) {
1026 // Check that the map is executed in elementwise compatible dimensions.
1027 if (dimensions().size() != shape().dimensions_size()) {
1028 return false;
1029 }
1030 for (int i = 0; i < dimensions().size(); ++i) {
1031 if (dimensions()[i] != i) {
1032 return false;
1033 }
1034 }
1035 }
1036 return true;
1037 }
1038
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1039 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1040 const HloPrintOptions& options) const {
1041 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1042 }
1043
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1044 bool HloMapInstruction::IdenticalSlowPath(
1045 const HloInstruction& other,
1046 const std::function<bool(const HloComputation*, const HloComputation*)>&
1047 eq_computations) const {
1048 return eq_computations(to_apply(), other.to_apply());
1049 }
1050
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1051 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1052 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1053 HloCloneContext* context) const {
1054 return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1055 }
1056
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1057 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1058 HloInstruction* operand,
1059 absl::Span<const int64> start_indices,
1060 absl::Span<const int64> limit_indices,
1061 absl::Span<const int64> strides)
1062 : HloInstruction(HloOpcode::kSlice, shape),
1063 slice_starts_(start_indices.begin(), start_indices.end()),
1064 slice_limits_(limit_indices.begin(), limit_indices.end()),
1065 slice_strides_(strides.begin(), strides.end()) {
1066 AppendOperand(operand);
1067 // For backward compatibility with old serialized computations: if there are
1068 // no strides, assume all strides are 1.
1069 // TODO(b/63317920): remove this code.
1070 if (slice_strides_.empty()) {
1071 slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1072 }
1073 }
1074
ToProto() const1075 HloInstructionProto HloSliceInstruction::ToProto() const {
1076 HloInstructionProto proto = HloInstruction::ToProto();
1077 for (int i = 0; i < slice_starts_.size(); ++i) {
1078 auto* slice_dimension = proto.add_slice_dimensions();
1079 slice_dimension->set_start(slice_starts_[i]);
1080 slice_dimension->set_limit(slice_limits_[i]);
1081 slice_dimension->set_stride(slice_strides_[i]);
1082 }
1083 return proto;
1084 }
1085
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1086 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1087 const HloPrintOptions& options) const {
1088 std::vector<string> bounds;
1089 bounds.reserve(slice_starts_.size());
1090 const bool omit_stride =
1091 absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
1092 for (int i = 0; i < slice_starts_.size(); ++i) {
1093 string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1094 bounds.push_back(
1095 StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1096 }
1097 return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1098 }
1099
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1100 bool HloSliceInstruction::IdenticalSlowPath(
1101 const HloInstruction& other,
1102 const std::function<bool(const HloComputation*, const HloComputation*)>&
1103 eq_computations) const {
1104 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1105 return slice_starts_ == other_slice.slice_starts_ &&
1106 slice_limits_ == other_slice.slice_limits_ &&
1107 slice_strides_ == other_slice.slice_strides_;
1108 }
1109
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1110 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1111 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1112 HloCloneContext* context) const {
1113 CHECK_EQ(new_operands.size(), 1);
1114 return absl::make_unique<HloSliceInstruction>(
1115 shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1116 }
1117
HloConstantInstruction(Literal literal)1118 HloConstantInstruction::HloConstantInstruction(Literal literal)
1119 : HloInstruction(HloOpcode::kConstant, literal.shape()),
1120 literal_(std::move(literal)) {}
1121
HloConstantInstruction(Literal literal,const Shape & shape)1122 HloConstantInstruction::HloConstantInstruction(Literal literal,
1123 const Shape& shape)
1124 : HloInstruction(HloOpcode::kConstant, shape),
1125 literal_(std::move(literal)) {}
1126
HloConstantInstruction(const Shape & shape)1127 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1128 : HloInstruction(HloOpcode::kConstant, shape) {}
1129
ToProto() const1130 HloInstructionProto HloConstantInstruction::ToProto() const {
1131 HloInstructionProto proto = HloInstruction::ToProto();
1132 if (literal_.has_value()) {
1133 *proto.mutable_literal() = literal_->ToProto();
1134 }
1135 return proto;
1136 }
1137
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1138 bool HloConstantInstruction::IsElementwiseImpl(
1139 const absl::optional<int64>& operand_idx) const {
1140 return true;
1141 }
1142
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1143 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1144 const ShapeIndex& shape_index) {
1145 Shape* mutable_array_subshape =
1146 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1147 CHECK(mutable_array_subshape->IsArray());
1148
1149 // Normally array_subshape will always have a layout, but this invariant is
1150 // temporarily broken in LayoutAssignment::AssignLayouts.
1151
1152 if (!mutable_array_subshape->has_layout() ||
1153 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1154 *literal_ = literal_->Relayout(new_layout, shape_index);
1155 *mutable_array_subshape->mutable_layout() = new_layout;
1156 }
1157 }
1158
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1159 bool HloConstantInstruction::IdenticalSlowPath(
1160 const HloInstruction& other,
1161 const std::function<bool(const HloComputation*, const HloComputation*)>&
1162 eq_computations) const {
1163 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1164 return literal() == other_slice.literal();
1165 }
1166
1167 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1168 HloConstantInstruction::CloneWithNewOperandsImpl(
1169 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1170 HloCloneContext* context) const {
1171 CHECK(literal_.has_value());
1172 // Literal's shape may have no/different tiling info. Use this instruction's
1173 // shape instead.
1174 CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1175 this->shape()));
1176 return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1177 this->shape());
1178 }
1179
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1180 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1181 const HloPrintOptions& options,
1182 CanonicalNameMap* canonical_name_map) const {
1183 string operands;
1184 // For constants, show the actual value in place of an empty operand list.
1185 if (literal_.has_value() &&
1186 ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1187 options.print_large_constants())) {
1188 // Literal::ToString emits multidimensional arrays over multiple
1189 // lines. Compact this into one line by stripping out white space.
1190 string tmp = literal().ToStringWithoutShape();
1191 std::replace(tmp.begin(), tmp.end(), '\n', ' ');
1192 std::vector<string> v = absl::StrSplit(tmp, ' ');
1193 bool first = true;
1194 // Concatenate elements in "v" with spaces separating them, but ignoring
1195 // empty entries.
1196 for (const auto& s : v) {
1197 if (s.empty()) {
1198 continue;
1199 }
1200 StrAppend(&operands, (first ? "" : " "), s);
1201 first = false;
1202 }
1203 } else {
1204 // Do not show large constants or tuples.
1205 operands = "{...}";
1206 }
1207 return operands;
1208 }
1209
HloTraceInstruction(const string & tag,HloInstruction * operand)1210 HloTraceInstruction::HloTraceInstruction(const string& tag,
1211 HloInstruction* operand)
1212 : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1213 literal_(LiteralUtil::CreateR1U8(tag)) {
1214 AppendOperand(operand);
1215 operand->set_tracing(this);
1216 }
1217
ToProto() const1218 HloInstructionProto HloTraceInstruction::ToProto() const {
1219 HloInstructionProto proto = HloInstruction::ToProto();
1220 *proto.mutable_literal() = literal_.ToProto();
1221 return proto;
1222 }
1223
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1224 bool HloTraceInstruction::IdenticalSlowPath(
1225 const HloInstruction& other,
1226 const std::function<bool(const HloComputation*, const HloComputation*)>&
1227 eq_computations) const {
1228 return false;
1229 }
1230
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1231 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1232 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1233 HloCloneContext* context) const {
1234 LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1235 }
1236
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1237 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1238 FusionKind fusion_kind,
1239 HloInstruction* fused_root)
1240 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1241 CHECK(fused_root != nullptr);
1242 SetAndSanitizeName("fusion");
1243 set_parent(fused_root->parent());
1244 set_metadata(fused_root->metadata());
1245 CloneAndFuseInternal(fused_root);
1246 }
1247
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1248 HloFusionInstruction::HloFusionInstruction(
1249 const Shape& shape, FusionKind fusion_kind,
1250 absl::Span<HloInstruction* const> operands,
1251 HloComputation* fusion_computation)
1252 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1253 for (auto operand : operands) {
1254 AppendOperand(operand);
1255 }
1256 SetAndSanitizeName("fusion");
1257 AppendComputation(fusion_computation);
1258 fusion_computation->SetFusionInstruction(this);
1259 }
1260
ToCategory() const1261 string HloFusionInstruction::ToCategory() const {
1262 switch (fusion_kind()) {
1263 case FusionKind::kLoop:
1264 return "loop fusion";
1265 case FusionKind::kInput:
1266 return "input fusion";
1267 case FusionKind::kOutput:
1268 return "output fusion";
1269 case FusionKind::kCustom:
1270 return "custom fusion";
1271 }
1272 }
1273
ToProto() const1274 HloInstructionProto HloFusionInstruction::ToProto() const {
1275 HloInstructionProto proto = HloInstruction::ToProto();
1276 proto.set_fusion_kind(xla::ToString(fusion_kind()));
1277 proto.add_called_computation_ids(
1278 fused_instructions_computation()->unique_id());
1279 return proto;
1280 }
1281
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1282 bool HloFusionInstruction::IsElementwiseImpl(
1283 const absl::optional<int64>& operand_idx) const {
1284 if (!operand_idx.has_value()) {
1285 for (auto* fused : fused_instructions()) {
1286 if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1287 return false;
1288 }
1289 }
1290 return true;
1291 }
1292 // A loop-fusion is elementwise on an operand if all operations (computed
1293 // using BFS) between the operand and the fused root are elementwise.
1294 std::deque<HloInstruction*> worklist;
1295 std::unordered_set<const HloInstruction*> visited;
1296 worklist.push_back(fused_parameter(operand_idx.value()));
1297 visited.insert(fused_parameter(operand_idx.value()));
1298 while (!worklist.empty()) {
1299 HloInstruction* operand = worklist.front();
1300 worklist.pop_front();
1301 for (HloInstruction* user : operand->users()) {
1302 CHECK_GE(user->unique_id(), 0);
1303 if (ContainsKey(visited, user)) {
1304 continue;
1305 }
1306 if (user->IsElementwise() ||
1307 IsInstructionElementwiseOnOperand(user, operand)) {
1308 worklist.push_back(user);
1309 visited.insert(user);
1310 } else {
1311 return false;
1312 }
1313 }
1314 }
1315 return true;
1316 }
1317
AddFusionOperand(HloInstruction * new_operand)1318 HloInstruction* HloFusionInstruction::AddFusionOperand(
1319 HloInstruction* new_operand) {
1320 CHECK_EQ(operand_count(),
1321 fused_instructions_computation()->parameter_instructions().size());
1322 const int64 param_no = operand_count();
1323 string param_name = StrCat("param_", param_no);
1324 HloInstruction* fused_parameter =
1325 fused_instructions_computation()->AddParameter(
1326 HloInstruction::CreateParameter(param_no, new_operand->shape(),
1327 param_name));
1328 AppendOperand(new_operand);
1329 return fused_parameter;
1330 }
1331
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1332 void HloFusionInstruction::MergeFusionInstruction(
1333 HloFusionInstruction* instruction_to_merge) {
1334 CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1335 // Clone the instruction from which to merge fused instructions.
1336 std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1337 HloFusionInstruction* cloned_fusion =
1338 static_cast<HloFusionInstruction*>(cloned.get());
1339 // Replace uses of fused parameters with the corresponding operand of the
1340 // fusion. Add all non-parameter fused instructions to
1341 // 'unfused_instructions' to be merged into 'this'. This is done in reverse
1342 // post order.
1343 std::vector<HloInstruction*> unfused_instructions;
1344 auto fused_instructions = cloned_fusion->fused_instructions_computation()
1345 ->MakeInstructionPostOrder();
1346 for (auto fused_it = fused_instructions.rbegin();
1347 fused_it != fused_instructions.rend(); ++fused_it) {
1348 auto fused_instruction = *fused_it;
1349 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1350 TF_CHECK_OK(
1351 fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1352 fused_instruction->parameter_number())));
1353 } else {
1354 unfused_instructions.push_back(fused_instruction);
1355 }
1356 }
1357
1358 // If there are no unfused instructions, the fused computation must consist
1359 // only of kParameter instructions. Make the operand of the corresponding
1360 // parameter number the new root.
1361 HloInstruction* unfused_root =
1362 unfused_instructions.empty()
1363 ? instruction_to_merge->mutable_operand(
1364 instruction_to_merge->fused_instructions_computation()
1365 ->root_instruction()
1366 ->parameter_number())
1367 : unfused_instructions.front();
1368 CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1369 unfused_instructions.empty());
1370 // Replace instruction_to_merge use of 'this' with unfused_root.
1371 TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1372 // Fuse 'unfused_instructions' into 'this'.
1373 for (auto& instruction : unfused_instructions) {
1374 FuseInstruction(instruction);
1375 }
1376 CHECK_EQ(0, cloned_fusion->user_count());
1377 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1378 cloned_fusion->fused_instructions_computation()));
1379 }
1380
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1381 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1382 HloFusionInstruction* instruction_to_merge) {
1383 // Add all non-parameter fused instructions to 'unfused_instructions' to be
1384 // merged into 'this'. `old_to_new' maps the instructions in the fused node
1385 // to the disassembled fusion instructions.
1386 // Note that we add the unfused instructions to this->parent_ computation.
1387 // This is necessary because the unique_id needs for an instruction and
1388 // it's only added when inserting to the computation.
1389 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1390 std::vector<HloInstruction*> unfused_instructions;
1391 auto computation_to_merge =
1392 instruction_to_merge->fused_instructions_computation();
1393 auto post_order = computation_to_merge->MakeInstructionPostOrder();
1394 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1395 auto fused_instruction = *rit;
1396 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1397 InsertOrDie(&old_to_new, fused_instruction,
1398 instruction_to_merge->mutable_operand(
1399 fused_instruction->parameter_number()));
1400 continue;
1401 }
1402
1403 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1404 // which clones again. This can be improved.
1405 auto cloned_instruction =
1406 parent()->AddInstruction(fused_instruction->Clone());
1407 unfused_instructions.push_back(cloned_instruction);
1408 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1409 }
1410 for (auto unfused_instruction : unfused_instructions) {
1411 for (int64 index = 0; index < unfused_instruction->operand_count();
1412 index++) {
1413 auto new_operand =
1414 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1415 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1416 }
1417 }
1418
1419 // If there are no unfused instructions, the fused computation must consist
1420 // only of kParameter instructions. Make the operand of the corresponding
1421 // parameter number the new root.
1422 HloInstruction* unfused_root =
1423 unfused_instructions.empty()
1424 ? instruction_to_merge->mutable_operand(
1425 instruction_to_merge->fused_instructions_computation()
1426 ->root_instruction()
1427 ->parameter_number())
1428 : unfused_instructions.front();
1429 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1430
1431 TF_CHECK_OK(
1432 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1433 if (GetModule()) {
1434 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1435 }
1436
1437 // Fuse the root instruction and generate multiple outputs.
1438 if (unfused_instructions.empty()) {
1439 return;
1440 }
1441 FuseInstructionIntoMultiOutput(unfused_root);
1442 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1443 // The rest instructions are of normal fusing.
1444 for (int64 i = 1; i < unfused_instructions.size(); i++) {
1445 auto instruction = unfused_instructions[i];
1446 FuseInstruction(instruction);
1447 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1448 }
1449 }
1450
fused_instructions_computation() const1451 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1452 CHECK(!called_computations().empty());
1453 auto* fused_instructions_computation = called_computations().front();
1454 CHECK(fused_instructions_computation->IsFusionComputation())
1455 << "Computation " << fused_instructions_computation->name()
1456 << " is not a fusion kind";
1457 return fused_instructions_computation;
1458 }
1459
fused_expression_root() const1460 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1461 return fused_instructions_computation()->root_instruction();
1462 }
1463
fused_parameter(int64 parameter_number) const1464 HloInstruction* HloFusionInstruction::fused_parameter(
1465 int64 parameter_number) const {
1466 return fused_instructions_computation()->parameter_instruction(
1467 parameter_number);
1468 }
1469
fused_parameters() const1470 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1471 const {
1472 return fused_instructions_computation()->parameter_instructions();
1473 }
1474
1475 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1476 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1477 HloFusionInstruction::fused_instructions() const {
1478 const HloComputation* subcomp = fused_instructions_computation();
1479 return subcomp->instructions();
1480 }
1481
1482 const tensorflow::gtl::iterator_range<
1483 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1484 HloFusionInstruction::fused_instructions() {
1485 return fused_instructions_computation()->instructions();
1486 }
1487
fused_instruction_count() const1488 int64 HloFusionInstruction::fused_instruction_count() const {
1489 return fused_instructions_computation()->instruction_count();
1490 }
1491
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1492 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1493 HloInstruction* instruction_to_fuse, bool add_output) {
1494 // When add_output is false, this fusion instruction must be a user of
1495 // instruction_to_fuse.
1496 if (!add_output) {
1497 CHECK(IsUserOf(instruction_to_fuse));
1498 }
1499 HloInstruction* fused_instruction =
1500 CloneAndFuseInternal(instruction_to_fuse, add_output);
1501 return fused_instruction;
1502 }
1503
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1504 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1505 HloInstruction* instruction_to_fuse, bool add_output) {
1506 CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1507 VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1508 HloInstruction* clone = nullptr;
1509 if (called_computations().empty()) {
1510 // New fusion instruction. It should not be a multioutput instruction.
1511 CHECK(!add_output);
1512 auto builder = HloComputation::Builder("fused_computation", this);
1513 builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1514 AppendComputation(
1515 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1516 clone = fused_expression_root();
1517 } else {
1518 // When add_output is false, instruction_to_fuse is necessarily an operand
1519 // of the fusion instruction. After fusion this will no longer be the
1520 // case. Remove the operand from the operand list and remove its
1521 // corresponding fused parameter instruction. Renumber parameters as
1522 // necessary to make parameter numbers consistent with their index in the
1523 // fused_parameter_ vector.
1524 bool in_operand_list =
1525 absl::c_linear_search(operands(), instruction_to_fuse);
1526 CHECK(add_output || in_operand_list);
1527 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1528 // We assume all uses of a kTuple operation are GTE ops, not another
1529 // fusion node. In this case, we don't need to clone
1530 // 'instruction_to_fuse'.
1531 CHECK(!in_operand_list);
1532 clone = instruction_to_fuse;
1533 } else {
1534 clone = fused_instructions_computation()->AddInstruction(
1535 instruction_to_fuse->Clone(/*suffix=*/""));
1536 }
1537 const std::vector<HloInstruction*>& fused_parameters =
1538 fused_instructions_computation()->parameter_instructions();
1539 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1540 if (instruction_to_fuse == operand(operand_num)) {
1541 // replace the fused parameter instruction's uses with the clone.
1542 HloInstruction* fused_parameter = fused_parameters[operand_num];
1543 TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1544
1545 // Remove the corresponding fused parameter and operand from their
1546 // respective vectors.
1547 TF_CHECK_OK(
1548 fused_instructions_computation()->RemoveParameter(operand_num));
1549 RemoveOperandAt(operand_num);
1550 break;
1551 }
1552 }
1553 // We've cloned instruction_to_fuse into this fusion instruction, so this
1554 // fusion instruction is no longer a use of instruction_to_fuse.
1555 if (in_operand_list) {
1556 DetachFrom(instruction_to_fuse);
1557 // When the instruction_to_fuse does not have other users, we don't need
1558 // to generate a multioutput fusion instruction.
1559 if (instruction_to_fuse->user_count() == 0) {
1560 add_output = false;
1561 }
1562 }
1563 }
1564
1565 // Reread the parameters in the computation.
1566 const std::vector<HloInstruction*>& fused_parameters =
1567 fused_instructions_computation()->parameter_instructions();
1568
1569 // Add each operand of the clone as an operand of the fusion instruction. A
1570 // complication is that some clone operands may already be operands of the
1571 // fusion instruction.
1572 for (int64 operand_num = 0; operand_num < clone->operand_count();
1573 ++operand_num) {
1574 HloInstruction* operand = clone->mutable_operand(operand_num);
1575
1576 // See if this operand is already an operand of the fusion node.
1577 CHECK_EQ(operands().size(), fused_parameters.size());
1578 HloInstruction* fused_param = nullptr;
1579 for (int64 i = 0; i < operands().size(); ++i) {
1580 if (this->operand(i) == operand) {
1581 fused_param = fused_parameters[i];
1582 break;
1583 }
1584 }
1585
1586 if (fused_param == nullptr) {
1587 // Clone's operand was not already an operand of the fusion
1588 // instruction. Add it as an operand and add a corresponding fused
1589 // parameter instruction.
1590 fused_param = AddFusionOperand(operand);
1591 }
1592 TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1593 }
1594
1595 if (add_output) {
1596 CHECK_GT(instruction_to_fuse->user_count(), 0);
1597 // If this is already a multioutput fusion instruction, expand the root
1598 // tuple by 1.
1599 HloInstruction* fused_root = fused_expression_root();
1600 HloInstruction::InstructionVector tuple_elements;
1601 bool newly_created_tuple_instr = false;
1602 if (fused_root->opcode() == HloOpcode::kTuple) {
1603 tuple_elements = fused_root->operands();
1604 } else {
1605 tuple_elements.push_back(fused_root);
1606 newly_created_tuple_instr = true;
1607 }
1608 if (clone->opcode() == HloOpcode::kTuple) {
1609 for (auto inst : clone->operands()) {
1610 tuple_elements.push_back(inst);
1611 }
1612 } else {
1613 tuple_elements.push_back(clone);
1614 }
1615 HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1616 HloInstruction::CreateTuple(tuple_elements));
1617 fused_instructions_computation()->set_root_instruction(new_root);
1618 *mutable_shape() = new_root->shape();
1619 if (fused_root->opcode() == HloOpcode::kTuple) {
1620 TF_CHECK_OK(
1621 fused_instructions_computation()->RemoveInstruction(fused_root));
1622 }
1623
1624 // If this is a newly created multioutput instruction, we need to update
1625 // the use of the original fusion instruction.
1626 if (newly_created_tuple_instr) {
1627 HloInstruction* new_instr = parent()->AddInstruction(
1628 HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1629 TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1630 }
1631 int64 index = tuple_elements.size();
1632 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1633 CHECK_EQ(clone, instruction_to_fuse);
1634 index -= clone->operand_count();
1635 std::vector<HloInstruction*> to_be_removed;
1636 for (auto old_gte : clone->users()) {
1637 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1638 int64 old_tuple_index = old_gte->tuple_index();
1639 HloInstruction* new_gte =
1640 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1641 old_gte->shape(), this, index + old_tuple_index));
1642 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1643 to_be_removed.push_back(old_gte);
1644 }
1645 for (auto old_gte : to_be_removed) {
1646 TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1647 }
1648 } else {
1649 HloInstruction* new_gte =
1650 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1651 clone->shape(), this, index - 1));
1652 TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1653 }
1654 }
1655
1656 if (clone != instruction_to_fuse) {
1657 VLOG(2) << "New clone:\n" << clone->ToString();
1658 }
1659 return clone;
1660 }
1661
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1662 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1663 const HloPrintOptions& options) const {
1664 return {StrCat("kind=", xla::ToString(fusion_kind()))};
1665 }
1666
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1667 bool HloFusionInstruction::IdenticalSlowPath(
1668 const HloInstruction& other,
1669 const std::function<bool(const HloComputation*, const HloComputation*)>&
1670 eq_computations) const {
1671 return fusion_kind() == other.fusion_kind() &&
1672 eq_computations(fused_instructions_computation(),
1673 other.fused_instructions_computation());
1674 }
1675
InnerHash() const1676 uint64 HloFusionInstruction::InnerHash() const {
1677 return fused_instructions_computation()->root_instruction()->Hash();
1678 }
1679
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1680 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1681 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1682 HloCloneContext* context) const {
1683 HloModule* module = context != nullptr ? context->module() : GetModule();
1684 HloComputation* new_fused_computation = nullptr;
1685 if (context != nullptr) {
1686 new_fused_computation =
1687 context->FindComputation(fused_instructions_computation());
1688 }
1689 if (new_fused_computation == nullptr) {
1690 new_fused_computation = module->AddEmbeddedComputation(
1691 fused_instructions_computation()->Clone("clone", context));
1692 }
1693 return absl::make_unique<HloFusionInstruction>(
1694 shape, fusion_kind(), new_operands, new_fused_computation);
1695 }
1696
DeduplicateFusionOperands()1697 Status HloFusionInstruction::DeduplicateFusionOperands() {
1698 if (IsCustomFusion()) {
1699 return Status::OK();
1700 }
1701 absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1702 std::vector<int> operands_to_remove;
1703 for (int i = 0; i < operand_count(); ++i) {
1704 auto emplace_result = operand_indices.emplace(operand(i), i);
1705 if (!emplace_result.second) {
1706 TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1707 fused_parameter(emplace_result.first->second)));
1708 operands_to_remove.push_back(i);
1709 }
1710 }
1711 if (operands_to_remove.empty()) {
1712 return Status::OK();
1713 }
1714 TF_RETURN_IF_ERROR(fused_instructions_computation()
1715 ->RemoveUnusedParametersFromFusedComputation());
1716 RemoveOperandsAtAscendingIndices(operands_to_remove);
1717 return Status::OK();
1718 }
1719
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1720 HloRngInstruction::HloRngInstruction(
1721 const Shape& shape, RandomDistribution distribution,
1722 absl::Span<HloInstruction* const> parameters)
1723 : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1724 for (HloInstruction* param : parameters) {
1725 AppendOperand(param);
1726 }
1727 }
1728
ToProto() const1729 HloInstructionProto HloRngInstruction::ToProto() const {
1730 HloInstructionProto proto = HloInstruction::ToProto();
1731 proto.set_distribution(distribution_);
1732 return proto;
1733 }
1734
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1735 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1736 const HloPrintOptions& options) const {
1737 return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1738 }
1739
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1740 bool HloRngInstruction::IsElementwiseImpl(
1741 const absl::optional<int64>& operand_idx) const {
1742 return true;
1743 }
1744
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1745 bool HloRngInstruction::IdenticalSlowPath(
1746 const HloInstruction& other,
1747 const std::function<bool(const HloComputation*, const HloComputation*)>&
1748 eq_computations) const {
1749 return true;
1750 }
1751
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1752 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1753 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1754 HloCloneContext* context) const {
1755 return absl::make_unique<HloRngInstruction>(shape, distribution_,
1756 new_operands);
1757 }
1758
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1759 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1760 const Shape& shape,
1761 const string& name)
1762 : HloInstruction(HloOpcode::kParameter, shape),
1763 parameter_number_(parameter_number) {
1764 SetAndSanitizeName(name);
1765 }
1766
ToProto() const1767 HloInstructionProto HloParameterInstruction::ToProto() const {
1768 HloInstructionProto proto = HloInstruction::ToProto();
1769 proto.set_parameter_number(parameter_number_);
1770 if (parameter_replicated_at_leaf_buffers_) {
1771 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1772 proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1773 replicated);
1774 }
1775 }
1776 return proto;
1777 }
1778
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1779 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1780 const HloPrintOptions& options) const {
1781 std::vector<string> result;
1782 if (!parameter_replicated_at_leaf_buffers_) {
1783 return result;
1784 }
1785 std::vector<string> buffers_replicated_strs;
1786 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1787 buffers_replicated_strs.push_back(replicated ? "true" : "false");
1788 }
1789 if (options.print_ids()) {
1790 result.push_back(StrCat("parameter_replication={",
1791 StrJoin(buffers_replicated_strs, ","), "}"));
1792 }
1793 return result;
1794 }
1795
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1796 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1797 const HloPrintOptions& options,
1798 CanonicalNameMap* canonical_name_map) const {
1799 return StrCat(parameter_number_);
1800 }
1801
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1802 bool HloParameterInstruction::IdenticalSlowPath(
1803 const HloInstruction& other,
1804 const std::function<bool(const HloComputation*, const HloComputation*)>&
1805 eq_computations) const {
1806 const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1807 return parameter_number() == casted_other.parameter_number();
1808 }
1809
1810 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1811 HloParameterInstruction::CloneWithNewOperandsImpl(
1812 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1813 HloCloneContext* context) const {
1814 return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
1815 name());
1816 }
1817
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1818 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1819 const Shape& shape, HloInstruction* operand, int64 index)
1820 : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1821 AppendOperand(operand);
1822 }
1823
ToProto() const1824 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1825 HloInstructionProto proto = HloInstruction::ToProto();
1826 proto.set_tuple_index(tuple_index_);
1827 return proto;
1828 }
1829
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1830 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1831 const HloPrintOptions& options) const {
1832 return {StrCat("index=", tuple_index())};
1833 }
1834
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1835 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1836 const HloInstruction& other,
1837 const std::function<bool(const HloComputation*, const HloComputation*)>&
1838 eq_computations) const {
1839 const auto& casted_other =
1840 static_cast<const HloGetTupleElementInstruction&>(other);
1841 return tuple_index() == casted_other.tuple_index();
1842 }
1843
1844 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1845 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1846 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1847 HloCloneContext* context) const {
1848 CHECK_EQ(new_operands.size(), 1);
1849 return absl::make_unique<HloGetTupleElementInstruction>(
1850 shape, new_operands[0], tuple_index());
1851 }
1852
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1853 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
1854 const Shape& shape, HloInstruction* operand, const int exponent_bits,
1855 const int mantissa_bits)
1856 : HloInstruction(HloOpcode::kReducePrecision, shape),
1857 exponent_bits_(exponent_bits),
1858 mantissa_bits_(mantissa_bits) {
1859 AppendOperand(operand);
1860 }
1861
ToProto() const1862 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
1863 HloInstructionProto proto = HloInstruction::ToProto();
1864 proto.set_exponent_bits(exponent_bits_);
1865 proto.set_mantissa_bits(mantissa_bits_);
1866 return proto;
1867 }
1868
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1869 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
1870 const HloPrintOptions& options) const {
1871 return {StrCat("exponent_bits=", exponent_bits_),
1872 StrCat("mantissa_bits=", mantissa_bits_)};
1873 }
1874
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1875 bool HloReducePrecisionInstruction::IdenticalSlowPath(
1876 const HloInstruction& other,
1877 const std::function<bool(const HloComputation*, const HloComputation*)>&
1878 eq_computations) const {
1879 const auto& casted_other =
1880 static_cast<const HloReducePrecisionInstruction&>(other);
1881 // A reduce-precision operation is determined by the bit sizes.
1882 return exponent_bits() == casted_other.exponent_bits() &&
1883 mantissa_bits() == casted_other.mantissa_bits();
1884 }
1885
1886 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1887 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
1888 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1889 HloCloneContext* context) const {
1890 CHECK_EQ(new_operands.size(), 1);
1891 return absl::make_unique<HloReducePrecisionInstruction>(
1892 shape, new_operands[0], exponent_bits(), mantissa_bits());
1893 }
1894
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1895 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
1896 HloInstruction* token_operand,
1897 const string& config)
1898 : HloInstruction(HloOpcode::kInfeed,
1899 ShapeUtil::MakeTupleShape(
1900 {infeed_shape, ShapeUtil::MakeTokenShape()})),
1901 infeed_config_(config) {
1902 AppendOperand(token_operand);
1903 }
1904
ToProto() const1905 HloInstructionProto HloInfeedInstruction::ToProto() const {
1906 HloInstructionProto proto = HloInstruction::ToProto();
1907 proto.set_infeed_config(infeed_config_);
1908 return proto;
1909 }
1910
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1911 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
1912 const HloPrintOptions& options) const {
1913 if (infeed_config_.empty()) {
1914 return {};
1915 }
1916 return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
1917 }
1918
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1919 bool HloInfeedInstruction::IdenticalSlowPath(
1920 const HloInstruction& other,
1921 const std::function<bool(const HloComputation*, const HloComputation*)>&
1922 eq_computations) const {
1923 // Not yet supported.
1924 return false;
1925 }
1926
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1927 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
1928 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1929 HloCloneContext* context) const {
1930 CHECK_EQ(new_operands.size(), 1);
1931 return absl::make_unique<HloInfeedInstruction>(
1932 infeed_shape(), new_operands[0], infeed_config());
1933 }
1934
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1935 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
1936 HloInstruction* operand,
1937 HloInstruction* token_operand,
1938 absl::string_view outfeed_config)
1939 : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
1940 outfeed_shape_(outfeed_shape),
1941 outfeed_config_(outfeed_config) {
1942 AppendOperand(operand);
1943 AppendOperand(token_operand);
1944 }
1945
ToProto() const1946 HloInstructionProto HloOutfeedInstruction::ToProto() const {
1947 HloInstructionProto proto = HloInstruction::ToProto();
1948 proto.set_outfeed_config(outfeed_config());
1949 *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
1950 return proto;
1951 }
1952
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1953 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
1954 const HloPrintOptions& options) const {
1955 if (outfeed_config_.empty()) {
1956 return {};
1957 }
1958 return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
1959 }
1960
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1961 bool HloOutfeedInstruction::IdenticalSlowPath(
1962 const HloInstruction& other,
1963 const std::function<bool(const HloComputation*, const HloComputation*)>&
1964 eq_computations) const {
1965 // Not yet supported.
1966 return false;
1967 }
1968
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1969 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
1970 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1971 HloCloneContext* context) const {
1972 CHECK_EQ(new_operands.size(), 2);
1973 return absl::make_unique<HloOutfeedInstruction>(
1974 outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
1975 }
1976
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1977 HloConvolutionInstruction::HloConvolutionInstruction(
1978 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1979 int64 feature_group_count, int64 batch_group_count, const Window& window,
1980 const ConvolutionDimensionNumbers& dimension_numbers,
1981 const PrecisionConfig& precision_config)
1982 : HloInstruction(HloOpcode::kConvolution, shape),
1983 feature_group_count_(feature_group_count),
1984 batch_group_count_(batch_group_count),
1985 window_(window),
1986 convolution_dimension_numbers_(dimension_numbers),
1987 precision_config_(precision_config) {
1988 if (window_util::HasBaseDilation(window)) {
1989 SetAndSanitizeName(StrCat(name(), "-base-dilated"));
1990 }
1991 if (window_util::HasWindowDilation(window)) {
1992 SetAndSanitizeName(StrCat(name(), "-window-dilated"));
1993 }
1994 AppendOperand(lhs);
1995 AppendOperand(rhs);
1996 }
1997
ToCategory() const1998 string HloConvolutionInstruction::ToCategory() const {
1999 string category = "convolution";
2000 if (window_util::HasBaseDilation(window())) {
2001 category += " base-dilated";
2002 }
2003 if (window_util::HasWindowDilation(window())) {
2004 category += " window-dilated";
2005 }
2006 return category;
2007 }
2008
ToProto() const2009 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2010 HloInstructionProto proto = HloInstruction::ToProto();
2011 *proto.mutable_window() = window_;
2012 *proto.mutable_convolution_dimension_numbers() =
2013 convolution_dimension_numbers_;
2014 proto.set_feature_group_count(feature_group_count_);
2015 proto.set_batch_group_count(batch_group_count_);
2016 *proto.mutable_precision_config() = precision_config_;
2017 return proto;
2018 }
2019
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2020 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2021 const HloPrintOptions& options) const {
2022 std::vector<string> extra;
2023 if (window_.dimensions_size() != 0) {
2024 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2025 }
2026 extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2027 convolution_dimension_numbers_)));
2028 if (feature_group_count_ != 1) {
2029 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2030 }
2031
2032 if (batch_group_count_ != 1) {
2033 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2034 }
2035
2036 string precision_config_string = PrecisionConfigToString(precision_config_);
2037 if (!precision_config_string.empty()) {
2038 extra.push_back(precision_config_string);
2039 }
2040
2041 return extra;
2042 }
2043
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2044 bool HloConvolutionInstruction::IdenticalSlowPath(
2045 const HloInstruction& other,
2046 const std::function<bool(const HloComputation*, const HloComputation*)>&
2047 eq_computations) const {
2048 const auto& casted_other =
2049 static_cast<const HloConvolutionInstruction&>(other);
2050 if (feature_group_count_ != other.feature_group_count()) {
2051 return false;
2052 }
2053 if (batch_group_count_ != other.batch_group_count()) {
2054 return false;
2055 }
2056 return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2057 protobuf_util::ProtobufEquals(
2058 convolution_dimension_numbers(),
2059 casted_other.convolution_dimension_numbers()) &&
2060 protobuf_util::ProtobufEquals(precision_config(),
2061 casted_other.precision_config());
2062 }
2063
2064 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2065 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2066 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2067 HloCloneContext* context) const {
2068 CHECK_EQ(new_operands.size(), 2);
2069 return absl::make_unique<HloConvolutionInstruction>(
2070 shape, new_operands[0], new_operands[1], feature_group_count_,
2071 batch_group_count_, window(), convolution_dimension_numbers_,
2072 precision_config_);
2073 }
2074
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2075 HloReduceWindowInstruction::HloReduceWindowInstruction(
2076 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2077 const Window& window, HloComputation* reduce_computation)
2078 : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2079 AppendOperand(operand);
2080 AppendOperand(init_value);
2081 AppendComputation(reduce_computation);
2082 }
2083
ToProto() const2084 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2085 HloInstructionProto proto = HloInstruction::ToProto();
2086 *proto.mutable_window() = window_;
2087 return proto;
2088 }
2089
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2090 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2091 const HloPrintOptions& options) const {
2092 std::vector<string> extra;
2093 if (window_.dimensions_size() != 0) {
2094 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2095 }
2096 return extra;
2097 }
2098
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2099 bool HloReduceWindowInstruction::IdenticalSlowPath(
2100 const HloInstruction& other,
2101 const std::function<bool(const HloComputation*, const HloComputation*)>&
2102 eq_computations) const {
2103 const auto& casted_other =
2104 static_cast<const HloReduceWindowInstruction&>(other);
2105 return eq_computations(to_apply(), casted_other.to_apply()) &&
2106 protobuf_util::ProtobufEquals(window(), casted_other.window());
2107 }
2108
2109 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2110 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2111 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2112 HloCloneContext* context) const {
2113 CHECK_EQ(new_operands.size(), 2);
2114 return absl::make_unique<HloReduceWindowInstruction>(
2115 shape, new_operands[0], new_operands[1], window(), to_apply());
2116 }
2117
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2118 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2119 const Shape& shape, HloInstruction* operand, HloComputation* select,
2120 const Window& window, HloInstruction* source, HloInstruction* init_value,
2121 HloComputation* scatter)
2122 : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2123 AppendOperand(operand);
2124 AppendOperand(source);
2125 AppendOperand(init_value);
2126 // Select comes before scatter in the vector.
2127 AppendComputation(select);
2128 AppendComputation(scatter);
2129 }
2130
ToProto() const2131 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2132 HloInstructionProto proto = HloInstruction::ToProto();
2133 *proto.mutable_window() = window_;
2134 return proto;
2135 }
2136
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2137 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2138 const HloPrintOptions& options) const {
2139 std::vector<string> extra;
2140 if (window_.dimensions_size() != 0) {
2141 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2142 }
2143 return extra;
2144 }
2145
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2146 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2147 const HloInstruction& other,
2148 const std::function<bool(const HloComputation*, const HloComputation*)>&
2149 eq_computations) const {
2150 const auto& casted_other =
2151 static_cast<const HloSelectAndScatterInstruction&>(other);
2152 return eq_computations(select(), casted_other.select()) &&
2153 eq_computations(scatter(), casted_other.scatter()) &&
2154 protobuf_util::ProtobufEquals(window(), casted_other.window());
2155 }
2156
2157 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2158 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2159 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2160 HloCloneContext* context) const {
2161 CHECK_EQ(new_operands.size(), 3);
2162 return absl::make_unique<HloSelectAndScatterInstruction>(
2163 shape, new_operands[0], select(), window(), new_operands[1],
2164 new_operands[2], scatter());
2165 }
2166
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)2167 HloCustomCallInstruction::HloCustomCallInstruction(
2168 const Shape& shape, absl::Span<HloInstruction* const> operands,
2169 absl::string_view custom_call_target, string opaque)
2170 : HloInstruction(HloOpcode::kCustomCall, shape),
2171 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2172 feature_group_count_(1),
2173 batch_group_count_(1),
2174 layout_constrained_(false),
2175 custom_call_has_side_effect_(false) {
2176 set_raw_backend_config_string(std::move(opaque));
2177 for (auto operand : operands) {
2178 AppendOperand(operand);
2179 }
2180 }
2181
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout)2182 HloCustomCallInstruction::HloCustomCallInstruction(
2183 const Shape& shape, absl::Span<HloInstruction* const> operands,
2184 absl::string_view custom_call_target, string opaque,
2185 absl::Span<const Shape> operand_shapes_with_layout)
2186 : HloInstruction(HloOpcode::kCustomCall, shape),
2187 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2188 feature_group_count_(1),
2189 batch_group_count_(1),
2190 layout_constrained_(true),
2191 operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2192 operand_shapes_with_layout.end()),
2193 custom_call_has_side_effect_(false) {
2194 set_raw_backend_config_string(std::move(opaque));
2195 for (auto operand : operands) {
2196 AppendOperand(operand);
2197 }
2198 }
2199
ToProto() const2200 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2201 HloInstructionProto proto = HloInstruction::ToProto();
2202 if (window_ != nullptr) {
2203 *proto.mutable_window() = *window_;
2204 }
2205 if (convolution_dimension_numbers_ != nullptr) {
2206 *proto.mutable_convolution_dimension_numbers() =
2207 *convolution_dimension_numbers_;
2208 }
2209 proto.set_custom_call_target(custom_call_target_);
2210 proto.set_feature_group_count(feature_group_count_);
2211 proto.set_batch_group_count(batch_group_count_);
2212 if (layout_constrained()) {
2213 proto.set_constrain_layout(true);
2214 for (const Shape& shape : operand_shapes_with_layout_) {
2215 *proto.add_operand_shapes_with_layout() = shape.ToProto();
2216 }
2217 }
2218 proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2219 return proto;
2220 }
2221
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2222 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2223 const HloPrintOptions& options) const {
2224 std::vector<string> extra;
2225 if (window_ != nullptr) {
2226 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2227 }
2228 if (convolution_dimension_numbers_ != nullptr) {
2229 extra.push_back(StrCat(
2230 "dim_labels=",
2231 ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2232 }
2233 if (feature_group_count_ != 1) {
2234 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2235 }
2236 if (batch_group_count_ != 1) {
2237 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2238 }
2239 // By contract, we print the custom call target even if
2240 // options.print_subcomputation_mode() == kOff, because the call target is not
2241 // an HloComputation.
2242 extra.push_back(
2243 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2244
2245 if (layout_constrained()) {
2246 std::vector<string> shape_strings;
2247 for (const Shape& shape : operand_shapes_with_layout_) {
2248 shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2249 }
2250 extra.push_back(StrCat("operand_layout_constraints={",
2251 StrJoin(shape_strings, ", "), "}"));
2252 }
2253 if (custom_call_has_side_effect_) {
2254 extra.push_back("custom_call_has_side_effect=true");
2255 }
2256 return extra;
2257 }
2258
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2259 bool HloCustomCallInstruction::IdenticalSlowPath(
2260 const HloInstruction& other,
2261 const std::function<bool(const HloComputation*, const HloComputation*)>&
2262 eq_computations) const {
2263 const auto& casted_other =
2264 static_cast<const HloCustomCallInstruction&>(other);
2265 if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2266 (window_ != nullptr &&
2267 !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2268 return false;
2269 }
2270 if ((convolution_dimension_numbers_ == nullptr) !=
2271 (casted_other.convolution_dimension_numbers_ == nullptr) ||
2272 (convolution_dimension_numbers_ != nullptr &&
2273 !protobuf_util::ProtobufEquals(
2274 convolution_dimension_numbers(),
2275 casted_other.convolution_dimension_numbers()))) {
2276 return false;
2277 }
2278 if (feature_group_count_ != casted_other.feature_group_count_) {
2279 return false;
2280 }
2281 if (batch_group_count_ != casted_other.batch_group_count_) {
2282 return false;
2283 }
2284 if (layout_constrained() != casted_other.layout_constrained()) {
2285 return false;
2286 }
2287 if (layout_constrained()) {
2288 for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2289 if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2290 casted_other.operand_shapes_with_layout_[i])) {
2291 return false;
2292 }
2293 }
2294 }
2295 if (custom_call_has_side_effect_ !=
2296 casted_other.custom_call_has_side_effect()) {
2297 return false;
2298 }
2299 // Note: backend_config comparison is done in Identical, which is the
2300 // intended/exposed way to compare computations, and so not repeated here.
2301 return custom_call_target_ == casted_other.custom_call_target_;
2302 }
2303
2304 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2305 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2306 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2307 HloCloneContext* context) const {
2308 auto cloned = absl::make_unique<HloCustomCallInstruction>(
2309 shape, new_operands, custom_call_target(), opaque());
2310 if (layout_constrained()) {
2311 cloned->layout_constrained_ = true;
2312 cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2313 }
2314 if (window_ != nullptr) {
2315 cloned->set_window(*window_);
2316 }
2317 if (convolution_dimension_numbers_ != nullptr) {
2318 cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2319 }
2320 cloned->set_feature_group_count(feature_group_count_);
2321 cloned->set_batch_group_count(batch_group_count_);
2322 cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2323 return std::move(cloned);
2324 }
2325
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2326 HloPadInstruction::HloPadInstruction(const Shape& shape,
2327 HloInstruction* operand,
2328 HloInstruction* padding_value,
2329 const PaddingConfig& padding_config)
2330 : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2331 AppendOperand(operand);
2332 AppendOperand(padding_value);
2333 }
2334
ToProto() const2335 HloInstructionProto HloPadInstruction::ToProto() const {
2336 HloInstructionProto proto = HloInstruction::ToProto();
2337 *proto.mutable_padding_config() = padding_config_;
2338 return proto;
2339 }
2340
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2341 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2342 const HloPrintOptions& options) const {
2343 return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2344 }
2345
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2346 bool HloPadInstruction::IdenticalSlowPath(
2347 const HloInstruction& other,
2348 const std::function<bool(const HloComputation*, const HloComputation*)>&
2349 eq_computations) const {
2350 const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2351 return protobuf_util::ProtobufEquals(padding_config(),
2352 casted_other.padding_config());
2353 }
2354
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2355 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2356 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2357 HloCloneContext* context) const {
2358 CHECK_EQ(new_operands.size(), 2);
2359 return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2360 new_operands[1], padding_config_);
2361 }
2362
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2363 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2364 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2365 absl::Span<const int64> slice_sizes)
2366 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2367 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2368 AppendOperand(operand);
2369 AppendOperand(start_indices);
2370 }
2371
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2372 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2373 const Shape& shape, HloInstruction* operand,
2374 absl::Span<HloInstruction* const> start_indices,
2375 absl::Span<const int64> slice_sizes)
2376 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2377 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2378 AppendOperand(operand);
2379 for (HloInstruction* index : start_indices) {
2380 AppendOperand(index);
2381 }
2382 }
2383
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2384 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2385 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2386 HloInstruction* start_indices)
2387 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2388 AppendOperand(operand);
2389 AppendOperand(update);
2390 AppendOperand(start_indices);
2391 }
2392
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2393 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2394 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2395 absl::Span<HloInstruction* const> start_indices)
2396 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2397 AppendOperand(operand);
2398 AppendOperand(update);
2399 for (HloInstruction* index : start_indices) {
2400 AppendOperand(index);
2401 }
2402 }
2403
ToProto() const2404 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2405 HloInstructionProto proto = HloInstruction::ToProto();
2406 for (int64 slice_size : dynamic_slice_sizes_) {
2407 proto.add_dynamic_slice_sizes(slice_size);
2408 }
2409 return proto;
2410 }
2411
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2412 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2413 const HloPrintOptions& options) const {
2414 return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2415 "}")};
2416 }
2417
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2418 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2419 const HloInstruction& other,
2420 const std::function<bool(const HloComputation*, const HloComputation*)>&
2421 eq_computations) const {
2422 return true;
2423 }
2424
2425 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2426 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2427 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2428 HloCloneContext* context) const {
2429 if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2430 // TODO(b/118437727): Old form, remove this path.
2431 return absl::make_unique<HloDynamicSliceInstruction>(
2432 shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2433 } else {
2434 return absl::make_unique<HloDynamicSliceInstruction>(
2435 shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2436 }
2437 }
2438
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2439 HloGatherInstruction::HloGatherInstruction(
2440 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2441 const GatherDimensionNumbers& gather_dim_numbers,
2442 absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2443 : HloInstruction(HloOpcode::kGather, shape),
2444 indices_are_sorted_(indices_are_sorted) {
2445 AppendOperand(operand);
2446 AppendOperand(start_indices);
2447 gather_dimension_numbers_ =
2448 absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2449 absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2450 }
2451
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2452 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2453 const GatherDimensionNumbers& gather_dimension_numbers) {
2454 string offset_dims =
2455 StrCat("offset_dims={",
2456 StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2457 string collapsed_slice_dims = StrCat(
2458 "collapsed_slice_dims={",
2459 StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2460 string start_index_map =
2461 StrCat("start_index_map={",
2462 StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2463 string index_vector_dim =
2464 StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2465
2466 return StrJoin<std::initializer_list<string>>(
2467 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2468 ", ");
2469 }
2470
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2471 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2472 absl::Span<const int64> offset_dims,
2473 absl::Span<const int64> collapsed_slice_dims,
2474 absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2475 GatherDimensionNumbers gather_dim_numbers;
2476 for (int64 output_window_dim : offset_dims) {
2477 gather_dim_numbers.add_offset_dims(output_window_dim);
2478 }
2479 for (int64 elided_window_dim : collapsed_slice_dims) {
2480 gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2481 }
2482 for (int64 gather_dim_to_input_dim : start_index_map) {
2483 gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2484 }
2485
2486 gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2487 return gather_dim_numbers;
2488 }
2489
ToProto() const2490 HloInstructionProto HloGatherInstruction::ToProto() const {
2491 HloInstructionProto proto = HloInstruction::ToProto();
2492 *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2493 for (int64 bound : gather_slice_sizes()) {
2494 proto.add_gather_slice_sizes(bound);
2495 }
2496 proto.set_indices_are_sorted(indices_are_sorted());
2497 return proto;
2498 }
2499
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2500 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2501 const HloPrintOptions& options) const {
2502 std::vector<string> attrs{
2503 GatherDimensionNumbersToString(gather_dimension_numbers()),
2504 StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2505 if (indices_are_sorted()) {
2506 attrs.push_back("indices_are_sorted=true");
2507 }
2508 return attrs;
2509 }
2510
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2511 bool HloGatherInstruction::IdenticalSlowPath(
2512 const HloInstruction& other,
2513 const std::function<bool(const HloComputation*, const HloComputation*)>&
2514 eq_computations) const {
2515 const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2516 return protobuf_util::ProtobufEquals(
2517 gather_dimension_numbers(),
2518 casted_other.gather_dimension_numbers()) &&
2519 gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2520 indices_are_sorted() == casted_other.indices_are_sorted();
2521 }
2522
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2523 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2524 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2525 HloCloneContext* context) const {
2526 CHECK_EQ(new_operands.size(), 2);
2527 return absl::make_unique<HloGatherInstruction>(
2528 shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2529 gather_slice_sizes(), indices_are_sorted());
2530 }
2531
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2532 HloScatterInstruction::HloScatterInstruction(
2533 const Shape& shape, HloInstruction* operand,
2534 HloInstruction* scatter_indices, HloInstruction* updates,
2535 HloComputation* update_computation,
2536 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2537 bool unique_indices)
2538 : HloInstruction(HloOpcode::kScatter, shape),
2539 indices_are_sorted_(indices_are_sorted),
2540 unique_indices_(unique_indices) {
2541 AppendOperand(operand);
2542 AppendOperand(scatter_indices);
2543 AppendOperand(updates);
2544 AppendComputation(update_computation);
2545 scatter_dimension_numbers_ =
2546 absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2547 }
2548
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2549 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2550 const ScatterDimensionNumbers& scatter_dimension_numbers) {
2551 string update_window_dims =
2552 StrCat("update_window_dims={",
2553 StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
2554 string inserted_window_dims = StrCat(
2555 "inserted_window_dims={",
2556 StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
2557 string scatter_dims_to_operand_dims = StrCat(
2558 "scatter_dims_to_operand_dims={",
2559 StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
2560 "}");
2561 string index_vector_dim =
2562 StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
2563
2564 return StrJoin<std::initializer_list<string>>(
2565 {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2566 index_vector_dim},
2567 ", ");
2568 }
2569
2570 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2571 HloScatterInstruction::MakeScatterDimNumbers(
2572 absl::Span<const int64> update_window_dims,
2573 absl::Span<const int64> inserted_window_dims,
2574 absl::Span<const int64> scatter_dims_to_operand_dims,
2575 int64 index_vector_dim) {
2576 ScatterDimensionNumbers scatter_dim_numbers;
2577 for (int64 update_window_dim : update_window_dims) {
2578 scatter_dim_numbers.add_update_window_dims(update_window_dim);
2579 }
2580 for (int64 inserted_window_dim : inserted_window_dims) {
2581 scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2582 }
2583 for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2584 scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2585 scatter_dim_to_operand_dim);
2586 }
2587 scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2588 return scatter_dim_numbers;
2589 }
2590
ToProto() const2591 HloInstructionProto HloScatterInstruction::ToProto() const {
2592 HloInstructionProto proto = HloInstruction::ToProto();
2593 *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2594 proto.set_indices_are_sorted(indices_are_sorted());
2595 proto.set_unique_indices(unique_indices());
2596 return proto;
2597 }
2598
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2599 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2600 const HloPrintOptions& options) const {
2601 std::vector<string> attrs{
2602 ScatterDimensionNumbersToString(scatter_dimension_numbers())};
2603 if (indices_are_sorted()) {
2604 attrs.push_back("indices_are_sorted=true");
2605 }
2606 if (unique_indices()) {
2607 attrs.push_back("unique_indices=true");
2608 }
2609 return attrs;
2610 }
2611
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2612 bool HloScatterInstruction::IdenticalSlowPath(
2613 const HloInstruction& other,
2614 const std::function<bool(const HloComputation*, const HloComputation*)>&
2615 eq_computations) const {
2616 const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2617 return protobuf_util::ProtobufEquals(
2618 scatter_dimension_numbers(),
2619 casted_other.scatter_dimension_numbers()) &&
2620 eq_computations(to_apply(), casted_other.to_apply()) &&
2621 indices_are_sorted() == casted_other.indices_are_sorted() &&
2622 unique_indices() == casted_other.unique_indices();
2623 }
2624
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2625 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2626 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2627 HloCloneContext* context) const {
2628 CHECK_EQ(new_operands.size(), 3);
2629 return absl::make_unique<HloScatterInstruction>(
2630 shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2631 scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
2632 }
2633
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2634 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2635 : HloInstruction(HloOpcode::kIota, shape),
2636 iota_dimension_(iota_dimension) {}
2637
ToProto() const2638 HloInstructionProto HloIotaInstruction::ToProto() const {
2639 HloInstructionProto proto = HloInstruction::ToProto();
2640 proto.add_dimensions(iota_dimension());
2641 return proto;
2642 }
2643
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2644 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2645 const HloPrintOptions& options) const {
2646 return {StrCat("iota_dimension=", iota_dimension())};
2647 }
2648
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2649 bool HloIotaInstruction::IdenticalSlowPath(
2650 const HloInstruction& other,
2651 const std::function<bool(const HloComputation*, const HloComputation*)>&
2652 eq_computations) const {
2653 const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2654 return iota_dimension() == casted_other.iota_dimension();
2655 }
2656
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2657 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2658 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2659 HloCloneContext* context) const {
2660 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2661 }
2662
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2663 HloDotInstruction::HloDotInstruction(
2664 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2665 const DotDimensionNumbers& dimension_numbers,
2666 const PrecisionConfig& precision_config)
2667 : HloInstruction(HloOpcode::kDot, shape),
2668 dot_dimension_numbers_(dimension_numbers),
2669 precision_config_(precision_config) {
2670 AppendOperand(lhs);
2671 AppendOperand(rhs);
2672 }
2673
ToProto() const2674 HloInstructionProto HloDotInstruction::ToProto() const {
2675 HloInstructionProto proto = HloInstruction::ToProto();
2676 *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2677 *proto.mutable_precision_config() = precision_config_;
2678 return proto;
2679 }
2680
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2681 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2682 const HloPrintOptions& options) const {
2683 std::vector<string> extra = {DotDimensionNumbersToString()};
2684
2685 string precision_config_string = PrecisionConfigToString(precision_config_);
2686 if (!precision_config_string.empty()) {
2687 extra.push_back(precision_config_string);
2688 }
2689 return extra;
2690 }
2691
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2692 bool HloDotInstruction::IdenticalSlowPath(
2693 const HloInstruction& other,
2694 const std::function<bool(const HloComputation*, const HloComputation*)>&
2695 eq_computations) const {
2696 const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2697 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2698 casted_other.dot_dimension_numbers()) &&
2699 protobuf_util::ProtobufEquals(precision_config(),
2700 casted_other.precision_config());
2701 }
2702
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2703 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2704 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2705 HloCloneContext* context) const {
2706 CHECK_EQ(new_operands.size(), 2);
2707 return absl::make_unique<HloDotInstruction>(
2708 shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2709 precision_config_);
2710 }
2711
DotDimensionNumbersToString() const2712 string HloDotInstruction::DotDimensionNumbersToString() const {
2713 std::vector<string> result;
2714 const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2715 if (!dnums.lhs_batch_dimensions().empty()) {
2716 result.push_back(StrCat("lhs_batch_dims={",
2717 StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
2718 }
2719 result.push_back(StrCat("lhs_contracting_dims={",
2720 StrJoin(dnums.lhs_contracting_dimensions(), ","),
2721 "}"));
2722
2723 if (!dnums.rhs_batch_dimensions().empty()) {
2724 result.push_back(StrCat("rhs_batch_dims={",
2725 StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
2726 }
2727 result.push_back(StrCat("rhs_contracting_dims={",
2728 StrJoin(dnums.rhs_contracting_dimensions(), ","),
2729 "}"));
2730
2731 return StrJoin(result, ", ");
2732 }
2733
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)2734 HloDomainInstruction::HloDomainInstruction(
2735 const Shape& shape, HloInstruction* operand,
2736 std::unique_ptr<DomainMetadata> operand_side_metadata,
2737 std::unique_ptr<DomainMetadata> user_side_metadata)
2738 : HloInstruction(HloOpcode::kDomain, shape),
2739 operand_side_metadata_(std::move(operand_side_metadata)),
2740 user_side_metadata_(std::move(user_side_metadata)) {
2741 AppendOperand(operand);
2742 }
2743
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2744 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
2745 const HloPrintOptions& options) const {
2746 if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
2747 return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
2748 "\", entry=", user_side_metadata_->ToString(),
2749 ", exit=", operand_side_metadata_->ToString(), "}")};
2750 }
2751 return {};
2752 }
2753
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2754 bool HloDomainInstruction::IdenticalSlowPath(
2755 const HloInstruction& other,
2756 const std::function<bool(const HloComputation*, const HloComputation*)>&
2757 eq_computations) const {
2758 const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
2759 return operand_side_metadata().Matches(
2760 casted_other.operand_side_metadata()) &&
2761 user_side_metadata().Matches(casted_other.user_side_metadata());
2762 }
2763
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2764 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
2765 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2766 HloCloneContext* context) const {
2767 CHECK_EQ(new_operands.size(), 1);
2768 return absl::make_unique<HloDomainInstruction>(
2769 shape, new_operands[0], operand_side_metadata_->Clone(),
2770 user_side_metadata_->Clone());
2771 }
2772
ToProto() const2773 HloInstructionProto HloDomainInstruction::ToProto() const {
2774 HloInstructionProto proto = HloInstruction::ToProto();
2775 auto operand_side_sharding =
2776 dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
2777 if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
2778 *proto.mutable_domain_entry_sharding() =
2779 operand_side_sharding->sharding()->ToProto();
2780 }
2781
2782 auto user_side_sharding =
2783 dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
2784 if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
2785 *proto.mutable_domain_exit_sharding() =
2786 user_side_sharding->sharding()->ToProto();
2787 }
2788
2789 return proto;
2790 }
2791
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)2792 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
2793 const Shape& shape, HloInstruction* operand, int64 dimension)
2794 : HloInstruction(HloOpcode::kGetDimensionSize, shape),
2795 dimension_(dimension) {
2796 AppendOperand(operand);
2797 }
2798
ToProto() const2799 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
2800 HloInstructionProto proto = HloInstruction::ToProto();
2801 proto.add_dimensions(dimension());
2802 return proto;
2803 }
2804
ExtraAttributesToStringImpl(const HloPrintOptions &) const2805 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2806 const HloPrintOptions& /*options*/) const {
2807 return {StrCat("dimensions={", dimension(), "}")};
2808 }
2809
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2810 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
2811 const HloInstruction& other,
2812 const std::function<bool(const HloComputation*, const HloComputation*)>&
2813 /*eq_computations*/) const {
2814 const auto& casted_other =
2815 static_cast<const HloGetDimensionSizeInstruction&>(other);
2816 return dimension() == casted_other.dimension();
2817 }
2818
2819 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2820 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2821 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2822 HloCloneContext* /*context*/) const {
2823 if (new_operands.size() != 1) {
2824 LOG(FATAL) << "expects 1 operand";
2825 }
2826 return absl::make_unique<HloGetDimensionSizeInstruction>(
2827 shape, new_operands[0], dimension());
2828 }
2829
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)2830 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
2831 const Shape& shape, HloInstruction* operand, HloInstruction* val,
2832 int64 dimension)
2833 : HloInstruction(HloOpcode::kSetDimensionSize, shape),
2834 dimension_(dimension) {
2835 AppendOperand(operand);
2836 AppendOperand(val);
2837 }
2838
ExtraAttributesToStringImpl(const HloPrintOptions &) const2839 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
2840 const HloPrintOptions& /*options*/) const {
2841 return {StrCat("dimensions={", dimension(), "}")};
2842 }
2843
ToProto() const2844 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
2845 HloInstructionProto proto = HloInstruction::ToProto();
2846 proto.add_dimensions(dimension());
2847 return proto;
2848 }
2849
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2850 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
2851 const HloInstruction& other,
2852 const std::function<bool(const HloComputation*, const HloComputation*)>&
2853 /*eq_computations*/) const {
2854 const auto& casted_other =
2855 static_cast<const HloSetDimensionSizeInstruction&>(other);
2856 return dimension() == casted_other.dimension();
2857 }
2858
2859 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2860 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
2861 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2862 HloCloneContext* /*context*/) const {
2863 if (new_operands.size() != 2) {
2864 LOG(FATAL) << "expects 2 operand";
2865 }
2866 return absl::make_unique<HloSetDimensionSizeInstruction>(
2867 shape, new_operands[0], new_operands[1], dimension());
2868 }
2869
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64 delta)2870 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
2871 const Shape& shape, int64 delta)
2872 : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
2873
ToProto() const2874 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
2875 HloInstructionProto proto = HloInstruction::ToProto();
2876 proto.set_delta(delta_);
2877 return proto;
2878 }
2879
2880 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const2881 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
2882 const HloPrintOptions& /*options*/) const {
2883 return {StrCat("delta=", delta())};
2884 }
2885
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const2886 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
2887 const HloInstruction& other,
2888 const std::function<bool(const HloComputation*, const HloComputation*)>&
2889 /*eq_computations*/) const {
2890 const auto& casted_other =
2891 static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
2892 return delta() == casted_other.delta();
2893 }
2894
2895 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const2896 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
2897 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2898 HloCloneContext* /*context*/) const {
2899 if (!new_operands.empty()) {
2900 LOG(FATAL) << "expects 0 operand";
2901 }
2902 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
2903 }
2904
2905 } // namespace xla
2906