1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17
18 #include <deque>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 namespace xla {
39 namespace {
40
41 using absl::CEscape;
42 using absl::StrAppend;
43 using absl::StrCat;
44 using absl::StrJoin;
45
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)46 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
47 const HloInstruction* operand) {
48 const auto operand_indices = instruction->OperandIndices(operand);
49 return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
50 return instruction->IsElementwiseOnOperand(operand_index);
51 });
52 }
53
PrecisionConfigToString(const PrecisionConfig & precision_config)54 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
55 if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
56 return static_cast<PrecisionConfig::Precision>(precision) ==
57 PrecisionConfig::DEFAULT;
58 })) {
59 return "";
60 }
61
62 return StrCat(
63 "operand_precision={",
64 StrJoin(
65 precision_config.operand_precision(), ",",
66 [](string* out, int32 precision) {
67 CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
68 StrAppend(out,
69 PrecisionToString(
70 static_cast<PrecisionConfig::Precision>(precision)));
71 }),
72 "}");
73 }
74 } // namespace
75
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64 feature_index)76 HloBatchNormInstruction::HloBatchNormInstruction(
77 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
78 HloInstruction* scale, float epsilon, int64 feature_index)
79 : HloInstruction(opcode, shape),
80 epsilon_(epsilon),
81 feature_index_(feature_index) {
82 AppendOperand(operand);
83 AppendOperand(scale);
84 }
85
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const86 bool HloBatchNormInstruction::IdenticalSlowPath(
87 const HloInstruction& other,
88 const std::function<bool(const HloComputation*, const HloComputation*)>&
89 eq_computations) const {
90 const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
91 return feature_index() == casted_other.feature_index() &&
92 epsilon() == casted_other.epsilon();
93 }
94
ToProto() const95 HloInstructionProto HloBatchNormInstruction::ToProto() const {
96 HloInstructionProto proto = HloInstruction::ToProto();
97 proto.set_epsilon(epsilon_);
98 proto.set_feature_index(feature_index_);
99 return proto;
100 }
101
ExtraAttributesToStringImpl(const HloPrintOptions & options) const102 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
103 const HloPrintOptions& options) const {
104 return {StrCat("epsilon=", epsilon()),
105 StrCat("feature_index=", feature_index())};
106 }
107
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)108 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
109 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
110 HloInstruction* offset, float epsilon, int64 feature_index)
111 : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
112 scale, epsilon, feature_index) {
113 AppendOperand(offset);
114 }
115
116 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const117 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
118 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
119 HloCloneContext* context) const {
120 CHECK_EQ(new_operands.size(), 3);
121 return absl::make_unique<HloBatchNormTrainingInstruction>(
122 shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
123 feature_index());
124 }
125
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)126 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
127 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
128 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
129 float epsilon, int64 feature_index)
130 : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
131 scale, epsilon, feature_index) {
132 AppendOperand(offset);
133 AppendOperand(mean);
134 AppendOperand(variance);
135 }
136
137 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const138 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
139 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
140 HloCloneContext* context) const {
141 CHECK_EQ(new_operands.size(), 5);
142 return absl::make_unique<HloBatchNormInferenceInstruction>(
143 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
144 new_operands[4], epsilon(), feature_index());
145 }
146
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)147 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
148 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
149 HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
150 float epsilon, int64 feature_index)
151 : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
152 epsilon, feature_index) {
153 AppendOperand(mean);
154 AppendOperand(variance);
155 AppendOperand(grad_output);
156 }
157
158 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const159 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
160 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
161 HloCloneContext* context) const {
162 CHECK_EQ(new_operands.size(), 5);
163 return absl::make_unique<HloBatchNormGradInstruction>(
164 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
165 new_operands[4], epsilon(), feature_index());
166 }
167
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)168 HloFftInstruction::HloFftInstruction(const Shape& shape,
169 HloInstruction* operand, FftType fft_type,
170 absl::Span<const int64> fft_length)
171 : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
172 fft_length_.assign(fft_length.begin(), fft_length.end());
173 AppendOperand(operand);
174 }
175
ToProto() const176 HloInstructionProto HloFftInstruction::ToProto() const {
177 HloInstructionProto proto = HloInstruction::ToProto();
178 proto.set_fft_type(fft_type_);
179 for (int64 fft_len : fft_length_) {
180 proto.add_fft_length(fft_len);
181 }
182 return proto;
183 }
184
ExtraAttributesToStringImpl(const HloPrintOptions & options) const185 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
186 const HloPrintOptions& options) const {
187 return {StrCat("fft_type=", FftType_Name(fft_type())),
188 StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
189 }
190
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const191 bool HloFftInstruction::IdenticalSlowPath(
192 const HloInstruction& other,
193 const std::function<bool(const HloComputation*, const HloComputation*)>&
194 eq_computations) const {
195 const auto& casted_other = static_cast<const HloFftInstruction&>(other);
196 return fft_type() == casted_other.fft_type() &&
197 fft_length() == casted_other.fft_length();
198 }
199
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const200 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
201 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
202 HloCloneContext* context) const {
203 CHECK_EQ(new_operands.size(), 1);
204 return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
205 fft_length_);
206 }
207
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)208 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
209 HloInstruction* operand,
210 bool is_cross_program_prefetch)
211 : HloInstruction(HloOpcode::kCopyStart, shape),
212 is_cross_program_prefetch_(is_cross_program_prefetch) {
213 AppendOperand(operand);
214 }
215
ToProto() const216 HloInstructionProto HloCopyStartInstruction::ToProto() const {
217 HloInstructionProto proto = HloInstruction::ToProto();
218 proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
219 return proto;
220 }
221
ExtraAttributesToStringImpl(const HloPrintOptions & options) const222 std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
223 const HloPrintOptions& options) const {
224 std::vector<string> result;
225 if (is_cross_program_prefetch()) {
226 result.push_back("is_cross_program_prefetch=true");
227 }
228 return result;
229 }
230
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const231 bool HloCopyStartInstruction::IdenticalSlowPath(
232 const HloInstruction& other,
233 const std::function<bool(const HloComputation*, const HloComputation*)>&
234 eq_computations) const {
235 const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
236 return is_cross_program_prefetch() ==
237 casted_other.is_cross_program_prefetch();
238 }
239
240 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const241 HloCopyStartInstruction::CloneWithNewOperandsImpl(
242 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
243 HloCloneContext* context) const {
244 CHECK_EQ(new_operands.size(), 1);
245 return absl::make_unique<HloCopyStartInstruction>(
246 shape, new_operands[0], is_cross_program_prefetch());
247 }
248
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)249 HloCompareInstruction::HloCompareInstruction(
250 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
251 ComparisonDirection direction, absl::optional<Comparison::Type> type)
252 : HloInstruction(HloOpcode::kCompare, shape),
253 compare_(direction, type ? (*type)
254 : Comparison::DefaultComparisonType(
255 lhs->shape().element_type())) {
256 AppendOperand(lhs);
257 AppendOperand(rhs);
258 }
259
ToProto() const260 HloInstructionProto HloCompareInstruction::ToProto() const {
261 HloInstructionProto proto = HloInstruction::ToProto();
262 proto.set_comparison_direction(
263 ComparisonDirectionToString(compare_.GetDirection()));
264 proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
265 return proto;
266 }
267
ExtraAttributesToStringImpl(const HloPrintOptions & options) const268 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
269 const HloPrintOptions& options) const {
270 std::vector<string> result;
271 result.push_back(
272 StrCat("direction=", ComparisonDirectionToString(direction())));
273 if (compare_.GetType() !=
274 Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
275 result.push_back(
276 StrCat("type=", ComparisonTypeToString(compare_.GetType())));
277 }
278 return result;
279 }
280
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const281 bool HloCompareInstruction::IdenticalSlowPath(
282 const HloInstruction& other,
283 const std::function<bool(const HloComputation*, const HloComputation*)>&
284 eq_computations) const {
285 const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
286 return direction() == casted_other.direction();
287 }
288
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const289 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
290 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
291 HloCloneContext* context) const {
292 CHECK_EQ(new_operands.size(), 2);
293 return absl::make_unique<HloCompareInstruction>(
294 shape, new_operands[0], new_operands[1], direction(), type());
295 }
296
297 namespace {
298
299 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
300 // of "key=value" attribute strings generically, using protocol buffer
301 // reflection.
302 //
303 // Currently implements a small subset of cases; feel free to add more as
304 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)305 std::vector<string> AttributeProtoToStringVector(
306 const tensorflow::protobuf::Message& message) {
307 const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
308 std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
309 reflection->ListFields(message, &fields);
310
311 std::vector<string> output;
312 for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
313 string s = absl::StrCat(field->name(), "=");
314 CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
315 switch (field->type()) {
316 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
317 bool val = reflection->GetBool(message, field);
318 absl::StrAppend(&s, val ? "true" : "false");
319 break;
320 }
321 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
322 const tensorflow::protobuf::EnumValueDescriptor* evd =
323 reflection->GetEnum(message, field);
324 absl::StrAppend(&s, evd->name());
325 break;
326 }
327 default:
328 LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
329 }
330 output.push_back(std::move(s));
331 }
332 return output;
333 }
334
335 } // namespace
336
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)337 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
338 const Shape& shape, HloInstruction* a, HloInstruction* b,
339 const TriangularSolveOptions& options)
340 : HloInstruction(HloOpcode::kTriangularSolve, shape),
341 triangular_solve_options_(options) {
342 AppendOperand(a);
343 AppendOperand(b);
344 }
345
ToProto() const346 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
347 HloInstructionProto proto = HloInstruction::ToProto();
348 *proto.mutable_triangular_solve_options() = triangular_solve_options_;
349 return proto;
350 }
351
ExtraAttributesToStringImpl(const HloPrintOptions & options) const352 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
353 const HloPrintOptions& options) const {
354 return AttributeProtoToStringVector(triangular_solve_options_);
355 }
356
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const357 bool HloTriangularSolveInstruction::IdenticalSlowPath(
358 const HloInstruction& other,
359 const std::function<bool(const HloComputation*, const HloComputation*)>&
360 eq_computations) const {
361 const auto& casted_other =
362 static_cast<const HloTriangularSolveInstruction&>(other);
363 const auto& options = triangular_solve_options();
364 const auto& other_options = casted_other.triangular_solve_options();
365
366 return options.left_side() == other_options.left_side() &&
367 options.lower() == other_options.lower() &&
368 options.unit_diagonal() == other_options.unit_diagonal() &&
369 options.transpose_a() == other_options.transpose_a();
370 }
371
372 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const373 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
374 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
375 HloCloneContext* context) const {
376 CHECK_EQ(new_operands.size(), 2);
377 return absl::make_unique<HloTriangularSolveInstruction>(
378 shape, new_operands[0], new_operands[1], triangular_solve_options());
379 }
380
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)381 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
382 HloInstruction* a,
383 const CholeskyOptions& options)
384 : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
385 AppendOperand(a);
386 }
387
ToProto() const388 HloInstructionProto HloCholeskyInstruction::ToProto() const {
389 HloInstructionProto proto = HloInstruction::ToProto();
390 *proto.mutable_cholesky_options() = cholesky_options_;
391 return proto;
392 }
393
ExtraAttributesToStringImpl(const HloPrintOptions & options) const394 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
395 const HloPrintOptions& options) const {
396 return AttributeProtoToStringVector(cholesky_options_);
397 }
398
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const399 bool HloCholeskyInstruction::IdenticalSlowPath(
400 const HloInstruction& other,
401 const std::function<bool(const HloComputation*, const HloComputation*)>&
402 eq_computations) const {
403 const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
404 const auto& options = cholesky_options();
405 const auto& other_options = casted_other.cholesky_options();
406
407 return options.lower() == other_options.lower();
408 }
409
410 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const411 HloCholeskyInstruction::CloneWithNewOperandsImpl(
412 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
413 HloCloneContext* context) const {
414 CHECK_EQ(new_operands.size(), 1);
415 return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
416 cholesky_options());
417 }
418
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)419 HloChannelInstruction::HloChannelInstruction(
420 HloOpcode opcode, const Shape& shape,
421 const absl::optional<int64>& channel_id)
422 : HloInstruction(opcode, shape), channel_id_(channel_id) {}
423
set_channel_id(const absl::optional<int64> & channel_id)424 void HloChannelInstruction::set_channel_id(
425 const absl::optional<int64>& channel_id) {
426 channel_id_ = channel_id;
427 }
428
ToProto() const429 HloInstructionProto HloChannelInstruction::ToProto() const {
430 HloInstructionProto proto = HloInstruction::ToProto();
431 if (channel_id_) {
432 CHECK_GT(channel_id_.value(), 0)
433 << "Non-positive channel id is equivalent to no channel id";
434 proto.set_channel_id(*channel_id_);
435 }
436 return proto;
437 }
438
ExtraAttributesToStringImpl(const HloPrintOptions &) const439 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
440 const HloPrintOptions& /*options*/) const {
441 std::vector<string> result;
442 if (channel_id_) {
443 result.push_back(StrCat("channel_id=", *channel_id_));
444 }
445 return result;
446 }
447
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const448 bool HloChannelInstruction::IdenticalSlowPath(
449 const HloInstruction& other,
450 const std::function<bool(const HloComputation*, const HloComputation*)>&
451 eq_computations) const {
452 if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
453 return false;
454 }
455 const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
456 return channel_id() == casted_other.channel_id();
457 }
458
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64 channel_id,bool is_host_transfer)459 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
460 const Shape& shape,
461 int64 channel_id,
462 bool is_host_transfer)
463 : HloChannelInstruction(opcode, shape, channel_id),
464 is_host_transfer_(is_host_transfer) {}
465
ToProto() const466 HloInstructionProto HloSendRecvInstruction::ToProto() const {
467 HloInstructionProto proto = HloChannelInstruction::ToProto();
468 proto.set_is_host_transfer(is_host_transfer_);
469 return proto;
470 }
471
ExtraAttributesToStringImpl(const HloPrintOptions & options) const472 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
473 const HloPrintOptions& options) const {
474 std::vector<string> attrs =
475 HloChannelInstruction::ExtraAttributesToStringImpl(options);
476 if (is_host_transfer()) {
477 attrs.push_back("is_host_transfer=true");
478 }
479 return attrs;
480 }
481
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const482 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
483 const HloInstruction& other,
484 const std::function<bool(const HloComputation*, const HloComputation*)>&
485 eq_computations) const {
486 // Not yet supported.
487 return false;
488 }
489
490 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)491 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
492 HloInstruction* token, int64 channel_id,
493 bool is_host_transfer)
494 : HloSendRecvInstruction(
495 HloOpcode::kSend,
496 ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
497 ShapeUtil::MakeShape(U32, {}),
498 ShapeUtil::MakeTokenShape()}),
499 channel_id, is_host_transfer) {
500 AppendOperand(operand);
501 AppendOperand(token);
502 }
503
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const504 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
505 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
506 HloCloneContext* context) const {
507 CHECK_EQ(new_operands.size(), 2);
508 return absl::make_unique<HloSendInstruction>(
509 new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
510 }
511
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)512 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
513 bool is_host_transfer)
514 : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
515 CHECK_NOTNULL(operand)->channel_id().value(),
516 is_host_transfer) {
517 AppendOperand(operand);
518 }
519
520 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const521 HloSendDoneInstruction::CloneWithNewOperandsImpl(
522 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
523 HloCloneContext* context) const {
524 CHECK_EQ(new_operands.size(), 1);
525 return absl::make_unique<HloSendDoneInstruction>(
526 Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
527 }
528
529 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)530 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
531 HloInstruction* token, int64 channel_id,
532 bool is_host_transfer)
533 : HloSendRecvInstruction(
534 HloOpcode::kRecv,
535 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
536 ShapeUtil::MakeTokenShape()}),
537 channel_id, is_host_transfer) {
538 AppendOperand(token);
539 }
540
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const541 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
542 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
543 HloCloneContext* context) const {
544 CHECK_EQ(new_operands.size(), 1);
545 return absl::make_unique<HloRecvInstruction>(
546 ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
547 is_host_transfer());
548 }
549
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)550 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
551 bool is_host_transfer)
552 : HloSendRecvInstruction(
553 HloOpcode::kRecvDone,
554 ShapeUtil::MakeTupleShape(
555 {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
556 ShapeUtil::MakeTokenShape()}),
557 CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
558 AppendOperand(operand);
559 }
560
561 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const562 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
563 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
564 HloCloneContext* context) const {
565 CHECK_EQ(new_operands.size(), 1);
566 return absl::make_unique<HloRecvDoneInstruction>(
567 Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
568 }
569
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)570 HloCollectiveInstruction::HloCollectiveInstruction(
571 HloOpcode opcode, const Shape& shape,
572 absl::Span<HloInstruction* const> operands,
573 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
574 const absl::optional<int64>& channel_id)
575 : HloChannelInstruction(opcode, shape, channel_id),
576 replica_groups_(replica_groups),
577 constrain_layout_(constrain_layout) {
578 for (auto operand : operands) {
579 AppendOperand(operand);
580 }
581 }
582
ToProto() const583 HloInstructionProto HloCollectiveInstruction::ToProto() const {
584 HloInstructionProto proto = HloChannelInstruction::ToProto();
585 *proto.mutable_replica_groups() = {replica_groups_.begin(),
586 replica_groups_.end()};
587 proto.set_constrain_layout(constrain_layout_);
588 return proto;
589 }
590
ExtraAttributesToStringImpl(const HloPrintOptions & options) const591 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
592 const HloPrintOptions& options) const {
593 std::vector<string> result =
594 HloChannelInstruction::ExtraAttributesToStringImpl(options);
595 result.push_back(
596 StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
597 if (constrain_layout_) {
598 result.push_back("constrain_layout=true");
599 }
600 return result;
601 }
602
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const603 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
604 const HloInstruction& other,
605 const std::function<bool(const HloComputation*, const HloComputation*)>&
606 eq_computations) const {
607 const auto& casted_other =
608 static_cast<const HloCollectiveInstruction&>(other);
609 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
610 other, eq_computations) &&
611 constrain_layout() == casted_other.constrain_layout() &&
612 absl::c_equal(replica_groups(), casted_other.replica_groups(),
613 [](const ReplicaGroup& a, const ReplicaGroup& b) {
614 return absl::c_equal(a.replica_ids(), b.replica_ids());
615 });
616 }
617
HloAllGatherInstruction(const Shape & shape,HloInstruction * operand,int64 all_gather_dimension,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)618 HloAllGatherInstruction::HloAllGatherInstruction(
619 const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
620 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
621 const absl::optional<int64>& channel_id, bool use_global_device_ids)
622 : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand},
623 replica_groups, constrain_layout, channel_id),
624 all_gather_dimension_(all_gather_dimension),
625 use_global_device_ids_(use_global_device_ids) {}
626
ExtraAttributesToStringImpl(const HloPrintOptions & options) const627 std::vector<string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
628 const HloPrintOptions& options) const {
629 std::vector<string> result =
630 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
631 result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
632 if (use_global_device_ids_) {
633 result.push_back("use_global_device_ids=true");
634 }
635 return result;
636 }
637
638 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const639 HloAllGatherInstruction::CloneWithNewOperandsImpl(
640 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
641 HloCloneContext* /*context*/) const {
642 return absl::make_unique<HloAllGatherInstruction>(
643 shape, new_operands[0], all_gather_dimension(), replica_groups(),
644 constrain_layout(), channel_id(), use_global_device_ids());
645 }
646
ToProto() const647 HloInstructionProto HloAllGatherInstruction::ToProto() const {
648 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
649 proto.add_dimensions(all_gather_dimension_);
650 proto.set_use_global_device_ids(use_global_device_ids_);
651 return proto;
652 }
653
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const654 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
655 const HloInstruction& other,
656 const std::function<bool(const HloComputation*, const HloComputation*)>&
657 eq_computations) const {
658 const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
659 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
660 other, eq_computations) &&
661 all_gather_dimension_ == casted_other.all_gather_dimension() &&
662 use_global_device_ids() == casted_other.use_global_device_ids();
663 }
664
HloAllReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)665 HloAllReduceInstruction::HloAllReduceInstruction(
666 const Shape& shape, absl::Span<HloInstruction* const> operands,
667 HloComputation* reduce_computation,
668 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
669 const absl::optional<int64>& channel_id, bool use_global_device_ids)
670 : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
671 replica_groups, constrain_layout, channel_id),
672 use_global_device_ids_(use_global_device_ids) {
673 AppendComputation(reduce_computation);
674 }
675
IsNoop() const676 bool HloAllReduceInstruction::IsNoop() const {
677 for (const auto& replica_group : replica_groups()) {
678 if (replica_group.replica_ids().size() != 1) {
679 return false;
680 }
681 }
682 return !channel_id();
683 }
684
ToProto() const685 HloInstructionProto HloAllReduceInstruction::ToProto() const {
686 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
687 proto.set_use_global_device_ids(use_global_device_ids_);
688 return proto;
689 }
690
ExtraAttributesToStringImpl(const HloPrintOptions & options) const691 std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
692 const HloPrintOptions& options) const {
693 std::vector<string> result =
694 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
695 if (use_global_device_ids_) {
696 result.push_back("use_global_device_ids=true");
697 }
698 return result;
699 }
700
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const701 bool HloAllReduceInstruction::IdenticalSlowPathIgnoringChannelIdValues(
702 const HloInstruction& other,
703 const std::function<bool(const HloComputation*, const HloComputation*)>&
704 eq_computations) const {
705 const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
706 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
707 other, eq_computations) &&
708 constrain_layout() == casted_other.constrain_layout() &&
709 use_global_device_ids() == casted_other.use_global_device_ids() &&
710 eq_computations(to_apply(), casted_other.to_apply());
711 }
712
713 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const714 HloAllReduceInstruction::CloneWithNewOperandsImpl(
715 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
716 HloCloneContext* /*context*/) const {
717 return absl::make_unique<HloAllReduceInstruction>(
718 shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
719 channel_id(), use_global_device_ids());
720 }
721
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)722 HloAllToAllInstruction::HloAllToAllInstruction(
723 const Shape& shape, absl::Span<HloInstruction* const> operands,
724 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
725 const absl::optional<int64>& channel_id,
726 const absl::optional<int64>& split_dimension)
727 : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
728 replica_groups, constrain_layout, channel_id),
729 split_dimension_(split_dimension) {}
730
731 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const732 HloAllToAllInstruction::CloneWithNewOperandsImpl(
733 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
734 HloCloneContext* /*context*/) const {
735 return absl::make_unique<HloAllToAllInstruction>(
736 shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
737 split_dimension());
738 }
739
ToProto() const740 HloInstructionProto HloAllToAllInstruction::ToProto() const {
741 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
742 if (split_dimension_) {
743 proto.add_dimensions(*split_dimension_);
744 }
745 return proto;
746 }
747
ExtraAttributesToStringImpl(const HloPrintOptions & options) const748 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
749 const HloPrintOptions& options) const {
750 std::vector<string> result =
751 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
752 if (split_dimension_) {
753 result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
754 }
755 return result;
756 }
757
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const758 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
759 const HloInstruction& other,
760 const std::function<bool(const HloComputation*, const HloComputation*)>&
761 eq_computations) const {
762 const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
763 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
764 other, eq_computations) &&
765 split_dimension_ == casted_other.split_dimension();
766 }
767
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)768 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
769 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
770 const std::vector<std::pair<int64, int64>>& source_target_pairs,
771 const absl::optional<int64>& channel_id)
772 : HloChannelInstruction(opcode, shape, channel_id),
773 source_target_pairs_(source_target_pairs) {
774 AppendOperand(operand);
775 }
776
ToProto() const777 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
778 HloInstructionProto proto = HloChannelInstruction::ToProto();
779 for (const auto& pair : source_target_pairs()) {
780 auto* proto_pair = proto.add_source_target_pairs();
781 proto_pair->set_source(pair.first);
782 proto_pair->set_target(pair.second);
783 }
784 return proto;
785 }
786
787 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const788 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
789 const HloPrintOptions& options) const {
790 std::vector<string> result =
791 HloChannelInstruction::ExtraAttributesToStringImpl(options);
792 std::vector<string> strs;
793 for (const auto& pair : source_target_pairs()) {
794 strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
795 }
796 result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
797 return result;
798 }
799
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const800 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
801 const HloInstruction& other,
802 const std::function<bool(const HloComputation*, const HloComputation*)>&
803 eq_computations) const {
804 if (opcode() != other.opcode()) {
805 return false;
806 }
807 const auto& casted_other =
808 static_cast<const HloCollectivePermuteInstruction&>(other);
809 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
810 other, eq_computations) &&
811 absl::c_equal(source_target_pairs(),
812 casted_other.source_target_pairs(),
813 [](const std::pair<int64, int64>& a,
814 const std::pair<int64, int64>& b) { return a == b; });
815 }
816
817 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const818 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
819 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
820 HloCloneContext* /*context*/) const {
821 return absl::make_unique<HloCollectivePermuteInstruction>(
822 opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
823 }
824
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)825 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
826 HloInstruction* operand,
827 absl::Span<const int64> dimensions)
828 : HloInstruction(HloOpcode::kReverse, shape),
829 dimensions_(dimensions.begin(), dimensions.end()) {
830 AppendOperand(operand);
831 }
832
ToProto() const833 HloInstructionProto HloReverseInstruction::ToProto() const {
834 HloInstructionProto proto = HloInstruction::ToProto();
835 for (int64 dimension : dimensions_) {
836 proto.add_dimensions(dimension);
837 }
838 return proto;
839 }
840
ExtraAttributesToStringImpl(const HloPrintOptions & options) const841 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
842 const HloPrintOptions& options) const {
843 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
844 }
845
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const846 bool HloReverseInstruction::IdenticalSlowPath(
847 const HloInstruction& other,
848 const std::function<bool(const HloComputation*, const HloComputation*)>&
849 eq_computations) const {
850 const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
851 return dimensions() == casted_other.dimensions();
852 }
853
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const854 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
855 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
856 HloCloneContext* context) const {
857 CHECK_EQ(new_operands.size(), 1);
858 return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
859 dimensions());
860 }
861
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)862 HloConcatenateInstruction::HloConcatenateInstruction(
863 const Shape& shape, absl::Span<HloInstruction* const> operands,
864 int64 dimension)
865 : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
866 for (auto operand : operands) {
867 AppendOperand(operand);
868 }
869 }
870
ToProto() const871 HloInstructionProto HloConcatenateInstruction::ToProto() const {
872 HloInstructionProto proto = HloInstruction::ToProto();
873 for (int64 dimension : dimensions_) {
874 proto.add_dimensions(dimension);
875 }
876 return proto;
877 }
878
ExtraAttributesToStringImpl(const HloPrintOptions & options) const879 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
880 const HloPrintOptions& options) const {
881 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
882 }
883
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const884 bool HloConcatenateInstruction::IdenticalSlowPath(
885 const HloInstruction& other,
886 const std::function<bool(const HloComputation*, const HloComputation*)>&
887 eq_computations) const {
888 const auto& casted_other =
889 static_cast<const HloConcatenateInstruction&>(other);
890 return dimensions() == casted_other.dimensions();
891 }
892
893 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const894 HloConcatenateInstruction::CloneWithNewOperandsImpl(
895 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
896 HloCloneContext* context) const {
897 return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
898 dimensions(0));
899 }
900
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)901 HloReduceInstruction::HloReduceInstruction(
902 const Shape& shape, absl::Span<HloInstruction* const> args,
903 absl::Span<const int64> dimensions_to_reduce,
904 HloComputation* reduce_computation)
905 : HloInstruction(HloOpcode::kReduce, shape),
906 dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
907 for (HloInstruction* arg : args) {
908 AppendOperand(arg);
909 }
910 AppendComputation(reduce_computation);
911 }
912
ToProto() const913 HloInstructionProto HloReduceInstruction::ToProto() const {
914 HloInstructionProto proto = HloInstruction::ToProto();
915 for (int64 dimension : dimensions_) {
916 proto.add_dimensions(dimension);
917 }
918 return proto;
919 }
920
ExtraAttributesToStringImpl(const HloPrintOptions & options) const921 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
922 const HloPrintOptions& options) const {
923 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
924 }
925
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const926 bool HloReduceInstruction::IdenticalSlowPath(
927 const HloInstruction& other,
928 const std::function<bool(const HloComputation*, const HloComputation*)>&
929 eq_computations) const {
930 const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
931 // Reduction results are determined by the reduction dimension and the
932 // reduction computation.
933 return dimensions() == casted_other.dimensions() &&
934 eq_computations(to_apply(), casted_other.to_apply());
935 }
936
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const937 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
938 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
939 HloCloneContext* context) const {
940 CHECK_EQ(new_operands.size() % 2, 0);
941 return absl::make_unique<HloReduceInstruction>(shape, new_operands,
942 dimensions(), to_apply());
943 }
944
HloSortInstruction(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)945 HloSortInstruction::HloSortInstruction(
946 const Shape& shape, int64 dimension,
947 absl::Span<HloInstruction* const> operands, HloComputation* compare,
948 bool is_stable)
949 : HloInstruction(HloOpcode::kSort, shape),
950 dimensions_({dimension}),
951 is_stable_(is_stable) {
952 for (auto* value : operands) {
953 AppendOperand(value);
954 }
955 AppendComputation(compare);
956 }
957
ToProto() const958 HloInstructionProto HloSortInstruction::ToProto() const {
959 HloInstructionProto proto = HloInstruction::ToProto();
960 for (int64 dimension : dimensions_) {
961 proto.add_dimensions(dimension);
962 }
963 proto.set_is_stable(is_stable());
964 return proto;
965 }
966
ExtraAttributesToStringImpl(const HloPrintOptions & options) const967 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
968 const HloPrintOptions& options) const {
969 std::vector<string> attrs;
970 attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
971 if (is_stable()) {
972 attrs.push_back("is_stable=true");
973 }
974 return attrs;
975 }
976
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const977 bool HloSortInstruction::IdenticalSlowPath(
978 const HloInstruction& other,
979 const std::function<bool(const HloComputation*, const HloComputation*)>&
980 eq_computations) const {
981 const auto& casted_other = static_cast<const HloSortInstruction&>(other);
982 if (dimensions() != casted_other.dimensions()) {
983 return false;
984 }
985 if (is_stable() != casted_other.is_stable()) {
986 return false;
987 }
988 return eq_computations(to_apply(), other.to_apply());
989 }
990
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const991 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
992 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
993 HloCloneContext* context) const {
994 return absl::make_unique<HloSortInstruction>(
995 shape, dimensions(0), new_operands, to_apply(), is_stable());
996 }
997
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)998 HloTransposeInstruction::HloTransposeInstruction(
999 const Shape& shape, HloInstruction* operand,
1000 absl::Span<const int64> dimensions)
1001 : HloInstruction(HloOpcode::kTranspose, shape),
1002 dimensions_(dimensions.begin(), dimensions.end()) {
1003 AppendOperand(operand);
1004 }
1005
IsRank2Transpose() const1006 bool HloTransposeInstruction::IsRank2Transpose() const {
1007 return dimensions() == std::vector<int64>({1, 0}) &&
1008 shape().dimensions_size() == 2 &&
1009 std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1010 operand(0)->shape().dimensions().rbegin());
1011 }
1012
ToProto() const1013 HloInstructionProto HloTransposeInstruction::ToProto() const {
1014 HloInstructionProto proto = HloInstruction::ToProto();
1015 for (int64 dimension : dimensions_) {
1016 proto.add_dimensions(dimension);
1017 }
1018 return proto;
1019 }
1020
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1021 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
1022 const HloPrintOptions& options) const {
1023 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1024 }
1025
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1026 bool HloTransposeInstruction::IdenticalSlowPath(
1027 const HloInstruction& other,
1028 const std::function<bool(const HloComputation*, const HloComputation*)>&
1029 eq_computations) const {
1030 const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
1031 return dimensions() == casted_other.dimensions();
1032 }
1033
1034 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1035 HloTransposeInstruction::CloneWithNewOperandsImpl(
1036 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1037 HloCloneContext* context) const {
1038 CHECK_EQ(new_operands.size(), 1);
1039 return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1040 dimensions());
1041 }
1042
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)1043 HloBroadcastInstruction::HloBroadcastInstruction(
1044 const Shape& shape, HloInstruction* operand,
1045 absl::Span<const int64> broadcast_dimension)
1046 : HloInstruction(HloOpcode::kBroadcast, shape),
1047 dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
1048 AppendOperand(operand);
1049 }
1050
ToProto() const1051 HloInstructionProto HloBroadcastInstruction::ToProto() const {
1052 HloInstructionProto proto = HloInstruction::ToProto();
1053 for (int64 dimension : dimensions_) {
1054 proto.add_dimensions(dimension);
1055 }
1056 return proto;
1057 }
1058
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1059 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
1060 const HloPrintOptions& options) const {
1061 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1062 }
1063
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1064 bool HloBroadcastInstruction::IdenticalSlowPath(
1065 const HloInstruction& other,
1066 const std::function<bool(const HloComputation*, const HloComputation*)>&
1067 eq_computations) const {
1068 const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
1069 return dimensions() == casted_other.dimensions();
1070 }
1071
1072 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1073 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1074 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1075 HloCloneContext* context) const {
1076 CHECK_EQ(new_operands.size(), 1);
1077 return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1078 dimensions());
1079 }
1080
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1081 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1082 const Shape& shape, HloInstruction* data_operand,
1083 absl::Span<HloInstruction* const> dim_sizes)
1084 : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1085 AppendOperand(data_operand);
1086 for (auto operand : dim_sizes) {
1087 AppendOperand(operand);
1088 }
1089 }
1090
1091 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1092 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1093 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1094 HloCloneContext* context) const {
1095 CHECK_GE(new_operands.size(), 1);
1096 return absl::make_unique<HloDynamicReshapeInstruction>(
1097 shape, new_operands[0], new_operands.subspan(1));
1098 }
1099
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1100 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1101 HloInstruction* operand,
1102 int64 inferred_dimension)
1103 : HloInstruction(HloOpcode::kReshape, shape),
1104 inferred_dimension_(inferred_dimension) {
1105 AppendOperand(operand);
1106 }
1107
ToProto() const1108 HloInstructionProto HloReshapeInstruction::ToProto() const {
1109 HloInstructionProto proto = HloInstruction::ToProto();
1110 if (inferred_dimension_ != -1) {
1111 proto.add_dimensions(inferred_dimension_);
1112 }
1113 return proto;
1114 }
1115
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1116 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1117 const HloPrintOptions& options) const {
1118 if (inferred_dimension() == -1) {
1119 return {};
1120 }
1121 return {StrCat("inferred_dimension=", inferred_dimension())};
1122 }
1123
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1124 bool HloReshapeInstruction::IdenticalSlowPath(
1125 const HloInstruction& other,
1126 const std::function<bool(const HloComputation*, const HloComputation*)>&
1127 eq_computations) const {
1128 const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1129 return inferred_dimension() == casted_other.inferred_dimension();
1130 }
1131
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1132 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1133 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1134 HloCloneContext* context) const {
1135 CHECK_EQ(new_operands.size(), 1);
1136 return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1137 inferred_dimension());
1138 }
1139
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1140 HloMapInstruction::HloMapInstruction(const Shape& shape,
1141 absl::Span<HloInstruction* const> operands,
1142 HloComputation* map_computation)
1143 : HloInstruction(HloOpcode::kMap, shape) {
1144 for (auto operand : operands) {
1145 AppendOperand(operand);
1146 }
1147 AppendComputation(map_computation);
1148 // TODO(b/65689298) Remove code below once Map is generalized to accept
1149 // arbitrary map dimensions.
1150 dimensions_.resize(shape.rank());
1151 std::iota(dimensions_.begin(), dimensions_.end(), 0);
1152 }
1153
ToProto() const1154 HloInstructionProto HloMapInstruction::ToProto() const {
1155 HloInstructionProto proto = HloInstruction::ToProto();
1156 for (int64 dimension : dimensions_) {
1157 proto.add_dimensions(dimension);
1158 }
1159 return proto;
1160 }
1161
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1162 bool HloMapInstruction::IsElementwiseImpl(
1163 const absl::optional<int64>& operand_idx) const {
1164 if (!dimensions().empty()) {
1165 // Check that the map is executed in elementwise compatible dimensions.
1166 if (dimensions().size() != shape().dimensions_size()) {
1167 return false;
1168 }
1169 for (int i = 0; i < dimensions().size(); ++i) {
1170 if (dimensions()[i] != i) {
1171 return false;
1172 }
1173 }
1174 }
1175 return true;
1176 }
1177
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1178 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1179 const HloPrintOptions& options) const {
1180 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1181 }
1182
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1183 bool HloMapInstruction::IdenticalSlowPath(
1184 const HloInstruction& other,
1185 const std::function<bool(const HloComputation*, const HloComputation*)>&
1186 eq_computations) const {
1187 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1188 return eq_computations(to_apply(), casted_other.to_apply()) &&
1189 dimensions() == casted_other.dimensions();
1190 }
1191
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1192 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1193 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1194 HloCloneContext* context) const {
1195 return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1196 }
1197
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1198 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1199 HloInstruction* operand,
1200 absl::Span<const int64> start_indices,
1201 absl::Span<const int64> limit_indices,
1202 absl::Span<const int64> strides)
1203 : HloInstruction(HloOpcode::kSlice, shape),
1204 slice_starts_(start_indices.begin(), start_indices.end()),
1205 slice_limits_(limit_indices.begin(), limit_indices.end()),
1206 slice_strides_(strides.begin(), strides.end()) {
1207 AppendOperand(operand);
1208 // For backward compatibility with old serialized computations: if there are
1209 // no strides, assume all strides are 1.
1210 // TODO(b/63317920): remove this code.
1211 if (slice_strides_.empty()) {
1212 slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1213 }
1214 }
1215
ToProto() const1216 HloInstructionProto HloSliceInstruction::ToProto() const {
1217 HloInstructionProto proto = HloInstruction::ToProto();
1218 for (int i = 0; i < slice_starts_.size(); ++i) {
1219 auto* slice_dimension = proto.add_slice_dimensions();
1220 slice_dimension->set_start(slice_starts_[i]);
1221 slice_dimension->set_limit(slice_limits_[i]);
1222 slice_dimension->set_stride(slice_strides_[i]);
1223 }
1224 return proto;
1225 }
1226
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1227 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1228 const HloPrintOptions& options) const {
1229 std::vector<string> bounds;
1230 bounds.reserve(slice_starts_.size());
1231 const bool omit_stride =
1232 absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
1233 for (int i = 0; i < slice_starts_.size(); ++i) {
1234 string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1235 bounds.push_back(
1236 StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1237 }
1238 return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1239 }
1240
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1241 bool HloSliceInstruction::IdenticalSlowPath(
1242 const HloInstruction& other,
1243 const std::function<bool(const HloComputation*, const HloComputation*)>&
1244 eq_computations) const {
1245 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1246 return slice_starts_ == other_slice.slice_starts_ &&
1247 slice_limits_ == other_slice.slice_limits_ &&
1248 slice_strides_ == other_slice.slice_strides_;
1249 }
1250
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1251 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1252 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1253 HloCloneContext* context) const {
1254 CHECK_EQ(new_operands.size(), 1);
1255 return absl::make_unique<HloSliceInstruction>(
1256 shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1257 }
1258
HloConstantInstruction(Literal literal)1259 HloConstantInstruction::HloConstantInstruction(Literal literal)
1260 : HloInstruction(HloOpcode::kConstant, literal.shape()),
1261 literal_(std::move(literal)) {}
1262
HloConstantInstruction(Literal literal,const Shape & shape)1263 HloConstantInstruction::HloConstantInstruction(Literal literal,
1264 const Shape& shape)
1265 : HloInstruction(HloOpcode::kConstant, shape),
1266 literal_(std::move(literal)) {}
1267
HloConstantInstruction(const Shape & shape)1268 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1269 : HloInstruction(HloOpcode::kConstant, shape) {}
1270
ToProto() const1271 HloInstructionProto HloConstantInstruction::ToProto() const {
1272 HloInstructionProto proto = HloInstruction::ToProto();
1273 if (literal_.has_value()) {
1274 *proto.mutable_literal() = literal_->ToProto();
1275 }
1276 return proto;
1277 }
1278
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1279 bool HloConstantInstruction::IsElementwiseImpl(
1280 const absl::optional<int64>& operand_idx) const {
1281 return true;
1282 }
1283
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1284 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1285 const ShapeIndex& shape_index) {
1286 Shape* mutable_array_subshape =
1287 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1288 CHECK(mutable_array_subshape->IsArray());
1289
1290 // Normally array_subshape will always have a layout, but this invariant is
1291 // temporarily broken in LayoutAssignment::AssignLayouts.
1292
1293 if (!mutable_array_subshape->has_layout() ||
1294 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1295 *literal_ = literal_->Relayout(new_layout, shape_index);
1296 *mutable_array_subshape->mutable_layout() = new_layout;
1297 }
1298 }
1299
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1300 bool HloConstantInstruction::IdenticalSlowPath(
1301 const HloInstruction& other,
1302 const std::function<bool(const HloComputation*, const HloComputation*)>&
1303 eq_computations) const {
1304 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1305 return literal() == other_slice.literal();
1306 }
1307
1308 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1309 HloConstantInstruction::CloneWithNewOperandsImpl(
1310 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1311 HloCloneContext* context) const {
1312 CHECK(literal_.has_value());
1313 // Literal's shape may have no/different tiling info. Use this instruction's
1314 // shape instead.
1315 CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1316 this->shape()));
1317 return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1318 this->shape());
1319 }
1320
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1321 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1322 const HloPrintOptions& options,
1323 CanonicalNameMap* canonical_name_map) const {
1324 string operands;
1325 // For constants, show the actual value in place of an empty operand list.
1326 if (literal_.has_value() &&
1327 ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1328 options.print_large_constants())) {
1329 // Literal::ToString emits multidimensional arrays over multiple
1330 // lines. Compact this into one line by stripping out white space.
1331 operands = literal_->ToStringWithoutShapeOneline();
1332 } else {
1333 // Do not show large constants or tuples.
1334 operands = "{...}";
1335 }
1336 return operands;
1337 }
1338
HloTraceInstruction(const string & tag,HloInstruction * operand)1339 HloTraceInstruction::HloTraceInstruction(const string& tag,
1340 HloInstruction* operand)
1341 : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1342 literal_(LiteralUtil::CreateR1U8(tag)) {
1343 AppendOperand(operand);
1344 operand->set_tracing(this);
1345 }
1346
ToProto() const1347 HloInstructionProto HloTraceInstruction::ToProto() const {
1348 HloInstructionProto proto = HloInstruction::ToProto();
1349 *proto.mutable_literal() = literal_.ToProto();
1350 return proto;
1351 }
1352
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1353 bool HloTraceInstruction::IdenticalSlowPath(
1354 const HloInstruction& other,
1355 const std::function<bool(const HloComputation*, const HloComputation*)>&
1356 eq_computations) const {
1357 return false;
1358 }
1359
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1360 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1361 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1362 HloCloneContext* context) const {
1363 LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1364 }
1365
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1366 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1367 FusionKind fusion_kind,
1368 HloInstruction* fused_root)
1369 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1370 CHECK(fused_root != nullptr);
1371 SetAndSanitizeName("fusion");
1372 set_parent(fused_root->parent());
1373 set_metadata(fused_root->metadata());
1374 CloneAndFuseInternal(fused_root);
1375 }
1376
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1377 HloFusionInstruction::HloFusionInstruction(
1378 const Shape& shape, FusionKind fusion_kind,
1379 absl::Span<HloInstruction* const> operands,
1380 HloComputation* fusion_computation)
1381 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1382 for (auto operand : operands) {
1383 AppendOperand(operand);
1384 }
1385 SetAndSanitizeName("fusion");
1386 AppendComputation(fusion_computation);
1387 fusion_computation->SetFusionInstruction(this);
1388 }
1389
ToCategory() const1390 string HloFusionInstruction::ToCategory() const {
1391 switch (fusion_kind()) {
1392 case FusionKind::kLoop:
1393 return "loop fusion";
1394 case FusionKind::kInput:
1395 return "input fusion";
1396 case FusionKind::kOutput:
1397 return "output fusion";
1398 case FusionKind::kCustom:
1399 return "custom fusion";
1400 }
1401 }
1402
ToProto() const1403 HloInstructionProto HloFusionInstruction::ToProto() const {
1404 HloInstructionProto proto = HloInstruction::ToProto();
1405 proto.set_fusion_kind(xla::ToString(fusion_kind()));
1406 proto.add_called_computation_ids(
1407 fused_instructions_computation()->unique_id());
1408 return proto;
1409 }
1410
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1411 bool HloFusionInstruction::IsElementwiseImpl(
1412 const absl::optional<int64>& operand_idx) const {
1413 if (!operand_idx.has_value()) {
1414 for (auto* fused : fused_instructions()) {
1415 if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1416 return false;
1417 }
1418 }
1419 return true;
1420 }
1421 // A loop-fusion is elementwise on an operand if all operations (computed
1422 // using BFS) between the operand and the fused root are elementwise.
1423 std::deque<HloInstruction*> worklist;
1424 std::unordered_set<const HloInstruction*> visited;
1425 worklist.push_back(fused_parameter(operand_idx.value()));
1426 visited.insert(fused_parameter(operand_idx.value()));
1427 while (!worklist.empty()) {
1428 HloInstruction* operand = worklist.front();
1429 worklist.pop_front();
1430 for (HloInstruction* user : operand->users()) {
1431 CHECK_GE(user->unique_id(), 0);
1432 if (ContainsKey(visited, user)) {
1433 continue;
1434 }
1435 if (user->IsElementwise() ||
1436 IsInstructionElementwiseOnOperand(user, operand)) {
1437 worklist.push_back(user);
1438 visited.insert(user);
1439 } else {
1440 return false;
1441 }
1442 }
1443 }
1444 return true;
1445 }
1446
AddFusionOperand(HloInstruction * new_operand)1447 HloInstruction* HloFusionInstruction::AddFusionOperand(
1448 HloInstruction* new_operand) {
1449 CHECK_EQ(operand_count(),
1450 fused_instructions_computation()->parameter_instructions().size());
1451 const int64 param_no = operand_count();
1452 string param_name = StrCat("param_", param_no);
1453 HloInstruction* fused_parameter =
1454 fused_instructions_computation()->AddParameter(
1455 HloInstruction::CreateParameter(param_no, new_operand->shape(),
1456 param_name));
1457 AppendOperand(new_operand);
1458 return fused_parameter;
1459 }
1460
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1461 void HloFusionInstruction::MergeFusionInstruction(
1462 HloFusionInstruction* instruction_to_merge) {
1463 CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1464 // Clone the instruction from which to merge fused instructions.
1465 std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1466 HloFusionInstruction* cloned_fusion =
1467 static_cast<HloFusionInstruction*>(cloned.get());
1468 // Replace uses of fused parameters with the corresponding operand of the
1469 // fusion. Add all non-parameter fused instructions to
1470 // 'unfused_instructions' to be merged into 'this'. This is done in reverse
1471 // post order.
1472 std::vector<HloInstruction*> unfused_instructions;
1473 auto fused_instructions = cloned_fusion->fused_instructions_computation()
1474 ->MakeInstructionPostOrder();
1475 for (auto fused_it = fused_instructions.rbegin();
1476 fused_it != fused_instructions.rend(); ++fused_it) {
1477 auto fused_instruction = *fused_it;
1478 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1479 TF_CHECK_OK(
1480 fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1481 fused_instruction->parameter_number())));
1482 } else {
1483 unfused_instructions.push_back(fused_instruction);
1484 }
1485 }
1486
1487 // If there are no unfused instructions, the fused computation must consist
1488 // only of kParameter instructions. Make the operand of the corresponding
1489 // parameter number the new root.
1490 HloInstruction* unfused_root =
1491 unfused_instructions.empty()
1492 ? instruction_to_merge->mutable_operand(
1493 instruction_to_merge->fused_instructions_computation()
1494 ->root_instruction()
1495 ->parameter_number())
1496 : unfused_instructions.front();
1497 CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1498 unfused_instructions.empty());
1499 // Replace instruction_to_merge use of 'this' with unfused_root.
1500 TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1501
1502 // Build a dummy root for the cloned fusion as we may remove the original root
1503 // in the fusion process.
1504 if (!unfused_instructions.empty()) {
1505 HloComputation* computation = unfused_root->parent();
1506 auto* dummy_root = computation->AddInstruction(
1507 HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1508 computation->set_root_instruction(dummy_root,
1509 /*accept_different_shape=*/true);
1510 }
1511
1512 // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1513 // we remove it from the closed fusion node. This is so that we don't add
1514 // extra users to the producer of that instruction (we use user count to
1515 // decide if a side-effectful instruction is fusible).
1516 for (auto& instruction : unfused_instructions) {
1517 auto* fused = FuseInstruction(instruction);
1518 TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1519 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1520 }
1521 CHECK_EQ(0, cloned_fusion->user_count());
1522 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1523 cloned_fusion->fused_instructions_computation()));
1524 }
1525
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1526 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1527 HloFusionInstruction* instruction_to_merge) {
1528 // Add all non-parameter fused instructions to 'unfused_instructions' to be
1529 // merged into 'this'. `old_to_new' maps the instructions in the fused node
1530 // to the disassembled fusion instructions.
1531 // Note that we add the unfused instructions to this->parent_ computation.
1532 // This is necessary because the unique_id needs for an instruction and
1533 // it's only added when inserting to the computation.
1534 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1535 std::vector<HloInstruction*> unfused_instructions;
1536 auto computation_to_merge =
1537 instruction_to_merge->fused_instructions_computation();
1538 auto post_order = computation_to_merge->MakeInstructionPostOrder();
1539 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1540 auto fused_instruction = *rit;
1541 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1542 InsertOrDie(&old_to_new, fused_instruction,
1543 instruction_to_merge->mutable_operand(
1544 fused_instruction->parameter_number()));
1545 continue;
1546 }
1547
1548 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1549 // which clones again. This can be improved.
1550 auto cloned_instruction =
1551 parent()->AddInstruction(fused_instruction->Clone());
1552 unfused_instructions.push_back(cloned_instruction);
1553 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1554 }
1555 for (auto unfused_instruction : unfused_instructions) {
1556 for (int64 index = 0; index < unfused_instruction->operand_count();
1557 index++) {
1558 auto new_operand =
1559 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1560 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1561 }
1562 }
1563
1564 // If there are no unfused instructions, the fused computation must consist
1565 // only of kParameter instructions. Make the operand of the corresponding
1566 // parameter number the new root.
1567 HloInstruction* unfused_root =
1568 unfused_instructions.empty()
1569 ? instruction_to_merge->mutable_operand(
1570 instruction_to_merge->fused_instructions_computation()
1571 ->root_instruction()
1572 ->parameter_number())
1573 : unfused_instructions.front();
1574 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1575
1576 TF_CHECK_OK(
1577 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1578 if (GetModule()) {
1579 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1580 }
1581
1582 // Fuse the root instruction and generate multiple outputs.
1583 if (unfused_instructions.empty()) {
1584 return;
1585 }
1586 FuseInstructionIntoMultiOutput(unfused_root);
1587 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1588 // The rest instructions are of normal fusing.
1589 for (int64 i = 1; i < unfused_instructions.size(); i++) {
1590 auto instruction = unfused_instructions[i];
1591 FuseInstruction(instruction);
1592 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1593 }
1594 }
1595
fused_instructions_computation() const1596 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1597 CHECK(!called_computations().empty());
1598 auto* fused_instructions_computation = called_computations().front();
1599 CHECK(fused_instructions_computation->IsFusionComputation())
1600 << "Computation " << fused_instructions_computation->name()
1601 << " is not a fusion kind";
1602 return fused_instructions_computation;
1603 }
1604
fused_expression_root() const1605 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1606 return fused_instructions_computation()->root_instruction();
1607 }
1608
fused_parameter(int64 parameter_number) const1609 HloInstruction* HloFusionInstruction::fused_parameter(
1610 int64 parameter_number) const {
1611 return fused_instructions_computation()->parameter_instruction(
1612 parameter_number);
1613 }
1614
fused_parameters() const1615 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1616 const {
1617 return fused_instructions_computation()->parameter_instructions();
1618 }
1619
1620 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1621 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1622 HloFusionInstruction::fused_instructions() const {
1623 const HloComputation* subcomp = fused_instructions_computation();
1624 return subcomp->instructions();
1625 }
1626
1627 const tensorflow::gtl::iterator_range<
1628 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1629 HloFusionInstruction::fused_instructions() {
1630 return fused_instructions_computation()->instructions();
1631 }
1632
fused_instruction_count() const1633 int64 HloFusionInstruction::fused_instruction_count() const {
1634 return fused_instructions_computation()->instruction_count();
1635 }
1636
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1637 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1638 HloInstruction* instruction_to_fuse, bool add_output) {
1639 // When add_output is false, this fusion instruction must be a user of
1640 // instruction_to_fuse.
1641 if (!add_output) {
1642 CHECK(IsUserOf(instruction_to_fuse));
1643 }
1644 HloInstruction* fused_instruction =
1645 CloneAndFuseInternal(instruction_to_fuse, add_output);
1646 return fused_instruction;
1647 }
1648
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1649 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1650 HloInstruction* instruction_to_fuse, bool add_output) {
1651 CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1652 VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1653 HloInstruction* clone = nullptr;
1654 if (called_computations().empty()) {
1655 // New fusion instruction. It should not be a multioutput instruction.
1656 CHECK(!add_output);
1657 auto builder = HloComputation::Builder("fused_computation", this);
1658 builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1659 AppendComputation(
1660 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1661 clone = fused_expression_root();
1662 } else {
1663 // When add_output is false, instruction_to_fuse is necessarily an operand
1664 // of the fusion instruction. After fusion this will no longer be the
1665 // case. Remove the operand from the operand list and remove its
1666 // corresponding fused parameter instruction. Renumber parameters as
1667 // necessary to make parameter numbers consistent with their index in the
1668 // fused_parameter_ vector.
1669 bool in_operand_list =
1670 absl::c_linear_search(operands(), instruction_to_fuse);
1671 CHECK(add_output || in_operand_list);
1672 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1673 // We assume all uses of a kTuple operation are GTE ops, not another
1674 // fusion node. In this case, we don't need to clone
1675 // 'instruction_to_fuse'.
1676 CHECK(!in_operand_list);
1677 clone = instruction_to_fuse;
1678 } else {
1679 clone = fused_instructions_computation()->AddInstruction(
1680 instruction_to_fuse->Clone(/*suffix=*/""));
1681 }
1682 const std::vector<HloInstruction*>& fused_parameters =
1683 fused_instructions_computation()->parameter_instructions();
1684 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1685 if (instruction_to_fuse == operand(operand_num)) {
1686 // replace the fused parameter instruction's uses with the clone.
1687 HloInstruction* fused_parameter = fused_parameters[operand_num];
1688 TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1689
1690 // Remove the corresponding fused parameter and operand from their
1691 // respective vectors.
1692 TF_CHECK_OK(
1693 fused_instructions_computation()->RemoveParameter(operand_num));
1694 RemoveOperandAt(operand_num);
1695 break;
1696 }
1697 }
1698 // We've cloned instruction_to_fuse into this fusion instruction, so this
1699 // fusion instruction is no longer a use of instruction_to_fuse.
1700 if (in_operand_list) {
1701 DetachFrom(instruction_to_fuse);
1702 // When the instruction_to_fuse does not have other users, we don't need
1703 // to generate a multioutput fusion instruction.
1704 if (instruction_to_fuse->user_count() == 0) {
1705 add_output = false;
1706 }
1707 }
1708 }
1709
1710 // Reread the parameters in the computation.
1711 const std::vector<HloInstruction*>& fused_parameters =
1712 fused_instructions_computation()->parameter_instructions();
1713
1714 // Add each operand of the clone as an operand of the fusion instruction. A
1715 // complication is that some clone operands may already be operands of the
1716 // fusion instruction.
1717 for (int64 operand_num = 0; operand_num < clone->operand_count();
1718 ++operand_num) {
1719 HloInstruction* operand = clone->mutable_operand(operand_num);
1720
1721 // See if this operand is already an operand of the fusion node.
1722 CHECK_EQ(operands().size(), fused_parameters.size());
1723 HloInstruction* fused_param = nullptr;
1724 for (int64 i = 0; i < operands().size(); ++i) {
1725 if (this->operand(i) == operand) {
1726 fused_param = fused_parameters[i];
1727 break;
1728 }
1729 }
1730
1731 if (fused_param == nullptr) {
1732 // Clone's operand was not already an operand of the fusion
1733 // instruction. Add it as an operand and add a corresponding fused
1734 // parameter instruction.
1735 fused_param = AddFusionOperand(operand);
1736 }
1737 TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1738 }
1739
1740 if (add_output) {
1741 CHECK_GT(instruction_to_fuse->user_count(), 0);
1742 // If this is already a multioutput fusion instruction, expand the root
1743 // tuple by 1.
1744 HloInstruction* fused_root = fused_expression_root();
1745 HloInstruction::InstructionVector tuple_elements;
1746 bool newly_created_tuple_instr = false;
1747 if (fused_root->opcode() == HloOpcode::kTuple) {
1748 tuple_elements = fused_root->operands();
1749 } else {
1750 tuple_elements.push_back(fused_root);
1751 newly_created_tuple_instr = true;
1752 }
1753 if (clone->opcode() == HloOpcode::kTuple) {
1754 for (auto inst : clone->operands()) {
1755 tuple_elements.push_back(inst);
1756 }
1757 } else {
1758 tuple_elements.push_back(clone);
1759 }
1760 HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1761 HloInstruction::CreateTuple(tuple_elements));
1762 fused_instructions_computation()->set_root_instruction(new_root);
1763 *mutable_shape() = new_root->shape();
1764 if (fused_root->opcode() == HloOpcode::kTuple) {
1765 TF_CHECK_OK(
1766 fused_instructions_computation()->RemoveInstruction(fused_root));
1767 }
1768
1769 // If this is a newly created multioutput instruction, we need to update
1770 // the use of the original fusion instruction.
1771 if (newly_created_tuple_instr) {
1772 HloInstruction* new_instr = parent()->AddInstruction(
1773 HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1774 TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1775 }
1776 int64 index = tuple_elements.size();
1777 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1778 CHECK_EQ(clone, instruction_to_fuse);
1779 index -= clone->operand_count();
1780 std::vector<HloInstruction*> to_be_removed;
1781 for (auto old_gte : clone->users()) {
1782 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1783 int64 old_tuple_index = old_gte->tuple_index();
1784 HloInstruction* new_gte =
1785 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1786 old_gte->shape(), this, index + old_tuple_index));
1787 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1788 to_be_removed.push_back(old_gte);
1789 }
1790 for (auto old_gte : to_be_removed) {
1791 TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1792 }
1793 } else {
1794 HloInstruction* new_gte =
1795 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1796 clone->shape(), this, index - 1));
1797 TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1798 }
1799 }
1800
1801 if (clone != instruction_to_fuse) {
1802 VLOG(2) << "New clone:\n" << clone->ToString();
1803 }
1804 return clone;
1805 }
1806
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1807 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1808 const HloPrintOptions& options) const {
1809 return {StrCat("kind=", xla::ToString(fusion_kind()))};
1810 }
1811
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1812 bool HloFusionInstruction::IdenticalSlowPath(
1813 const HloInstruction& other,
1814 const std::function<bool(const HloComputation*, const HloComputation*)>&
1815 eq_computations) const {
1816 return fusion_kind() == other.fusion_kind() &&
1817 eq_computations(fused_instructions_computation(),
1818 other.fused_instructions_computation());
1819 }
1820
InnerHash() const1821 uint64 HloFusionInstruction::InnerHash() const {
1822 return fused_instructions_computation()->root_instruction()->Hash();
1823 }
1824
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1825 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1826 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1827 HloCloneContext* context) const {
1828 HloModule* module = context != nullptr ? context->module() : GetModule();
1829 HloComputation* new_fused_computation = nullptr;
1830 if (context != nullptr) {
1831 new_fused_computation =
1832 context->FindComputation(fused_instructions_computation());
1833 }
1834 if (new_fused_computation == nullptr) {
1835 new_fused_computation = module->AddEmbeddedComputation(
1836 fused_instructions_computation()->Clone("clone", context));
1837 }
1838 return absl::make_unique<HloFusionInstruction>(
1839 shape, fusion_kind(), new_operands, new_fused_computation);
1840 }
1841
DeduplicateFusionOperands()1842 Status HloFusionInstruction::DeduplicateFusionOperands() {
1843 if (IsCustomFusion()) {
1844 return Status::OK();
1845 }
1846 absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1847 std::vector<int> operands_to_remove;
1848 for (int i = 0; i < operand_count(); ++i) {
1849 auto emplace_result = operand_indices.emplace(operand(i), i);
1850 if (!emplace_result.second) {
1851 TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1852 fused_parameter(emplace_result.first->second)));
1853 operands_to_remove.push_back(i);
1854 }
1855 }
1856 if (operands_to_remove.empty()) {
1857 return Status::OK();
1858 }
1859 TF_RETURN_IF_ERROR(fused_instructions_computation()
1860 ->RemoveUnusedParametersFromFusedComputation());
1861 RemoveOperandsAtAscendingIndices(operands_to_remove);
1862 return Status::OK();
1863 }
1864
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1865 HloRngInstruction::HloRngInstruction(
1866 const Shape& shape, RandomDistribution distribution,
1867 absl::Span<HloInstruction* const> parameters)
1868 : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
1869 for (HloInstruction* param : parameters) {
1870 AppendOperand(param);
1871 }
1872 }
1873
ToProto() const1874 HloInstructionProto HloRngInstruction::ToProto() const {
1875 HloInstructionProto proto = HloInstruction::ToProto();
1876 proto.set_distribution(distribution_);
1877 return proto;
1878 }
1879
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1880 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
1881 const HloPrintOptions& options) const {
1882 return {StrCat("distribution=", RandomDistributionToString(distribution_))};
1883 }
1884
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1885 bool HloRngInstruction::IsElementwiseImpl(
1886 const absl::optional<int64>& operand_idx) const {
1887 return true;
1888 }
1889
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1890 bool HloRngInstruction::IdenticalSlowPath(
1891 const HloInstruction& other,
1892 const std::function<bool(const HloComputation*, const HloComputation*)>&
1893 eq_computations) const {
1894 const auto& casted_other = static_cast<const HloRngInstruction&>(other);
1895 return distribution_ == casted_other.distribution_;
1896 }
1897
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1898 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
1899 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1900 HloCloneContext* context) const {
1901 return absl::make_unique<HloRngInstruction>(shape, distribution_,
1902 new_operands);
1903 }
1904
HloParameterInstruction(int64 parameter_number,const Shape & shape,const string & name)1905 HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
1906 const Shape& shape,
1907 const string& name)
1908 : HloInstruction(HloOpcode::kParameter, shape),
1909 parameter_number_(parameter_number) {
1910 SetAndSanitizeName(name);
1911 }
1912
ToProto() const1913 HloInstructionProto HloParameterInstruction::ToProto() const {
1914 HloInstructionProto proto = HloInstruction::ToProto();
1915 proto.set_parameter_number(parameter_number_);
1916 if (parameter_replicated_at_leaf_buffers_) {
1917 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1918 proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
1919 replicated);
1920 }
1921 }
1922 return proto;
1923 }
1924
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1925 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
1926 const HloPrintOptions& options) const {
1927 std::vector<string> result;
1928 if (!parameter_replicated_at_leaf_buffers_) {
1929 return result;
1930 }
1931 std::vector<string> buffers_replicated_strs;
1932 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
1933 buffers_replicated_strs.push_back(replicated ? "true" : "false");
1934 }
1935 if (options.print_ids()) {
1936 result.push_back(StrCat("parameter_replication={",
1937 StrJoin(buffers_replicated_strs, ","), "}"));
1938 }
1939 return result;
1940 }
1941
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1942 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
1943 const HloPrintOptions& options,
1944 CanonicalNameMap* canonical_name_map) const {
1945 return StrCat(parameter_number_);
1946 }
1947
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1948 bool HloParameterInstruction::IdenticalSlowPath(
1949 const HloInstruction& other,
1950 const std::function<bool(const HloComputation*, const HloComputation*)>&
1951 eq_computations) const {
1952 const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
1953 return parameter_number() == casted_other.parameter_number();
1954 }
1955
1956 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1957 HloParameterInstruction::CloneWithNewOperandsImpl(
1958 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1959 HloCloneContext* context) const {
1960 auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_,
1961 shape, name());
1962 if (parameter_replicated_at_leaf_buffers_ &&
1963 ShapeUtil::Equal(shape, this->shape())) {
1964 clone->set_parameter_replicated_at_leaf_buffers(
1965 *parameter_replicated_at_leaf_buffers_);
1966 }
1967 return clone;
1968 }
1969
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64 index)1970 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
1971 const Shape& shape, HloInstruction* operand, int64 index)
1972 : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
1973 AppendOperand(operand);
1974 }
1975
ToProto() const1976 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
1977 HloInstructionProto proto = HloInstruction::ToProto();
1978 proto.set_tuple_index(tuple_index_);
1979 return proto;
1980 }
1981
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1982 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
1983 const HloPrintOptions& options) const {
1984 return {StrCat("index=", tuple_index())};
1985 }
1986
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1987 bool HloGetTupleElementInstruction::IdenticalSlowPath(
1988 const HloInstruction& other,
1989 const std::function<bool(const HloComputation*, const HloComputation*)>&
1990 eq_computations) const {
1991 const auto& casted_other =
1992 static_cast<const HloGetTupleElementInstruction&>(other);
1993 return tuple_index() == casted_other.tuple_index();
1994 }
1995
1996 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1997 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
1998 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1999 HloCloneContext* context) const {
2000 CHECK_EQ(new_operands.size(), 1);
2001 return absl::make_unique<HloGetTupleElementInstruction>(
2002 shape, new_operands[0], tuple_index());
2003 }
2004
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2005 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2006 const Shape& shape, HloInstruction* operand, const int exponent_bits,
2007 const int mantissa_bits)
2008 : HloInstruction(HloOpcode::kReducePrecision, shape),
2009 exponent_bits_(exponent_bits),
2010 mantissa_bits_(mantissa_bits) {
2011 AppendOperand(operand);
2012 }
2013
ToProto() const2014 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2015 HloInstructionProto proto = HloInstruction::ToProto();
2016 proto.set_exponent_bits(exponent_bits_);
2017 proto.set_mantissa_bits(mantissa_bits_);
2018 return proto;
2019 }
2020
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2021 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2022 const HloPrintOptions& options) const {
2023 return {StrCat("exponent_bits=", exponent_bits_),
2024 StrCat("mantissa_bits=", mantissa_bits_)};
2025 }
2026
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2027 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2028 const HloInstruction& other,
2029 const std::function<bool(const HloComputation*, const HloComputation*)>&
2030 eq_computations) const {
2031 const auto& casted_other =
2032 static_cast<const HloReducePrecisionInstruction&>(other);
2033 // A reduce-precision operation is determined by the bit sizes.
2034 return exponent_bits() == casted_other.exponent_bits() &&
2035 mantissa_bits() == casted_other.mantissa_bits();
2036 }
2037
2038 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2039 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2040 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2041 HloCloneContext* context) const {
2042 CHECK_EQ(new_operands.size(), 1);
2043 return absl::make_unique<HloReducePrecisionInstruction>(
2044 shape, new_operands[0], exponent_bits(), mantissa_bits());
2045 }
2046
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)2047 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2048 HloInstruction* token_operand,
2049 const string& config)
2050 : HloInstruction(HloOpcode::kInfeed,
2051 ShapeUtil::MakeTupleShape(
2052 {infeed_shape, ShapeUtil::MakeTokenShape()})),
2053 infeed_config_(config) {
2054 AppendOperand(token_operand);
2055 }
2056
ToProto() const2057 HloInstructionProto HloInfeedInstruction::ToProto() const {
2058 HloInstructionProto proto = HloInstruction::ToProto();
2059 proto.set_infeed_config(infeed_config_);
2060 return proto;
2061 }
2062
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2063 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2064 const HloPrintOptions& options) const {
2065 if (infeed_config_.empty()) {
2066 return {};
2067 }
2068 return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2069 }
2070
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2071 bool HloInfeedInstruction::IdenticalSlowPath(
2072 const HloInstruction& other,
2073 const std::function<bool(const HloComputation*, const HloComputation*)>&
2074 eq_computations) const {
2075 // Not yet supported.
2076 return false;
2077 }
2078
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2079 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2080 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2081 HloCloneContext* context) const {
2082 CHECK_EQ(new_operands.size(), 1);
2083 return absl::make_unique<HloInfeedInstruction>(
2084 infeed_shape(), new_operands[0], infeed_config());
2085 }
2086
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2087 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2088 HloInstruction* operand,
2089 HloInstruction* token_operand,
2090 absl::string_view outfeed_config)
2091 : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2092 outfeed_shape_(outfeed_shape),
2093 outfeed_config_(outfeed_config) {
2094 AppendOperand(operand);
2095 AppendOperand(token_operand);
2096 }
2097
ToProto() const2098 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2099 HloInstructionProto proto = HloInstruction::ToProto();
2100 proto.set_outfeed_config(outfeed_config());
2101 *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2102 return proto;
2103 }
2104
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2105 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2106 const HloPrintOptions& options) const {
2107 std::vector<string> extra;
2108 extra.push_back(StrCat("outfeed_shape=",
2109 ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2110 if (!outfeed_config_.empty()) {
2111 extra.push_back(
2112 StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2113 }
2114 return extra;
2115 }
2116
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2117 bool HloOutfeedInstruction::IdenticalSlowPath(
2118 const HloInstruction& other,
2119 const std::function<bool(const HloComputation*, const HloComputation*)>&
2120 eq_computations) const {
2121 // Not yet supported.
2122 return false;
2123 }
2124
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2125 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2126 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2127 HloCloneContext* context) const {
2128 CHECK_EQ(new_operands.size(), 2);
2129 return absl::make_unique<HloOutfeedInstruction>(
2130 outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2131 }
2132
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2133 HloConvolutionInstruction::HloConvolutionInstruction(
2134 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2135 int64 feature_group_count, int64 batch_group_count, const Window& window,
2136 const ConvolutionDimensionNumbers& dimension_numbers,
2137 const PrecisionConfig& precision_config)
2138 : HloInstruction(HloOpcode::kConvolution, shape),
2139 feature_group_count_(feature_group_count),
2140 batch_group_count_(batch_group_count),
2141 window_(window),
2142 convolution_dimension_numbers_(dimension_numbers),
2143 precision_config_(precision_config) {
2144 if (window_util::HasBaseDilation(window)) {
2145 SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2146 }
2147 if (window_util::HasWindowDilation(window)) {
2148 SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2149 }
2150 AppendOperand(lhs);
2151 AppendOperand(rhs);
2152 }
2153
ToCategory() const2154 string HloConvolutionInstruction::ToCategory() const {
2155 string category = "convolution";
2156 if (window_util::HasBaseDilation(window())) {
2157 category += " base-dilated";
2158 }
2159 if (window_util::HasWindowDilation(window())) {
2160 category += " window-dilated";
2161 }
2162 return category;
2163 }
2164
ToProto() const2165 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2166 HloInstructionProto proto = HloInstruction::ToProto();
2167 *proto.mutable_window() = window_;
2168 *proto.mutable_convolution_dimension_numbers() =
2169 convolution_dimension_numbers_;
2170 proto.set_feature_group_count(feature_group_count_);
2171 proto.set_batch_group_count(batch_group_count_);
2172 *proto.mutable_precision_config() = precision_config_;
2173 return proto;
2174 }
2175
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2176 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2177 const HloPrintOptions& options) const {
2178 std::vector<string> extra;
2179 if (window_.dimensions_size() != 0) {
2180 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2181 }
2182 extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2183 convolution_dimension_numbers_)));
2184 if (feature_group_count_ != 1) {
2185 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2186 }
2187
2188 if (batch_group_count_ != 1) {
2189 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2190 }
2191
2192 string precision_config_string = PrecisionConfigToString(precision_config_);
2193 if (!precision_config_string.empty()) {
2194 extra.push_back(precision_config_string);
2195 }
2196 return extra;
2197 }
2198
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2199 bool HloConvolutionInstruction::IdenticalSlowPath(
2200 const HloInstruction& other,
2201 const std::function<bool(const HloComputation*, const HloComputation*)>&
2202 eq_computations) const {
2203 const auto& casted_other =
2204 static_cast<const HloConvolutionInstruction&>(other);
2205 if (feature_group_count_ != other.feature_group_count()) {
2206 return false;
2207 }
2208 if (batch_group_count_ != other.batch_group_count()) {
2209 return false;
2210 }
2211 return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2212 protobuf_util::ProtobufEquals(
2213 convolution_dimension_numbers(),
2214 casted_other.convolution_dimension_numbers()) &&
2215 protobuf_util::ProtobufEquals(precision_config(),
2216 casted_other.precision_config());
2217 }
2218
2219 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2220 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2221 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2222 HloCloneContext* context) const {
2223 CHECK_EQ(new_operands.size(), 2);
2224 return absl::make_unique<HloConvolutionInstruction>(
2225 shape, new_operands[0], new_operands[1], feature_group_count_,
2226 batch_group_count_, window(), convolution_dimension_numbers_,
2227 precision_config_);
2228 }
2229
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2230 HloReduceWindowInstruction::HloReduceWindowInstruction(
2231 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2232 const Window& window, HloComputation* reduce_computation)
2233 : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2234 absl::MakeSpan(&init_value, 1), window,
2235 reduce_computation) {}
2236
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2237 HloReduceWindowInstruction::HloReduceWindowInstruction(
2238 const Shape& shape, absl::Span<HloInstruction* const> operands,
2239 absl::Span<HloInstruction* const> init_values, const Window& window,
2240 HloComputation* reduce_computation)
2241 : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2242 for (auto* operand : operands) {
2243 AppendOperand(operand);
2244 }
2245 for (auto* init_value : init_values) {
2246 AppendOperand(init_value);
2247 }
2248 AppendComputation(reduce_computation);
2249 }
2250
ToProto() const2251 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2252 HloInstructionProto proto = HloInstruction::ToProto();
2253 *proto.mutable_window() = window_;
2254 return proto;
2255 }
2256
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2257 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2258 const HloPrintOptions& options) const {
2259 std::vector<string> extra;
2260 if (window_.dimensions_size() != 0) {
2261 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2262 }
2263 return extra;
2264 }
2265
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2266 bool HloReduceWindowInstruction::IdenticalSlowPath(
2267 const HloInstruction& other,
2268 const std::function<bool(const HloComputation*, const HloComputation*)>&
2269 eq_computations) const {
2270 const auto& casted_other =
2271 static_cast<const HloReduceWindowInstruction&>(other);
2272 return eq_computations(to_apply(), casted_other.to_apply()) &&
2273 protobuf_util::ProtobufEquals(window(), casted_other.window());
2274 }
2275
2276 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2277 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2278 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2279 HloCloneContext* context) const {
2280 CHECK_EQ(new_operands.size() % 2, 0);
2281 int64 num_operands = new_operands.size() / 2;
2282 return absl::make_unique<HloReduceWindowInstruction>(
2283 shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2284 absl::MakeSpan(new_operands)
2285 .subspan(num_operands, new_operands.size() / 2),
2286 window(), to_apply());
2287 }
2288
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2289 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2290 const Shape& shape, HloInstruction* operand, HloComputation* select,
2291 const Window& window, HloInstruction* source, HloInstruction* init_value,
2292 HloComputation* scatter)
2293 : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2294 AppendOperand(operand);
2295 AppendOperand(source);
2296 AppendOperand(init_value);
2297 // Select comes before scatter in the vector.
2298 AppendComputation(select);
2299 AppendComputation(scatter);
2300 }
2301
ToProto() const2302 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2303 HloInstructionProto proto = HloInstruction::ToProto();
2304 *proto.mutable_window() = window_;
2305 return proto;
2306 }
2307
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2308 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2309 const HloPrintOptions& options) const {
2310 std::vector<string> extra;
2311 if (window_.dimensions_size() != 0) {
2312 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2313 }
2314 return extra;
2315 }
2316
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2317 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2318 const HloInstruction& other,
2319 const std::function<bool(const HloComputation*, const HloComputation*)>&
2320 eq_computations) const {
2321 const auto& casted_other =
2322 static_cast<const HloSelectAndScatterInstruction&>(other);
2323 return eq_computations(select(), casted_other.select()) &&
2324 eq_computations(scatter(), casted_other.scatter()) &&
2325 protobuf_util::ProtobufEquals(window(), casted_other.window());
2326 }
2327
2328 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2329 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2330 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2331 HloCloneContext* context) const {
2332 CHECK_EQ(new_operands.size(), 3);
2333 return absl::make_unique<HloSelectAndScatterInstruction>(
2334 shape, new_operands[0], select(), window(), new_operands[1],
2335 new_operands[2], scatter());
2336 }
2337
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)2338 HloCustomCallInstruction::HloCustomCallInstruction(
2339 const Shape& shape, absl::Span<HloInstruction* const> operands,
2340 absl::string_view custom_call_target, string opaque)
2341 : HloInstruction(HloOpcode::kCustomCall, shape),
2342 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2343 feature_group_count_(1),
2344 batch_group_count_(1),
2345 layout_constrained_(false),
2346 padding_type_(PaddingType::PADDING_INVALID),
2347 custom_call_has_side_effect_(false) {
2348 set_raw_backend_config_string(std::move(opaque));
2349 for (auto operand : operands) {
2350 AppendOperand(operand);
2351 }
2352 }
2353
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque)2354 HloCustomCallInstruction::HloCustomCallInstruction(
2355 const Shape& shape, absl::Span<HloInstruction* const> operands,
2356 HloComputation* to_apply, absl::string_view custom_call_target,
2357 string opaque)
2358 : HloInstruction(HloOpcode::kCustomCall, shape),
2359 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2360 feature_group_count_(1),
2361 batch_group_count_(1),
2362 layout_constrained_(false),
2363 padding_type_(PaddingType::PADDING_INVALID),
2364 custom_call_has_side_effect_(false) {
2365 set_raw_backend_config_string(std::move(opaque));
2366 for (auto operand : operands) {
2367 AppendOperand(operand);
2368 }
2369 AppendComputation(to_apply);
2370 }
2371
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque)2372 HloCustomCallInstruction::HloCustomCallInstruction(
2373 const Shape& shape, absl::Span<HloInstruction* const> operands,
2374 absl::Span<HloComputation* const> called_computations,
2375 absl::string_view custom_call_target, string opaque)
2376 : HloInstruction(HloOpcode::kCustomCall, shape),
2377 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2378 feature_group_count_(1),
2379 batch_group_count_(1),
2380 layout_constrained_(false),
2381 padding_type_(PaddingType::PADDING_INVALID),
2382 custom_call_has_side_effect_(false) {
2383 set_raw_backend_config_string(std::move(opaque));
2384 for (auto operand : operands) {
2385 AppendOperand(operand);
2386 }
2387 for (auto comp : called_computations) {
2388 AppendComputation(comp);
2389 }
2390 }
2391
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout)2392 HloCustomCallInstruction::HloCustomCallInstruction(
2393 const Shape& shape, absl::Span<HloInstruction* const> operands,
2394 absl::string_view custom_call_target, string opaque,
2395 absl::Span<const Shape> operand_shapes_with_layout)
2396 : HloInstruction(HloOpcode::kCustomCall, shape),
2397 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2398 feature_group_count_(1),
2399 batch_group_count_(1),
2400 layout_constrained_(true),
2401 padding_type_(PaddingType::PADDING_INVALID),
2402 operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2403 operand_shapes_with_layout.end()),
2404 custom_call_has_side_effect_(false) {
2405 set_raw_backend_config_string(std::move(opaque));
2406 for (auto operand : operands) {
2407 AppendOperand(operand);
2408 }
2409 }
2410
ToProto() const2411 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2412 HloInstructionProto proto = HloInstruction::ToProto();
2413 if (window_ != nullptr) {
2414 *proto.mutable_window() = *window_;
2415 }
2416 if (convolution_dimension_numbers_ != nullptr) {
2417 *proto.mutable_convolution_dimension_numbers() =
2418 *convolution_dimension_numbers_;
2419 }
2420 proto.set_custom_call_target(custom_call_target_);
2421 proto.set_feature_group_count(feature_group_count_);
2422 proto.set_batch_group_count(batch_group_count_);
2423 *proto.mutable_precision_config() = precision_config_;
2424 proto.set_padding_type(padding_type_);
2425 if (layout_constrained()) {
2426 proto.set_constrain_layout(true);
2427 for (const Shape& shape : operand_shapes_with_layout_) {
2428 *proto.add_operand_shapes_with_layout() = shape.ToProto();
2429 }
2430 }
2431 proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2432 if (literal_.has_value()) {
2433 *proto.mutable_literal() = literal_->ToProto();
2434 }
2435 for (const auto& pair : output_to_operand_aliasing_) {
2436 auto aliasing = proto.add_custom_call_output_operand_aliasing();
2437 aliasing->set_operand_index(pair.second.first);
2438 for (int64 index : pair.first) {
2439 aliasing->add_output_shape_index(index);
2440 }
2441 for (int64 index : pair.second.second) {
2442 aliasing->add_operand_shape_index(index);
2443 }
2444 }
2445 return proto;
2446 }
2447
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2448 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2449 const HloPrintOptions& options) const {
2450 std::vector<string> extra;
2451 if (window_ != nullptr) {
2452 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2453 }
2454 if (convolution_dimension_numbers_ != nullptr) {
2455 extra.push_back(StrCat(
2456 "dim_labels=",
2457 ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2458 }
2459 if (feature_group_count_ != 1) {
2460 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2461 }
2462 if (batch_group_count_ != 1) {
2463 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2464 }
2465 string precision_config_string = PrecisionConfigToString(precision_config_);
2466 if (!precision_config_string.empty()) {
2467 extra.push_back(precision_config_string);
2468 }
2469 if (padding_type_ != PaddingType::PADDING_INVALID) {
2470 extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2471 }
2472 // By contract, we print the custom call target even if
2473 // options.print_subcomputation_mode() == kOff, because the call target is not
2474 // an HloComputation.
2475 extra.push_back(
2476 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2477
2478 if (layout_constrained()) {
2479 std::vector<string> shape_strings;
2480 for (const Shape& shape : operand_shapes_with_layout_) {
2481 shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2482 }
2483 extra.push_back(StrCat("operand_layout_constraints={",
2484 StrJoin(shape_strings, ", "), "}"));
2485 }
2486 if (custom_call_has_side_effect_) {
2487 extra.push_back("custom_call_has_side_effect=true");
2488 }
2489 if (literal_.has_value()) {
2490 extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
2491 }
2492 if (!output_to_operand_aliasing_.empty()) {
2493 std::vector<string> pair_strings;
2494 for (const auto& pair : output_to_operand_aliasing_) {
2495 pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2496 pair.second.first, ", ",
2497 pair.second.second.ToString(), ")"));
2498 }
2499 extra.push_back(StrCat("output_to_operand_aliasing={",
2500 StrJoin(pair_strings, ", "), "}"));
2501 }
2502 return extra;
2503 }
2504
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2505 bool HloCustomCallInstruction::IdenticalSlowPath(
2506 const HloInstruction& other,
2507 const std::function<bool(const HloComputation*, const HloComputation*)>&
2508 eq_computations) const {
2509 const auto& casted_other =
2510 static_cast<const HloCustomCallInstruction&>(other);
2511 if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2512 (window_ != nullptr &&
2513 !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2514 return false;
2515 }
2516 if ((convolution_dimension_numbers_ == nullptr) !=
2517 (casted_other.convolution_dimension_numbers_ == nullptr) ||
2518 (convolution_dimension_numbers_ != nullptr &&
2519 !protobuf_util::ProtobufEquals(
2520 convolution_dimension_numbers(),
2521 casted_other.convolution_dimension_numbers()))) {
2522 return false;
2523 }
2524 if (feature_group_count_ != casted_other.feature_group_count_) {
2525 return false;
2526 }
2527 if (batch_group_count_ != casted_other.batch_group_count_) {
2528 return false;
2529 }
2530
2531 if (padding_type_ != casted_other.padding_type()) {
2532 return false;
2533 }
2534
2535 if (layout_constrained() != casted_other.layout_constrained()) {
2536 return false;
2537 }
2538 if (layout_constrained()) {
2539 for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2540 if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2541 casted_other.operand_shapes_with_layout_[i])) {
2542 return false;
2543 }
2544 }
2545 }
2546 if (custom_call_has_side_effect_ !=
2547 casted_other.custom_call_has_side_effect()) {
2548 return false;
2549 }
2550 if (output_to_operand_aliasing_ !=
2551 casted_other.output_to_operand_aliasing()) {
2552 return false;
2553 }
2554 if (!protobuf_util::ProtobufEquals(precision_config(),
2555 casted_other.precision_config())) {
2556 return false;
2557 }
2558
2559 if (called_computations().size() != other.called_computations().size()) {
2560 return false;
2561 }
2562 for (int64 i = 0; i < called_computations().size(); ++i) {
2563 if (!eq_computations(called_computations()[i],
2564 other.called_computations()[i])) {
2565 return false;
2566 }
2567 }
2568 if (HasLiteral() == casted_other.HasLiteral()) {
2569 if (HasLiteral() && literal() == casted_other.literal()) {
2570 return false;
2571 }
2572 } else {
2573 return true;
2574 }
2575
2576 // Note: backend_config comparison is done in Identical, which is the
2577 // intended/exposed way to compare computations, and so not repeated here.
2578 return custom_call_target_ == casted_other.custom_call_target_;
2579 }
2580
2581 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2582 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2583 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2584 HloCloneContext* context) const {
2585 auto cloned = absl::make_unique<HloCustomCallInstruction>(
2586 shape, new_operands, custom_call_target(), opaque());
2587 if (layout_constrained()) {
2588 cloned->layout_constrained_ = true;
2589 cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2590 }
2591 if (window_ != nullptr) {
2592 cloned->set_window(*window_);
2593 }
2594 if (convolution_dimension_numbers_ != nullptr) {
2595 cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2596 }
2597 if (HasLiteral()) {
2598 cloned->set_literal(literal().Clone());
2599 }
2600 cloned->set_feature_group_count(feature_group_count_);
2601 cloned->set_batch_group_count(batch_group_count_);
2602 cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2603 cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2604 cloned->set_padding_type(padding_type_);
2605 *cloned->mutable_precision_config() = precision_config();
2606 return std::move(cloned);
2607 }
2608
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2609 HloPadInstruction::HloPadInstruction(const Shape& shape,
2610 HloInstruction* operand,
2611 HloInstruction* padding_value,
2612 const PaddingConfig& padding_config)
2613 : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2614 AppendOperand(operand);
2615 AppendOperand(padding_value);
2616 }
2617
ToProto() const2618 HloInstructionProto HloPadInstruction::ToProto() const {
2619 HloInstructionProto proto = HloInstruction::ToProto();
2620 *proto.mutable_padding_config() = padding_config_;
2621 return proto;
2622 }
2623
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2624 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2625 const HloPrintOptions& options) const {
2626 return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2627 }
2628
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2629 bool HloPadInstruction::IdenticalSlowPath(
2630 const HloInstruction& other,
2631 const std::function<bool(const HloComputation*, const HloComputation*)>&
2632 eq_computations) const {
2633 const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2634 return protobuf_util::ProtobufEquals(padding_config(),
2635 casted_other.padding_config());
2636 }
2637
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2638 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2639 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2640 HloCloneContext* context) const {
2641 CHECK_EQ(new_operands.size(), 2);
2642 return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2643 new_operands[1], padding_config_);
2644 }
2645
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2646 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2647 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2648 absl::Span<const int64> slice_sizes)
2649 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2650 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2651 AppendOperand(operand);
2652 AppendOperand(start_indices);
2653 }
2654
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2655 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2656 const Shape& shape, HloInstruction* operand,
2657 absl::Span<HloInstruction* const> start_indices,
2658 absl::Span<const int64> slice_sizes)
2659 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2660 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2661 AppendOperand(operand);
2662 for (HloInstruction* index : start_indices) {
2663 AppendOperand(index);
2664 }
2665 }
2666
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2667 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2668 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2669 HloInstruction* start_indices)
2670 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2671 AppendOperand(operand);
2672 AppendOperand(update);
2673 AppendOperand(start_indices);
2674 }
2675
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2676 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2677 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2678 absl::Span<HloInstruction* const> start_indices)
2679 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2680 AppendOperand(operand);
2681 AppendOperand(update);
2682 for (HloInstruction* index : start_indices) {
2683 AppendOperand(index);
2684 }
2685 }
2686
ToProto() const2687 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2688 HloInstructionProto proto = HloInstruction::ToProto();
2689 for (int64 slice_size : dynamic_slice_sizes_) {
2690 proto.add_dynamic_slice_sizes(slice_size);
2691 }
2692 return proto;
2693 }
2694
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2695 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2696 const HloPrintOptions& options) const {
2697 return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2698 "}")};
2699 }
2700
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2701 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2702 const HloInstruction& other,
2703 const std::function<bool(const HloComputation*, const HloComputation*)>&
2704 eq_computations) const {
2705 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
2706 return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
2707 }
2708
2709 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2710 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2711 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2712 HloCloneContext* context) const {
2713 if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2714 // TODO(b/118437727): Old form, remove this path.
2715 return absl::make_unique<HloDynamicSliceInstruction>(
2716 shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2717 } else {
2718 return absl::make_unique<HloDynamicSliceInstruction>(
2719 shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2720 }
2721 }
2722
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2723 HloGatherInstruction::HloGatherInstruction(
2724 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2725 const GatherDimensionNumbers& gather_dim_numbers,
2726 absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2727 : HloInstruction(HloOpcode::kGather, shape),
2728 indices_are_sorted_(indices_are_sorted) {
2729 AppendOperand(operand);
2730 AppendOperand(start_indices);
2731 gather_dimension_numbers_ =
2732 absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2733 absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2734 }
2735
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2736 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2737 const GatherDimensionNumbers& gather_dimension_numbers) {
2738 string offset_dims =
2739 StrCat("offset_dims={",
2740 StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2741 string collapsed_slice_dims = StrCat(
2742 "collapsed_slice_dims={",
2743 StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2744 string start_index_map =
2745 StrCat("start_index_map={",
2746 StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2747 string index_vector_dim =
2748 StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2749
2750 return StrJoin<std::initializer_list<string>>(
2751 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2752 ", ");
2753 }
2754
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64 index_vector_dim)2755 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2756 absl::Span<const int64> offset_dims,
2757 absl::Span<const int64> collapsed_slice_dims,
2758 absl::Span<const int64> start_index_map, int64 index_vector_dim) {
2759 GatherDimensionNumbers gather_dim_numbers;
2760 for (int64 output_window_dim : offset_dims) {
2761 gather_dim_numbers.add_offset_dims(output_window_dim);
2762 }
2763 for (int64 elided_window_dim : collapsed_slice_dims) {
2764 gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2765 }
2766 for (int64 gather_dim_to_input_dim : start_index_map) {
2767 gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2768 }
2769
2770 gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2771 return gather_dim_numbers;
2772 }
2773
ToProto() const2774 HloInstructionProto HloGatherInstruction::ToProto() const {
2775 HloInstructionProto proto = HloInstruction::ToProto();
2776 *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2777 for (int64 bound : gather_slice_sizes()) {
2778 proto.add_gather_slice_sizes(bound);
2779 }
2780 proto.set_indices_are_sorted(indices_are_sorted());
2781 return proto;
2782 }
2783
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2784 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2785 const HloPrintOptions& options) const {
2786 std::vector<string> attrs{
2787 GatherDimensionNumbersToString(gather_dimension_numbers()),
2788 StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2789 if (indices_are_sorted()) {
2790 attrs.push_back("indices_are_sorted=true");
2791 }
2792 return attrs;
2793 }
2794
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2795 bool HloGatherInstruction::IdenticalSlowPath(
2796 const HloInstruction& other,
2797 const std::function<bool(const HloComputation*, const HloComputation*)>&
2798 eq_computations) const {
2799 const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2800 return protobuf_util::ProtobufEquals(
2801 gather_dimension_numbers(),
2802 casted_other.gather_dimension_numbers()) &&
2803 gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2804 indices_are_sorted() == casted_other.indices_are_sorted();
2805 }
2806
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2807 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2808 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2809 HloCloneContext* context) const {
2810 CHECK_EQ(new_operands.size(), 2);
2811 return absl::make_unique<HloGatherInstruction>(
2812 shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2813 gather_slice_sizes(), indices_are_sorted());
2814 }
2815
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2816 HloScatterInstruction::HloScatterInstruction(
2817 const Shape& shape, HloInstruction* operand,
2818 HloInstruction* scatter_indices, HloInstruction* updates,
2819 HloComputation* update_computation,
2820 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2821 bool unique_indices)
2822 : HloInstruction(HloOpcode::kScatter, shape),
2823 indices_are_sorted_(indices_are_sorted),
2824 unique_indices_(unique_indices) {
2825 AppendOperand(operand);
2826 AppendOperand(scatter_indices);
2827 AppendOperand(updates);
2828 AppendComputation(update_computation);
2829 scatter_dimension_numbers_ =
2830 absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2831 }
2832
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2833 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2834 const ScatterDimensionNumbers& scatter_dimension_numbers) {
2835 string update_window_dims =
2836 StrCat("update_window_dims={",
2837 StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
2838 string inserted_window_dims = StrCat(
2839 "inserted_window_dims={",
2840 StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
2841 string scatter_dims_to_operand_dims = StrCat(
2842 "scatter_dims_to_operand_dims={",
2843 StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
2844 "}");
2845 string index_vector_dim =
2846 StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
2847
2848 return StrJoin<std::initializer_list<string>>(
2849 {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
2850 index_vector_dim},
2851 ", ");
2852 }
2853
2854 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64 index_vector_dim)2855 HloScatterInstruction::MakeScatterDimNumbers(
2856 absl::Span<const int64> update_window_dims,
2857 absl::Span<const int64> inserted_window_dims,
2858 absl::Span<const int64> scatter_dims_to_operand_dims,
2859 int64 index_vector_dim) {
2860 ScatterDimensionNumbers scatter_dim_numbers;
2861 for (int64 update_window_dim : update_window_dims) {
2862 scatter_dim_numbers.add_update_window_dims(update_window_dim);
2863 }
2864 for (int64 inserted_window_dim : inserted_window_dims) {
2865 scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
2866 }
2867 for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
2868 scatter_dim_numbers.add_scatter_dims_to_operand_dims(
2869 scatter_dim_to_operand_dim);
2870 }
2871 scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
2872 return scatter_dim_numbers;
2873 }
2874
ToProto() const2875 HloInstructionProto HloScatterInstruction::ToProto() const {
2876 HloInstructionProto proto = HloInstruction::ToProto();
2877 *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
2878 proto.set_indices_are_sorted(indices_are_sorted());
2879 proto.set_unique_indices(unique_indices());
2880 return proto;
2881 }
2882
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2883 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
2884 const HloPrintOptions& options) const {
2885 std::vector<string> attrs{
2886 ScatterDimensionNumbersToString(scatter_dimension_numbers())};
2887 if (indices_are_sorted()) {
2888 attrs.push_back("indices_are_sorted=true");
2889 }
2890 if (unique_indices()) {
2891 attrs.push_back("unique_indices=true");
2892 }
2893 return attrs;
2894 }
2895
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2896 bool HloScatterInstruction::IdenticalSlowPath(
2897 const HloInstruction& other,
2898 const std::function<bool(const HloComputation*, const HloComputation*)>&
2899 eq_computations) const {
2900 const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
2901 return protobuf_util::ProtobufEquals(
2902 scatter_dimension_numbers(),
2903 casted_other.scatter_dimension_numbers()) &&
2904 eq_computations(to_apply(), casted_other.to_apply()) &&
2905 indices_are_sorted() == casted_other.indices_are_sorted() &&
2906 unique_indices() == casted_other.unique_indices();
2907 }
2908
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2909 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
2910 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2911 HloCloneContext* context) const {
2912 CHECK_EQ(new_operands.size(), 3);
2913 return absl::make_unique<HloScatterInstruction>(
2914 shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
2915 scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
2916 }
2917
HloIotaInstruction(const Shape & shape,int64 iota_dimension)2918 HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
2919 : HloInstruction(HloOpcode::kIota, shape),
2920 iota_dimension_(iota_dimension) {}
2921
ToProto() const2922 HloInstructionProto HloIotaInstruction::ToProto() const {
2923 HloInstructionProto proto = HloInstruction::ToProto();
2924 proto.add_dimensions(iota_dimension());
2925 return proto;
2926 }
2927
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2928 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
2929 const HloPrintOptions& options) const {
2930 return {StrCat("iota_dimension=", iota_dimension())};
2931 }
2932
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2933 bool HloIotaInstruction::IdenticalSlowPath(
2934 const HloInstruction& other,
2935 const std::function<bool(const HloComputation*, const HloComputation*)>&
2936 eq_computations) const {
2937 const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
2938 return iota_dimension() == casted_other.iota_dimension();
2939 }
2940
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2941 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
2942 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2943 HloCloneContext* context) const {
2944 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
2945 }
2946
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2947 HloDotInstruction::HloDotInstruction(
2948 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2949 const DotDimensionNumbers& dimension_numbers,
2950 const PrecisionConfig& precision_config)
2951 : HloInstruction(HloOpcode::kDot, shape),
2952 dot_dimension_numbers_(dimension_numbers),
2953 precision_config_(precision_config) {
2954 AppendOperand(lhs);
2955 AppendOperand(rhs);
2956 }
2957
ToProto() const2958 HloInstructionProto HloDotInstruction::ToProto() const {
2959 HloInstructionProto proto = HloInstruction::ToProto();
2960 *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
2961 *proto.mutable_precision_config() = precision_config_;
2962 return proto;
2963 }
2964
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2965 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
2966 const HloPrintOptions& options) const {
2967 std::vector<string> extra = {DotDimensionNumbersToString()};
2968
2969 string precision_config_string = PrecisionConfigToString(precision_config_);
2970 if (!precision_config_string.empty()) {
2971 extra.push_back(precision_config_string);
2972 }
2973 return extra;
2974 }
2975
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2976 bool HloDotInstruction::IdenticalSlowPath(
2977 const HloInstruction& other,
2978 const std::function<bool(const HloComputation*, const HloComputation*)>&
2979 eq_computations) const {
2980 const auto& casted_other = static_cast<const HloDotInstruction&>(other);
2981 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
2982 casted_other.dot_dimension_numbers()) &&
2983 protobuf_util::ProtobufEquals(precision_config(),
2984 casted_other.precision_config());
2985 }
2986
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2987 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
2988 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2989 HloCloneContext* context) const {
2990 CHECK_EQ(new_operands.size(), 2);
2991 return absl::make_unique<HloDotInstruction>(
2992 shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
2993 precision_config_);
2994 }
2995
DotDimensionNumbersToString() const2996 string HloDotInstruction::DotDimensionNumbersToString() const {
2997 std::vector<string> result;
2998 const DotDimensionNumbers& dnums = dot_dimension_numbers_;
2999 if (!dnums.lhs_batch_dimensions().empty()) {
3000 result.push_back(StrCat("lhs_batch_dims={",
3001 StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
3002 }
3003 result.push_back(StrCat("lhs_contracting_dims={",
3004 StrJoin(dnums.lhs_contracting_dimensions(), ","),
3005 "}"));
3006
3007 if (!dnums.rhs_batch_dimensions().empty()) {
3008 result.push_back(StrCat("rhs_batch_dims={",
3009 StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
3010 }
3011 result.push_back(StrCat("rhs_contracting_dims={",
3012 StrJoin(dnums.rhs_contracting_dimensions(), ","),
3013 "}"));
3014
3015 return StrJoin(result, ", ");
3016 }
3017
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3018 HloDomainInstruction::HloDomainInstruction(
3019 const Shape& shape, HloInstruction* operand,
3020 std::unique_ptr<DomainMetadata> operand_side_metadata,
3021 std::unique_ptr<DomainMetadata> user_side_metadata)
3022 : HloInstruction(HloOpcode::kDomain, shape),
3023 operand_side_metadata_(std::move(operand_side_metadata)),
3024 user_side_metadata_(std::move(user_side_metadata)) {
3025 AppendOperand(operand);
3026 }
3027
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3028 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
3029 const HloPrintOptions& options) const {
3030 if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3031 return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3032 "\", entry=", user_side_metadata_->ToString(),
3033 ", exit=", operand_side_metadata_->ToString(), "}")};
3034 }
3035 return {};
3036 }
3037
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3038 bool HloDomainInstruction::IdenticalSlowPath(
3039 const HloInstruction& other,
3040 const std::function<bool(const HloComputation*, const HloComputation*)>&
3041 eq_computations) const {
3042 const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3043 return operand_side_metadata().Matches(
3044 casted_other.operand_side_metadata()) &&
3045 user_side_metadata().Matches(casted_other.user_side_metadata());
3046 }
3047
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3048 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3049 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3050 HloCloneContext* context) const {
3051 CHECK_EQ(new_operands.size(), 1);
3052 return absl::make_unique<HloDomainInstruction>(
3053 shape, new_operands[0], operand_side_metadata_->Clone(),
3054 user_side_metadata_->Clone());
3055 }
3056
ToProto() const3057 HloInstructionProto HloDomainInstruction::ToProto() const {
3058 HloInstructionProto proto = HloInstruction::ToProto();
3059 auto operand_side_sharding =
3060 dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3061 if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3062 *proto.mutable_domain_entry_sharding() =
3063 operand_side_sharding->sharding()->ToProto();
3064 }
3065
3066 auto user_side_sharding =
3067 dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3068 if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3069 *proto.mutable_domain_exit_sharding() =
3070 user_side_sharding->sharding()->ToProto();
3071 }
3072
3073 return proto;
3074 }
3075
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64 dimension)3076 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3077 const Shape& shape, HloInstruction* operand, int64 dimension)
3078 : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3079 dimension_(dimension) {
3080 AppendOperand(operand);
3081 }
3082
ToProto() const3083 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3084 HloInstructionProto proto = HloInstruction::ToProto();
3085 proto.add_dimensions(dimension());
3086 return proto;
3087 }
3088
ExtraAttributesToStringImpl(const HloPrintOptions &) const3089 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3090 const HloPrintOptions& /*options*/) const {
3091 return {StrCat("dimensions={", dimension(), "}")};
3092 }
3093
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3094 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3095 const HloInstruction& other,
3096 const std::function<bool(const HloComputation*, const HloComputation*)>&
3097 /*eq_computations*/) const {
3098 const auto& casted_other =
3099 static_cast<const HloGetDimensionSizeInstruction&>(other);
3100 return dimension() == casted_other.dimension();
3101 }
3102
3103 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3104 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3105 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3106 HloCloneContext* /*context*/) const {
3107 if (new_operands.size() != 1) {
3108 LOG(FATAL) << "expects 1 operand";
3109 }
3110 return absl::make_unique<HloGetDimensionSizeInstruction>(
3111 shape, new_operands[0], dimension());
3112 }
3113
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)3114 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3115 const Shape& shape, HloInstruction* operand, HloInstruction* val,
3116 int64 dimension)
3117 : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3118 dimension_(dimension) {
3119 AppendOperand(operand);
3120 AppendOperand(val);
3121 }
3122
ExtraAttributesToStringImpl(const HloPrintOptions &) const3123 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3124 const HloPrintOptions& /*options*/) const {
3125 return {StrCat("dimensions={", dimension(), "}")};
3126 }
3127
ToProto() const3128 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3129 HloInstructionProto proto = HloInstruction::ToProto();
3130 proto.add_dimensions(dimension());
3131 return proto;
3132 }
3133
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3134 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3135 const HloInstruction& other,
3136 const std::function<bool(const HloComputation*, const HloComputation*)>&
3137 /*eq_computations*/) const {
3138 const auto& casted_other =
3139 static_cast<const HloSetDimensionSizeInstruction&>(other);
3140 return dimension() == casted_other.dimension();
3141 }
3142
3143 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3144 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3145 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3146 HloCloneContext* /*context*/) const {
3147 if (new_operands.size() != 2) {
3148 LOG(FATAL) << "expects 2 operand";
3149 }
3150 return absl::make_unique<HloSetDimensionSizeInstruction>(
3151 shape, new_operands[0], new_operands[1], dimension());
3152 }
3153
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64 delta)3154 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3155 const Shape& shape, int64 delta)
3156 : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3157
ToProto() const3158 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3159 HloInstructionProto proto = HloInstruction::ToProto();
3160 proto.set_delta(delta_);
3161 return proto;
3162 }
3163
3164 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3165 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3166 const HloPrintOptions& /*options*/) const {
3167 return {StrCat("delta=", delta())};
3168 }
3169
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3170 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3171 const HloInstruction& other,
3172 const std::function<bool(const HloComputation*, const HloComputation*)>&
3173 /*eq_computations*/) const {
3174 const auto& casted_other =
3175 static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3176 return delta() == casted_other.delta();
3177 }
3178
3179 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3180 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3181 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3182 HloCloneContext* /*context*/) const {
3183 if (!new_operands.empty()) {
3184 LOG(FATAL) << "expects 0 operand";
3185 }
3186 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3187 }
3188
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3189 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3190 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3191 : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3192 algorithm_(algorithm) {
3193 AppendOperand(state);
3194 }
3195
ToProto() const3196 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3197 HloInstructionProto proto = HloInstruction::ToProto();
3198 proto.set_rng_algorithm(algorithm_);
3199 return proto;
3200 }
3201
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3202 std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3203 const HloPrintOptions& options) const {
3204 return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3205 }
3206
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3207 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3208 const HloInstruction& other,
3209 const std::function<bool(const HloComputation*, const HloComputation*)>&
3210 eq_computations) const {
3211 const auto& casted_other =
3212 static_cast<const HloRngBitGeneratorInstruction&>(other);
3213 return algorithm() == casted_other.algorithm();
3214 }
3215
3216 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3217 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3218 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3219 HloCloneContext* /*context*/) const {
3220 CHECK_EQ(new_operands.size(), 1);
3221 return absl::make_unique<HloRngBitGeneratorInstruction>(
3222 shape, new_operands[0], algorithm());
3223 }
3224
3225 } // namespace xla
3226