• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/str_split.h"
35 #include "absl/strings/string_view.h"
36 #include "tensorflow/compiler/xla/literal_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/protobuf_util.h"
39 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/platform/protobuf.h"
49 
50 namespace xla {
51 namespace {
52 
53 using absl::CEscape;
54 using absl::StrAppend;
55 using absl::StrCat;
56 using absl::StrJoin;
57 
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)58 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
59                                        const HloInstruction* operand) {
60   const auto operand_indices = instruction->OperandIndices(operand);
61   return absl::c_all_of(operand_indices, [instruction](int64_t operand_index) {
62     return instruction->IsElementwiseOnOperand(operand_index);
63   });
64 }
65 
PrecisionConfigToString(const PrecisionConfig & precision_config)66 std::string PrecisionConfigToString(const PrecisionConfig& precision_config) {
67   if (absl::c_all_of(
68           precision_config.operand_precision(), [](int32_t precision) {
69             return static_cast<PrecisionConfig::Precision>(precision) ==
70                    PrecisionConfig::DEFAULT;
71           })) {
72     return "";
73   }
74 
75   return StrCat(
76       "operand_precision={",
77       StrJoin(
78           precision_config.operand_precision(), ",",
79           [](std::string* out, int32_t precision) {
80             CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
81             StrAppend(out,
82                       PrecisionToString(
83                           static_cast<PrecisionConfig::Precision>(precision)));
84           }),
85       "}");
86 }
87 
SetThreadName(HloComputation * called_computation,absl::string_view execution_thread,bool skip_async_execution_thread_overwrite)88 void SetThreadName(HloComputation* called_computation,
89                    absl::string_view execution_thread,
90                    bool skip_async_execution_thread_overwrite) {
91   called_computation->SetExecutionThread(execution_thread);
92   for (HloInstruction* instr : called_computation->instructions()) {
93     if (instr->IsAsynchronous()) {
94       if (!skip_async_execution_thread_overwrite) {
95         // Set async instruction thread name and also recursively set async
96         // computations.
97         instr->set_async_execution_thread(execution_thread);
98       }
99       continue;
100     }
101     for (HloComputation* nested_called_computation :
102          instr->called_computations()) {
103       SetThreadName(nested_called_computation, execution_thread,
104                     skip_async_execution_thread_overwrite);
105     }
106   }
107 }
108 
109 }  // namespace
110 
HloBatchNormInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloInstruction * scale,float epsilon,int64_t feature_index)111 HloBatchNormInstruction::HloBatchNormInstruction(
112     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
113     HloInstruction* scale, float epsilon, int64_t feature_index)
114     : HloInstruction(opcode, shape),
115       epsilon_(epsilon),
116       feature_index_(feature_index) {
117   AppendOperand(operand);
118   AppendOperand(scale);
119 }
120 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const121 bool HloBatchNormInstruction::IdenticalSlowPath(
122     const HloInstruction& other,
123     const std::function<bool(const HloComputation*, const HloComputation*)>&
124         eq_computations) const {
125   const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
126   return feature_index() == casted_other.feature_index() &&
127          epsilon() == casted_other.epsilon();
128 }
129 
ToProto() const130 HloInstructionProto HloBatchNormInstruction::ToProto() const {
131   HloInstructionProto proto = HloInstruction::ToProto();
132   proto.set_epsilon(epsilon_);
133   proto.set_feature_index(feature_index_);
134   return proto;
135 }
136 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const137 std::vector<std::string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
138     const HloPrintOptions& options) const {
139   return {StrCat("epsilon=", epsilon()),
140           StrCat("feature_index=", feature_index())};
141 }
142 
HloBatchNormTrainingInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)143 HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
144     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
145     HloInstruction* offset, float epsilon, int64_t feature_index)
146     : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
147                               scale, epsilon, feature_index) {
148   AppendOperand(offset);
149 }
150 
151 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const152 HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
153     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
154     HloCloneContext* context) const {
155   CHECK_EQ(new_operands.size(), 3);
156   return std::make_unique<HloBatchNormTrainingInstruction>(
157       shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
158       feature_index());
159 }
160 
HloBatchNormInferenceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)161 HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
162     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
163     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
164     float epsilon, int64_t feature_index)
165     : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
166                               scale, epsilon, feature_index) {
167   AppendOperand(offset);
168   AppendOperand(mean);
169   AppendOperand(variance);
170 }
171 
172 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const173 HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
174     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
175     HloCloneContext* context) const {
176   CHECK_EQ(new_operands.size(), 5);
177   return std::make_unique<HloBatchNormInferenceInstruction>(
178       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
179       new_operands[4], epsilon(), feature_index());
180 }
181 
HloBatchNormGradInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)182 HloBatchNormGradInstruction::HloBatchNormGradInstruction(
183     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
184     HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
185     float epsilon, int64_t feature_index)
186     : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
187                               epsilon, feature_index) {
188   AppendOperand(mean);
189   AppendOperand(variance);
190   AppendOperand(grad_output);
191 }
192 
193 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const194 HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
195     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
196     HloCloneContext* context) const {
197   CHECK_EQ(new_operands.size(), 5);
198   return std::make_unique<HloBatchNormGradInstruction>(
199       shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
200       new_operands[4], epsilon(), feature_index());
201 }
202 
HloFftInstruction(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64_t> fft_length)203 HloFftInstruction::HloFftInstruction(const Shape& shape,
204                                      HloInstruction* operand, FftType fft_type,
205                                      absl::Span<const int64_t> fft_length)
206     : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
207   fft_length_.assign(fft_length.begin(), fft_length.end());
208   AppendOperand(operand);
209 }
210 
ToProto() const211 HloInstructionProto HloFftInstruction::ToProto() const {
212   HloInstructionProto proto = HloInstruction::ToProto();
213   proto.set_fft_type(fft_type_);
214   for (int64_t fft_len : fft_length_) {
215     proto.add_fft_length(fft_len);
216   }
217   return proto;
218 }
219 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const220 std::vector<std::string> HloFftInstruction::ExtraAttributesToStringImpl(
221     const HloPrintOptions& options) const {
222   return {StrCat("fft_type=", FftType_Name(fft_type())),
223           StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
224 }
225 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const226 bool HloFftInstruction::IdenticalSlowPath(
227     const HloInstruction& other,
228     const std::function<bool(const HloComputation*, const HloComputation*)>&
229         eq_computations) const {
230   const auto& casted_other = static_cast<const HloFftInstruction&>(other);
231   return fft_type() == casted_other.fft_type() &&
232          fft_length() == casted_other.fft_length();
233 }
234 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const235 std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
236     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
237     HloCloneContext* context) const {
238   CHECK_EQ(new_operands.size(), 1);
239   return std::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
240                                              fft_length_);
241 }
242 
HloAsyncInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)243 HloAsyncInstruction::HloAsyncInstruction(
244     HloOpcode opcode, const Shape& shape,
245     absl::Span<HloInstruction* const> operands,
246     HloComputation* async_computation, std::optional<int64_t> async_group_id,
247     absl::string_view async_execution_thread)
248     : HloInstruction(opcode, shape),
249       async_group_id_(async_group_id),
250       async_execution_thread_(async_execution_thread) {
251   CHECK(opcode == HloOpcode::kAsyncStart || operands.size() == 1);
252   for (auto operand : operands) {
253     AppendOperand(operand);
254   }
255   AppendComputation(async_computation);
256   CHECK(!async_computation->IsCustomCallComputation());
257   CHECK(!async_computation->IsFusionComputation());
258   async_computation->AddAsyncInstruction(this);
259   set_async_execution_thread(async_execution_thread);
260 }
261 
HloAsyncInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)262 HloAsyncInstruction::HloAsyncInstruction(
263     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
264     HloComputation* async_computation, std::optional<int64_t> async_group_id,
265     absl::string_view async_execution_thread)
266     : HloInstruction(opcode, shape),
267       async_group_id_(async_group_id),
268       async_execution_thread_(async_execution_thread) {
269   AppendOperand(operand);
270   AppendComputation(async_computation);
271   CHECK(!async_computation->IsCustomCallComputation());
272   CHECK(!async_computation->IsFusionComputation());
273   async_computation->AddAsyncInstruction(this);
274   set_async_execution_thread(async_execution_thread);
275 }
276 
~HloAsyncInstruction()277 HloAsyncInstruction::~HloAsyncInstruction() {
278   ClearAsyncComputationInstruction();
279   ClearCalledComputations();
280 }
281 
ClearAsyncComputationInstruction()282 void HloAsyncInstruction::ClearAsyncComputationInstruction() {
283   // Each async instruction calls a single computation, but we use
284   // called_computations() instead of async_wrapped_instruction(), because the
285   // order in which things get destructed can vary; the async computation's
286   // back-pointer may already be null, which violates a check in
287   // async_wrapped_instruction.
288   for (HloComputation* computation : called_computations()) {
289     CHECK(computation != nullptr);
290     if (computation->IsAsyncComputation()) {
291       computation->RemoveAsyncInstruction(this);
292     }
293   }
294 }
295 
async_wrapped_instruction() const296 HloInstruction* HloAsyncInstruction::async_wrapped_instruction() const {
297   CHECK(!called_computations().empty());
298   return called_computations()[0]->root_instruction();
299 }
300 
async_wrapped_opcode() const301 HloOpcode HloAsyncInstruction::async_wrapped_opcode() const {
302   return async_wrapped_instruction()->opcode();
303 }
304 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const305 std::vector<std::string> HloAsyncInstruction::ExtraAttributesToStringImpl(
306     const HloPrintOptions& options) const {
307   std::vector<std::string> result;
308   if (async_group_id_.has_value()) {
309     result.push_back(StrCat("async_group_id=", *async_group_id_));
310   }
311   if (async_execution_thread_ != kMainExecutionThread) {
312     result.push_back(
313         StrCat("async_execution_thread=\"", async_execution_thread_, "\""));
314   }
315   if (options.syntax_sugar_async_ops()) {
316     std::vector<std::string> wrapped_extra_attributes =
317         async_wrapped_instruction()->ExtraAttributesToString(options);
318     absl::c_copy(wrapped_extra_attributes, std::back_inserter(result));
319   }
320   return result;
321 }
322 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const323 bool HloAsyncInstruction::IdenticalSlowPath(
324     const HloInstruction& other,
325     const std::function<bool(const HloComputation*, const HloComputation*)>&
326         eq_computations) const {
327   return opcode() == other.opcode() &&
328          eq_computations(async_wrapped_computation(),
329                          other.async_wrapped_computation());
330 }
331 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const332 std::unique_ptr<HloInstruction> HloAsyncInstruction::CloneWithNewOperandsImpl(
333     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
334     HloCloneContext* context) const {
335   HloModule* module = context != nullptr ? context->module() : GetModule();
336   HloComputation* new_wrapped_computation = nullptr;
337   if (context != nullptr) {
338     new_wrapped_computation =
339         context->FindComputation(async_wrapped_computation());
340   }
341   if (new_wrapped_computation == nullptr) {
342     new_wrapped_computation = module->AddEmbeddedComputation(
343         async_wrapped_computation()->Clone("clone", context));
344   }
345   return std::make_unique<HloAsyncInstruction>(
346       opcode(), shape, new_operands, new_wrapped_computation, async_group_id_,
347       async_execution_thread_);
348 }
349 
set_async_group_id(std::optional<int64_t> async_group_id)350 void HloAsyncInstruction::set_async_group_id(
351     std::optional<int64_t> async_group_id) {
352   async_group_id_ = async_group_id;
353 }
354 
set_async_execution_thread(absl::string_view async_execution_thread)355 void HloAsyncInstruction::set_async_execution_thread(
356     absl::string_view async_execution_thread) {
357   async_execution_thread_ = std::string(async_execution_thread);
358   SetThreadName(async_wrapped_computation(), async_execution_thread,
359                 /*skip_async_execution_thread_overwrite=*/false);
360 }
361 
ToProto() const362 HloInstructionProto HloAsyncInstruction::ToProto() const {
363   HloInstructionProto proto = HloInstruction::ToProto();
364   proto.set_async_group_id(async_group_id_.has_value() ? *async_group_id_ : -1);
365   proto.set_async_execution_thread(async_execution_thread_ ==
366                                            HloInstruction::kMainExecutionThread
367                                        ? ""
368                                        : async_execution_thread_);
369   return proto;
370 }
371 
HloCopyStartInstruction(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)372 HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
373                                                  HloInstruction* operand,
374                                                  bool is_cross_program_prefetch)
375     : HloInstruction(HloOpcode::kCopyStart, shape),
376       is_cross_program_prefetch_(is_cross_program_prefetch) {
377   AppendOperand(operand);
378 }
379 
ToProto() const380 HloInstructionProto HloCopyStartInstruction::ToProto() const {
381   HloInstructionProto proto = HloInstruction::ToProto();
382   proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
383   return proto;
384 }
385 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const386 std::vector<std::string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
387     const HloPrintOptions& options) const {
388   std::vector<std::string> result;
389   if (is_cross_program_prefetch()) {
390     result.push_back("is_cross_program_prefetch=true");
391   }
392   return result;
393 }
394 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const395 bool HloCopyStartInstruction::IdenticalSlowPath(
396     const HloInstruction& other,
397     const std::function<bool(const HloComputation*, const HloComputation*)>&
398         eq_computations) const {
399   const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
400   return is_cross_program_prefetch() ==
401          casted_other.is_cross_program_prefetch();
402 }
403 
404 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const405 HloCopyStartInstruction::CloneWithNewOperandsImpl(
406     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
407     HloCloneContext* context) const {
408   CHECK_EQ(new_operands.size(), 1);
409   return std::make_unique<HloCopyStartInstruction>(shape, new_operands[0],
410                                                    is_cross_program_prefetch());
411 }
412 
HloCompareInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,std::optional<Comparison::Type> type)413 HloCompareInstruction::HloCompareInstruction(
414     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
415     ComparisonDirection direction, std::optional<Comparison::Type> type)
416     : HloInstruction(HloOpcode::kCompare, shape),
417       compare_(type.has_value()
418                    ? Comparison(direction, *type)
419                    : Comparison(direction, lhs->shape().element_type())) {
420   AppendOperand(lhs);
421   AppendOperand(rhs);
422 }
423 
ToProto() const424 HloInstructionProto HloCompareInstruction::ToProto() const {
425   HloInstructionProto proto = HloInstruction::ToProto();
426   proto.set_comparison_direction(
427       ComparisonDirectionToString(compare_.GetDirection()));
428   proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
429   return proto;
430 }
431 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const432 std::vector<std::string> HloCompareInstruction::ExtraAttributesToStringImpl(
433     const HloPrintOptions& options) const {
434   std::vector<std::string> result;
435   result.push_back(
436       StrCat("direction=", ComparisonDirectionToString(direction())));
437   if (compare_.GetType() !=
438       Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
439     result.push_back(
440         StrCat("type=", ComparisonTypeToString(compare_.GetType())));
441   }
442   return result;
443 }
444 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const445 bool HloCompareInstruction::IdenticalSlowPath(
446     const HloInstruction& other,
447     const std::function<bool(const HloComputation*, const HloComputation*)>&
448         eq_computations) const {
449   const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
450   return direction() == casted_other.direction();
451 }
452 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const453 std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
454     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
455     HloCloneContext* context) const {
456   CHECK_EQ(new_operands.size(), 2);
457   return std::make_unique<HloCompareInstruction>(
458       shape, new_operands[0], new_operands[1], direction(), type());
459 }
460 
461 namespace {
462 
463 // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
464 // of "key=value" attribute strings generically, using protocol buffer
465 // reflection.
466 //
467 // Currently implements a small subset of cases; feel free to add more as
468 // needed.
AttributeProtoToStringVector(const tensorflow::protobuf::Message & message)469 std::vector<std::string> AttributeProtoToStringVector(
470     const tensorflow::protobuf::Message& message) {
471   const tensorflow::protobuf::Reflection* reflection = message.GetReflection();
472   std::vector<const tensorflow::protobuf::FieldDescriptor*> fields;
473   reflection->ListFields(message, &fields);
474 
475   std::vector<std::string> output;
476   for (const tensorflow::protobuf::FieldDescriptor* field : fields) {
477     std::string s = absl::StrCat(field->name(), "=");
478     CHECK(!field->is_repeated()) << "Repeated fields aren't implemented";
479     switch (field->type()) {
480       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
481         bool val = reflection->GetBool(message, field);
482         absl::StrAppend(&s, val ? "true" : "false");
483         break;
484       }
485       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
486         const tensorflow::protobuf::EnumValueDescriptor* evd =
487             reflection->GetEnum(message, field);
488         absl::StrAppend(&s, evd->name());
489         break;
490       }
491       default:
492         LOG(FATAL) << "Unimplemented field type: " << field->DebugString();
493     }
494     output.push_back(std::move(s));
495   }
496   return output;
497 }
498 
499 }  // namespace
500 
HloTriangularSolveInstruction(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)501 HloTriangularSolveInstruction::HloTriangularSolveInstruction(
502     const Shape& shape, HloInstruction* a, HloInstruction* b,
503     const TriangularSolveOptions& options)
504     : HloInstruction(HloOpcode::kTriangularSolve, shape),
505       triangular_solve_options_(options) {
506   AppendOperand(a);
507   AppendOperand(b);
508 }
509 
ToProto() const510 HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
511   HloInstructionProto proto = HloInstruction::ToProto();
512   *proto.mutable_triangular_solve_options() = triangular_solve_options_;
513   return proto;
514 }
515 
516 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const517 HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
518     const HloPrintOptions& options) const {
519   return AttributeProtoToStringVector(triangular_solve_options_);
520 }
521 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const522 bool HloTriangularSolveInstruction::IdenticalSlowPath(
523     const HloInstruction& other,
524     const std::function<bool(const HloComputation*, const HloComputation*)>&
525         eq_computations) const {
526   const auto& casted_other =
527       static_cast<const HloTriangularSolveInstruction&>(other);
528   const auto& options = triangular_solve_options();
529   const auto& other_options = casted_other.triangular_solve_options();
530 
531   return options.left_side() == other_options.left_side() &&
532          options.lower() == other_options.lower() &&
533          options.unit_diagonal() == other_options.unit_diagonal() &&
534          options.transpose_a() == other_options.transpose_a();
535 }
536 
537 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const538 HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
539     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
540     HloCloneContext* context) const {
541   CHECK_EQ(new_operands.size(), 2);
542   return std::make_unique<HloTriangularSolveInstruction>(
543       shape, new_operands[0], new_operands[1], triangular_solve_options());
544 }
545 
HloCholeskyInstruction(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)546 HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
547                                                HloInstruction* a,
548                                                const CholeskyOptions& options)
549     : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
550   AppendOperand(a);
551 }
552 
ToProto() const553 HloInstructionProto HloCholeskyInstruction::ToProto() const {
554   HloInstructionProto proto = HloInstruction::ToProto();
555   *proto.mutable_cholesky_options() = cholesky_options_;
556   return proto;
557 }
558 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const559 std::vector<std::string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
560     const HloPrintOptions& options) const {
561   return AttributeProtoToStringVector(cholesky_options_);
562 }
563 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const564 bool HloCholeskyInstruction::IdenticalSlowPath(
565     const HloInstruction& other,
566     const std::function<bool(const HloComputation*, const HloComputation*)>&
567         eq_computations) const {
568   const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
569   const auto& options = cholesky_options();
570   const auto& other_options = casted_other.cholesky_options();
571 
572   return options.lower() == other_options.lower();
573 }
574 
575 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const576 HloCholeskyInstruction::CloneWithNewOperandsImpl(
577     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
578     HloCloneContext* context) const {
579   CHECK_EQ(new_operands.size(), 1);
580   return std::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
581                                                   cholesky_options());
582 }
583 
HloChannelInstruction(HloOpcode opcode,const Shape & shape,const std::optional<int64_t> & channel_id)584 HloChannelInstruction::HloChannelInstruction(
585     HloOpcode opcode, const Shape& shape,
586     const std::optional<int64_t>& channel_id)
587     : HloInstruction(opcode, shape), channel_id_(channel_id) {}
588 
set_channel_id(const std::optional<int64_t> & channel_id)589 void HloChannelInstruction::set_channel_id(
590     const std::optional<int64_t>& channel_id) {
591   channel_id_ = channel_id;
592 }
593 
ToProto() const594 HloInstructionProto HloChannelInstruction::ToProto() const {
595   HloInstructionProto proto = HloInstruction::ToProto();
596   if (channel_id_) {
597     CHECK_GT(channel_id_.value(), 0)
598         << "Non-positive channel id is equivalent to no channel id";
599     proto.set_channel_id(*channel_id_);
600   }
601   return proto;
602 }
603 
ExtraAttributesToStringImpl(const HloPrintOptions &) const604 std::vector<std::string> HloChannelInstruction::ExtraAttributesToStringImpl(
605     const HloPrintOptions& /*options*/) const {
606   std::vector<std::string> result;
607   if (channel_id_) {
608     result.push_back(StrCat("channel_id=", *channel_id_));
609   }
610   return result;
611 }
612 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const613 bool HloChannelInstruction::IdenticalSlowPath(
614     const HloInstruction& other,
615     const std::function<bool(const HloComputation*, const HloComputation*)>&
616         eq_computations) const {
617   if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) {
618     return false;
619   }
620   const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
621   return channel_id() == casted_other.channel_id();
622 }
623 
HloSendRecvInstruction(HloOpcode opcode,const Shape & shape,int64_t channel_id,bool is_host_transfer)624 HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
625                                                const Shape& shape,
626                                                int64_t channel_id,
627                                                bool is_host_transfer)
628     : HloChannelInstruction(opcode, shape, channel_id),
629       is_host_transfer_(is_host_transfer) {}
630 
ToProto() const631 HloInstructionProto HloSendRecvInstruction::ToProto() const {
632   HloInstructionProto proto = HloChannelInstruction::ToProto();
633   proto.set_is_host_transfer(is_host_transfer_);
634   return proto;
635 }
636 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const637 std::vector<std::string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
638     const HloPrintOptions& options) const {
639   std::vector<std::string> attrs =
640       HloChannelInstruction::ExtraAttributesToStringImpl(options);
641   if (is_host_transfer()) {
642     attrs.push_back("is_host_transfer=true");
643   }
644   return attrs;
645 }
646 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const647 bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
648     const HloInstruction& other,
649     const std::function<bool(const HloComputation*, const HloComputation*)>&
650         eq_computations) const {
651   // Not yet supported.
652   return false;
653 }
654 
655 // Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)656 HloSendInstruction::HloSendInstruction(HloInstruction* operand,
657                                        HloInstruction* token,
658                                        int64_t channel_id,
659                                        bool is_host_transfer)
660     : HloSendRecvInstruction(
661           HloOpcode::kSend,
662           ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
663                                      ShapeUtil::MakeShape(U32, {}),
664                                      ShapeUtil::MakeTokenShape()}),
665           channel_id, is_host_transfer) {
666   AppendOperand(operand);
667   AppendOperand(token);
668 }
669 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const670 std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
671     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
672     HloCloneContext* context) const {
673   CHECK_EQ(new_operands.size(), 2);
674   return std::make_unique<HloSendInstruction>(
675       new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
676 }
677 
HloSendDoneInstruction(HloSendInstruction * operand,bool is_host_transfer)678 HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
679                                                bool is_host_transfer)
680     : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
681                              CHECK_NOTNULL(operand)->channel_id().value(),
682                              is_host_transfer) {
683   AppendOperand(operand);
684 }
685 
686 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const687 HloSendDoneInstruction::CloneWithNewOperandsImpl(
688     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
689     HloCloneContext* context) const {
690   CHECK_EQ(new_operands.size(), 1);
691   return std::make_unique<HloSendDoneInstruction>(
692       Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
693 }
694 
695 // Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)696 HloRecvInstruction::HloRecvInstruction(const Shape& shape,
697                                        HloInstruction* token,
698                                        int64_t channel_id,
699                                        bool is_host_transfer)
700     : HloSendRecvInstruction(
701           HloOpcode::kRecv,
702           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
703                                      ShapeUtil::MakeTokenShape()}),
704           channel_id, is_host_transfer) {
705   AppendOperand(token);
706 }
707 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const708 std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
709     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
710     HloCloneContext* context) const {
711   CHECK_EQ(new_operands.size(), 1);
712   return std::make_unique<HloRecvInstruction>(
713       ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
714       is_host_transfer());
715 }
716 
HloRecvDoneInstruction(HloRecvInstruction * operand,bool is_host_transfer)717 HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
718                                                bool is_host_transfer)
719     : HloSendRecvInstruction(
720           HloOpcode::kRecvDone,
721           ShapeUtil::MakeTupleShape(
722               {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
723                ShapeUtil::MakeTokenShape()}),
724           CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
725   AppendOperand(operand);
726 }
727 
728 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const729 HloRecvDoneInstruction::CloneWithNewOperandsImpl(
730     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
731     HloCloneContext* context) const {
732   CHECK_EQ(new_operands.size(), 1);
733   return std::make_unique<HloRecvDoneInstruction>(
734       Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
735 }
736 
HloCollectiveInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id)737 HloCollectiveInstruction::HloCollectiveInstruction(
738     HloOpcode opcode, const Shape& shape,
739     absl::Span<HloInstruction* const> operands,
740     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
741     const std::optional<int64_t>& channel_id)
742     : HloChannelInstruction(opcode, shape, channel_id),
743       replica_groups_(SpanToVector(replica_groups)),
744       constrain_layout_(constrain_layout) {
745   for (auto operand : operands) {
746     AppendOperand(operand);
747   }
748 }
749 
ToProto() const750 HloInstructionProto HloCollectiveInstruction::ToProto() const {
751   HloInstructionProto proto = HloChannelInstruction::ToProto();
752   *proto.mutable_replica_groups() = {replica_groups_.begin(),
753                                      replica_groups_.end()};
754   proto.set_constrain_layout(constrain_layout_);
755   return proto;
756 }
757 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const758 std::vector<std::string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
759     const HloPrintOptions& options) const {
760   std::vector<std::string> result =
761       HloChannelInstruction::ExtraAttributesToStringImpl(options);
762   result.push_back(
763       StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
764   if (constrain_layout_) {
765     result.push_back("constrain_layout=true");
766   }
767   return result;
768 }
769 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const770 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
771     const HloInstruction& other,
772     const std::function<bool(const HloComputation*, const HloComputation*)>&
773         eq_computations) const {
774   const auto& casted_other =
775       static_cast<const HloCollectiveInstruction&>(other);
776   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
777              other, eq_computations) &&
778          constrain_layout() == casted_other.constrain_layout() &&
779          absl::c_equal(replica_groups(), casted_other.replica_groups(),
780                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
781                          return absl::c_equal(a.replica_ids(), b.replica_ids());
782                        });
783 }
784 
HloAllGatherInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)785 HloAllGatherInstruction::HloAllGatherInstruction(
786     HloOpcode opcode, const Shape& shape,
787     absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension,
788     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
789     const std::optional<int64_t>& channel_id, bool use_global_device_ids)
790     : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
791                                constrain_layout, channel_id),
792       all_gather_dimension_(all_gather_dimension),
793       use_global_device_ids_(use_global_device_ids) {}
794 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const795 std::vector<std::string> HloAllGatherInstruction::ExtraAttributesToStringImpl(
796     const HloPrintOptions& options) const {
797   std::vector<std::string> result =
798       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
799   result.push_back(StrCat("dimensions={", all_gather_dimension_, "}"));
800   if (use_global_device_ids_) {
801     result.push_back("use_global_device_ids=true");
802   }
803   return result;
804 }
805 
806 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const807 HloAllGatherInstruction::CloneWithNewOperandsImpl(
808     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
809     HloCloneContext* /*context*/) const {
810   return std::make_unique<HloAllGatherInstruction>(
811       opcode(), shape, new_operands, all_gather_dimension(), replica_groups(),
812       constrain_layout(), channel_id(), use_global_device_ids());
813 }
814 
ToProto() const815 HloInstructionProto HloAllGatherInstruction::ToProto() const {
816   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
817   proto.add_dimensions(all_gather_dimension_);
818   proto.set_use_global_device_ids(use_global_device_ids_);
819   return proto;
820 }
821 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const822 bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues(
823     const HloInstruction& other,
824     const std::function<bool(const HloComputation*, const HloComputation*)>&
825         eq_computations) const {
826   const auto& casted_other = static_cast<const HloAllGatherInstruction&>(other);
827   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
828              other, eq_computations) &&
829          all_gather_dimension_ == casted_other.all_gather_dimension() &&
830          use_global_device_ids() == casted_other.use_global_device_ids();
831 }
832 
HloAllReduceInstructionBase(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)833 HloAllReduceInstructionBase::HloAllReduceInstructionBase(
834     HloOpcode opcode, const Shape& shape,
835     absl::Span<HloInstruction* const> operands,
836     HloComputation* reduce_computation,
837     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
838     const std::optional<int64_t>& channel_id, bool use_global_device_ids)
839     : HloCollectiveInstruction(opcode, shape, operands, replica_groups,
840                                constrain_layout, channel_id),
841       use_global_device_ids_(use_global_device_ids) {
842   AppendComputation(reduce_computation);
843 }
844 
ToProto() const845 HloInstructionProto HloAllReduceInstructionBase::ToProto() const {
846   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
847   proto.set_use_global_device_ids(use_global_device_ids_);
848   return proto;
849 }
850 
851 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const852 HloAllReduceInstructionBase::ExtraAttributesToStringImpl(
853     const HloPrintOptions& options) const {
854   std::vector<std::string> result =
855       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
856   if (use_global_device_ids_) {
857     result.push_back("use_global_device_ids=true");
858   }
859   return result;
860 }
861 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const862 bool HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
863     const HloInstruction& other,
864     const std::function<bool(const HloComputation*, const HloComputation*)>&
865         eq_computations) const {
866   if (opcode() != other.opcode()) {
867     return false;
868   }
869   const auto& casted_other =
870       static_cast<const HloAllReduceInstructionBase&>(other);
871   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
872              other, eq_computations) &&
873          constrain_layout() == casted_other.constrain_layout() &&
874          use_global_device_ids() == casted_other.use_global_device_ids() &&
875          eq_computations(to_apply(), casted_other.to_apply());
876 }
877 
IsNoop() const878 bool HloAllReduceInstruction::IsNoop() const {
879   for (const auto& replica_group : replica_groups()) {
880     if (replica_group.replica_ids().size() != 1) {
881       return false;
882     }
883   }
884   return !channel_id();
885 }
886 
887 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const888 HloAllReduceInstruction::CloneWithNewOperandsImpl(
889     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
890     HloCloneContext* /*context*/) const {
891   return std::make_unique<HloAllReduceInstruction>(
892       opcode(), shape, new_operands, to_apply(), replica_groups(),
893       constrain_layout(), channel_id(), use_global_device_ids());
894 }
895 
HloReduceScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)896 HloReduceScatterInstruction::HloReduceScatterInstruction(
897     const Shape& shape, absl::Span<HloInstruction* const> operands,
898     HloComputation* reduce_computation,
899     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
900     const std::optional<int64_t>& channel_id, bool use_global_device_ids,
901     int64_t scatter_dimension)
902     : HloAllReduceInstructionBase(
903           HloOpcode::kReduceScatter, shape, operands, reduce_computation,
904           replica_groups, constrain_layout, channel_id, use_global_device_ids),
905       scatter_dimension_(scatter_dimension) {}
906 
907 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const908 HloReduceScatterInstruction::ExtraAttributesToStringImpl(
909     const HloPrintOptions& options) const {
910   std::vector<std::string> result =
911       HloAllReduceInstructionBase::ExtraAttributesToStringImpl(options);
912   result.push_back(StrCat("dimensions={", scatter_dimension_, "}"));
913   return result;
914 }
915 
ToProto() const916 HloInstructionProto HloReduceScatterInstruction::ToProto() const {
917   HloInstructionProto proto = HloAllReduceInstructionBase::ToProto();
918   proto.add_dimensions(scatter_dimension_);
919   return proto;
920 }
921 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const922 bool HloReduceScatterInstruction::IdenticalSlowPathIgnoringChannelIdValues(
923     const HloInstruction& other,
924     const std::function<bool(const HloComputation*, const HloComputation*)>&
925         eq_computations) const {
926   const auto& casted_other =
927       static_cast<const HloReduceScatterInstruction&>(other);
928   return HloAllReduceInstructionBase::IdenticalSlowPathIgnoringChannelIdValues(
929              other, eq_computations) &&
930          scatter_dimension_ == casted_other.scatter_dimension();
931 }
932 
933 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const934 HloReduceScatterInstruction::CloneWithNewOperandsImpl(
935     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
936     HloCloneContext* /*context*/) const {
937   return std::make_unique<HloReduceScatterInstruction>(
938       shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
939       channel_id(), use_global_device_ids(), scatter_dimension());
940 }
941 
HloAllToAllInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,const std::optional<int64_t> & split_dimension)942 HloAllToAllInstruction::HloAllToAllInstruction(
943     const Shape& shape, absl::Span<HloInstruction* const> operands,
944     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
945     const std::optional<int64_t>& channel_id,
946     const std::optional<int64_t>& split_dimension)
947     : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
948                                replica_groups, constrain_layout, channel_id),
949       split_dimension_(split_dimension) {}
950 
951 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const952 HloAllToAllInstruction::CloneWithNewOperandsImpl(
953     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
954     HloCloneContext* /*context*/) const {
955   return std::make_unique<HloAllToAllInstruction>(
956       shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
957       split_dimension());
958 }
959 
ToProto() const960 HloInstructionProto HloAllToAllInstruction::ToProto() const {
961   HloInstructionProto proto = HloCollectiveInstruction::ToProto();
962   if (split_dimension_) {
963     proto.add_dimensions(*split_dimension_);
964   }
965   return proto;
966 }
967 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const968 std::vector<std::string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
969     const HloPrintOptions& options) const {
970   std::vector<std::string> result =
971       HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
972   if (split_dimension_) {
973     result.push_back(StrCat("dimensions={", *split_dimension_, "}"));
974   }
975   return result;
976 }
977 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const978 bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
979     const HloInstruction& other,
980     const std::function<bool(const HloComputation*, const HloComputation*)>&
981         eq_computations) const {
982   const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
983   return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
984              other, eq_computations) &&
985          split_dimension_ == casted_other.split_dimension();
986 }
987 
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)988 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
989     HloOpcode opcode, const Shape& shape, HloInstruction* operand,
990     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
991     const std::optional<int64_t>& channel_id)
992     : HloChannelInstruction(opcode, shape, channel_id),
993       source_target_pairs_(source_target_pairs) {
994   AppendOperand(operand);
995 }
996 
HloCollectivePermuteInstruction(HloOpcode opcode,const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)997 HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
998     HloOpcode opcode, const Shape& shape, HloInstruction* input,
999     HloInstruction* output, HloInstruction* input_start_indices,
1000     HloInstruction* output_start_indices,
1001     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1002     absl::Span<const std::vector<int64_t>> slice_sizes,
1003     const std::optional<int64_t>& channel_id)
1004     : HloChannelInstruction(opcode, shape, channel_id),
1005       source_target_pairs_(source_target_pairs.begin(),
1006                            source_target_pairs.end()),
1007       slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
1008   AppendOperand(input);
1009   AppendOperand(output);
1010   AppendOperand(input_start_indices);
1011   AppendOperand(output_start_indices);
1012 }
1013 
ToProto() const1014 HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
1015   HloInstructionProto proto = HloChannelInstruction::ToProto();
1016   for (const auto& pair : source_target_pairs()) {
1017     auto* proto_pair = proto.add_source_target_pairs();
1018     proto_pair->set_source(pair.first);
1019     proto_pair->set_target(pair.second);
1020   }
1021   for (const auto& slice_size : dynamic_slice_sizes_list()) {
1022     for (const auto& dimension_slice_size : slice_size) {
1023       proto.add_dynamic_slice_sizes(dimension_slice_size);
1024     }
1025   }
1026   return proto;
1027 }
1028 
1029 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1030 HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
1031     const HloPrintOptions& options) const {
1032   std::vector<std::string> result =
1033       HloChannelInstruction::ExtraAttributesToStringImpl(options);
1034   {
1035     std::vector<std::string> strs;
1036     const auto& pairs = source_target_pairs();
1037     strs.reserve(pairs.size());
1038     for (const auto& pair : pairs) {
1039       strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
1040     }
1041     result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
1042   }
1043   if (!dynamic_slice_sizes_list().empty()) {
1044     std::vector<std::string> strs;
1045     const auto& sizes_list = dynamic_slice_sizes_list();
1046     strs.reserve(sizes_list.size());
1047     for (const auto& slice_sizes : dynamic_slice_sizes_list()) {
1048       strs.push_back(StrCat("{", StrJoin(slice_sizes, ","), "}"));
1049     }
1050     result.push_back(StrCat("slice_sizes={", StrJoin(strs, ","), "}"));
1051   }
1052   return result;
1053 }
1054 
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1055 bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues(
1056     const HloInstruction& other,
1057     const std::function<bool(const HloComputation*, const HloComputation*)>&
1058         eq_computations) const {
1059   if (opcode() != other.opcode()) {
1060     return false;
1061   }
1062   const auto& casted_other =
1063       static_cast<const HloCollectivePermuteInstruction&>(other);
1064   return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues(
1065              other, eq_computations) &&
1066          absl::c_equal(
1067              source_target_pairs(), casted_other.source_target_pairs(),
1068              [](const std::pair<int64_t, int64_t>& a,
1069                 const std::pair<int64_t, int64_t>& b) { return a == b; }) &&
1070          absl::c_equal(
1071              dynamic_slice_sizes_list(),
1072              casted_other.dynamic_slice_sizes_list(),
1073              [](const std::vector<int64_t>& a, const std::vector<int64_t>& b) {
1074                return absl::c_equal(a, b);
1075              });
1076 }
1077 
1078 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const1079 HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
1080     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1081     HloCloneContext* /*context*/) const {
1082   if (dynamic_slice_sizes_list().empty()) {
1083     return std::make_unique<HloCollectivePermuteInstruction>(
1084         opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
1085   } else {
1086     return std::make_unique<HloCollectivePermuteInstruction>(
1087         opcode(), shape, new_operands[0], new_operands[1], new_operands[2],
1088         new_operands[3], source_target_pairs(), dynamic_slice_sizes_list(),
1089         channel_id());
1090   }
1091 }
1092 
HloReverseInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1093 HloReverseInstruction::HloReverseInstruction(
1094     const Shape& shape, HloInstruction* operand,
1095     absl::Span<const int64_t> dimensions)
1096     : HloDimensionsInstruction(HloOpcode::kReverse, shape, dimensions) {
1097   AppendOperand(operand);
1098 }
1099 
ToProto() const1100 HloInstructionProto HloDimensionsInstruction::ToProto() const {
1101   HloInstructionProto proto = HloInstruction::ToProto();
1102   for (int64_t dimension : dimensions_) {
1103     proto.add_dimensions(dimension);
1104   }
1105   return proto;
1106 }
1107 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1108 std::vector<std::string> HloDimensionsInstruction::ExtraAttributesToStringImpl(
1109     const HloPrintOptions& options) const {
1110   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1111 }
1112 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1113 bool HloDimensionsInstruction::IdenticalSlowPath(
1114     const HloInstruction& other,
1115     const std::function<bool(const HloComputation*, const HloComputation*)>&
1116         eq_computations) const {
1117   const auto& casted_other =
1118       static_cast<const HloDimensionsInstruction&>(other);
1119   return dimensions() == casted_other.dimensions();
1120 }
1121 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1122 std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
1123     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1124     HloCloneContext* context) const {
1125   CHECK_EQ(new_operands.size(), 1);
1126   return std::make_unique<HloReverseInstruction>(shape, new_operands[0],
1127                                                  dimensions());
1128 }
1129 
HloConcatenateInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1130 HloConcatenateInstruction::HloConcatenateInstruction(
1131     const Shape& shape, absl::Span<HloInstruction* const> operands,
1132     int64_t dimension)
1133     : HloDimensionsInstruction(HloOpcode::kConcatenate, shape, {dimension}) {
1134   for (auto operand : operands) {
1135     AppendOperand(operand);
1136   }
1137 }
1138 
1139 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1140 HloConcatenateInstruction::CloneWithNewOperandsImpl(
1141     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1142     HloCloneContext* context) const {
1143   return std::make_unique<HloConcatenateInstruction>(shape, new_operands,
1144                                                      concatenate_dimension());
1145 }
1146 
HloReduceInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1147 HloReduceInstruction::HloReduceInstruction(
1148     const Shape& shape, absl::Span<HloInstruction* const> args,
1149     absl::Span<const int64_t> dimensions_to_reduce,
1150     HloComputation* reduce_computation)
1151     : HloDimensionsInstruction(HloOpcode::kReduce, shape,
1152                                dimensions_to_reduce) {
1153   for (HloInstruction* arg : args) {
1154     AppendOperand(arg);
1155   }
1156   AppendComputation(reduce_computation);
1157 }
1158 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1159 bool HloReduceInstruction::IdenticalSlowPath(
1160     const HloInstruction& other,
1161     const std::function<bool(const HloComputation*, const HloComputation*)>&
1162         eq_computations) const {
1163   const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
1164   // Reduction results are determined by the reduction dimension and the
1165   // reduction computation.
1166   return dimensions() == casted_other.dimensions() &&
1167          eq_computations(to_apply(), casted_other.to_apply());
1168 }
1169 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1170 std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
1171     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1172     HloCloneContext* context) const {
1173   CHECK_EQ(new_operands.size() % 2, 0);
1174   return std::make_unique<HloReduceInstruction>(shape, new_operands,
1175                                                 dimensions(), to_apply());
1176 }
1177 
HloSortInstruction(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1178 HloSortInstruction::HloSortInstruction(
1179     const Shape& shape, int64_t dimension,
1180     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1181     bool is_stable)
1182     : HloDimensionsInstruction(HloOpcode::kSort, shape, {dimension}),
1183       is_stable_(is_stable) {
1184   for (auto* value : operands) {
1185     AppendOperand(value);
1186   }
1187   AppendComputation(compare);
1188 }
1189 
ToProto() const1190 HloInstructionProto HloSortInstruction::ToProto() const {
1191   HloInstructionProto proto = HloInstruction::ToProto();
1192   for (int64_t dimension : dimensions_) {
1193     proto.add_dimensions(dimension);
1194   }
1195   proto.set_is_stable(is_stable());
1196   return proto;
1197 }
1198 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1199 std::vector<std::string> HloSortInstruction::ExtraAttributesToStringImpl(
1200     const HloPrintOptions& options) const {
1201   std::vector<std::string> attrs;
1202   attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
1203   if (is_stable()) {
1204     attrs.push_back("is_stable=true");
1205   }
1206   return attrs;
1207 }
1208 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1209 bool HloSortInstruction::IdenticalSlowPath(
1210     const HloInstruction& other,
1211     const std::function<bool(const HloComputation*, const HloComputation*)>&
1212         eq_computations) const {
1213   const auto& casted_other = static_cast<const HloSortInstruction&>(other);
1214   if (dimensions() != casted_other.dimensions()) {
1215     return false;
1216   }
1217   if (is_stable() != casted_other.is_stable()) {
1218     return false;
1219   }
1220   return eq_computations(to_apply(), other.to_apply());
1221 }
1222 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1223 std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
1224     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1225     HloCloneContext* context) const {
1226   return std::make_unique<HloSortInstruction>(
1227       shape, dimensions_[0], new_operands, to_apply(), is_stable());
1228 }
1229 
HloTransposeInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1230 HloTransposeInstruction::HloTransposeInstruction(
1231     const Shape& shape, HloInstruction* operand,
1232     absl::Span<const int64_t> dimensions)
1233     : HloDimensionsInstruction(HloOpcode::kTranspose, shape, dimensions) {
1234   AppendOperand(operand);
1235 }
1236 
IsRank2Transpose() const1237 bool HloTransposeInstruction::IsRank2Transpose() const {
1238   return dimensions() == std::vector<int64_t>({1, 0}) &&
1239          shape().dimensions_size() == 2 &&
1240          std::equal(shape().dimensions().begin(), shape().dimensions().end(),
1241                     operand(0)->shape().dimensions().rbegin());
1242 }
1243 
1244 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1245 HloTransposeInstruction::CloneWithNewOperandsImpl(
1246     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1247     HloCloneContext* context) const {
1248   CHECK_EQ(new_operands.size(), 1);
1249   return std::make_unique<HloTransposeInstruction>(shape, new_operands[0],
1250                                                    dimensions());
1251 }
1252 
HloBroadcastInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> broadcast_dimension)1253 HloBroadcastInstruction::HloBroadcastInstruction(
1254     const Shape& shape, HloInstruction* operand,
1255     absl::Span<const int64_t> broadcast_dimension)
1256     : HloDimensionsInstruction(HloOpcode::kBroadcast, shape,
1257                                broadcast_dimension) {
1258   AppendOperand(operand);
1259 }
1260 
1261 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1262 HloBroadcastInstruction::CloneWithNewOperandsImpl(
1263     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1264     HloCloneContext* context) const {
1265   CHECK_EQ(new_operands.size(), 1);
1266   return std::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
1267                                                    dimensions());
1268 }
1269 
HloDynamicReshapeInstruction(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1270 HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
1271     const Shape& shape, HloInstruction* data_operand,
1272     absl::Span<HloInstruction* const> dim_sizes)
1273     : HloInstruction(HloOpcode::kDynamicReshape, shape) {
1274   AppendOperand(data_operand);
1275   for (auto operand : dim_sizes) {
1276     AppendOperand(operand);
1277   }
1278 }
1279 
1280 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1281 HloDynamicReshapeInstruction::CloneWithNewOperandsImpl(
1282     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1283     HloCloneContext* context) const {
1284   CHECK_GE(new_operands.size(), 1);
1285   return std::make_unique<HloDynamicReshapeInstruction>(
1286       shape, new_operands[0], new_operands.subspan(1));
1287 }
1288 
HloReshapeInstruction(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1289 HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
1290                                              HloInstruction* operand,
1291                                              int64_t inferred_dimension)
1292     : HloInstruction(HloOpcode::kReshape, shape),
1293       inferred_dimension_(inferred_dimension) {
1294   AppendOperand(operand);
1295 }
1296 
ToProto() const1297 HloInstructionProto HloReshapeInstruction::ToProto() const {
1298   HloInstructionProto proto = HloInstruction::ToProto();
1299   if (inferred_dimension_ != -1) {
1300     proto.add_dimensions(inferred_dimension_);
1301   }
1302   return proto;
1303 }
1304 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1305 std::vector<std::string> HloReshapeInstruction::ExtraAttributesToStringImpl(
1306     const HloPrintOptions& options) const {
1307   if (inferred_dimension() == -1) {
1308     return {};
1309   }
1310   return {StrCat("inferred_dimension=", inferred_dimension())};
1311 }
1312 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1313 bool HloReshapeInstruction::IdenticalSlowPath(
1314     const HloInstruction& other,
1315     const std::function<bool(const HloComputation*, const HloComputation*)>&
1316         eq_computations) const {
1317   const auto& casted_other = static_cast<const HloReshapeInstruction&>(other);
1318   return inferred_dimension() == casted_other.inferred_dimension();
1319 }
1320 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1321 std::unique_ptr<HloInstruction> HloReshapeInstruction::CloneWithNewOperandsImpl(
1322     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1323     HloCloneContext* context) const {
1324   CHECK_EQ(new_operands.size(), 1);
1325   return std::make_unique<HloReshapeInstruction>(shape, new_operands[0],
1326                                                  inferred_dimension());
1327 }
1328 
HloMapInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1329 HloMapInstruction::HloMapInstruction(const Shape& shape,
1330                                      absl::Span<HloInstruction* const> operands,
1331                                      HloComputation* map_computation)
1332     : HloInstruction(HloOpcode::kMap, shape) {
1333   for (auto operand : operands) {
1334     AppendOperand(operand);
1335   }
1336   AppendComputation(map_computation);
1337   // TODO(b/65689298) Remove code below once Map is generalized to accept
1338   // arbitrary map dimensions.
1339   dimensions_.resize(shape.rank());
1340   std::iota(dimensions_.begin(), dimensions_.end(), 0);
1341 }
1342 
ToProto() const1343 HloInstructionProto HloMapInstruction::ToProto() const {
1344   HloInstructionProto proto = HloInstruction::ToProto();
1345   for (int64_t dimension : dimensions_) {
1346     proto.add_dimensions(dimension);
1347   }
1348   return proto;
1349 }
1350 
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1351 bool HloMapInstruction::IsElementwiseImpl(
1352     const std::optional<int64_t>& operand_idx) const {
1353   if (!dimensions().empty()) {
1354     // Check that the map is executed in elementwise compatible dimensions.
1355     if (dimensions().size() != shape().dimensions_size()) {
1356       return false;
1357     }
1358     for (int i = 0; i < dimensions().size(); ++i) {
1359       if (dimensions()[i] != i) {
1360         return false;
1361       }
1362     }
1363   }
1364   return true;
1365 }
1366 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1367 std::vector<std::string> HloMapInstruction::ExtraAttributesToStringImpl(
1368     const HloPrintOptions& options) const {
1369   return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
1370 }
1371 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1372 bool HloMapInstruction::IdenticalSlowPath(
1373     const HloInstruction& other,
1374     const std::function<bool(const HloComputation*, const HloComputation*)>&
1375         eq_computations) const {
1376   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
1377   return eq_computations(to_apply(), casted_other.to_apply()) &&
1378          dimensions() == casted_other.dimensions();
1379 }
1380 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1381 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
1382     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1383     HloCloneContext* context) const {
1384   return std::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
1385 }
1386 
HloSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1387 HloSliceInstruction::HloSliceInstruction(
1388     const Shape& shape, HloInstruction* operand,
1389     absl::Span<const int64_t> start_indices,
1390     absl::Span<const int64_t> limit_indices, absl::Span<const int64_t> strides)
1391     : HloInstruction(HloOpcode::kSlice, shape),
1392       slice_starts_(start_indices.begin(), start_indices.end()),
1393       slice_limits_(limit_indices.begin(), limit_indices.end()),
1394       slice_strides_(strides.begin(), strides.end()) {
1395   AppendOperand(operand);
1396   // For backward compatibility with old serialized computations: if there are
1397   // no strides, assume all strides are 1.
1398   // TODO(b/63317920): remove this code.
1399   if (slice_strides_.empty()) {
1400     slice_strides_ = std::vector<int64_t>(start_indices.size(), 1LL);
1401   }
1402 }
1403 
ToProto() const1404 HloInstructionProto HloSliceInstruction::ToProto() const {
1405   HloInstructionProto proto = HloInstruction::ToProto();
1406   for (int i = 0; i < slice_starts_.size(); ++i) {
1407     auto* slice_dimension = proto.add_slice_dimensions();
1408     slice_dimension->set_start(slice_starts_[i]);
1409     slice_dimension->set_limit(slice_limits_[i]);
1410     slice_dimension->set_stride(slice_strides_[i]);
1411   }
1412   return proto;
1413 }
1414 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const1415 std::vector<std::string> HloSliceInstruction::ExtraAttributesToStringImpl(
1416     const HloPrintOptions& options) const {
1417   std::vector<std::string> bounds;
1418   bounds.reserve(slice_starts_.size());
1419   const bool omit_stride = absl::c_all_of(
1420       slice_strides_, [](int64_t stride) { return stride == 1; });
1421   for (int i = 0; i < slice_starts_.size(); ++i) {
1422     std::string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
1423     bounds.push_back(
1424         StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
1425   }
1426   return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
1427 }
1428 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1429 bool HloSliceInstruction::IdenticalSlowPath(
1430     const HloInstruction& other,
1431     const std::function<bool(const HloComputation*, const HloComputation*)>&
1432         eq_computations) const {
1433   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1434   return slice_starts_ == other_slice.slice_starts_ &&
1435          slice_limits_ == other_slice.slice_limits_ &&
1436          slice_strides_ == other_slice.slice_strides_;
1437 }
1438 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1439 std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
1440     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1441     HloCloneContext* context) const {
1442   CHECK_EQ(new_operands.size(), 1);
1443   return std::make_unique<HloSliceInstruction>(
1444       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
1445 }
1446 
HloConstantInstruction(Literal literal)1447 HloConstantInstruction::HloConstantInstruction(Literal literal)
1448     : HloInstruction(HloOpcode::kConstant, literal.shape()),
1449       literal_(std::move(literal)) {}
1450 
HloConstantInstruction(Literal literal,const Shape & shape)1451 HloConstantInstruction::HloConstantInstruction(Literal literal,
1452                                                const Shape& shape)
1453     : HloInstruction(HloOpcode::kConstant, shape),
1454       literal_(std::move(literal)) {}
1455 
HloConstantInstruction(const Shape & shape)1456 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
1457     : HloInstruction(HloOpcode::kConstant, shape) {}
1458 
ToProto() const1459 HloInstructionProto HloConstantInstruction::ToProto() const {
1460   HloInstructionProto proto = HloInstruction::ToProto();
1461   if (literal_.has_value()) {
1462     *proto.mutable_literal() = literal_->ToProto();
1463   }
1464   return proto;
1465 }
1466 
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1467 bool HloConstantInstruction::IsElementwiseImpl(
1468     const std::optional<int64_t>& operand_idx) const {
1469   return true;
1470 }
1471 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)1472 void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
1473                                               const ShapeIndex& shape_index) {
1474   Shape* mutable_array_subshape =
1475       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
1476   CHECK(mutable_array_subshape->IsArray());
1477 
1478   // Normally array_subshape will always have a layout, but this invariant is
1479   // temporarily broken in LayoutAssignment::AssignLayouts.
1480 
1481   if (!mutable_array_subshape->has_layout() ||
1482       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
1483     *literal_ = literal_->Relayout(new_layout, shape_index);
1484     *mutable_array_subshape->mutable_layout() = new_layout;
1485   }
1486 }
1487 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1488 bool HloConstantInstruction::IdenticalSlowPath(
1489     const HloInstruction& other,
1490     const std::function<bool(const HloComputation*, const HloComputation*)>&
1491         eq_computations) const {
1492   const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
1493   return literal() == other_slice.literal();
1494 }
1495 
1496 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1497 HloConstantInstruction::CloneWithNewOperandsImpl(
1498     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1499     HloCloneContext* context) const {
1500   if (!literal_.has_value()) {
1501     return std::make_unique<HloConstantInstruction>(this->shape());
1502   }
1503   CHECK(literal_.has_value());
1504   // Literal's shape may have no/different tiling info. Use this instruction's
1505   // shape instead.
1506   CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
1507                                                   this->shape()));
1508   return std::make_unique<HloConstantInstruction>(literal_->Clone(),
1509                                                   this->shape());
1510 }
1511 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const1512 std::string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
1513     const HloPrintOptions& options,
1514     CanonicalNameMap* canonical_name_map) const {
1515   if (options.print_only_essential_constants()) {
1516     if (!literal_.has_value()) {
1517       return "{...}";
1518     }
1519     if (literal().IsAll(0)) {
1520       return "0";
1521     }
1522     if (literal().IsAll(1)) {
1523       return "1";
1524     }
1525     if (shape().IsInteger()) {
1526       return literal_->ToStringWithoutShapeOneline();
1527     }
1528     return "{...}";
1529   }
1530 
1531   // For constants, show the actual value in place of an empty operand list.
1532   if (literal_.has_value() &&
1533       ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) ||
1534        options.print_large_constants())) {
1535     // Literal::ToString emits multidimensional arrays over multiple
1536     // lines. Compact this into one line by stripping out white space.
1537     return literal_->ToStringWithoutShapeOneline();
1538   } else {
1539     // Do not show large constants or tuples.
1540     return "{...}";
1541   }
1542 }
1543 
HloCallableInstruction(HloOpcode opcode,const Shape & shape)1544 HloCallableInstruction::HloCallableInstruction(HloOpcode opcode,
1545                                                const Shape& shape)
1546     : HloInstruction(opcode, shape) {}
1547 
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands)1548 HloCallableInstruction::HloCallableInstruction(
1549     HloOpcode opcode, const Shape& shape,
1550     absl::Span<HloInstruction* const> operands)
1551     : HloInstruction(opcode, shape) {
1552   for (auto operand : operands) {
1553     AppendOperand(operand);
1554   }
1555   SetAndSanitizeName(HloOpcodeString(opcode));
1556 }
1557 
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * called_computation,absl::string_view prefix)1558 HloCallableInstruction::HloCallableInstruction(
1559     HloOpcode opcode, const Shape& shape,
1560     absl::Span<HloInstruction* const> operands,
1561     HloComputation* called_computation, absl::string_view prefix)
1562     : HloInstruction(opcode, shape) {
1563   for (auto operand : operands) {
1564     AppendOperand(operand);
1565   }
1566   SetAndSanitizeName(std::string(prefix) + HloOpcodeString(opcode));
1567   AppendComputation(called_computation);
1568 }
1569 
HloCallableInstruction(HloOpcode opcode,const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations)1570 HloCallableInstruction::HloCallableInstruction(
1571     HloOpcode opcode, const Shape& shape,
1572     absl::Span<HloInstruction* const> operands,
1573     absl::Span<HloComputation* const> called_computations)
1574     : HloInstruction(opcode, shape) {
1575   for (auto operand : operands) {
1576     AppendOperand(operand);
1577   }
1578   SetAndSanitizeName(HloOpcodeString(opcode));
1579   for (auto called_computation : called_computations) {
1580     AppendComputation(called_computation);
1581   }
1582 }
1583 
~HloCallableInstruction()1584 HloCallableInstruction::~HloCallableInstruction() { ClearCalledComputations(); }
1585 
called_computation() const1586 HloComputation* HloCallableInstruction::called_computation() const {
1587   CHECK(!called_computations().empty());
1588   return called_computations().front();
1589 }
1590 
called_computation_root() const1591 HloInstruction* HloCallableInstruction::called_computation_root() const {
1592   return called_computation()->root_instruction();
1593 }
1594 
AddCallOperand(HloInstruction * new_operand)1595 HloInstruction* HloCallableInstruction::AddCallOperand(
1596     HloInstruction* new_operand) {
1597   CHECK_EQ(operand_count(),
1598            called_computation()->parameter_instructions().size());
1599   const int64_t param_no = operand_count();
1600   std::string param_name = StrCat("param_", param_no);
1601   HloInstruction* called_computation_parameter =
1602       called_computation()->AddParameter(HloInstruction::CreateParameter(
1603           param_no, new_operand->shape(), param_name));
1604   AppendOperand(new_operand);
1605   return called_computation_parameter;
1606 }
1607 
AppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)1608 HloInstruction* HloCallableInstruction::AppendInstructionIntoCalledComputation(
1609     HloInstruction* instruction_to_append, bool add_output) {
1610   // When add_output is false, this callable instruction must be a user of
1611   // instruction_to_append.
1612   if (!add_output) {
1613     CHECK(IsUserOf(instruction_to_append));
1614   }
1615   return CloneAndAppendInstructionIntoCalledComputation(instruction_to_append,
1616                                                         add_output);
1617 }
1618 
1619 HloInstruction*
CloneAndAppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)1620 HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation(
1621     HloInstruction* instruction_to_append, bool add_output) {
1622   CHECK(instruction_to_append->IsFusible())
1623       << instruction_to_append->ToString();
1624   VLOG(3) << "CloneAndAppendInstructionIntoCalledComputation:\n"
1625           << instruction_to_append->ToString();
1626   HloInstruction* clone = nullptr;
1627   if (called_computations().empty()) {
1628     // New fusion instruction. It should not be a multioutput instruction.
1629     CHECK(!add_output);
1630     auto builder = HloComputation::Builder(
1631         default_called_computation_name(),
1632         opcode() == HloOpcode::kFusion ? this : nullptr);
1633     builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/""));
1634     AppendComputation(
1635         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
1636     clone = called_computation_root();
1637   } else {
1638     // When add_output is false, instruction_to_append is necessarily an operand
1639     // of the callable instruction. After appending this will no longer be the
1640     // case. Remove the operand from the operand list and remove its
1641     // corresponding called computation parameter instruction.
1642     bool in_operand_list =
1643         absl::c_linear_search(operands(), instruction_to_append);
1644     CHECK(add_output || in_operand_list);
1645     if (instruction_to_append->opcode() == HloOpcode::kTuple) {
1646       // We assume all uses of a kTuple operation are GTE ops. In this case, we
1647       // don't need to clone 'instruction_to_append'.
1648       CHECK(!in_operand_list);
1649       clone = instruction_to_append;
1650     } else {
1651       clone = called_computation()->AddInstruction(
1652           instruction_to_append->Clone(/*suffix=*/""));
1653     }
1654     const std::vector<HloInstruction*>& called_computation_parameters =
1655         called_computation()->parameter_instructions();
1656     for (int64_t operand_num = 0; operand_num < operand_count();
1657          ++operand_num) {
1658       if (instruction_to_append == operand(operand_num)) {
1659         // Replace the called computation parameter instruction's uses with the
1660         // clone.
1661         HloInstruction* called_computation_parameter =
1662             called_computation_parameters[operand_num];
1663         TF_CHECK_OK(called_computation_parameter->ReplaceAllUsesWith(clone));
1664 
1665         // Remove the corresponding called computation parameter and operand
1666         // from their respective vectors.
1667         TF_CHECK_OK(called_computation()->RemoveParameter(operand_num));
1668         RemoveOperandAt(operand_num);
1669         break;
1670       }
1671     }
1672     // We've cloned instruction_to_append into this callable instruction, so
1673     // this callable instruction is no longer a use of instruction_to_append.
1674     if (in_operand_list) {
1675       DetachFrom(instruction_to_append);
1676       // When the instruction_to_append does not have other users, we don't need
1677       // to generate a multioutput instruction.
1678       if (instruction_to_append->user_count() == 0) {
1679         add_output = false;
1680       }
1681     }
1682   }
1683 
1684   // Reread the parameters in the computation.
1685   const std::vector<HloInstruction*>& called_computation_parameters =
1686       called_computation()->parameter_instructions();
1687 
1688   // Add each operand of the clone as an operand of the callable instruction. A
1689   // complication is that some clone operands may already be operands of the
1690   // callable instruction.
1691   for (int64_t operand_num = 0; operand_num < clone->operand_count();
1692        ++operand_num) {
1693     HloInstruction* operand = clone->mutable_operand(operand_num);
1694 
1695     // See if this operand is already an operand of the callable instruction.
1696     CHECK_EQ(operands().size(), called_computation_parameters.size());
1697     HloInstruction* called_computation_parameter = nullptr;
1698     for (int64_t i = 0; i < operands().size(); ++i) {
1699       if (this->operand(i) == operand) {
1700         called_computation_parameter = called_computation_parameters[i];
1701         break;
1702       }
1703     }
1704 
1705     if (called_computation_parameter == nullptr) {
1706       // Clone's operand was not already an operand of the callable instruction.
1707       // Add it as an operand and add a corresponding called computation
1708       // parameter instruction.
1709       called_computation_parameter = AddCallOperand(operand);
1710     }
1711     TF_CHECK_OK(
1712         clone->ReplaceOperandWith(operand_num, called_computation_parameter));
1713   }
1714 
1715   if (add_output) {
1716     CHECK_GT(instruction_to_append->user_count(), 0);
1717     // If this is already a multioutput instruction, expand the root tuple by 1.
1718     HloInstruction* root = called_computation_root();
1719     HloInstruction::InstructionVector tuple_elements;
1720     bool newly_created_tuple_instr = false;
1721     if (root->opcode() == HloOpcode::kTuple) {
1722       tuple_elements = root->operands();
1723     } else {
1724       tuple_elements.push_back(root);
1725       newly_created_tuple_instr = true;
1726     }
1727     if (clone->opcode() == HloOpcode::kTuple) {
1728       for (auto inst : clone->operands()) {
1729         tuple_elements.push_back(inst);
1730       }
1731     } else {
1732       tuple_elements.push_back(clone);
1733     }
1734     HloInstruction* new_root = called_computation()->AddInstruction(
1735         HloInstruction::CreateTuple(tuple_elements));
1736     called_computation()->set_root_instruction(new_root,
1737                                                /*accept_different_shape=*/true);
1738     *mutable_shape() = new_root->shape();
1739     if (root->opcode() == HloOpcode::kTuple) {
1740       TF_CHECK_OK(called_computation()->RemoveInstruction(root));
1741     }
1742 
1743     // If this is a newly created multioutput instruction, we need to update
1744     // the use of the original callable instruction.
1745     if (newly_created_tuple_instr) {
1746       HloInstruction* new_instr = parent()->AddInstruction(
1747           HloInstruction::CreateGetTupleElement(root->shape(), this, 0));
1748       TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr));
1749     }
1750     int64_t index = tuple_elements.size();
1751     if (instruction_to_append->opcode() == HloOpcode::kTuple) {
1752       CHECK_EQ(clone, instruction_to_append);
1753       index -= clone->operand_count();
1754       std::vector<HloInstruction*> to_be_removed;
1755       const auto& users = clone->users();
1756       to_be_removed.reserve(users.size());
1757       for (auto old_gte : users) {
1758         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1759         int64_t old_tuple_index = old_gte->tuple_index();
1760         HloInstruction* new_gte =
1761             parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1762                 old_gte->shape(), this, index + old_tuple_index));
1763         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1764         to_be_removed.push_back(old_gte);
1765       }
1766       for (auto old_gte : to_be_removed) {
1767         TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
1768       }
1769     } else {
1770       HloInstruction* new_gte =
1771           parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
1772               clone->shape(), this, index - 1));
1773       TF_CHECK_OK(instruction_to_append->ReplaceAllUsesWith(new_gte));
1774     }
1775   }
1776 
1777   if (clone != instruction_to_append) {
1778     VLOG(2) << "New clone:\n" << clone->ToString();
1779   }
1780   return clone;
1781 }
1782 
1783 absl::InlinedVector<HloComputation*, 1>
GetOrCloneCalledComputations(HloCloneContext * context) const1784 HloCallableInstruction::GetOrCloneCalledComputations(
1785     HloCloneContext* context) const {
1786   HloModule* module = context != nullptr ? context->module() : GetModule();
1787   absl::InlinedVector<HloComputation*, 1> new_called_computations;
1788   for (auto* comp : called_computations()) {
1789     HloComputation* new_custom_call_computation = nullptr;
1790     if (context != nullptr) {
1791       new_custom_call_computation = context->FindComputation(comp);
1792     }
1793     if (new_custom_call_computation == nullptr) {
1794       new_custom_call_computation =
1795           module->AddEmbeddedComputation(comp->Clone("clone", context));
1796     }
1797     new_called_computations.push_back(new_custom_call_computation);
1798   }
1799   return new_called_computations;
1800 }
1801 
RecursivelySetComputationsThreadName(absl::string_view execution_thread,bool skip_async_execution_thread_overwrite)1802 void HloCallableInstruction::RecursivelySetComputationsThreadName(
1803     absl::string_view execution_thread,
1804     bool skip_async_execution_thread_overwrite) {
1805   for (HloComputation* comp : called_computations()) {
1806     SetThreadName(comp, execution_thread,
1807                   skip_async_execution_thread_overwrite);
1808   }
1809 }
1810 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1811 HloFusionInstruction::HloFusionInstruction(const Shape& shape,
1812                                            FusionKind fusion_kind,
1813                                            HloInstruction* fused_root)
1814     : HloCallableInstruction(HloOpcode::kFusion, shape),
1815       fusion_kind_(fusion_kind) {
1816   CHECK(fused_root != nullptr);
1817   SetAndSanitizeName(HloOpcodeString(opcode()));
1818   set_parent(fused_root->parent());
1819   set_metadata(fused_root->metadata());
1820   CHECK(fused_root->IsFusible()) << fused_root->ToString();
1821   CloneAndAppendInstructionIntoCalledComputation(fused_root);
1822 }
1823 
HloFusionInstruction(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation,absl::string_view prefix)1824 HloFusionInstruction::HloFusionInstruction(
1825     const Shape& shape, FusionKind fusion_kind,
1826     absl::Span<HloInstruction* const> operands,
1827     HloComputation* fusion_computation, absl::string_view prefix)
1828     : HloCallableInstruction(HloOpcode::kFusion, shape, operands,
1829                              fusion_computation, prefix),
1830       fusion_kind_(fusion_kind) {
1831   fusion_computation->SetFusionInstruction(this);
1832 }
1833 
~HloFusionInstruction()1834 HloFusionInstruction::~HloFusionInstruction() {
1835   ClearFusionComputationInstruction();
1836 }
1837 
ClearFusionComputationInstruction()1838 void HloFusionInstruction::ClearFusionComputationInstruction() {
1839   // Each fusion calls a single computation, but we use called_computations()
1840   // instead of fused_instructions_computation(), because the order in which
1841   // things get destructed can vary; the fusion computation's back-pointer may
1842   // already be null, which violates a check in fused_instructions_computation.
1843   for (HloComputation* computation : called_computations()) {
1844     // Some passes that rewrite fusions may reassign a fusion computation to a
1845     // different fusion instruction as this instruction gets destructed.
1846     if (computation->FusionInstruction() == this) {
1847       computation->SetFusionInstruction(nullptr);
1848     }
1849   }
1850 }
1851 
ClearCalledComputations()1852 void HloFusionInstruction::ClearCalledComputations() {
1853   ClearFusionComputationInstruction();
1854   HloInstruction::ClearCalledComputations();
1855 }
1856 
ToCategory() const1857 std::string HloFusionInstruction::ToCategory() const {
1858   switch (fusion_kind()) {
1859     case FusionKind::kLoop:
1860       return "loop fusion";
1861     case FusionKind::kInput:
1862       return "input fusion";
1863     case FusionKind::kOutput:
1864       return "output fusion";
1865     case FusionKind::kCustom:
1866       return "custom fusion";
1867   }
1868 }
1869 
ToProto() const1870 HloInstructionProto HloFusionInstruction::ToProto() const {
1871   HloInstructionProto proto = HloInstruction::ToProto();
1872   proto.set_fusion_kind(xla::ToString(fusion_kind()));
1873   proto.add_called_computation_ids(
1874       fused_instructions_computation()->unique_id());
1875   return proto;
1876 }
1877 
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const1878 bool HloFusionInstruction::IsElementwiseImpl(
1879     const std::optional<int64_t>& operand_idx) const {
1880   if (!operand_idx.has_value()) {
1881     for (auto* fused : fused_instructions()) {
1882       if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
1883         return false;
1884       }
1885     }
1886     return true;
1887   }
1888   // A loop-fusion is elementwise on an operand if all operations (computed
1889   // using BFS) between the operand and the fused root are elementwise.
1890   std::deque<HloInstruction*> worklist;
1891   absl::flat_hash_set<const HloInstruction*> visited;
1892   worklist.push_back(fused_parameter(operand_idx.value()));
1893   visited.insert(fused_parameter(operand_idx.value()));
1894   while (!worklist.empty()) {
1895     HloInstruction* operand = worklist.front();
1896     worklist.pop_front();
1897     for (HloInstruction* user : operand->users()) {
1898       CHECK_GE(user->unique_id(), 0);
1899       if (ContainsKey(visited, user)) {
1900         continue;
1901       }
1902       if (user->IsElementwise() ||
1903           IsInstructionElementwiseOnOperand(user, operand)) {
1904         worklist.push_back(user);
1905         visited.insert(user);
1906       } else {
1907         return false;
1908       }
1909     }
1910   }
1911   return true;
1912 }
1913 
AddFusionOperand(HloInstruction * new_operand)1914 HloInstruction* HloFusionInstruction::AddFusionOperand(
1915     HloInstruction* new_operand) {
1916   return AddCallOperand(new_operand);
1917 }
1918 
MergeFusionInstruction(HloFusionInstruction * instruction_to_merge)1919 void HloFusionInstruction::MergeFusionInstruction(
1920     HloFusionInstruction* instruction_to_merge) {
1921   CHECK(absl::c_linear_search(operands(), instruction_to_merge));
1922   // Clone the instruction from which to merge fused instructions.
1923   std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
1924   HloFusionInstruction* cloned_fusion =
1925       static_cast<HloFusionInstruction*>(cloned.get());
1926   // Replace uses of fused parameters with the corresponding operand of the
1927   // fusion.  Add all non-parameter fused instructions to
1928   // 'unfused_instructions' to be merged into 'this'.  This is done in reverse
1929   // post order.
1930   std::vector<HloInstruction*> unfused_instructions;
1931   auto fused_instructions = cloned_fusion->fused_instructions_computation()
1932                                 ->MakeInstructionPostOrder();
1933   for (auto fused_it = fused_instructions.rbegin();
1934        fused_it != fused_instructions.rend(); ++fused_it) {
1935     auto fused_instruction = *fused_it;
1936     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1937       TF_CHECK_OK(
1938           fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
1939               fused_instruction->parameter_number())));
1940     } else {
1941       unfused_instructions.push_back(fused_instruction);
1942     }
1943   }
1944 
1945   // If there are no unfused instructions, the fused computation must consist
1946   // only of kParameter instructions. Make the operand of the corresponding
1947   // parameter number the new root.
1948   HloInstruction* unfused_root =
1949       unfused_instructions.empty()
1950           ? instruction_to_merge->mutable_operand(
1951                 instruction_to_merge->fused_instructions_computation()
1952                     ->root_instruction()
1953                     ->parameter_number())
1954           : unfused_instructions.front();
1955   CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
1956         unfused_instructions.empty());
1957   // Replace instruction_to_merge use of 'this' with unfused_root.
1958   TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
1959 
1960   // Build a dummy root for the cloned fusion as we may remove the original root
1961   // in the fusion process.
1962   if (!unfused_instructions.empty()) {
1963     HloComputation* computation = unfused_root->parent();
1964     auto* dummy_root = computation->AddInstruction(
1965         HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
1966     computation->set_root_instruction(dummy_root,
1967                                       /*accept_different_shape=*/true);
1968   }
1969 
1970   // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
1971   // we remove it from the closed fusion node. This is so that we don't add
1972   // extra users to the producer of that instruction (we use user count to
1973   // decide if a side-effectful instruction is fusible).
1974   for (auto& instruction : unfused_instructions) {
1975     auto* fused = FuseInstruction(instruction);
1976     TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
1977     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
1978   }
1979   CHECK_EQ(0, cloned_fusion->user_count());
1980   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
1981       cloned_fusion->fused_instructions_computation()));
1982 }
1983 
MergeFusionInstructionIntoMultiOutput(HloFusionInstruction * instruction_to_merge)1984 void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
1985     HloFusionInstruction* instruction_to_merge) {
1986   // Add all non-parameter fused instructions to 'unfused_instructions' to be
1987   // merged into 'this'. `old_to_new' maps the instructions in the fused node
1988   // to the disassembled fusion instructions.
1989   // Note that we add the unfused instructions to this->parent_ computation.
1990   // This is necessary because the unique_id needs for an instruction and
1991   // it's only added when inserting to the computation.
1992   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
1993   std::vector<HloInstruction*> unfused_instructions;
1994   auto computation_to_merge =
1995       instruction_to_merge->fused_instructions_computation();
1996   auto post_order = computation_to_merge->MakeInstructionPostOrder();
1997   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
1998     auto fused_instruction = *rit;
1999     if (fused_instruction->opcode() == HloOpcode::kParameter) {
2000       InsertOrDie(&old_to_new, fused_instruction,
2001                   instruction_to_merge->mutable_operand(
2002                       fused_instruction->parameter_number()));
2003       continue;
2004     }
2005 
2006     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
2007     // which clones again. This can be improved.
2008     auto cloned_instruction =
2009         parent()->AddInstruction(fused_instruction->Clone());
2010     unfused_instructions.push_back(cloned_instruction);
2011     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
2012   }
2013   for (auto unfused_instruction : unfused_instructions) {
2014     for (int64_t index = 0; index < unfused_instruction->operand_count();
2015          index++) {
2016       auto new_operand =
2017           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
2018       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
2019     }
2020   }
2021 
2022   // If there are no unfused instructions, the fused computation must consist
2023   // only of kParameter instructions. Make the operand of the corresponding
2024   // parameter number the new root.
2025   HloInstruction* unfused_root =
2026       unfused_instructions.empty()
2027           ? instruction_to_merge->mutable_operand(
2028                 instruction_to_merge->fused_instructions_computation()
2029                     ->root_instruction()
2030                     ->parameter_number())
2031           : unfused_instructions.front();
2032   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
2033 
2034   TF_CHECK_OK(
2035       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
2036   if (GetModule()) {
2037     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
2038   }
2039 
2040   // Fuse the root instruction and generate multiple outputs.
2041   if (unfused_instructions.empty()) {
2042     return;
2043   }
2044   FuseInstructionIntoMultiOutput(unfused_root);
2045   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
2046   // The rest instructions are of normal fusing.
2047   for (int64_t i = 1; i < unfused_instructions.size(); i++) {
2048     auto instruction = unfused_instructions[i];
2049     FuseInstruction(instruction);
2050     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
2051   }
2052 }
2053 
fused_instructions_computation() const2054 HloComputation* HloFusionInstruction::fused_instructions_computation() const {
2055   CHECK(!called_computations().empty());
2056   auto* fused_instructions_computation = called_computations().front();
2057   CHECK(fused_instructions_computation->IsFusionComputation())
2058       << "Computation " << fused_instructions_computation->name()
2059       << " is not a fusion kind";
2060   return fused_instructions_computation;
2061 }
2062 
fused_expression_root() const2063 HloInstruction* HloFusionInstruction::fused_expression_root() const {
2064   return fused_instructions_computation()->root_instruction();
2065 }
2066 
fused_parameter(int64_t parameter_number) const2067 HloInstruction* HloFusionInstruction::fused_parameter(
2068     int64_t parameter_number) const {
2069   return fused_instructions_computation()->parameter_instruction(
2070       parameter_number);
2071 }
2072 
fused_parameters() const2073 const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
2074     const {
2075   return fused_instructions_computation()->parameter_instructions();
2076 }
2077 
2078 const tensorflow::gtl::iterator_range<UnwrappingIterator<
2079     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const2080 HloFusionInstruction::fused_instructions() const {
2081   const HloComputation* subcomp = fused_instructions_computation();
2082   return subcomp->instructions();
2083 }
2084 
2085 const tensorflow::gtl::iterator_range<
2086     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()2087 HloFusionInstruction::fused_instructions() {
2088   return fused_instructions_computation()->instructions();
2089 }
2090 
fused_instruction_count() const2091 int64_t HloFusionInstruction::fused_instruction_count() const {
2092   return fused_instructions_computation()->instruction_count();
2093 }
2094 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2095 std::vector<std::string> HloFusionInstruction::ExtraAttributesToStringImpl(
2096     const HloPrintOptions& options) const {
2097   return {StrCat("kind=", xla::ToString(fusion_kind()))};
2098 }
2099 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2100 bool HloFusionInstruction::IdenticalSlowPath(
2101     const HloInstruction& other,
2102     const std::function<bool(const HloComputation*, const HloComputation*)>&
2103         eq_computations) const {
2104   return fusion_kind() == other.fusion_kind() &&
2105          eq_computations(fused_instructions_computation(),
2106                          other.fused_instructions_computation());
2107 }
2108 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2109 std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
2110     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2111     HloCloneContext* context) const {
2112   auto new_fused_computation = GetOrCloneCalledComputations(context);
2113   CHECK_EQ(new_fused_computation.size(), 1);
2114   return std::make_unique<HloFusionInstruction>(
2115       shape, fusion_kind(), new_operands, new_fused_computation.front());
2116 }
2117 
DeduplicateFusionOperands()2118 Status HloFusionInstruction::DeduplicateFusionOperands() {
2119   if (IsCustomFusion()) {
2120     return OkStatus();
2121   }
2122   absl::flat_hash_map<const HloInstruction*, int> operand_indices;
2123   std::vector<int> operands_to_remove;
2124   const int count = operand_count();
2125   operands_to_remove.reserve(count);
2126   for (int i = 0; i < count; ++i) {
2127     auto emplace_result = operand_indices.emplace(operand(i), i);
2128     if (!emplace_result.second) {
2129       TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
2130           fused_parameter(emplace_result.first->second)));
2131       operands_to_remove.push_back(i);
2132     }
2133   }
2134   if (operands_to_remove.empty()) {
2135     return OkStatus();
2136   }
2137   TF_RETURN_IF_ERROR(fused_instructions_computation()
2138                          ->RemoveUnusedParametersFromFusedComputation());
2139   RemoveOperandsAtAscendingIndices(operands_to_remove);
2140   return OkStatus();
2141 }
2142 
HloCallInstruction(const Shape & shape,HloInstruction * called_computation_root)2143 HloCallInstruction::HloCallInstruction(const Shape& shape,
2144                                        HloInstruction* called_computation_root)
2145     : HloCallableInstruction(HloOpcode::kCall, shape) {
2146   CHECK(called_computation_root != nullptr);
2147   SetAndSanitizeName(HloOpcodeString(opcode()));
2148   set_parent(called_computation_root->parent());
2149   set_metadata(called_computation_root->metadata());
2150   CloneAndAppendInstructionIntoCalledComputation(called_computation_root);
2151 }
2152 
HloCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * called_computation)2153 HloCallInstruction::HloCallInstruction(
2154     const Shape& shape, absl::Span<HloInstruction* const> operands,
2155     HloComputation* called_computation)
2156     : HloCallableInstruction(HloOpcode::kCall, shape, operands,
2157                              called_computation) {}
2158 
HloRngInstruction(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)2159 HloRngInstruction::HloRngInstruction(
2160     const Shape& shape, RandomDistribution distribution,
2161     absl::Span<HloInstruction* const> parameters)
2162     : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
2163   for (HloInstruction* param : parameters) {
2164     AppendOperand(param);
2165   }
2166 }
2167 
ToProto() const2168 HloInstructionProto HloRngInstruction::ToProto() const {
2169   HloInstructionProto proto = HloInstruction::ToProto();
2170   proto.set_distribution(distribution_);
2171   return proto;
2172 }
2173 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2174 std::vector<std::string> HloRngInstruction::ExtraAttributesToStringImpl(
2175     const HloPrintOptions& options) const {
2176   return {StrCat("distribution=", RandomDistributionToString(distribution_))};
2177 }
2178 
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const2179 bool HloRngInstruction::IsElementwiseImpl(
2180     const std::optional<int64_t>& operand_idx) const {
2181   return true;
2182 }
2183 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2184 bool HloRngInstruction::IdenticalSlowPath(
2185     const HloInstruction& other,
2186     const std::function<bool(const HloComputation*, const HloComputation*)>&
2187         eq_computations) const {
2188   const auto& casted_other = static_cast<const HloRngInstruction&>(other);
2189   return distribution_ == casted_other.distribution_;
2190 }
2191 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2192 std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
2193     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2194     HloCloneContext* context) const {
2195   return std::make_unique<HloRngInstruction>(shape, distribution_,
2196                                              new_operands);
2197 }
2198 
HloParameterInstruction(int64_t parameter_number,const Shape & shape,const std::string & name)2199 HloParameterInstruction::HloParameterInstruction(int64_t parameter_number,
2200                                                  const Shape& shape,
2201                                                  const std::string& name)
2202     : HloInstruction(HloOpcode::kParameter, shape),
2203       parameter_number_(parameter_number) {
2204   SetAndSanitizeName(name);
2205 }
2206 
ToProto() const2207 HloInstructionProto HloParameterInstruction::ToProto() const {
2208   HloInstructionProto proto = HloInstruction::ToProto();
2209   proto.set_parameter_number(parameter_number_);
2210   if (parameter_replicated_at_leaf_buffers_) {
2211     for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2212       proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers(
2213           replicated);
2214     }
2215   }
2216   return proto;
2217 }
2218 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2219 std::vector<std::string> HloParameterInstruction::ExtraAttributesToStringImpl(
2220     const HloPrintOptions& options) const {
2221   std::vector<std::string> result;
2222   if (!parameter_replicated_at_leaf_buffers_) {
2223     return result;
2224   }
2225   std::vector<std::string> buffers_replicated_strs;
2226   buffers_replicated_strs.reserve(
2227       parameter_replicated_at_leaf_buffers_->size());
2228   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
2229     buffers_replicated_strs.push_back(replicated ? "true" : "false");
2230   }
2231   if (options.print_ids()) {
2232     result.push_back(StrCat("parameter_replication={",
2233                             StrJoin(buffers_replicated_strs, ","), "}"));
2234   }
2235   return result;
2236 }
2237 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2238 std::string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
2239     const HloPrintOptions& options,
2240     CanonicalNameMap* canonical_name_map) const {
2241   return StrCat(parameter_number_);
2242 }
2243 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2244 bool HloParameterInstruction::IdenticalSlowPath(
2245     const HloInstruction& other,
2246     const std::function<bool(const HloComputation*, const HloComputation*)>&
2247         eq_computations) const {
2248   const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
2249   return parameter_number() == casted_other.parameter_number();
2250 }
2251 
2252 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2253 HloParameterInstruction::CloneWithNewOperandsImpl(
2254     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2255     HloCloneContext* context) const {
2256   auto clone = std::make_unique<HloParameterInstruction>(parameter_number_,
2257                                                          shape, name());
2258   if (parameter_replicated_at_leaf_buffers_ &&
2259       ShapeUtil::Equal(shape, this->shape())) {
2260     clone->set_parameter_replicated_at_leaf_buffers(
2261         *parameter_replicated_at_leaf_buffers_);
2262   }
2263   return clone;
2264 }
2265 
HloGetTupleElementInstruction(const Shape & shape,HloInstruction * operand,int64_t index)2266 HloGetTupleElementInstruction::HloGetTupleElementInstruction(
2267     const Shape& shape, HloInstruction* operand, int64_t index)
2268     : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
2269   AppendOperand(operand);
2270 }
2271 
ToProto() const2272 HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
2273   HloInstructionProto proto = HloInstruction::ToProto();
2274   proto.set_tuple_index(tuple_index_);
2275   return proto;
2276 }
2277 
2278 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2279 HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
2280     const HloPrintOptions& options) const {
2281   return {StrCat("index=", tuple_index())};
2282 }
2283 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2284 bool HloGetTupleElementInstruction::IdenticalSlowPath(
2285     const HloInstruction& other,
2286     const std::function<bool(const HloComputation*, const HloComputation*)>&
2287         eq_computations) const {
2288   const auto& casted_other =
2289       static_cast<const HloGetTupleElementInstruction&>(other);
2290   return tuple_index() == casted_other.tuple_index();
2291 }
2292 
2293 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2294 HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
2295     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2296     HloCloneContext* context) const {
2297   CHECK_EQ(new_operands.size(), 1);
2298   return std::make_unique<HloGetTupleElementInstruction>(shape, new_operands[0],
2299                                                          tuple_index());
2300 }
2301 
HloReducePrecisionInstruction(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)2302 HloReducePrecisionInstruction::HloReducePrecisionInstruction(
2303     const Shape& shape, HloInstruction* operand, const int exponent_bits,
2304     const int mantissa_bits)
2305     : HloInstruction(HloOpcode::kReducePrecision, shape),
2306       exponent_bits_(exponent_bits),
2307       mantissa_bits_(mantissa_bits) {
2308   AppendOperand(operand);
2309 }
2310 
ToProto() const2311 HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
2312   HloInstructionProto proto = HloInstruction::ToProto();
2313   proto.set_exponent_bits(exponent_bits_);
2314   proto.set_mantissa_bits(mantissa_bits_);
2315   return proto;
2316 }
2317 
2318 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2319 HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
2320     const HloPrintOptions& options) const {
2321   return {StrCat("exponent_bits=", exponent_bits_),
2322           StrCat("mantissa_bits=", mantissa_bits_)};
2323 }
2324 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2325 bool HloReducePrecisionInstruction::IdenticalSlowPath(
2326     const HloInstruction& other,
2327     const std::function<bool(const HloComputation*, const HloComputation*)>&
2328         eq_computations) const {
2329   const auto& casted_other =
2330       static_cast<const HloReducePrecisionInstruction&>(other);
2331   // A reduce-precision operation is determined by the bit sizes.
2332   return exponent_bits() == casted_other.exponent_bits() &&
2333          mantissa_bits() == casted_other.mantissa_bits();
2334 }
2335 
2336 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2337 HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
2338     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2339     HloCloneContext* context) const {
2340   CHECK_EQ(new_operands.size(), 1);
2341   return std::make_unique<HloReducePrecisionInstruction>(
2342       shape, new_operands[0], exponent_bits(), mantissa_bits());
2343 }
2344 
HloInfeedInstruction(const Shape & infeed_shape,HloInstruction * token_operand,const std::string & config)2345 HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
2346                                            HloInstruction* token_operand,
2347                                            const std::string& config)
2348     : HloInstruction(HloOpcode::kInfeed,
2349                      ShapeUtil::MakeTupleShape(
2350                          {infeed_shape, ShapeUtil::MakeTokenShape()})),
2351       infeed_config_(config) {
2352   AppendOperand(token_operand);
2353 }
2354 
ToProto() const2355 HloInstructionProto HloInfeedInstruction::ToProto() const {
2356   HloInstructionProto proto = HloInstruction::ToProto();
2357   proto.set_infeed_config(infeed_config_);
2358   return proto;
2359 }
2360 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2361 std::vector<std::string> HloInfeedInstruction::ExtraAttributesToStringImpl(
2362     const HloPrintOptions& options) const {
2363   if (!options.print_infeed_outfeed_config() || infeed_config_.empty()) {
2364     return {};
2365   }
2366   return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
2367 }
2368 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2369 bool HloInfeedInstruction::IdenticalSlowPath(
2370     const HloInstruction& other,
2371     const std::function<bool(const HloComputation*, const HloComputation*)>&
2372         eq_computations) const {
2373   // Not yet supported.
2374   return false;
2375 }
2376 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2377 std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
2378     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2379     HloCloneContext* context) const {
2380   CHECK_EQ(new_operands.size(), 1);
2381   return std::make_unique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
2382                                                 infeed_config());
2383 }
2384 
HloOutfeedInstruction(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)2385 HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
2386                                              HloInstruction* operand,
2387                                              HloInstruction* token_operand,
2388                                              absl::string_view outfeed_config)
2389     : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
2390       outfeed_shape_(outfeed_shape),
2391       outfeed_config_(outfeed_config) {
2392   AppendOperand(operand);
2393   AppendOperand(token_operand);
2394 }
2395 
ToProto() const2396 HloInstructionProto HloOutfeedInstruction::ToProto() const {
2397   HloInstructionProto proto = HloInstruction::ToProto();
2398   proto.set_outfeed_config(outfeed_config());
2399   *proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
2400   return proto;
2401 }
2402 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2403 std::vector<std::string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
2404     const HloPrintOptions& options) const {
2405   std::vector<std::string> extra;
2406   extra.push_back(StrCat("outfeed_shape=",
2407                          ShapeUtil::HumanStringWithLayout(outfeed_shape_)));
2408   if (options.print_infeed_outfeed_config() && !outfeed_config_.empty()) {
2409     extra.push_back(
2410         StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2411   }
2412   return extra;
2413 }
2414 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2415 bool HloOutfeedInstruction::IdenticalSlowPath(
2416     const HloInstruction& other,
2417     const std::function<bool(const HloComputation*, const HloComputation*)>&
2418         eq_computations) const {
2419   // Not yet supported.
2420   return false;
2421 }
2422 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2423 std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
2424     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2425     HloCloneContext* context) const {
2426   CHECK_EQ(new_operands.size(), 2);
2427   return std::make_unique<HloOutfeedInstruction>(
2428       outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
2429 }
2430 
HloConvolutionInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)2431 HloConvolutionInstruction::HloConvolutionInstruction(
2432     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
2433     int64_t feature_group_count, int64_t batch_group_count,
2434     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
2435     const PrecisionConfig& precision_config)
2436     : HloInstruction(HloOpcode::kConvolution, shape),
2437       feature_group_count_(feature_group_count),
2438       batch_group_count_(batch_group_count),
2439       window_(window),
2440       convolution_dimension_numbers_(dimension_numbers),
2441       precision_config_(precision_config) {
2442   if (window_util::HasBaseDilation(window)) {
2443     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
2444   }
2445   if (window_util::HasWindowDilation(window)) {
2446     SetAndSanitizeName(StrCat(name(), "-window-dilated"));
2447   }
2448   AppendOperand(lhs);
2449   AppendOperand(rhs);
2450 }
2451 
ToCategory() const2452 std::string HloConvolutionInstruction::ToCategory() const {
2453   std::string category = "convolution";
2454   if (window_util::HasBaseDilation(window())) {
2455     category += " base-dilated";
2456   }
2457   if (window_util::HasWindowDilation(window())) {
2458     category += " window-dilated";
2459   }
2460   return category;
2461 }
2462 
ToProto() const2463 HloInstructionProto HloConvolutionInstruction::ToProto() const {
2464   HloInstructionProto proto = HloInstruction::ToProto();
2465   *proto.mutable_window() = window_;
2466   *proto.mutable_convolution_dimension_numbers() =
2467       convolution_dimension_numbers_;
2468   proto.set_feature_group_count(feature_group_count_);
2469   proto.set_batch_group_count(batch_group_count_);
2470   *proto.mutable_precision_config() = precision_config_;
2471   return proto;
2472 }
2473 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2474 std::vector<std::string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
2475     const HloPrintOptions& options) const {
2476   std::vector<std::string> extra;
2477   if (window_.dimensions_size() != 0) {
2478     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2479   }
2480   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
2481                                             convolution_dimension_numbers_)));
2482   if (feature_group_count_ != 1) {
2483     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2484   }
2485 
2486   if (batch_group_count_ != 1) {
2487     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2488   }
2489 
2490   std::string precision_config_string =
2491       PrecisionConfigToString(precision_config_);
2492   if (!precision_config_string.empty()) {
2493     extra.push_back(precision_config_string);
2494   }
2495   return extra;
2496 }
2497 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2498 bool HloConvolutionInstruction::IdenticalSlowPath(
2499     const HloInstruction& other,
2500     const std::function<bool(const HloComputation*, const HloComputation*)>&
2501         eq_computations) const {
2502   const auto& casted_other =
2503       static_cast<const HloConvolutionInstruction&>(other);
2504   if (feature_group_count_ != other.feature_group_count()) {
2505     return false;
2506   }
2507   if (batch_group_count_ != other.batch_group_count()) {
2508     return false;
2509   }
2510   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
2511          protobuf_util::ProtobufEquals(
2512              convolution_dimension_numbers(),
2513              casted_other.convolution_dimension_numbers()) &&
2514          protobuf_util::ProtobufEquals(precision_config(),
2515                                        casted_other.precision_config());
2516 }
2517 
2518 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2519 HloConvolutionInstruction::CloneWithNewOperandsImpl(
2520     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2521     HloCloneContext* context) const {
2522   CHECK_EQ(new_operands.size(), 2);
2523   return std::make_unique<HloConvolutionInstruction>(
2524       shape, new_operands[0], new_operands[1], feature_group_count_,
2525       batch_group_count_, window(), convolution_dimension_numbers_,
2526       precision_config_);
2527 }
2528 
HloReduceWindowInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)2529 HloReduceWindowInstruction::HloReduceWindowInstruction(
2530     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
2531     const Window& window, HloComputation* reduce_computation)
2532     : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
2533                                  absl::MakeSpan(&init_value, 1), window,
2534                                  reduce_computation) {}
2535 
HloReduceWindowInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)2536 HloReduceWindowInstruction::HloReduceWindowInstruction(
2537     const Shape& shape, absl::Span<HloInstruction* const> operands,
2538     absl::Span<HloInstruction* const> init_values, const Window& window,
2539     HloComputation* reduce_computation)
2540     : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
2541   for (auto* operand : operands) {
2542     AppendOperand(operand);
2543   }
2544   for (auto* init_value : init_values) {
2545     AppendOperand(init_value);
2546   }
2547   AppendComputation(reduce_computation);
2548 }
2549 
ToProto() const2550 HloInstructionProto HloReduceWindowInstruction::ToProto() const {
2551   HloInstructionProto proto = HloInstruction::ToProto();
2552   *proto.mutable_window() = window_;
2553   return proto;
2554 }
2555 
2556 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2557 HloReduceWindowInstruction::ExtraAttributesToStringImpl(
2558     const HloPrintOptions& options) const {
2559   std::vector<std::string> extra;
2560   if (window_.dimensions_size() != 0) {
2561     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2562   }
2563   return extra;
2564 }
2565 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2566 bool HloReduceWindowInstruction::IdenticalSlowPath(
2567     const HloInstruction& other,
2568     const std::function<bool(const HloComputation*, const HloComputation*)>&
2569         eq_computations) const {
2570   const auto& casted_other =
2571       static_cast<const HloReduceWindowInstruction&>(other);
2572   return eq_computations(to_apply(), casted_other.to_apply()) &&
2573          protobuf_util::ProtobufEquals(window(), casted_other.window());
2574 }
2575 
2576 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2577 HloReduceWindowInstruction::CloneWithNewOperandsImpl(
2578     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2579     HloCloneContext* context) const {
2580   CHECK_EQ(new_operands.size() % 2, 0);
2581   int64_t num_operands = new_operands.size() / 2;
2582   return std::make_unique<HloReduceWindowInstruction>(
2583       shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
2584       absl::MakeSpan(new_operands)
2585           .subspan(num_operands, new_operands.size() / 2),
2586       window(), to_apply());
2587 }
2588 
HloSelectAndScatterInstruction(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)2589 HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
2590     const Shape& shape, HloInstruction* operand, HloComputation* select,
2591     const Window& window, HloInstruction* source, HloInstruction* init_value,
2592     HloComputation* scatter)
2593     : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
2594   AppendOperand(operand);
2595   AppendOperand(source);
2596   AppendOperand(init_value);
2597   // Select comes before scatter in the vector.
2598   AppendComputation(select);
2599   AppendComputation(scatter);
2600 }
2601 
ToProto() const2602 HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
2603   HloInstructionProto proto = HloInstruction::ToProto();
2604   *proto.mutable_window() = window_;
2605   return proto;
2606 }
2607 
2608 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2609 HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
2610     const HloPrintOptions& options) const {
2611   std::vector<std::string> extra;
2612   if (window_.dimensions_size() != 0) {
2613     extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
2614   }
2615   return extra;
2616 }
2617 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2618 bool HloSelectAndScatterInstruction::IdenticalSlowPath(
2619     const HloInstruction& other,
2620     const std::function<bool(const HloComputation*, const HloComputation*)>&
2621         eq_computations) const {
2622   const auto& casted_other =
2623       static_cast<const HloSelectAndScatterInstruction&>(other);
2624   return eq_computations(select(), casted_other.select()) &&
2625          eq_computations(scatter(), casted_other.scatter()) &&
2626          protobuf_util::ProtobufEquals(window(), casted_other.window());
2627 }
2628 
2629 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2630 HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
2631     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2632     HloCloneContext* context) const {
2633   CHECK_EQ(new_operands.size(), 3);
2634   return std::make_unique<HloSelectAndScatterInstruction>(
2635       shape, new_operands[0], select(), window(), new_operands[1],
2636       new_operands[2], scatter());
2637 }
2638 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2639 HloCustomCallInstruction::HloCustomCallInstruction(
2640     const Shape& shape, absl::Span<HloInstruction* const> operands,
2641     absl::string_view custom_call_target, std::string opaque,
2642     CustomCallApiVersion api_version)
2643     : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands),
2644       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2645       feature_group_count_(1),
2646       batch_group_count_(1),
2647       layout_constrained_(false),
2648       padding_type_(PaddingType::PADDING_INVALID),
2649       custom_call_has_side_effect_(false),
2650       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2651       api_version_(api_version) {
2652   set_raw_backend_config_string(std::move(opaque));
2653 }
2654 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2655 HloCustomCallInstruction::HloCustomCallInstruction(
2656     const Shape& shape, absl::Span<HloInstruction* const> operands,
2657     HloComputation* to_apply, absl::string_view custom_call_target,
2658     std::string opaque, CustomCallApiVersion api_version)
2659     : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands, to_apply),
2660       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2661       feature_group_count_(1),
2662       batch_group_count_(1),
2663       layout_constrained_(false),
2664       padding_type_(PaddingType::PADDING_INVALID),
2665       custom_call_has_side_effect_(false),
2666       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2667       api_version_(api_version) {
2668   set_raw_backend_config_string(std::move(opaque));
2669   to_apply->SetCustomCallInstruction(this);
2670 }
2671 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)2672 HloCustomCallInstruction::HloCustomCallInstruction(
2673     const Shape& shape, absl::Span<HloInstruction* const> operands,
2674     absl::Span<HloComputation* const> called_computations,
2675     absl::string_view custom_call_target, std::string opaque,
2676     CustomCallApiVersion api_version)
2677     : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands,
2678                              called_computations),
2679       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2680       feature_group_count_(1),
2681       batch_group_count_(1),
2682       layout_constrained_(false),
2683       padding_type_(PaddingType::PADDING_INVALID),
2684       custom_call_has_side_effect_(false),
2685       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2686       api_version_(api_version) {
2687   set_raw_backend_config_string(std::move(opaque));
2688   for (auto comp : called_computations) {
2689     comp->SetCustomCallInstruction(this);
2690   }
2691 }
2692 
HloCustomCallInstruction(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,absl::Span<const Shape> operand_shapes_with_layout,CustomCallApiVersion api_version)2693 HloCustomCallInstruction::HloCustomCallInstruction(
2694     const Shape& shape, absl::Span<HloInstruction* const> operands,
2695     absl::string_view custom_call_target, std::string opaque,
2696     absl::Span<const Shape> operand_shapes_with_layout,
2697     CustomCallApiVersion api_version)
2698     : HloCallableInstruction(HloOpcode::kCustomCall, shape, operands),
2699       custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
2700       feature_group_count_(1),
2701       batch_group_count_(1),
2702       layout_constrained_(true),
2703       padding_type_(PaddingType::PADDING_INVALID),
2704       operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
2705                                   operand_shapes_with_layout.end()),
2706       custom_call_has_side_effect_(false),
2707       custom_call_schedule_(CustomCallSchedule::SCHEDULE_NONE),
2708       api_version_(api_version) {
2709   set_raw_backend_config_string(std::move(opaque));
2710 }
2711 
ToProto() const2712 HloInstructionProto HloCustomCallInstruction::ToProto() const {
2713   HloInstructionProto proto = HloInstruction::ToProto();
2714   if (window_ != nullptr) {
2715     *proto.mutable_window() = *window_;
2716   }
2717   if (convolution_dimension_numbers_ != nullptr) {
2718     *proto.mutable_convolution_dimension_numbers() =
2719         *convolution_dimension_numbers_;
2720   }
2721   proto.set_custom_call_target(custom_call_target_);
2722   proto.set_feature_group_count(feature_group_count_);
2723   proto.set_batch_group_count(batch_group_count_);
2724   *proto.mutable_precision_config() = precision_config_;
2725   proto.set_padding_type(padding_type_);
2726   if (layout_constrained()) {
2727     proto.set_constrain_layout(true);
2728     for (const Shape& shape : operand_shapes_with_layout_) {
2729       *proto.add_operand_shapes_with_layout() = shape.ToProto();
2730     }
2731   }
2732   proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
2733   if (literal_.has_value()) {
2734     *proto.mutable_literal() = literal_->ToProto();
2735   }
2736   for (const auto& pair : output_to_operand_aliasing_) {
2737     auto aliasing = proto.add_custom_call_output_operand_aliasing();
2738     aliasing->set_operand_index(pair.second.first);
2739     for (int64_t index : pair.first) {
2740       aliasing->add_output_shape_index(index);
2741     }
2742     for (int64_t index : pair.second.second) {
2743       aliasing->add_operand_shape_index(index);
2744     }
2745   }
2746   proto.set_custom_call_schedule(custom_call_schedule_);
2747   proto.set_custom_call_api_version(api_version_);
2748   return proto;
2749 }
2750 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2751 std::vector<std::string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
2752     const HloPrintOptions& options) const {
2753   std::vector<std::string> extra;
2754   if (window_ != nullptr) {
2755     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2756   }
2757   if (convolution_dimension_numbers_ != nullptr) {
2758     extra.push_back(StrCat(
2759         "dim_labels=",
2760         ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
2761   }
2762   if (feature_group_count_ != 1) {
2763     extra.push_back(StrCat("feature_group_count=", feature_group_count_));
2764   }
2765   if (batch_group_count_ != 1) {
2766     extra.push_back(StrCat("batch_group_count=", batch_group_count_));
2767   }
2768   std::string precision_config_string =
2769       PrecisionConfigToString(precision_config_);
2770   if (!precision_config_string.empty()) {
2771     extra.push_back(precision_config_string);
2772   }
2773   if (padding_type_ != PaddingType::PADDING_INVALID) {
2774     extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
2775   }
2776   // By contract, we print the custom call target even if
2777   // options.print_subcomputation_mode() == kOff, because the call target is not
2778   // an HloComputation.
2779   extra.push_back(
2780       StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2781 
2782   if (layout_constrained()) {
2783     std::vector<std::string> shape_strings;
2784     shape_strings.reserve(operand_shapes_with_layout_.size());
2785     for (const Shape& shape : operand_shapes_with_layout_) {
2786       shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
2787     }
2788     extra.push_back(StrCat("operand_layout_constraints={",
2789                            StrJoin(shape_strings, ", "), "}"));
2790   }
2791   if (custom_call_has_side_effect_) {
2792     extra.push_back("custom_call_has_side_effect=true");
2793   }
2794   if (literal_.has_value()) {
2795     extra.push_back(StrCat("literal=", literal_->ToStringWithLayoutOneline()));
2796   }
2797   if (!output_to_operand_aliasing_.empty()) {
2798     std::vector<std::string> pair_strings;
2799     pair_strings.reserve(output_to_operand_aliasing_.size());
2800     for (const auto& pair : output_to_operand_aliasing_) {
2801       pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
2802                                     pair.second.first, ", ",
2803                                     pair.second.second.ToString(), ")"));
2804     }
2805     extra.push_back(StrCat("output_to_operand_aliasing={",
2806                            StrJoin(pair_strings, ", "), "}"));
2807   }
2808   if (custom_call_schedule_ != CustomCallSchedule::SCHEDULE_NONE) {
2809     extra.push_back(
2810         StrCat("schedule=", CustomCallSchedule_Name(custom_call_schedule_)));
2811   }
2812   if (api_version_ != CustomCallApiVersion::API_VERSION_ORIGINAL) {
2813     extra.push_back(
2814         StrCat("api_version=", CustomCallApiVersion_Name(api_version_)));
2815   }
2816   return extra;
2817 }
2818 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2819 bool HloCustomCallInstruction::IdenticalSlowPath(
2820     const HloInstruction& other,
2821     const std::function<bool(const HloComputation*, const HloComputation*)>&
2822         eq_computations) const {
2823   const auto& casted_other =
2824       static_cast<const HloCustomCallInstruction&>(other);
2825   if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
2826       (window_ != nullptr &&
2827        !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
2828     return false;
2829   }
2830   if ((convolution_dimension_numbers_ == nullptr) !=
2831           (casted_other.convolution_dimension_numbers_ == nullptr) ||
2832       (convolution_dimension_numbers_ != nullptr &&
2833        !protobuf_util::ProtobufEquals(
2834            convolution_dimension_numbers(),
2835            casted_other.convolution_dimension_numbers()))) {
2836     return false;
2837   }
2838   if (feature_group_count_ != casted_other.feature_group_count_) {
2839     return false;
2840   }
2841   if (batch_group_count_ != casted_other.batch_group_count_) {
2842     return false;
2843   }
2844 
2845   if (padding_type_ != casted_other.padding_type()) {
2846     return false;
2847   }
2848 
2849   if (layout_constrained() != casted_other.layout_constrained()) {
2850     return false;
2851   }
2852   if (layout_constrained()) {
2853     for (int64_t i = 0; i < operand_shapes_with_layout_.size(); ++i) {
2854       if (!ShapeUtil::Equal(operand_shapes_with_layout_[i],
2855                             casted_other.operand_shapes_with_layout_[i])) {
2856         return false;
2857       }
2858     }
2859   }
2860   if (custom_call_has_side_effect_ !=
2861       casted_other.custom_call_has_side_effect()) {
2862     return false;
2863   }
2864   if (output_to_operand_aliasing_ !=
2865       casted_other.output_to_operand_aliasing()) {
2866     return false;
2867   }
2868   if (!protobuf_util::ProtobufEquals(precision_config(),
2869                                      casted_other.precision_config())) {
2870     return false;
2871   }
2872 
2873   if (called_computations().size() != other.called_computations().size()) {
2874     return false;
2875   }
2876   for (int64_t i = 0; i < called_computations().size(); ++i) {
2877     if (!eq_computations(called_computations()[i],
2878                          other.called_computations()[i])) {
2879       return false;
2880     }
2881   }
2882   if (custom_call_schedule_ != casted_other.custom_call_schedule()) {
2883     return false;
2884   }
2885   if (HasLiteral() != casted_other.HasLiteral()) {
2886     return false;
2887   }
2888   if (HasLiteral() && literal() != casted_other.literal()) {
2889     return false;
2890   }
2891   if (api_version_ != casted_other.api_version_) {
2892     return false;
2893   }
2894   // Note: backend_config comparison is done in Identical, which is the
2895   // intended/exposed way to compare computations, and so not repeated here.
2896   return custom_call_target_ == casted_other.custom_call_target_;
2897 }
2898 
2899 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2900 HloCustomCallInstruction::CloneWithNewOperandsImpl(
2901     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2902     HloCloneContext* context) const {
2903   absl::InlinedVector<HloComputation*, 1> new_called_computations =
2904       GetOrCloneCalledComputations(context);
2905 
2906   auto cloned = std::make_unique<HloCustomCallInstruction>(
2907       shape, new_operands, new_called_computations, custom_call_target(),
2908       opaque(), api_version_);
2909   if (layout_constrained()) {
2910     cloned->layout_constrained_ = true;
2911     cloned->operand_shapes_with_layout_ = operand_shapes_with_layout();
2912   }
2913   if (window_ != nullptr) {
2914     cloned->set_window(*window_);
2915   }
2916   if (convolution_dimension_numbers_ != nullptr) {
2917     cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
2918   }
2919   if (HasLiteral()) {
2920     cloned->set_literal(literal().Clone());
2921   }
2922   cloned->set_feature_group_count(feature_group_count_);
2923   cloned->set_batch_group_count(batch_group_count_);
2924   cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
2925   cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
2926   cloned->set_padding_type(padding_type_);
2927   *cloned->mutable_precision_config() = precision_config();
2928   cloned->set_custom_call_schedule(custom_call_schedule_);
2929   return std::move(cloned);
2930 }
2931 
HloPadInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)2932 HloPadInstruction::HloPadInstruction(const Shape& shape,
2933                                      HloInstruction* operand,
2934                                      HloInstruction* padding_value,
2935                                      const PaddingConfig& padding_config)
2936     : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
2937   AppendOperand(operand);
2938   AppendOperand(padding_value);
2939 }
2940 
ToProto() const2941 HloInstructionProto HloPadInstruction::ToProto() const {
2942   HloInstructionProto proto = HloInstruction::ToProto();
2943   *proto.mutable_padding_config() = padding_config_;
2944   return proto;
2945 }
2946 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const2947 std::vector<std::string> HloPadInstruction::ExtraAttributesToStringImpl(
2948     const HloPrintOptions& options) const {
2949   return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
2950 }
2951 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2952 bool HloPadInstruction::IdenticalSlowPath(
2953     const HloInstruction& other,
2954     const std::function<bool(const HloComputation*, const HloComputation*)>&
2955         eq_computations) const {
2956   const auto& casted_other = static_cast<const HloPadInstruction&>(other);
2957   return protobuf_util::ProtobufEquals(padding_config(),
2958                                        casted_other.padding_config());
2959 }
2960 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const2961 std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
2962     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2963     HloCloneContext* context) const {
2964   CHECK_EQ(new_operands.size(), 2);
2965   return std::make_unique<HloPadInstruction>(shape, new_operands[0],
2966                                              new_operands[1], padding_config_);
2967 }
2968 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64_t> slice_sizes)2969 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2970     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
2971     absl::Span<const int64_t> slice_sizes)
2972     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2973       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2974   AppendOperand(operand);
2975   AppendOperand(start_indices);
2976 }
2977 
HloDynamicSliceInstruction(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes)2978 HloDynamicSliceInstruction::HloDynamicSliceInstruction(
2979     const Shape& shape, HloInstruction* operand,
2980     absl::Span<HloInstruction* const> start_indices,
2981     absl::Span<const int64_t> slice_sizes)
2982     : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
2983       dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
2984   AppendOperand(operand);
2985   for (HloInstruction* index : start_indices) {
2986     AppendOperand(index);
2987   }
2988 }
2989 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)2990 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
2991     const Shape& shape, HloInstruction* operand, HloInstruction* update,
2992     HloInstruction* start_indices)
2993     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
2994   AppendOperand(operand);
2995   AppendOperand(update);
2996   AppendOperand(start_indices);
2997 }
2998 
HloDynamicUpdateSliceInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)2999 HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
3000     const Shape& shape, HloInstruction* operand, HloInstruction* update,
3001     absl::Span<HloInstruction* const> start_indices)
3002     : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
3003   AppendOperand(operand);
3004   AppendOperand(update);
3005   for (HloInstruction* index : start_indices) {
3006     AppendOperand(index);
3007   }
3008 }
3009 
ToProto() const3010 HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
3011   HloInstructionProto proto = HloInstruction::ToProto();
3012   for (int64_t slice_size : dynamic_slice_sizes_) {
3013     proto.add_dynamic_slice_sizes(slice_size);
3014   }
3015   return proto;
3016 }
3017 
3018 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3019 HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
3020     const HloPrintOptions& options) const {
3021   return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
3022                  "}")};
3023 }
3024 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3025 bool HloDynamicSliceInstruction::IdenticalSlowPath(
3026     const HloInstruction& other,
3027     const std::function<bool(const HloComputation*, const HloComputation*)>&
3028         eq_computations) const {
3029   const auto& casted_other = static_cast<const HloMapInstruction&>(other);
3030   return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
3031 }
3032 
3033 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3034 HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
3035     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3036     HloCloneContext* context) const {
3037   if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) {
3038     // TODO(b/118437727): Old form, remove this path.
3039     return std::make_unique<HloDynamicSliceInstruction>(
3040         shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
3041   } else {
3042     return std::make_unique<HloDynamicSliceInstruction>(
3043         shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_);
3044   }
3045 }
3046 
HloGatherInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)3047 HloGatherInstruction::HloGatherInstruction(
3048     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
3049     const GatherDimensionNumbers& gather_dim_numbers,
3050     absl::Span<const int64_t> slice_sizes, bool indices_are_sorted)
3051     : HloInstruction(HloOpcode::kGather, shape),
3052       indices_are_sorted_(indices_are_sorted) {
3053   AppendOperand(operand);
3054   AppendOperand(start_indices);
3055   gather_dimension_numbers_ =
3056       std::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
3057   absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
3058 }
3059 
GatherDimensionNumbersToString(const GatherDimensionNumbers & gather_dimension_numbers)3060 /*static*/ std::string HloGatherInstruction::GatherDimensionNumbersToString(
3061     const GatherDimensionNumbers& gather_dimension_numbers) {
3062   std::string offset_dims =
3063       StrCat("offset_dims={",
3064              StrJoin(gather_dimension_numbers.offset_dims(), ","), "}");
3065   std::string collapsed_slice_dims = StrCat(
3066       "collapsed_slice_dims={",
3067       StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}");
3068   std::string start_index_map =
3069       StrCat("start_index_map={",
3070              StrJoin(gather_dimension_numbers.start_index_map(), ","), "}");
3071   std::string index_vector_dim =
3072       StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim());
3073 
3074   return StrJoin<std::initializer_list<std::string>>(
3075       {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
3076       ", ");
3077 }
3078 
MakeGatherDimNumbers(absl::Span<const int64_t> offset_dims,absl::Span<const int64_t> collapsed_slice_dims,absl::Span<const int64_t> start_index_map,int64_t index_vector_dim)3079 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
3080     absl::Span<const int64_t> offset_dims,
3081     absl::Span<const int64_t> collapsed_slice_dims,
3082     absl::Span<const int64_t> start_index_map, int64_t index_vector_dim) {
3083   GatherDimensionNumbers gather_dim_numbers;
3084   for (int64_t output_window_dim : offset_dims) {
3085     gather_dim_numbers.add_offset_dims(output_window_dim);
3086   }
3087   for (int64_t elided_window_dim : collapsed_slice_dims) {
3088     gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
3089   }
3090   for (int64_t gather_dim_to_input_dim : start_index_map) {
3091     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
3092   }
3093 
3094   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
3095   return gather_dim_numbers;
3096 }
3097 
ToProto() const3098 HloInstructionProto HloGatherInstruction::ToProto() const {
3099   HloInstructionProto proto = HloInstruction::ToProto();
3100   *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
3101   for (int64_t bound : gather_slice_sizes()) {
3102     proto.add_gather_slice_sizes(bound);
3103   }
3104   proto.set_indices_are_sorted(indices_are_sorted());
3105   return proto;
3106 }
3107 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3108 std::vector<std::string> HloGatherInstruction::ExtraAttributesToStringImpl(
3109     const HloPrintOptions& options) const {
3110   std::vector<std::string> attrs{
3111       GatherDimensionNumbersToString(gather_dimension_numbers()),
3112       StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
3113   if (indices_are_sorted()) {
3114     attrs.push_back("indices_are_sorted=true");
3115   }
3116   return attrs;
3117 }
3118 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3119 bool HloGatherInstruction::IdenticalSlowPath(
3120     const HloInstruction& other,
3121     const std::function<bool(const HloComputation*, const HloComputation*)>&
3122         eq_computations) const {
3123   const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
3124   return protobuf_util::ProtobufEquals(
3125              gather_dimension_numbers(),
3126              casted_other.gather_dimension_numbers()) &&
3127          gather_slice_sizes() == casted_other.gather_slice_sizes() &&
3128          indices_are_sorted() == casted_other.indices_are_sorted();
3129 }
3130 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3131 std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
3132     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3133     HloCloneContext* context) const {
3134   CHECK_EQ(new_operands.size(), 2);
3135   return std::make_unique<HloGatherInstruction>(
3136       shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
3137       gather_slice_sizes(), indices_are_sorted());
3138 }
3139 
HloScatterInstruction(const Shape & shape,absl::Span<HloInstruction * const> args,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)3140 HloScatterInstruction::HloScatterInstruction(
3141     const Shape& shape, absl::Span<HloInstruction* const> args,
3142     HloComputation* update_computation,
3143     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
3144     bool unique_indices)
3145     : HloInstruction(HloOpcode::kScatter, shape),
3146       indices_are_sorted_(indices_are_sorted),
3147       unique_indices_(unique_indices) {
3148   mutable_operands().reserve(args.size());
3149   for (HloInstruction* arg : args) {
3150     AppendOperand(arg);
3151   }
3152   AppendComputation(update_computation);
3153   scatter_dimension_numbers_ =
3154       std::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
3155 }
3156 
ScatterDimensionNumbersToString(const ScatterDimensionNumbers & scatter_dimension_numbers)3157 /*static*/ std::string HloScatterInstruction::ScatterDimensionNumbersToString(
3158     const ScatterDimensionNumbers& scatter_dimension_numbers) {
3159   std::string update_window_dims =
3160       StrCat("update_window_dims={",
3161              StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}");
3162   std::string inserted_window_dims = StrCat(
3163       "inserted_window_dims={",
3164       StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}");
3165   std::string scatter_dims_to_operand_dims = StrCat(
3166       "scatter_dims_to_operand_dims={",
3167       StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","),
3168       "}");
3169   std::string index_vector_dim =
3170       StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim());
3171 
3172   return StrJoin<std::initializer_list<std::string>>(
3173       {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
3174        index_vector_dim},
3175       ", ");
3176 }
3177 
3178 /* static */ ScatterDimensionNumbers
MakeScatterDimNumbers(absl::Span<const int64_t> update_window_dims,absl::Span<const int64_t> inserted_window_dims,absl::Span<const int64_t> scatter_dims_to_operand_dims,int64_t index_vector_dim)3179 HloScatterInstruction::MakeScatterDimNumbers(
3180     absl::Span<const int64_t> update_window_dims,
3181     absl::Span<const int64_t> inserted_window_dims,
3182     absl::Span<const int64_t> scatter_dims_to_operand_dims,
3183     int64_t index_vector_dim) {
3184   ScatterDimensionNumbers scatter_dim_numbers;
3185   for (int64_t update_window_dim : update_window_dims) {
3186     scatter_dim_numbers.add_update_window_dims(update_window_dim);
3187   }
3188   for (int64_t inserted_window_dim : inserted_window_dims) {
3189     scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
3190   }
3191   for (int64_t scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
3192     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
3193         scatter_dim_to_operand_dim);
3194   }
3195   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
3196   return scatter_dim_numbers;
3197 }
3198 
ToProto() const3199 HloInstructionProto HloScatterInstruction::ToProto() const {
3200   HloInstructionProto proto = HloInstruction::ToProto();
3201   *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
3202   proto.set_indices_are_sorted(indices_are_sorted());
3203   proto.set_unique_indices(unique_indices());
3204   return proto;
3205 }
3206 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3207 std::vector<std::string> HloScatterInstruction::ExtraAttributesToStringImpl(
3208     const HloPrintOptions& options) const {
3209   std::vector<std::string> attrs{
3210       ScatterDimensionNumbersToString(scatter_dimension_numbers())};
3211   if (indices_are_sorted()) {
3212     attrs.push_back("indices_are_sorted=true");
3213   }
3214   if (unique_indices()) {
3215     attrs.push_back("unique_indices=true");
3216   }
3217   return attrs;
3218 }
3219 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3220 bool HloScatterInstruction::IdenticalSlowPath(
3221     const HloInstruction& other,
3222     const std::function<bool(const HloComputation*, const HloComputation*)>&
3223         eq_computations) const {
3224   const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
3225   return protobuf_util::ProtobufEquals(
3226              scatter_dimension_numbers(),
3227              casted_other.scatter_dimension_numbers()) &&
3228          eq_computations(to_apply(), casted_other.to_apply()) &&
3229          indices_are_sorted() == casted_other.indices_are_sorted() &&
3230          unique_indices() == casted_other.unique_indices();
3231 }
3232 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3233 std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
3234     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3235     HloCloneContext* context) const {
3236   return std::make_unique<HloScatterInstruction>(
3237       shape, new_operands, to_apply(), scatter_dimension_numbers(),
3238       indices_are_sorted(), unique_indices());
3239 }
3240 
HloIotaInstruction(const Shape & shape,int64_t iota_dimension)3241 HloIotaInstruction::HloIotaInstruction(const Shape& shape,
3242                                        int64_t iota_dimension)
3243     : HloInstruction(HloOpcode::kIota, shape),
3244       iota_dimension_(iota_dimension) {}
3245 
ToProto() const3246 HloInstructionProto HloIotaInstruction::ToProto() const {
3247   HloInstructionProto proto = HloInstruction::ToProto();
3248   proto.add_dimensions(iota_dimension());
3249   return proto;
3250 }
3251 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3252 std::vector<std::string> HloIotaInstruction::ExtraAttributesToStringImpl(
3253     const HloPrintOptions& options) const {
3254   return {StrCat("iota_dimension=", iota_dimension())};
3255 }
3256 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3257 bool HloIotaInstruction::IdenticalSlowPath(
3258     const HloInstruction& other,
3259     const std::function<bool(const HloComputation*, const HloComputation*)>&
3260         eq_computations) const {
3261   const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
3262   return iota_dimension() == casted_other.iota_dimension();
3263 }
3264 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3265 std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
3266     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3267     HloCloneContext* context) const {
3268   return std::make_unique<HloIotaInstruction>(shape, iota_dimension());
3269 }
3270 
HloDotInstruction(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)3271 HloDotInstruction::HloDotInstruction(
3272     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
3273     const DotDimensionNumbers& dimension_numbers,
3274     const PrecisionConfig& precision_config)
3275     : HloInstruction(HloOpcode::kDot, shape),
3276       dot_dimension_numbers_(dimension_numbers),
3277       precision_config_(precision_config) {
3278   AppendOperand(lhs);
3279   AppendOperand(rhs);
3280 }
3281 
ToProto() const3282 HloInstructionProto HloDotInstruction::ToProto() const {
3283   HloInstructionProto proto = HloInstruction::ToProto();
3284   *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
3285   *proto.mutable_precision_config() = precision_config_;
3286   return proto;
3287 }
3288 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3289 std::vector<std::string> HloDotInstruction::ExtraAttributesToStringImpl(
3290     const HloPrintOptions& options) const {
3291   std::vector<std::string> extra = {
3292       DotDimensionNumbersToString(dot_dimension_numbers_)};
3293 
3294   std::string precision_config_string =
3295       PrecisionConfigToString(precision_config_);
3296   if (!precision_config_string.empty()) {
3297     extra.push_back(precision_config_string);
3298   }
3299   return extra;
3300 }
3301 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3302 bool HloDotInstruction::IdenticalSlowPath(
3303     const HloInstruction& other,
3304     const std::function<bool(const HloComputation*, const HloComputation*)>&
3305         eq_computations) const {
3306   const auto& casted_other = static_cast<const HloDotInstruction&>(other);
3307   return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
3308                                        casted_other.dot_dimension_numbers()) &&
3309          protobuf_util::ProtobufEquals(precision_config(),
3310                                        casted_other.precision_config());
3311 }
3312 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3313 std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
3314     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3315     HloCloneContext* context) const {
3316   CHECK_EQ(new_operands.size(), 2);
3317   return std::make_unique<HloDotInstruction>(
3318       shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
3319       precision_config_);
3320 }
3321 
HloDomainInstruction(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)3322 HloDomainInstruction::HloDomainInstruction(
3323     const Shape& shape, HloInstruction* operand,
3324     std::unique_ptr<DomainMetadata> operand_side_metadata,
3325     std::unique_ptr<DomainMetadata> user_side_metadata)
3326     : HloInstruction(HloOpcode::kDomain, shape),
3327       operand_side_metadata_(std::move(operand_side_metadata)),
3328       user_side_metadata_(std::move(user_side_metadata)) {
3329   AppendOperand(operand);
3330 }
3331 
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3332 std::vector<std::string> HloDomainInstruction::ExtraAttributesToStringImpl(
3333     const HloPrintOptions& options) const {
3334   if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
3335     return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
3336                    "\", entry=", user_side_metadata_->ToString(),
3337                    ", exit=", operand_side_metadata_->ToString(), "}")};
3338   }
3339   return {};
3340 }
3341 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3342 bool HloDomainInstruction::IdenticalSlowPath(
3343     const HloInstruction& other,
3344     const std::function<bool(const HloComputation*, const HloComputation*)>&
3345         eq_computations) const {
3346   const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
3347   return operand_side_metadata().Matches(
3348              casted_other.operand_side_metadata()) &&
3349          user_side_metadata().Matches(casted_other.user_side_metadata());
3350 }
3351 
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const3352 std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
3353     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3354     HloCloneContext* context) const {
3355   CHECK_EQ(new_operands.size(), 1);
3356   return std::make_unique<HloDomainInstruction>(shape, new_operands[0],
3357                                                 operand_side_metadata_->Clone(),
3358                                                 user_side_metadata_->Clone());
3359 }
3360 
ToProto() const3361 HloInstructionProto HloDomainInstruction::ToProto() const {
3362   HloInstructionProto proto = HloInstruction::ToProto();
3363   auto operand_side_sharding =
3364       dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
3365   if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) {
3366     *proto.mutable_domain_entry_sharding() =
3367         operand_side_sharding->sharding()->ToProto();
3368   }
3369 
3370   auto user_side_sharding =
3371       dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
3372   if (user_side_sharding && user_side_sharding->sharding() != nullptr) {
3373     *proto.mutable_domain_exit_sharding() =
3374         user_side_sharding->sharding()->ToProto();
3375   }
3376 
3377   return proto;
3378 }
3379 
HloGetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,int64_t dimension)3380 HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction(
3381     const Shape& shape, HloInstruction* operand, int64_t dimension)
3382     : HloInstruction(HloOpcode::kGetDimensionSize, shape),
3383       dimension_(dimension) {
3384   AppendOperand(operand);
3385 }
3386 
ToProto() const3387 HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const {
3388   HloInstructionProto proto = HloInstruction::ToProto();
3389   proto.add_dimensions(dimension());
3390   return proto;
3391 }
3392 
3393 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3394 HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3395     const HloPrintOptions& /*options*/) const {
3396   return {StrCat("dimensions={", dimension(), "}")};
3397 }
3398 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3399 bool HloGetDimensionSizeInstruction::IdenticalSlowPath(
3400     const HloInstruction& other,
3401     const std::function<bool(const HloComputation*, const HloComputation*)>&
3402     /*eq_computations*/) const {
3403   const auto& casted_other =
3404       static_cast<const HloGetDimensionSizeInstruction&>(other);
3405   return dimension() == casted_other.dimension();
3406 }
3407 
3408 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3409 HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3410     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3411     HloCloneContext* /*context*/) const {
3412   if (new_operands.size() != 1) {
3413     LOG(FATAL) << "expects 1 operand";
3414   }
3415   return std::make_unique<HloGetDimensionSizeInstruction>(
3416       shape, new_operands[0], dimension());
3417 }
3418 
HloSetDimensionSizeInstruction(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)3419 HloSetDimensionSizeInstruction::HloSetDimensionSizeInstruction(
3420     const Shape& shape, HloInstruction* operand, HloInstruction* val,
3421     int64_t dimension)
3422     : HloInstruction(HloOpcode::kSetDimensionSize, shape),
3423       dimension_(dimension) {
3424   AppendOperand(operand);
3425   AppendOperand(val);
3426 }
3427 
3428 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3429 HloSetDimensionSizeInstruction::ExtraAttributesToStringImpl(
3430     const HloPrintOptions& /*options*/) const {
3431   return {StrCat("dimensions={", dimension(), "}")};
3432 }
3433 
ToProto() const3434 HloInstructionProto HloSetDimensionSizeInstruction::ToProto() const {
3435   HloInstructionProto proto = HloInstruction::ToProto();
3436   proto.add_dimensions(dimension());
3437   return proto;
3438 }
3439 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3440 bool HloSetDimensionSizeInstruction::IdenticalSlowPath(
3441     const HloInstruction& other,
3442     const std::function<bool(const HloComputation*, const HloComputation*)>&
3443     /*eq_computations*/) const {
3444   const auto& casted_other =
3445       static_cast<const HloSetDimensionSizeInstruction&>(other);
3446   return dimension() == casted_other.dimension();
3447 }
3448 
3449 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3450 HloSetDimensionSizeInstruction::CloneWithNewOperandsImpl(
3451     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3452     HloCloneContext* /*context*/) const {
3453   if (new_operands.size() != 2) {
3454     LOG(FATAL) << "expects 2 operand";
3455   }
3456   return std::make_unique<HloSetDimensionSizeInstruction>(
3457       shape, new_operands[0], new_operands[1], dimension());
3458 }
3459 
HloRngGetAndUpdateStateInstruction(const Shape & shape,int64_t delta)3460 HloRngGetAndUpdateStateInstruction::HloRngGetAndUpdateStateInstruction(
3461     const Shape& shape, int64_t delta)
3462     : HloInstruction(HloOpcode::kRngGetAndUpdateState, shape), delta_(delta) {}
3463 
ToProto() const3464 HloInstructionProto HloRngGetAndUpdateStateInstruction::ToProto() const {
3465   HloInstructionProto proto = HloInstruction::ToProto();
3466   proto.set_delta(delta_);
3467   return proto;
3468 }
3469 
3470 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions &) const3471 HloRngGetAndUpdateStateInstruction::ExtraAttributesToStringImpl(
3472     const HloPrintOptions& /*options*/) const {
3473   return {StrCat("delta=", delta())};
3474 }
3475 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> &) const3476 bool HloRngGetAndUpdateStateInstruction::IdenticalSlowPath(
3477     const HloInstruction& other,
3478     const std::function<bool(const HloComputation*, const HloComputation*)>&
3479     /*eq_computations*/) const {
3480   const auto& casted_other =
3481       static_cast<const HloRngGetAndUpdateStateInstruction&>(other);
3482   return delta() == casted_other.delta();
3483 }
3484 
3485 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3486 HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
3487     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3488     HloCloneContext* /*context*/) const {
3489   if (!new_operands.empty()) {
3490     LOG(FATAL) << "expects 0 operand";
3491   }
3492   return std::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
3493 }
3494 
HloRngBitGeneratorInstruction(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)3495 HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
3496     const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
3497     : HloInstruction(HloOpcode::kRngBitGenerator, shape),
3498       algorithm_(algorithm) {
3499   AppendOperand(state);
3500 }
3501 
ToProto() const3502 HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
3503   HloInstructionProto proto = HloInstruction::ToProto();
3504   proto.set_rng_algorithm(algorithm_);
3505   return proto;
3506 }
3507 
3508 std::vector<std::string>
ExtraAttributesToStringImpl(const HloPrintOptions & options) const3509 HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
3510     const HloPrintOptions& options) const {
3511   return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
3512 }
3513 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const3514 bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
3515     const HloInstruction& other,
3516     const std::function<bool(const HloComputation*, const HloComputation*)>&
3517         eq_computations) const {
3518   const auto& casted_other =
3519       static_cast<const HloRngBitGeneratorInstruction&>(other);
3520   return algorithm() == casted_other.algorithm();
3521 }
3522 
3523 std::unique_ptr<HloInstruction>
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext *) const3524 HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
3525     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
3526     HloCloneContext* /*context*/) const {
3527   CHECK_EQ(new_operands.size(), 1);
3528   return std::make_unique<HloRngBitGeneratorInstruction>(shape, new_operands[0],
3529                                                          algorithm());
3530 }
3531 
3532 }  // namespace xla
3533