1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17
18 #include <deque>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/protobuf.h"
39
40 namespace xla {
41 namespace {
42
43 using absl::CEscape;
44 using absl::StrAppend;
45 using absl::StrCat;
46 using absl::StrJoin;
47
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)48 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
49 const HloInstruction* operand) {
50 const auto operand_indices = instruction->OperandIndices(operand);
51 return absl::c_all_of(operand_indices, [instruction](int64_t operand_index) {
52 return instruction->IsElementwiseOnOperand(operand_index);
53 });
54 }
55
PrecisionConfigToString(const PrecisionConfig & precision_config)56 string PrecisionConfigToString(const PrecisionConfig& precision_config) {
57 if (absl::c_all_of(
58 precision_config.operand_precision(), [](int32_t precision) {
59 return static_cast<PrecisionConfig::Precision>(precision) ==
60 PrecisionConfig::DEFAULT;
61 })) {
62 return "";
63 }
64
65 return StrCat(
66 "operand_precision={",
67 StrJoin(
68 precision_config.operand_precision(), ",",
69 [](string* out, int32_t precision) {
70 CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
71 StrAppend(out,
72 PrecisionToString(
73 static_cast<PrecisionConfig::Precision>(precision)));
74 }),
75 "}");
76 }
77 } // namespace
78
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64_t feature_index)79 HloBatchNormInstruction::HloBatchNormInstruction(
80 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
81 HloInstruction* scale, float epsilon, int64_t feature_index)
82 : HloInstruction(opcode, shape),
83 epsilon_(epsilon),
84 feature_index_(feature_index) {
85 AppendOperand(operand);
86 AppendOperand(scale);
87 }
88
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const89 bool HloBatchNormInstruction::IdenticalSlowPath(
90 const HloInstruction& other,
91 const std::function<bool(const HloComputation*, const HloComputation*)>&
92 eq_computations) const {
93 const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
94 return feature_index() == casted_other.feature_index() &&
95 epsilon() == casted_other.epsilon();
96 }
97
ToProto() const98 HloInstructionProto HloBatchNormInstruction::ToProto() const {
99 HloInstructionProto proto = HloInstruction::ToProto();
100 proto.set_epsilon(epsilon_);
101 proto.set_feature_index(feature_index_);
102 return proto;
103 }
104
ExtraAttributesToStringImpl(const HloPrintOptions & options) const105 std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
106 const HloPrintOptions& options) const {
107 return {StrCat("epsilon=", epsilon()),
108 StrCat("feature_index=", feature_index())};
109 }
110
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)111 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
112 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
113 HloInstruction* offset, float epsilon, int64_t feature_index)
114 : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
115 scale, epsilon, feature_index) {
116 AppendOperand(offset);
117 }
118
119 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const120 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
121 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
122 HloCloneContext* context) const {
123 CHECK_EQ(new_operands.size(), 3);
124 return absl::make_unique<HloBatchNormTrainingInstruction>(
125 shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
126 feature_index());
127 }
128
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)129 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
130 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
131 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
132 float epsilon, int64_t feature_index)
133 : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
134 scale, epsilon, feature_index) {
135 AppendOperand(offset);
136 AppendOperand(mean);
137 AppendOperand(variance);
138 }
139
140 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const141 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
142 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
143 HloCloneContext* context) const {
144 CHECK_EQ(new_operands.size(), 5);
145 return absl::make_unique<HloBatchNormInferenceInstruction>(
146 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
147 new_operands[4], epsilon(), feature_index());
148 }
149
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)150 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
151 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
152 HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
153 float epsilon, int64_t feature_index)
154 : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
155 epsilon, feature_index) {
156 AppendOperand(mean);
157 AppendOperand(variance);
158 AppendOperand(grad_output);
159 }
160
161 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const162 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
163 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
164 HloCloneContext* context) const {
165 CHECK_EQ(new_operands.size(), 5);
166 return absl::make_unique<HloBatchNormGradInstruction>(
167 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
168 new_operands[4], epsilon(), feature_index());
169 }
170
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)171 HloFftInstruction::HloFftInstruction(const Shape& shape,
172 HloInstruction* operand, FftType fft_type,
173 absl::Span<const int64> fft_length)
174 : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
175 fft_length_.assign(fft_length.begin(), fft_length.end());
176 AppendOperand(operand);
177 }
178
ToProto() const179 HloInstructionProto HloFftInstruction::ToProto() const {
180 HloInstructionProto proto = HloInstruction::ToProto();
181 proto.set_fft_type(fft_type_);
182 for (int64_t fft_len : fft_length_) {
183 proto.add_fft_length(fft_len);
184 }
185 return proto;
186 }
187
ExtraAttributesToStringImpl(const HloPrintOptions & options) const188 std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
189 const HloPrintOptions& options) const {
190 return {StrCat("fft_type=", FftType_Name(fft_type())),
191 StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
192 }
193
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const194 bool HloFftInstruction::IdenticalSlowPath(
195 const HloInstruction& other,
196 const std::function<bool(const HloComputation*, const HloComputation*)>&
197 eq_computations) const {
198 const auto& casted_other = static_cast<const HloFftInstruction&>(other);
199 return fft_type() == casted_other.fft_type() &&
200 fft_length() == casted_other.fft_length();
201 }
202
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const203 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
204 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
205 HloCloneContext* context) const {
206 CHECK_EQ(new_operands.size(), 1);
207 return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
208 fft_length_);
209 }
210
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)211 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
212 HloInstruction* operand,
213 bool is_cross_program_prefetch)
214 : HloInstruction(HloOpcode::kCopyStart, shape),
215 is_cross_program_prefetch_(is_cross_program_prefetch) {
216 AppendOperand(operand);
217 }
218
ToProto() const219 HloInstructionProto HloCopyStartInstruction::ToProto() const {
220 HloInstructionProto proto = HloInstruction::ToProto();
221 proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
222 return proto;
223 }
224
ExtraAttributesToStringImpl(const HloPrintOptions & options) const225 std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
226 const HloPrintOptions& options) const {
227 std::vector<string> result;
228 if (is_cross_program_prefetch()) {
229 result.push_back("is_cross_program_prefetch=true");
230 }
231 return result;
232 }
233
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const234 bool HloCopyStartInstruction::IdenticalSlowPath(
235 const HloInstruction& other,
236 const std::function<bool(const HloComputation*, const HloComputation*)>&
237 eq_computations) const {
238 const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
239 return is_cross_program_prefetch() ==
240 casted_other.is_cross_program_prefetch();
241 }
242
243 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const244 HloCopyStartInstruction::CloneWithNewOperandsImpl(
245 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
246 HloCloneContext* context) const {
247 CHECK_EQ(new_operands.size(), 1);
248 return absl::make_unique<HloCopyStartInstruction>(
249 shape, new_operands[0], is_cross_program_prefetch());
250 }
251
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)252 HloCompareInstruction::HloCompareInstruction(
253 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
254 ComparisonDirection direction, absl::optional<Comparison::Type> type)
255 : HloInstruction(HloOpcode::kCompare, shape),
256 compare_(direction, type ? (*type)
257 : Comparison::DefaultComparisonType(
258 lhs->shape().element_type())) {
259 AppendOperand(lhs);
260 AppendOperand(rhs);
261 }
262
ToProto() const263 HloInstructionProto HloCompareInstruction::ToProto() const {
264 HloInstructionProto proto = HloInstruction::ToProto();
265 proto.set_comparison_direction(
266 ComparisonDirectionToString(compare_.GetDirection()));
267 proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
268 return proto;
269 }
270
ExtraAttributesToStringImpl(const HloPrintOptions & options) const271 std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
272 const HloPrintOptions& options) const {
273 std::vector<string> result;
274 result.push_back(
275 StrCat("direction=", ComparisonDirectionToString(direction())));
276 if (compare_.GetType() !=
277 Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
278 result.push_back(
279 StrCat("type=", ComparisonTypeToString(compare_.GetType())));
280 }
281 return result;
282 }
283
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const284 bool HloCompareInstruction::IdenticalSlowPath(
285 const HloInstruction& other,
286 const std::function<bool(const HloComputation*, const HloComputation*)>&
287 eq_computations) const {
288 const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
289 return direction() == casted_other.direction();
290 }
291
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const292 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
293 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
294 HloCloneContext* context) const {
295 CHECK_EQ(new_operands.size(), 2);
296 return absl::make_unique<HloCompareInstruction>(
297 shape, new_operands[0], new_operands[1], direction(), type());
298 }
299
300 namespace {
301
302 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
303 // of "key=value" attribute strings generically, using protocol buffer
304 // reflection.
305 //
306 // Currently implements a small subset of cases; feel free to add more as
307 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)308 std::vector<string> AttributeProtoToStringVector(
309 const tensorflow::protobuf::Message& message) {
310 const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
311 std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
312 reflection->ListFields(message, &fields);
313
314 std::vector<string> output;
315 for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
316 string s = absl::StrCat(field->name(), "=");
317 CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
318 switch (field->type()) {
319 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
320 bool val = reflection->GetBool(message, field);
321 absl::StrAppend(&s, val ? "true" : "false");
322 break;
323 }
324 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
325 const tensorflow::protobuf::EnumValueDescriptor* evd =
326 reflection->GetEnum(message, field);
327 absl::StrAppend(&s, evd->name());
328 break;
329 }
330 default:
331 LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
332 }
333 output.push_back(std::move(s));
334 }
335 return output;
336 }
337
338 } // namespace
339
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)340 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
341 const Shape& shape, HloInstruction* a, HloInstruction* b,
342 const TriangularSolveOptions& options)
343 : HloInstruction(HloOpcode::kTriangularSolve, shape),
344 triangular_solve_options_(options) {
345 AppendOperand(a);
346 AppendOperand(b);
347 }
348
ToProto() const349 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
350 HloInstructionProto proto = HloInstruction::ToProto();
351 *proto.mutable_triangular_solve_options() = triangular_solve_options_;
352 return proto;
353 }
354
ExtraAttributesToStringImpl(const HloPrintOptions & options) const355 std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
356 const HloPrintOptions& options) const {
357 return AttributeProtoToStringVector(triangular_solve_options_);
358 }
359
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const360 bool HloTriangularSolveInstruction::IdenticalSlowPath(
361 const HloInstruction& other,
362 const std::function<bool(const HloComputation*, const HloComputation*)>&
363 eq_computations) const {
364 const auto& casted_other =
365 static_cast<const HloTriangularSolveInstruction&>(other);
366 const auto& options = triangular_solve_options();
367 const auto& other_options = casted_other.triangular_solve_options();
368
369 return options.left_side() == other_options.left_side() &&
370 options.lower() == other_options.lower() &&
371 options.unit_diagonal() == other_options.unit_diagonal() &&
372 options.transpose_a() == other_options.transpose_a();
373 }
374
375 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const376 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
377 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
378 HloCloneContext* context) const {
379 CHECK_EQ(new_operands.size(), 2);
380 return absl::make_unique<HloTriangularSolveInstruction>(
381 shape, new_operands[0], new_operands[1], triangular_solve_options());
382 }
383
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)384 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
385 HloInstruction* a,
386 const CholeskyOptions& options)
387 : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
388 AppendOperand(a);
389 }
390
ToProto() const391 HloInstructionProto HloCholeskyInstruction::ToProto() const {
392 HloInstructionProto proto = HloInstruction::ToProto();
393 *proto.mutable_cholesky_options() = cholesky_options_;
394 return proto;
395 }
396
ExtraAttributesToStringImpl(const HloPrintOptions & options) const397 std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
398 const HloPrintOptions& options) const {
399 return AttributeProtoToStringVector(cholesky_options_);
400 }
401
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const402 bool HloCholeskyInstruction::IdenticalSlowPath(
403 const HloInstruction& other,
404 const std::function<bool(const HloComputation*, const HloComputation*)>&
405 eq_computations) const {
406 const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
407 const auto& options = cholesky_options();
408 const auto& other_options = casted_other.cholesky_options();
409
410 return options.lower() == other_options.lower();
411 }
412
413 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const414 HloCholeskyInstruction::CloneWithNewOperandsImpl(
415 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
416 HloCloneContext* context) const {
417 CHECK_EQ(new_operands.size(), 1);
418 return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
419 cholesky_options());
420 }
421
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const absl::optional<int64> & channel_id)422 HloChannelInstruction::HloChannelInstruction(
423 HloOpcode opcode, const Shape& shape,
424 const absl::optional<int64>& channel_id)
425 : HloInstruction(opcode, shape), channel_id_(channel_id) {}
426
set_channel_id(const absl::optional<int64> & channel_id)427 void HloChannelInstruction::set_channel_id(
428 const absl::optional<int64>& channel_id) {
429 channel_id_ = channel_id;
430 }
431
ToProto() const432 HloInstructionProto HloChannelInstruction::ToProto() const {
433 HloInstructionProto proto = HloInstruction::ToProto();
434 if (channel_id_) {
435 CHECK_GT(channel_id_.value(), 0)
436 << "Non-positive channel id is equivalent to no channel id";
437 proto.set_channel_id(*channel_id_);
438 }
439 return proto;
440 }
441
ExtraAttributesToStringImpl(const HloPrintOptions &) const442 std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
443 const HloPrintOptions& /*options*/) const {
444 std::vector<string> result;
445 if (channel_id_) {
446 result.push_back(StrCat("channel_id=", *channel_id_));
447 }
448 return result;
449 }
450
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const451 bool HloChannelInstruction::IdenticalSlowPath(
452 const HloInstruction& other,
453 const std::function<bool(const HloComputation*, const HloComputation*)>&
454 eq_computations) const {
455 if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
456 return false;
457 }
458 const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
459 return channel_id() == casted_other.channel_id();
460 }
461
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64_t channel_id,bool is_host_transfer)462 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
463 const Shape& shape,
464 int64_t channel_id,
465 bool is_host_transfer)
466 : HloChannelInstruction(opcode, shape, channel_id),
467 is_host_transfer_(is_host_transfer) {}
468
ToProto() const469 HloInstructionProto HloSendRecvInstruction::ToProto() const {
470 HloInstructionProto proto = HloChannelInstruction::ToProto();
471 proto.set_is_host_transfer(is_host_transfer_);
472 return proto;
473 }
474
ExtraAttributesToStringImpl(const HloPrintOptions & options) const475 std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
476 const HloPrintOptions& options) const {
477 std::vector<string> attrs =
478 HloChannelInstruction::ExtraAttributesToStringImpl(options);
479 if (is_host_transfer()) {
480 attrs.push_back("is_host_transfer=true");
481 }
482 return attrs;
483 }
484
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const485 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
486 const HloInstruction& other,
487 const std::function<bool(const HloComputation*, const HloComputation*)>&
488 eq_computations) const {
489 // Not yet supported.
490 return false;
491 }
492
493 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)494 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
495 HloInstruction* token,
496 int64_t channel_id,
497 bool is_host_transfer)
498 : HloSendRecvInstruction(
499 HloOpcode::kSend,
500 ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
501 ShapeUtil::MakeShape(U32, {}),
502 ShapeUtil::MakeTokenShape()}),
503 channel_id, is_host_transfer) {
504 AppendOperand(operand);
505 AppendOperand(token);
506 }
507
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const508 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
509 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
510 HloCloneContext* context) const {
511 CHECK_EQ(new_operands.size(), 2);
512 return absl::make_unique<HloSendInstruction>(
513 new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
514 }
515
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)516 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
517 bool is_host_transfer)
518 : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
519 CHECK_NOTNULL(operand)->channel_id().value(),
520 is_host_transfer) {
521 AppendOperand(operand);
522 }
523
524 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const525 HloSendDoneInstruction::CloneWithNewOperandsImpl(
526 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
527 HloCloneContext* context) const {
528 CHECK_EQ(new_operands.size(), 1);
529 return absl::make_unique<HloSendDoneInstruction>(
530 Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
531 }
532
533 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)534 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
535 HloInstruction* token,
536 int64_t channel_id,
537 bool is_host_transfer)
538 : HloSendRecvInstruction(
539 HloOpcode::kRecv,
540 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
541 ShapeUtil::MakeTokenShape()}),
542 channel_id, is_host_transfer) {
543 AppendOperand(token);
544 }
545
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const546 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
547 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
548 HloCloneContext* context) const {
549 CHECK_EQ(new_operands.size(), 1);
550 return absl::make_unique<HloRecvInstruction>(
551 ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
552 is_host_transfer());
553 }
554
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)555 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
556 bool is_host_transfer)
557 : HloSendRecvInstruction(
558 HloOpcode::kRecvDone,
559 ShapeUtil::MakeTupleShape(
560 {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
561 ShapeUtil::MakeTokenShape()}),
562 CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
563 AppendOperand(operand);
564 }
565
566 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const567 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
568 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
569 HloCloneContext* context) const {
570 CHECK_EQ(new_operands.size(), 1);
571 return absl::make_unique<HloRecvDoneInstruction>(
572 Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
573 }
574
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)575 HloCollectiveInstruction::HloCollectiveInstruction(
576 HloOpcode opcode, const Shape& shape,
577 absl::Span<HloInstruction* const> operands,
578 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
579 const absl::optional<int64>& channel_id)
580 : HloChannelInstruction(opcode, shape, channel_id),
581 replica_groups_(SpanToVector(replica_groups)),
582 constrain_layout_(constrain_layout) {
583 for (auto operand : operands) {
584 AppendOperand(operand);
585 }
586 }
587
ToProto() const588 HloInstructionProto HloCollectiveInstruction::ToProto() const {
589 HloInstructionProto proto = HloChannelInstruction::ToProto();
590 *proto.mutable_replica_groups() = {replica_groups_.begin(),
591 replica_groups_.end()};
592 proto.set_constrain_layout(constrain_layout_);
593 return proto;
594 }
595
ExtraAttributesToStringImpl(const HloPrintOptions & options) const596 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
597 const HloPrintOptions& options) const {
598 std::vector<string> result =
599 HloChannelInstruction::ExtraAttributesToStringImpl(options);
600 result.push_back(
601 StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
602 if (constrain_layout_) {
603 result.push_back("constrain_layout=true");
604 }
605 return result;
606 }
607
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const608 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
609 const HloInstruction& other,
610 const std::function<bool(const HloComputation*, const HloComputation*)>&
611 eq_computations) const {
612 const auto& casted_other =
613 static_cast<const HloCollectiveInstruction&>(other);
614 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
615 other, eq_computations) &&
616 constrain_layout() == casted_other.constrain_layout() &&
617 absl::c_equal(replica_groups(), casted_other.replica_groups(),
618 [](const ReplicaGroup& a, const ReplicaGroup& b) {
619 return absl::c_equal(a.replica_ids(), b.replica_ids());
620 });
621 }
622
HloAllGatherInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)623 HloAllGatherInstruction::HloAllGatherInstruction(
624 HloOpcode opcode, const Shape& shape,
625 absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
626 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
627 const absl::optional<int64>& channel_id, bool use_global_device_ids)
628 : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
629 constrain_layout, channel_id),
630 all_gather_dimension_(all_gather_dimension),
631 use_global_device_ids_(use_global_device_ids) {}
632
ExtraAttributesToStringImpl(const HloPrintOptions & options) const633 std::vector<string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
634 const HloPrintOptions& options) const {
635 std::vector<string> result =
636 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
637 result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
638 if (use_global_device_ids_) {
639 result.push_back("use_global_device_ids=true");
640 }
641 return result;
642 }
643
644 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const645 HloAllGatherInstruction::CloneWithNewOperandsImpl(
646 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
647 HloCloneContext* /*context*/) const {
648 return absl::make_unique<HloAllGatherInstruction>(
649 opcode(), shape, new_operands, all_gather_dimension(), replica_groups(),
650 constrain_layout(), channel_id(), use_global_device_ids());
651 }
652
ToProto() const653 HloInstructionProto HloAllGatherInstruction::ToProto() const {
654 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
655 proto.add_dimensions(all_gather_dimension_);
656 proto.set_use_global_device_ids(use_global_device_ids_);
657 return proto;
658 }
659
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const660 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
661 const HloInstruction& other,
662 const std::function<bool(const HloComputation*, const HloComputation*)>&
663 eq_computations) const {
664 const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
665 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
666 other, eq_computations) &&
667 all_gather_dimension_ == casted_other.all_gather_dimension() &&
668 use_global_device_ids() == casted_other.use_global_device_ids();
669 }
670
HloAllReduceInstructionBase(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)671 HloAllReduceInstructionBase::HloAllReduceInstructionBase(
672 HloOpcode opcode, const Shape& shape,
673 absl::Span<HloInstruction* const> operands,
674 HloComputation* reduce_computation,
675 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
676 const absl::optional<int64>& channel_id, bool use_global_device_ids)
677 : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
678 constrain_layout, channel_id),
679 use_global_device_ids_(use_global_device_ids) {
680 AppendComputation(reduce_computation);
681 }
682
ToProto() const683 HloInstructionProto HloAllReduceInstructionBase::ToProto() const {
684 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
685 proto.set_use_global_device_ids(use_global_device_ids_);
686 return proto;
687 }
688
ExtraAttributesToStringImpl(const HloPrintOptions & options) const689 std::vector<string> HloAllReduceInstructionBase::ExtraAttributesToStringImpl(
690 const HloPrintOptions& options) const {
691 std::vector<string> result =
692 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
693 if (use_global_device_ids_) {
694 result.push_back("use_global_device_ids=true");
695 }
696 return result;
697 }
698
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const699 bool HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
700 const HloInstruction& other,
701 const std::function<bool(const HloComputation*, const HloComputation*)>&
702 eq_computations) const {
703 if (opcode() != other.opcode()) {
704 return false;
705 }
706 const auto& casted_other =
707 static_cast<const HloAllReduceInstructionBase&>(other);
708 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
709 other, eq_computations) &&
710 constrain_layout() == casted_other.constrain_layout() &&
711 use_global_device_ids() == casted_other.use_global_device_ids() &&
712 eq_computations(to_apply(), casted_other.to_apply());
713 }
714
IsNoop() const715 bool HloAllReduceInstruction::IsNoop() const {
716 for (const auto& replica_group : replica_groups()) {
717 if (replica_group.replica_ids().size() != 1) {
718 return false;
719 }
720 }
721 return !channel_id();
722 }
723
724 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const725 HloAllReduceInstruction::CloneWithNewOperandsImpl(
726 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
727 HloCloneContext* /*context*/) const {
728 return absl::make_unique<HloAllReduceInstruction>(
729 opcode(), shape, new_operands, to_apply(), replica_groups(),
730 constrain_layout(), channel_id(), use_global_device_ids());
731 }
732
HloReduceScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)733 HloReduceScatterInstruction::HloReduceScatterInstruction(
734 const Shape& shape, absl::Span<HloInstruction* const> operands,
735 HloComputation* reduce_computation,
736 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
737 const absl::optional<int64>& channel_id, bool use_global_device_ids,
738 int64_t scatter_dimension)
739 : HloAllReduceInstructionBase(
740 HloOpcode::kReduceScatter, shape, operands, reduce_computation,
741 replica_groups, constrain_layout, channel_id, use_global_device_ids),
742 scatter_dimension_(scatter_dimension) {}
743
ExtraAttributesToStringImpl(const HloPrintOptions & options) const744 std::vector<string> HloReduceScatterInstruction::ExtraAttributesToStringImpl(
745 const HloPrintOptions& options) const {
746 std::vector<string> result =
747 HloAllReduceInstructionBase::ExtraAttributesToStringImpl(options);
748 result.push_back(StrCat("dimensions={", scatter_dimension_, "}"));
749 return result;
750 }
751
ToProto() const752 HloInstructionProto HloReduceScatterInstruction::ToProto() const {
753 HloInstructionProto proto = HloAllReduceInstructionBase::ToProto();
754 proto.add_dimensions(scatter_dimension_);
755 return proto;
756 }
757
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const758 bool HloReduceScatterInstruction::IdenticalSlowPathIgnoringChannelIdValues(
759 const HloInstruction& other,
760 const std::function<bool(const HloComputation*, const HloComputation*)>&
761 eq_computations) const {
762 const auto& casted_other =
763 static_cast<const HloReduceScatterInstruction&>(other);
764 return HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
765 other, eq_computations) &&
766 scatter_dimension_ == casted_other.scatter_dimension();
767 }
768
769 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const770 HloReduceScatterInstruction::CloneWithNewOperandsImpl(
771 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
772 HloCloneContext* /*context*/) const {
773 return absl::make_unique<HloReduceScatterInstruction>(
774 shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
775 channel_id(), use_global_device_ids(), scatter_dimension());
776 }
777
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)778 HloAllToAllInstruction::HloAllToAllInstruction(
779 const Shape& shape, absl::Span<HloInstruction* const> operands,
780 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
781 const absl::optional<int64>& channel_id,
782 const absl::optional<int64>& split_dimension)
783 : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
784 replica_groups, constrain_layout, channel_id),
785 split_dimension_(split_dimension) {}
786
787 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const788 HloAllToAllInstruction::CloneWithNewOperandsImpl(
789 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
790 HloCloneContext* /*context*/) const {
791 return absl::make_unique<HloAllToAllInstruction>(
792 shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
793 split_dimension());
794 }
795
ToProto() const796 HloInstructionProto HloAllToAllInstruction::ToProto() const {
797 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
798 if (split_dimension_) {
799 proto.add_dimensions(*split_dimension_);
800 }
801 return proto;
802 }
803
ExtraAttributesToStringImpl(const HloPrintOptions & options) const804 std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
805 const HloPrintOptions& options) const {
806 std::vector<string> result =
807 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
808 if (split_dimension_) {
809 result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
810 }
811 return result;
812 }
813
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const814 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
815 const HloInstruction& other,
816 const std::function<bool(const HloComputation*, const HloComputation*)>&
817 eq_computations) const {
818 const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
819 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
820 other, eq_computations) &&
821 split_dimension_ == casted_other.split_dimension();
822 }
823
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)824 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
825 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
826 const std::vector<std::pair<int64, int64>>& source_target_pairs,
827 const absl::optional<int64>& channel_id)
828 : HloChannelInstruction(opcode, shape, channel_id),
829 source_target_pairs_(source_target_pairs) {
830 AppendOperand(operand);
831 }
832
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)833 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
834 HloOpcode opcode, const Shape& shape, HloInstruction* input,
835 HloInstruction* output, HloInstruction* input_start_indices,
836 HloInstruction* output_start_indices,
837 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
838 absl::Span<const std::vector<int64_t>> slice_sizes,
839 const absl::optional<int64_t>& channel_id)
840 : HloChannelInstruction(opcode, shape, channel_id),
841 source_target_pairs_(source_target_pairs.begin(),
842 source_target_pairs.end()),
843 slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
844 AppendOperand(input);
845 AppendOperand(output);
846 AppendOperand(input_start_indices);
847 AppendOperand(output_start_indices);
848 }
849
ToProto() const850 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
851 HloInstructionProto proto = HloChannelInstruction::ToProto();
852 for (const auto& pair : source_target_pairs()) {
853 auto* proto_pair = proto.add_source_target_pairs();
854 proto_pair->set_source(pair.first);
855 proto_pair->set_target(pair.second);
856 }
857 for (const auto& slice_size : dynamic_slice_sizes_list()) {
858 for (const auto& dimension_slice_size : slice_size) {
859 proto.add_dynamic_slice_sizes(dimension_slice_size);
860 }
861 }
862 return proto;
863 }
864
865 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const866 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
867 const HloPrintOptions& options) const {
868 std::vector<string> result =
869 HloChannelInstruction::ExtraAttributesToStringImpl(options);
870 {
871 std::vector<string> strs;
872 for (const auto& pair : source_target_pairs()) {
873 strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
874 }
875 result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
876 }
877 if (!dynamic_slice_sizes_list().empty()) {
878 std::vector<string> strs;
879 for (const auto& slice_sizes : dynamic_slice_sizes_list()) {
880 strs.push_back(StrCat("{", StrJoin(slice_sizes, ","), "}"));
881 }
882 result.push_back(StrCat("slice_sizes={", StrJoin(strs, ","), "}"));
883 }
884 return result;
885 }
886
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const887 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
888 const HloInstruction& other,
889 const std::function<bool(const HloComputation*, const HloComputation*)>&
890 eq_computations) const {
891 if (opcode() != other.opcode()) {
892 return false;
893 }
894 const auto& casted_other =
895 static_cast<const HloCollectivePermuteInstruction&>(other);
896 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
897 other, eq_computations) &&
898 absl::c_equal(
899 source_target_pairs(), casted_other.source_target_pairs(),
900 [](const std::pair<int64, int64>& a,
901 const std::pair<int64, int64>& b) { return a == b; }) &&
902 absl::c_equal(
903 dynamic_slice_sizes_list(),
904 casted_other.dynamic_slice_sizes_list(),
905 [](const std::vector<int64>& a, const std::vector<int64>& b) {
906 return absl::c_equal(a, b);
907 });
908 }
909
910 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const911 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
912 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
913 HloCloneContext* /*context*/) const {
914 if (dynamic_slice_sizes_list().empty()) {
915 return absl::make_unique<HloCollectivePermuteInstruction>(
916 opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
917 } else {
918 return absl::make_unique<HloCollectivePermuteInstruction>(
919 opcode(), shape, new_operands[0], new_operands[1], new_operands[2],
920 new_operands[3], source_target_pairs(), dynamic_slice_sizes_list(),
921 channel_id());
922 }
923 }
924
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)925 HloReverseInstruction::HloReverseInstruction(const Shape& shape,
926 HloInstruction* operand,
927 absl::Span<const int64> dimensions)
928 : HloInstruction(HloOpcode::kReverse, shape),
929 dimensions_(dimensions.begin(), dimensions.end()) {
930 AppendOperand(operand);
931 }
932
ToProto() const933 HloInstructionProto HloReverseInstruction::ToProto() const {
934 HloInstructionProto proto = HloInstruction::ToProto();
935 for (int64_t dimension : dimensions_) {
936 proto.add_dimensions(dimension);
937 }
938 return proto;
939 }
940
ExtraAttributesToStringImpl(const HloPrintOptions & options) const941 std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
942 const HloPrintOptions& options) const {
943 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
944 }
945
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const946 bool HloReverseInstruction::IdenticalSlowPath(
947 const HloInstruction& other,
948 const std::function<bool(const HloComputation*, const HloComputation*)>&
949 eq_computations) const {
950 const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
951 return dimensions() == casted_other.dimensions();
952 }
953
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const954 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
955 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
956 HloCloneContext* context) const {
957 CHECK_EQ(new_operands.size(), 1);
958 return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
959 dimensions());
960 }
961
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)962 HloConcatenateInstruction::HloConcatenateInstruction(
963 const Shape& shape, absl::Span<HloInstruction* const> operands,
964 int64_t dimension)
965 : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
966 for (auto operand : operands) {
967 AppendOperand(operand);
968 }
969 }
970
ToProto() const971 HloInstructionProto HloConcatenateInstruction::ToProto() const {
972 HloInstructionProto proto = HloInstruction::ToProto();
973 for (int64_t dimension : dimensions_) {
974 proto.add_dimensions(dimension);
975 }
976 return proto;
977 }
978
ExtraAttributesToStringImpl(const HloPrintOptions & options) const979 std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
980 const HloPrintOptions& options) const {
981 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
982 }
983
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const984 bool HloConcatenateInstruction::IdenticalSlowPath(
985 const HloInstruction& other,
986 const std::function<bool(const HloComputation*, const HloComputation*)>&
987 eq_computations) const {
988 const auto& casted_other =
989 static_cast<const HloConcatenateInstruction&>(other);
990 return dimensions() == casted_other.dimensions();
991 }
992
993 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const994 HloConcatenateInstruction::CloneWithNewOperandsImpl(
995 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
996 HloCloneContext* context) const {
997 return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
998 dimensions(0));
999 }
1000
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1001 HloReduceInstruction::HloReduceInstruction(
1002 const Shape& shape, absl::Span<HloInstruction* const> args,
1003 absl::Span<const int64> dimensions_to_reduce,
1004 HloComputation* reduce_computation)
1005 : HloInstruction(HloOpcode::kReduce, shape),
1006 dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
1007 for (HloInstruction* arg : args) {
1008 AppendOperand(arg);
1009 }
1010 AppendComputation(reduce_computation);
1011 }
1012
ToProto() const1013 HloInstructionProto HloReduceInstruction::ToProto() const {
1014 HloInstructionProto proto = HloInstruction::ToProto();
1015 for (int64_t dimension : dimensions_) {
1016 proto.add_dimensions(dimension);
1017 }
1018 return proto;
1019 }
1020
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1021 std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
1022 const HloPrintOptions& options) const {
1023 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1024 }
1025
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1026 bool HloReduceInstruction::IdenticalSlowPath(
1027 const HloInstruction& other,
1028 const std::function<bool(const HloComputation*, const HloComputation*)>&
1029 eq_computations) const {
1030 const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
1031 // Reduction results are determined by the reduction dimension and the
1032 // reduction computation.
1033 return dimensions() == casted_other.dimensions() &&
1034 eq_computations(to_apply(), casted_other.to_apply());
1035 }
1036
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1037 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
1038 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1039 HloCloneContext* context) const {
1040 CHECK_EQ(new_operands.size() % 2, 0);
1041 return absl::make_unique<HloReduceInstruction>(shape, new_operands,
1042 dimensions(), to_apply());
1043 }
1044
HloSortInstruction(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1045 HloSortInstruction::HloSortInstruction(
1046 const Shape& shape, int64_t dimension,
1047 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1048 bool is_stable)
1049 : HloInstruction(HloOpcode::kSort, shape),
1050 dimensions_({dimension}),
1051 is_stable_(is_stable) {
1052 for (auto* value : operands) {
1053 AppendOperand(value);
1054 }
1055 AppendComputation(compare);
1056 }
1057
ToProto() const1058 HloInstructionProto HloSortInstruction::ToProto() const {
1059 HloInstructionProto proto = HloInstruction::ToProto();
1060 for (int64_t dimension : dimensions_) {
1061 proto.add_dimensions(dimension);
1062 }
1063 proto.set_is_stable(is_stable());
1064 return proto;
1065 }
1066
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1067 std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
1068 const HloPrintOptions& options) const {
1069 std::vector<string> attrs;
1070 attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
1071 if (is_stable()) {
1072 attrs.push_back("is_stable=true");
1073 }
1074 return attrs;
1075 }
1076
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1077 bool HloSortInstruction::IdenticalSlowPath(
1078 const HloInstruction& other,
1079 const std::function<bool(const HloComputation*, const HloComputation*)>&
1080 eq_computations) const {
1081 const auto& casted_other = static_cast<const HloSortInstruction&>(other);
1082 if (dimensions() != casted_other.dimensions()) {
1083 return false;
1084 }
1085 if (is_stable() != casted_other.is_stable()) {
1086 return false;
1087 }
1088 return eq_computations(to_apply(), other.to_apply());
1089 }
1090
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1091 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
1092 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1093 HloCloneContext* context) const {
1094 return absl::make_unique<HloSortInstruction>(
1095 shape, dimensions(0), new_operands, to_apply(), is_stable());
1096 }
1097
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1098 HloTransposeInstruction::HloTransposeInstruction(
1099 const Shape& shape, HloInstruction* operand,
1100 absl::Span<const int64> dimensions)
1101 : HloInstruction(HloOpcode::kTranspose, shape),
1102 dimensions_(dimensions.begin(), dimensions.end()) {
1103 AppendOperand(operand);
1104 }
1105
IsRank2Transpose() const1106 bool HloTransposeInstruction::IsRank2Transpose() const {
1107 return dimensions() == std::vector<int64>({1, 0}) &&
1108 shape().dimensions_size() == 2 &&
1109 std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1110 operand(0)->shape().dimensions().rbegin());
1111 }
1112
ToProto() const1113 HloInstructionProto HloTransposeInstruction::ToProto() const {
1114 HloInstructionProto proto = HloInstruction::ToProto();
1115 for (int64_t dimension : dimensions_) {
1116 proto.add_dimensions(dimension);
1117 }
1118 return proto;
1119 }
1120
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1121 std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
1122 const HloPrintOptions& options) const {
1123 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1124 }
1125
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1126 bool HloTransposeInstruction::IdenticalSlowPath(
1127 const HloInstruction& other,
1128 const std::function<bool(const HloComputation*, const HloComputation*)>&
1129 eq_computations) const {
1130 const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
1131 return dimensions() == casted_other.dimensions();
1132 }
1133
1134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1135 HloTransposeInstruction::CloneWithNewOperandsImpl(
1136 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1137 HloCloneContext* context) const {
1138 CHECK_EQ(new_operands.size(), 1);
1139 return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1140 dimensions());
1141 }
1142
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimension)1143 HloBroadcastInstruction::HloBroadcastInstruction(
1144 const Shape& shape, HloInstruction* operand,
1145 absl::Span<const int64> broadcast_dimension)
1146 : HloInstruction(HloOpcode::kBroadcast, shape),
1147 dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
1148 AppendOperand(operand);
1149 }
1150
ToProto() const1151 HloInstructionProto HloBroadcastInstruction::ToProto() const {
1152 HloInstructionProto proto = HloInstruction::ToProto();
1153 for (int64_t dimension : dimensions_) {
1154 proto.add_dimensions(dimension);
1155 }
1156 return proto;
1157 }
1158
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1159 std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
1160 const HloPrintOptions& options) const {
1161 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1162 }
1163
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1164 bool HloBroadcastInstruction::IdenticalSlowPath(
1165 const HloInstruction& other,
1166 const std::function<bool(const HloComputation*, const HloComputation*)>&
1167 eq_computations) const {
1168 const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
1169 return dimensions() == casted_other.dimensions();
1170 }
1171
1172 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1173 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1174 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1175 HloCloneContext* context) const {
1176 CHECK_EQ(new_operands.size(), 1);
1177 return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1178 dimensions());
1179 }
1180
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1181 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1182 const Shape& shape, HloInstruction* data_operand,
1183 absl::Span<HloInstruction* const> dim_sizes)
1184 : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1185 AppendOperand(data_operand);
1186 for (auto operand : dim_sizes) {
1187 AppendOperand(operand);
1188 }
1189 }
1190
1191 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1192 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1193 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1194 HloCloneContext* context) const {
1195 CHECK_GE(new_operands.size(), 1);
1196 return absl::make_unique<HloDynamicReshapeInstruction>(
1197 shape, new_operands[0], new_operands.subspan(1));
1198 }
1199
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1200 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1201 HloInstruction* operand,
1202 int64_t inferred_dimension)
1203 : HloInstruction(HloOpcode::kReshape, shape),
1204 inferred_dimension_(inferred_dimension) {
1205 AppendOperand(operand);
1206 }
1207
ToProto() const1208 HloInstructionProto HloReshapeInstruction::ToProto() const {
1209 HloInstructionProto proto = HloInstruction::ToProto();
1210 if (inferred_dimension_ != -1) {
1211 proto.add_dimensions(inferred_dimension_);
1212 }
1213 return proto;
1214 }
1215
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1216 std::vector<string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1217 const HloPrintOptions& options) const {
1218 if (inferred_dimension() == -1) {
1219 return {};
1220 }
1221 return {StrCat("inferred_dimension=", inferred_dimension())};
1222 }
1223
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1224 bool HloReshapeInstruction::IdenticalSlowPath(
1225 const HloInstruction& other,
1226 const std::function<bool(const HloComputation*, const HloComputation*)>&
1227 eq_computations) const {
1228 const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1229 return inferred_dimension() == casted_other.inferred_dimension();
1230 }
1231
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1232 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1233 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1234 HloCloneContext* context) const {
1235 CHECK_EQ(new_operands.size(), 1);
1236 return absl::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1237 inferred_dimension());
1238 }
1239
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1240 HloMapInstruction::HloMapInstruction(const Shape& shape,
1241 absl::Span<HloInstruction* const> operands,
1242 HloComputation* map_computation)
1243 : HloInstruction(HloOpcode::kMap, shape) {
1244 for (auto operand : operands) {
1245 AppendOperand(operand);
1246 }
1247 AppendComputation(map_computation);
1248 // TODO(b/65689298) Remove code below once Map is generalized to accept
1249 // arbitrary map dimensions.
1250 dimensions_.resize(shape.rank());
1251 std::iota(dimensions_.begin(), dimensions_.end(), 0);
1252 }
1253
ToProto() const1254 HloInstructionProto HloMapInstruction::ToProto() const {
1255 HloInstructionProto proto = HloInstruction::ToProto();
1256 for (int64_t dimension : dimensions_) {
1257 proto.add_dimensions(dimension);
1258 }
1259 return proto;
1260 }
1261
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1262 bool HloMapInstruction::IsElementwiseImpl(
1263 const absl::optional<int64>& operand_idx) const {
1264 if (!dimensions().empty()) {
1265 // Check that the map is executed in elementwise compatible dimensions.
1266 if (dimensions().size() != shape().dimensions_size()) {
1267 return false;
1268 }
1269 for (int i = 0; i < dimensions().size(); ++i) {
1270 if (dimensions()[i] != i) {
1271 return false;
1272 }
1273 }
1274 }
1275 return true;
1276 }
1277
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1278 std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
1279 const HloPrintOptions& options) const {
1280 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1281 }
1282
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1283 bool HloMapInstruction::IdenticalSlowPath(
1284 const HloInstruction& other,
1285 const std::function<bool(const HloComputation*, const HloComputation*)>&
1286 eq_computations) const {
1287 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1288 return eq_computations(to_apply(), casted_other.to_apply()) &&
1289 dimensions() == casted_other.dimensions();
1290 }
1291
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1292 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1293 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1294 HloCloneContext* context) const {
1295 return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1296 }
1297
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1298 HloSliceInstruction::HloSliceInstruction(const Shape& shape,
1299 HloInstruction* operand,
1300 absl::Span<const int64> start_indices,
1301 absl::Span<const int64> limit_indices,
1302 absl::Span<const int64> strides)
1303 : HloInstruction(HloOpcode::kSlice, shape),
1304 slice_starts_(start_indices.begin(), start_indices.end()),
1305 slice_limits_(limit_indices.begin(), limit_indices.end()),
1306 slice_strides_(strides.begin(), strides.end()) {
1307 AppendOperand(operand);
1308 // For backward compatibility with old serialized computations: if there are
1309 // no strides, assume all strides are 1.
1310 // TODO(b/63317920): remove this code.
1311 if (slice_strides_.empty()) {
1312 slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
1313 }
1314 }
1315
ToProto() const1316 HloInstructionProto HloSliceInstruction::ToProto() const {
1317 HloInstructionProto proto = HloInstruction::ToProto();
1318 for (int i = 0; i < slice_starts_.size(); ++i) {
1319 auto* slice_dimension = proto.add_slice_dimensions();
1320 slice_dimension->set_start(slice_starts_[i]);
1321 slice_dimension->set_limit(slice_limits_[i]);
1322 slice_dimension->set_stride(slice_strides_[i]);
1323 }
1324 return proto;
1325 }
1326
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1327 std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
1328 const HloPrintOptions& options) const {
1329 std::vector<string> bounds;
1330 bounds.reserve(slice_starts_.size());
1331 const bool omit_stride = absl::c_all_of(
1332 slice_strides_, [](int64_t stride) { return stride == 1; });
1333 for (int i = 0; i < slice_starts_.size(); ++i) {
1334 string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1335 bounds.push_back(
1336 StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1337 }
1338 return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1339 }
1340
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1341 bool HloSliceInstruction::IdenticalSlowPath(
1342 const HloInstruction& other,
1343 const std::function<bool(const HloComputation*, const HloComputation*)>&
1344 eq_computations) const {
1345 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1346 return slice_starts_ == other_slice.slice_starts_ &&
1347 slice_limits_ == other_slice.slice_limits_ &&
1348 slice_strides_ == other_slice.slice_strides_;
1349 }
1350
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1351 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1352 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1353 HloCloneContext* context) const {
1354 CHECK_EQ(new_operands.size(), 1);
1355 return absl::make_unique<HloSliceInstruction>(
1356 shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1357 }
1358
HloConstantInstruction(Literal literal)1359 HloConstantInstruction::HloConstantInstruction(Literal literal)
1360 : HloInstruction(HloOpcode::kConstant, literal.shape()),
1361 literal_(std::move(literal)) {}
1362
HloConstantInstruction(Literal literal,const Shape & shape)1363 HloConstantInstruction::HloConstantInstruction(Literal literal,
1364 const Shape& shape)
1365 : HloInstruction(HloOpcode::kConstant, shape),
1366 literal_(std::move(literal)) {}
1367
HloConstantInstruction(const Shape & shape)1368 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1369 : HloInstruction(HloOpcode::kConstant, shape) {}
1370
ToProto() const1371 HloInstructionProto HloConstantInstruction::ToProto() const {
1372 HloInstructionProto proto = HloInstruction::ToProto();
1373 if (literal_.has_value()) {
1374 *proto.mutable_literal() = literal_->ToProto();
1375 }
1376 return proto;
1377 }
1378
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1379 bool HloConstantInstruction::IsElementwiseImpl(
1380 const absl::optional<int64>& operand_idx) const {
1381 return true;
1382 }
1383
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1384 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1385 const ShapeIndex& shape_index) {
1386 Shape* mutable_array_subshape =
1387 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1388 CHECK(mutable_array_subshape->IsArray());
1389
1390 // Normally array_subshape will always have a layout, but this invariant is
1391 // temporarily broken in LayoutAssignment::AssignLayouts.
1392
1393 if (!mutable_array_subshape->has_layout() ||
1394 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1395 *literal_ = literal_->Relayout(new_layout, shape_index);
1396 *mutable_array_subshape->mutable_layout() = new_layout;
1397 }
1398 }
1399
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1400 bool HloConstantInstruction::IdenticalSlowPath(
1401 const HloInstruction& other,
1402 const std::function<bool(const HloComputation*, const HloComputation*)>&
1403 eq_computations) const {
1404 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1405 return literal() == other_slice.literal();
1406 }
1407
1408 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1409 HloConstantInstruction::CloneWithNewOperandsImpl(
1410 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1411 HloCloneContext* context) const {
1412 CHECK(literal_.has_value());
1413 // Literal's shape may have no/different tiling info. Use this instruction's
1414 // shape instead.
1415 CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1416 this->shape()));
1417 return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
1418 this->shape());
1419 }
1420
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1421 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1422 const HloPrintOptions& options,
1423 CanonicalNameMap* canonical_name_map) const {
1424 if (options.print_only_essential_constants()) {
1425 if (!literal_.has_value()) {
1426 return "{...}";
1427 }
1428 if (literal().IsAll(0)) {
1429 return "0";
1430 }
1431 if (literal().IsAll(1)) {
1432 return "1";
1433 }
1434 if (shape().IsInteger()) {
1435 return literal_->ToStringWithoutShapeOneline();
1436 }
1437 return "{...}";
1438 }
1439
1440 // For constants, show the actual value in place of an empty operand list.
1441 if (literal_.has_value() &&
1442 ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1443 options.print_large_constants())) {
1444 // Literal::ToString emits multidimensional arrays over multiple
1445 // lines. Compact this into one line by stripping out white space.
1446 return literal_->ToStringWithoutShapeOneline();
1447 } else {
1448 // Do not show large constants or tuples.
1449 return "{...}";
1450 }
1451 }
1452
HloTraceInstruction(const string & tag,HloInstruction * operand)1453 HloTraceInstruction::HloTraceInstruction(const string& tag,
1454 HloInstruction* operand)
1455 : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
1456 literal_(LiteralUtil::CreateR1U8(tag)) {
1457 AppendOperand(operand);
1458 operand->set_tracing(this);
1459 }
1460
ToProto() const1461 HloInstructionProto HloTraceInstruction::ToProto() const {
1462 HloInstructionProto proto = HloInstruction::ToProto();
1463 *proto.mutable_literal() = literal_.ToProto();
1464 return proto;
1465 }
1466
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1467 bool HloTraceInstruction::IdenticalSlowPath(
1468 const HloInstruction& other,
1469 const std::function<bool(const HloComputation*, const HloComputation*)>&
1470 eq_computations) const {
1471 return false;
1472 }
1473
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1474 std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
1475 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1476 HloCloneContext* context) const {
1477 LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
1478 }
1479
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1480 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1481 FusionKind fusion_kind,
1482 HloInstruction* fused_root)
1483 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1484 CHECK(fused_root != nullptr);
1485 SetAndSanitizeName("fusion");
1486 set_parent(fused_root->parent());
1487 set_metadata(fused_root->metadata());
1488 CloneAndFuseInternal(fused_root);
1489 }
1490
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1491 HloFusionInstruction::HloFusionInstruction(
1492 const Shape& shape, FusionKind fusion_kind,
1493 absl::Span<HloInstruction* const> operands,
1494 HloComputation* fusion_computation)
1495 : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
1496 for (auto operand : operands) {
1497 AppendOperand(operand);
1498 }
1499 SetAndSanitizeName("fusion");
1500 AppendComputation(fusion_computation);
1501 fusion_computation->SetFusionInstruction(this);
1502 }
1503
~HloFusionInstruction()1504 HloFusionInstruction::~HloFusionInstruction() {
1505 ClearFusionComputationInstruction();
1506 }
1507
ClearFusionComputationInstruction()1508 void HloFusionInstruction::ClearFusionComputationInstruction() {
1509 // Each fusion calls a single computation, but we use called_computations()
1510 // instead of fused_instructions_computation(), because the order in which
1511 // things get destructed can vary; the fusion computation's back-pointer may
1512 // already be null, which violates a check in fused_instructions_computation.
1513 for (HloComputation* computation : called_computations()) {
1514 // Some passes that rewrite fusions may reassign a fusion computation to a
1515 // different fusion instruction as this instruction gets destructed.
1516 if (computation->FusionInstruction() == this) {
1517 computation->SetFusionInstruction(nullptr);
1518 }
1519 }
1520 }
1521
ClearCalledComputations()1522 void HloFusionInstruction::ClearCalledComputations() {
1523 ClearFusionComputationInstruction();
1524 HloInstruction::ClearCalledComputations();
1525 }
1526
ToCategory() const1527 string HloFusionInstruction::ToCategory() const {
1528 switch (fusion_kind()) {
1529 case FusionKind::kLoop:
1530 return "loop fusion";
1531 case FusionKind::kInput:
1532 return "input fusion";
1533 case FusionKind::kOutput:
1534 return "output fusion";
1535 case FusionKind::kCustom:
1536 return "custom fusion";
1537 }
1538 }
1539
ToProto() const1540 HloInstructionProto HloFusionInstruction::ToProto() const {
1541 HloInstructionProto proto = HloInstruction::ToProto();
1542 proto.set_fusion_kind(xla::ToString(fusion_kind()));
1543 proto.add_called_computation_ids(
1544 fused_instructions_computation()->unique_id());
1545 return proto;
1546 }
1547
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const1548 bool HloFusionInstruction::IsElementwiseImpl(
1549 const absl::optional<int64>& operand_idx) const {
1550 if (!operand_idx.has_value()) {
1551 for (auto* fused : fused_instructions()) {
1552 if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1553 return false;
1554 }
1555 }
1556 return true;
1557 }
1558 // A loop-fusion is elementwise on an operand if all operations (computed
1559 // using BFS) between the operand and the fused root are elementwise.
1560 std::deque<HloInstruction*> worklist;
1561 std::unordered_set<const HloInstruction*> visited;
1562 worklist.push_back(fused_parameter(operand_idx.value()));
1563 visited.insert(fused_parameter(operand_idx.value()));
1564 while (!worklist.empty()) {
1565 HloInstruction* operand = worklist.front();
1566 worklist.pop_front();
1567 for (HloInstruction* user : operand->users()) {
1568 CHECK_GE(user->unique_id(), 0);
1569 if (ContainsKey(visited, user)) {
1570 continue;
1571 }
1572 if (user->IsElementwise() ||
1573 IsInstructionElementwiseOnOperand(user, operand)) {
1574 worklist.push_back(user);
1575 visited.insert(user);
1576 } else {
1577 return false;
1578 }
1579 }
1580 }
1581 return true;
1582 }
1583
AddFusionOperand(HloInstruction * new_operand)1584 HloInstruction* HloFusionInstruction::AddFusionOperand(
1585 HloInstruction* new_operand) {
1586 CHECK_EQ(operand_count(),
1587 fused_instructions_computation()->parameter_instructions().size());
1588 const int64_t param_no = operand_count();
1589 string param_name = StrCat("param_", param_no);
1590 HloInstruction* fused_parameter =
1591 fused_instructions_computation()->AddParameter(
1592 HloInstruction::CreateParameter(param_no, new_operand->shape(),
1593 param_name));
1594 AppendOperand(new_operand);
1595 return fused_parameter;
1596 }
1597
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1598 void HloFusionInstruction::MergeFusionInstruction(
1599 HloFusionInstruction* instruction_to_merge) {
1600 CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1601 // Clone the instruction from which to merge fused instructions.
1602 std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1603 HloFusionInstruction* cloned_fusion =
1604 static_cast<HloFusionInstruction*>(cloned.get());
1605 // Replace uses of fused parameters with the corresponding operand of the
1606 // fusion. Add all non-parameter fused instructions to
1607 // 'unfused_instructions' to be merged into 'this'. This is done in reverse
1608 // post order.
1609 std::vector<HloInstruction*> unfused_instructions;
1610 auto fused_instructions = cloned_fusion->fused_instructions_computation()
1611 ->MakeInstructionPostOrder();
1612 for (auto fused_it = fused_instructions.rbegin();
1613 fused_it != fused_instructions.rend(); ++fused_it) {
1614 auto fused_instruction = *fused_it;
1615 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1616 TF_CHECK_OK(
1617 fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1618 fused_instruction->parameter_number())));
1619 } else {
1620 unfused_instructions.push_back(fused_instruction);
1621 }
1622 }
1623
1624 // If there are no unfused instructions, the fused computation must consist
1625 // only of kParameter instructions. Make the operand of the corresponding
1626 // parameter number the new root.
1627 HloInstruction* unfused_root =
1628 unfused_instructions.empty()
1629 ? instruction_to_merge->mutable_operand(
1630 instruction_to_merge->fused_instructions_computation()
1631 ->root_instruction()
1632 ->parameter_number())
1633 : unfused_instructions.front();
1634 CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1635 unfused_instructions.empty());
1636 // Replace instruction_to_merge use of 'this' with unfused_root.
1637 TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1638
1639 // Build a dummy root for the cloned fusion as we may remove the original root
1640 // in the fusion process.
1641 if (!unfused_instructions.empty()) {
1642 HloComputation* computation = unfused_root->parent();
1643 auto* dummy_root = computation->AddInstruction(
1644 HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1645 computation->set_root_instruction(dummy_root,
1646 /*accept_different_shape=*/true);
1647 }
1648
1649 // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1650 // we remove it from the closed fusion node. This is so that we don't add
1651 // extra users to the producer of that instruction (we use user count to
1652 // decide if a side-effectful instruction is fusible).
1653 for (auto& instruction : unfused_instructions) {
1654 auto* fused = FuseInstruction(instruction);
1655 TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1656 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1657 }
1658 CHECK_EQ(0, cloned_fusion->user_count());
1659 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1660 cloned_fusion->fused_instructions_computation()));
1661 }
1662
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1663 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1664 HloFusionInstruction* instruction_to_merge) {
1665 // Add all non-parameter fused instructions to 'unfused_instructions' to be
1666 // merged into 'this'. `old_to_new' maps the instructions in the fused node
1667 // to the disassembled fusion instructions.
1668 // Note that we add the unfused instructions to this->parent_ computation.
1669 // This is necessary because the unique_id needs for an instruction and
1670 // it's only added when inserting to the computation.
1671 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1672 std::vector<HloInstruction*> unfused_instructions;
1673 auto computation_to_merge =
1674 instruction_to_merge->fused_instructions_computation();
1675 auto post_order = computation_to_merge->MakeInstructionPostOrder();
1676 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1677 auto fused_instruction = *rit;
1678 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1679 InsertOrDie(&old_to_new, fused_instruction,
1680 instruction_to_merge->mutable_operand(
1681 fused_instruction->parameter_number()));
1682 continue;
1683 }
1684
1685 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
1686 // which clones again. This can be improved.
1687 auto cloned_instruction =
1688 parent()->AddInstruction(fused_instruction->Clone());
1689 unfused_instructions.push_back(cloned_instruction);
1690 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
1691 }
1692 for (auto unfused_instruction : unfused_instructions) {
1693 for (int64_t index = 0; index < unfused_instruction->operand_count();
1694 index++) {
1695 auto new_operand =
1696 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
1697 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
1698 }
1699 }
1700
1701 // If there are no unfused instructions, the fused computation must consist
1702 // only of kParameter instructions. Make the operand of the corresponding
1703 // parameter number the new root.
1704 HloInstruction* unfused_root =
1705 unfused_instructions.empty()
1706 ? instruction_to_merge->mutable_operand(
1707 instruction_to_merge->fused_instructions_computation()
1708 ->root_instruction()
1709 ->parameter_number())
1710 : unfused_instructions.front();
1711 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
1712
1713 TF_CHECK_OK(
1714 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
1715 if (GetModule()) {
1716 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
1717 }
1718
1719 // Fuse the root instruction and generate multiple outputs.
1720 if (unfused_instructions.empty()) {
1721 return;
1722 }
1723 FuseInstructionIntoMultiOutput(unfused_root);
1724 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
1725 // The rest instructions are of normal fusing.
1726 for (int64_t i = 1; i < unfused_instructions.size(); i++) {
1727 auto instruction = unfused_instructions[i];
1728 FuseInstruction(instruction);
1729 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1730 }
1731 }
1732
fused_instructions_computation() const1733 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
1734 CHECK(!called_computations().empty());
1735 auto* fused_instructions_computation = called_computations().front();
1736 CHECK(fused_instructions_computation->IsFusionComputation())
1737 << "Computation " << fused_instructions_computation->name()
1738 << " is not a fusion kind";
1739 return fused_instructions_computation;
1740 }
1741
fused_expression_root() const1742 HloInstruction* HloFusionInstruction::fused_expression_root() const {
1743 return fused_instructions_computation()->root_instruction();
1744 }
1745
fused_parameter(int64_t parameter_number) const1746 HloInstruction* HloFusionInstruction::fused_parameter(
1747 int64_t parameter_number) const {
1748 return fused_instructions_computation()->parameter_instruction(
1749 parameter_number);
1750 }
1751
fused_parameters() const1752 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
1753 const {
1754 return fused_instructions_computation()->parameter_instructions();
1755 }
1756
1757 const tensorflow::gtl::iterator_range<UnwrappingIterator<
1758 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const1759 HloFusionInstruction::fused_instructions() const {
1760 const HloComputation* subcomp = fused_instructions_computation();
1761 return subcomp->instructions();
1762 }
1763
1764 const tensorflow::gtl::iterator_range<
1765 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()1766 HloFusionInstruction::fused_instructions() {
1767 return fused_instructions_computation()->instructions();
1768 }
1769
fused_instruction_count() const1770 int64 HloFusionInstruction::fused_instruction_count() const {
1771 return fused_instructions_computation()->instruction_count();
1772 }
1773
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)1774 HloInstruction* HloFusionInstruction::FuseInstructionInternal(
1775 HloInstruction* instruction_to_fuse, bool add_output) {
1776 // When add_output is false, this fusion instruction must be a user of
1777 // instruction_to_fuse.
1778 if (!add_output) {
1779 CHECK(IsUserOf(instruction_to_fuse));
1780 }
1781 HloInstruction* fused_instruction =
1782 CloneAndFuseInternal(instruction_to_fuse, add_output);
1783 return fused_instruction;
1784 }
1785
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)1786 HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
1787 HloInstruction* instruction_to_fuse, bool add_output) {
1788 CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
1789 VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
1790 HloInstruction* clone = nullptr;
1791 if (called_computations().empty()) {
1792 // New fusion instruction. It should not be a multioutput instruction.
1793 CHECK(!add_output);
1794 auto builder = HloComputation::Builder("fused_computation", this);
1795 builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
1796 AppendComputation(
1797 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1798 clone = fused_expression_root();
1799 } else {
1800 // When add_output is false, instruction_to_fuse is necessarily an operand
1801 // of the fusion instruction. After fusion this will no longer be the
1802 // case. Remove the operand from the operand list and remove its
1803 // corresponding fused parameter instruction. Renumber parameters as
1804 // necessary to make parameter numbers consistent with their index in the
1805 // fused_parameter_ vector.
1806 bool in_operand_list =
1807 absl::c_linear_search(operands(), instruction_to_fuse);
1808 CHECK(add_output || in_operand_list);
1809 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1810 // We assume all uses of a kTuple operation are GTE ops, not another
1811 // fusion node. In this case, we don't need to clone
1812 // 'instruction_to_fuse'.
1813 CHECK(!in_operand_list);
1814 clone = instruction_to_fuse;
1815 } else {
1816 clone = fused_instructions_computation()->AddInstruction(
1817 instruction_to_fuse->Clone(/*suffix=*/""));
1818 }
1819 const std::vector<HloInstruction*>& fused_parameters =
1820 fused_instructions_computation()->parameter_instructions();
1821 for (int64_t operand_num = 0; operand_num < operand_count();
1822 ++operand_num) {
1823 if (instruction_to_fuse == operand(operand_num)) {
1824 // replace the fused parameter instruction's uses with the clone.
1825 HloInstruction* fused_parameter = fused_parameters[operand_num];
1826 TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
1827
1828 // Remove the corresponding fused parameter and operand from their
1829 // respective vectors.
1830 TF_CHECK_OK(
1831 fused_instructions_computation()->RemoveParameter(operand_num));
1832 RemoveOperandAt(operand_num);
1833 break;
1834 }
1835 }
1836 // We've cloned instruction_to_fuse into this fusion instruction, so this
1837 // fusion instruction is no longer a use of instruction_to_fuse.
1838 if (in_operand_list) {
1839 DetachFrom(instruction_to_fuse);
1840 // When the instruction_to_fuse does not have other users, we don't need
1841 // to generate a multioutput fusion instruction.
1842 if (instruction_to_fuse->user_count() == 0) {
1843 add_output = false;
1844 }
1845 }
1846 }
1847
1848 // Reread the parameters in the computation.
1849 const std::vector<HloInstruction*>& fused_parameters =
1850 fused_instructions_computation()->parameter_instructions();
1851
1852 // Add each operand of the clone as an operand of the fusion instruction. A
1853 // complication is that some clone operands may already be operands of the
1854 // fusion instruction.
1855 for (int64_t operand_num = 0; operand_num < clone->operand_count();
1856 ++operand_num) {
1857 HloInstruction* operand = clone->mutable_operand(operand_num);
1858
1859 // See if this operand is already an operand of the fusion node.
1860 CHECK_EQ(operands().size(), fused_parameters.size());
1861 HloInstruction* fused_param = nullptr;
1862 for (int64_t i = 0; i < operands().size(); ++i) {
1863 if (this->operand(i) == operand) {
1864 fused_param = fused_parameters[i];
1865 break;
1866 }
1867 }
1868
1869 if (fused_param == nullptr) {
1870 // Clone's operand was not already an operand of the fusion
1871 // instruction. Add it as an operand and add a corresponding fused
1872 // parameter instruction.
1873 fused_param = AddFusionOperand(operand);
1874 }
1875 TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1876 }
1877
1878 if (add_output) {
1879 CHECK_GT(instruction_to_fuse->user_count(), 0);
1880 // If this is already a multioutput fusion instruction, expand the root
1881 // tuple by 1.
1882 HloInstruction* fused_root = fused_expression_root();
1883 HloInstruction::InstructionVector tuple_elements;
1884 bool newly_created_tuple_instr = false;
1885 if (fused_root->opcode() == HloOpcode::kTuple) {
1886 tuple_elements = fused_root->operands();
1887 } else {
1888 tuple_elements.push_back(fused_root);
1889 newly_created_tuple_instr = true;
1890 }
1891 if (clone->opcode() == HloOpcode::kTuple) {
1892 for (auto inst : clone->operands()) {
1893 tuple_elements.push_back(inst);
1894 }
1895 } else {
1896 tuple_elements.push_back(clone);
1897 }
1898 HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1899 HloInstruction::CreateTuple(tuple_elements));
1900 fused_instructions_computation()->set_root_instruction(new_root);
1901 *mutable_shape() = new_root->shape();
1902 if (fused_root->opcode() == HloOpcode::kTuple) {
1903 TF_CHECK_OK(
1904 fused_instructions_computation()->RemoveInstruction(fused_root));
1905 }
1906
1907 // If this is a newly created multioutput instruction, we need to update
1908 // the use of the original fusion instruction.
1909 if (newly_created_tuple_instr) {
1910 HloInstruction* new_instr = parent()->AddInstruction(
1911 HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1912 TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1913 }
1914 int64_t index = tuple_elements.size();
1915 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1916 CHECK_EQ(clone, instruction_to_fuse);
1917 index -= clone->operand_count();
1918 std::vector<HloInstruction*> to_be_removed;
1919 for (auto old_gte : clone->users()) {
1920 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1921 int64_t old_tuple_index = old_gte->tuple_index();
1922 HloInstruction* new_gte =
1923 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1924 old_gte->shape(), this, index + old_tuple_index));
1925 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1926 to_be_removed.push_back(old_gte);
1927 }
1928 for (auto old_gte : to_be_removed) {
1929 TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1930 }
1931 } else {
1932 HloInstruction* new_gte =
1933 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1934 clone->shape(), this, index - 1));
1935 TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1936 }
1937 }
1938
1939 if (clone != instruction_to_fuse) {
1940 VLOG(2) << "New clone:\n" << clone->ToString();
1941 }
1942 return clone;
1943 }
1944
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1945 std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
1946 const HloPrintOptions& options) const {
1947 return {StrCat("kind=", xla::ToString(fusion_kind()))};
1948 }
1949
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1950 bool HloFusionInstruction::IdenticalSlowPath(
1951 const HloInstruction& other,
1952 const std::function<bool(const HloComputation*, const HloComputation*)>&
1953 eq_computations) const {
1954 return fusion_kind() == other.fusion_kind() &&
1955 eq_computations(fused_instructions_computation(),
1956 other.fused_instructions_computation());
1957 }
1958
InnerHash() const1959 uint64 HloFusionInstruction::InnerHash() const {
1960 return fused_instructions_computation()->root_instruction()->Hash();
1961 }
1962
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1963 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
1964 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1965 HloCloneContext* context) const {
1966 HloModule* module = context != nullptr ? context->module() : GetModule();
1967 HloComputation* new_fused_computation = nullptr;
1968 if (context != nullptr) {
1969 new_fused_computation =
1970 context->FindComputation(fused_instructions_computation());
1971 }
1972 if (new_fused_computation == nullptr) {
1973 new_fused_computation = module->AddEmbeddedComputation(
1974 fused_instructions_computation()->Clone("clone", context));
1975 }
1976 return absl::make_unique<HloFusionInstruction>(
1977 shape, fusion_kind(), new_operands, new_fused_computation);
1978 }
1979
DeduplicateFusionOperands()1980 Status HloFusionInstruction::DeduplicateFusionOperands() {
1981 if (IsCustomFusion()) {
1982 return Status::OK();
1983 }
1984 absl::flat_hash_map<const HloInstruction*, int> operand_indices;
1985 std::vector<int> operands_to_remove;
1986 for (int i = 0; i < operand_count(); ++i) {
1987 auto emplace_result = operand_indices.emplace(operand(i), i);
1988 if (!emplace_result.second) {
1989 TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
1990 fused_parameter(emplace_result.first->second)));
1991 operands_to_remove.push_back(i);
1992 }
1993 }
1994 if (operands_to_remove.empty()) {
1995 return Status::OK();
1996 }
1997 TF_RETURN_IF_ERROR(fused_instructions_computation()
1998 ->RemoveUnusedParametersFromFusedComputation());
1999 RemoveOperandsAtAscendingIndices(operands_to_remove);
2000 return Status::OK();
2001 }
2002
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)2003 HloRngInstruction::HloRngInstruction(
2004 const Shape& shape, RandomDistribution distribution,
2005 absl::Span<HloInstruction* const> parameters)
2006 : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
2007 for (HloInstruction* param : parameters) {
2008 AppendOperand(param);
2009 }
2010 }
2011
ToProto() const2012 HloInstructionProto HloRngInstruction::ToProto() const {
2013 HloInstructionProto proto = HloInstruction::ToProto();
2014 proto.set_distribution(distribution_);
2015 return proto;
2016 }
2017
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2018 std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
2019 const HloPrintOptions& options) const {
2020 return {StrCat("distribution=", RandomDistributionToString(distribution_))};
2021 }
2022
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2023 bool HloRngInstruction::IsElementwiseImpl(
2024 const absl::optional<int64>& operand_idx) const {
2025 return true;
2026 }
2027
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2028 bool HloRngInstruction::IdenticalSlowPath(
2029 const HloInstruction& other,
2030 const std::function<bool(const HloComputation*, const HloComputation*)>&
2031 eq_computations) const {
2032 const auto& casted_other = static_cast<const HloRngInstruction&>(other);
2033 return distribution_ == casted_other.distribution_;
2034 }
2035
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2036 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
2037 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2038 HloCloneContext* context) const {
2039 return absl::make_unique<HloRngInstruction>(shape, distribution_,
2040 new_operands);
2041 }
2042
HloParameterInstruction(int64_t parameter_number,const Shape & shape,const string & name)2043 HloParameterInstruction::HloParameterInstruction(int64_t parameter_number,
2044 const Shape& shape,
2045 const string& name)
2046 : HloInstruction(HloOpcode::kParameter, shape),
2047 parameter_number_(parameter_number) {
2048 SetAndSanitizeName(name);
2049 }
2050
ToProto() const2051 HloInstructionProto HloParameterInstruction::ToProto() const {
2052 HloInstructionProto proto = HloInstruction::ToProto();
2053 proto.set_parameter_number(parameter_number_);
2054 if (parameter_replicated_at_leaf_buffers_) {
2055 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2056 proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
2057 replicated);
2058 }
2059 }
2060 return proto;
2061 }
2062
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2063 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
2064 const HloPrintOptions& options) const {
2065 std::vector<string> result;
2066 if (!parameter_replicated_at_leaf_buffers_) {
2067 return result;
2068 }
2069 std::vector<string> buffers_replicated_strs;
2070 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2071 buffers_replicated_strs.push_back(replicated ? "true" : "false");
2072 }
2073 if (options.print_ids()) {
2074 result.push_back(StrCat("parameter_replication={",
2075 StrJoin(buffers_replicated_strs, ","), "}"));
2076 }
2077 return result;
2078 }
2079
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2080 string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
2081 const HloPrintOptions& options,
2082 CanonicalNameMap* canonical_name_map) const {
2083 return StrCat(parameter_number_);
2084 }
2085
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2086 bool HloParameterInstruction::IdenticalSlowPath(
2087 const HloInstruction& other,
2088 const std::function<bool(const HloComputation*, const HloComputation*)>&
2089 eq_computations) const {
2090 const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
2091 return parameter_number() == casted_other.parameter_number();
2092 }
2093
2094 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2095 HloParameterInstruction::CloneWithNewOperandsImpl(
2096 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2097 HloCloneContext* context) const {
2098 auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_,
2099 shape, name());
2100 if (parameter_replicated_at_leaf_buffers_ &&
2101 ShapeUtil::Equal(shape, this->shape())) {
2102 clone->set_parameter_replicated_at_leaf_buffers(
2103 *parameter_replicated_at_leaf_buffers_);
2104 }
2105 return clone;
2106 }
2107
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64_t index)2108 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
2109 const Shape& shape, HloInstruction* operand, int64_t index)
2110 : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
2111 AppendOperand(operand);
2112 }
2113
ToProto() const2114 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
2115 HloInstructionProto proto = HloInstruction::ToProto();
2116 proto.set_tuple_index(tuple_index_);
2117 return proto;
2118 }
2119
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2120 std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
2121 const HloPrintOptions& options) const {
2122 return {StrCat("index=", tuple_index())};
2123 }
2124
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2125 bool HloGetTupleElementInstruction::IdenticalSlowPath(
2126 const HloInstruction& other,
2127 const std::function<bool(const HloComputation*, const HloComputation*)>&
2128 eq_computations) const {
2129 const auto& casted_other =
2130 static_cast<const HloGetTupleElementInstruction&>(other);
2131 return tuple_index() == casted_other.tuple_index();
2132 }
2133
2134 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2135 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
2136 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2137 HloCloneContext* context) const {
2138 CHECK_EQ(new_operands.size(), 1);
2139 return absl::make_unique<HloGetTupleElementInstruction>(
2140 shape, new_operands[0], tuple_index());
2141 }
2142
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2143 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2144 const Shape& shape, HloInstruction* operand, const int exponent_bits,
2145 const int mantissa_bits)
2146 : HloInstruction(HloOpcode::kReducePrecision, shape),
2147 exponent_bits_(exponent_bits),
2148 mantissa_bits_(mantissa_bits) {
2149 AppendOperand(operand);
2150 }
2151
ToProto() const2152 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2153 HloInstructionProto proto = HloInstruction::ToProto();
2154 proto.set_exponent_bits(exponent_bits_);
2155 proto.set_mantissa_bits(mantissa_bits_);
2156 return proto;
2157 }
2158
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2159 std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2160 const HloPrintOptions& options) const {
2161 return {StrCat("exponent_bits=", exponent_bits_),
2162 StrCat("mantissa_bits=", mantissa_bits_)};
2163 }
2164
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2165 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2166 const HloInstruction& other,
2167 const std::function<bool(const HloComputation*, const HloComputation*)>&
2168 eq_computations) const {
2169 const auto& casted_other =
2170 static_cast<const HloReducePrecisionInstruction&>(other);
2171 // A reduce-precision operation is determined by the bit sizes.
2172 return exponent_bits() == casted_other.exponent_bits() &&
2173 mantissa_bits() == casted_other.mantissa_bits();
2174 }
2175
2176 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2177 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2178 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2179 HloCloneContext* context) const {
2180 CHECK_EQ(new_operands.size(), 1);
2181 return absl::make_unique<HloReducePrecisionInstruction>(
2182 shape, new_operands[0], exponent_bits(), mantissa_bits());
2183 }
2184
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)2185 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2186 HloInstruction* token_operand,
2187 const string& config)
2188 : HloInstruction(HloOpcode::kInfeed,
2189 ShapeUtil::MakeTupleShape(
2190 {infeed_shape, ShapeUtil::MakeTokenShape()})),
2191 infeed_config_(config) {
2192 AppendOperand(token_operand);
2193 }
2194
ToProto() const2195 HloInstructionProto HloInfeedInstruction::ToProto() const {
2196 HloInstructionProto proto = HloInstruction::ToProto();
2197 proto.set_infeed_config(infeed_config_);
2198 return proto;
2199 }
2200
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2201 std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2202 const HloPrintOptions& options) const {
2203 if (!options.print_infeed_outfeed_config() || infeed_config_.empty()) {
2204 return {};
2205 }
2206 return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2207 }
2208
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2209 bool HloInfeedInstruction::IdenticalSlowPath(
2210 const HloInstruction& other,
2211 const std::function<bool(const HloComputation*, const HloComputation*)>&
2212 eq_computations) const {
2213 // Not yet supported.
2214 return false;
2215 }
2216
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2217 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2218 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2219 HloCloneContext* context) const {
2220 CHECK_EQ(new_operands.size(), 1);
2221 return absl::make_unique<HloInfeedInstruction>(
2222 infeed_shape(), new_operands[0], infeed_config());
2223 }
2224
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2225 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2226 HloInstruction* operand,
2227 HloInstruction* token_operand,
2228 absl::string_view outfeed_config)
2229 : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2230 outfeed_shape_(outfeed_shape),
2231 outfeed_config_(outfeed_config) {
2232 AppendOperand(operand);
2233 AppendOperand(token_operand);
2234 }
2235
ToProto() const2236 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2237 HloInstructionProto proto = HloInstruction::ToProto();
2238 proto.set_outfeed_config(outfeed_config());
2239 *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2240 return proto;
2241 }
2242
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2243 std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2244 const HloPrintOptions& options) const {
2245 std::vector<string> extra;
2246 extra.push_back(StrCat("outfeed_shape=",
2247 ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2248 if (options.print_infeed_outfeed_config() && !outfeed_config_.empty()) {
2249 extra.push_back(
2250 StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2251 }
2252 return extra;
2253 }
2254
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2255 bool HloOutfeedInstruction::IdenticalSlowPath(
2256 const HloInstruction& other,
2257 const std::function<bool(const HloComputation*, const HloComputation*)>&
2258 eq_computations) const {
2259 // Not yet supported.
2260 return false;
2261 }
2262
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2263 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2264 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2265 HloCloneContext* context) const {
2266 CHECK_EQ(new_operands.size(), 2);
2267 return absl::make_unique<HloOutfeedInstruction>(
2268 outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2269 }
2270
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2271 HloConvolutionInstruction::HloConvolutionInstruction(
2272 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2273 int64_t feature_group_count, int64_t batch_group_count,
2274 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
2275 const PrecisionConfig& precision_config)
2276 : HloInstruction(HloOpcode::kConvolution, shape),
2277 feature_group_count_(feature_group_count),
2278 batch_group_count_(batch_group_count),
2279 window_(window),
2280 convolution_dimension_numbers_(dimension_numbers),
2281 precision_config_(precision_config) {
2282 if (window_util::HasBaseDilation(window)) {
2283 SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2284 }
2285 if (window_util::HasWindowDilation(window)) {
2286 SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2287 }
2288 AppendOperand(lhs);
2289 AppendOperand(rhs);
2290 }
2291
ToCategory() const2292 string HloConvolutionInstruction::ToCategory() const {
2293 string category = "convolution";
2294 if (window_util::HasBaseDilation(window())) {
2295 category += " base-dilated";
2296 }
2297 if (window_util::HasWindowDilation(window())) {
2298 category += " window-dilated";
2299 }
2300 return category;
2301 }
2302
ToProto() const2303 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2304 HloInstructionProto proto = HloInstruction::ToProto();
2305 *proto.mutable_window() = window_;
2306 *proto.mutable_convolution_dimension_numbers() =
2307 convolution_dimension_numbers_;
2308 proto.set_feature_group_count(feature_group_count_);
2309 proto.set_batch_group_count(batch_group_count_);
2310 *proto.mutable_precision_config() = precision_config_;
2311 return proto;
2312 }
2313
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2314 std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2315 const HloPrintOptions& options) const {
2316 std::vector<string> extra;
2317 if (window_.dimensions_size() != 0) {
2318 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2319 }
2320 extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2321 convolution_dimension_numbers_)));
2322 if (feature_group_count_ != 1) {
2323 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2324 }
2325
2326 if (batch_group_count_ != 1) {
2327 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2328 }
2329
2330 string precision_config_string = PrecisionConfigToString(precision_config_);
2331 if (!precision_config_string.empty()) {
2332 extra.push_back(precision_config_string);
2333 }
2334 return extra;
2335 }
2336
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2337 bool HloConvolutionInstruction::IdenticalSlowPath(
2338 const HloInstruction& other,
2339 const std::function<bool(const HloComputation*, const HloComputation*)>&
2340 eq_computations) const {
2341 const auto& casted_other =
2342 static_cast<const HloConvolutionInstruction&>(other);
2343 if (feature_group_count_ != other.feature_group_count()) {
2344 return false;
2345 }
2346 if (batch_group_count_ != other.batch_group_count()) {
2347 return false;
2348 }
2349 return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2350 protobuf_util::ProtobufEquals(
2351 convolution_dimension_numbers(),
2352 casted_other.convolution_dimension_numbers()) &&
2353 protobuf_util::ProtobufEquals(precision_config(),
2354 casted_other.precision_config());
2355 }
2356
2357 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2358 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2359 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2360 HloCloneContext* context) const {
2361 CHECK_EQ(new_operands.size(), 2);
2362 return absl::make_unique<HloConvolutionInstruction>(
2363 shape, new_operands[0], new_operands[1], feature_group_count_,
2364 batch_group_count_, window(), convolution_dimension_numbers_,
2365 precision_config_);
2366 }
2367
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2368 HloReduceWindowInstruction::HloReduceWindowInstruction(
2369 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2370 const Window& window, HloComputation* reduce_computation)
2371 : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2372 absl::MakeSpan(&init_value, 1), window,
2373 reduce_computation) {}
2374
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2375 HloReduceWindowInstruction::HloReduceWindowInstruction(
2376 const Shape& shape, absl::Span<HloInstruction* const> operands,
2377 absl::Span<HloInstruction* const> init_values, const Window& window,
2378 HloComputation* reduce_computation)
2379 : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2380 for (auto* operand : operands) {
2381 AppendOperand(operand);
2382 }
2383 for (auto* init_value : init_values) {
2384 AppendOperand(init_value);
2385 }
2386 AppendComputation(reduce_computation);
2387 }
2388
ToProto() const2389 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2390 HloInstructionProto proto = HloInstruction::ToProto();
2391 *proto.mutable_window() = window_;
2392 return proto;
2393 }
2394
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2395 std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2396 const HloPrintOptions& options) const {
2397 std::vector<string> extra;
2398 if (window_.dimensions_size() != 0) {
2399 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2400 }
2401 return extra;
2402 }
2403
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2404 bool HloReduceWindowInstruction::IdenticalSlowPath(
2405 const HloInstruction& other,
2406 const std::function<bool(const HloComputation*, const HloComputation*)>&
2407 eq_computations) const {
2408 const auto& casted_other =
2409 static_cast<const HloReduceWindowInstruction&>(other);
2410 return eq_computations(to_apply(), casted_other.to_apply()) &&
2411 protobuf_util::ProtobufEquals(window(), casted_other.window());
2412 }
2413
2414 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2415 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2416 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2417 HloCloneContext* context) const {
2418 CHECK_EQ(new_operands.size() % 2, 0);
2419 int64_t num_operands = new_operands.size() / 2;
2420 return absl::make_unique<HloReduceWindowInstruction>(
2421 shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2422 absl::MakeSpan(new_operands)
2423 .subspan(num_operands, new_operands.size() / 2),
2424 window(), to_apply());
2425 }
2426
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2427 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2428 const Shape& shape, HloInstruction* operand, HloComputation* select,
2429 const Window& window, HloInstruction* source, HloInstruction* init_value,
2430 HloComputation* scatter)
2431 : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2432 AppendOperand(operand);
2433 AppendOperand(source);
2434 AppendOperand(init_value);
2435 // Select comes before scatter in the vector.
2436 AppendComputation(select);
2437 AppendComputation(scatter);
2438 }
2439
ToProto() const2440 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2441 HloInstructionProto proto = HloInstruction::ToProto();
2442 *proto.mutable_window() = window_;
2443 return proto;
2444 }
2445
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2446 std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2447 const HloPrintOptions& options) const {
2448 std::vector<string> extra;
2449 if (window_.dimensions_size() != 0) {
2450 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2451 }
2452 return extra;
2453 }
2454
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2455 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2456 const HloInstruction& other,
2457 const std::function<bool(const HloComputation*, const HloComputation*)>&
2458 eq_computations) const {
2459 const auto& casted_other =
2460 static_cast<const HloSelectAndScatterInstruction&>(other);
2461 return eq_computations(select(), casted_other.select()) &&
2462 eq_computations(scatter(), casted_other.scatter()) &&
2463 protobuf_util::ProtobufEquals(window(), casted_other.window());
2464 }
2465
2466 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2467 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2468 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2469 HloCloneContext* context) const {
2470 CHECK_EQ(new_operands.size(), 3);
2471 return absl::make_unique<HloSelectAndScatterInstruction>(
2472 shape, new_operands[0], select(), window(), new_operands[1],
2473 new_operands[2], scatter());
2474 }
2475
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2476 HloCustomCallInstruction::HloCustomCallInstruction(
2477 const Shape& shape, absl::Span<HloInstruction* const> operands,
2478 absl::string_view custom_call_target, string opaque,
2479 CustomCallApiVersion api_version)
2480 : HloInstruction(HloOpcode::kCustomCall, shape),
2481 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2482 feature_group_count_(1),
2483 batch_group_count_(1),
2484 layout_constrained_(false),
2485 padding_type_(PaddingType::PADDING_INVALID),
2486 custom_call_has_side_effect_(false),
2487 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2488 api_version_(api_version) {
2489 set_raw_backend_config_string(std::move(opaque));
2490 for (auto operand : operands) {
2491 AppendOperand(operand);
2492 }
2493 }
2494
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2495 HloCustomCallInstruction::HloCustomCallInstruction(
2496 const Shape& shape, absl::Span<HloInstruction* const> operands,
2497 HloComputation* to_apply, absl::string_view custom_call_target,
2498 string opaque, CustomCallApiVersion api_version)
2499 : HloInstruction(HloOpcode::kCustomCall, shape),
2500 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2501 feature_group_count_(1),
2502 batch_group_count_(1),
2503 layout_constrained_(false),
2504 padding_type_(PaddingType::PADDING_INVALID),
2505 custom_call_has_side_effect_(false),
2506 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2507 api_version_(api_version) {
2508 set_raw_backend_config_string(std::move(opaque));
2509 for (auto operand : operands) {
2510 AppendOperand(operand);
2511 }
2512 AppendComputation(to_apply);
2513 to_apply->SetCustomCallInstruction(this);
2514 }
2515
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)2516 HloCustomCallInstruction::HloCustomCallInstruction(
2517 const Shape& shape, absl::Span<HloInstruction* const> operands,
2518 absl::Span<HloComputation* const> called_computations,
2519 absl::string_view custom_call_target, string opaque,
2520 CustomCallApiVersion api_version)
2521 : HloInstruction(HloOpcode::kCustomCall, shape),
2522 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2523 feature_group_count_(1),
2524 batch_group_count_(1),
2525 layout_constrained_(false),
2526 padding_type_(PaddingType::PADDING_INVALID),
2527 custom_call_has_side_effect_(false),
2528 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2529 api_version_(api_version) {
2530 set_raw_backend_config_string(std::move(opaque));
2531 for (auto operand : operands) {
2532 AppendOperand(operand);
2533 }
2534 for (auto comp : called_computations) {
2535 AppendComputation(comp);
2536 }
2537 }
2538
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,absl::Span<const Shape> operand_shapes_with_layout,CustomCallApiVersion api_version)2539 HloCustomCallInstruction::HloCustomCallInstruction(
2540 const Shape& shape, absl::Span<HloInstruction* const> operands,
2541 absl::string_view custom_call_target, string opaque,
2542 absl::Span<const Shape> operand_shapes_with_layout,
2543 CustomCallApiVersion api_version)
2544 : HloInstruction(HloOpcode::kCustomCall, shape),
2545 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2546 feature_group_count_(1),
2547 batch_group_count_(1),
2548 layout_constrained_(true),
2549 padding_type_(PaddingType::PADDING_INVALID),
2550 operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2551 operand_shapes_with_layout.end()),
2552 custom_call_has_side_effect_(false),
2553 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2554 api_version_(api_version) {
2555 set_raw_backend_config_string(std::move(opaque));
2556 for (auto operand : operands) {
2557 AppendOperand(operand);
2558 }
2559 }
2560
ToProto() const2561 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2562 HloInstructionProto proto = HloInstruction::ToProto();
2563 if (window_ != nullptr) {
2564 *proto.mutable_window() = *window_;
2565 }
2566 if (convolution_dimension_numbers_ != nullptr) {
2567 *proto.mutable_convolution_dimension_numbers() =
2568 *convolution_dimension_numbers_;
2569 }
2570 proto.set_custom_call_target(custom_call_target_);
2571 proto.set_feature_group_count(feature_group_count_);
2572 proto.set_batch_group_count(batch_group_count_);
2573 *proto.mutable_precision_config() = precision_config_;
2574 proto.set_padding_type(padding_type_);
2575 if (layout_constrained()) {
2576 proto.set_constrain_layout(true);
2577 for (const Shape& shape : operand_shapes_with_layout_) {
2578 *proto.add_operand_shapes_with_layout() = shape.ToProto();
2579 }
2580 }
2581 proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2582 if (literal_.has_value()) {
2583 *proto.mutable_literal() = literal_->ToProto();
2584 }
2585 for (const auto& pair : output_to_operand_aliasing_) {
2586 auto aliasing = proto.add_custom_call_output_operand_aliasing();
2587 aliasing->set_operand_index(pair.second.first);
2588 for (int64_t index : pair.first) {
2589 aliasing->add_output_shape_index(index);
2590 }
2591 for (int64_t index : pair.second.second) {
2592 aliasing->add_operand_shape_index(index);
2593 }
2594 }
2595 proto.set_custom_call_schedule(custom_call_schedule_);
2596 proto.set_custom_call_api_version(api_version_);
2597 return proto;
2598 }
2599
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2600 std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2601 const HloPrintOptions& options) const {
2602 std::vector<string> extra;
2603 if (window_ != nullptr) {
2604 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2605 }
2606 if (convolution_dimension_numbers_ != nullptr) {
2607 extra.push_back(StrCat(
2608 "dim_labels=",
2609 ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2610 }
2611 if (feature_group_count_ != 1) {
2612 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2613 }
2614 if (batch_group_count_ != 1) {
2615 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2616 }
2617 string precision_config_string = PrecisionConfigToString(precision_config_);
2618 if (!precision_config_string.empty()) {
2619 extra.push_back(precision_config_string);
2620 }
2621 if (padding_type_ != PaddingType::PADDING_INVALID) {
2622 extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2623 }
2624 // By contract, we print the custom call target even if
2625 // options.print_subcomputation_mode() == kOff, because the call target is not
2626 // an HloComputation.
2627 extra.push_back(
2628 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2629
2630 if (layout_constrained()) {
2631 std::vector<string> shape_strings;
2632 for (const Shape& shape : operand_shapes_with_layout_) {
2633 shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2634 }
2635 extra.push_back(StrCat("operand_layout_constraints={",
2636 StrJoin(shape_strings, ", "), "}"));
2637 }
2638 if (custom_call_has_side_effect_) {
2639 extra.push_back("custom_call_has_side_effect=true");
2640 }
2641 if (literal_.has_value()) {
2642 extra.push_back(StrCat("literal=", literal_->ToStringWithLayoutOneline()));
2643 }
2644 if (!output_to_operand_aliasing_.empty()) {
2645 std::vector<string> pair_strings;
2646 for (const auto& pair : output_to_operand_aliasing_) {
2647 pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2648 pair.second.first, ", ",
2649 pair.second.second.ToString(), ")"));
2650 }
2651 extra.push_back(StrCat("output_to_operand_aliasing={",
2652 StrJoin(pair_strings, ", "), "}"));
2653 }
2654 if (custom_call_schedule_ != CustomCallSchedule::SCHEDULE_NONE) {
2655 extra.push_back(
2656 StrCat("schedule=", CustomCallSchedule_Name(custom_call_schedule_)));
2657 }
2658 if (api_version_ != CustomCallApiVersion::API_VERSION_ORIGINAL) {
2659 extra.push_back(
2660 StrCat("api_version=", CustomCallApiVersion_Name(api_version_)));
2661 }
2662 return extra;
2663 }
2664
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2665 bool HloCustomCallInstruction::IdenticalSlowPath(
2666 const HloInstruction& other,
2667 const std::function<bool(const HloComputation*, const HloComputation*)>&
2668 eq_computations) const {
2669 const auto& casted_other =
2670 static_cast<const HloCustomCallInstruction&>(other);
2671 if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2672 (window_ != nullptr &&
2673 !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2674 return false;
2675 }
2676 if ((convolution_dimension_numbers_ == nullptr) !=
2677 (casted_other.convolution_dimension_numbers_ == nullptr) ||
2678 (convolution_dimension_numbers_ != nullptr &&
2679 !protobuf_util::ProtobufEquals(
2680 convolution_dimension_numbers(),
2681 casted_other.convolution_dimension_numbers()))) {
2682 return false;
2683 }
2684 if (feature_group_count_ != casted_other.feature_group_count_) {
2685 return false;
2686 }
2687 if (batch_group_count_ != casted_other.batch_group_count_) {
2688 return false;
2689 }
2690
2691 if (padding_type_ != casted_other.padding_type()) {
2692 return false;
2693 }
2694
2695 if (layout_constrained() != casted_other.layout_constrained()) {
2696 return false;
2697 }
2698 if (layout_constrained()) {
2699 for (int64_t i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2700 if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2701 casted_other.operand_shapes_with_layout_[i])) {
2702 return false;
2703 }
2704 }
2705 }
2706 if (custom_call_has_side_effect_ !=
2707 casted_other.custom_call_has_side_effect()) {
2708 return false;
2709 }
2710 if (output_to_operand_aliasing_ !=
2711 casted_other.output_to_operand_aliasing()) {
2712 return false;
2713 }
2714 if (!protobuf_util::ProtobufEquals(precision_config(),
2715 casted_other.precision_config())) {
2716 return false;
2717 }
2718
2719 if (called_computations().size() != other.called_computations().size()) {
2720 return false;
2721 }
2722 for (int64_t i = 0; i < called_computations().size(); ++i) {
2723 if (!eq_computations(called_computations()[i],
2724 other.called_computations()[i])) {
2725 return false;
2726 }
2727 }
2728 if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
2729 return false;
2730 }
2731 if (HasLiteral() == casted_other.HasLiteral()) {
2732 if (HasLiteral() && literal() == casted_other.literal()) {
2733 return false;
2734 }
2735 } else {
2736 return true;
2737 }
2738 // Note: backend_config comparison is done in Identical, which is the
2739 // intended/exposed way to compare computations, and so not repeated here.
2740 return custom_call_target_ == casted_other.custom_call_target_;
2741 }
2742
2743 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2744 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2745 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2746 HloCloneContext* context) const {
2747 auto cloned = absl::make_unique<HloCustomCallInstruction>(
2748 shape, new_operands, called_computations(), custom_call_target(),
2749 opaque(), api_version_);
2750 if (layout_constrained()) {
2751 cloned->layout_constrained_ = true;
2752 cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2753 }
2754 if (window_ != nullptr) {
2755 cloned->set_window(*window_);
2756 }
2757 if (convolution_dimension_numbers_ != nullptr) {
2758 cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2759 }
2760 if (HasLiteral()) {
2761 cloned->set_literal(literal().Clone());
2762 }
2763 cloned->set_feature_group_count(feature_group_count_);
2764 cloned->set_batch_group_count(batch_group_count_);
2765 cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2766 cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2767 cloned->set_padding_type(padding_type_);
2768 *cloned->mutable_precision_config() = precision_config();
2769 cloned->set_custom_call_schedule(custom_call_schedule_);
2770 return std::move(cloned);
2771 }
2772
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2773 HloPadInstruction::HloPadInstruction(const Shape& shape,
2774 HloInstruction* operand,
2775 HloInstruction* padding_value,
2776 const PaddingConfig& padding_config)
2777 : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2778 AppendOperand(operand);
2779 AppendOperand(padding_value);
2780 }
2781
ToProto() const2782 HloInstructionProto HloPadInstruction::ToProto() const {
2783 HloInstructionProto proto = HloInstruction::ToProto();
2784 *proto.mutable_padding_config() = padding_config_;
2785 return proto;
2786 }
2787
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2788 std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
2789 const HloPrintOptions& options) const {
2790 return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2791 }
2792
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2793 bool HloPadInstruction::IdenticalSlowPath(
2794 const HloInstruction& other,
2795 const std::function<bool(const HloComputation*, const HloComputation*)>&
2796 eq_computations) const {
2797 const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2798 return protobuf_util::ProtobufEquals(padding_config(),
2799 casted_other.padding_config());
2800 }
2801
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2802 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2803 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2804 HloCloneContext* context) const {
2805 CHECK_EQ(new_operands.size(), 2);
2806 return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
2807 new_operands[1], padding_config_);
2808 }
2809
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)2810 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2811 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2812 absl::Span<const int64> slice_sizes)
2813 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2814 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2815 AppendOperand(operand);
2816 AppendOperand(start_indices);
2817 }
2818
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)2819 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2820 const Shape& shape, HloInstruction* operand,
2821 absl::Span<HloInstruction* const> start_indices,
2822 absl::Span<const int64> slice_sizes)
2823 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2824 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2825 AppendOperand(operand);
2826 for (HloInstruction* index : start_indices) {
2827 AppendOperand(index);
2828 }
2829 }
2830
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2831 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2832 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2833 HloInstruction* start_indices)
2834 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2835 AppendOperand(operand);
2836 AppendOperand(update);
2837 AppendOperand(start_indices);
2838 }
2839
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2840 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2841 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2842 absl::Span<HloInstruction* const> start_indices)
2843 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2844 AppendOperand(operand);
2845 AppendOperand(update);
2846 for (HloInstruction* index : start_indices) {
2847 AppendOperand(index);
2848 }
2849 }
2850
ToProto() const2851 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
2852 HloInstructionProto proto = HloInstruction::ToProto();
2853 for (int64_t slice_size : dynamic_slice_sizes_) {
2854 proto.add_dynamic_slice_sizes(slice_size);
2855 }
2856 return proto;
2857 }
2858
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2859 std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
2860 const HloPrintOptions& options) const {
2861 return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
2862 "}")};
2863 }
2864
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2865 bool HloDynamicSliceInstruction::IdenticalSlowPath(
2866 const HloInstruction& other,
2867 const std::function<bool(const HloComputation*, const HloComputation*)>&
2868 eq_computations) const {
2869 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
2870 return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
2871 }
2872
2873 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2874 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
2875 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2876 HloCloneContext* context) const {
2877 if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
2878 // TODO(b/118437727): Old form, remove this path.
2879 return absl::make_unique<HloDynamicSliceInstruction>(
2880 shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
2881 } else {
2882 return absl::make_unique<HloDynamicSliceInstruction>(
2883 shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
2884 }
2885 }
2886
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2887 HloGatherInstruction::HloGatherInstruction(
2888 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2889 const GatherDimensionNumbers& gather_dim_numbers,
2890 absl::Span<const int64> slice_sizes, bool indices_are_sorted)
2891 : HloInstruction(HloOpcode::kGather, shape),
2892 indices_are_sorted_(indices_are_sorted) {
2893 AppendOperand(operand);
2894 AppendOperand(start_indices);
2895 gather_dimension_numbers_ =
2896 absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
2897 absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
2898 }
2899
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)2900 /*static*/ string HloGatherInstruction::GatherDimensionNumbersToString(
2901 const GatherDimensionNumbers& gather_dimension_numbers) {
2902 string offset_dims =
2903 StrCat("offset_dims={",
2904 StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
2905 string collapsed_slice_dims = StrCat(
2906 "collapsed_slice_dims={",
2907 StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
2908 string start_index_map =
2909 StrCat("start_index_map={",
2910 StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
2911 string index_vector_dim =
2912 StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
2913
2914 return StrJoin<std::initializer_list<string>>(
2915 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
2916 ", ");
2917 }
2918
MakeGatherDimNumbers(absl::Span<const int64> offset_dims,absl::Span<const int64> collapsed_slice_dims,absl::Span<const int64> start_index_map,int64_t index_vector_dim)2919 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
2920 absl::Span<const int64> offset_dims,
2921 absl::Span<const int64> collapsed_slice_dims,
2922 absl::Span<const int64> start_index_map, int64_t index_vector_dim) {
2923 GatherDimensionNumbers gather_dim_numbers;
2924 for (int64_t output_window_dim : offset_dims) {
2925 gather_dim_numbers.add_offset_dims(output_window_dim);
2926 }
2927 for (int64_t elided_window_dim : collapsed_slice_dims) {
2928 gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
2929 }
2930 for (int64_t gather_dim_to_input_dim : start_index_map) {
2931 gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
2932 }
2933
2934 gather_dim_numbers.set_index_vector_dim(index_vector_dim);
2935 return gather_dim_numbers;
2936 }
2937
ToProto() const2938 HloInstructionProto HloGatherInstruction::ToProto() const {
2939 HloInstructionProto proto = HloInstruction::ToProto();
2940 *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
2941 for (int64_t bound : gather_slice_sizes()) {
2942 proto.add_gather_slice_sizes(bound);
2943 }
2944 proto.set_indices_are_sorted(indices_are_sorted());
2945 return proto;
2946 }
2947
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2948 std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
2949 const HloPrintOptions& options) const {
2950 std::vector<string> attrs{
2951 GatherDimensionNumbersToString(gather_dimension_numbers()),
2952 StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
2953 if (indices_are_sorted()) {
2954 attrs.push_back("indices_are_sorted=true");
2955 }
2956 return attrs;
2957 }
2958
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2959 bool HloGatherInstruction::IdenticalSlowPath(
2960 const HloInstruction& other,
2961 const std::function<bool(const HloComputation*, const HloComputation*)>&
2962 eq_computations) const {
2963 const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
2964 return protobuf_util::ProtobufEquals(
2965 gather_dimension_numbers(),
2966 casted_other.gather_dimension_numbers()) &&
2967 gather_slice_sizes() == casted_other.gather_slice_sizes() &&
2968 indices_are_sorted() == casted_other.indices_are_sorted();
2969 }
2970
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2971 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
2972 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2973 HloCloneContext* context) const {
2974 CHECK_EQ(new_operands.size(), 2);
2975 return absl::make_unique<HloGatherInstruction>(
2976 shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
2977 gather_slice_sizes(), indices_are_sorted());
2978 }
2979
HloScatterInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)2980 HloScatterInstruction::HloScatterInstruction(
2981 const Shape& shape, HloInstruction* operand,
2982 HloInstruction* scatter_indices, HloInstruction* updates,
2983 HloComputation* update_computation,
2984 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
2985 bool unique_indices)
2986 : HloInstruction(HloOpcode::kScatter, shape),
2987 indices_are_sorted_(indices_are_sorted),
2988 unique_indices_(unique_indices) {
2989 AppendOperand(operand);
2990 AppendOperand(scatter_indices);
2991 AppendOperand(updates);
2992 AppendComputation(update_computation);
2993 scatter_dimension_numbers_ =
2994 absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
2995 }
2996
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)2997 /*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString(
2998 const ScatterDimensionNumbers& scatter_dimension_numbers) {
2999 string update_window_dims =
3000 StrCat("update_window_dims={",
3001 StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
3002 string inserted_window_dims = StrCat(
3003 "inserted_window_dims={",
3004 StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
3005 string scatter_dims_to_operand_dims = StrCat(
3006 "scatter_dims_to_operand_dims={",
3007 StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
3008 "}");
3009 string index_vector_dim =
3010 StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
3011
3012 return StrJoin<std::initializer_list<string>>(
3013 {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
3014 index_vector_dim},
3015 ", ");
3016 }
3017
3018 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64> update_window_dims,absl::Span<const int64> inserted_window_dims,absl::Span<const int64> scatter_dims_to_operand_dims,int64_t index_vector_dim)3019 HloScatterInstruction::MakeScatterDimNumbers(
3020 absl::Span<const int64> update_window_dims,
3021 absl::Span<const int64> inserted_window_dims,
3022 absl::Span<const int64> scatter_dims_to_operand_dims,
3023 int64_t index_vector_dim) {
3024 ScatterDimensionNumbers scatter_dim_numbers;
3025 for (int64_t update_window_dim : update_window_dims) {
3026 scatter_dim_numbers.add_update_window_dims(update_window_dim);
3027 }
3028 for (int64_t inserted_window_dim : inserted_window_dims) {
3029 scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
3030 }
3031 for (int64_t scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
3032 scatter_dim_numbers.add_scatter_dims_to_operand_dims(
3033 scatter_dim_to_operand_dim);
3034 }
3035 scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
3036 return scatter_dim_numbers;
3037 }
3038
ToProto() const3039 HloInstructionProto HloScatterInstruction::ToProto() const {
3040 HloInstructionProto proto = HloInstruction::ToProto();
3041 *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
3042 proto.set_indices_are_sorted(indices_are_sorted());
3043 proto.set_unique_indices(unique_indices());
3044 return proto;
3045 }
3046
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3047 std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
3048 const HloPrintOptions& options) const {
3049 std::vector<string> attrs{
3050 ScatterDimensionNumbersToString(scatter_dimension_numbers())};
3051 if (indices_are_sorted()) {
3052 attrs.push_back("indices_are_sorted=true");
3053 }
3054 if (unique_indices()) {
3055 attrs.push_back("unique_indices=true");
3056 }
3057 return attrs;
3058 }
3059
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3060 bool HloScatterInstruction::IdenticalSlowPath(
3061 const HloInstruction& other,
3062 const std::function<bool(const HloComputation*, const HloComputation*)>&
3063 eq_computations) const {
3064 const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
3065 return protobuf_util::ProtobufEquals(
3066 scatter_dimension_numbers(),
3067 casted_other.scatter_dimension_numbers()) &&
3068 eq_computations(to_apply(), casted_other.to_apply()) &&
3069 indices_are_sorted() == casted_other.indices_are_sorted() &&
3070 unique_indices() == casted_other.unique_indices();
3071 }
3072
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3073 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
3074 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3075 HloCloneContext* context) const {
3076 CHECK_EQ(new_operands.size(), 3);
3077 return absl::make_unique<HloScatterInstruction>(
3078 shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
3079 scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
3080 }
3081
HloIotaInstruction(const Shape & shape,int64_t iota_dimension)3082 HloIotaInstruction::HloIotaInstruction(const Shape& shape,
3083 int64_t iota_dimension)
3084 : HloInstruction(HloOpcode::kIota, shape),
3085 iota_dimension_(iota_dimension) {}
3086
ToProto() const3087 HloInstructionProto HloIotaInstruction::ToProto() const {
3088 HloInstructionProto proto = HloInstruction::ToProto();
3089 proto.add_dimensions(iota_dimension());
3090 return proto;
3091 }
3092
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3093 std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
3094 const HloPrintOptions& options) const {
3095 return {StrCat("iota_dimension=", iota_dimension())};
3096 }
3097
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3098 bool HloIotaInstruction::IdenticalSlowPath(
3099 const HloInstruction& other,
3100 const std::function<bool(const HloComputation*, const HloComputation*)>&
3101 eq_computations) const {
3102 const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
3103 return iota_dimension() == casted_other.iota_dimension();
3104 }
3105
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3106 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
3107 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3108 HloCloneContext* context) const {
3109 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
3110 }
3111
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)3112 HloDotInstruction::HloDotInstruction(
3113 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
3114 const DotDimensionNumbers& dimension_numbers,
3115 const PrecisionConfig& precision_config)
3116 : HloInstruction(HloOpcode::kDot, shape),
3117 dot_dimension_numbers_(dimension_numbers),
3118 precision_config_(precision_config) {
3119 AppendOperand(lhs);
3120 AppendOperand(rhs);
3121 }
3122
ToProto() const3123 HloInstructionProto HloDotInstruction::ToProto() const {
3124 HloInstructionProto proto = HloInstruction::ToProto();
3125 *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
3126 *proto.mutable_precision_config() = precision_config_;
3127 return proto;
3128 }
3129
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3130 std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
3131 const HloPrintOptions& options) const {
3132 std::vector<string> extra = {DotDimensionNumbersToString()};
3133
3134 string precision_config_string = PrecisionConfigToString(precision_config_);
3135 if (!precision_config_string.empty()) {
3136 extra.push_back(precision_config_string);
3137 }
3138 return extra;
3139 }
3140
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3141 bool HloDotInstruction::IdenticalSlowPath(
3142 const HloInstruction& other,
3143 const std::function<bool(const HloComputation*, const HloComputation*)>&
3144 eq_computations) const {
3145 const auto& casted_other = static_cast<const HloDotInstruction&>(other);
3146 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
3147 casted_other.dot_dimension_numbers()) &&
3148 protobuf_util::ProtobufEquals(precision_config(),
3149 casted_other.precision_config());
3150 }
3151
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3152 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
3153 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3154 HloCloneContext* context) const {
3155 CHECK_EQ(new_operands.size(), 2);
3156 return absl::make_unique<HloDotInstruction>(
3157 shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
3158 precision_config_);
3159 }
3160
DotDimensionNumbersToString() const3161 string HloDotInstruction::DotDimensionNumbersToString() const {
3162 std::vector<string> result;
3163 const DotDimensionNumbers& dnums = dot_dimension_numbers_;
3164 if (!dnums.lhs_batch_dimensions().empty()) {
3165 result.push_back(StrCat("lhs_batch_dims={",
3166 StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
3167 }
3168 result.push_back(StrCat("lhs_contracting_dims={",
3169 StrJoin(dnums.lhs_contracting_dimensions(), ","),
3170 "}"));
3171
3172 if (!dnums.rhs_batch_dimensions().empty()) {
3173 result.push_back(StrCat("rhs_batch_dims={",
3174 StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
3175 }
3176 result.push_back(StrCat("rhs_contracting_dims={",
3177 StrJoin(dnums.rhs_contracting_dimensions(), ","),
3178 "}"));
3179
3180 return StrJoin(result, ", ");
3181 }
3182
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3183 HloDomainInstruction::HloDomainInstruction(
3184 const Shape& shape, HloInstruction* operand,
3185 std::unique_ptr<DomainMetadata> operand_side_metadata,
3186 std::unique_ptr<DomainMetadata> user_side_metadata)
3187 : HloInstruction(HloOpcode::kDomain, shape),
3188 operand_side_metadata_(std::move(operand_side_metadata)),
3189 user_side_metadata_(std::move(user_side_metadata)) {
3190 AppendOperand(operand);
3191 }
3192
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3193 std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
3194 const HloPrintOptions& options) const {
3195 if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3196 return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3197 "\", entry=", user_side_metadata_->ToString(),
3198 ", exit=", operand_side_metadata_->ToString(), "}")};
3199 }
3200 return {};
3201 }
3202
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3203 bool HloDomainInstruction::IdenticalSlowPath(
3204 const HloInstruction& other,
3205 const std::function<bool(const HloComputation*, const HloComputation*)>&
3206 eq_computations) const {
3207 const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3208 return operand_side_metadata().Matches(
3209 casted_other.operand_side_metadata()) &&
3210 user_side_metadata().Matches(casted_other.user_side_metadata());
3211 }
3212
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3213 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3214 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3215 HloCloneContext* context) const {
3216 CHECK_EQ(new_operands.size(), 1);
3217 return absl::make_unique<HloDomainInstruction>(
3218 shape, new_operands[0], operand_side_metadata_->Clone(),
3219 user_side_metadata_->Clone());
3220 }
3221
ToProto() const3222 HloInstructionProto HloDomainInstruction::ToProto() const {
3223 HloInstructionProto proto = HloInstruction::ToProto();
3224 auto operand_side_sharding =
3225 dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3226 if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3227 *proto.mutable_domain_entry_sharding() =
3228 operand_side_sharding->sharding()->ToProto();
3229 }
3230
3231 auto user_side_sharding =
3232 dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3233 if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3234 *proto.mutable_domain_exit_sharding() =
3235 user_side_sharding->sharding()->ToProto();
3236 }
3237
3238 return proto;
3239 }
3240
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64_t dimension)3241 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3242 const Shape& shape, HloInstruction* operand, int64_t dimension)
3243 : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3244 dimension_(dimension) {
3245 AppendOperand(operand);
3246 }
3247
ToProto() const3248 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3249 HloInstructionProto proto = HloInstruction::ToProto();
3250 proto.add_dimensions(dimension());
3251 return proto;
3252 }
3253
ExtraAttributesToStringImpl(const HloPrintOptions &) const3254 std::vector<string> HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3255 const HloPrintOptions& /*options*/) const {
3256 return {StrCat("dimensions={", dimension(), "}")};
3257 }
3258
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3259 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3260 const HloInstruction& other,
3261 const std::function<bool(const HloComputation*, const HloComputation*)>&
3262 /*eq_computations*/) const {
3263 const auto& casted_other =
3264 static_cast<const HloGetDimensionSizeInstruction&>(other);
3265 return dimension() == casted_other.dimension();
3266 }
3267
3268 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3269 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3270 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3271 HloCloneContext* /*context*/) const {
3272 if (new_operands.size() != 1) {
3273 LOG(FATAL) << "expects 1 operand";
3274 }
3275 return absl::make_unique<HloGetDimensionSizeInstruction>(
3276 shape, new_operands[0], dimension());
3277 }
3278
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)3279 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3280 const Shape& shape, HloInstruction* operand, HloInstruction* val,
3281 int64_t dimension)
3282 : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3283 dimension_(dimension) {
3284 AppendOperand(operand);
3285 AppendOperand(val);
3286 }
3287
ExtraAttributesToStringImpl(const HloPrintOptions &) const3288 std::vector<string> HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3289 const HloPrintOptions& /*options*/) const {
3290 return {StrCat("dimensions={", dimension(), "}")};
3291 }
3292
ToProto() const3293 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3294 HloInstructionProto proto = HloInstruction::ToProto();
3295 proto.add_dimensions(dimension());
3296 return proto;
3297 }
3298
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3299 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3300 const HloInstruction& other,
3301 const std::function<bool(const HloComputation*, const HloComputation*)>&
3302 /*eq_computations*/) const {
3303 const auto& casted_other =
3304 static_cast<const HloSetDimensionSizeInstruction&>(other);
3305 return dimension() == casted_other.dimension();
3306 }
3307
3308 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3309 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3310 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3311 HloCloneContext* /*context*/) const {
3312 if (new_operands.size() != 2) {
3313 LOG(FATAL) << "expects 2 operand";
3314 }
3315 return absl::make_unique<HloSetDimensionSizeInstruction>(
3316 shape, new_operands[0], new_operands[1], dimension());
3317 }
3318
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64_t delta)3319 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3320 const Shape& shape, int64_t delta)
3321 : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3322
ToProto() const3323 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3324 HloInstructionProto proto = HloInstruction::ToProto();
3325 proto.set_delta(delta_);
3326 return proto;
3327 }
3328
3329 std::vector<string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3330 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3331 const HloPrintOptions& /*options*/) const {
3332 return {StrCat("delta=", delta())};
3333 }
3334
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3335 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3336 const HloInstruction& other,
3337 const std::function<bool(const HloComputation*, const HloComputation*)>&
3338 /*eq_computations*/) const {
3339 const auto& casted_other =
3340 static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3341 return delta() == casted_other.delta();
3342 }
3343
3344 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3345 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3346 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3347 HloCloneContext* /*context*/) const {
3348 if (!new_operands.empty()) {
3349 LOG(FATAL) << "expects 0 operand";
3350 }
3351 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3352 }
3353
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3354 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3355 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3356 : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3357 algorithm_(algorithm) {
3358 AppendOperand(state);
3359 }
3360
ToProto() const3361 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3362 HloInstructionProto proto = HloInstruction::ToProto();
3363 proto.set_rng_algorithm(algorithm_);
3364 return proto;
3365 }
3366
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3367 std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3368 const HloPrintOptions& options) const {
3369 return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3370 }
3371
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3372 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3373 const HloInstruction& other,
3374 const std::function<bool(const HloComputation*, const HloComputation*)>&
3375 eq_computations) const {
3376 const auto& casted_other =
3377 static_cast<const HloRngBitGeneratorInstruction&>(other);
3378 return algorithm() == casted_other.algorithm();
3379 }
3380
3381 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3382 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3383 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3384 HloCloneContext* /*context*/) const {
3385 CHECK_EQ(new_operands.size(), 1);
3386 return absl::make_unique<HloRngBitGeneratorInstruction>(
3387 shape, new_operands[0], algorithm());
3388 }
3389
3390 } // namespace xla
3391