1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/str_split.h"
35 #include "absl/strings/string_view.h"
36 #include "tensorflow/compiler/xla/literal_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/protobuf_util.h"
39 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/platform/protobuf.h"
49
50 namespace xla {
51 namespace {
52
53 using absl::CEscape;
54 using absl::StrAppend;
55 using absl::StrCat;
56 using absl::StrJoin;
57
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)58 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
59 const HloInstruction* operand) {
60 const auto operand_indices = instruction->OperandIndices(operand);
61 return absl::c_all_of(operand_indices, [instruction](int64_t operand_index) {
62 return instruction->IsElementwiseOnOperand(operand_index);
63 });
64 }
65
PrecisionConfigToString(const PrecisionConfig & precision_config)66 std::string PrecisionConfigToString(const PrecisionConfig& precision_config) {
67 if (absl::c_all_of(
68 precision_config.operand_precision(), [](int32_t precision) {
69 return static_cast<PrecisionConfig::Precision>(precision) ==
70 PrecisionConfig::DEFAULT;
71 })) {
72 return "";
73 }
74
75 return StrCat(
76 "operand_precision={",
77 StrJoin(
78 precision_config.operand_precision(), ",",
79 [](std::string* out, int32_t precision) {
80 CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
81 StrAppend(out,
82 PrecisionToString(
83 static_cast<PrecisionConfig::Precision>(precision)));
84 }),
85 "}");
86 }
87
SetThreadName(HloComputation * called_computation,absl::string_view execution_thread,bool skip_async_execution_thread_overwrite)88 void SetThreadName(HloComputation* called_computation,
89 absl::string_view execution_thread,
90 bool skip_async_execution_thread_overwrite) {
91 called_computation->SetExecutionThread(execution_thread);
92 for (HloInstruction* instr : called_computation->instructions()) {
93 if (instr->IsAsynchronous()) {
94 if (!skip_async_execution_thread_overwrite) {
95 // Set async instruction thread name and also recursively set async
96 // computations.
97 instr->set_async_execution_thread(execution_thread);
98 }
99 continue;
100 }
101 for (HloComputation* nested_called_computation :
102 instr->called_computations()) {
103 SetThreadName(nested_called_computation, execution_thread,
104 skip_async_execution_thread_overwrite);
105 }
106 }
107 }
108
109 } // namespace
110
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64_t feature_index)111 HloBatchNormInstruction::HloBatchNormInstruction(
112 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
113 HloInstruction* scale, float epsilon, int64_t feature_index)
114 : HloInstruction(opcode, shape),
115 epsilon_(epsilon),
116 feature_index_(feature_index) {
117 AppendOperand(operand);
118 AppendOperand(scale);
119 }
120
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const121 bool HloBatchNormInstruction::IdenticalSlowPath(
122 const HloInstruction& other,
123 const std::function<bool(const HloComputation*, const HloComputation*)>&
124 eq_computations) const {
125 const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
126 return feature_index() == casted_other.feature_index() &&
127 epsilon() == casted_other.epsilon();
128 }
129
ToProto() const130 HloInstructionProto HloBatchNormInstruction::ToProto() const {
131 HloInstructionProto proto = HloInstruction::ToProto();
132 proto.set_epsilon(epsilon_);
133 proto.set_feature_index(feature_index_);
134 return proto;
135 }
136
ExtraAttributesToStringImpl(const HloPrintOptions & options) const137 std::vector<std::string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
138 const HloPrintOptions& options) const {
139 return {StrCat("epsilon=", epsilon()),
140 StrCat("feature_index=", feature_index())};
141 }
142
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)143 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
144 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
145 HloInstruction* offset, float epsilon, int64_t feature_index)
146 : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
147 scale, epsilon, feature_index) {
148 AppendOperand(offset);
149 }
150
151 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const152 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
153 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
154 HloCloneContext* context) const {
155 CHECK_EQ(new_operands.size(), 3);
156 return std::make_unique<HloBatchNormTrainingInstruction>(
157 shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
158 feature_index());
159 }
160
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)161 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
162 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
163 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
164 float epsilon, int64_t feature_index)
165 : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
166 scale, epsilon, feature_index) {
167 AppendOperand(offset);
168 AppendOperand(mean);
169 AppendOperand(variance);
170 }
171
172 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const173 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
174 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
175 HloCloneContext* context) const {
176 CHECK_EQ(new_operands.size(), 5);
177 return std::make_unique<HloBatchNormInferenceInstruction>(
178 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
179 new_operands[4], epsilon(), feature_index());
180 }
181
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)182 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
183 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
184 HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
185 float epsilon, int64_t feature_index)
186 : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
187 epsilon, feature_index) {
188 AppendOperand(mean);
189 AppendOperand(variance);
190 AppendOperand(grad_output);
191 }
192
193 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const194 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
195 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
196 HloCloneContext* context) const {
197 CHECK_EQ(new_operands.size(), 5);
198 return std::make_unique<HloBatchNormGradInstruction>(
199 shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
200 new_operands[4], epsilon(), feature_index());
201 }
202
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64_t> fft_length)203 HloFftInstruction::HloFftInstruction(const Shape& shape,
204 HloInstruction* operand, FftType fft_type,
205 absl::Span<const int64_t> fft_length)
206 : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
207 fft_length_.assign(fft_length.begin(), fft_length.end());
208 AppendOperand(operand);
209 }
210
ToProto() const211 HloInstructionProto HloFftInstruction::ToProto() const {
212 HloInstructionProto proto = HloInstruction::ToProto();
213 proto.set_fft_type(fft_type_);
214 for (int64_t fft_len : fft_length_) {
215 proto.add_fft_length(fft_len);
216 }
217 return proto;
218 }
219
ExtraAttributesToStringImpl(const HloPrintOptions & options) const220 std::vector<std::string> HloFftInstruction::ExtraAttributesToStringImpl(
221 const HloPrintOptions& options) const {
222 return {StrCat("fft_type=", FftType_Name(fft_type())),
223 StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
224 }
225
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const226 bool HloFftInstruction::IdenticalSlowPath(
227 const HloInstruction& other,
228 const std::function<bool(const HloComputation*, const HloComputation*)>&
229 eq_computations) const {
230 const auto& casted_other = static_cast<const HloFftInstruction&>(other);
231 return fft_type() == casted_other.fft_type() &&
232 fft_length() == casted_other.fft_length();
233 }
234
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const235 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
236 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
237 HloCloneContext* context) const {
238 CHECK_EQ(new_operands.size(), 1);
239 return std::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
240 fft_length_);
241 }
242
HloAsyncInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)243 HloAsyncInstruction::HloAsyncInstruction(
244 HloOpcode opcode, const Shape& shape,
245 absl::Span<HloInstruction* const> operands,
246 HloComputation* async_computation, std::optional<int64_t> async_group_id,
247 absl::string_view async_execution_thread)
248 : HloInstruction(opcode, shape),
249 async_group_id_(async_group_id),
250 async_execution_thread_(async_execution_thread) {
251 CHECK(opcode == HloOpcode::kAsyncStart || operands.size() == 1);
252 for (auto operand : operands) {
253 AppendOperand(operand);
254 }
255 AppendComputation(async_computation);
256 CHECK(!async_computation->IsCustomCallComputation());
257 CHECK(!async_computation->IsFusionComputation());
258 async_computation->AddAsyncInstruction(this);
259 set_async_execution_thread(async_execution_thread);
260 }
261
HloAsyncInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)262 HloAsyncInstruction::HloAsyncInstruction(
263 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
264 HloComputation* async_computation, std::optional<int64_t> async_group_id,
265 absl::string_view async_execution_thread)
266 : HloInstruction(opcode, shape),
267 async_group_id_(async_group_id),
268 async_execution_thread_(async_execution_thread) {
269 AppendOperand(operand);
270 AppendComputation(async_computation);
271 CHECK(!async_computation->IsCustomCallComputation());
272 CHECK(!async_computation->IsFusionComputation());
273 async_computation->AddAsyncInstruction(this);
274 set_async_execution_thread(async_execution_thread);
275 }
276
~HloAsyncInstruction()277 HloAsyncInstruction::~HloAsyncInstruction() {
278 ClearAsyncComputationInstruction();
279 ClearCalledComputations();
280 }
281
ClearAsyncComputationInstruction()282 void HloAsyncInstruction::ClearAsyncComputationInstruction() {
283 // Each async instruction calls a single computation, but we use
284 // called_computations() instead of async_wrapped_instruction(), because the
285 // order in which things get destructed can vary; the async computation's
286 // back-pointer may already be null, which violates a check in
287 // async_wrapped_instruction.
288 for (HloComputation* computation : called_computations()) {
289 CHECK(computation != nullptr);
290 if (computation->IsAsyncComputation()) {
291 computation->RemoveAsyncInstruction(this);
292 }
293 }
294 }
295
async_wrapped_instruction() const296 HloInstruction* HloAsyncInstruction::async_wrapped_instruction() const {
297 CHECK(!called_computations().empty());
298 return called_computations()[0]->root_instruction();
299 }
300
async_wrapped_opcode() const301 HloOpcode HloAsyncInstruction::async_wrapped_opcode() const {
302 return async_wrapped_instruction()->opcode();
303 }
304
ExtraAttributesToStringImpl(const HloPrintOptions & options) const305 std::vector<std::string> HloAsyncInstruction::ExtraAttributesToStringImpl(
306 const HloPrintOptions& options) const {
307 std::vector<std::string> result;
308 if (async_group_id_.has_value()) {
309 result.push_back(StrCat("async_group_id=", *async_group_id_));
310 }
311 if (async_execution_thread_ != kMainExecutionThread) {
312 result.push_back(
313 StrCat("async_execution_thread=\"", async_execution_thread_, "\""));
314 }
315 if (options.syntax_sugar_async_ops()) {
316 std::vector<std::string> wrapped_extra_attributes =
317 async_wrapped_instruction()->ExtraAttributesToString(options);
318 absl::c_copy(wrapped_extra_attributes, std::back_inserter(result));
319 }
320 return result;
321 }
322
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const323 bool HloAsyncInstruction::IdenticalSlowPath(
324 const HloInstruction& other,
325 const std::function<bool(const HloComputation*, const HloComputation*)>&
326 eq_computations) const {
327 return opcode() == other.opcode() &&
328 eq_computations(async_wrapped_computation(),
329 other.async_wrapped_computation());
330 }
331
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const332 std::unique_ptr<HloInstruction> HloAsyncInstruction::CloneWithNewOperandsImpl(
333 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
334 HloCloneContext* context) const {
335 HloModule* module = context != nullptr ? context->module() : GetModule();
336 HloComputation* new_wrapped_computation = nullptr;
337 if (context != nullptr) {
338 new_wrapped_computation =
339 context->FindComputation(async_wrapped_computation());
340 }
341 if (new_wrapped_computation == nullptr) {
342 new_wrapped_computation = module->AddEmbeddedComputation(
343 async_wrapped_computation()->Clone("clone", context));
344 }
345 return std::make_unique<HloAsyncInstruction>(
346 opcode(), shape, new_operands, new_wrapped_computation, async_group_id_,
347 async_execution_thread_);
348 }
349
set_async_group_id(std::optional<int64_t> async_group_id)350 void HloAsyncInstruction::set_async_group_id(
351 std::optional<int64_t> async_group_id) {
352 async_group_id_ = async_group_id;
353 }
354
set_async_execution_thread(absl::string_view async_execution_thread)355 void HloAsyncInstruction::set_async_execution_thread(
356 absl::string_view async_execution_thread) {
357 async_execution_thread_ = std::string(async_execution_thread);
358 SetThreadName(async_wrapped_computation(), async_execution_thread,
359 /*skip_async_execution_thread_overwrite=*/false);
360 }
361
ToProto() const362 HloInstructionProto HloAsyncInstruction::ToProto() const {
363 HloInstructionProto proto = HloInstruction::ToProto();
364 proto.set_async_group_id(async_group_id_.has_value() ? *async_group_id_ : -1);
365 proto.set_async_execution_thread(async_execution_thread_ ==
366 HloInstruction::kMainExecutionThread
367 ? ""
368 : async_execution_thread_);
369 return proto;
370 }
371
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)372 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
373 HloInstruction* operand,
374 bool is_cross_program_prefetch)
375 : HloInstruction(HloOpcode::kCopyStart, shape),
376 is_cross_program_prefetch_(is_cross_program_prefetch) {
377 AppendOperand(operand);
378 }
379
ToProto() const380 HloInstructionProto HloCopyStartInstruction::ToProto() const {
381 HloInstructionProto proto = HloInstruction::ToProto();
382 proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
383 return proto;
384 }
385
ExtraAttributesToStringImpl(const HloPrintOptions & options) const386 std::vector<std::string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
387 const HloPrintOptions& options) const {
388 std::vector<std::string> result;
389 if (is_cross_program_prefetch()) {
390 result.push_back("is_cross_program_prefetch=true");
391 }
392 return result;
393 }
394
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const395 bool HloCopyStartInstruction::IdenticalSlowPath(
396 const HloInstruction& other,
397 const std::function<bool(const HloComputation*, const HloComputation*)>&
398 eq_computations) const {
399 const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
400 return is_cross_program_prefetch() ==
401 casted_other.is_cross_program_prefetch();
402 }
403
404 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const405 HloCopyStartInstruction::CloneWithNewOperandsImpl(
406 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
407 HloCloneContext* context) const {
408 CHECK_EQ(new_operands.size(), 1);
409 return std::make_unique<HloCopyStartInstruction>(shape, new_operands[0],
410 is_cross_program_prefetch());
411 }
412
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,std::optional<Comparison::Type> type)413 HloCompareInstruction::HloCompareInstruction(
414 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
415 ComparisonDirection direction, std::optional<Comparison::Type> type)
416 : HloInstruction(HloOpcode::kCompare, shape),
417 compare_(type.has_value()
418 ? Comparison(direction, *type)
419 : Comparison(direction, lhs->shape().element_type())) {
420 AppendOperand(lhs);
421 AppendOperand(rhs);
422 }
423
ToProto() const424 HloInstructionProto HloCompareInstruction::ToProto() const {
425 HloInstructionProto proto = HloInstruction::ToProto();
426 proto.set_comparison_direction(
427 ComparisonDirectionToString(compare_.GetDirection()));
428 proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
429 return proto;
430 }
431
ExtraAttributesToStringImpl(const HloPrintOptions & options) const432 std::vector<std::string> HloCompareInstruction::ExtraAttributesToStringImpl(
433 const HloPrintOptions& options) const {
434 std::vector<std::string> result;
435 result.push_back(
436 StrCat("direction=", ComparisonDirectionToString(direction())));
437 if (compare_.GetType() !=
438 Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
439 result.push_back(
440 StrCat("type=", ComparisonTypeToString(compare_.GetType())));
441 }
442 return result;
443 }
444
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const445 bool HloCompareInstruction::IdenticalSlowPath(
446 const HloInstruction& other,
447 const std::function<bool(const HloComputation*, const HloComputation*)>&
448 eq_computations) const {
449 const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
450 return direction() == casted_other.direction();
451 }
452
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const453 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
454 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
455 HloCloneContext* context) const {
456 CHECK_EQ(new_operands.size(), 2);
457 return std::make_unique<HloCompareInstruction>(
458 shape, new_operands[0], new_operands[1], direction(), type());
459 }
460
461 namespace {
462
463 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
464 // of "key=value" attribute strings generically, using protocol buffer
465 // reflection.
466 //
467 // Currently implements a small subset of cases; feel free to add more as
468 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)469 std::vector<std::string> AttributeProtoToStringVector(
470 const tensorflow::protobuf::Message& message) {
471 const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
472 std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
473 reflection->ListFields(message, &fields);
474
475 std::vector<std::string> output;
476 for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
477 std::string s = absl::StrCat(field->name(), "=");
478 CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
479 switch (field->type()) {
480 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
481 bool val = reflection->GetBool(message, field);
482 absl::StrAppend(&s, val ? "true" : "false");
483 break;
484 }
485 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
486 const tensorflow::protobuf::EnumValueDescriptor* evd =
487 reflection->GetEnum(message, field);
488 absl::StrAppend(&s, evd->name());
489 break;
490 }
491 default:
492 LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
493 }
494 output.push_back(std::move(s));
495 }
496 return output;
497 }
498
499 } // namespace
500
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)501 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
502 const Shape& shape, HloInstruction* a, HloInstruction* b,
503 const TriangularSolveOptions& options)
504 : HloInstruction(HloOpcode::kTriangularSolve, shape),
505 triangular_solve_options_(options) {
506 AppendOperand(a);
507 AppendOperand(b);
508 }
509
ToProto() const510 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
511 HloInstructionProto proto = HloInstruction::ToProto();
512 *proto.mutable_triangular_solve_options() = triangular_solve_options_;
513 return proto;
514 }
515
516 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const517 HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
518 const HloPrintOptions& options) const {
519 return AttributeProtoToStringVector(triangular_solve_options_);
520 }
521
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const522 bool HloTriangularSolveInstruction::IdenticalSlowPath(
523 const HloInstruction& other,
524 const std::function<bool(const HloComputation*, const HloComputation*)>&
525 eq_computations) const {
526 const auto& casted_other =
527 static_cast<const HloTriangularSolveInstruction&>(other);
528 const auto& options = triangular_solve_options();
529 const auto& other_options = casted_other.triangular_solve_options();
530
531 return options.left_side() == other_options.left_side() &&
532 options.lower() == other_options.lower() &&
533 options.unit_diagonal() == other_options.unit_diagonal() &&
534 options.transpose_a() == other_options.transpose_a();
535 }
536
537 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const538 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
539 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
540 HloCloneContext* context) const {
541 CHECK_EQ(new_operands.size(), 2);
542 return std::make_unique<HloTriangularSolveInstruction>(
543 shape, new_operands[0], new_operands[1], triangular_solve_options());
544 }
545
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)546 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
547 HloInstruction* a,
548 const CholeskyOptions& options)
549 : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
550 AppendOperand(a);
551 }
552
ToProto() const553 HloInstructionProto HloCholeskyInstruction::ToProto() const {
554 HloInstructionProto proto = HloInstruction::ToProto();
555 *proto.mutable_cholesky_options() = cholesky_options_;
556 return proto;
557 }
558
ExtraAttributesToStringImpl(const HloPrintOptions & options) const559 std::vector<std::string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
560 const HloPrintOptions& options) const {
561 return AttributeProtoToStringVector(cholesky_options_);
562 }
563
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const564 bool HloCholeskyInstruction::IdenticalSlowPath(
565 const HloInstruction& other,
566 const std::function<bool(const HloComputation*, const HloComputation*)>&
567 eq_computations) const {
568 const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
569 const auto& options = cholesky_options();
570 const auto& other_options = casted_other.cholesky_options();
571
572 return options.lower() == other_options.lower();
573 }
574
575 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const576 HloCholeskyInstruction::CloneWithNewOperandsImpl(
577 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
578 HloCloneContext* context) const {
579 CHECK_EQ(new_operands.size(), 1);
580 return std::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
581 cholesky_options());
582 }
583
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const std::optional<int64_t> & channel_id)584 HloChannelInstruction::HloChannelInstruction(
585 HloOpcode opcode, const Shape& shape,
586 const std::optional<int64_t>& channel_id)
587 : HloInstruction(opcode, shape), channel_id_(channel_id) {}
588
set_channel_id(const std::optional<int64_t> & channel_id)589 void HloChannelInstruction::set_channel_id(
590 const std::optional<int64_t>& channel_id) {
591 channel_id_ = channel_id;
592 }
593
ToProto() const594 HloInstructionProto HloChannelInstruction::ToProto() const {
595 HloInstructionProto proto = HloInstruction::ToProto();
596 if (channel_id_) {
597 CHECK_GT(channel_id_.value(), 0)
598 << "Non-positive channel id is equivalent to no channel id";
599 proto.set_channel_id(*channel_id_);
600 }
601 return proto;
602 }
603
ExtraAttributesToStringImpl(const HloPrintOptions &) const604 std::vector<std::string> HloChannelInstruction::ExtraAttributesToStringImpl(
605 const HloPrintOptions& /*options*/) const {
606 std::vector<std::string> result;
607 if (channel_id_) {
608 result.push_back(StrCat("channel_id=", *channel_id_));
609 }
610 return result;
611 }
612
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const613 bool HloChannelInstruction::IdenticalSlowPath(
614 const HloInstruction& other,
615 const std::function<bool(const HloComputation*, const HloComputation*)>&
616 eq_computations) const {
617 if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
618 return false;
619 }
620 const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
621 return channel_id() == casted_other.channel_id();
622 }
623
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64_t channel_id,bool is_host_transfer)624 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
625 const Shape& shape,
626 int64_t channel_id,
627 bool is_host_transfer)
628 : HloChannelInstruction(opcode, shape, channel_id),
629 is_host_transfer_(is_host_transfer) {}
630
ToProto() const631 HloInstructionProto HloSendRecvInstruction::ToProto() const {
632 HloInstructionProto proto = HloChannelInstruction::ToProto();
633 proto.set_is_host_transfer(is_host_transfer_);
634 return proto;
635 }
636
ExtraAttributesToStringImpl(const HloPrintOptions & options) const637 std::vector<std::string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
638 const HloPrintOptions& options) const {
639 std::vector<std::string> attrs =
640 HloChannelInstruction::ExtraAttributesToStringImpl(options);
641 if (is_host_transfer()) {
642 attrs.push_back("is_host_transfer=true");
643 }
644 return attrs;
645 }
646
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const647 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
648 const HloInstruction& other,
649 const std::function<bool(const HloComputation*, const HloComputation*)>&
650 eq_computations) const {
651 // Not yet supported.
652 return false;
653 }
654
655 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)656 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
657 HloInstruction* token,
658 int64_t channel_id,
659 bool is_host_transfer)
660 : HloSendRecvInstruction(
661 HloOpcode::kSend,
662 ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
663 ShapeUtil::MakeShape(U32, {}),
664 ShapeUtil::MakeTokenShape()}),
665 channel_id, is_host_transfer) {
666 AppendOperand(operand);
667 AppendOperand(token);
668 }
669
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const670 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
671 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
672 HloCloneContext* context) const {
673 CHECK_EQ(new_operands.size(), 2);
674 return std::make_unique<HloSendInstruction>(
675 new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
676 }
677
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)678 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
679 bool is_host_transfer)
680 : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
681 CHECK_NOTNULL(operand)->channel_id().value(),
682 is_host_transfer) {
683 AppendOperand(operand);
684 }
685
686 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const687 HloSendDoneInstruction::CloneWithNewOperandsImpl(
688 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
689 HloCloneContext* context) const {
690 CHECK_EQ(new_operands.size(), 1);
691 return std::make_unique<HloSendDoneInstruction>(
692 Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
693 }
694
695 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)696 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
697 HloInstruction* token,
698 int64_t channel_id,
699 bool is_host_transfer)
700 : HloSendRecvInstruction(
701 HloOpcode::kRecv,
702 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
703 ShapeUtil::MakeTokenShape()}),
704 channel_id, is_host_transfer) {
705 AppendOperand(token);
706 }
707
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const708 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
709 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
710 HloCloneContext* context) const {
711 CHECK_EQ(new_operands.size(), 1);
712 return std::make_unique<HloRecvInstruction>(
713 ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
714 is_host_transfer());
715 }
716
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)717 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
718 bool is_host_transfer)
719 : HloSendRecvInstruction(
720 HloOpcode::kRecvDone,
721 ShapeUtil::MakeTupleShape(
722 {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
723 ShapeUtil::MakeTokenShape()}),
724 CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
725 AppendOperand(operand);
726 }
727
728 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const729 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
730 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
731 HloCloneContext* context) const {
732 CHECK_EQ(new_operands.size(), 1);
733 return std::make_unique<HloRecvDoneInstruction>(
734 Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
735 }
736
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id)737 HloCollectiveInstruction::HloCollectiveInstruction(
738 HloOpcode opcode, const Shape& shape,
739 absl::Span<HloInstruction* const> operands,
740 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
741 const std::optional<int64_t>& channel_id)
742 : HloChannelInstruction(opcode, shape, channel_id),
743 replica_groups_(SpanToVector(replica_groups)),
744 constrain_layout_(constrain_layout) {
745 for (auto operand : operands) {
746 AppendOperand(operand);
747 }
748 }
749
ToProto() const750 HloInstructionProto HloCollectiveInstruction::ToProto() const {
751 HloInstructionProto proto = HloChannelInstruction::ToProto();
752 *proto.mutable_replica_groups() = {replica_groups_.begin(),
753 replica_groups_.end()};
754 proto.set_constrain_layout(constrain_layout_);
755 return proto;
756 }
757
ExtraAttributesToStringImpl(const HloPrintOptions & options) const758 std::vector<std::string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
759 const HloPrintOptions& options) const {
760 std::vector<std::string> result =
761 HloChannelInstruction::ExtraAttributesToStringImpl(options);
762 result.push_back(
763 StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
764 if (constrain_layout_) {
765 result.push_back("constrain_layout=true");
766 }
767 return result;
768 }
769
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const770 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
771 const HloInstruction& other,
772 const std::function<bool(const HloComputation*, const HloComputation*)>&
773 eq_computations) const {
774 const auto& casted_other =
775 static_cast<const HloCollectiveInstruction&>(other);
776 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
777 other, eq_computations) &&
778 constrain_layout() == casted_other.constrain_layout() &&
779 absl::c_equal(replica_groups(), casted_other.replica_groups(),
780 [](const ReplicaGroup& a, const ReplicaGroup& b) {
781 return absl::c_equal(a.replica_ids(), b.replica_ids());
782 });
783 }
784
HloAllGatherInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)785 HloAllGatherInstruction::HloAllGatherInstruction(
786 HloOpcode opcode, const Shape& shape,
787 absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
788 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
789 const std::optional<int64_t>& channel_id, bool use_global_device_ids)
790 : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
791 constrain_layout, channel_id),
792 all_gather_dimension_(all_gather_dimension),
793 use_global_device_ids_(use_global_device_ids) {}
794
ExtraAttributesToStringImpl(const HloPrintOptions & options) const795 std::vector<std::string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
796 const HloPrintOptions& options) const {
797 std::vector<std::string> result =
798 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
799 result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
800 if (use_global_device_ids_) {
801 result.push_back("use_global_device_ids=true");
802 }
803 return result;
804 }
805
806 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const807 HloAllGatherInstruction::CloneWithNewOperandsImpl(
808 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
809 HloCloneContext* /*context*/) const {
810 return std::make_unique<HloAllGatherInstruction>(
811 opcode(), shape, new_operands, all_gather_dimension(), replica_groups(),
812 constrain_layout(), channel_id(), use_global_device_ids());
813 }
814
ToProto() const815 HloInstructionProto HloAllGatherInstruction::ToProto() const {
816 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
817 proto.add_dimensions(all_gather_dimension_);
818 proto.set_use_global_device_ids(use_global_device_ids_);
819 return proto;
820 }
821
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const822 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
823 const HloInstruction& other,
824 const std::function<bool(const HloComputation*, const HloComputation*)>&
825 eq_computations) const {
826 const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
827 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
828 other, eq_computations) &&
829 all_gather_dimension_ == casted_other.all_gather_dimension() &&
830 use_global_device_ids() == casted_other.use_global_device_ids();
831 }
832
HloAllReduceInstructionBase(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)833 HloAllReduceInstructionBase::HloAllReduceInstructionBase(
834 HloOpcode opcode, const Shape& shape,
835 absl::Span<HloInstruction* const> operands,
836 HloComputation* reduce_computation,
837 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
838 const std::optional<int64_t>& channel_id, bool use_global_device_ids)
839 : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
840 constrain_layout, channel_id),
841 use_global_device_ids_(use_global_device_ids) {
842 AppendComputation(reduce_computation);
843 }
844
ToProto() const845 HloInstructionProto HloAllReduceInstructionBase::ToProto() const {
846 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
847 proto.set_use_global_device_ids(use_global_device_ids_);
848 return proto;
849 }
850
851 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const852 HloAllReduceInstructionBase::ExtraAttributesToStringImpl(
853 const HloPrintOptions& options) const {
854 std::vector<std::string> result =
855 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
856 if (use_global_device_ids_) {
857 result.push_back("use_global_device_ids=true");
858 }
859 return result;
860 }
861
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const862 bool HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
863 const HloInstruction& other,
864 const std::function<bool(const HloComputation*, const HloComputation*)>&
865 eq_computations) const {
866 if (opcode() != other.opcode()) {
867 return false;
868 }
869 const auto& casted_other =
870 static_cast<const HloAllReduceInstructionBase&>(other);
871 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
872 other, eq_computations) &&
873 constrain_layout() == casted_other.constrain_layout() &&
874 use_global_device_ids() == casted_other.use_global_device_ids() &&
875 eq_computations(to_apply(), casted_other.to_apply());
876 }
877
IsNoop() const878 bool HloAllReduceInstruction::IsNoop() const {
879 for (const auto& replica_group : replica_groups()) {
880 if (replica_group.replica_ids().size() != 1) {
881 return false;
882 }
883 }
884 return !channel_id();
885 }
886
887 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const888 HloAllReduceInstruction::CloneWithNewOperandsImpl(
889 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
890 HloCloneContext* /*context*/) const {
891 return std::make_unique<HloAllReduceInstruction>(
892 opcode(), shape, new_operands, to_apply(), replica_groups(),
893 constrain_layout(), channel_id(), use_global_device_ids());
894 }
895
HloReduceScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)896 HloReduceScatterInstruction::HloReduceScatterInstruction(
897 const Shape& shape, absl::Span<HloInstruction* const> operands,
898 HloComputation* reduce_computation,
899 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
900 const std::optional<int64_t>& channel_id, bool use_global_device_ids,
901 int64_t scatter_dimension)
902 : HloAllReduceInstructionBase(
903 HloOpcode::kReduceScatter, shape, operands, reduce_computation,
904 replica_groups, constrain_layout, channel_id, use_global_device_ids),
905 scatter_dimension_(scatter_dimension) {}
906
907 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const908 HloReduceScatterInstruction::ExtraAttributesToStringImpl(
909 const HloPrintOptions& options) const {
910 std::vector<std::string> result =
911 HloAllReduceInstructionBase::ExtraAttributesToStringImpl(options);
912 result.push_back(StrCat("dimensions={", scatter_dimension_, "}"));
913 return result;
914 }
915
ToProto() const916 HloInstructionProto HloReduceScatterInstruction::ToProto() const {
917 HloInstructionProto proto = HloAllReduceInstructionBase::ToProto();
918 proto.add_dimensions(scatter_dimension_);
919 return proto;
920 }
921
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const922 bool HloReduceScatterInstruction::IdenticalSlowPathIgnoringChannelIdValues(
923 const HloInstruction& other,
924 const std::function<bool(const HloComputation*, const HloComputation*)>&
925 eq_computations) const {
926 const auto& casted_other =
927 static_cast<const HloReduceScatterInstruction&>(other);
928 return HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
929 other, eq_computations) &&
930 scatter_dimension_ == casted_other.scatter_dimension();
931 }
932
933 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const934 HloReduceScatterInstruction::CloneWithNewOperandsImpl(
935 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
936 HloCloneContext* /*context*/) const {
937 return std::make_unique<HloReduceScatterInstruction>(
938 shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
939 channel_id(), use_global_device_ids(), scatter_dimension());
940 }
941
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,const std::optional<int64_t> & split_dimension)942 HloAllToAllInstruction::HloAllToAllInstruction(
943 const Shape& shape, absl::Span<HloInstruction* const> operands,
944 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
945 const std::optional<int64_t>& channel_id,
946 const std::optional<int64_t>& split_dimension)
947 : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
948 replica_groups, constrain_layout, channel_id),
949 split_dimension_(split_dimension) {}
950
951 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const952 HloAllToAllInstruction::CloneWithNewOperandsImpl(
953 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
954 HloCloneContext* /*context*/) const {
955 return std::make_unique<HloAllToAllInstruction>(
956 shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
957 split_dimension());
958 }
959
ToProto() const960 HloInstructionProto HloAllToAllInstruction::ToProto() const {
961 HloInstructionProto proto = HloCollectiveInstruction::ToProto();
962 if (split_dimension_) {
963 proto.add_dimensions(*split_dimension_);
964 }
965 return proto;
966 }
967
ExtraAttributesToStringImpl(const HloPrintOptions & options) const968 std::vector<std::string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
969 const HloPrintOptions& options) const {
970 std::vector<std::string> result =
971 HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
972 if (split_dimension_) {
973 result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
974 }
975 return result;
976 }
977
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const978 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
979 const HloInstruction& other,
980 const std::function<bool(const HloComputation*, const HloComputation*)>&
981 eq_computations) const {
982 const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
983 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
984 other, eq_computations) &&
985 split_dimension_ == casted_other.split_dimension();
986 }
987
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)988 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
989 HloOpcode opcode, const Shape& shape, HloInstruction* operand,
990 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
991 const std::optional<int64_t>& channel_id)
992 : HloChannelInstruction(opcode, shape, channel_id),
993 source_target_pairs_(source_target_pairs) {
994 AppendOperand(operand);
995 }
996
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)997 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
998 HloOpcode opcode, const Shape& shape, HloInstruction* input,
999 HloInstruction* output, HloInstruction* input_start_indices,
1000 HloInstruction* output_start_indices,
1001 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1002 absl::Span<const std::vector<int64_t>> slice_sizes,
1003 const std::optional<int64_t>& channel_id)
1004 : HloChannelInstruction(opcode, shape, channel_id),
1005 source_target_pairs_(source_target_pairs.begin(),
1006 source_target_pairs.end()),
1007 slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
1008 AppendOperand(input);
1009 AppendOperand(output);
1010 AppendOperand(input_start_indices);
1011 AppendOperand(output_start_indices);
1012 }
1013
ToProto() const1014 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
1015 HloInstructionProto proto = HloChannelInstruction::ToProto();
1016 for (const auto& pair : source_target_pairs()) {
1017 auto* proto_pair = proto.add_source_target_pairs();
1018 proto_pair->set_source(pair.first);
1019 proto_pair->set_target(pair.second);
1020 }
1021 for (const auto& slice_size : dynamic_slice_sizes_list()) {
1022 for (const auto& dimension_slice_size : slice_size) {
1023 proto.add_dynamic_slice_sizes(dimension_slice_size);
1024 }
1025 }
1026 return proto;
1027 }
1028
1029 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1030 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
1031 const HloPrintOptions& options) const {
1032 std::vector<std::string> result =
1033 HloChannelInstruction::ExtraAttributesToStringImpl(options);
1034 {
1035 std::vector<std::string> strs;
1036 const auto& pairs = source_target_pairs();
1037 strs.reserve(pairs.size());
1038 for (const auto& pair : pairs) {
1039 strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
1040 }
1041 result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
1042 }
1043 if (!dynamic_slice_sizes_list().empty()) {
1044 std::vector<std::string> strs;
1045 const auto& sizes_list = dynamic_slice_sizes_list();
1046 strs.reserve(sizes_list.size());
1047 for (const auto& slice_sizes : dynamic_slice_sizes_list()) {
1048 strs.push_back(StrCat("{", StrJoin(slice_sizes, ","), "}"));
1049 }
1050 result.push_back(StrCat("slice_sizes={", StrJoin(strs, ","), "}"));
1051 }
1052 return result;
1053 }
1054
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1055 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
1056 const HloInstruction& other,
1057 const std::function<bool(const HloComputation*, const HloComputation*)>&
1058 eq_computations) const {
1059 if (opcode() != other.opcode()) {
1060 return false;
1061 }
1062 const auto& casted_other =
1063 static_cast<const HloCollectivePermuteInstruction&>(other);
1064 return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
1065 other, eq_computations) &&
1066 absl::c_equal(
1067 source_target_pairs(), casted_other.source_target_pairs(),
1068 [](const std::pair<int64_t, int64_t>& a,
1069 const std::pair<int64_t, int64_t>& b) { return a == b; }) &&
1070 absl::c_equal(
1071 dynamic_slice_sizes_list(),
1072 casted_other.dynamic_slice_sizes_list(),
1073 [](const std::vector<int64_t>& a, const std::vector<int64_t>& b) {
1074 return absl::c_equal(a, b);
1075 });
1076 }
1077
1078 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const1079 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
1080 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1081 HloCloneContext* /*context*/) const {
1082 if (dynamic_slice_sizes_list().empty()) {
1083 return std::make_unique<HloCollectivePermuteInstruction>(
1084 opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
1085 } else {
1086 return std::make_unique<HloCollectivePermuteInstruction>(
1087 opcode(), shape, new_operands[0], new_operands[1], new_operands[2],
1088 new_operands[3], source_target_pairs(), dynamic_slice_sizes_list(),
1089 channel_id());
1090 }
1091 }
1092
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1093 HloReverseInstruction::HloReverseInstruction(
1094 const Shape& shape, HloInstruction* operand,
1095 absl::Span<const int64_t> dimensions)
1096 : HloDimensionsInstruction(HloOpcode::kReverse, shape, dimensions) {
1097 AppendOperand(operand);
1098 }
1099
ToProto() const1100 HloInstructionProto HloDimensionsInstruction::ToProto() const {
1101 HloInstructionProto proto = HloInstruction::ToProto();
1102 for (int64_t dimension : dimensions_) {
1103 proto.add_dimensions(dimension);
1104 }
1105 return proto;
1106 }
1107
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1108 std::vector<std::string> HloDimensionsInstruction::ExtraAttributesToStringImpl(
1109 const HloPrintOptions& options) const {
1110 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1111 }
1112
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1113 bool HloDimensionsInstruction::IdenticalSlowPath(
1114 const HloInstruction& other,
1115 const std::function<bool(const HloComputation*, const HloComputation*)>&
1116 eq_computations) const {
1117 const auto& casted_other =
1118 static_cast<const HloDimensionsInstruction&>(other);
1119 return dimensions() == casted_other.dimensions();
1120 }
1121
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1122 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
1123 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1124 HloCloneContext* context) const {
1125 CHECK_EQ(new_operands.size(), 1);
1126 return std::make_unique<HloReverseInstruction>(shape, new_operands[0],
1127 dimensions());
1128 }
1129
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1130 HloConcatenateInstruction::HloConcatenateInstruction(
1131 const Shape& shape, absl::Span<HloInstruction* const> operands,
1132 int64_t dimension)
1133 : HloDimensionsInstruction(HloOpcode::kConcatenate, shape, {dimension}) {
1134 for (auto operand : operands) {
1135 AppendOperand(operand);
1136 }
1137 }
1138
1139 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1140 HloConcatenateInstruction::CloneWithNewOperandsImpl(
1141 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1142 HloCloneContext* context) const {
1143 return std::make_unique<HloConcatenateInstruction>(shape, new_operands,
1144 concatenate_dimension());
1145 }
1146
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1147 HloReduceInstruction::HloReduceInstruction(
1148 const Shape& shape, absl::Span<HloInstruction* const> args,
1149 absl::Span<const int64_t> dimensions_to_reduce,
1150 HloComputation* reduce_computation)
1151 : HloDimensionsInstruction(HloOpcode::kReduce, shape,
1152 dimensions_to_reduce) {
1153 for (HloInstruction* arg : args) {
1154 AppendOperand(arg);
1155 }
1156 AppendComputation(reduce_computation);
1157 }
1158
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1159 bool HloReduceInstruction::IdenticalSlowPath(
1160 const HloInstruction& other,
1161 const std::function<bool(const HloComputation*, const HloComputation*)>&
1162 eq_computations) const {
1163 const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
1164 // Reduction results are determined by the reduction dimension and the
1165 // reduction computation.
1166 return dimensions() == casted_other.dimensions() &&
1167 eq_computations(to_apply(), casted_other.to_apply());
1168 }
1169
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1170 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
1171 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1172 HloCloneContext* context) const {
1173 CHECK_EQ(new_operands.size() % 2, 0);
1174 return std::make_unique<HloReduceInstruction>(shape, new_operands,
1175 dimensions(), to_apply());
1176 }
1177
HloSortInstruction(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1178 HloSortInstruction::HloSortInstruction(
1179 const Shape& shape, int64_t dimension,
1180 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1181 bool is_stable)
1182 : HloDimensionsInstruction(HloOpcode::kSort, shape, {dimension}),
1183 is_stable_(is_stable) {
1184 for (auto* value : operands) {
1185 AppendOperand(value);
1186 }
1187 AppendComputation(compare);
1188 }
1189
ToProto() const1190 HloInstructionProto HloSortInstruction::ToProto() const {
1191 HloInstructionProto proto = HloInstruction::ToProto();
1192 for (int64_t dimension : dimensions_) {
1193 proto.add_dimensions(dimension);
1194 }
1195 proto.set_is_stable(is_stable());
1196 return proto;
1197 }
1198
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1199 std::vector<std::string> HloSortInstruction::ExtraAttributesToStringImpl(
1200 const HloPrintOptions& options) const {
1201 std::vector<std::string> attrs;
1202 attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
1203 if (is_stable()) {
1204 attrs.push_back("is_stable=true");
1205 }
1206 return attrs;
1207 }
1208
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1209 bool HloSortInstruction::IdenticalSlowPath(
1210 const HloInstruction& other,
1211 const std::function<bool(const HloComputation*, const HloComputation*)>&
1212 eq_computations) const {
1213 const auto& casted_other = static_cast<const HloSortInstruction&>(other);
1214 if (dimensions() != casted_other.dimensions()) {
1215 return false;
1216 }
1217 if (is_stable() != casted_other.is_stable()) {
1218 return false;
1219 }
1220 return eq_computations(to_apply(), other.to_apply());
1221 }
1222
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1223 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
1224 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1225 HloCloneContext* context) const {
1226 return std::make_unique<HloSortInstruction>(
1227 shape, dimensions_[0], new_operands, to_apply(), is_stable());
1228 }
1229
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1230 HloTransposeInstruction::HloTransposeInstruction(
1231 const Shape& shape, HloInstruction* operand,
1232 absl::Span<const int64_t> dimensions)
1233 : HloDimensionsInstruction(HloOpcode::kTranspose, shape, dimensions) {
1234 AppendOperand(operand);
1235 }
1236
IsRank2Transpose() const1237 bool HloTransposeInstruction::IsRank2Transpose() const {
1238 return dimensions() == std::vector<int64_t>({1, 0}) &&
1239 shape().dimensions_size() == 2 &&
1240 std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1241 operand(0)->shape().dimensions().rbegin());
1242 }
1243
1244 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1245 HloTransposeInstruction::CloneWithNewOperandsImpl(
1246 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1247 HloCloneContext* context) const {
1248 CHECK_EQ(new_operands.size(), 1);
1249 return std::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1250 dimensions());
1251 }
1252
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> broadcast_dimension)1253 HloBroadcastInstruction::HloBroadcastInstruction(
1254 const Shape& shape, HloInstruction* operand,
1255 absl::Span<const int64_t> broadcast_dimension)
1256 : HloDimensionsInstruction(HloOpcode::kBroadcast, shape,
1257 broadcast_dimension) {
1258 AppendOperand(operand);
1259 }
1260
1261 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1262 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1263 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1264 HloCloneContext* context) const {
1265 CHECK_EQ(new_operands.size(), 1);
1266 return std::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1267 dimensions());
1268 }
1269
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1270 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1271 const Shape& shape, HloInstruction* data_operand,
1272 absl::Span<HloInstruction* const> dim_sizes)
1273 : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1274 AppendOperand(data_operand);
1275 for (auto operand : dim_sizes) {
1276 AppendOperand(operand);
1277 }
1278 }
1279
1280 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1281 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1282 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1283 HloCloneContext* context) const {
1284 CHECK_GE(new_operands.size(), 1);
1285 return std::make_unique<HloDynamicReshapeInstruction>(
1286 shape, new_operands[0], new_operands.subspan(1));
1287 }
1288
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1289 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1290 HloInstruction* operand,
1291 int64_t inferred_dimension)
1292 : HloInstruction(HloOpcode::kReshape, shape),
1293 inferred_dimension_(inferred_dimension) {
1294 AppendOperand(operand);
1295 }
1296
ToProto() const1297 HloInstructionProto HloReshapeInstruction::ToProto() const {
1298 HloInstructionProto proto = HloInstruction::ToProto();
1299 if (inferred_dimension_ != -1) {
1300 proto.add_dimensions(inferred_dimension_);
1301 }
1302 return proto;
1303 }
1304
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1305 std::vector<std::string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1306 const HloPrintOptions& options) const {
1307 if (inferred_dimension() == -1) {
1308 return {};
1309 }
1310 return {StrCat("inferred_dimension=", inferred_dimension())};
1311 }
1312
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1313 bool HloReshapeInstruction::IdenticalSlowPath(
1314 const HloInstruction& other,
1315 const std::function<bool(const HloComputation*, const HloComputation*)>&
1316 eq_computations) const {
1317 const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1318 return inferred_dimension() == casted_other.inferred_dimension();
1319 }
1320
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1321 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1322 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1323 HloCloneContext* context) const {
1324 CHECK_EQ(new_operands.size(), 1);
1325 return std::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1326 inferred_dimension());
1327 }
1328
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1329 HloMapInstruction::HloMapInstruction(const Shape& shape,
1330 absl::Span<HloInstruction* const> operands,
1331 HloComputation* map_computation)
1332 : HloInstruction(HloOpcode::kMap, shape) {
1333 for (auto operand : operands) {
1334 AppendOperand(operand);
1335 }
1336 AppendComputation(map_computation);
1337 // TODO(b/65689298) Remove code below once Map is generalized to accept
1338 // arbitrary map dimensions.
1339 dimensions_.resize(shape.rank());
1340 std::iota(dimensions_.begin(), dimensions_.end(), 0);
1341 }
1342
ToProto() const1343 HloInstructionProto HloMapInstruction::ToProto() const {
1344 HloInstructionProto proto = HloInstruction::ToProto();
1345 for (int64_t dimension : dimensions_) {
1346 proto.add_dimensions(dimension);
1347 }
1348 return proto;
1349 }
1350
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1351 bool HloMapInstruction::IsElementwiseImpl(
1352 const std::optional<int64_t>& operand_idx) const {
1353 if (!dimensions().empty()) {
1354 // Check that the map is executed in elementwise compatible dimensions.
1355 if (dimensions().size() != shape().dimensions_size()) {
1356 return false;
1357 }
1358 for (int i = 0; i < dimensions().size(); ++i) {
1359 if (dimensions()[i] != i) {
1360 return false;
1361 }
1362 }
1363 }
1364 return true;
1365 }
1366
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1367 std::vector<std::string> HloMapInstruction::ExtraAttributesToStringImpl(
1368 const HloPrintOptions& options) const {
1369 return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1370 }
1371
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1372 bool HloMapInstruction::IdenticalSlowPath(
1373 const HloInstruction& other,
1374 const std::function<bool(const HloComputation*, const HloComputation*)>&
1375 eq_computations) const {
1376 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1377 return eq_computations(to_apply(), casted_other.to_apply()) &&
1378 dimensions() == casted_other.dimensions();
1379 }
1380
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1381 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1382 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1383 HloCloneContext* context) const {
1384 return std::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1385 }
1386
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1387 HloSliceInstruction::HloSliceInstruction(
1388 const Shape& shape, HloInstruction* operand,
1389 absl::Span<const int64_t> start_indices,
1390 absl::Span<const int64_t> limit_indices, absl::Span<const int64_t> strides)
1391 : HloInstruction(HloOpcode::kSlice, shape),
1392 slice_starts_(start_indices.begin(), start_indices.end()),
1393 slice_limits_(limit_indices.begin(), limit_indices.end()),
1394 slice_strides_(strides.begin(), strides.end()) {
1395 AppendOperand(operand);
1396 // For backward compatibility with old serialized computations: if there are
1397 // no strides, assume all strides are 1.
1398 // TODO(b/63317920): remove this code.
1399 if (slice_strides_.empty()) {
1400 slice_strides_ = std::vector<int64_t>(start_indices.size(), 1LL);
1401 }
1402 }
1403
ToProto() const1404 HloInstructionProto HloSliceInstruction::ToProto() const {
1405 HloInstructionProto proto = HloInstruction::ToProto();
1406 for (int i = 0; i < slice_starts_.size(); ++i) {
1407 auto* slice_dimension = proto.add_slice_dimensions();
1408 slice_dimension->set_start(slice_starts_[i]);
1409 slice_dimension->set_limit(slice_limits_[i]);
1410 slice_dimension->set_stride(slice_strides_[i]);
1411 }
1412 return proto;
1413 }
1414
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1415 std::vector<std::string> HloSliceInstruction::ExtraAttributesToStringImpl(
1416 const HloPrintOptions& options) const {
1417 std::vector<std::string> bounds;
1418 bounds.reserve(slice_starts_.size());
1419 const bool omit_stride = absl::c_all_of(
1420 slice_strides_, [](int64_t stride) { return stride == 1; });
1421 for (int i = 0; i < slice_starts_.size(); ++i) {
1422 std::string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1423 bounds.push_back(
1424 StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1425 }
1426 return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1427 }
1428
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1429 bool HloSliceInstruction::IdenticalSlowPath(
1430 const HloInstruction& other,
1431 const std::function<bool(const HloComputation*, const HloComputation*)>&
1432 eq_computations) const {
1433 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1434 return slice_starts_ == other_slice.slice_starts_ &&
1435 slice_limits_ == other_slice.slice_limits_ &&
1436 slice_strides_ == other_slice.slice_strides_;
1437 }
1438
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1439 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1440 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1441 HloCloneContext* context) const {
1442 CHECK_EQ(new_operands.size(), 1);
1443 return std::make_unique<HloSliceInstruction>(
1444 shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1445 }
1446
HloConstantInstruction(Literal literal)1447 HloConstantInstruction::HloConstantInstruction(Literal literal)
1448 : HloInstruction(HloOpcode::kConstant, literal.shape()),
1449 literal_(std::move(literal)) {}
1450
HloConstantInstruction(Literal literal,const Shape & shape)1451 HloConstantInstruction::HloConstantInstruction(Literal literal,
1452 const Shape& shape)
1453 : HloInstruction(HloOpcode::kConstant, shape),
1454 literal_(std::move(literal)) {}
1455
HloConstantInstruction(const Shape & shape)1456 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1457 : HloInstruction(HloOpcode::kConstant, shape) {}
1458
ToProto() const1459 HloInstructionProto HloConstantInstruction::ToProto() const {
1460 HloInstructionProto proto = HloInstruction::ToProto();
1461 if (literal_.has_value()) {
1462 *proto.mutable_literal() = literal_->ToProto();
1463 }
1464 return proto;
1465 }
1466
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1467 bool HloConstantInstruction::IsElementwiseImpl(
1468 const std::optional<int64_t>& operand_idx) const {
1469 return true;
1470 }
1471
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1472 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1473 const ShapeIndex& shape_index) {
1474 Shape* mutable_array_subshape =
1475 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1476 CHECK(mutable_array_subshape->IsArray());
1477
1478 // Normally array_subshape will always have a layout, but this invariant is
1479 // temporarily broken in LayoutAssignment::AssignLayouts.
1480
1481 if (!mutable_array_subshape->has_layout() ||
1482 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1483 *literal_ = literal_->Relayout(new_layout, shape_index);
1484 *mutable_array_subshape->mutable_layout() = new_layout;
1485 }
1486 }
1487
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1488 bool HloConstantInstruction::IdenticalSlowPath(
1489 const HloInstruction& other,
1490 const std::function<bool(const HloComputation*, const HloComputation*)>&
1491 eq_computations) const {
1492 const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1493 return literal() == other_slice.literal();
1494 }
1495
1496 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1497 HloConstantInstruction::CloneWithNewOperandsImpl(
1498 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1499 HloCloneContext* context) const {
1500 if (!literal_.has_value()) {
1501 return std::make_unique<HloConstantInstruction>(this->shape());
1502 }
1503 CHECK(literal_.has_value());
1504 // Literal's shape may have no/different tiling info. Use this instruction's
1505 // shape instead.
1506 CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1507 this->shape()));
1508 return std::make_unique<HloConstantInstruction>(literal_->Clone(),
1509 this->shape());
1510 }
1511
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1512 std::string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1513 const HloPrintOptions& options,
1514 CanonicalNameMap* canonical_name_map) const {
1515 if (options.print_only_essential_constants()) {
1516 if (!literal_.has_value()) {
1517 return "{...}";
1518 }
1519 if (literal().IsAll(0)) {
1520 return "0";
1521 }
1522 if (literal().IsAll(1)) {
1523 return "1";
1524 }
1525 if (shape().IsInteger()) {
1526 return literal_->ToStringWithoutShapeOneline();
1527 }
1528 return "{...}";
1529 }
1530
1531 // For constants, show the actual value in place of an empty operand list.
1532 if (literal_.has_value() &&
1533 ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1534 options.print_large_constants())) {
1535 // Literal::ToString emits multidimensional arrays over multiple
1536 // lines. Compact this into one line by stripping out white space.
1537 return literal_->ToStringWithoutShapeOneline();
1538 } else {
1539 // Do not show large constants or tuples.
1540 return "{...}";
1541 }
1542 }
1543
HloCallableInstruction(HloOpcode opcode,const Shape & shape)1544 HloCallableInstruction::HloCallableInstruction(HloOpcode opcode,
1545 const Shape& shape)
1546 : HloInstruction(opcode, shape) {}
1547
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands)1548 HloCallableInstruction::HloCallableInstruction(
1549 HloOpcode opcode, const Shape& shape,
1550 absl::Span<HloInstruction* const> operands)
1551 : HloInstruction(opcode, shape) {
1552 for (auto operand : operands) {
1553 AppendOperand(operand);
1554 }
1555 SetAndSanitizeName(HloOpcodeString(opcode));
1556 }
1557
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * called_computation,absl::string_view prefix)1558 HloCallableInstruction::HloCallableInstruction(
1559 HloOpcode opcode, const Shape& shape,
1560 absl::Span<HloInstruction* const> operands,
1561 HloComputation* called_computation, absl::string_view prefix)
1562 : HloInstruction(opcode, shape) {
1563 for (auto operand : operands) {
1564 AppendOperand(operand);
1565 }
1566 SetAndSanitizeName(std::string(prefix) + HloOpcodeString(opcode));
1567 AppendComputation(called_computation);
1568 }
1569
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations)1570 HloCallableInstruction::HloCallableInstruction(
1571 HloOpcode opcode, const Shape& shape,
1572 absl::Span<HloInstruction* const> operands,
1573 absl::Span<HloComputation* const> called_computations)
1574 : HloInstruction(opcode, shape) {
1575 for (auto operand : operands) {
1576 AppendOperand(operand);
1577 }
1578 SetAndSanitizeName(HloOpcodeString(opcode));
1579 for (auto called_computation : called_computations) {
1580 AppendComputation(called_computation);
1581 }
1582 }
1583
~HloCallableInstruction()1584 HloCallableInstruction::~HloCallableInstruction() { ClearCalledComputations(); }
1585
called_computation() const1586 HloComputation* HloCallableInstruction::called_computation() const {
1587 CHECK(!called_computations().empty());
1588 return called_computations().front();
1589 }
1590
called_computation_root() const1591 HloInstruction* HloCallableInstruction::called_computation_root() const {
1592 return called_computation()->root_instruction();
1593 }
1594
AddCallOperand(HloInstruction * new_operand)1595 HloInstruction* HloCallableInstruction::AddCallOperand(
1596 HloInstruction* new_operand) {
1597 CHECK_EQ(operand_count(),
1598 called_computation()->parameter_instructions().size());
1599 const int64_t param_no = operand_count();
1600 std::string param_name = StrCat("param_", param_no);
1601 HloInstruction* called_computation_parameter =
1602 called_computation()->AddParameter(HloInstruction::CreateParameter(
1603 param_no, new_operand->shape(), param_name));
1604 AppendOperand(new_operand);
1605 return called_computation_parameter;
1606 }
1607
AppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)1608 HloInstruction* HloCallableInstruction::AppendInstructionIntoCalledComputation(
1609 HloInstruction* instruction_to_append, bool add_output) {
1610 // When add_output is false, this callable instruction must be a user of
1611 // instruction_to_append.
1612 if (!add_output) {
1613 CHECK(IsUserOf(instruction_to_append));
1614 }
1615 return CloneAndAppendInstructionIntoCalledComputation(instruction_to_append,
1616 add_output);
1617 }
1618
1619 HloInstruction*
CloneAndAppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)1620 HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation(
1621 HloInstruction* instruction_to_append, bool add_output) {
1622 CHECK(instruction_to_append->IsFusible())
1623 << instruction_to_append->ToString();
1624 VLOG(3) << "CloneAndAppendInstructionIntoCalledComputation:\n"
1625 << instruction_to_append->ToString();
1626 HloInstruction* clone = nullptr;
1627 if (called_computations().empty()) {
1628 // New fusion instruction. It should not be a multioutput instruction.
1629 CHECK(!add_output);
1630 auto builder = HloComputation::Builder(
1631 default_called_computation_name(),
1632 opcode() == HloOpcode::kFusion ? this : nullptr);
1633 builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/""));
1634 AppendComputation(
1635 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1636 clone = called_computation_root();
1637 } else {
1638 // When add_output is false, instruction_to_append is necessarily an operand
1639 // of the callable instruction. After appending this will no longer be the
1640 // case. Remove the operand from the operand list and remove its
1641 // corresponding called computation parameter instruction.
1642 bool in_operand_list =
1643 absl::c_linear_search(operands(), instruction_to_append);
1644 CHECK(add_output || in_operand_list);
1645 if (instruction_to_append->opcode() == HloOpcode::kTuple) {
1646 // We assume all uses of a kTuple operation are GTE ops. In this case, we
1647 // don't need to clone 'instruction_to_append'.
1648 CHECK(!in_operand_list);
1649 clone = instruction_to_append;
1650 } else {
1651 clone = called_computation()->AddInstruction(
1652 instruction_to_append->Clone(/*suffix=*/""));
1653 }
1654 const std::vector<HloInstruction*>& called_computation_parameters =
1655 called_computation()->parameter_instructions();
1656 for (int64_t operand_num = 0; operand_num < operand_count();
1657 ++operand_num) {
1658 if (instruction_to_append == operand(operand_num)) {
1659 // Replace the called computation parameter instruction's uses with the
1660 // clone.
1661 HloInstruction* called_computation_parameter =
1662 called_computation_parameters[operand_num];
1663 TF_CHECK_OK(called_computation_parameter->ReplaceAllUsesWith(clone));
1664
1665 // Remove the corresponding called computation parameter and operand
1666 // from their respective vectors.
1667 TF_CHECK_OK(called_computation()->RemoveParameter(operand_num));
1668 RemoveOperandAt(operand_num);
1669 break;
1670 }
1671 }
1672 // We've cloned instruction_to_append into this callable instruction, so
1673 // this callable instruction is no longer a use of instruction_to_append.
1674 if (in_operand_list) {
1675 DetachFrom(instruction_to_append);
1676 // When the instruction_to_append does not have other users, we don't need
1677 // to generate a multioutput instruction.
1678 if (instruction_to_append->user_count() == 0) {
1679 add_output = false;
1680 }
1681 }
1682 }
1683
1684 // Reread the parameters in the computation.
1685 const std::vector<HloInstruction*>& called_computation_parameters =
1686 called_computation()->parameter_instructions();
1687
1688 // Add each operand of the clone as an operand of the callable instruction. A
1689 // complication is that some clone operands may already be operands of the
1690 // callable instruction.
1691 for (int64_t operand_num = 0; operand_num < clone->operand_count();
1692 ++operand_num) {
1693 HloInstruction* operand = clone->mutable_operand(operand_num);
1694
1695 // See if this operand is already an operand of the callable instruction.
1696 CHECK_EQ(operands().size(), called_computation_parameters.size());
1697 HloInstruction* called_computation_parameter = nullptr;
1698 for (int64_t i = 0; i < operands().size(); ++i) {
1699 if (this->operand(i) == operand) {
1700 called_computation_parameter = called_computation_parameters[i];
1701 break;
1702 }
1703 }
1704
1705 if (called_computation_parameter == nullptr) {
1706 // Clone's operand was not already an operand of the callable instruction.
1707 // Add it as an operand and add a corresponding called computation
1708 // parameter instruction.
1709 called_computation_parameter = AddCallOperand(operand);
1710 }
1711 TF_CHECK_OK(
1712 clone->ReplaceOperandWith(operand_num, called_computation_parameter));
1713 }
1714
1715 if (add_output) {
1716 CHECK_GT(instruction_to_append->user_count(), 0);
1717 // If this is already a multioutput instruction, expand the root tuple by 1.
1718 HloInstruction* root = called_computation_root();
1719 HloInstruction::InstructionVector tuple_elements;
1720 bool newly_created_tuple_instr = false;
1721 if (root->opcode() == HloOpcode::kTuple) {
1722 tuple_elements = root->operands();
1723 } else {
1724 tuple_elements.push_back(root);
1725 newly_created_tuple_instr = true;
1726 }
1727 if (clone->opcode() == HloOpcode::kTuple) {
1728 for (auto inst : clone->operands()) {
1729 tuple_elements.push_back(inst);
1730 }
1731 } else {
1732 tuple_elements.push_back(clone);
1733 }
1734 HloInstruction* new_root = called_computation()->AddInstruction(
1735 HloInstruction::CreateTuple(tuple_elements));
1736 called_computation()->set_root_instruction(new_root,
1737 /*accept_different_shape=*/true);
1738 *mutable_shape() = new_root->shape();
1739 if (root->opcode() == HloOpcode::kTuple) {
1740 TF_CHECK_OK(called_computation()->RemoveInstruction(root));
1741 }
1742
1743 // If this is a newly created multioutput instruction, we need to update
1744 // the use of the original callable instruction.
1745 if (newly_created_tuple_instr) {
1746 HloInstruction* new_instr = parent()->AddInstruction(
1747 HloInstruction::CreateGetTupleElement(root->shape(), this, 0));
1748 TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1749 }
1750 int64_t index = tuple_elements.size();
1751 if (instruction_to_append->opcode() == HloOpcode::kTuple) {
1752 CHECK_EQ(clone, instruction_to_append);
1753 index -= clone->operand_count();
1754 std::vector<HloInstruction*> to_be_removed;
1755 const auto& users = clone->users();
1756 to_be_removed.reserve(users.size());
1757 for (auto old_gte : users) {
1758 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1759 int64_t old_tuple_index = old_gte->tuple_index();
1760 HloInstruction* new_gte =
1761 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1762 old_gte->shape(), this, index + old_tuple_index));
1763 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1764 to_be_removed.push_back(old_gte);
1765 }
1766 for (auto old_gte : to_be_removed) {
1767 TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1768 }
1769 } else {
1770 HloInstruction* new_gte =
1771 parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1772 clone->shape(), this, index - 1));
1773 TF_CHECK_OK(instruction_to_append->ReplaceAllUsesWith(new_gte));
1774 }
1775 }
1776
1777 if (clone != instruction_to_append) {
1778 VLOG(2) << "New clone:\n" << clone->ToString();
1779 }
1780 return clone;
1781 }
1782
1783 absl::InlinedVector<HloComputation*, 1>
GetOrCloneCalledComputations(HloCloneContext * context) const1784 HloCallableInstruction::GetOrCloneCalledComputations(
1785 HloCloneContext* context) const {
1786 HloModule* module = context != nullptr ? context->module() : GetModule();
1787 absl::InlinedVector<HloComputation*, 1> new_called_computations;
1788 for (auto* comp : called_computations()) {
1789 HloComputation* new_custom_call_computation = nullptr;
1790 if (context != nullptr) {
1791 new_custom_call_computation = context->FindComputation(comp);
1792 }
1793 if (new_custom_call_computation == nullptr) {
1794 new_custom_call_computation =
1795 module->AddEmbeddedComputation(comp->Clone("clone", context));
1796 }
1797 new_called_computations.push_back(new_custom_call_computation);
1798 }
1799 return new_called_computations;
1800 }
1801
RecursivelySetComputationsThreadName(absl::string_view execution_thread,bool skip_async_execution_thread_overwrite)1802 void HloCallableInstruction::RecursivelySetComputationsThreadName(
1803 absl::string_view execution_thread,
1804 bool skip_async_execution_thread_overwrite) {
1805 for (HloComputation* comp : called_computations()) {
1806 SetThreadName(comp, execution_thread,
1807 skip_async_execution_thread_overwrite);
1808 }
1809 }
1810
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1811 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1812 FusionKind fusion_kind,
1813 HloInstruction* fused_root)
1814 : HloCallableInstruction(HloOpcode::kFusion, shape),
1815 fusion_kind_(fusion_kind) {
1816 CHECK(fused_root != nullptr);
1817 SetAndSanitizeName(HloOpcodeString(opcode()));
1818 set_parent(fused_root->parent());
1819 set_metadata(fused_root->metadata());
1820 CHECK(fused_root->IsFusible()) << fused_root->ToString();
1821 CloneAndAppendInstructionIntoCalledComputation(fused_root);
1822 }
1823
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation,absl::string_view prefix)1824 HloFusionInstruction::HloFusionInstruction(
1825 const Shape& shape, FusionKind fusion_kind,
1826 absl::Span<HloInstruction* const> operands,
1827 HloComputation* fusion_computation, absl::string_view prefix)
1828 : HloCallableInstruction(HloOpcode::kFusion, shape, operands,
1829 fusion_computation, prefix),
1830 fusion_kind_(fusion_kind) {
1831 fusion_computation->SetFusionInstruction(this);
1832 }
1833
~HloFusionInstruction()1834 HloFusionInstruction::~HloFusionInstruction() {
1835 ClearFusionComputationInstruction();
1836 }
1837
ClearFusionComputationInstruction()1838 void HloFusionInstruction::ClearFusionComputationInstruction() {
1839 // Each fusion calls a single computation, but we use called_computations()
1840 // instead of fused_instructions_computation(), because the order in which
1841 // things get destructed can vary; the fusion computation's back-pointer may
1842 // already be null, which violates a check in fused_instructions_computation.
1843 for (HloComputation* computation : called_computations()) {
1844 // Some passes that rewrite fusions may reassign a fusion computation to a
1845 // different fusion instruction as this instruction gets destructed.
1846 if (computation->FusionInstruction() == this) {
1847 computation->SetFusionInstruction(nullptr);
1848 }
1849 }
1850 }
1851
ClearCalledComputations()1852 void HloFusionInstruction::ClearCalledComputations() {
1853 ClearFusionComputationInstruction();
1854 HloInstruction::ClearCalledComputations();
1855 }
1856
ToCategory() const1857 std::string HloFusionInstruction::ToCategory() const {
1858 switch (fusion_kind()) {
1859 case FusionKind::kLoop:
1860 return "loop fusion";
1861 case FusionKind::kInput:
1862 return "input fusion";
1863 case FusionKind::kOutput:
1864 return "output fusion";
1865 case FusionKind::kCustom:
1866 return "custom fusion";
1867 }
1868 }
1869
ToProto() const1870 HloInstructionProto HloFusionInstruction::ToProto() const {
1871 HloInstructionProto proto = HloInstruction::ToProto();
1872 proto.set_fusion_kind(xla::ToString(fusion_kind()));
1873 proto.add_called_computation_ids(
1874 fused_instructions_computation()->unique_id());
1875 return proto;
1876 }
1877
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1878 bool HloFusionInstruction::IsElementwiseImpl(
1879 const std::optional<int64_t>& operand_idx) const {
1880 if (!operand_idx.has_value()) {
1881 for (auto* fused : fused_instructions()) {
1882 if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1883 return false;
1884 }
1885 }
1886 return true;
1887 }
1888 // A loop-fusion is elementwise on an operand if all operations (computed
1889 // using BFS) between the operand and the fused root are elementwise.
1890 std::deque<HloInstruction*> worklist;
1891 absl::flat_hash_set<const HloInstruction*> visited;
1892 worklist.push_back(fused_parameter(operand_idx.value()));
1893 visited.insert(fused_parameter(operand_idx.value()));
1894 while (!worklist.empty()) {
1895 HloInstruction* operand = worklist.front();
1896 worklist.pop_front();
1897 for (HloInstruction* user : operand->users()) {
1898 CHECK_GE(user->unique_id(), 0);
1899 if (ContainsKey(visited, user)) {
1900 continue;
1901 }
1902 if (user->IsElementwise() ||
1903 IsInstructionElementwiseOnOperand(user, operand)) {
1904 worklist.push_back(user);
1905 visited.insert(user);
1906 } else {
1907 return false;
1908 }
1909 }
1910 }
1911 return true;
1912 }
1913
AddFusionOperand(HloInstruction * new_operand)1914 HloInstruction* HloFusionInstruction::AddFusionOperand(
1915 HloInstruction* new_operand) {
1916 return AddCallOperand(new_operand);
1917 }
1918
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1919 void HloFusionInstruction::MergeFusionInstruction(
1920 HloFusionInstruction* instruction_to_merge) {
1921 CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1922 // Clone the instruction from which to merge fused instructions.
1923 std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1924 HloFusionInstruction* cloned_fusion =
1925 static_cast<HloFusionInstruction*>(cloned.get());
1926 // Replace uses of fused parameters with the corresponding operand of the
1927 // fusion. Add all non-parameter fused instructions to
1928 // 'unfused_instructions' to be merged into 'this'. This is done in reverse
1929 // post order.
1930 std::vector<HloInstruction*> unfused_instructions;
1931 auto fused_instructions = cloned_fusion->fused_instructions_computation()
1932 ->MakeInstructionPostOrder();
1933 for (auto fused_it = fused_instructions.rbegin();
1934 fused_it != fused_instructions.rend(); ++fused_it) {
1935 auto fused_instruction = *fused_it;
1936 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1937 TF_CHECK_OK(
1938 fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1939 fused_instruction->parameter_number())));
1940 } else {
1941 unfused_instructions.push_back(fused_instruction);
1942 }
1943 }
1944
1945 // If there are no unfused instructions, the fused computation must consist
1946 // only of kParameter instructions. Make the operand of the corresponding
1947 // parameter number the new root.
1948 HloInstruction* unfused_root =
1949 unfused_instructions.empty()
1950 ? instruction_to_merge->mutable_operand(
1951 instruction_to_merge->fused_instructions_computation()
1952 ->root_instruction()
1953 ->parameter_number())
1954 : unfused_instructions.front();
1955 CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1956 unfused_instructions.empty());
1957 // Replace instruction_to_merge use of 'this' with unfused_root.
1958 TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1959
1960 // Build a dummy root for the cloned fusion as we may remove the original root
1961 // in the fusion process.
1962 if (!unfused_instructions.empty()) {
1963 HloComputation* computation = unfused_root->parent();
1964 auto* dummy_root = computation->AddInstruction(
1965 HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1966 computation->set_root_instruction(dummy_root,
1967 /*accept_different_shape=*/true);
1968 }
1969
1970 // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1971 // we remove it from the closed fusion node. This is so that we don't add
1972 // extra users to the producer of that instruction (we use user count to
1973 // decide if a side-effectful instruction is fusible).
1974 for (auto& instruction : unfused_instructions) {
1975 auto* fused = FuseInstruction(instruction);
1976 TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1977 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1978 }
1979 CHECK_EQ(0, cloned_fusion->user_count());
1980 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1981 cloned_fusion->fused_instructions_computation()));
1982 }
1983
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1984 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1985 HloFusionInstruction* instruction_to_merge) {
1986 // Add all non-parameter fused instructions to 'unfused_instructions' to be
1987 // merged into 'this'. `old_to_new' maps the instructions in the fused node
1988 // to the disassembled fusion instructions.
1989 // Note that we add the unfused instructions to this->parent_ computation.
1990 // This is necessary because the unique_id needs for an instruction and
1991 // it's only added when inserting to the computation.
1992 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1993 std::vector<HloInstruction*> unfused_instructions;
1994 auto computation_to_merge =
1995 instruction_to_merge->fused_instructions_computation();
1996 auto post_order = computation_to_merge->MakeInstructionPostOrder();
1997 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1998 auto fused_instruction = *rit;
1999 if (fused_instruction->opcode() == HloOpcode::kParameter) {
2000 InsertOrDie(&old_to_new, fused_instruction,
2001 instruction_to_merge->mutable_operand(
2002 fused_instruction->parameter_number()));
2003 continue;
2004 }
2005
2006 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
2007 // which clones again. This can be improved.
2008 auto cloned_instruction =
2009 parent()->AddInstruction(fused_instruction->Clone());
2010 unfused_instructions.push_back(cloned_instruction);
2011 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
2012 }
2013 for (auto unfused_instruction : unfused_instructions) {
2014 for (int64_t index = 0; index < unfused_instruction->operand_count();
2015 index++) {
2016 auto new_operand =
2017 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
2018 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
2019 }
2020 }
2021
2022 // If there are no unfused instructions, the fused computation must consist
2023 // only of kParameter instructions. Make the operand of the corresponding
2024 // parameter number the new root.
2025 HloInstruction* unfused_root =
2026 unfused_instructions.empty()
2027 ? instruction_to_merge->mutable_operand(
2028 instruction_to_merge->fused_instructions_computation()
2029 ->root_instruction()
2030 ->parameter_number())
2031 : unfused_instructions.front();
2032 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
2033
2034 TF_CHECK_OK(
2035 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
2036 if (GetModule()) {
2037 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
2038 }
2039
2040 // Fuse the root instruction and generate multiple outputs.
2041 if (unfused_instructions.empty()) {
2042 return;
2043 }
2044 FuseInstructionIntoMultiOutput(unfused_root);
2045 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
2046 // The rest instructions are of normal fusing.
2047 for (int64_t i = 1; i < unfused_instructions.size(); i++) {
2048 auto instruction = unfused_instructions[i];
2049 FuseInstruction(instruction);
2050 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
2051 }
2052 }
2053
fused_instructions_computation() const2054 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
2055 CHECK(!called_computations().empty());
2056 auto* fused_instructions_computation = called_computations().front();
2057 CHECK(fused_instructions_computation->IsFusionComputation())
2058 << "Computation " << fused_instructions_computation->name()
2059 << " is not a fusion kind";
2060 return fused_instructions_computation;
2061 }
2062
fused_expression_root() const2063 HloInstruction* HloFusionInstruction::fused_expression_root() const {
2064 return fused_instructions_computation()->root_instruction();
2065 }
2066
fused_parameter(int64_t parameter_number) const2067 HloInstruction* HloFusionInstruction::fused_parameter(
2068 int64_t parameter_number) const {
2069 return fused_instructions_computation()->parameter_instruction(
2070 parameter_number);
2071 }
2072
fused_parameters() const2073 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
2074 const {
2075 return fused_instructions_computation()->parameter_instructions();
2076 }
2077
2078 const tensorflow::gtl::iterator_range<UnwrappingIterator<
2079 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const2080 HloFusionInstruction::fused_instructions() const {
2081 const HloComputation* subcomp = fused_instructions_computation();
2082 return subcomp->instructions();
2083 }
2084
2085 const tensorflow::gtl::iterator_range<
2086 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()2087 HloFusionInstruction::fused_instructions() {
2088 return fused_instructions_computation()->instructions();
2089 }
2090
fused_instruction_count() const2091 int64_t HloFusionInstruction::fused_instruction_count() const {
2092 return fused_instructions_computation()->instruction_count();
2093 }
2094
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2095 std::vector<std::string> HloFusionInstruction::ExtraAttributesToStringImpl(
2096 const HloPrintOptions& options) const {
2097 return {StrCat("kind=", xla::ToString(fusion_kind()))};
2098 }
2099
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2100 bool HloFusionInstruction::IdenticalSlowPath(
2101 const HloInstruction& other,
2102 const std::function<bool(const HloComputation*, const HloComputation*)>&
2103 eq_computations) const {
2104 return fusion_kind() == other.fusion_kind() &&
2105 eq_computations(fused_instructions_computation(),
2106 other.fused_instructions_computation());
2107 }
2108
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2109 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
2110 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2111 HloCloneContext* context) const {
2112 auto new_fused_computation = GetOrCloneCalledComputations(context);
2113 CHECK_EQ(new_fused_computation.size(), 1);
2114 return std::make_unique<HloFusionInstruction>(
2115 shape, fusion_kind(), new_operands, new_fused_computation.front());
2116 }
2117
DeduplicateFusionOperands()2118 Status HloFusionInstruction::DeduplicateFusionOperands() {
2119 if (IsCustomFusion()) {
2120 return OkStatus();
2121 }
2122 absl::flat_hash_map<const HloInstruction*, int> operand_indices;
2123 std::vector<int> operands_to_remove;
2124 const int count = operand_count();
2125 operands_to_remove.reserve(count);
2126 for (int i = 0; i < count; ++i) {
2127 auto emplace_result = operand_indices.emplace(operand(i), i);
2128 if (!emplace_result.second) {
2129 TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
2130 fused_parameter(emplace_result.first->second)));
2131 operands_to_remove.push_back(i);
2132 }
2133 }
2134 if (operands_to_remove.empty()) {
2135 return OkStatus();
2136 }
2137 TF_RETURN_IF_ERROR(fused_instructions_computation()
2138 ->RemoveUnusedParametersFromFusedComputation());
2139 RemoveOperandsAtAscendingIndices(operands_to_remove);
2140 return OkStatus();
2141 }
2142
HloCallInstruction(const Shape & shape,HloInstruction * called_computation_root)2143 HloCallInstruction::HloCallInstruction(const Shape& shape,
2144 HloInstruction* called_computation_root)
2145 : HloCallableInstruction(HloOpcode::kCall, shape) {
2146 CHECK(called_computation_root != nullptr);
2147 SetAndSanitizeName(HloOpcodeString(opcode()));
2148 set_parent(called_computation_root->parent());
2149 set_metadata(called_computation_root->metadata());
2150 CloneAndAppendInstructionIntoCalledComputation(called_computation_root);
2151 }
2152
HloCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * called_computation)2153 HloCallInstruction::HloCallInstruction(
2154 const Shape& shape, absl::Span<HloInstruction* const> operands,
2155 HloComputation* called_computation)
2156 : HloCallableInstruction(HloOpcode::kCall, shape, operands,
2157 called_computation) {}
2158
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)2159 HloRngInstruction::HloRngInstruction(
2160 const Shape& shape, RandomDistribution distribution,
2161 absl::Span<HloInstruction* const> parameters)
2162 : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
2163 for (HloInstruction* param : parameters) {
2164 AppendOperand(param);
2165 }
2166 }
2167
ToProto() const2168 HloInstructionProto HloRngInstruction::ToProto() const {
2169 HloInstructionProto proto = HloInstruction::ToProto();
2170 proto.set_distribution(distribution_);
2171 return proto;
2172 }
2173
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2174 std::vector<std::string> HloRngInstruction::ExtraAttributesToStringImpl(
2175 const HloPrintOptions& options) const {
2176 return {StrCat("distribution=", RandomDistributionToString(distribution_))};
2177 }
2178
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const2179 bool HloRngInstruction::IsElementwiseImpl(
2180 const std::optional<int64_t>& operand_idx) const {
2181 return true;
2182 }
2183
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2184 bool HloRngInstruction::IdenticalSlowPath(
2185 const HloInstruction& other,
2186 const std::function<bool(const HloComputation*, const HloComputation*)>&
2187 eq_computations) const {
2188 const auto& casted_other = static_cast<const HloRngInstruction&>(other);
2189 return distribution_ == casted_other.distribution_;
2190 }
2191
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2192 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
2193 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2194 HloCloneContext* context) const {
2195 return std::make_unique<HloRngInstruction>(shape, distribution_,
2196 new_operands);
2197 }
2198
HloParameterInstruction(int64_t parameter_number,const Shape & shape,const std::string & name)2199 HloParameterInstruction::HloParameterInstruction(int64_t parameter_number,
2200 const Shape& shape,
2201 const std::string& name)
2202 : HloInstruction(HloOpcode::kParameter, shape),
2203 parameter_number_(parameter_number) {
2204 SetAndSanitizeName(name);
2205 }
2206
ToProto() const2207 HloInstructionProto HloParameterInstruction::ToProto() const {
2208 HloInstructionProto proto = HloInstruction::ToProto();
2209 proto.set_parameter_number(parameter_number_);
2210 if (parameter_replicated_at_leaf_buffers_) {
2211 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2212 proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
2213 replicated);
2214 }
2215 }
2216 return proto;
2217 }
2218
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2219 std::vector<std::string> HloParameterInstruction::ExtraAttributesToStringImpl(
2220 const HloPrintOptions& options) const {
2221 std::vector<std::string> result;
2222 if (!parameter_replicated_at_leaf_buffers_) {
2223 return result;
2224 }
2225 std::vector<std::string> buffers_replicated_strs;
2226 buffers_replicated_strs.reserve(
2227 parameter_replicated_at_leaf_buffers_->size());
2228 for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2229 buffers_replicated_strs.push_back(replicated ? "true" : "false");
2230 }
2231 if (options.print_ids()) {
2232 result.push_back(StrCat("parameter_replication={",
2233 StrJoin(buffers_replicated_strs, ","), "}"));
2234 }
2235 return result;
2236 }
2237
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2238 std::string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
2239 const HloPrintOptions& options,
2240 CanonicalNameMap* canonical_name_map) const {
2241 return StrCat(parameter_number_);
2242 }
2243
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2244 bool HloParameterInstruction::IdenticalSlowPath(
2245 const HloInstruction& other,
2246 const std::function<bool(const HloComputation*, const HloComputation*)>&
2247 eq_computations) const {
2248 const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
2249 return parameter_number() == casted_other.parameter_number();
2250 }
2251
2252 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2253 HloParameterInstruction::CloneWithNewOperandsImpl(
2254 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2255 HloCloneContext* context) const {
2256 auto clone = std::make_unique<HloParameterInstruction>(parameter_number_,
2257 shape, name());
2258 if (parameter_replicated_at_leaf_buffers_ &&
2259 ShapeUtil::Equal(shape, this->shape())) {
2260 clone->set_parameter_replicated_at_leaf_buffers(
2261 *parameter_replicated_at_leaf_buffers_);
2262 }
2263 return clone;
2264 }
2265
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64_t index)2266 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
2267 const Shape& shape, HloInstruction* operand, int64_t index)
2268 : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
2269 AppendOperand(operand);
2270 }
2271
ToProto() const2272 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
2273 HloInstructionProto proto = HloInstruction::ToProto();
2274 proto.set_tuple_index(tuple_index_);
2275 return proto;
2276 }
2277
2278 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2279 HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
2280 const HloPrintOptions& options) const {
2281 return {StrCat("index=", tuple_index())};
2282 }
2283
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2284 bool HloGetTupleElementInstruction::IdenticalSlowPath(
2285 const HloInstruction& other,
2286 const std::function<bool(const HloComputation*, const HloComputation*)>&
2287 eq_computations) const {
2288 const auto& casted_other =
2289 static_cast<const HloGetTupleElementInstruction&>(other);
2290 return tuple_index() == casted_other.tuple_index();
2291 }
2292
2293 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2294 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
2295 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2296 HloCloneContext* context) const {
2297 CHECK_EQ(new_operands.size(), 1);
2298 return std::make_unique<HloGetTupleElementInstruction>(shape, new_operands[0],
2299 tuple_index());
2300 }
2301
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2302 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2303 const Shape& shape, HloInstruction* operand, const int exponent_bits,
2304 const int mantissa_bits)
2305 : HloInstruction(HloOpcode::kReducePrecision, shape),
2306 exponent_bits_(exponent_bits),
2307 mantissa_bits_(mantissa_bits) {
2308 AppendOperand(operand);
2309 }
2310
ToProto() const2311 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2312 HloInstructionProto proto = HloInstruction::ToProto();
2313 proto.set_exponent_bits(exponent_bits_);
2314 proto.set_mantissa_bits(mantissa_bits_);
2315 return proto;
2316 }
2317
2318 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2319 HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2320 const HloPrintOptions& options) const {
2321 return {StrCat("exponent_bits=", exponent_bits_),
2322 StrCat("mantissa_bits=", mantissa_bits_)};
2323 }
2324
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2325 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2326 const HloInstruction& other,
2327 const std::function<bool(const HloComputation*, const HloComputation*)>&
2328 eq_computations) const {
2329 const auto& casted_other =
2330 static_cast<const HloReducePrecisionInstruction&>(other);
2331 // A reduce-precision operation is determined by the bit sizes.
2332 return exponent_bits() == casted_other.exponent_bits() &&
2333 mantissa_bits() == casted_other.mantissa_bits();
2334 }
2335
2336 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2337 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2338 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2339 HloCloneContext* context) const {
2340 CHECK_EQ(new_operands.size(), 1);
2341 return std::make_unique<HloReducePrecisionInstruction>(
2342 shape, new_operands[0], exponent_bits(), mantissa_bits());
2343 }
2344
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const std::string & config)2345 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2346 HloInstruction* token_operand,
2347 const std::string& config)
2348 : HloInstruction(HloOpcode::kInfeed,
2349 ShapeUtil::MakeTupleShape(
2350 {infeed_shape, ShapeUtil::MakeTokenShape()})),
2351 infeed_config_(config) {
2352 AppendOperand(token_operand);
2353 }
2354
ToProto() const2355 HloInstructionProto HloInfeedInstruction::ToProto() const {
2356 HloInstructionProto proto = HloInstruction::ToProto();
2357 proto.set_infeed_config(infeed_config_);
2358 return proto;
2359 }
2360
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2361 std::vector<std::string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2362 const HloPrintOptions& options) const {
2363 if (!options.print_infeed_outfeed_config() || infeed_config_.empty()) {
2364 return {};
2365 }
2366 return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2367 }
2368
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2369 bool HloInfeedInstruction::IdenticalSlowPath(
2370 const HloInstruction& other,
2371 const std::function<bool(const HloComputation*, const HloComputation*)>&
2372 eq_computations) const {
2373 // Not yet supported.
2374 return false;
2375 }
2376
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2377 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2378 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2379 HloCloneContext* context) const {
2380 CHECK_EQ(new_operands.size(), 1);
2381 return std::make_unique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
2382 infeed_config());
2383 }
2384
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2385 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2386 HloInstruction* operand,
2387 HloInstruction* token_operand,
2388 absl::string_view outfeed_config)
2389 : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2390 outfeed_shape_(outfeed_shape),
2391 outfeed_config_(outfeed_config) {
2392 AppendOperand(operand);
2393 AppendOperand(token_operand);
2394 }
2395
ToProto() const2396 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2397 HloInstructionProto proto = HloInstruction::ToProto();
2398 proto.set_outfeed_config(outfeed_config());
2399 *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2400 return proto;
2401 }
2402
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2403 std::vector<std::string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2404 const HloPrintOptions& options) const {
2405 std::vector<std::string> extra;
2406 extra.push_back(StrCat("outfeed_shape=",
2407 ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2408 if (options.print_infeed_outfeed_config() && !outfeed_config_.empty()) {
2409 extra.push_back(
2410 StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2411 }
2412 return extra;
2413 }
2414
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2415 bool HloOutfeedInstruction::IdenticalSlowPath(
2416 const HloInstruction& other,
2417 const std::function<bool(const HloComputation*, const HloComputation*)>&
2418 eq_computations) const {
2419 // Not yet supported.
2420 return false;
2421 }
2422
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2423 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2424 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2425 HloCloneContext* context) const {
2426 CHECK_EQ(new_operands.size(), 2);
2427 return std::make_unique<HloOutfeedInstruction>(
2428 outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2429 }
2430
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2431 HloConvolutionInstruction::HloConvolutionInstruction(
2432 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2433 int64_t feature_group_count, int64_t batch_group_count,
2434 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
2435 const PrecisionConfig& precision_config)
2436 : HloInstruction(HloOpcode::kConvolution, shape),
2437 feature_group_count_(feature_group_count),
2438 batch_group_count_(batch_group_count),
2439 window_(window),
2440 convolution_dimension_numbers_(dimension_numbers),
2441 precision_config_(precision_config) {
2442 if (window_util::HasBaseDilation(window)) {
2443 SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2444 }
2445 if (window_util::HasWindowDilation(window)) {
2446 SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2447 }
2448 AppendOperand(lhs);
2449 AppendOperand(rhs);
2450 }
2451
ToCategory() const2452 std::string HloConvolutionInstruction::ToCategory() const {
2453 std::string category = "convolution";
2454 if (window_util::HasBaseDilation(window())) {
2455 category += " base-dilated";
2456 }
2457 if (window_util::HasWindowDilation(window())) {
2458 category += " window-dilated";
2459 }
2460 return category;
2461 }
2462
ToProto() const2463 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2464 HloInstructionProto proto = HloInstruction::ToProto();
2465 *proto.mutable_window() = window_;
2466 *proto.mutable_convolution_dimension_numbers() =
2467 convolution_dimension_numbers_;
2468 proto.set_feature_group_count(feature_group_count_);
2469 proto.set_batch_group_count(batch_group_count_);
2470 *proto.mutable_precision_config() = precision_config_;
2471 return proto;
2472 }
2473
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2474 std::vector<std::string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2475 const HloPrintOptions& options) const {
2476 std::vector<std::string> extra;
2477 if (window_.dimensions_size() != 0) {
2478 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2479 }
2480 extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2481 convolution_dimension_numbers_)));
2482 if (feature_group_count_ != 1) {
2483 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2484 }
2485
2486 if (batch_group_count_ != 1) {
2487 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2488 }
2489
2490 std::string precision_config_string =
2491 PrecisionConfigToString(precision_config_);
2492 if (!precision_config_string.empty()) {
2493 extra.push_back(precision_config_string);
2494 }
2495 return extra;
2496 }
2497
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2498 bool HloConvolutionInstruction::IdenticalSlowPath(
2499 const HloInstruction& other,
2500 const std::function<bool(const HloComputation*, const HloComputation*)>&
2501 eq_computations) const {
2502 const auto& casted_other =
2503 static_cast<const HloConvolutionInstruction&>(other);
2504 if (feature_group_count_ != other.feature_group_count()) {
2505 return false;
2506 }
2507 if (batch_group_count_ != other.batch_group_count()) {
2508 return false;
2509 }
2510 return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2511 protobuf_util::ProtobufEquals(
2512 convolution_dimension_numbers(),
2513 casted_other.convolution_dimension_numbers()) &&
2514 protobuf_util::ProtobufEquals(precision_config(),
2515 casted_other.precision_config());
2516 }
2517
2518 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2519 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2520 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2521 HloCloneContext* context) const {
2522 CHECK_EQ(new_operands.size(), 2);
2523 return std::make_unique<HloConvolutionInstruction>(
2524 shape, new_operands[0], new_operands[1], feature_group_count_,
2525 batch_group_count_, window(), convolution_dimension_numbers_,
2526 precision_config_);
2527 }
2528
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2529 HloReduceWindowInstruction::HloReduceWindowInstruction(
2530 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2531 const Window& window, HloComputation* reduce_computation)
2532 : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2533 absl::MakeSpan(&init_value, 1), window,
2534 reduce_computation) {}
2535
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2536 HloReduceWindowInstruction::HloReduceWindowInstruction(
2537 const Shape& shape, absl::Span<HloInstruction* const> operands,
2538 absl::Span<HloInstruction* const> init_values, const Window& window,
2539 HloComputation* reduce_computation)
2540 : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2541 for (auto* operand : operands) {
2542 AppendOperand(operand);
2543 }
2544 for (auto* init_value : init_values) {
2545 AppendOperand(init_value);
2546 }
2547 AppendComputation(reduce_computation);
2548 }
2549
ToProto() const2550 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2551 HloInstructionProto proto = HloInstruction::ToProto();
2552 *proto.mutable_window() = window_;
2553 return proto;
2554 }
2555
2556 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2557 HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2558 const HloPrintOptions& options) const {
2559 std::vector<std::string> extra;
2560 if (window_.dimensions_size() != 0) {
2561 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2562 }
2563 return extra;
2564 }
2565
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2566 bool HloReduceWindowInstruction::IdenticalSlowPath(
2567 const HloInstruction& other,
2568 const std::function<bool(const HloComputation*, const HloComputation*)>&
2569 eq_computations) const {
2570 const auto& casted_other =
2571 static_cast<const HloReduceWindowInstruction&>(other);
2572 return eq_computations(to_apply(), casted_other.to_apply()) &&
2573 protobuf_util::ProtobufEquals(window(), casted_other.window());
2574 }
2575
2576 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2577 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2578 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2579 HloCloneContext* context) const {
2580 CHECK_EQ(new_operands.size() % 2, 0);
2581 int64_t num_operands = new_operands.size() / 2;
2582 return std::make_unique<HloReduceWindowInstruction>(
2583 shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2584 absl::MakeSpan(new_operands)
2585 .subspan(num_operands, new_operands.size() / 2),
2586 window(), to_apply());
2587 }
2588
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2589 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2590 const Shape& shape, HloInstruction* operand, HloComputation* select,
2591 const Window& window, HloInstruction* source, HloInstruction* init_value,
2592 HloComputation* scatter)
2593 : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2594 AppendOperand(operand);
2595 AppendOperand(source);
2596 AppendOperand(init_value);
2597 // Select comes before scatter in the vector.
2598 AppendComputation(select);
2599 AppendComputation(scatter);
2600 }
2601
ToProto() const2602 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2603 HloInstructionProto proto = HloInstruction::ToProto();
2604 *proto.mutable_window() = window_;
2605 return proto;
2606 }
2607
2608 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2609 HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2610 const HloPrintOptions& options) const {
2611 std::vector<std::string> extra;
2612 if (window_.dimensions_size() != 0) {
2613 extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2614 }
2615 return extra;
2616 }
2617
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2618 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2619 const HloInstruction& other,
2620 const std::function<bool(const HloComputation*, const HloComputation*)>&
2621 eq_computations) const {
2622 const auto& casted_other =
2623 static_cast<const HloSelectAndScatterInstruction&>(other);
2624 return eq_computations(select(), casted_other.select()) &&
2625 eq_computations(scatter(), casted_other.scatter()) &&
2626 protobuf_util::ProtobufEquals(window(), casted_other.window());
2627 }
2628
2629 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2630 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2631 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2632 HloCloneContext* context) const {
2633 CHECK_EQ(new_operands.size(), 3);
2634 return std::make_unique<HloSelectAndScatterInstruction>(
2635 shape, new_operands[0], select(), window(), new_operands[1],
2636 new_operands[2], scatter());
2637 }
2638
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2639 HloCustomCallInstruction::HloCustomCallInstruction(
2640 const Shape& shape, absl::Span<HloInstruction* const> operands,
2641 absl::string_view custom_call_target, std::string opaque,
2642 CustomCallApiVersion api_version)
2643 : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands),
2644 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2645 feature_group_count_(1),
2646 batch_group_count_(1),
2647 layout_constrained_(false),
2648 padding_type_(PaddingType::PADDING_INVALID),
2649 custom_call_has_side_effect_(false),
2650 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2651 api_version_(api_version) {
2652 set_raw_backend_config_string(std::move(opaque));
2653 }
2654
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2655 HloCustomCallInstruction::HloCustomCallInstruction(
2656 const Shape& shape, absl::Span<HloInstruction* const> operands,
2657 HloComputation* to_apply, absl::string_view custom_call_target,
2658 std::string opaque, CustomCallApiVersion api_version)
2659 : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands, to_apply),
2660 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2661 feature_group_count_(1),
2662 batch_group_count_(1),
2663 layout_constrained_(false),
2664 padding_type_(PaddingType::PADDING_INVALID),
2665 custom_call_has_side_effect_(false),
2666 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2667 api_version_(api_version) {
2668 set_raw_backend_config_string(std::move(opaque));
2669 to_apply->SetCustomCallInstruction(this);
2670 }
2671
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2672 HloCustomCallInstruction::HloCustomCallInstruction(
2673 const Shape& shape, absl::Span<HloInstruction* const> operands,
2674 absl::Span<HloComputation* const> called_computations,
2675 absl::string_view custom_call_target, std::string opaque,
2676 CustomCallApiVersion api_version)
2677 : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands,
2678 called_computations),
2679 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2680 feature_group_count_(1),
2681 batch_group_count_(1),
2682 layout_constrained_(false),
2683 padding_type_(PaddingType::PADDING_INVALID),
2684 custom_call_has_side_effect_(false),
2685 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2686 api_version_(api_version) {
2687 set_raw_backend_config_string(std::move(opaque));
2688 for (auto comp : called_computations) {
2689 comp->SetCustomCallInstruction(this);
2690 }
2691 }
2692
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,absl::Span<const Shape> operand_shapes_with_layout,CustomCallApiVersion api_version)2693 HloCustomCallInstruction::HloCustomCallInstruction(
2694 const Shape& shape, absl::Span<HloInstruction* const> operands,
2695 absl::string_view custom_call_target, std::string opaque,
2696 absl::Span<const Shape> operand_shapes_with_layout,
2697 CustomCallApiVersion api_version)
2698 : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands),
2699 custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2700 feature_group_count_(1),
2701 batch_group_count_(1),
2702 layout_constrained_(true),
2703 padding_type_(PaddingType::PADDING_INVALID),
2704 operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2705 operand_shapes_with_layout.end()),
2706 custom_call_has_side_effect_(false),
2707 custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2708 api_version_(api_version) {
2709 set_raw_backend_config_string(std::move(opaque));
2710 }
2711
ToProto() const2712 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2713 HloInstructionProto proto = HloInstruction::ToProto();
2714 if (window_ != nullptr) {
2715 *proto.mutable_window() = *window_;
2716 }
2717 if (convolution_dimension_numbers_ != nullptr) {
2718 *proto.mutable_convolution_dimension_numbers() =
2719 *convolution_dimension_numbers_;
2720 }
2721 proto.set_custom_call_target(custom_call_target_);
2722 proto.set_feature_group_count(feature_group_count_);
2723 proto.set_batch_group_count(batch_group_count_);
2724 *proto.mutable_precision_config() = precision_config_;
2725 proto.set_padding_type(padding_type_);
2726 if (layout_constrained()) {
2727 proto.set_constrain_layout(true);
2728 for (const Shape& shape : operand_shapes_with_layout_) {
2729 *proto.add_operand_shapes_with_layout() = shape.ToProto();
2730 }
2731 }
2732 proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2733 if (literal_.has_value()) {
2734 *proto.mutable_literal() = literal_->ToProto();
2735 }
2736 for (const auto& pair : output_to_operand_aliasing_) {
2737 auto aliasing = proto.add_custom_call_output_operand_aliasing();
2738 aliasing->set_operand_index(pair.second.first);
2739 for (int64_t index : pair.first) {
2740 aliasing->add_output_shape_index(index);
2741 }
2742 for (int64_t index : pair.second.second) {
2743 aliasing->add_operand_shape_index(index);
2744 }
2745 }
2746 proto.set_custom_call_schedule(custom_call_schedule_);
2747 proto.set_custom_call_api_version(api_version_);
2748 return proto;
2749 }
2750
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2751 std::vector<std::string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2752 const HloPrintOptions& options) const {
2753 std::vector<std::string> extra;
2754 if (window_ != nullptr) {
2755 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2756 }
2757 if (convolution_dimension_numbers_ != nullptr) {
2758 extra.push_back(StrCat(
2759 "dim_labels=",
2760 ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2761 }
2762 if (feature_group_count_ != 1) {
2763 extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2764 }
2765 if (batch_group_count_ != 1) {
2766 extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2767 }
2768 std::string precision_config_string =
2769 PrecisionConfigToString(precision_config_);
2770 if (!precision_config_string.empty()) {
2771 extra.push_back(precision_config_string);
2772 }
2773 if (padding_type_ != PaddingType::PADDING_INVALID) {
2774 extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2775 }
2776 // By contract, we print the custom call target even if
2777 // options.print_subcomputation_mode() == kOff, because the call target is not
2778 // an HloComputation.
2779 extra.push_back(
2780 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2781
2782 if (layout_constrained()) {
2783 std::vector<std::string> shape_strings;
2784 shape_strings.reserve(operand_shapes_with_layout_.size());
2785 for (const Shape& shape : operand_shapes_with_layout_) {
2786 shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2787 }
2788 extra.push_back(StrCat("operand_layout_constraints={",
2789 StrJoin(shape_strings, ", "), "}"));
2790 }
2791 if (custom_call_has_side_effect_) {
2792 extra.push_back("custom_call_has_side_effect=true");
2793 }
2794 if (literal_.has_value()) {
2795 extra.push_back(StrCat("literal=", literal_->ToStringWithLayoutOneline()));
2796 }
2797 if (!output_to_operand_aliasing_.empty()) {
2798 std::vector<std::string> pair_strings;
2799 pair_strings.reserve(output_to_operand_aliasing_.size());
2800 for (const auto& pair : output_to_operand_aliasing_) {
2801 pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2802 pair.second.first, ", ",
2803 pair.second.second.ToString(), ")"));
2804 }
2805 extra.push_back(StrCat("output_to_operand_aliasing={",
2806 StrJoin(pair_strings, ", "), "}"));
2807 }
2808 if (custom_call_schedule_ != CustomCallSchedule::SCHEDULE_NONE) {
2809 extra.push_back(
2810 StrCat("schedule=", CustomCallSchedule_Name(custom_call_schedule_)));
2811 }
2812 if (api_version_ != CustomCallApiVersion::API_VERSION_ORIGINAL) {
2813 extra.push_back(
2814 StrCat("api_version=", CustomCallApiVersion_Name(api_version_)));
2815 }
2816 return extra;
2817 }
2818
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2819 bool HloCustomCallInstruction::IdenticalSlowPath(
2820 const HloInstruction& other,
2821 const std::function<bool(const HloComputation*, const HloComputation*)>&
2822 eq_computations) const {
2823 const auto& casted_other =
2824 static_cast<const HloCustomCallInstruction&>(other);
2825 if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2826 (window_ != nullptr &&
2827 !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2828 return false;
2829 }
2830 if ((convolution_dimension_numbers_ == nullptr) !=
2831 (casted_other.convolution_dimension_numbers_ == nullptr) ||
2832 (convolution_dimension_numbers_ != nullptr &&
2833 !protobuf_util::ProtobufEquals(
2834 convolution_dimension_numbers(),
2835 casted_other.convolution_dimension_numbers()))) {
2836 return false;
2837 }
2838 if (feature_group_count_ != casted_other.feature_group_count_) {
2839 return false;
2840 }
2841 if (batch_group_count_ != casted_other.batch_group_count_) {
2842 return false;
2843 }
2844
2845 if (padding_type_ != casted_other.padding_type()) {
2846 return false;
2847 }
2848
2849 if (layout_constrained() != casted_other.layout_constrained()) {
2850 return false;
2851 }
2852 if (layout_constrained()) {
2853 for (int64_t i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2854 if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2855 casted_other.operand_shapes_with_layout_[i])) {
2856 return false;
2857 }
2858 }
2859 }
2860 if (custom_call_has_side_effect_ !=
2861 casted_other.custom_call_has_side_effect()) {
2862 return false;
2863 }
2864 if (output_to_operand_aliasing_ !=
2865 casted_other.output_to_operand_aliasing()) {
2866 return false;
2867 }
2868 if (!protobuf_util::ProtobufEquals(precision_config(),
2869 casted_other.precision_config())) {
2870 return false;
2871 }
2872
2873 if (called_computations().size() != other.called_computations().size()) {
2874 return false;
2875 }
2876 for (int64_t i = 0; i < called_computations().size(); ++i) {
2877 if (!eq_computations(called_computations()[i],
2878 other.called_computations()[i])) {
2879 return false;
2880 }
2881 }
2882 if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
2883 return false;
2884 }
2885 if (HasLiteral() != casted_other.HasLiteral()) {
2886 return false;
2887 }
2888 if (HasLiteral() && literal() != casted_other.literal()) {
2889 return false;
2890 }
2891 if (api_version_ != casted_other.api_version_) {
2892 return false;
2893 }
2894 // Note: backend_config comparison is done in Identical, which is the
2895 // intended/exposed way to compare computations, and so not repeated here.
2896 return custom_call_target_ == casted_other.custom_call_target_;
2897 }
2898
2899 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2900 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2901 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2902 HloCloneContext* context) const {
2903 absl::InlinedVector<HloComputation*, 1> new_called_computations =
2904 GetOrCloneCalledComputations(context);
2905
2906 auto cloned = std::make_unique<HloCustomCallInstruction>(
2907 shape, new_operands, new_called_computations, custom_call_target(),
2908 opaque(), api_version_);
2909 if (layout_constrained()) {
2910 cloned->layout_constrained_ = true;
2911 cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2912 }
2913 if (window_ != nullptr) {
2914 cloned->set_window(*window_);
2915 }
2916 if (convolution_dimension_numbers_ != nullptr) {
2917 cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2918 }
2919 if (HasLiteral()) {
2920 cloned->set_literal(literal().Clone());
2921 }
2922 cloned->set_feature_group_count(feature_group_count_);
2923 cloned->set_batch_group_count(batch_group_count_);
2924 cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2925 cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2926 cloned->set_padding_type(padding_type_);
2927 *cloned->mutable_precision_config() = precision_config();
2928 cloned->set_custom_call_schedule(custom_call_schedule_);
2929 return std::move(cloned);
2930 }
2931
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2932 HloPadInstruction::HloPadInstruction(const Shape& shape,
2933 HloInstruction* operand,
2934 HloInstruction* padding_value,
2935 const PaddingConfig& padding_config)
2936 : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2937 AppendOperand(operand);
2938 AppendOperand(padding_value);
2939 }
2940
ToProto() const2941 HloInstructionProto HloPadInstruction::ToProto() const {
2942 HloInstructionProto proto = HloInstruction::ToProto();
2943 *proto.mutable_padding_config() = padding_config_;
2944 return proto;
2945 }
2946
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2947 std::vector<std::string> HloPadInstruction::ExtraAttributesToStringImpl(
2948 const HloPrintOptions& options) const {
2949 return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2950 }
2951
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2952 bool HloPadInstruction::IdenticalSlowPath(
2953 const HloInstruction& other,
2954 const std::function<bool(const HloComputation*, const HloComputation*)>&
2955 eq_computations) const {
2956 const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2957 return protobuf_util::ProtobufEquals(padding_config(),
2958 casted_other.padding_config());
2959 }
2960
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2961 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2962 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2963 HloCloneContext* context) const {
2964 CHECK_EQ(new_operands.size(), 2);
2965 return std::make_unique<HloPadInstruction>(shape, new_operands[0],
2966 new_operands[1], padding_config_);
2967 }
2968
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64_t> slice_sizes)2969 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2970 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2971 absl::Span<const int64_t> slice_sizes)
2972 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2973 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2974 AppendOperand(operand);
2975 AppendOperand(start_indices);
2976 }
2977
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes)2978 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2979 const Shape& shape, HloInstruction* operand,
2980 absl::Span<HloInstruction* const> start_indices,
2981 absl::Span<const int64_t> slice_sizes)
2982 : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2983 dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2984 AppendOperand(operand);
2985 for (HloInstruction* index : start_indices) {
2986 AppendOperand(index);
2987 }
2988 }
2989
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2990 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2991 const Shape& shape, HloInstruction* operand, HloInstruction* update,
2992 HloInstruction* start_indices)
2993 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2994 AppendOperand(operand);
2995 AppendOperand(update);
2996 AppendOperand(start_indices);
2997 }
2998
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2999 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
3000 const Shape& shape, HloInstruction* operand, HloInstruction* update,
3001 absl::Span<HloInstruction* const> start_indices)
3002 : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
3003 AppendOperand(operand);
3004 AppendOperand(update);
3005 for (HloInstruction* index : start_indices) {
3006 AppendOperand(index);
3007 }
3008 }
3009
ToProto() const3010 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
3011 HloInstructionProto proto = HloInstruction::ToProto();
3012 for (int64_t slice_size : dynamic_slice_sizes_) {
3013 proto.add_dynamic_slice_sizes(slice_size);
3014 }
3015 return proto;
3016 }
3017
3018 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3019 HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
3020 const HloPrintOptions& options) const {
3021 return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
3022 "}")};
3023 }
3024
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3025 bool HloDynamicSliceInstruction::IdenticalSlowPath(
3026 const HloInstruction& other,
3027 const std::function<bool(const HloComputation*, const HloComputation*)>&
3028 eq_computations) const {
3029 const auto& casted_other = static_cast<const HloMapInstruction&>(other);
3030 return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
3031 }
3032
3033 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3034 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
3035 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3036 HloCloneContext* context) const {
3037 if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
3038 // TODO(b/118437727): Old form, remove this path.
3039 return std::make_unique<HloDynamicSliceInstruction>(
3040 shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
3041 } else {
3042 return std::make_unique<HloDynamicSliceInstruction>(
3043 shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
3044 }
3045 }
3046
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)3047 HloGatherInstruction::HloGatherInstruction(
3048 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
3049 const GatherDimensionNumbers& gather_dim_numbers,
3050 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted)
3051 : HloInstruction(HloOpcode::kGather, shape),
3052 indices_are_sorted_(indices_are_sorted) {
3053 AppendOperand(operand);
3054 AppendOperand(start_indices);
3055 gather_dimension_numbers_ =
3056 std::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
3057 absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
3058 }
3059
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)3060 /*static*/ std::string HloGatherInstruction::GatherDimensionNumbersToString(
3061 const GatherDimensionNumbers& gather_dimension_numbers) {
3062 std::string offset_dims =
3063 StrCat("offset_dims={",
3064 StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
3065 std::string collapsed_slice_dims = StrCat(
3066 "collapsed_slice_dims={",
3067 StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
3068 std::string start_index_map =
3069 StrCat("start_index_map={",
3070 StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
3071 std::string index_vector_dim =
3072 StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
3073
3074 return StrJoin<std::initializer_list<std::string>>(
3075 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
3076 ", ");
3077 }
3078
MakeGatherDimNumbers(absl::Span<const int64_t> offset_dims,absl::Span<const int64_t> collapsed_slice_dims,absl::Span<const int64_t> start_index_map,int64_t index_vector_dim)3079 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
3080 absl::Span<const int64_t> offset_dims,
3081 absl::Span<const int64_t> collapsed_slice_dims,
3082 absl::Span<const int64_t> start_index_map, int64_t index_vector_dim) {
3083 GatherDimensionNumbers gather_dim_numbers;
3084 for (int64_t output_window_dim : offset_dims) {
3085 gather_dim_numbers.add_offset_dims(output_window_dim);
3086 }
3087 for (int64_t elided_window_dim : collapsed_slice_dims) {
3088 gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
3089 }
3090 for (int64_t gather_dim_to_input_dim : start_index_map) {
3091 gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
3092 }
3093
3094 gather_dim_numbers.set_index_vector_dim(index_vector_dim);
3095 return gather_dim_numbers;
3096 }
3097
ToProto() const3098 HloInstructionProto HloGatherInstruction::ToProto() const {
3099 HloInstructionProto proto = HloInstruction::ToProto();
3100 *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
3101 for (int64_t bound : gather_slice_sizes()) {
3102 proto.add_gather_slice_sizes(bound);
3103 }
3104 proto.set_indices_are_sorted(indices_are_sorted());
3105 return proto;
3106 }
3107
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3108 std::vector<std::string> HloGatherInstruction::ExtraAttributesToStringImpl(
3109 const HloPrintOptions& options) const {
3110 std::vector<std::string> attrs{
3111 GatherDimensionNumbersToString(gather_dimension_numbers()),
3112 StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
3113 if (indices_are_sorted()) {
3114 attrs.push_back("indices_are_sorted=true");
3115 }
3116 return attrs;
3117 }
3118
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3119 bool HloGatherInstruction::IdenticalSlowPath(
3120 const HloInstruction& other,
3121 const std::function<bool(const HloComputation*, const HloComputation*)>&
3122 eq_computations) const {
3123 const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
3124 return protobuf_util::ProtobufEquals(
3125 gather_dimension_numbers(),
3126 casted_other.gather_dimension_numbers()) &&
3127 gather_slice_sizes() == casted_other.gather_slice_sizes() &&
3128 indices_are_sorted() == casted_other.indices_are_sorted();
3129 }
3130
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3131 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
3132 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3133 HloCloneContext* context) const {
3134 CHECK_EQ(new_operands.size(), 2);
3135 return std::make_unique<HloGatherInstruction>(
3136 shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
3137 gather_slice_sizes(), indices_are_sorted());
3138 }
3139
HloScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)3140 HloScatterInstruction::HloScatterInstruction(
3141 const Shape& shape, absl::Span<HloInstruction* const> args,
3142 HloComputation* update_computation,
3143 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
3144 bool unique_indices)
3145 : HloInstruction(HloOpcode::kScatter, shape),
3146 indices_are_sorted_(indices_are_sorted),
3147 unique_indices_(unique_indices) {
3148 mutable_operands().reserve(args.size());
3149 for (HloInstruction* arg : args) {
3150 AppendOperand(arg);
3151 }
3152 AppendComputation(update_computation);
3153 scatter_dimension_numbers_ =
3154 std::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
3155 }
3156
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)3157 /*static*/ std::string HloScatterInstruction::ScatterDimensionNumbersToString(
3158 const ScatterDimensionNumbers& scatter_dimension_numbers) {
3159 std::string update_window_dims =
3160 StrCat("update_window_dims={",
3161 StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
3162 std::string inserted_window_dims = StrCat(
3163 "inserted_window_dims={",
3164 StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
3165 std::string scatter_dims_to_operand_dims = StrCat(
3166 "scatter_dims_to_operand_dims={",
3167 StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
3168 "}");
3169 std::string index_vector_dim =
3170 StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
3171
3172 return StrJoin<std::initializer_list<std::string>>(
3173 {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
3174 index_vector_dim},
3175 ", ");
3176 }
3177
3178 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64_t> update_window_dims,absl::Span<const int64_t> inserted_window_dims,absl::Span<const int64_t> scatter_dims_to_operand_dims,int64_t index_vector_dim)3179 HloScatterInstruction::MakeScatterDimNumbers(
3180 absl::Span<const int64_t> update_window_dims,
3181 absl::Span<const int64_t> inserted_window_dims,
3182 absl::Span<const int64_t> scatter_dims_to_operand_dims,
3183 int64_t index_vector_dim) {
3184 ScatterDimensionNumbers scatter_dim_numbers;
3185 for (int64_t update_window_dim : update_window_dims) {
3186 scatter_dim_numbers.add_update_window_dims(update_window_dim);
3187 }
3188 for (int64_t inserted_window_dim : inserted_window_dims) {
3189 scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
3190 }
3191 for (int64_t scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
3192 scatter_dim_numbers.add_scatter_dims_to_operand_dims(
3193 scatter_dim_to_operand_dim);
3194 }
3195 scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
3196 return scatter_dim_numbers;
3197 }
3198
ToProto() const3199 HloInstructionProto HloScatterInstruction::ToProto() const {
3200 HloInstructionProto proto = HloInstruction::ToProto();
3201 *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
3202 proto.set_indices_are_sorted(indices_are_sorted());
3203 proto.set_unique_indices(unique_indices());
3204 return proto;
3205 }
3206
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3207 std::vector<std::string> HloScatterInstruction::ExtraAttributesToStringImpl(
3208 const HloPrintOptions& options) const {
3209 std::vector<std::string> attrs{
3210 ScatterDimensionNumbersToString(scatter_dimension_numbers())};
3211 if (indices_are_sorted()) {
3212 attrs.push_back("indices_are_sorted=true");
3213 }
3214 if (unique_indices()) {
3215 attrs.push_back("unique_indices=true");
3216 }
3217 return attrs;
3218 }
3219
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3220 bool HloScatterInstruction::IdenticalSlowPath(
3221 const HloInstruction& other,
3222 const std::function<bool(const HloComputation*, const HloComputation*)>&
3223 eq_computations) const {
3224 const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
3225 return protobuf_util::ProtobufEquals(
3226 scatter_dimension_numbers(),
3227 casted_other.scatter_dimension_numbers()) &&
3228 eq_computations(to_apply(), casted_other.to_apply()) &&
3229 indices_are_sorted() == casted_other.indices_are_sorted() &&
3230 unique_indices() == casted_other.unique_indices();
3231 }
3232
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3233 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
3234 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3235 HloCloneContext* context) const {
3236 return std::make_unique<HloScatterInstruction>(
3237 shape, new_operands, to_apply(), scatter_dimension_numbers(),
3238 indices_are_sorted(), unique_indices());
3239 }
3240
HloIotaInstruction(const Shape & shape,int64_t iota_dimension)3241 HloIotaInstruction::HloIotaInstruction(const Shape& shape,
3242 int64_t iota_dimension)
3243 : HloInstruction(HloOpcode::kIota, shape),
3244 iota_dimension_(iota_dimension) {}
3245
ToProto() const3246 HloInstructionProto HloIotaInstruction::ToProto() const {
3247 HloInstructionProto proto = HloInstruction::ToProto();
3248 proto.add_dimensions(iota_dimension());
3249 return proto;
3250 }
3251
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3252 std::vector<std::string> HloIotaInstruction::ExtraAttributesToStringImpl(
3253 const HloPrintOptions& options) const {
3254 return {StrCat("iota_dimension=", iota_dimension())};
3255 }
3256
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3257 bool HloIotaInstruction::IdenticalSlowPath(
3258 const HloInstruction& other,
3259 const std::function<bool(const HloComputation*, const HloComputation*)>&
3260 eq_computations) const {
3261 const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
3262 return iota_dimension() == casted_other.iota_dimension();
3263 }
3264
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3265 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
3266 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3267 HloCloneContext* context) const {
3268 return std::make_unique<HloIotaInstruction>(shape, iota_dimension());
3269 }
3270
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)3271 HloDotInstruction::HloDotInstruction(
3272 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
3273 const DotDimensionNumbers& dimension_numbers,
3274 const PrecisionConfig& precision_config)
3275 : HloInstruction(HloOpcode::kDot, shape),
3276 dot_dimension_numbers_(dimension_numbers),
3277 precision_config_(precision_config) {
3278 AppendOperand(lhs);
3279 AppendOperand(rhs);
3280 }
3281
ToProto() const3282 HloInstructionProto HloDotInstruction::ToProto() const {
3283 HloInstructionProto proto = HloInstruction::ToProto();
3284 *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
3285 *proto.mutable_precision_config() = precision_config_;
3286 return proto;
3287 }
3288
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3289 std::vector<std::string> HloDotInstruction::ExtraAttributesToStringImpl(
3290 const HloPrintOptions& options) const {
3291 std::vector<std::string> extra = {
3292 DotDimensionNumbersToString(dot_dimension_numbers_)};
3293
3294 std::string precision_config_string =
3295 PrecisionConfigToString(precision_config_);
3296 if (!precision_config_string.empty()) {
3297 extra.push_back(precision_config_string);
3298 }
3299 return extra;
3300 }
3301
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3302 bool HloDotInstruction::IdenticalSlowPath(
3303 const HloInstruction& other,
3304 const std::function<bool(const HloComputation*, const HloComputation*)>&
3305 eq_computations) const {
3306 const auto& casted_other = static_cast<const HloDotInstruction&>(other);
3307 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
3308 casted_other.dot_dimension_numbers()) &&
3309 protobuf_util::ProtobufEquals(precision_config(),
3310 casted_other.precision_config());
3311 }
3312
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3313 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
3314 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3315 HloCloneContext* context) const {
3316 CHECK_EQ(new_operands.size(), 2);
3317 return std::make_unique<HloDotInstruction>(
3318 shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
3319 precision_config_);
3320 }
3321
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3322 HloDomainInstruction::HloDomainInstruction(
3323 const Shape& shape, HloInstruction* operand,
3324 std::unique_ptr<DomainMetadata> operand_side_metadata,
3325 std::unique_ptr<DomainMetadata> user_side_metadata)
3326 : HloInstruction(HloOpcode::kDomain, shape),
3327 operand_side_metadata_(std::move(operand_side_metadata)),
3328 user_side_metadata_(std::move(user_side_metadata)) {
3329 AppendOperand(operand);
3330 }
3331
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3332 std::vector<std::string> HloDomainInstruction::ExtraAttributesToStringImpl(
3333 const HloPrintOptions& options) const {
3334 if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3335 return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3336 "\", entry=", user_side_metadata_->ToString(),
3337 ", exit=", operand_side_metadata_->ToString(), "}")};
3338 }
3339 return {};
3340 }
3341
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3342 bool HloDomainInstruction::IdenticalSlowPath(
3343 const HloInstruction& other,
3344 const std::function<bool(const HloComputation*, const HloComputation*)>&
3345 eq_computations) const {
3346 const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3347 return operand_side_metadata().Matches(
3348 casted_other.operand_side_metadata()) &&
3349 user_side_metadata().Matches(casted_other.user_side_metadata());
3350 }
3351
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3352 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3353 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3354 HloCloneContext* context) const {
3355 CHECK_EQ(new_operands.size(), 1);
3356 return std::make_unique<HloDomainInstruction>(shape, new_operands[0],
3357 operand_side_metadata_->Clone(),
3358 user_side_metadata_->Clone());
3359 }
3360
ToProto() const3361 HloInstructionProto HloDomainInstruction::ToProto() const {
3362 HloInstructionProto proto = HloInstruction::ToProto();
3363 auto operand_side_sharding =
3364 dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3365 if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3366 *proto.mutable_domain_entry_sharding() =
3367 operand_side_sharding->sharding()->ToProto();
3368 }
3369
3370 auto user_side_sharding =
3371 dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3372 if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3373 *proto.mutable_domain_exit_sharding() =
3374 user_side_sharding->sharding()->ToProto();
3375 }
3376
3377 return proto;
3378 }
3379
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64_t dimension)3380 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3381 const Shape& shape, HloInstruction* operand, int64_t dimension)
3382 : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3383 dimension_(dimension) {
3384 AppendOperand(operand);
3385 }
3386
ToProto() const3387 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3388 HloInstructionProto proto = HloInstruction::ToProto();
3389 proto.add_dimensions(dimension());
3390 return proto;
3391 }
3392
3393 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3394 HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3395 const HloPrintOptions& /*options*/) const {
3396 return {StrCat("dimensions={", dimension(), "}")};
3397 }
3398
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3399 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3400 const HloInstruction& other,
3401 const std::function<bool(const HloComputation*, const HloComputation*)>&
3402 /*eq_computations*/) const {
3403 const auto& casted_other =
3404 static_cast<const HloGetDimensionSizeInstruction&>(other);
3405 return dimension() == casted_other.dimension();
3406 }
3407
3408 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3409 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3410 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3411 HloCloneContext* /*context*/) const {
3412 if (new_operands.size() != 1) {
3413 LOG(FATAL) << "expects 1 operand";
3414 }
3415 return std::make_unique<HloGetDimensionSizeInstruction>(
3416 shape, new_operands[0], dimension());
3417 }
3418
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)3419 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3420 const Shape& shape, HloInstruction* operand, HloInstruction* val,
3421 int64_t dimension)
3422 : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3423 dimension_(dimension) {
3424 AppendOperand(operand);
3425 AppendOperand(val);
3426 }
3427
3428 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3429 HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3430 const HloPrintOptions& /*options*/) const {
3431 return {StrCat("dimensions={", dimension(), "}")};
3432 }
3433
ToProto() const3434 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3435 HloInstructionProto proto = HloInstruction::ToProto();
3436 proto.add_dimensions(dimension());
3437 return proto;
3438 }
3439
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3440 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3441 const HloInstruction& other,
3442 const std::function<bool(const HloComputation*, const HloComputation*)>&
3443 /*eq_computations*/) const {
3444 const auto& casted_other =
3445 static_cast<const HloSetDimensionSizeInstruction&>(other);
3446 return dimension() == casted_other.dimension();
3447 }
3448
3449 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3450 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3451 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3452 HloCloneContext* /*context*/) const {
3453 if (new_operands.size() != 2) {
3454 LOG(FATAL) << "expects 2 operand";
3455 }
3456 return std::make_unique<HloSetDimensionSizeInstruction>(
3457 shape, new_operands[0], new_operands[1], dimension());
3458 }
3459
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64_t delta)3460 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3461 const Shape& shape, int64_t delta)
3462 : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3463
ToProto() const3464 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3465 HloInstructionProto proto = HloInstruction::ToProto();
3466 proto.set_delta(delta_);
3467 return proto;
3468 }
3469
3470 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3471 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3472 const HloPrintOptions& /*options*/) const {
3473 return {StrCat("delta=", delta())};
3474 }
3475
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3476 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3477 const HloInstruction& other,
3478 const std::function<bool(const HloComputation*, const HloComputation*)>&
3479 /*eq_computations*/) const {
3480 const auto& casted_other =
3481 static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3482 return delta() == casted_other.delta();
3483 }
3484
3485 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3486 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3487 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3488 HloCloneContext* /*context*/) const {
3489 if (!new_operands.empty()) {
3490 LOG(FATAL) << "expects 0 operand";
3491 }
3492 return std::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3493 }
3494
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3495 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3496 const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3497 : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3498 algorithm_(algorithm) {
3499 AppendOperand(state);
3500 }
3501
ToProto() const3502 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3503 HloInstructionProto proto = HloInstruction::ToProto();
3504 proto.set_rng_algorithm(algorithm_);
3505 return proto;
3506 }
3507
3508 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3509 HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3510 const HloPrintOptions& options) const {
3511 return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3512 }
3513
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3514 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3515 const HloInstruction& other,
3516 const std::function<bool(const HloComputation*, const HloComputation*)>&
3517 eq_computations) const {
3518 const auto& casted_other =
3519 static_cast<const HloRngBitGeneratorInstruction&>(other);
3520 return algorithm() == casted_other.algorithm();
3521 }
3522
3523 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3524 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3525 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3526 HloCloneContext* /*context*/) const {
3527 CHECK_EQ(new_operands.size(), 1);
3528 return std::make_unique<HloRngBitGeneratorInstruction>(shape, new_operands[0],
3529 algorithm());
3530 }
3531
3532 } // namespace xla
3533