• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <iostream>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <utility>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/strings/ascii.h"
33 #include "absl/strings/escaping.h"
34 #include "absl/strings/numbers.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/string_view.h"
38 #include "absl/types/span.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/literal.h"
41 #include "tensorflow/compiler/xla/protobuf_util.h"
42 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_computation.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_module.h"
47 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
50 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
51 #include "tensorflow/compiler/xla/service/name_uniquer.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/status_macros.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/gtl/map_util.h"
59 #include "tensorflow/core/platform/errors.h"
60 #include "tensorflow/core/platform/human_readable_json.h"
61 #include "tensorflow/core/platform/logging.h"
62 
63 namespace xla {
64 
65 using absl::CEscape;
66 using absl::StrAppend;
67 using absl::StrCat;
68 using absl::StrJoin;
69 
AddInstruction(std::unique_ptr<HloInstruction> derived_instruction)70 HloInstruction* HloInstruction::AddInstruction(
71     std::unique_ptr<HloInstruction> derived_instruction) {
72   HloInstruction* derived =
73       parent()->AddInstruction(std::move(derived_instruction));
74   const bool has_prior_sharding = derived->has_sharding();
75   SetupDerivedInstruction(derived);
76   if (!has_prior_sharding && (derived->opcode() == HloOpcode::kReshape ||
77                               derived->opcode() == HloOpcode::kTranspose)) {
78     derived->clear_sharding();
79   }
80   return derived;
81 }
82 
83 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64_t,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64_t,HloComputation * > & computation_map,bool prohibit_empty_literal)84 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
85     const HloInstructionProto& proto,
86     const absl::flat_hash_map<int64_t, HloInstruction*>& instruction_map,
87     const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
88     bool prohibit_empty_literal) {
89   TF_RET_CHECK(!proto.opcode().empty());
90   HloOpcode opcode;
91   auto opcode_or = StringToHloOpcode(proto.opcode());
92   std::optional<ComparisonDirection> comparison_direction;
93   if (opcode_or.ok()) {
94     opcode = std::move(opcode_or).value();
95   } else {
96     // Unknown opcode. Try auto-upgrading deprecated "less-than",
97     // "greater-than", etc opcodes, which are now rolled into the kCompare
98     // opcode.
99     if (proto.opcode() == "equal-to") {
100       comparison_direction = ComparisonDirection::kEq;
101     } else if (proto.opcode() == "not-equal-to") {
102       comparison_direction = ComparisonDirection::kNe;
103     } else if (proto.opcode() == "greater-than-or-equal-to") {
104       comparison_direction = ComparisonDirection::kGe;
105     } else if (proto.opcode() == "greater-than") {
106       comparison_direction = ComparisonDirection::kGt;
107     } else if (proto.opcode() == "less-than-or-equal-to") {
108       comparison_direction = ComparisonDirection::kLe;
109     } else if (proto.opcode() == "less-than") {
110       comparison_direction = ComparisonDirection::kLt;
111     }
112     if (comparison_direction) {
113       opcode = HloOpcode::kCompare;
114     } else {
115       return InvalidArgument("Unknown opcode: %s", proto.opcode());
116     }
117   }
118 
119   TF_RET_CHECK(proto.has_shape());
120 
121   std::unique_ptr<HloInstruction> instruction;
122   const auto operands = [&instruction_map, &proto](int index) {
123     return instruction_map.at(proto.operand_ids(index));
124   };
125   const auto all_operands = [&instruction_map, &proto]() {
126     std::vector<HloInstruction*> result(proto.operand_ids_size());
127     std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
128                    result.begin(), [&instruction_map](int64_t operand_id) {
129                      return instruction_map.at(operand_id);
130                    });
131     return result;
132   };
133   const auto computations = [&computation_map, &proto](int index) {
134     return computation_map.at(proto.called_computation_ids(index));
135   };
136   const auto all_computations = [&computation_map, &proto]() {
137     std::vector<HloComputation*> result(proto.called_computation_ids_size());
138     std::transform(proto.called_computation_ids().begin(),
139                    proto.called_computation_ids().end(), result.begin(),
140                    [&computation_map](int64_t computation_id) {
141                      return computation_map.at(computation_id);
142                    });
143     return result;
144   };
145 
146   TF_RET_CHECK(
147       absl::c_all_of(proto.operand_ids(),
148                      [&](int64_t id) { return instruction_map.contains(id); }))
149       << proto.name() << " instruction contains invalid operand id(s)";
150 
151   TF_RET_CHECK(
152       absl::c_all_of(proto.called_computation_ids(),
153                      [&](int64_t id) { return computation_map.contains(id); }))
154       << proto.name() << " instruction references invalid computation id(s)";
155 
156   Shape shape(proto.shape());
157   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
158 
159   std::optional<int> arity = HloOpcodeArity(opcode);
160   if (arity) {
161     TF_RET_CHECK(proto.operand_ids_size() == *arity)
162         << proto.opcode() << " instruction should have " << *arity
163         << " operands but sees " << proto.operand_ids_size();
164   }
165 
166   switch (opcode) {
167     // Ops migrated to subclasses.
168     case HloOpcode::kBatchNormTraining:
169       instruction =
170           CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
171                                   proto.epsilon(), proto.feature_index());
172       break;
173     case HloOpcode::kBatchNormInference:
174       instruction = CreateBatchNormInference(
175           shape, operands(0), operands(1), operands(2), operands(3),
176           operands(4), proto.epsilon(), proto.feature_index());
177       break;
178     case HloOpcode::kBatchNormGrad:
179       instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
180                                         operands(2), operands(3), operands(4),
181                                         proto.epsilon(), proto.feature_index());
182       break;
183     case HloOpcode::kFft: {
184       std::vector<int64_t> fft_length(proto.fft_length().begin(),
185                                       proto.fft_length().end());
186       instruction = CreateFft(shape, operands(0), proto.fft_type(),
187                               absl::Span<const int64_t>(fft_length));
188       break;
189     }
190     case HloOpcode::kAsyncStart: {
191       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
192           << "Async start instruction should have 1 called computation but "
193              "sees "
194           << proto.called_computation_ids_size();
195       std::optional<int64_t> async_group_id;
196       if (proto.async_group_id() >= 0) {
197         async_group_id = proto.async_group_id();
198       }
199       instruction = CreateAsyncStart(shape, all_operands(), computations(0),
200                                      async_group_id,
201                                      proto.async_execution_thread().empty()
202                                          ? kMainExecutionThread
203                                          : proto.async_execution_thread());
204       break;
205     }
206     case HloOpcode::kAsyncUpdate: {
207       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
208           << "Async update instruction should have 1 called computation but "
209              "sees "
210           << proto.called_computation_ids_size();
211       std::optional<int64_t> async_group_id;
212       if (proto.async_group_id() >= 0) {
213         async_group_id = proto.async_group_id();
214       }
215       instruction =
216           CreateAsyncUpdate(shape, operands(0), computations(0), async_group_id,
217                             proto.async_execution_thread().empty()
218                                 ? kMainExecutionThread
219                                 : proto.async_execution_thread());
220       break;
221     }
222     case HloOpcode::kAsyncDone: {
223       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
224           << "Async done instruction should have 1 called computation but sees "
225           << proto.called_computation_ids_size();
226       std::optional<int64_t> async_group_id;
227       if (proto.async_group_id() >= 0) {
228         async_group_id = proto.async_group_id();
229       }
230       instruction =
231           CreateAsyncDone(shape, operands(0), computations(0), async_group_id,
232                           proto.async_execution_thread().empty()
233                               ? kMainExecutionThread
234                               : proto.async_execution_thread());
235       break;
236     }
237     case HloOpcode::kCopyStart: {
238       instruction = CreateCopyStart(shape, operands(0),
239                                     proto.is_cross_program_prefetch());
240       break;
241     }
242     case HloOpcode::kCompare: {
243       // Auto-upgraded from deprecated opcode skips the following.
244       if (!comparison_direction) {
245         TF_ASSIGN_OR_RETURN(
246             comparison_direction,
247             StringToComparisonDirection(proto.comparison_direction()));
248       }
249       auto comparison_type_str = proto.comparison_type();
250       if (!comparison_type_str.empty()) {
251         // If a comparison type is specified, it *must* be valid.
252         TF_ASSIGN_OR_RETURN(auto comparison_type,
253                             StringToComparisonType(comparison_type_str));
254         instruction = CreateCompare(shape, operands(0), operands(1),
255                                     *comparison_direction, comparison_type);
256       } else {
257         // Allow the specify of comparison type to be optional.
258         // The comparison type will be determined by the types of the operands.
259         instruction = CreateCompare(shape, operands(0), operands(1),
260                                     *comparison_direction);
261       }
262       break;
263     }
264     case HloOpcode::kTriangularSolve: {
265       instruction = CreateTriangularSolve(shape, operands(0), operands(1),
266                                           proto.triangular_solve_options());
267       break;
268     }
269     case HloOpcode::kCholesky: {
270       instruction =
271           CreateCholesky(shape, operands(0), proto.cholesky_options());
272       break;
273     }
274     case HloOpcode::kSend:
275       instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
276                                proto.is_host_transfer());
277       break;
278     case HloOpcode::kSendDone:
279       instruction = CreateSendDone(operands(0), proto.is_host_transfer());
280       break;
281     case HloOpcode::kRecv:
282       instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
283                                proto.channel_id(), proto.is_host_transfer());
284       break;
285     case HloOpcode::kRecvDone:
286       instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
287       break;
288     case HloOpcode::kReverse:
289       instruction =
290           CreateReverse(shape, operands(0),
291                         std::vector<int64_t>(proto.dimensions().begin(),
292                                              proto.dimensions().end()));
293       break;
294     case HloOpcode::kConcatenate:
295       TF_RET_CHECK(proto.dimensions_size() == 1)
296           << "Concatenate instruction should have 1 dimension but sees "
297           << proto.dimensions_size();
298       instruction =
299           CreateConcatenate(shape, all_operands(), proto.dimensions(0));
300       break;
301     case HloOpcode::kConditional: {
302       TF_RET_CHECK(proto.called_computation_ids_size() > 0)
303           << "conditional should have at least 1 called computation";
304       if (operands(0)->shape().element_type() == PRED) {
305         TF_RET_CHECK(proto.called_computation_ids_size() == 2)
306             << "conditional should have exactly 2 called computations but got "
307             << proto.called_computation_ids_size();
308       }
309       TF_RET_CHECK(proto.operand_ids_size() ==
310                    proto.called_computation_ids_size() + 1)
311           << "conditional should have one branch_index operand plus one "
312              "operand per called computation but got "
313           << proto.operand_ids_size() << " operands for "
314           << proto.called_computation_ids_size() << " branch computations";
315       auto cond_operands = all_operands();
316       instruction =
317           CreateConditional(shape, cond_operands[0], all_computations(),
318                             absl::MakeSpan(cond_operands).subspan(1));
319       break;
320     }
321     case HloOpcode::kReduce:
322       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
323           << "Reduce instruction should have an even number of operands but "
324              "sees "
325           << proto.operand_ids_size();
326       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
327           << "Reduce instruction should have 1 called computation but sees "
328           << proto.called_computation_ids_size();
329       {
330         const auto reduce_operands = all_operands();
331         auto inputs = absl::MakeSpan(reduce_operands)
332                           .subspan(0, reduce_operands.size() / 2);
333         auto init_values =
334             absl::MakeSpan(reduce_operands)
335                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
336         instruction =
337             CreateReduce(shape, inputs, init_values,
338                          std::vector<int64_t>(proto.dimensions().begin(),
339                                               proto.dimensions().end()),
340                          computations(0));
341       }
342       break;
343     case HloOpcode::kSort: {
344       TF_RET_CHECK(proto.operand_ids_size() >= 1)
345           << "Sort instruction should have at least 1 operand but has "
346           << proto.operand_ids_size();
347       TF_RET_CHECK(proto.dimensions().size() == 1)
348           << "Sort instruction should have 1 dimension";
349       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
350           << "Sort instruction should one called computation but sees "
351           << proto.called_computation_ids_size();
352       auto sort_operands = all_operands();
353       instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
354                                computations(0), proto.is_stable());
355       break;
356     }
357     case HloOpcode::kTranspose:
358       instruction =
359           CreateTranspose(shape, operands(0),
360                           std::vector<int64_t>(proto.dimensions().begin(),
361                                                proto.dimensions().end()));
362       break;
363     case HloOpcode::kBroadcast:
364       instruction =
365           CreateBroadcast(shape, operands(0),
366                           std::vector<int64_t>(proto.dimensions().begin(),
367                                                proto.dimensions().end()));
368       break;
369     case HloOpcode::kMap:
370       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
371           << "Map instruction should have 1 called computation but sees "
372           << proto.called_computation_ids_size();
373       instruction = CreateMap(shape, all_operands(), computations(0));
374       break;
375     case HloOpcode::kSlice: {
376       std::vector<int64_t> slice_starts, slice_limits, slice_strides;
377       for (const HloInstructionProto::SliceDimensions& slice_dimensions :
378            proto.slice_dimensions()) {
379         slice_starts.push_back(slice_dimensions.start());
380         slice_limits.push_back(slice_dimensions.limit());
381         slice_strides.push_back(slice_dimensions.stride());
382       }
383       instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
384                                 slice_strides);
385       break;
386     }
387     case HloOpcode::kConstant: {
388       // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
389       if (proto.has_literal()) {
390         TF_ASSIGN_OR_RETURN(
391             auto literal,
392             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
393         instruction = CreateConstant(std::move(literal));
394         // Literal's shape may have no/different tiling info.
395         TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
396             instruction->shape(), shape))
397             << instruction->shape().ToString(true) << " vs "
398             << shape.ToString(true);
399         *instruction->mutable_shape() = shape;
400       } else {
401         instruction = std::make_unique<HloConstantInstruction>(shape);
402       }
403       break;
404     }
405     case HloOpcode::kFusion: {
406       // In the proto, fused computations are held exclusively within the
407       // HloInstructionProto and do not appear as an HloComputationProto within
408       // the HloModuleProto.
409       TF_RET_CHECK(!proto.fusion_kind().empty());
410       TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
411                           StringToFusionKind(proto.fusion_kind()));
412 
413       // Find the fused computation and set its fusion instruction.
414       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
415           << "Expect 1 called computation for fusion instruction but sees "
416           << proto.called_computation_ids_size();
417       const int64_t fusion_id = proto.called_computation_ids(0);
418       auto* fused_computation =
419           tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
420       TF_RET_CHECK(fused_computation != nullptr)
421           << "No fusion computation with id " << fusion_id;
422       instruction =
423           CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
424       break;
425     }
426     case HloOpcode::kRng:
427       instruction = CreateRng(shape, proto.distribution(), all_operands());
428       break;
429     case HloOpcode::kRngBitGenerator:
430       instruction =
431           CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
432       break;
433     case HloOpcode::kRngGetAndUpdateState:
434       instruction = CreateRngGetAndUpdateState(shape, proto.delta());
435       break;
436     case HloOpcode::kParameter:
437       instruction =
438           CreateParameter(proto.parameter_number(), shape, proto.name());
439       if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
440         instruction->set_parameter_replicated_at_leaf_buffers(
441             proto.parameter_replication().replicated_at_leaf_buffers());
442       }
443       break;
444     case HloOpcode::kGetTupleElement:
445       instruction =
446           CreateGetTupleElement(shape, operands(0), proto.tuple_index());
447       break;
448     case HloOpcode::kReducePrecision:
449       instruction = CreateReducePrecision(
450           shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
451       break;
452     case HloOpcode::kInfeed: {
453       TF_RET_CHECK(shape.IsTuple() &&
454                    (ShapeUtil::TupleElementCount(shape) == 2))
455           << "Infeed should have a tuple shape with 2 operands, but has: "
456           << shape;
457       const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
458       instruction =
459           CreateInfeed(data_shape, operands(0), proto.infeed_config());
460     } break;
461     case HloOpcode::kOutfeed: {
462       Shape outfeed_shape(proto.outfeed_shape());
463       TF_RETURN_IF_ERROR(
464           ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
465       instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
466                                   proto.outfeed_config());
467       break;
468     }
469     case HloOpcode::kAllGather:
470     case HloOpcode::kAllGatherStart: {
471       std::optional<int64_t> channel_id;
472       if (proto.channel_id() > 0) {
473         channel_id = proto.channel_id();
474       }
475 
476       TF_RET_CHECK(proto.dimensions_size() == 1)
477           << "AllGather cannot have more than 1 all-gather dimensions";
478       int64_t all_gather_dimension = proto.dimensions(0);
479       if (opcode == HloOpcode::kAllGather) {
480         instruction = CreateAllGather(
481             shape, all_operands(), all_gather_dimension,
482             std::vector<ReplicaGroup>(proto.replica_groups().begin(),
483                                       proto.replica_groups().end()),
484             proto.constrain_layout(), channel_id,
485             proto.use_global_device_ids());
486       } else {
487         instruction = CreateAllGatherStart(
488             shape, all_operands(), all_gather_dimension,
489             std::vector<ReplicaGroup>(proto.replica_groups().begin(),
490                                       proto.replica_groups().end()),
491             proto.constrain_layout(), channel_id,
492             proto.use_global_device_ids());
493       }
494       break;
495     }
496     case HloOpcode::kAllReduce:
497     case HloOpcode::kAllReduceStart:
498     case HloOpcode::kReduceScatter: {
499       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
500           << "AllReduce should have 1 called computation but sees "
501           << proto.called_computation_ids_size();
502       TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
503           << "AllReduce cannot have both channel_id() and all_reduce_id()";
504       std::optional<int64_t> channel_id;
505       if (proto.channel_id() > 0) {
506         channel_id = proto.channel_id();
507       }
508       if (proto.all_reduce_id() > 0) {
509         channel_id = proto.all_reduce_id();
510       }
511       std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(),
512                                                proto.replica_groups().end());
513       if (opcode == HloOpcode::kAllReduce) {
514         instruction =
515             CreateAllReduce(shape, all_operands(), computations(0),
516                             replica_groups, proto.constrain_layout(),
517                             channel_id, proto.use_global_device_ids());
518       } else if (opcode == HloOpcode::kReduceScatter) {
519         TF_RET_CHECK(proto.dimensions_size() == 1)
520             << "ReduceScatter cannot have more than 1 scatter dimensions";
521         int64_t scatter_dimension = proto.dimensions(0);
522         instruction = CreateReduceScatter(
523             shape, all_operands(), computations(0), replica_groups,
524             proto.constrain_layout(), channel_id, proto.use_global_device_ids(),
525             scatter_dimension);
526       } else {
527         instruction =
528             CreateAllReduceStart(shape, all_operands(), computations(0),
529                                  replica_groups, proto.constrain_layout(),
530                                  channel_id, proto.use_global_device_ids());
531       }
532       break;
533     }
534     case HloOpcode::kAllToAll: {
535       std::optional<int64_t> channel_id;
536       if (proto.channel_id() > 0) {
537         channel_id = proto.channel_id();
538       }
539       std::optional<int64_t> split_dimension;
540       if (proto.dimensions_size() > 0) {
541         TF_RET_CHECK(proto.dimensions_size() == 1)
542             << "AllToAll cannot have more than 1 dimension (split dimension)";
543         TF_RET_CHECK(all_operands().size() == 1)
544             << "AllToAll must have a single operand when the split dimension "
545                "is specified";
546         split_dimension = proto.dimensions(0);
547       }
548       instruction = CreateAllToAll(
549           shape, all_operands(),
550           /*replica_groups=*/
551           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
552                                     proto.replica_groups().end()),
553           /*constrain_layout=*/proto.constrain_layout(),
554           /*channel_id=*/channel_id, split_dimension);
555       break;
556     }
557     case HloOpcode::kCollectivePermute:
558     case HloOpcode::kCollectivePermuteStart: {
559       TF_RET_CHECK(proto.operand_ids().size() == 1 ||
560                    proto.operand_ids().size() == 4);
561       std::vector<std::pair<int64_t, int64_t>> source_target_pairs(
562           proto.source_target_pairs_size());
563       std::optional<int64_t> channel_id;
564       if (proto.channel_id() > 0) {
565         channel_id = proto.channel_id();
566       }
567       for (int i = 0; i < source_target_pairs.size(); ++i) {
568         source_target_pairs[i].first = proto.source_target_pairs(i).source();
569         source_target_pairs[i].second = proto.source_target_pairs(i).target();
570       }
571       if (proto.dynamic_slice_sizes_size() == 0) {
572         if (opcode == HloOpcode::kCollectivePermute) {
573           instruction = CreateCollectivePermute(
574               shape, operands(0), source_target_pairs, channel_id);
575         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
576           instruction = CreateCollectivePermuteStart(
577               shape, operands(0), source_target_pairs, channel_id);
578         } else {
579           LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
580                      << "but got " << HloOpcodeString(opcode);
581         }
582       } else {
583         std::vector<std::vector<int64_t>> slice_sizes;
584         HloInstruction* input = operands(0);
585         HloInstruction* input_start_indices = operands(2);
586         if (input->shape().IsTuple() &&
587             input->shape().tuple_shapes_size() > 1) {
588           slice_sizes.resize(input->shape().tuple_shapes_size());
589         } else {
590           slice_sizes.resize(1);
591         }
592         int proto_index = 0;
593         if (input->shape().IsTuple()) {
594           if (input_start_indices->shape()
595                   .tuple_shapes(0)
596                   .tuple_shapes(0)
597                   .IsArray()) {
598             slice_sizes.resize(input->shape().tuple_shapes_size());
599             for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
600               slice_sizes[i].resize(
601                   input->shape().tuple_shapes(i).dimensions_size());
602               for (int j = 0;
603                    j < input->shape().tuple_shapes(i).dimensions_size(); ++j) {
604                 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
605                 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
606                 proto_index += 1;
607               }
608             }
609           } else {
610             slice_sizes.resize(
611                 input->shape().tuple_shapes_size() *
612                 ShapeUtil::TupleElementCount(
613                     input_start_indices->shape().tuple_shapes(0)));
614             int slice_sizes_count = 0;
615             for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
616               for (int j = 0;
617                    j < ShapeUtil::TupleElementCount(
618                            input_start_indices->shape().tuple_shapes(i));
619                    ++j) {
620                 slice_sizes[slice_sizes_count].resize(
621                     input->shape().tuple_shapes(i).rank());
622                 for (int k = 0; k < input->shape().tuple_shapes(i).rank();
623                      ++k) {
624                   CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
625                   slice_sizes[slice_sizes_count][k] =
626                       proto.dynamic_slice_sizes(proto_index);
627                   proto_index += 1;
628                 }
629                 slice_sizes_count += 1;
630               }
631             }
632           }
633         } else {
634           slice_sizes.resize(
635               ShapeUtil::TupleElementCount(input_start_indices->shape()));
636           if (input_start_indices->shape().tuple_shapes(0).IsTuple()) {
637             for (int i = 0;
638                  i < ShapeUtil::TupleElementCount(input_start_indices->shape());
639                  ++i) {
640               slice_sizes[i].resize(input->shape().dimensions_size());
641               for (int j = 0; j < input->shape().dimensions_size(); ++j) {
642                 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
643                 proto_index += 1;
644               }
645             }
646           } else {
647             slice_sizes.resize(1);
648             slice_sizes[0].resize(input->shape().dimensions_size());
649             for (int j = 0; j < input->shape().dimensions_size(); ++j) {
650               slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index);
651               proto_index += 1;
652             }
653           }
654         }
655         if (opcode == HloOpcode::kCollectivePermute) {
656           instruction = CreateCollectivePermute(
657               shape, operands(0), operands(1), operands(2), operands(3),
658               source_target_pairs, slice_sizes, channel_id);
659         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
660           instruction = CreateCollectivePermuteStart(
661               shape, operands(0), operands(1), operands(2), operands(3),
662               source_target_pairs, slice_sizes, channel_id);
663         } else {
664           LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
665                      << "but got " << HloOpcodeString(opcode);
666         }
667       }
668       break;
669     }
670     case HloOpcode::kReplicaId: {
671       instruction = CreateReplicaId(shape);
672       break;
673     }
674     case HloOpcode::kPartitionId: {
675       instruction = CreatePartitionId(shape);
676       break;
677     }
678     case HloOpcode::kConvolution: {
679       TF_RET_CHECK(proto.has_window());
680       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
681       PrecisionConfig precision_config = proto.precision_config();
682       precision_config.mutable_operand_precision()->Resize(
683           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
684       instruction = CreateConvolve(
685           shape, operands(0), operands(1),
686           std::max<int64_t>(proto.feature_group_count(), 1),
687           std::max<int64_t>(proto.batch_group_count(), 1), proto.window(),
688           proto.convolution_dimension_numbers(), precision_config);
689       break;
690     }
691     case HloOpcode::kReduceWindow:
692       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
693           << "Reduce window should have an even number of operands but "
694              "sees "
695           << proto.operand_ids_size();
696       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
697           << "ReduceWindow should have 1 called computation but sees "
698           << proto.called_computation_ids_size();
699       {
700         const auto reduce_operands = all_operands();
701         auto inputs = absl::MakeSpan(reduce_operands)
702                           .subspan(0, reduce_operands.size() / 2);
703         auto init_values =
704             absl::MakeSpan(reduce_operands)
705                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
706         instruction = CreateReduceWindow(shape, inputs, init_values,
707                                          proto.window(), computations(0));
708       }
709       break;
710     case HloOpcode::kSelectAndScatter:
711       TF_RET_CHECK(proto.called_computation_ids_size() == 2)
712           << "SelectAndScatter should have 2 called computations but sees "
713           << proto.called_computation_ids_size();
714       instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
715                                            proto.window(), operands(1),
716                                            operands(2), computations(1));
717       break;
718     case HloOpcode::kCustomCall: {
719       if (proto.constrain_layout()) {
720         // A proto RepeatedPtrField cannot be converted to a Span (it is a
721         // vector of pointers essentially) so create a vector of shapes to pass
722         // in.
723         std::vector<Shape> operand_shapes;
724         const auto& operand_shapes_with_layout =
725             proto.operand_shapes_with_layout();
726         operand_shapes.reserve(operand_shapes_with_layout.size());
727         for (const ShapeProto& shape_proto : operand_shapes_with_layout) {
728           operand_shapes.emplace_back(shape_proto);
729         }
730         instruction =
731             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
732                              operand_shapes, proto.backend_config());
733       } else {
734         if (proto.called_computation_ids_size() == 1) {
735           instruction = CreateCustomCall(shape, all_operands(), computations(0),
736                                          proto.custom_call_target(),
737                                          proto.backend_config());
738         } else if (proto.called_computation_ids_size() > 1) {
739           instruction = CreateCustomCall(
740               shape, all_operands(), all_computations(),
741               proto.custom_call_target(), proto.backend_config());
742 
743         } else {
744           instruction = CreateCustomCall(shape, all_operands(),
745                                          proto.custom_call_target(),
746                                          proto.backend_config());
747         }
748       }
749       auto custom_call_instr =
750           Cast<HloCustomCallInstruction>(instruction.get());
751       if (proto.has_window()) {
752         custom_call_instr->set_window(proto.window());
753       }
754       if (proto.has_literal()) {
755         TF_ASSIGN_OR_RETURN(
756             auto literal,
757             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
758         custom_call_instr->set_literal(std::move(literal));
759       }
760       if (proto.has_convolution_dimension_numbers()) {
761         custom_call_instr->set_convolution_dimension_numbers(
762             proto.convolution_dimension_numbers());
763       }
764       custom_call_instr->set_feature_group_count(std::max(
765           static_cast<int64_t>(proto.feature_group_count()), int64_t{1}));
766       custom_call_instr->set_batch_group_count(std::max(
767           static_cast<int64_t>(proto.batch_group_count()), int64_t{1}));
768       custom_call_instr->set_custom_call_has_side_effect(
769           proto.custom_call_has_side_effect());
770       custom_call_instr->set_padding_type(proto.padding_type());
771 
772       PrecisionConfig precision_config = proto.precision_config();
773       precision_config.mutable_operand_precision()->Resize(
774           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
775       *custom_call_instr->mutable_precision_config() = precision_config;
776       std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
777           output_to_operand_aliasing;
778       for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
779         output_to_operand_aliasing.emplace_back(
780             ShapeIndex(aliasing.output_shape_index().begin(),
781                        aliasing.output_shape_index().end()),
782             std::pair<int64_t, ShapeIndex>{
783                 aliasing.operand_index(),
784                 ShapeIndex(aliasing.operand_shape_index().begin(),
785                            aliasing.operand_shape_index().end())});
786       }
787       custom_call_instr->set_output_to_operand_aliasing(
788           std::move(output_to_operand_aliasing));
789       custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule());
790       custom_call_instr->set_api_version(proto.custom_call_api_version());
791       break;
792     }
793     case HloOpcode::kPad:
794       TF_RET_CHECK(proto.has_padding_config());
795       instruction =
796           CreatePad(shape, operands(0), operands(1), proto.padding_config());
797       break;
798     case HloOpcode::kDynamicSlice: {
799       std::vector<int64_t> slice_sizes(proto.dynamic_slice_sizes_size());
800       absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
801       TF_RET_CHECK(proto.operand_ids_size() >= 1)
802           << "DynamicSlice instruction should have at least 1 operands but "
803              "sees "
804           << proto.operand_ids_size();
805       // TODO(b/118437727): Old form, make the check unconditional.
806       if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
807         auto expected_operands = 1 + operands(0)->shape().rank();
808         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
809             << "DynamicSlice instruction should have " << expected_operands
810             << " operands, but has " << proto.operand_ids_size();
811       }
812       const auto& operand_vector = all_operands();
813       instruction = CreateDynamicSlice(
814           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
815           slice_sizes);
816       break;
817     }
818     case HloOpcode::kDynamicUpdateSlice: {
819       TF_RET_CHECK(proto.operand_ids_size() >= 2)
820           << "DynamicUpdateSlice instruction should have at least 2 operands "
821              "but sees "
822           << proto.operand_ids_size();
823       // TODO(b/118437727): Old form, make the check unconditional.
824       if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
825         auto expected_operands = 2 + operands(0)->shape().rank();
826         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
827             << "DynamicUpdateSlice instruction should have "
828             << expected_operands << " operands, but has "
829             << proto.operand_ids_size();
830       }
831       const auto& operand_vector = all_operands();
832       instruction =
833           CreateDynamicUpdateSlice(shape, operands(0), operands(1),
834                                    absl::MakeSpan(operand_vector).subspan(2));
835 
836       break;
837     }
838     case HloOpcode::kGather: {
839       TF_RET_CHECK(proto.has_gather_dimension_numbers())
840           << "Gather instruction should have GatherDimensionNumbers set.";
841       auto gather_dimension_numbers = std::make_unique<GatherDimensionNumbers>(
842           proto.gather_dimension_numbers());
843       std::vector<int64_t> gather_slice_sizes;
844       const auto& slice_sizes = proto.gather_slice_sizes();
845       gather_slice_sizes.reserve(slice_sizes.size());
846       for (int64_t bound : slice_sizes) {
847         gather_slice_sizes.push_back(bound);
848       }
849       instruction = CreateGather(shape, operands(0), operands(1),
850                                  *gather_dimension_numbers, gather_slice_sizes,
851                                  proto.indices_are_sorted());
852       break;
853     }
854     case HloOpcode::kScatter: {
855       TF_RET_CHECK(proto.has_scatter_dimension_numbers())
856           << "Scatter instruction should have ScatterDimensionNumbers set.";
857       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
858           << "Scatter instruction should have 1 called computation but sees "
859           << proto.called_computation_ids_size();
860       auto scatter_dimension_numbers =
861           std::make_unique<ScatterDimensionNumbers>(
862               proto.scatter_dimension_numbers());
863       auto operands = all_operands();
864       auto operand_span = absl::MakeConstSpan(operands);
865       auto input_count = operands.size() / 2;
866       instruction =
867           CreateScatter(shape, operand_span.first(input_count),
868                         operands[input_count], operand_span.last(input_count),
869                         computations(0), *scatter_dimension_numbers,
870                         proto.indices_are_sorted(), proto.unique_indices());
871       break;
872     }
873     case HloOpcode::kIota:
874       TF_RET_CHECK(proto.dimensions_size() == 1)
875           << "Iota instruction should have 1 dimension but sees "
876           << proto.dimensions_size();
877       instruction = CreateIota(shape, proto.dimensions(0));
878       break;
879     case HloOpcode::kDot: {
880       TF_RET_CHECK(proto.has_dot_dimension_numbers())
881           << "Dot instruction should have dot_dimension_numbers.";
882       PrecisionConfig precision_config = proto.precision_config();
883       precision_config.mutable_operand_precision()->Resize(
884           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
885       instruction = std::make_unique<HloDotInstruction>(
886           shape, operands(0), operands(1), proto.dot_dimension_numbers(),
887           precision_config);
888       break;
889     }
890     case HloOpcode::kDomain: {
891       std::shared_ptr<const HloSharding> entry_hlo_sharding;
892       std::shared_ptr<const HloSharding> exit_hlo_sharding;
893       if (proto.has_domain_entry_sharding()) {
894         TF_ASSIGN_OR_RETURN(
895             HloSharding sharding,
896             HloSharding::FromProto(proto.domain_entry_sharding()));
897         entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
898       }
899       if (proto.has_domain_exit_sharding()) {
900         TF_ASSIGN_OR_RETURN(
901             HloSharding sharding,
902             HloSharding::FromProto(proto.domain_exit_sharding()));
903         exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
904       }
905       instruction = std::make_unique<HloDomainInstruction>(
906           shape, operands(0),
907           std::make_unique<ShardingMetadata>(entry_hlo_sharding),
908           std::make_unique<ShardingMetadata>(exit_hlo_sharding));
909       break;
910     }
911     case HloOpcode::kGetDimensionSize:
912       TF_RET_CHECK(proto.dimensions_size() == 1);
913       instruction =
914           CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
915       break;
916     case HloOpcode::kSetDimensionSize:
917       TF_RET_CHECK(proto.dimensions_size() == 1);
918       instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
919                                            proto.dimensions(0));
920       break;
921     case HloOpcode::kReshape: {
922       int64_t inferred_dimension = -1;
923       if (!proto.dimensions().empty()) {
924         inferred_dimension = proto.dimensions()[0];
925       }
926       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
927                    ShapeUtil::ElementsIn(shape) ==
928                        ShapeUtil::ElementsIn(operands(0)->shape()))
929           << "shape: " << ShapeUtil::HumanString(shape)
930           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
931       instruction = CreateReshape(shape, operands(0), inferred_dimension);
932       break;
933     }
934     case HloOpcode::kDynamicReshape: {
935       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
936                    ShapeUtil::ElementsIn(shape) ==
937                        ShapeUtil::ElementsIn(operands(0)->shape()))
938           << "shape: " << ShapeUtil::HumanString(shape)
939           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
940       const auto& operand_vector = all_operands();
941       instruction = CreateDynamicReshape(
942           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
943       break;
944     }
945     default: {
946       instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
947       for (const int64_t operand_id : proto.operand_ids()) {
948         instruction->AppendOperand(instruction_map.at(operand_id));
949       }
950       if (instruction->opcode() != HloOpcode::kFusion) {
951         if (instruction->opcode() == HloOpcode::kCall) {
952           TF_RET_CHECK(proto.called_computation_ids_size() == 1)
953               << "Call should have 1 called computation but has "
954               << proto.called_computation_ids_size();
955         }
956         if (instruction->opcode() == HloOpcode::kWhile) {
957           TF_RET_CHECK(proto.called_computation_ids_size() == 2)
958               << "While should have 2 called computation but has "
959               << proto.called_computation_ids_size();
960         }
961         for (const int64_t computation_id : proto.called_computation_ids()) {
962           instruction->called_computations_.push_back(
963               computation_map.at(computation_id));
964         }
965       }
966       TF_RET_CHECK(!proto.has_precision_config())
967           << instruction->opcode() << proto.DebugString();
968       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
969       break;
970     }
971   }
972 
973   for (const int64_t predecessor_id : proto.control_predecessor_ids()) {
974     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
975         << "No instruction with id " << predecessor_id;
976     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
977                            ->AddControlDependencyTo(instruction.get()));
978   }
979 
980   TF_RET_CHECK(!proto.name().empty());
981   instruction->SetAndSanitizeName(proto.name());
982   instruction->metadata_ = proto.metadata();
983   instruction->backend_config_ = proto.backend_config();
984   instruction->outer_dimension_partitions_.assign(
985       proto.outer_dimension_partitions().begin(),
986       proto.outer_dimension_partitions().end());
987 
988   TF_RET_CHECK(proto.id() >= 0)
989       << "Instruction with negative id: " << proto.id();
990   TF_RET_CHECK(proto.id() <= INT_MAX)
991       << "Instruction with id > INT_MAX: " << proto.id();
992   instruction->unique_id_ = proto.id();
993 
994   if (proto.has_sharding()) {
995     TF_ASSIGN_OR_RETURN(const auto& sharding,
996                         HloSharding::FromProto(proto.sharding()));
997     instruction->set_sharding(sharding);
998   }
999 
1000   if (proto.has_frontend_attributes()) {
1001     instruction->set_frontend_attributes(proto.frontend_attributes());
1002   }
1003 
1004   return std::move(instruction);
1005 }
1006 
CreateParameter(int64_t parameter_number,const Shape & shape,const std::string & name)1007 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
1008     int64_t parameter_number, const Shape& shape, const std::string& name) {
1009   return std::make_unique<HloParameterInstruction>(parameter_number, shape,
1010                                                    name);
1011 }
1012 
CreateConstant(Literal literal)1013 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
1014     Literal literal) {
1015   return std::make_unique<HloConstantInstruction>(std::move(literal));
1016 }
1017 
CreateIota(const Shape & shape,int64_t iota_dimension)1018 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
1019     const Shape& shape, int64_t iota_dimension) {
1020   return std::make_unique<HloIotaInstruction>(shape, iota_dimension);
1021 }
1022 
1023 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64_t index)1024 HloInstruction::CreateGetTupleElement(const Shape& shape,
1025                                       HloInstruction* operand, int64_t index) {
1026   return std::make_unique<HloGetTupleElementInstruction>(shape, operand, index);
1027 }
1028 
1029 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(HloInstruction * operand,int64_t index)1030 HloInstruction::CreateGetTupleElement(HloInstruction* operand, int64_t index) {
1031   return std::make_unique<HloGetTupleElementInstruction>(
1032       operand->shape().tuple_shapes(index), operand, index);
1033 }
1034 
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1035 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
1036     const Shape& shape, RandomDistribution distribution,
1037     absl::Span<HloInstruction* const> parameters) {
1038   return std::make_unique<HloRngInstruction>(shape, distribution, parameters);
1039 }
1040 
1041 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64_t delta)1042 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64_t delta) {
1043   return std::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
1044 }
1045 
1046 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)1047 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
1048                                       RandomAlgorithm algorithm) {
1049   return std::make_unique<HloRngBitGeneratorInstruction>(shape, state,
1050                                                          algorithm);
1051 }
1052 
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1053 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
1054     const Shape& shape, HloOpcode opcode,
1055     absl::Span<HloInstruction* const> operands) {
1056   if (opcode == HloOpcode::kCopy) {
1057     // It is impossible to copy an opaque shape, we don't know how big it is.
1058     CHECK(!shape.IsOpaque());
1059   }
1060   auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
1061   for (auto operand : operands) {
1062     instruction->AppendOperand(operand);
1063   }
1064   return instruction;
1065 }
1066 
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)1067 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
1068     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
1069   // Only certain opcodes are supported with CreateUnary: opcodes of unary
1070   // instructions with no auxiliary fields.
1071   switch (opcode) {
1072     case HloOpcode::kAbs:
1073     case HloOpcode::kAllGatherDone:
1074     case HloOpcode::kAllReduceDone:
1075     case HloOpcode::kRoundNearestAfz:
1076     case HloOpcode::kRoundNearestEven:
1077     case HloOpcode::kBitcast:
1078     case HloOpcode::kCeil:
1079     case HloOpcode::kCollectivePermuteDone:
1080     case HloOpcode::kCopy:
1081     case HloOpcode::kCopyDone:
1082     case HloOpcode::kCos:
1083     case HloOpcode::kOptimizationBarrier:
1084     case HloOpcode::kClz:
1085     case HloOpcode::kExp:
1086     case HloOpcode::kExpm1:
1087     case HloOpcode::kFloor:
1088     case HloOpcode::kImag:
1089     case HloOpcode::kIsFinite:
1090     case HloOpcode::kLog:
1091     case HloOpcode::kLog1p:
1092     case HloOpcode::kNot:
1093     case HloOpcode::kNegate:
1094     case HloOpcode::kPopulationCount:
1095     case HloOpcode::kReal:
1096     case HloOpcode::kRsqrt:
1097     case HloOpcode::kLogistic:
1098     case HloOpcode::kSign:
1099     case HloOpcode::kSin:
1100     case HloOpcode::kSqrt:
1101     case HloOpcode::kCbrt:
1102     case HloOpcode::kTanh:
1103       break;
1104     default:
1105       LOG(FATAL) << "Invalid unary instruction opcode "
1106                  << HloOpcodeString(opcode);
1107   }
1108   return CreateNary(shape, opcode, {operand});
1109 }
1110 
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)1111 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
1112     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1113     HloInstruction* rhs) {
1114   // Only certain opcodes are supported with CreateBinary: opcodes of binary
1115   // instructions with no auxiliary fields.
1116   switch (opcode) {
1117     case HloOpcode::kAdd:
1118     case HloOpcode::kAtan2:
1119     case HloOpcode::kDivide:
1120     case HloOpcode::kComplex:
1121     case HloOpcode::kMaximum:
1122     case HloOpcode::kMinimum:
1123     case HloOpcode::kMultiply:
1124     case HloOpcode::kPower:
1125     case HloOpcode::kRemainder:
1126     case HloOpcode::kSubtract:
1127     case HloOpcode::kAnd:
1128     case HloOpcode::kOr:
1129     case HloOpcode::kXor:
1130     case HloOpcode::kShiftLeft:
1131     case HloOpcode::kShiftRightArithmetic:
1132     case HloOpcode::kShiftRightLogical:
1133       break;
1134     default:
1135       LOG(FATAL) << "Invalid binary instruction opcode "
1136                  << HloOpcodeString(opcode);
1137   }
1138   return CreateNary(shape, opcode, {lhs, rhs});
1139 }
1140 
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)1141 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
1142     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1143     HloInstruction* rhs, HloInstruction* ehs) {
1144   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
1145   // instructions with no auxiliary fields.
1146   switch (opcode) {
1147     case HloOpcode::kClamp:
1148     case HloOpcode::kSelect:
1149       break;
1150     default:
1151       LOG(FATAL) << "Invalid ternary instruction opcode "
1152                  << HloOpcodeString(opcode);
1153   }
1154   return CreateNary(shape, opcode, {lhs, rhs, ehs});
1155 }
1156 
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1157 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
1158     const Shape& shape, HloOpcode opcode,
1159     absl::Span<HloInstruction* const> operands) {
1160   CHECK_EQ(HloOpcode::kTuple, opcode);
1161   return CreateNary(shape, opcode, operands);
1162 }
1163 
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1164 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
1165     const Shape& shape, absl::Span<HloInstruction* const> operands,
1166     HloComputation* map_computation) {
1167   return std::make_unique<HloMapInstruction>(shape, operands, map_computation);
1168 }
1169 
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1170 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
1171     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1172     int64_t feature_group_count, int64_t batch_group_count,
1173     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
1174     const PrecisionConfig& precision_config) {
1175   return std::make_unique<HloConvolutionInstruction>(
1176       shape, lhs, rhs, feature_group_count, batch_group_count, window,
1177       dimension_numbers, precision_config);
1178 }
1179 
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64_t> fft_length)1180 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
1181     const Shape& shape, HloInstruction* operand, FftType fft_type,
1182     absl::Span<const int64_t> fft_length) {
1183   return std::make_unique<HloFftInstruction>(shape, operand, fft_type,
1184                                              fft_length);
1185 }
1186 
CreateAsyncStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1187 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncStart(
1188     const Shape& shape, absl::Span<HloInstruction* const> operands,
1189     HloComputation* async_computation, std::optional<int64_t> async_group_id,
1190     absl::string_view async_execution_thread) {
1191   return std::make_unique<HloAsyncInstruction>(
1192       HloOpcode::kAsyncStart, shape, operands, async_computation,
1193       async_group_id, async_execution_thread);
1194 }
1195 
CreateAsyncUpdate(const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1196 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncUpdate(
1197     const Shape& shape, HloInstruction* operand,
1198     HloComputation* async_computation, std::optional<int64_t> async_group_id,
1199     absl::string_view async_execution_thread) {
1200   return std::make_unique<HloAsyncInstruction>(
1201       HloOpcode::kAsyncUpdate, shape, operand, async_computation,
1202       async_group_id, async_execution_thread);
1203 }
1204 
CreateAsyncDone(const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1205 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncDone(
1206     const Shape& shape, HloInstruction* operand,
1207     HloComputation* async_computation, std::optional<int64_t> async_group_id,
1208     absl::string_view async_execution_thread) {
1209   return std::make_unique<HloAsyncInstruction>(
1210       HloOpcode::kAsyncDone, shape, operand, async_computation, async_group_id,
1211       async_execution_thread);
1212 }
1213 
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1214 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1215     const Shape& shape, HloInstruction* operand,
1216     bool is_cross_program_prefetch) {
1217   return std::make_unique<HloCopyStartInstruction>(shape, operand,
1218                                                    is_cross_program_prefetch);
1219 }
1220 
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,std::optional<Comparison::Type> type)1221 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1222     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1223     ComparisonDirection direction, std::optional<Comparison::Type> type) {
1224   return std::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1225                                                  type);
1226 }
1227 
1228 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1229 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1230                                       HloInstruction* b,
1231                                       const TriangularSolveOptions& options) {
1232   return std::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1233 }
1234 
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1235 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1236     const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1237   return std::make_unique<HloCholeskyInstruction>(shape, a, options);
1238 }
1239 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1240 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1241     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1242     const DotDimensionNumbers& dimension_numbers,
1243     const PrecisionConfig& precision_config) {
1244   return std::make_unique<HloDotInstruction>(shape, lhs, rhs, dimension_numbers,
1245                                              precision_config);
1246 }
1247 
1248 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1249 HloInstruction::CreateReducePrecision(const Shape& shape,
1250                                       HloInstruction* operand,
1251                                       const int exponent_bits,
1252                                       const int mantissa_bits) {
1253   return std::make_unique<HloReducePrecisionInstruction>(
1254       shape, operand, exponent_bits, mantissa_bits);
1255 }
1256 
CreateAllGather(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1257 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1258     const Shape& shape, absl::Span<HloInstruction* const> operands,
1259     int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1260     bool constrain_layout, const std::optional<int64_t>& channel_id,
1261     bool use_global_device_ids) {
1262   return std::make_unique<HloAllGatherInstruction>(
1263       HloOpcode::kAllGather, shape, operands, all_gather_dimension,
1264       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1265 }
1266 
1267 /* static */ std::unique_ptr<HloInstruction>
CreateAllGatherStart(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1268 HloInstruction::CreateAllGatherStart(
1269     const Shape& shape, absl::Span<HloInstruction* const> operands,
1270     int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1271     bool constrain_layout, const std::optional<int64_t>& channel_id,
1272     bool use_global_device_ids) {
1273   return std::make_unique<HloAllGatherInstruction>(
1274       HloOpcode::kAllGatherStart, shape, operands, all_gather_dimension,
1275       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1276 }
1277 
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1278 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1279     const Shape& shape, absl::Span<HloInstruction* const> operands,
1280     HloComputation* reduce_computation,
1281     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1282     const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
1283   return std::make_unique<HloAllReduceInstruction>(
1284       HloOpcode::kAllReduce, shape, operands, reduce_computation,
1285       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1286 }
1287 
1288 /* static */ std::unique_ptr<HloInstruction>
CreateReduceScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)1289 HloInstruction::CreateReduceScatter(
1290     const Shape& shape, absl::Span<HloInstruction* const> operands,
1291     HloComputation* reduce_computation,
1292     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1293     const std::optional<int64_t>& channel_id, bool use_global_device_ids,
1294     int64_t scatter_dimension) {
1295   return std::make_unique<HloReduceScatterInstruction>(
1296       shape, operands, reduce_computation, replica_groups, constrain_layout,
1297       channel_id, use_global_device_ids, scatter_dimension);
1298 }
1299 
1300 /* static */ std::unique_ptr<HloInstruction>
CreateAllReduceStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1301 HloInstruction::CreateAllReduceStart(
1302     const Shape& shape, absl::Span<HloInstruction* const> operands,
1303     HloComputation* reduce_computation,
1304     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1305     const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
1306   return std::make_unique<HloAllReduceInstruction>(
1307       HloOpcode::kAllReduceStart, shape, operands, reduce_computation,
1308       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1309 }
1310 
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,const std::optional<int64_t> & split_dimension)1311 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1312     const Shape& shape, absl::Span<HloInstruction* const> operands,
1313     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1314     const std::optional<int64_t>& channel_id,
1315     const std::optional<int64_t>& split_dimension) {
1316   return std::make_unique<HloAllToAllInstruction>(
1317       shape, operands, replica_groups, constrain_layout, channel_id,
1318       split_dimension);
1319 }
1320 
1321 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)1322 HloInstruction::CreateCollectivePermute(
1323     const Shape& shape, HloInstruction* operand,
1324     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1325     const std::optional<int64_t>& channel_id) {
1326   return std::make_unique<HloCollectivePermuteInstruction>(
1327       HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1328       channel_id);
1329 }
1330 
1331 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)1332 HloInstruction::CreateCollectivePermute(
1333     const Shape& shape, HloInstruction* input, HloInstruction* output,
1334     HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1335     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1336     absl::Span<const std::vector<int64_t>> slice_sizes,
1337     const std::optional<int64_t>& channel_id) {
1338   return std::make_unique<HloCollectivePermuteInstruction>(
1339       HloOpcode::kCollectivePermute, shape, input, output, input_start_indices,
1340       output_start_indices, source_target_pairs, slice_sizes, channel_id);
1341 }
1342 
1343 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)1344 HloInstruction::CreateCollectivePermuteStart(
1345     const Shape& shape, HloInstruction* operand,
1346     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1347     const std::optional<int64_t>& channel_id) {
1348   return std::make_unique<HloCollectivePermuteInstruction>(
1349       HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1350       channel_id);
1351 }
1352 
1353 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)1354 HloInstruction::CreateCollectivePermuteStart(
1355     const Shape& shape, HloInstruction* input, HloInstruction* output,
1356     HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1357     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1358     absl::Span<const std::vector<int64_t>> slice_sizes,
1359     const std::optional<int64_t>& channel_id) {
1360   return std::make_unique<HloCollectivePermuteInstruction>(
1361       HloOpcode::kCollectivePermuteStart, shape, input, output,
1362       input_start_indices, output_start_indices, source_target_pairs,
1363       slice_sizes, channel_id);
1364 }
1365 
CreateReplicaId(const Shape & shape)1366 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1367     const Shape& shape) {
1368   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1369       << "HloInstruction replica-id must have a shape of u32[], but "
1370       << shape.ToString() << " is specified";
1371   return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1372 }
1373 
CreatePartitionId(const Shape & shape)1374 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1375     const Shape& shape) {
1376   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1377       << "HloInstruction partition-id must have a shape of u32[], but "
1378       << shape.ToString() << " is specified";
1379   return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1380 }
1381 
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const std::string & config)1382 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1383     const Shape& infeed_shape, HloInstruction* token_operand,
1384     const std::string& config) {
1385   return std::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1386                                                 config);
1387 }
1388 
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1389 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1390     const Shape& outfeed_shape, HloInstruction* operand,
1391     HloInstruction* token_operand, absl::string_view outfeed_config) {
1392   return std::make_unique<HloOutfeedInstruction>(outfeed_shape, operand,
1393                                                  token_operand, outfeed_config);
1394 }
1395 
CreateSend(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1396 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1397     HloInstruction* operand, HloInstruction* token, int64_t channel_id,
1398     bool is_host_transfer) {
1399   return std::make_unique<HloSendInstruction>(operand, token, channel_id,
1400                                               is_host_transfer);
1401 }
1402 
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1403 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1404     HloInstruction* operand, bool is_host_transfer) {
1405   auto send_operand = DynCast<HloSendInstruction>(operand);
1406   CHECK(send_operand != nullptr)
1407       << "SendDone must take the context operand from Send";
1408   return std::make_unique<HloSendDoneInstruction>(send_operand,
1409                                                   is_host_transfer);
1410 }
1411 
CreateRecv(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1412 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1413     const Shape& shape, HloInstruction* token, int64_t channel_id,
1414     bool is_host_transfer) {
1415   return std::make_unique<HloRecvInstruction>(shape, token, channel_id,
1416                                               is_host_transfer);
1417 }
1418 
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1419 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1420     HloInstruction* operand, bool is_host_transfer) {
1421   auto recv_operand = DynCast<HloRecvInstruction>(operand);
1422   CHECK(recv_operand != nullptr)
1423       << "RecvDone must take the context operand from Recv";
1424   return std::make_unique<HloRecvDoneInstruction>(recv_operand,
1425                                                   is_host_transfer);
1426 }
1427 
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1428 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1429     const Shape& shape, HloInstruction* operand,
1430     absl::Span<const int64_t> dimensions) {
1431   return std::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1432 }
1433 
CreateAfterAll(absl::Span<HloInstruction * const> operands)1434 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1435     absl::Span<HloInstruction* const> operands) {
1436   CHECK(!operands.empty());
1437   auto instruction = absl::WrapUnique(
1438       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1439   for (auto operand : operands) {
1440     instruction->AppendOperand(operand);
1441   }
1442   return instruction;
1443 }
1444 
CreateToken()1445 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1446   return absl::WrapUnique(
1447       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1448 }
1449 
1450 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1451 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1452                                     HloInstruction* token_operand) {
1453   auto instruction = absl::WrapUnique(
1454       new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1455   instruction->AppendOperand(data_operand);
1456   instruction->AppendOperand(token_operand);
1457   return instruction;
1458 }
1459 
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1460 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1461     const Shape& shape, HloComputation* condition, HloComputation* body,
1462     HloInstruction* init) {
1463   auto instruction =
1464       absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1465   instruction->AppendOperand(init);
1466   // Body comes before condition computation in the vector.
1467   instruction->called_computations_.push_back(body);
1468   instruction->called_computations_.push_back(condition);
1469   return instruction;
1470 }
1471 
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1472 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1473     const Shape& shape, HloInstruction* pred,
1474     HloInstruction* true_computation_arg, HloComputation* true_computation,
1475     HloInstruction* false_computation_arg, HloComputation* false_computation) {
1476   auto instruction =
1477       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1478   instruction->AppendOperand(pred);
1479   instruction->AppendOperand(true_computation_arg);
1480   instruction->AppendOperand(false_computation_arg);
1481   // In called_computations_, the index of true_computation must be 0 and that
1482   // of false computation must be 1, as defined by kTrueComputationIndex and
1483   // kFalseComputationIndex.
1484   instruction->called_computations_.push_back(true_computation);
1485   instruction->called_computations_.push_back(false_computation);
1486   return instruction;
1487 }
1488 
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1489 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1490     const Shape& shape, HloInstruction* branch_index,
1491     absl::Span<HloComputation* const> branch_computations,
1492     absl::Span<HloInstruction* const> branch_computation_args) {
1493   auto instruction =
1494       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1495   instruction->AppendOperand(branch_index);
1496   CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1497   for (int i = 0; i < branch_computations.size(); ++i) {
1498     instruction->called_computations_.push_back(branch_computations[i]);
1499     instruction->AppendOperand(branch_computation_args[i]);
1500   }
1501   return instruction;
1502 }
1503 
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1504 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1505     const Shape& shape, HloInstruction* operand,
1506     absl::Span<const int64_t> start_indices,
1507     absl::Span<const int64_t> limit_indices,
1508     absl::Span<const int64_t> strides) {
1509   return std::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1510                                                limit_indices, strides);
1511 }
1512 
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes)1513 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1514     const Shape& shape, HloInstruction* operand,
1515     absl::Span<HloInstruction* const> start_indices,
1516     absl::Span<const int64_t> slice_sizes) {
1517   return std::make_unique<HloDynamicSliceInstruction>(
1518       shape, operand, start_indices, slice_sizes);
1519 }
1520 
1521 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1522 HloInstruction::CreateDynamicUpdateSlice(
1523     const Shape& shape, HloInstruction* operand, HloInstruction* update,
1524     absl::Span<HloInstruction* const> start_indices) {
1525   return std::make_unique<HloDynamicUpdateSliceInstruction>(
1526       shape, operand, update, start_indices);
1527 }
1528 
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1529 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1530     const Shape& shape, absl::Span<HloInstruction* const> operands,
1531     int64_t dimension) {
1532   return std::make_unique<HloConcatenateInstruction>(shape, operands,
1533                                                      dimension);
1534 }
1535 
CreateConvert(const Shape & shape,HloInstruction * operand)1536 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1537     const Shape& shape, HloInstruction* operand) {
1538   auto instruction =
1539       absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1540   instruction->AppendOperand(operand);
1541   return instruction;
1542 }
1543 
1544 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1545 HloInstruction::CreateBitcastConvert(const Shape& shape,
1546                                      HloInstruction* operand) {
1547   auto instruction =
1548       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1549   instruction->AppendOperand(operand);
1550   return instruction;
1551 }
1552 
CreateBitcast(const Shape & shape,HloInstruction * operand)1553 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1554     const Shape& shape, HloInstruction* operand) {
1555   auto instruction =
1556       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1557   instruction->AppendOperand(operand);
1558   return instruction;
1559 }
1560 
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1561 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1562     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1563     absl::Span<const int64_t> dimensions_to_reduce,
1564     HloComputation* reduce_computation) {
1565   auto instruction = absl::WrapUnique(new HloReduceInstruction(
1566       shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1567   return std::move(instruction);
1568 }
1569 
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1570 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1571     const Shape& shape, absl::Span<HloInstruction* const> operands,
1572     absl::Span<HloInstruction* const> init_values,
1573     absl::Span<const int64_t> dimensions_to_reduce,
1574     HloComputation* reduce_computation) {
1575   std::vector<HloInstruction*> all_args;
1576   all_args.reserve(operands.size() * 2);
1577   all_args.insert(all_args.end(), operands.begin(), operands.end());
1578   all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1579   return std::make_unique<HloReduceInstruction>(
1580       shape, all_args, dimensions_to_reduce, reduce_computation);
1581 }
1582 
CreateReduce(const Shape & shape,HloInstruction * tuple_of_instructions,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1583 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1584     const Shape& shape, HloInstruction* tuple_of_instructions,
1585     absl::Span<HloInstruction* const> init_values,
1586     absl::Span<const int64_t> dimensions_to_reduce,
1587     HloComputation* reduce_computation) {
1588   if (!tuple_of_instructions->shape().IsTuple()) {
1589     CHECK_EQ(init_values.size(), 1)
1590         << "The first input has to be a tuple, or the number of init values "
1591            "has to be one.";
1592     return CreateReduce(shape, tuple_of_instructions, init_values[0],
1593                         dimensions_to_reduce, reduce_computation);
1594   }
1595   absl::InlinedVector<HloInstruction*, 4> inputs;
1596   for (int idx = 0; idx < tuple_of_instructions->shape().tuple_shapes_size();
1597        idx++) {
1598     std::unique_ptr<HloInstruction> gte =
1599         HloInstruction::CreateGetTupleElement(tuple_of_instructions, idx);
1600     inputs.push_back(
1601         tuple_of_instructions->parent()->AddInstruction(std::move(gte)));
1602   }
1603   return CreateReduce(shape, inputs, init_values, dimensions_to_reduce,
1604                       reduce_computation);
1605 }
1606 
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1607 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1608     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1609     const Window& window, HloComputation* reduce_computation) {
1610   return std::make_unique<HloReduceWindowInstruction>(
1611       shape, operand, init_value, window, reduce_computation);
1612 }
1613 
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1614 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1615     const Shape& shape, absl::Span<HloInstruction* const> operands,
1616     absl::Span<HloInstruction* const> init_values, const Window& window,
1617     HloComputation* reduce_computation) {
1618   return std::make_unique<HloReduceWindowInstruction>(
1619       shape, operands, init_values, window, reduce_computation);
1620 }
1621 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)1622 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1623                                         HloInstruction* operand,
1624                                         HloInstruction* scale,
1625                                         HloInstruction* offset, float epsilon,
1626                                         int64_t feature_index) {
1627   return std::make_unique<HloBatchNormTrainingInstruction>(
1628       shape, operand, scale, offset, epsilon, feature_index);
1629 }
1630 
1631 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)1632 HloInstruction::CreateBatchNormInference(
1633     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1634     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1635     float epsilon, int64_t feature_index) {
1636   return std::make_unique<HloBatchNormInferenceInstruction>(
1637       shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1638 }
1639 
1640 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)1641 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1642                                     HloInstruction* scale, HloInstruction* mean,
1643                                     HloInstruction* variance,
1644                                     HloInstruction* grad_output, float epsilon,
1645                                     int64_t feature_index) {
1646   return std::make_unique<HloBatchNormGradInstruction>(
1647       shape, operand, scale, mean, variance, grad_output, epsilon,
1648       feature_index);
1649 }
1650 
1651 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1652 HloInstruction::CreateSelectAndScatter(
1653     const Shape& shape, HloInstruction* operand, HloComputation* select,
1654     const Window& window, HloInstruction* source, HloInstruction* init_value,
1655     HloComputation* scatter) {
1656   return std::make_unique<HloSelectAndScatterInstruction>(
1657       shape, operand, select, window, source, init_value, scatter);
1658 }
1659 
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions)1660 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1661     const Shape& shape, HloInstruction* operand,
1662     absl::Span<const int64_t> broadcast_dimensions) {
1663   return std::make_unique<HloBroadcastInstruction>(shape, operand,
1664                                                    broadcast_dimensions);
1665 }
1666 
1667 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64_t dimension)1668 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1669                                        HloInstruction* operand,
1670                                        int64_t dimension) {
1671   return std::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1672                                                           dimension);
1673 }
1674 
1675 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)1676 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1677                                        HloInstruction* operand,
1678                                        HloInstruction* val, int64_t dimension) {
1679   return std::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1680                                                           dimension);
1681 }
1682 
1683 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1684 HloInstruction::CreateBroadcastSequence(
1685     const Shape& output_shape, HloInstruction* operand,
1686     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1687         adder) {
1688   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1689         operand->shape().rank() == output_shape.rank());
1690   Shape broadcast_shape = ShapeUtil::ChangeElementType(
1691       output_shape, operand->shape().element_type());
1692   // Do explicit broadcast for scalar.
1693   if (ShapeUtil::IsScalar(operand->shape())) {
1694     auto broadcast =
1695         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1696     broadcast->set_metadata(operand->metadata());
1697     if (operand->has_sharding()) {
1698       broadcast->set_sharding(operand->sharding());
1699     }
1700     broadcast->set_frontend_attributes(operand->frontend_attributes());
1701     return broadcast;
1702   }
1703   // Do explicit broadcast for degenerate broadcast.
1704   std::vector<int64_t> broadcast_dimensions;
1705   std::vector<int64_t> reshaped_dimensions;
1706   for (int i = 0; i < operand->shape().rank(); i++) {
1707     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1708       broadcast_dimensions.push_back(i);
1709       reshaped_dimensions.push_back(operand->shape().dimensions(i));
1710     } else {
1711       CHECK_EQ(operand->shape().dimensions(i), 1)
1712           << "An explicit broadcast sequence requires the broadcasted "
1713              "dimensions to be trivial; operand: "
1714           << operand->ToString() << "; output_shape: " << output_shape;
1715     }
1716   }
1717   // Eliminate the size one dimensions.
1718   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1719       ShapeUtil::MakeShape(operand->shape().element_type(),
1720                            reshaped_dimensions),
1721       operand));
1722   reshaped_operand->set_metadata(operand->metadata());
1723   if (operand->has_sharding()) {
1724     reshaped_operand->set_sharding(operand->sharding());
1725   }
1726   reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1727   // Broadcast 'reshape' up to the larger size.
1728   auto broadcast = HloInstruction::CreateBroadcast(
1729       broadcast_shape, reshaped_operand, broadcast_dimensions);
1730   broadcast->set_metadata(operand->metadata());
1731   if (operand->has_sharding()) {
1732     broadcast->set_sharding(operand->sharding());
1733   }
1734   broadcast->set_frontend_attributes(operand->frontend_attributes());
1735   return broadcast;
1736 }
1737 
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1738 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1739     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1740     const PaddingConfig& padding_config) {
1741   return std::make_unique<HloPadInstruction>(shape, operand, padding_value,
1742                                              padding_config);
1743 }
1744 
CreateReshape(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1745 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1746     const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) {
1747   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1748            ShapeUtil::ElementsIn(operand->shape()))
1749       << "shape: " << ShapeUtil::HumanString(shape)
1750       << " operand: " << ShapeUtil::HumanString(operand->shape());
1751 
1752   return std::make_unique<HloReshapeInstruction>(shape, operand,
1753                                                  inferred_dimension);
1754 }
1755 
1756 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1757 HloInstruction::CreateDynamicReshape(
1758     const Shape& shape, HloInstruction* data_operand,
1759     absl::Span<HloInstruction* const> dim_sizes) {
1760   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1761            ShapeUtil::ElementsIn(data_operand[0].shape()))
1762       << "shape: " << ShapeUtil::HumanString(shape)
1763       << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1764   CHECK_EQ(shape.rank(), dim_sizes.size());
1765   return std::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1766                                                         dim_sizes);
1767 }
1768 
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1769 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1770     const Shape& shape, HloInstruction* operand,
1771     absl::Span<const int64_t> dimensions) {
1772   return std::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1773 }
1774 
CreateSort(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1775 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1776     const Shape& shape, int64_t dimension,
1777     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1778     bool is_stable) {
1779   return std::make_unique<HloSortInstruction>(shape, dimension, operands,
1780                                               compare, is_stable);
1781 }
1782 
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1783 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1784     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1785   return std::make_unique<HloFusionInstruction>(shape, fusion_kind, fused_root);
1786 }
1787 
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation,absl::string_view prefix)1788 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1789     const Shape& shape, FusionKind fusion_kind,
1790     absl::Span<HloInstruction* const> operands,
1791     HloComputation* fusion_computation, absl::string_view prefix) {
1792   return std::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1793                                                 fusion_computation, prefix);
1794 }
1795 
set_single_sharding(const HloSharding & sharding)1796 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1797   CHECK(!sharding.IsTuple()) << sharding;
1798   if (shape().IsTuple()) {
1799     set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1800   } else {
1801     set_sharding(sharding);
1802   }
1803 }
1804 
SetupDerivedInstruction(HloInstruction * derived_instruction) const1805 void HloInstruction::SetupDerivedInstruction(
1806     HloInstruction* derived_instruction) const {
1807   if (sharding_ != nullptr &&
1808       ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1809     // Only copy sharding if the tuple tree shape of the two instruction is
1810     // compatible because copying it between differently shaped instructions
1811     // can produce invalid shardings.
1812     derived_instruction->set_sharding(*sharding_);
1813   } else {
1814     derived_instruction->clear_sharding();
1815   }
1816   derived_instruction->set_metadata(metadata_);
1817   derived_instruction->set_frontend_attributes(frontend_attributes_);
1818 }
1819 
IsRoot() const1820 bool HloInstruction::IsRoot() const {
1821   return parent_ != nullptr && this == parent_->root_instruction();
1822 }
1823 
HasSideEffectNoRecurse() const1824 bool HloInstruction::HasSideEffectNoRecurse() const {
1825   switch (opcode_) {
1826     case HloOpcode::kSend:
1827     case HloOpcode::kSendDone:
1828     case HloOpcode::kRecv:
1829     case HloOpcode::kRecvDone:
1830     case HloOpcode::kRng:
1831     case HloOpcode::kRngGetAndUpdateState:
1832     case HloOpcode::kInfeed:
1833     case HloOpcode::kOutfeed:
1834     case HloOpcode::kAllReduceStart:
1835     case HloOpcode::kAllReduceDone:
1836     case HloOpcode::kAllGatherStart:
1837     case HloOpcode::kAllGatherDone:
1838     case HloOpcode::kCollectivePermuteStart:
1839     case HloOpcode::kCollectivePermuteDone:
1840       return true;
1841     case HloOpcode::kAllReduce:
1842       return channel_id().has_value() ||
1843              Cast<HloAllReduceInstruction>(this)->constrain_layout();
1844     case HloOpcode::kAllToAll:
1845       return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1846     case HloOpcode::kCustomCall:
1847       return Cast<HloCustomCallInstruction>(this)
1848           ->custom_call_has_side_effect();
1849     default:
1850       return false;
1851   }
1852 }
1853 
HasSideEffect() const1854 bool HloInstruction::HasSideEffect() const {
1855   if (HasSideEffectNoRecurse()) {
1856     return true;
1857   }
1858   // Check if any of the called computations has a side effect.
1859   for (const auto& computation : called_computations()) {
1860     if (computation->HasSideEffect()) {
1861       return true;
1862     }
1863   }
1864   return false;
1865 }
1866 
CreateCall(const Shape & shape,HloInstruction * called_computation_root)1867 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1868     const Shape& shape, HloInstruction* called_computation_root) {
1869   return std::make_unique<HloCallInstruction>(shape, called_computation_root);
1870 }
1871 
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1872 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1873     const Shape& shape, absl::Span<HloInstruction* const> operands,
1874     HloComputation* computation) {
1875   return std::make_unique<HloCallInstruction>(shape, operands, computation);
1876 }
1877 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1878 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1879     const Shape& shape, absl::Span<HloInstruction* const> operands,
1880     absl::string_view custom_call_target, std::string opaque,
1881     CustomCallApiVersion api_version) {
1882   return std::make_unique<HloCustomCallInstruction>(
1883       shape, operands, custom_call_target, std::move(opaque), api_version);
1884 }
1885 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1886 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1887     const Shape& shape, absl::Span<HloInstruction* const> operands,
1888     HloComputation* to_apply, absl::string_view custom_call_target,
1889     std::string opaque, CustomCallApiVersion api_version) {
1890   return std::make_unique<HloCustomCallInstruction>(
1891       shape, operands, to_apply, custom_call_target, std::move(opaque),
1892       api_version);
1893 }
1894 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1895 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1896     const Shape& shape, absl::Span<HloInstruction* const> operands,
1897     absl::Span<HloComputation* const> called_computations,
1898     absl::string_view custom_call_target, std::string opaque,
1899     CustomCallApiVersion api_version) {
1900   return std::make_unique<HloCustomCallInstruction>(
1901       shape, operands, called_computations, custom_call_target,
1902       std::move(opaque), api_version);
1903 }
1904 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,std::string opaque,CustomCallApiVersion api_version)1905 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1906     const Shape& shape, absl::Span<HloInstruction* const> operands,
1907     absl::string_view custom_call_target,
1908     absl::Span<const Shape> operand_shapes_with_layout, std::string opaque,
1909     CustomCallApiVersion api_version) {
1910   return std::make_unique<HloCustomCallInstruction>(
1911       shape, operands, custom_call_target, std::move(opaque),
1912       operand_shapes_with_layout, api_version);
1913 }
1914 
CreateTuple(absl::Span<HloInstruction * const> elements)1915 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1916     absl::Span<HloInstruction* const> elements) {
1917   std::vector<const Shape*> element_shapes;
1918   element_shapes.reserve(elements.size());
1919   for (auto element : elements) {
1920     element_shapes.push_back(&element->shape());
1921   }
1922   Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(element_shapes);
1923   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1924 }
1925 
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)1926 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1927     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1928     const GatherDimensionNumbers& gather_dim_numbers,
1929     absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
1930   return std::make_unique<HloGatherInstruction>(shape, operand, start_indices,
1931                                                 gather_dim_numbers, slice_sizes,
1932                                                 indices_are_sorted);
1933 }
1934 
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1935 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1936     const Shape& shape, HloInstruction* operand,
1937     HloInstruction* scatter_indices, HloInstruction* updates,
1938     HloComputation* update_computation,
1939     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1940     bool unique_indices) {
1941   return absl::WrapUnique(new HloScatterInstruction(
1942       shape, {operand, scatter_indices, updates}, update_computation,
1943       scatter_dim_numbers, indices_are_sorted, unique_indices));
1944 }
1945 
CreateScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloInstruction * scatter_indices,absl::Span<HloInstruction * const> updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1946 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1947     const Shape& shape, absl::Span<HloInstruction* const> operands,
1948     HloInstruction* scatter_indices, absl::Span<HloInstruction* const> updates,
1949     HloComputation* update_computation,
1950     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1951     bool unique_indices) {
1952   absl::InlinedVector<HloInstruction*, 3> args;
1953   args.reserve(operands.size() + updates.size() + 1);
1954   absl::c_copy(operands, std::back_inserter(args));
1955   args.push_back(scatter_indices);
1956   absl::c_copy(updates, std::back_inserter(args));
1957   return std::make_unique<HloScatterInstruction>(
1958       shape, args, update_computation, scatter_dim_numbers, indices_are_sorted,
1959       unique_indices);
1960 }
1961 
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1962 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1963     const Shape& shape, HloInstruction* operand,
1964     std::unique_ptr<DomainMetadata> operand_side_metadata,
1965     std::unique_ptr<DomainMetadata> user_side_metadata) {
1966   return std::make_unique<HloDomainInstruction>(
1967       shape, operand, std::move(operand_side_metadata),
1968       std::move(user_side_metadata));
1969 }
1970 
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1971 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1972     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1973     HloCloneContext* context) const {
1974   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
1975   VLOG(3) << "  new operands:";
1976   for (const HloInstruction* new_operand : new_operands) {
1977     VLOG(3) << "    %" << new_operand->name();
1978   }
1979 
1980   std::unique_ptr<HloInstruction> clone;
1981   // Explicitly call the factory for the instruction type. This is more robust
1982   // in the face of code changes than copying fields explicitly. This also
1983   // properly sets the user fields of the operands.
1984   switch (opcode_) {
1985     // Ops migrated to subclasses.
1986     // TODO(b/80131774): Remove this switch when migration is complete.
1987     case HloOpcode::kBatchNormTraining:
1988     case HloOpcode::kBatchNormInference:
1989     case HloOpcode::kBatchNormGrad:
1990     case HloOpcode::kFft:
1991     case HloOpcode::kCompare:
1992     case HloOpcode::kAsyncStart:
1993     case HloOpcode::kAsyncUpdate:
1994     case HloOpcode::kAsyncDone:
1995     case HloOpcode::kCopyStart:
1996     case HloOpcode::kSend:
1997     case HloOpcode::kSendDone:
1998     case HloOpcode::kRecv:
1999     case HloOpcode::kRecvDone:
2000     case HloOpcode::kReverse:
2001     case HloOpcode::kConcatenate:
2002     case HloOpcode::kReduce:
2003     case HloOpcode::kTranspose:
2004     case HloOpcode::kBroadcast:
2005     case HloOpcode::kReshape:
2006     case HloOpcode::kDynamicReshape:
2007     case HloOpcode::kMap:
2008     case HloOpcode::kSlice:
2009     case HloOpcode::kConstant:
2010     case HloOpcode::kFusion:
2011     case HloOpcode::kRng:
2012     case HloOpcode::kRngBitGenerator:
2013     case HloOpcode::kRngGetAndUpdateState:
2014     case HloOpcode::kParameter:
2015     case HloOpcode::kGetTupleElement:
2016     case HloOpcode::kReducePrecision:
2017     case HloOpcode::kAllGather:
2018     case HloOpcode::kAllGatherStart:
2019     case HloOpcode::kAllReduce:
2020     case HloOpcode::kReduceScatter:
2021     case HloOpcode::kAllReduceStart:
2022     case HloOpcode::kAllToAll:
2023     case HloOpcode::kCollectivePermute:
2024     case HloOpcode::kCollectivePermuteStart:
2025     case HloOpcode::kInfeed:
2026     case HloOpcode::kOutfeed:
2027     case HloOpcode::kConvolution:
2028     case HloOpcode::kCustomCall:
2029     case HloOpcode::kReduceWindow:
2030     case HloOpcode::kSelectAndScatter:
2031     case HloOpcode::kPad:
2032     case HloOpcode::kDynamicSlice:
2033     case HloOpcode::kSort:
2034     case HloOpcode::kGather:
2035     case HloOpcode::kScatter:
2036     case HloOpcode::kIota:
2037     case HloOpcode::kDot:
2038     case HloOpcode::kDomain:
2039     case HloOpcode::kGetDimensionSize:
2040     case HloOpcode::kSetDimensionSize:
2041     case HloOpcode::kTriangularSolve:
2042     case HloOpcode::kCholesky:
2043       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
2044       break;
2045     // Unary ops.
2046     case HloOpcode::kAbs:
2047     case HloOpcode::kAllGatherDone:
2048     case HloOpcode::kAllReduceDone:
2049     case HloOpcode::kRoundNearestAfz:
2050     case HloOpcode::kRoundNearestEven:
2051     case HloOpcode::kBitcast:
2052     case HloOpcode::kCeil:
2053     case HloOpcode::kClz:
2054     case HloOpcode::kCollectivePermuteDone:
2055     case HloOpcode::kCopy:
2056     case HloOpcode::kOptimizationBarrier:
2057     case HloOpcode::kCopyDone:
2058     case HloOpcode::kCos:
2059     case HloOpcode::kExp:
2060     case HloOpcode::kExpm1:
2061     case HloOpcode::kImag:
2062     case HloOpcode::kIsFinite:
2063     case HloOpcode::kFloor:
2064     case HloOpcode::kLog:
2065     case HloOpcode::kLog1p:
2066     case HloOpcode::kNot:
2067     case HloOpcode::kNegate:
2068     case HloOpcode::kPopulationCount:
2069     case HloOpcode::kReal:
2070     case HloOpcode::kRsqrt:
2071     case HloOpcode::kLogistic:
2072     case HloOpcode::kSign:
2073     case HloOpcode::kSin:
2074     case HloOpcode::kSqrt:
2075     case HloOpcode::kCbrt:
2076     case HloOpcode::kTanh:
2077       CHECK_EQ(new_operands.size(), 1);
2078       clone = CreateUnary(shape, opcode_, new_operands[0]);
2079       break;
2080     // Binary ops.
2081     case HloOpcode::kAdd:
2082     case HloOpcode::kAtan2:
2083     case HloOpcode::kComplex:
2084     case HloOpcode::kDivide:
2085     case HloOpcode::kMultiply:
2086     case HloOpcode::kSubtract:
2087     case HloOpcode::kMaximum:
2088     case HloOpcode::kMinimum:
2089     case HloOpcode::kPower:
2090     case HloOpcode::kRemainder:
2091     case HloOpcode::kAnd:
2092     case HloOpcode::kOr:
2093     case HloOpcode::kXor:
2094     case HloOpcode::kShiftLeft:
2095     case HloOpcode::kShiftRightArithmetic:
2096     case HloOpcode::kShiftRightLogical:
2097       CHECK_EQ(new_operands.size(), 2);
2098       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
2099       break;
2100     // Ternary ops.
2101     case HloOpcode::kClamp:
2102     case HloOpcode::kSelect:
2103       CHECK_EQ(new_operands.size(), 3);
2104       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
2105                             new_operands[2]);
2106       break;
2107     // Other supported ops.
2108     case HloOpcode::kCall:
2109       clone = CreateCall(shape, new_operands, to_apply());
2110       break;
2111     case HloOpcode::kConvert:
2112       CHECK_EQ(new_operands.size(), 1);
2113       clone = CreateConvert(shape, new_operands[0]);
2114       break;
2115     case HloOpcode::kBitcastConvert:
2116       CHECK_EQ(new_operands.size(), 1);
2117       clone = CreateBitcastConvert(shape, new_operands[0]);
2118       break;
2119     case HloOpcode::kDynamicUpdateSlice:
2120       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
2121                                        new_operands.subspan(2));
2122       break;
2123     case HloOpcode::kTuple:
2124       clone = CreateTuple(new_operands);
2125       *clone->mutable_shape() = shape;
2126       break;
2127     case HloOpcode::kWhile:
2128       CHECK_EQ(new_operands.size(), 1);
2129       clone =
2130           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
2131       break;
2132     case HloOpcode::kConditional:
2133       CHECK_EQ(new_operands.size(), branch_count() + 1);
2134       clone = CreateConditional(shape, new_operands[0],
2135                                 absl::MakeSpan(branch_computations()),
2136                                 new_operands.subspan(1));
2137       break;
2138     case HloOpcode::kAfterAll:
2139       if (new_operands.empty()) {
2140         clone = CreateToken();
2141       } else {
2142         clone = CreateAfterAll(new_operands);
2143       }
2144       break;
2145     case HloOpcode::kAddDependency:
2146       CHECK_EQ(new_operands.size(), 2);
2147       clone = CreateAddDependency(new_operands[0], new_operands[1]);
2148       break;
2149     case HloOpcode::kReplicaId:
2150       CHECK_EQ(new_operands.size(), 0);
2151       clone = CreateReplicaId(shape);
2152       break;
2153     case HloOpcode::kPartitionId:
2154       CHECK_EQ(new_operands.size(), 0);
2155       clone = CreatePartitionId(shape);
2156       break;
2157   }
2158   // SetupDerivedInstruction will setup the precision_config_ field.
2159   SetupDerivedInstruction(clone.get());
2160   clone->set_parent(parent_);
2161   clone->set_outer_dimension_partitions(outer_dimension_partitions_);
2162   clone->backend_config_ = backend_config_.Clone();
2163   // The new instruction's name will be uniquified when it's added to a
2164   // computation.
2165   clone->SetAndSanitizeName(name());
2166   if (context != nullptr) {
2167     context->MapInstruction(this, clone.get());
2168     clone->ReplaceCalledComputations([&](HloComputation* callee) {
2169       return callee->parent() != context->module()
2170                  ? context->module()->DeepCloneComputation(callee, context)
2171                  : callee;
2172     });
2173   }
2174   return clone;
2175 }
2176 
DetachFromOperandsAndUsers()2177 void HloInstruction::DetachFromOperandsAndUsers() {
2178   if (cleaned_up_) {
2179     return;
2180   }
2181   cleaned_up_ = true;
2182   // Detach from operands. An instruction may be repeated as an operand. To
2183   // avoid calling RemoveUser twice on the same operand, check before remove.
2184   for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) {
2185     HloInstruction* operand = operands_[operand_num];
2186     if (operand == nullptr) {
2187       continue;
2188     }
2189     if (operand->user_map_.find(this) != operand->user_map_.end()) {
2190       operand->RemoveUser(this);
2191     }
2192     operands_[operand_num] = nullptr;
2193   }
2194 
2195   // Update users. Set `nullptr` to the corresponding operand slot for users.
2196   for (auto& user : this->users()) {
2197     for (int i = 0; i < user->operand_count(); ++i) {
2198       if (user->operands_[i] == this) {
2199         user->operands_[i] = nullptr;
2200       }
2201     }
2202   }
2203 }
2204 
CloneWithNewShape(const Shape & shape,const std::string & suffix,HloCloneContext * context) const2205 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
2206     const Shape& shape, const std::string& suffix,
2207     HloCloneContext* context) const {
2208   std::unique_ptr<HloInstruction> clone =
2209       CloneWithNewOperands(shape, operands_, context);
2210   if (suffix.empty()) {
2211     clone->name_ = name();
2212   } else {
2213     // If an instruction is cloned multiple times avoid names like
2214     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
2215     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
2216     // clone of foo.suffix2 is named foo.suffix3 and so on.
2217     const std::string dot_suffix = "." + suffix;
2218     size_t index = name().rfind(dot_suffix);
2219     if (index == std::string::npos) {
2220       // Existing name does not include ".suffix".
2221       clone->name_ = name() + dot_suffix;
2222     } else {
2223       // Existing name includes ".suffix". Determine if substring after
2224       // ".suffix" is numeric and should be replaced with an incremented number.
2225       std::string after_suffix = name().substr(index + dot_suffix.size());
2226       if (after_suffix.empty()) {
2227         // Existing name ends in ".suffix". New name should end in ".suffix2".
2228         clone->name_ = name() + "2";
2229       } else {
2230         // If names ends with .suffix[0-9]+ then replace with a suffix with the
2231         // numeric value incremented.
2232         int64_t numeric_suffix;
2233         if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
2234           clone->name_ =
2235               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
2236         } else {
2237           // Substring after ".suffix" is non-numeric.
2238           clone->name_ = name() + dot_suffix;
2239         }
2240       }
2241     }
2242   }
2243   return clone;
2244 }
2245 
Clone(const std::string & suffix,HloCloneContext * context) const2246 std::unique_ptr<HloInstruction> HloInstruction::Clone(
2247     const std::string& suffix, HloCloneContext* context) const {
2248   std::unique_ptr<HloInstruction> clone =
2249       CloneWithNewShape(shape_, suffix, context);
2250   return clone;
2251 }
2252 
2253 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const2254 HloInstruction::LatestNonGteAncestorAndIndex() const {
2255   const HloInstruction* hlo = this;
2256   ShapeIndex index;
2257   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2258     index.push_back(hlo->tuple_index());
2259     hlo = hlo->operand(0);
2260   }
2261 
2262   // We built up index in the reverse order from what we want.
2263   std::reverse(index.begin(), index.end());
2264 
2265   return {hlo, index};
2266 }
2267 
LatestNonGteAncestor() const2268 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
2269   const HloInstruction* hlo = this;
2270   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2271     hlo = hlo->operand(0);
2272   }
2273   return hlo;
2274 }
2275 
operand(int64_t i) const2276 const HloInstruction* HloInstruction::operand(int64_t i) const {
2277   return operands_.at(i);
2278 }
2279 
mutable_operand(int64_t i)2280 HloInstruction* HloInstruction::mutable_operand(int64_t i) {
2281   CHECK(operands_[i] != nullptr);
2282   return operands_.at(i);
2283 }
2284 
operand_index(const HloInstruction * target) const2285 int64_t HloInstruction::operand_index(const HloInstruction* target) const {
2286   for (int64_t i = 0; i < operand_count(); ++i) {
2287     if (target == operand(i)) {
2288       return i;
2289     }
2290   }
2291   LOG(FATAL) << "target was not an operand: " << target->ToString();
2292 }
2293 
unique_operands() const2294 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
2295   InstructionVector unique;
2296   absl::flat_hash_set<const HloInstruction*> seen;
2297   for (HloInstruction* operand : operands()) {
2298     if (seen.insert(operand).second) {
2299       unique.push_back(operand);
2300     }
2301   }
2302   return unique;
2303 }
2304 
AddControlDependencyTo(HloInstruction * instruction)2305 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
2306   TF_RET_CHECK(instruction->parent() == parent());
2307   if (!absl::c_linear_search(control_successors_, instruction)) {
2308     control_successors_.push_back(instruction);
2309     TF_RET_CHECK(
2310         !absl::c_linear_search(instruction->control_predecessors_, this));
2311     instruction->control_predecessors_.push_back(this);
2312   }
2313   return OkStatus();
2314 }
2315 
RemoveControlDependencyTo(HloInstruction * instruction)2316 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
2317   TF_RET_CHECK(instruction->parent() == parent());
2318   TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
2319   TF_RETURN_IF_ERROR(
2320       EraseElementFromVector(&instruction->control_predecessors_, this));
2321   return OkStatus();
2322 }
2323 
DropAllControlDeps()2324 Status HloInstruction::DropAllControlDeps() {
2325   for (auto* ctrl_succ : control_successors_) {
2326     TF_RETURN_IF_ERROR(
2327         EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
2328   }
2329   for (auto* ctrl_pred : control_predecessors_) {
2330     TF_RETURN_IF_ERROR(
2331         EraseElementFromVector(&ctrl_pred->control_successors_, this));
2332   }
2333   control_successors_.clear();
2334   control_predecessors_.clear();
2335   return OkStatus();
2336 }
2337 
CopyAllControlDepsFrom(const HloInstruction * inst)2338 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2339   for (auto* ctrl_pred : inst->control_predecessors()) {
2340     TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2341   }
2342 
2343   for (auto* ctrl_succ : inst->control_successors()) {
2344     TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2345   }
2346 
2347   return OkStatus();
2348 }
2349 
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values,bool ignore_commutative_operand_order) const2350 bool HloInstruction::IdenticalInternal(
2351     const HloInstruction& other,
2352     const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2353         eq_operands,
2354     const std::function<bool(const HloComputation*, const HloComputation*)>&
2355         eq_computations,
2356     bool layout_sensitive, bool ignore_channel_id_values,
2357     bool ignore_commutative_operand_order) const {
2358   // An instruction is always identical to itself.
2359   if (this == &other) {
2360     return true;
2361   }
2362 
2363   // Identical instruction must have the same opcode, shape, and identical
2364   // operands.
2365   if (opcode() != other.opcode()) {
2366     return false;
2367   }
2368   if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2369                          : ShapeUtil::Compatible(shape(), other.shape()))) {
2370     return false;
2371   }
2372   if (operands().size() != other.operands().size()) {
2373     return false;
2374   }
2375 
2376   // Check that operands are equal.
2377   //
2378   // Use an explicit loop rather than ContainerEquals, because copying around
2379   // std::functions may be too expensive in some cases.
2380   if (ignore_commutative_operand_order &&
2381       HloOpcodeIsBinaryCommutative(opcode())) {
2382     CHECK_EQ(operand_count(), 2);
2383     if (!(eq_operands(operand(0), other.operand(0)) &&
2384           eq_operands(operand(1), other.operand(1))) &&
2385         !(eq_operands(operand(0), other.operand(1)) &&
2386           eq_operands(operand(1), other.operand(0)))) {
2387       return false;
2388     }
2389   } else {
2390     for (size_t i = 0; i < operands().size(); ++i) {
2391       if (!eq_operands(operand(i), other.operand(i))) {
2392         return false;
2393       }
2394     }
2395   }
2396 
2397   if (backend_config_ != other.backend_config_) {
2398     return false;
2399   }
2400 
2401   if (ignore_channel_id_values) {
2402     if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2403       return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2404           other, eq_computations);
2405     }
2406   }
2407   return IdenticalSlowPath(other, eq_computations);
2408 }
2409 
AppendOperand(HloInstruction * operand)2410 void HloInstruction::AppendOperand(HloInstruction* operand) {
2411   if (operand->parent() != nullptr) {
2412     DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2413         << "Operand " << operand->name() << " is already marked dead";
2414   }
2415   operands_.push_back(operand);
2416   operand->AddUser(this);
2417 }
2418 
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2419 void HloInstruction::RemoveOperandsAtAscendingIndices(
2420     absl::Span<const int> ascending_indices) {
2421   if (ascending_indices.empty()) {
2422     return;
2423   }
2424   int next_index = 0;
2425   int removed_count = 0;
2426   for (int to_remove : ascending_indices) {
2427     while (next_index < to_remove) {
2428       operands_[next_index - removed_count] = operands_[next_index];
2429       ++next_index;
2430     }
2431     CHECK_LT(to_remove, operands_.size());
2432     ++removed_count;
2433     ++next_index;
2434   }
2435   while (next_index < operands_.size()) {
2436     operands_[next_index - removed_count] = operands_[next_index];
2437     ++next_index;
2438   }
2439   CHECK_EQ(removed_count, ascending_indices.size());
2440   operands_.resize(operands_.size() - removed_count);
2441 }
2442 
AddUser(HloInstruction * user)2443 void HloInstruction::AddUser(HloInstruction* user) {
2444   if (!ContainsKey(user_map_, user)) {
2445     user_map_.emplace(user, users_.size());
2446     users_.push_back(user);
2447   }
2448 }
2449 
UserId(HloInstruction * user)2450 int64_t HloInstruction::UserId(HloInstruction* user) {
2451   auto result = user_map_.find(user);
2452   CHECK(result != user_map_.end());
2453   return result->second;
2454 }
2455 
HasConstantOperand() const2456 bool HloInstruction::HasConstantOperand() const {
2457   for (const HloInstruction* operand : operands_) {
2458     if (operand->IsConstant()) {
2459       return true;
2460     }
2461   }
2462   return false;
2463 }
2464 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2465 bool HloInstruction::IdenticalSlowPath(
2466     const HloInstruction& other,
2467     const std::function<bool(const HloComputation*, const HloComputation*)>&
2468         eq_computations) const {
2469   // Perform opcode specific checks.
2470   switch (opcode()) {
2471     // The result of these instructions only depend upon their opcode and
2472     // operands.
2473     case HloOpcode::kAbs:
2474     case HloOpcode::kAllGatherDone:
2475     case HloOpcode::kAllReduceDone:
2476     case HloOpcode::kAtan2:
2477     case HloOpcode::kAdd:
2478     case HloOpcode::kBitcast:
2479     case HloOpcode::kBitcastConvert:
2480     case HloOpcode::kCeil:
2481     case HloOpcode::kClamp:
2482     case HloOpcode::kClz:
2483     case HloOpcode::kCollectivePermuteDone:
2484     case HloOpcode::kComplex:
2485     case HloOpcode::kConvert:
2486     case HloOpcode::kCopy:
2487     case HloOpcode::kCopyStart:
2488     case HloOpcode::kCopyDone:
2489     case HloOpcode::kCos:
2490     case HloOpcode::kDivide:
2491     case HloOpcode::kDynamicUpdateSlice:
2492     case HloOpcode::kExp:
2493     case HloOpcode::kExpm1:
2494     case HloOpcode::kFloor:
2495     case HloOpcode::kImag:
2496     case HloOpcode::kIsFinite:
2497     case HloOpcode::kLog:
2498     case HloOpcode::kLog1p:
2499     case HloOpcode::kAnd:
2500     case HloOpcode::kNot:
2501     case HloOpcode::kOr:
2502     case HloOpcode::kXor:
2503     case HloOpcode::kMaximum:
2504     case HloOpcode::kMinimum:
2505     case HloOpcode::kMultiply:
2506     case HloOpcode::kNegate:
2507     case HloOpcode::kOptimizationBarrier:
2508     case HloOpcode::kPartitionId:
2509     case HloOpcode::kPopulationCount:
2510     case HloOpcode::kPower:
2511     case HloOpcode::kReal:
2512     case HloOpcode::kRemainder:
2513     case HloOpcode::kReshape:
2514     case HloOpcode::kDynamicReshape:
2515     case HloOpcode::kReplicaId:
2516     case HloOpcode::kRoundNearestAfz:
2517     case HloOpcode::kRoundNearestEven:
2518     case HloOpcode::kRsqrt:
2519     case HloOpcode::kSelect:
2520     case HloOpcode::kShiftLeft:
2521     case HloOpcode::kShiftRightArithmetic:
2522     case HloOpcode::kShiftRightLogical:
2523     case HloOpcode::kLogistic:
2524     case HloOpcode::kSign:
2525     case HloOpcode::kSin:
2526     case HloOpcode::kSqrt:
2527     case HloOpcode::kCbrt:
2528     case HloOpcode::kSubtract:
2529     case HloOpcode::kTanh:
2530     case HloOpcode::kTuple:
2531       return true;
2532 
2533     // This opcode has complex or special behavior so just return false.
2534     case HloOpcode::kAfterAll:
2535     case HloOpcode::kAddDependency:
2536       return false;
2537 
2538     // Remaining instructions with special values.
2539     case HloOpcode::kCall:
2540       return eq_computations(to_apply(), other.to_apply());
2541     case HloOpcode::kConditional:
2542       for (int j = 0; j < branch_count(); ++j) {
2543         if (!eq_computations(branch_computation(j),
2544                              other.branch_computation(j))) {
2545           return false;
2546         }
2547       }
2548       return true;
2549     case HloOpcode::kWhile:
2550       return (eq_computations(while_body(), other.while_body()) &&
2551               eq_computations(while_condition(), other.while_condition()));
2552 
2553     // Ops migrated to subclasses should never come to this line.
2554     // TODO(b/80131774): Remove this switch when migration is complete.
2555     case HloOpcode::kAsyncStart:
2556     case HloOpcode::kAsyncUpdate:
2557     case HloOpcode::kAsyncDone:
2558     case HloOpcode::kBatchNormTraining:
2559     case HloOpcode::kBatchNormInference:
2560     case HloOpcode::kBatchNormGrad:
2561     case HloOpcode::kFft:
2562     case HloOpcode::kCompare:
2563     case HloOpcode::kSend:
2564     case HloOpcode::kSendDone:
2565     case HloOpcode::kRecv:
2566     case HloOpcode::kRecvDone:
2567     case HloOpcode::kReverse:
2568     case HloOpcode::kConcatenate:
2569     case HloOpcode::kReduce:
2570     case HloOpcode::kSort:
2571     case HloOpcode::kTranspose:
2572     case HloOpcode::kBroadcast:
2573     case HloOpcode::kMap:
2574     case HloOpcode::kSlice:
2575     case HloOpcode::kConstant:
2576     case HloOpcode::kIota:
2577     case HloOpcode::kFusion:
2578     case HloOpcode::kRng:
2579     case HloOpcode::kRngBitGenerator:
2580     case HloOpcode::kRngGetAndUpdateState:
2581     case HloOpcode::kParameter:
2582     case HloOpcode::kGetTupleElement:
2583     case HloOpcode::kReducePrecision:
2584     case HloOpcode::kInfeed:
2585     case HloOpcode::kOutfeed:
2586     case HloOpcode::kAllGather:
2587     case HloOpcode::kAllGatherStart:
2588     case HloOpcode::kAllReduce:
2589     case HloOpcode::kReduceScatter:
2590     case HloOpcode::kAllReduceStart:
2591     case HloOpcode::kAllToAll:
2592     case HloOpcode::kCollectivePermute:
2593     case HloOpcode::kCollectivePermuteStart:
2594     case HloOpcode::kConvolution:
2595     case HloOpcode::kCustomCall:
2596     case HloOpcode::kReduceWindow:
2597     case HloOpcode::kSelectAndScatter:
2598     case HloOpcode::kPad:
2599     case HloOpcode::kDynamicSlice:
2600     case HloOpcode::kGather:
2601     case HloOpcode::kScatter:
2602     case HloOpcode::kDot:
2603     case HloOpcode::kDomain:
2604     case HloOpcode::kGetDimensionSize:
2605     case HloOpcode::kSetDimensionSize:
2606     case HloOpcode::kTriangularSolve:
2607     case HloOpcode::kCholesky:
2608       LOG(FATAL) << "Base class impl called for opcode with subclass: "
2609                  << opcode();
2610   }
2611   return false;
2612 }
2613 
RemoveUser(HloInstruction * user)2614 void HloInstruction::RemoveUser(HloInstruction* user) {
2615   auto map_it = user_map_.find(user);
2616   CHECK(map_it != user_map_.end());
2617 
2618   const int64_t index = map_it->second;
2619   CHECK_EQ(users_[index], user);
2620 
2621   // Move the last user into the position of the removed user.
2622   users_[index] = users_.back();
2623   user_map_[users_.back()] = index;
2624 
2625   // Remove the user from the map and drop the last slot from the vector what
2626   // have been moved to the position of the original user.
2627   user_map_.erase(map_it);
2628   users_.pop_back();
2629 }
2630 
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2631 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2632                                       HloInstruction* new_producer) {
2633   TF_RET_CHECK(
2634       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2635       << "this shape: " << ShapeUtil::HumanString(shape())
2636       << ", replacement shape: "
2637       << ShapeUtil::HumanString(new_producer->shape());
2638   return ReplaceUseWithDifferentShape(user, new_producer);
2639 }
2640 
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2641 Status HloInstruction::ReplaceUseWithDifferentShape(
2642     HloInstruction* user, HloInstruction* new_producer) {
2643   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2644           << " with " << new_producer->name();
2645 
2646   RemoveUser(user);
2647 
2648   TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2649   std::replace(user->operands_.begin(), user->operands_.end(), this,
2650                new_producer);
2651   new_producer->AddUser(user);
2652   // Custom fusions may not be able to handle deduplicated operands.
2653   if (user->opcode() == HloOpcode::kFusion) {
2654     TF_RETURN_IF_ERROR(
2655         Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2656   }
2657   return OkStatus();
2658 }
2659 
ReplaceUseWith(HloInstruction * user,int operand_number,HloInstruction * new_producer)2660 Status HloInstruction::ReplaceUseWith(HloInstruction* user, int operand_number,
2661                                       HloInstruction* new_producer) {
2662   TF_RET_CHECK(
2663       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2664       << "this shape: " << ShapeUtil::HumanString(shape())
2665       << ", replacement shape: "
2666       << ShapeUtil::HumanString(new_producer->shape());
2667   return ReplaceUseWithDifferentShape(user, operand_number, new_producer);
2668 }
2669 
ReplaceUseWithDifferentShape(HloInstruction * user,int operand_number,HloInstruction * new_producer)2670 Status HloInstruction::ReplaceUseWithDifferentShape(
2671     HloInstruction* user, int operand_number, HloInstruction* new_producer) {
2672   VLOG(3) << "Replacing operand " << operand_number << " of " << name()
2673           << " in " << user->name() << " with " << new_producer->name();
2674 
2675   if (absl::c_count(user->operands_, this) == 1) {
2676     RemoveUser(user);
2677   }
2678 
2679   TF_RET_CHECK(user->operand(operand_number) == this)
2680       << "Expected operand " << operand_number << " of " << user->ToString()
2681       << " to be equal to " << ToString();
2682   user->operands_[operand_number] = new_producer;
2683   new_producer->AddUser(user);
2684   return OkStatus();
2685 }
2686 
ReplaceOperandWith(int64_t operand_num,HloInstruction * new_operand)2687 Status HloInstruction::ReplaceOperandWith(int64_t operand_num,
2688                                           HloInstruction* new_operand) {
2689   auto old_operand = operand(operand_num);
2690   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2691                                                         new_operand->shape()))
2692       << old_operand->shape() << " is not compatible with "
2693       << new_operand->shape();
2694   return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2695 }
2696 
ReplaceOperandWithDifferentShape(int64_t operand_num,HloInstruction * new_operand)2697 Status HloInstruction::ReplaceOperandWithDifferentShape(
2698     int64_t operand_num, HloInstruction* new_operand) {
2699   TF_RET_CHECK(operand_num >= 0);
2700   TF_RET_CHECK(operand_num < operand_count());
2701   HloInstruction* old_operand = mutable_operand(operand_num);
2702   if (old_operand == new_operand) {
2703     return OkStatus();
2704   }
2705 
2706   operands_[operand_num] = new_operand;
2707 
2708   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2709           << new_operand->name() << ", was " << old_operand->name();
2710 
2711   if (!absl::c_linear_search(operands_, old_operand)) {
2712     old_operand->RemoveUser(this);
2713   }
2714   new_operand->AddUser(this);
2715   return OkStatus();
2716 }
2717 
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2718 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2719                                        HloInstruction* new_producer) {
2720   TF_RET_CHECK(
2721       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2722       << shape() << " is not compatible with " << new_producer->shape();
2723   return ReplaceAllUsesWithDifferentShape(users, new_producer);
2724 }
2725 
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2726 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2727     absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2728   for (HloInstruction* user : users) {
2729     TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2730   }
2731 
2732   if (parent_ && parent_->root_instruction() == this) {
2733     parent_->set_root_instruction(new_producer,
2734                                   /*accept_different_shape=*/true);
2735   }
2736   return OkStatus();
2737 }
2738 
ReplaceAllUsesWith(HloInstruction * new_producer)2739 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2740   TF_RET_CHECK(
2741       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2742       << shape() << " is not compatible with " << new_producer->shape();
2743   return ReplaceAllUsesWithDifferentShape(new_producer);
2744 }
2745 
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2746 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2747     HloInstruction* new_producer) {
2748   bool new_producer_is_user = false;
2749   for (HloInstruction* user : users()) {
2750     if (user == new_producer) {
2751       // It's possible that new_producer is a user of this instruction as might
2752       // be the case when replacing an instruction with a kCopy of itself. In
2753       // this case, don't do the replacement to avoid creating a cycle in the
2754       // graph. new_producer remains the only user of this instruction.
2755       new_producer_is_user = true;
2756     } else {
2757       std::replace(user->operands_.begin(), user->operands_.end(), this,
2758                    new_producer);
2759       new_producer->AddUser(user);
2760       if (user->opcode() == HloOpcode::kFusion) {
2761         TF_RETURN_IF_ERROR(
2762             Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2763       }
2764     }
2765   }
2766   users_.clear();
2767   user_map_.clear();
2768   if (new_producer_is_user) {
2769     AddUser(new_producer);
2770   }
2771   if (parent_ && parent_->root_instruction() == this) {
2772     parent_->set_root_instruction(new_producer,
2773                                   /*accept_different_shape=*/true);
2774   }
2775 
2776   return OkStatus();
2777 }
2778 
IsEffectiveBitcast() const2779 bool HloInstruction::IsEffectiveBitcast() const {
2780   return opcode_ == HloOpcode::kBitcast ||
2781          (opcode_ == HloOpcode::kTranspose &&
2782           ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2783                                         dimensions()));
2784 }
2785 
to_apply() const2786 HloComputation* HloInstruction::to_apply() const {
2787   switch (opcode_) {
2788     case HloOpcode::kCall:
2789     case HloOpcode::kMap:
2790     case HloOpcode::kReduceWindow:
2791     case HloOpcode::kReduce:
2792     case HloOpcode::kAllReduce:
2793     case HloOpcode::kReduceScatter:
2794     case HloOpcode::kAllReduceStart:
2795     case HloOpcode::kScatter:
2796     case HloOpcode::kSort:
2797     case HloOpcode::kCustomCall:
2798       CHECK_EQ(called_computations_.size(), 1);
2799       return called_computations_[0];
2800     default:
2801       LOG(FATAL) << "Invalid opcode for to_apply(): "
2802                  << HloOpcodeString(opcode());
2803   }
2804 }
2805 
set_to_apply(HloComputation * computation)2806 void HloInstruction::set_to_apply(HloComputation* computation) {
2807   // Don't allow changing the computation for fused instructions so we don't
2808   // have to recompute called_instructions for the entire fusion instruction.
2809   CHECK(!IsFused());
2810   switch (opcode_) {
2811     case HloOpcode::kCall:
2812     case HloOpcode::kMap:
2813     case HloOpcode::kReduceWindow:
2814     case HloOpcode::kReduce:
2815     case HloOpcode::kAllReduce:
2816     case HloOpcode::kAllReduceStart:
2817     case HloOpcode::kScatter:
2818     case HloOpcode::kSort:
2819     case HloOpcode::kCustomCall:
2820       CHECK_EQ(called_computations_.size(), 1);
2821       called_computations_[0] = computation;
2822       break;
2823     default:
2824       LOG(FATAL) << "Invalid opcode for to_apply(): "
2825                  << HloOpcodeString(opcode());
2826   }
2827 }
2828 
while_condition() const2829 HloComputation* HloInstruction::while_condition() const {
2830   CHECK_EQ(HloOpcode::kWhile, opcode_);
2831   return called_computations_[kConditionComputationIndex];
2832 }
2833 
while_body() const2834 HloComputation* HloInstruction::while_body() const {
2835   CHECK_EQ(HloOpcode::kWhile, opcode_);
2836   return called_computations_[kBodyComputationIndex];
2837 }
2838 
set_while_condition(HloComputation * computation)2839 void HloInstruction::set_while_condition(HloComputation* computation) {
2840   // Don't allow changing the computation for fused instructions so we don't
2841   // have to recompute called_instructions for the entire fusion instruction.
2842   CHECK(!IsFused());
2843   CHECK_EQ(HloOpcode::kWhile, opcode_);
2844   called_computations_[kConditionComputationIndex] = computation;
2845 }
2846 
set_while_body(HloComputation * computation)2847 void HloInstruction::set_while_body(HloComputation* computation) {
2848   // Don't allow changing the computation for fused instructions so we don't
2849   // have to recompute called_instructions for the entire fusion instruction.
2850   CHECK(!IsFused());
2851   CHECK_EQ(HloOpcode::kWhile, opcode_);
2852   called_computations_[kBodyComputationIndex] = computation;
2853 }
2854 
while_init() const2855 HloInstruction* HloInstruction::while_init() const {
2856   CHECK_EQ(HloOpcode::kWhile, opcode_);
2857   return operands_[0];
2858 }
2859 
true_computation() const2860 HloComputation* HloInstruction::true_computation() const {
2861   CHECK_EQ(HloOpcode::kConditional, opcode_);
2862   CHECK_EQ(PRED, operand(0)->shape().element_type());
2863   return called_computations_[kTrueComputationIndex];
2864 }
2865 
false_computation() const2866 HloComputation* HloInstruction::false_computation() const {
2867   CHECK_EQ(HloOpcode::kConditional, opcode_);
2868   CHECK_EQ(PRED, operand(0)->shape().element_type());
2869   return called_computations_[kFalseComputationIndex];
2870 }
2871 
branch_computations() const2872 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2873     const {
2874   CHECK(HloOpcode::kConditional == opcode_);
2875   return called_computations_;
2876 }
2877 
branch_count() const2878 int HloInstruction::branch_count() const {
2879   CHECK(HloOpcode::kConditional == opcode_);
2880   return called_computations_.size();
2881 }
2882 
branch_computation(int b) const2883 HloComputation* HloInstruction::branch_computation(int b) const {
2884   CHECK(HloOpcode::kConditional == opcode_);
2885   CHECK_GE(b, 0);
2886   CHECK_LT(b, called_computations_.size());
2887   return called_computations_[b];
2888 }
2889 
set_branch_computation(int b,HloComputation * computation)2890 void HloInstruction::set_branch_computation(int b,
2891                                             HloComputation* computation) {
2892   // Don't allow changing the computation for fused instructions so we don't
2893   // have to recompute called_instructions for the entire fusion instruction.
2894   CHECK(!IsFused());
2895   CHECK_EQ(HloOpcode::kConditional, opcode_);
2896   called_computations_[b] = computation;
2897 }
2898 
SignatureString() const2899 std::string HloInstruction::SignatureString() const {
2900   std::string operands =
2901       StrJoin(operands_, ", ", [](std::string* out, HloInstruction* operand) {
2902         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2903       });
2904   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2905 }
2906 
PrintName(const std::string & name,bool print_ids)2907 std::string PrintName(const std::string& name, bool print_ids) {
2908   if (print_ids) {
2909     return name;
2910   } else {
2911     auto dot_position = name.find_first_of('.');
2912     return name.substr(0, dot_position);
2913   }
2914 }
2915 
2916 namespace {
2917 
2918 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2919 
PrintNameInternal(const std::string & name,const HloPrintOptions & options)2920 std::string PrintNameInternal(const std::string& name,
2921                               const HloPrintOptions& options) {
2922   return StrCat(options.print_percent() ? "%" : "",
2923                 PrintName(name, options.print_ids()));
2924 }
2925 
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2926 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2927   // This set contains HloInstructions from the top of `DFSStack` that might
2928   // belong to the cycle, i.e. if  DFSStack :=[back,...,child,...,top], then
2929   // `subgraph` := {child,...,top}.
2930   absl::flat_hash_set<const HloInstruction*> subgraph;
2931   while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2932     subgraph.insert(dfs_stack->back().second);
2933     dfs_stack->pop_back();
2934   }
2935   // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2936   absl::flat_hash_set<const HloInstruction*> visited;
2937   absl::InlinedVector<const HloInstruction*, 16> dfs;
2938   dfs.push_back(child);
2939   while (!dfs.empty()) {
2940     bool found_next_instr = false;
2941     for (const auto& user : dfs.back()->users()) {
2942       if (user == child) {
2943         dfs.push_back(child);
2944         LOG(INFO) << "\n\nDirected cycle:\n  "
2945                   << absl::StrJoin(
2946                          dfs, "\n  ",
2947                          [](std::string* out, const HloInstruction* instr) {
2948                            out->append(instr->name());
2949                          });
2950         return;
2951       }
2952       if (!subgraph.contains(user) || visited.contains(user)) {
2953         continue;
2954       }
2955       visited.insert(user);
2956       dfs.push_back(user);
2957       found_next_instr = true;
2958     }
2959     if (!found_next_instr) {
2960       dfs.pop_back();
2961     }
2962   }
2963 }
2964 
2965 }  // namespace
2966 
ToString(const HloPrintOptions & options) const2967 std::string HloInstruction::ToString(const HloPrintOptions& options) const {
2968   CanonicalNameMap new_map;
2969   return ToStringWithCanonicalNameMap(options, &new_map);
2970 }
2971 
IsOpElementwise(HloOpcode opcode)2972 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2973   switch (opcode) {
2974     // Unary elementwise operations.
2975     case HloOpcode::kAbs:
2976     case HloOpcode::kRoundNearestAfz:
2977     case HloOpcode::kRoundNearestEven:
2978     case HloOpcode::kCeil:
2979     case HloOpcode::kClz:
2980     case HloOpcode::kConvert:
2981     case HloOpcode::kBitcastConvert:
2982     case HloOpcode::kCopy:
2983     case HloOpcode::kCos:
2984     case HloOpcode::kExp:
2985     case HloOpcode::kExpm1:
2986     case HloOpcode::kFloor:
2987     case HloOpcode::kImag:
2988     case HloOpcode::kIsFinite:
2989     case HloOpcode::kLog:
2990     case HloOpcode::kLog1p:
2991     case HloOpcode::kNot:
2992     case HloOpcode::kNegate:
2993     case HloOpcode::kPopulationCount:
2994     case HloOpcode::kReal:
2995     case HloOpcode::kReducePrecision:
2996     case HloOpcode::kRsqrt:
2997     case HloOpcode::kLogistic:
2998     case HloOpcode::kSign:
2999     case HloOpcode::kSin:
3000     case HloOpcode::kSqrt:
3001     case HloOpcode::kCbrt:
3002     case HloOpcode::kTanh:
3003       return true;
3004 
3005     // Binary elementwise operations, the same as in IsElementwiseBinary().
3006     case HloOpcode::kAdd:
3007     case HloOpcode::kAtan2:
3008     case HloOpcode::kCompare:
3009     case HloOpcode::kComplex:
3010     case HloOpcode::kDivide:
3011     case HloOpcode::kMaximum:
3012     case HloOpcode::kMinimum:
3013     case HloOpcode::kMultiply:
3014     case HloOpcode::kPower:
3015     case HloOpcode::kRemainder:
3016     case HloOpcode::kSubtract:
3017     case HloOpcode::kAnd:
3018     case HloOpcode::kOr:
3019     case HloOpcode::kXor:
3020     case HloOpcode::kShiftLeft:
3021     case HloOpcode::kShiftRightArithmetic:
3022     case HloOpcode::kShiftRightLogical:
3023       return true;
3024 
3025     // Ternary elementwise operations.
3026     case HloOpcode::kSelect:
3027     case HloOpcode::kClamp:
3028       return true;
3029 
3030     default:
3031       return false;
3032   }
3033 }
3034 
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const3035 bool HloInstruction::IsElementwiseImpl(
3036     const std::optional<int64_t>& operand_idx) const {
3037   if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
3038     return operand_idx.has_value() && operand_idx.value() == 0;
3039   }
3040   if (opcode_ == HloOpcode::kBitcastConvert &&
3041       primitive_util::BitWidth(shape_.element_type()) !=
3042           primitive_util::BitWidth(operands_[0]->shape().element_type())) {
3043     return false;
3044   }
3045   return IsOpElementwise(opcode_);
3046 }
3047 
IsCrossModuleAllReduce() const3048 bool HloInstruction::IsCrossModuleAllReduce() const {
3049   return opcode() == HloOpcode::kAllReduce && channel_id();
3050 }
3051 
IsCrossReplicaAllReduce() const3052 bool HloInstruction::IsCrossReplicaAllReduce() const {
3053   return opcode() == HloOpcode::kAllReduce && !channel_id();
3054 }
3055 
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const3056 std::string HloInstruction::ToStringWithCanonicalNameMap(
3057     const HloPrintOptions& options,
3058     CanonicalNameMap* canonical_name_map) const {
3059   std::string result = "";
3060 
3061   // Logic to print the instruction name (e.g. "%foo = ").
3062   if (options.canonicalize_instruction_names()) {
3063     if (options.is_in_nested_computation()) {
3064       // If we are canonicalizing instruction names and this is a top-level
3065       // HloInstruction::ToString() call, don't print an instruction name.
3066       DCHECK(!options.print_percent());  // no need to call PrintNameInternal
3067       StrAppend(&result, canonical_name_map->LookupOrInsert(name()), " = ");
3068     }
3069   } else {
3070     StrAppend(&result, PrintNameInternal(name(), options), " = ");
3071   }
3072 
3073   if (options.print_result_shape()) {
3074     // Print shape.
3075     if (options.include_layout_in_shapes()) {
3076       StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
3077     } else {
3078       StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
3079     }
3080   }
3081 
3082   // Print opcode, operand(s).
3083   if (options.syntax_sugar_async_ops() && HloOpcodeIsAsync(opcode())) {
3084     std::string suffix = [&]() {
3085       switch (opcode()) {
3086         case HloOpcode::kAsyncStart:
3087           return "-start";
3088         case HloOpcode::kAsyncUpdate:
3089           return "-update";
3090         default:
3091           CHECK(opcode() == HloOpcode::kAsyncDone)
3092               << "Unexpected async opcode: " << HloOpcodeString(opcode());
3093           return "-done";
3094       }
3095     }();
3096     StrAppend(&result, HloOpcodeString(async_wrapped_opcode()), suffix);
3097   } else {
3098     StrAppend(&result, HloOpcodeString(opcode()));
3099   }
3100   StrAppend(&result, "(",
3101             OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
3102             ")");
3103 
3104   // Print additional attributes. If an instruction contains a subcomputation,
3105   // the subcomputation is also printed here.
3106   for (const std::string& extra : ExtraAttributesToString(options)) {
3107     StrAppend(&result, ", ", extra);
3108   }
3109 
3110   if (options.print_metadata() &&
3111       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
3112        !metadata_.source_file().empty())) {
3113     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
3114   }
3115   if (options.print_backend_config() && !backend_config_.empty()) {
3116     StrAppend(&result, ", backend_config=\"",
3117               CEscape(backend_config_.GetRawString()), "\"");
3118   }
3119   return result;
3120 }
3121 
OperandsToString(const HloPrintOptions & options) const3122 std::string HloInstruction::OperandsToString(
3123     const HloPrintOptions& options) const {
3124   CanonicalNameMap new_map;
3125   return OperandsToStringWithCanonicalNameMap(options, &new_map);
3126 }
3127 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const3128 std::string HloInstruction::OperandsToStringWithCanonicalNameMap(
3129     const HloPrintOptions& options,
3130     CanonicalNameMap* canonical_name_map) const {
3131   std::string operands;
3132   absl::Span<HloInstruction* const> slice(operands_);
3133   const int64_t kMaxOperandsToShowIfCompact = 4;
3134   if (options.compact_operands() &&
3135       slice.size() > kMaxOperandsToShowIfCompact) {
3136     slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
3137   }
3138   for (int64_t i = 0; i < slice.size(); ++i) {
3139     HloInstruction* operand = slice[i];
3140     if (i != 0) {
3141       StrAppend(&operands, ", ");
3142       if (options.print_operand_index_annotation_interval() != 0 &&
3143           i % options.print_operand_index_annotation_interval() == 0) {
3144         StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i));
3145       }
3146     }
3147     // If operand is already been deleted, put `null` to the string output.
3148     if (operand == nullptr) {
3149       StrAppend(&operands, "null ");
3150       continue;
3151     }
3152     std::vector<std::string> str;
3153     if (options.print_operand_shape()) {
3154       if (options.include_layout_in_shapes()) {
3155         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
3156       } else {
3157         str.push_back(ShapeUtil::HumanString(operand->shape()));
3158       }
3159     }
3160     if (options.canonicalize_instruction_names()) {
3161       if (options.is_in_nested_computation()) {
3162         // In a top-level HloInstruction::ToString() call, the operand name is
3163         // not part of the canonical string.
3164         DCHECK(!options.print_percent());  // no need to call PrintNameInternal
3165         str.push_back(canonical_name_map->LookupOrInsert(operand->name()));
3166       }
3167     } else if (options.print_operand_names()) {
3168       str.push_back(PrintNameInternal(operand->name(), options));
3169     }
3170     StrAppend(&operands, StrJoin(str, " "));
3171   }
3172   const int64_t remaining = operands_.size() - slice.size();
3173   if (slice.size() != operands_.size()) {
3174     StrAppend(&operands, ", ...(+", remaining, ")");
3175   }
3176   return operands;
3177 }
3178 
3179 namespace {
3180 
IsSequentialCall(HloOpcode opcode)3181 bool IsSequentialCall(HloOpcode opcode) {
3182   switch (opcode) {
3183     case HloOpcode::kCall:
3184     case HloOpcode::kConditional:
3185     case HloOpcode::kWhile:
3186       return true;
3187     default:
3188       return false;
3189   }
3190 }
3191 
3192 }  // namespace
3193 
ExtraAttributesToString(const HloPrintOptions & options) const3194 std::vector<std::string> HloInstruction::ExtraAttributesToString(
3195     const HloPrintOptions& options) const {
3196   std::vector<std::string> extra = options.print_extra_attributes()
3197                                        ? ExtraAttributesToStringImpl(options)
3198                                        : std::vector<std::string>();
3199 
3200   const auto subcomputation_mode = options.print_subcomputation_mode();
3201   if (subcomputation_mode ==
3202       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
3203     if (opcode() == HloOpcode::kWhile) {
3204       extra.push_back(StrCat(
3205           "condition=", PrintNameInternal(while_condition()->name(), options)));
3206       extra.push_back(
3207           StrCat("body=", PrintNameInternal(while_body()->name(), options)));
3208     } else if (opcode() == HloOpcode::kSelectAndScatter) {
3209       extra.push_back(
3210           StrCat("select=", PrintNameInternal(select()->name(), options)));
3211       extra.push_back(
3212           StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
3213     } else if (opcode() == HloOpcode::kConditional) {
3214       if (operand(0)->shape().element_type() == PRED) {
3215         extra.push_back(
3216             StrCat("true_computation=",
3217                    PrintNameInternal(true_computation()->name(), options)));
3218         extra.push_back(
3219             StrCat("false_computation=",
3220                    PrintNameInternal(false_computation()->name(), options)));
3221       } else {
3222         extra.push_back(StrCat(
3223             "branch_computations={",
3224             StrJoin(branch_computations(), ", ",
3225                     [&](std::string* out, const HloComputation* computation) {
3226                       StrAppend(
3227                           out, PrintNameInternal(computation->name(), options));
3228                     }),
3229             "}"));
3230       }
3231     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
3232                opcode() == HloOpcode::kReduceWindow ||
3233                opcode() == HloOpcode::kReduce ||
3234                opcode() == HloOpcode::kAllReduce ||
3235                opcode() == HloOpcode::kReduceScatter ||
3236                opcode() == HloOpcode::kAllReduceStart ||
3237                opcode() == HloOpcode::kScatter ||
3238                opcode() == HloOpcode::kSort) {
3239       if (!called_computations().empty()) {
3240         extra.push_back(StrCat("to_apply=",
3241                                PrintNameInternal(to_apply()->name(), options)));
3242       }
3243     } else if (opcode() == HloOpcode::kCustomCall) {
3244       if (!called_computations().empty()) {
3245         extra.push_back(StrCat(
3246             "called_computations={",
3247             StrJoin(called_computations(), ", ",
3248                     [&](std::string* out, const HloComputation* computation) {
3249                       StrAppend(
3250                           out, PrintNameInternal(computation->name(), options));
3251                     }),
3252             "}"));
3253       }
3254     } else if (HloOpcodeIsAsync(opcode())) {
3255       if (!options.syntax_sugar_async_ops()) {
3256         extra.push_back(StrCat(
3257             "calls=",
3258             PrintNameInternal(async_wrapped_computation()->name(), options)));
3259       }
3260     } else if (!called_computations().empty()) {
3261       extra.push_back(StrCat(
3262           "calls=",
3263           StrJoin(called_computations(), ", ",
3264                   [&](std::string* out, const HloComputation* computation) {
3265                     StrAppend(out,
3266                               PrintNameInternal(computation->name(), options));
3267                   })));
3268     }
3269   } else if ((subcomputation_mode ==
3270               HloPrintOptions::PrintSubcomputationMode::kFullBodies) ||
3271              (subcomputation_mode == HloPrintOptions::PrintSubcomputationMode::
3272                                          kNonSequentialBodies &&
3273               !IsSequentialCall(opcode()))) {
3274     HloPrintOptions new_options = options;
3275     new_options.set_is_in_nested_computation(true);
3276     switch (opcode()) {
3277       case HloOpcode::kWhile:
3278         extra.push_back(
3279             StrCat("condition=\n", while_condition()->ToString(new_options)));
3280         extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
3281         break;
3282       case HloOpcode::kSelectAndScatter:
3283         extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
3284         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
3285         break;
3286       case HloOpcode::kConditional:
3287         if (operand(0)->shape().element_type() == PRED) {
3288           extra.push_back(StrCat("true_computation=\n",
3289                                  true_computation()->ToString(new_options)));
3290           extra.push_back(StrCat("false_computation=\n",
3291                                  false_computation()->ToString(new_options)));
3292         } else {
3293           extra.push_back(StrCat(
3294               "branch_computations={\n",
3295               StrJoin(branch_computations(), ",\n",
3296                       [&](std::string* out, const HloComputation* computation) {
3297                         StrAppend(out, computation->ToString(new_options));
3298                       }),
3299               "\n}"));
3300         }
3301         break;
3302       case HloOpcode::kCall:
3303       case HloOpcode::kMap:
3304       case HloOpcode::kReduceWindow:
3305       case HloOpcode::kReduce:
3306       case HloOpcode::kAllReduce:
3307       case HloOpcode::kAllReduceStart:
3308       case HloOpcode::kScatter:
3309       case HloOpcode::kSort:
3310         if (!called_computations().empty()) {
3311           extra.push_back(
3312               StrCat("to_apply=\n", to_apply()->ToString(new_options)));
3313         }
3314         break;
3315       default:
3316         if (!called_computations().empty()) {
3317           extra.push_back(StrCat(
3318               "calls=\n",
3319               StrJoin(called_computations(), ", ",
3320                       [&](std::string* out, const HloComputation* computation) {
3321                         StrAppend(out, computation->ToString(new_options));
3322                       })));
3323         }
3324         break;
3325     }
3326   }
3327 
3328   if (has_sharding()) {
3329     extra.push_back(
3330         StrCat("sharding=", sharding().ToString(options.print_metadata())));
3331   }
3332   if (!frontend_attributes_.map().empty()) {
3333     extra.push_back(StrCat("frontend_attributes=",
3334                            FrontendAttributesToString(frontend_attributes_)));
3335   }
3336   if (!outer_dimension_partitions_.empty()) {
3337     extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
3338                                     StrJoin(outer_dimension_partitions_, ",")));
3339   }
3340 
3341   if (options.print_control_dependencies() && !control_predecessors_.empty()) {
3342     extra.push_back(StrCat("control-predecessors={",
3343                            StrJoin(control_predecessors_, ", ",
3344                                    [&](std::string* out, HloInstruction* pre) {
3345                                      StrAppend(out, PrintNameInternal(
3346                                                         pre->name(), options));
3347                                    }),
3348                            "}"));
3349   }
3350 
3351   return extra;
3352 }
3353 
ToShortString() const3354 std::string HloInstruction::ToShortString() const {
3355   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
3356                 StrJoin(operands_, ", ",
3357                         [](std::string* out, HloInstruction* operand) {
3358                           StrAppend(out, "%", operand->name());
3359                         }),
3360                 ")");
3361 }
3362 
ToProto() const3363 HloInstructionProto HloInstruction::ToProto() const {
3364   HloInstructionProto proto;
3365   CHECK(unique_id_ != -1)
3366       << "This instruction does not have a valid id. Please make sure the "
3367          "instruction is inside a module before dumping it.";
3368   proto.set_id(unique_id_);
3369   proto.set_name(name_);
3370   proto.set_opcode(HloOpcodeString(opcode_));
3371   *proto.mutable_shape() = shape_.ToProto();
3372   for (const HloInstruction* operand : operands_) {
3373     proto.add_operand_ids(operand->unique_id());
3374   }
3375   for (const HloInstruction* control : control_predecessors_) {
3376     proto.add_control_predecessor_ids(control->unique_id());
3377   }
3378 
3379   *proto.mutable_metadata() = metadata_;
3380   proto.set_backend_config(backend_config_.GetRawString());
3381   if (opcode() != HloOpcode::kFusion) {
3382     for (const HloComputation* computation : called_computations_) {
3383       proto.add_called_computation_ids(computation->unique_id());
3384     }
3385   }
3386 
3387   if (has_sharding()) {
3388     *proto.mutable_sharding() = sharding().ToProto();
3389   }
3390   if (!outer_dimension_partitions_.empty()) {
3391     for (const auto& idx : outer_dimension_partitions_) {
3392       proto.mutable_outer_dimension_partitions()->Add(idx);
3393     }
3394   }
3395 
3396   *proto.mutable_frontend_attributes() = frontend_attributes_;
3397 
3398   return proto;
3399 }
3400 
ToCategory() const3401 std::string HloInstruction::ToCategory() const {
3402   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
3403       opcode() == HloOpcode::kReshape ||
3404       opcode() == HloOpcode::kDynamicReshape) {
3405     return "data formatting";
3406   }
3407 
3408   if (IsElementwise()) {
3409     return "non-fusion elementwise";
3410   }
3411 
3412   return HloOpcodeString(opcode());
3413 }
3414 
IsFused() const3415 bool HloInstruction::IsFused() const {
3416   return parent_ != nullptr && parent_->IsFusionComputation();
3417 }
3418 
IsCustomCall(absl::string_view target) const3419 bool HloInstruction::IsCustomCall(absl::string_view target) const {
3420   return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3421 }
3422 
IsInputFusion() const3423 bool HloInstruction::IsInputFusion() const {
3424   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3425 }
3426 
IsLoopFusion() const3427 bool HloInstruction::IsLoopFusion() const {
3428   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3429 }
3430 
IsOutputFusion() const3431 bool HloInstruction::IsOutputFusion() const {
3432   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3433 }
3434 
IsCustomFusion() const3435 bool HloInstruction::IsCustomFusion() const {
3436   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3437 }
3438 
IsFusible() const3439 bool HloInstruction::IsFusible() const {
3440   // Some kinds of instructions don't make sense to fuse.
3441   switch (opcode_) {
3442     case HloOpcode::kDomain:
3443     case HloOpcode::kParameter:
3444     case HloOpcode::kWhile:
3445     case HloOpcode::kConditional:
3446     case HloOpcode::kCall:
3447       return false;
3448     // Fusions are always fusible.
3449     case HloOpcode::kFusion:
3450     // Side effecting reduce and reduce window would be invalid HLO.
3451     case HloOpcode::kMap:
3452     case HloOpcode::kReduce:
3453     case HloOpcode::kReduceWindow:
3454       return true;
3455     case HloOpcode::kRng:
3456       return user_count() <= 1;
3457     // Side effecting instructions cannot be fused.
3458     default:
3459       return !HasSideEffect();
3460   }
3461 }
3462 
HloInstruction(HloOpcode opcode,const Shape & shape)3463 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3464     : unique_id_(-1),
3465       opcode_(opcode),
3466       shape_(shape),
3467       name_(HloOpcodeString(opcode)),
3468       marked_as_dead_(false) {
3469   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3470 }
3471 
3472 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3473 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3474   switch (opcode_) {
3475     case HloOpcode::kAbs:
3476       return visitor->HandleAbs(this);
3477     case HloOpcode::kAtan2:
3478       return visitor->HandleAtan2(this);
3479     case HloOpcode::kRoundNearestAfz:
3480       return visitor->HandleRound(this);
3481     case HloOpcode::kRoundNearestEven:
3482       return visitor->HandleRoundNearestEven(this);
3483     case HloOpcode::kBatchNormTraining:
3484       return visitor->HandleBatchNormTraining(this);
3485     case HloOpcode::kBatchNormInference:
3486       return visitor->HandleBatchNormInference(this);
3487     case HloOpcode::kBatchNormGrad:
3488       return visitor->HandleBatchNormGrad(this);
3489     case HloOpcode::kLogistic:
3490       return visitor->HandleLogistic(this);
3491     case HloOpcode::kSign:
3492       return visitor->HandleSign(this);
3493     case HloOpcode::kConstant:
3494       return visitor->HandleConstant(this);
3495     case HloOpcode::kGetTupleElement:
3496       return visitor->HandleGetTupleElement(this);
3497     case HloOpcode::kParameter:
3498       return visitor->HandleParameter(this);
3499     case HloOpcode::kCompare:
3500       return visitor->HandleCompare(this);
3501     case HloOpcode::kComplex:
3502       return visitor->HandleComplex(this);
3503     case HloOpcode::kAdd:
3504       return visitor->HandleAdd(this);
3505     case HloOpcode::kDivide:
3506       return visitor->HandleDivide(this);
3507     case HloOpcode::kSubtract:
3508       return visitor->HandleSubtract(this);
3509     case HloOpcode::kMaximum:
3510       return visitor->HandleMaximum(this);
3511     case HloOpcode::kMinimum:
3512       return visitor->HandleMinimum(this);
3513     case HloOpcode::kAnd:
3514       return visitor->HandleAnd(this);
3515     case HloOpcode::kOr:
3516       return visitor->HandleOr(this);
3517     case HloOpcode::kXor:
3518       return visitor->HandleXor(this);
3519     case HloOpcode::kShiftLeft:
3520       return visitor->HandleShiftLeft(this);
3521     case HloOpcode::kShiftRightArithmetic:
3522       return visitor->HandleShiftRightArithmetic(this);
3523     case HloOpcode::kShiftRightLogical:
3524       return visitor->HandleShiftRightLogical(this);
3525     case HloOpcode::kConcatenate:
3526       return visitor->HandleConcatenate(this);
3527     case HloOpcode::kConvert:
3528       return visitor->HandleConvert(this);
3529     case HloOpcode::kBitcastConvert:
3530       return visitor->HandleBitcastConvert(this);
3531     case HloOpcode::kCopy:
3532       return visitor->HandleCopy(this);
3533     case HloOpcode::kMultiply:
3534       return visitor->HandleMultiply(this);
3535     case HloOpcode::kDot:
3536       return visitor->HandleDot(this);
3537     case HloOpcode::kPower:
3538       return visitor->HandlePower(this);
3539     case HloOpcode::kRemainder:
3540       return visitor->HandleRemainder(this);
3541     case HloOpcode::kSelect:
3542       return visitor->HandleSelect(this);
3543     case HloOpcode::kConvolution:
3544       return visitor->HandleConvolution(this);
3545     case HloOpcode::kFft:
3546       return visitor->HandleFft(this);
3547     case HloOpcode::kAllGather:
3548       return visitor->HandleAllGather(this);
3549     case HloOpcode::kAllGatherStart:
3550       return visitor->HandleAllGatherStart(this);
3551     case HloOpcode::kAllGatherDone:
3552       return visitor->HandleAllGatherDone(this);
3553     case HloOpcode::kAllReduce:
3554       return visitor->HandleAllReduce(this);
3555     case HloOpcode::kReduceScatter:
3556       return visitor->HandleReduceScatter(this);
3557     case HloOpcode::kAllReduceStart:
3558       return visitor->HandleAllReduceStart(this);
3559     case HloOpcode::kAllReduceDone:
3560       return visitor->HandleAllReduceDone(this);
3561     case HloOpcode::kAllToAll:
3562       return visitor->HandleAllToAll(this);
3563     case HloOpcode::kCollectivePermute:
3564       return visitor->HandleCollectivePermute(this);
3565     case HloOpcode::kCollectivePermuteStart:
3566       return visitor->HandleCollectivePermuteStart(this);
3567     case HloOpcode::kCollectivePermuteDone:
3568       return visitor->HandleCollectivePermuteDone(this);
3569     case HloOpcode::kReplicaId:
3570       return visitor->HandleReplicaId(this);
3571     case HloOpcode::kPartitionId:
3572       return visitor->HandlePartitionId(this);
3573     case HloOpcode::kTuple:
3574       return visitor->HandleTuple(this);
3575     case HloOpcode::kMap:
3576       return visitor->HandleMap(this);
3577     case HloOpcode::kClamp:
3578       return visitor->HandleClamp(this);
3579     case HloOpcode::kReduce:
3580       return visitor->HandleReduce(this);
3581     case HloOpcode::kReduceWindow:
3582       return visitor->HandleReduceWindow(this);
3583     case HloOpcode::kSelectAndScatter:
3584       return visitor->HandleSelectAndScatter(this);
3585     case HloOpcode::kNegate:
3586       return visitor->HandleNegate(this);
3587     case HloOpcode::kExp:
3588       return visitor->HandleExp(this);
3589     case HloOpcode::kExpm1:
3590       return visitor->HandleExpm1(this);
3591     case HloOpcode::kFloor:
3592       return visitor->HandleFloor(this);
3593     case HloOpcode::kCeil:
3594       return visitor->HandleCeil(this);
3595     case HloOpcode::kClz:
3596       return visitor->HandleClz(this);
3597     case HloOpcode::kLog:
3598       return visitor->HandleLog(this);
3599     case HloOpcode::kLog1p:
3600       return visitor->HandleLog1p(this);
3601     case HloOpcode::kTanh:
3602       return visitor->HandleTanh(this);
3603     case HloOpcode::kCos:
3604       return visitor->HandleCos(this);
3605     case HloOpcode::kSin:
3606       return visitor->HandleSin(this);
3607     case HloOpcode::kSqrt:
3608       return visitor->HandleSqrt(this);
3609     case HloOpcode::kCbrt:
3610       return visitor->HandleCbrt(this);
3611     case HloOpcode::kRsqrt:
3612       return visitor->HandleRsqrt(this);
3613     case HloOpcode::kReal:
3614       return visitor->HandleReal(this);
3615     case HloOpcode::kImag:
3616       return visitor->HandleImag(this);
3617     case HloOpcode::kIsFinite:
3618       return visitor->HandleIsFinite(this);
3619     case HloOpcode::kNot:
3620       return visitor->HandleNot(this);
3621     case HloOpcode::kPopulationCount:
3622       return visitor->HandlePopulationCount(this);
3623     case HloOpcode::kBitcast:
3624       return visitor->HandleBitcast(this);
3625     case HloOpcode::kBroadcast:
3626       return visitor->HandleBroadcast(this);
3627     case HloOpcode::kPad:
3628       return visitor->HandlePad(this);
3629     case HloOpcode::kReshape:
3630       return visitor->HandleReshape(this);
3631     case HloOpcode::kDynamicReshape:
3632       return visitor->HandleDynamicReshape(this);
3633     case HloOpcode::kTranspose:
3634       return visitor->HandleTranspose(this);
3635     case HloOpcode::kReverse:
3636       return visitor->HandleReverse(this);
3637     case HloOpcode::kReducePrecision:
3638       return visitor->HandleReducePrecision(this);
3639     case HloOpcode::kSlice:
3640       return visitor->HandleSlice(this);
3641     case HloOpcode::kDynamicSlice:
3642       return visitor->HandleDynamicSlice(this);
3643     case HloOpcode::kDynamicUpdateSlice:
3644       return visitor->HandleDynamicUpdateSlice(this);
3645     case HloOpcode::kSort:
3646       return visitor->HandleSort(this);
3647     case HloOpcode::kInfeed:
3648       return visitor->HandleInfeed(this);
3649     case HloOpcode::kOutfeed:
3650       return visitor->HandleOutfeed(this);
3651     case HloOpcode::kRng:
3652       return visitor->HandleRng(this);
3653     case HloOpcode::kRngBitGenerator:
3654       return visitor->HandleRngBitGenerator(this);
3655     case HloOpcode::kRngGetAndUpdateState:
3656       return visitor->HandleRngGetAndUpdateState(this);
3657     case HloOpcode::kWhile:
3658       return visitor->HandleWhile(this);
3659     case HloOpcode::kFusion:
3660       return visitor->HandleFusion(this);
3661     case HloOpcode::kCall:
3662       return visitor->HandleCall(this);
3663     case HloOpcode::kConditional:
3664       return visitor->HandleConditional(this);
3665     case HloOpcode::kCustomCall:
3666       return visitor->HandleCustomCall(this);
3667     case HloOpcode::kAsyncStart:
3668       return visitor->HandleAsyncStart(this);
3669     case HloOpcode::kAsyncUpdate:
3670       return visitor->HandleAsyncUpdate(this);
3671     case HloOpcode::kAsyncDone:
3672       return visitor->HandleAsyncDone(this);
3673     case HloOpcode::kCopyStart:
3674       return visitor->HandleCopyStart(this);
3675     case HloOpcode::kCopyDone:
3676       return visitor->HandleCopyDone(this);
3677     case HloOpcode::kRecv:
3678       return visitor->HandleRecv(this);
3679     case HloOpcode::kRecvDone:
3680       return visitor->HandleRecvDone(this);
3681     case HloOpcode::kSend:
3682       return visitor->HandleSend(this);
3683     case HloOpcode::kSendDone:
3684       return visitor->HandleSendDone(this);
3685     case HloOpcode::kGather:
3686       return visitor->HandleGather(this);
3687     case HloOpcode::kScatter:
3688       return visitor->HandleScatter(this);
3689     case HloOpcode::kDomain:
3690       return visitor->HandleDomain(this);
3691     case HloOpcode::kAfterAll:
3692       return visitor->HandleAfterAll(this);
3693     case HloOpcode::kAddDependency:
3694       return visitor->HandleAddDependency(this);
3695     case HloOpcode::kIota:
3696       return visitor->HandleIota(this);
3697     case HloOpcode::kGetDimensionSize:
3698       return visitor->HandleGetDimensionSize(this);
3699     case HloOpcode::kSetDimensionSize:
3700       return visitor->HandleSetDimensionSize(this);
3701     case HloOpcode::kTriangularSolve:
3702       return visitor->HandleTriangularSolve(this);
3703     case HloOpcode::kCholesky:
3704       return visitor->HandleCholesky(this);
3705     case HloOpcode::kOptimizationBarrier:
3706       return visitor->HandleOptimizationBarrier(this);
3707   }
3708   return InternalError(
3709       "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3710       "please file a bug for XLA.",
3711       HloOpcodeString(opcode_));
3712 }
3713 
3714 // Explicit instantiations.
3715 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3716 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3717 
3718 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
3719 // cycle was detected, and true otherwise.
3720 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3721 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3722                          HloInstruction* child) {
3723   CHECK(child != nullptr);
3724   const int id = child->unique_id();
3725   CHECK_GE(id, 0) << "instruction may not have a parent computation";
3726   switch (visitor->GetVisitState(id)) {
3727     case Visitor::kVisiting:
3728       return false;
3729 
3730     case Visitor::kVisited:
3731       // Nothing to do
3732       return true;
3733 
3734     case Visitor::kNotVisited:
3735       dfs_stack->push_back(std::make_pair(id, child));
3736       return true;
3737   }
3738 }
3739 
3740 using InternalCompareFunction =
3741     std::function<bool(std::pair<int, const HloInstruction*>,
3742                        std::pair<int, const HloInstruction*>)>;
3743 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3744 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3745                            const InternalCompareFunction* operand_order,
3746                            bool ignore_control_predecessors) {
3747   visitor->ReserveVisitStates(root->parent()->instruction_count());
3748 
3749   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3750   //
3751   // We need to keep track of both the id and the instruction because
3752   // instructions can get deleted while they are on the stack, so we
3753   // can't always use the (potentially dead) instruction object to grab
3754   // its id.
3755   DFSStack dfs_stack;
3756   dfs_stack.emplace_back(root->unique_id(), root);
3757 
3758   do {
3759     DCHECK(!dfs_stack.empty());
3760 
3761     int current_id = dfs_stack.back().first;
3762     HloInstruction* current_node = dfs_stack.back().second;
3763     CHECK_GE(current_id, 0) << current_id << ": " << current_node
3764                             << ": instruction may not have parent computation";
3765     typename Visitor::VisitState visit_state =
3766         visitor->GetVisitState(current_id);
3767     if (visit_state == Visitor::kVisited) {
3768       dfs_stack.pop_back();
3769       VLOG(3) << "Not visiting HLO (id = " << current_id
3770               << ") as it was already visited.";
3771       continue;
3772     }
3773 
3774     if (visit_state == Visitor::kVisiting) {
3775       dfs_stack.pop_back();
3776 
3777       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3778       VLOG(2) << "Visiting HLO %" << current_node->name();
3779       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3780       visitor->SetVisitState(current_id, Visitor::kVisited);
3781       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3782       continue;
3783     }
3784 
3785     visitor->SetVisitState(current_id, Visitor::kVisiting);
3786 
3787     const size_t old_dfs_stack_size = dfs_stack.size();
3788     for (HloInstruction* child : current_node->operands()) {
3789       if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3790         PrintCycle(child, &dfs_stack);
3791         return FailedPrecondition(
3792             "A cycle is detected while visiting instruction %s",
3793             current_node->ToString());
3794       }
3795     }
3796 
3797     if (!ignore_control_predecessors) {
3798       for (HloInstruction* child : current_node->control_predecessors()) {
3799         if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3800           PrintCycle(child, &dfs_stack);
3801           return FailedPrecondition(
3802               "A cycle is detected while visiting instruction %s",
3803               current_node->ToString());
3804         }
3805       }
3806     }
3807 
3808     if (operand_order != nullptr) {
3809       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3810                 *operand_order);
3811     }
3812 
3813     // This makes the traversal order the same as what you'd expect
3814     // out of a recursive algorithm.
3815     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3816   } while (!dfs_stack.empty());
3817 
3818   return OkStatus();
3819 }
3820 
3821 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3822 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3823                               bool call_finish_visit,
3824                               bool ignore_control_predecessors) {
3825   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3826   TF_RETURN_IF_ERROR(
3827       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3828   if (call_finish_visit) {
3829     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3830   }
3831   return OkStatus();
3832 }
3833 
3834 // Explicit instantiations.
3835 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3836 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3837 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3838 Status HloInstruction::AcceptWithOperandOrder(
3839     DfsHloVisitor* visitor, const CompareFunction& operand_order,
3840     bool call_finish_visit) {
3841   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3842   InternalCompareFunction func = [&operand_order](
3843                                      std::pair<int, const HloInstruction*> a,
3844                                      std::pair<int, const HloInstruction*> b) {
3845     // Call the client's comparison function on the actual HloInstruction*
3846     // objects (ignoring the internal ids we also have in our stack entries)
3847     return operand_order(a.second, b.second);
3848   };
3849   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3850                                   /*ignore_control_predecessors=*/false));
3851   if (call_finish_visit) {
3852     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3853     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3854     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3855   }
3856   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3857   return OkStatus();
3858 }
3859 
shape() const3860 const Shape& HloInstruction::shape() const { return shape_; }
3861 
OperandIndices(const HloInstruction * operand) const3862 absl::InlinedVector<int64_t, 4> HloInstruction::OperandIndices(
3863     const HloInstruction* operand) const {
3864   absl::InlinedVector<int64_t, 4> result;
3865   for (int64_t i = 0; i < operand_count(); ++i) {
3866     if (this->operand(i) == operand) {
3867       result.push_back(i);
3868     }
3869   }
3870   return result;
3871 }
3872 
IsElementwiseBinary() const3873 bool HloInstruction::IsElementwiseBinary() const {
3874   return IsElementwise() && operand_count() == 2;
3875 }
3876 
IsElementwise() const3877 bool HloInstruction::IsElementwise() const {
3878   return IsElementwiseImpl(std::nullopt);
3879 }
3880 
IsElementwiseOnOperand(int64_t operand_idx) const3881 bool HloInstruction::IsElementwiseOnOperand(int64_t operand_idx) const {
3882   return IsElementwiseImpl(operand_idx);
3883 }
3884 
3885 namespace {
3886 
3887 // Indicates how an instruction uses a value (such as an operand).
3888 //
3889 // Does it (a) not use it, (b) use it, or (c) use it multiple times?
3890 enum class UseKind { kReuse = 0, kUse = 1, kNoUse = 2 };
3891 
3892 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3893 // in HloInstruction::OperandElementUse below.
3894 class FusionReusesParamElements {
3895  public:
Compute(int64_t i,const HloInstruction & hlo)3896   static UseKind Compute(int64_t i, const HloInstruction& hlo) {
3897     absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3898     return ComputeInternal(i, hlo, &memoization_cache);
3899   }
3900 
3901  private:
3902   static UseKind ComputeInternal(
3903       int64_t outer_param_num, const HloInstruction& hlo,
3904       absl::flat_hash_map<const HloInstruction*, UseKind>* cache);
3905 };
3906 
3907 }  // namespace
3908 
3909 // Returns how this instruction uses elements of its operand at operand_num.
OperandElementUse(const HloInstruction & instr,int64_t operand_num)3910 static UseKind OperandElementUse(const HloInstruction& instr,
3911                                  int64_t operand_num) {
3912   switch (instr.opcode()) {
3913     case HloOpcode::kBitcast:
3914     case HloOpcode::kConcatenate:
3915     case HloOpcode::kReshape:
3916     case HloOpcode::kReverse:
3917     case HloOpcode::kSlice:
3918     case HloOpcode::kTranspose:
3919     case HloOpcode::kGather:
3920       return UseKind::kUse;
3921     case HloOpcode::kPad:
3922       // Pad reuses the padding value but not the padded array elements.
3923       return operand_num > 0 ? UseKind::kReuse : UseKind::kUse;
3924     case HloOpcode::kReduce:
3925       // Reduce reuses the init values but not the operand array elements.
3926       return operand_num >= Cast<HloReduceInstruction>(&instr)->input_count()
3927                  ? UseKind::kReuse
3928                  : UseKind::kUse;
3929     case HloOpcode::kFusion:
3930       // Uses the memoizing, recursive computation defined above.
3931       return FusionReusesParamElements::Compute(operand_num,
3932                                                 *instr.fused_expression_root());
3933     case HloOpcode::kDot:
3934       // Matrix-vector dots do not reuse the matrix operand.
3935       if (instr.shape().dimensions_size() <= 1) {
3936         if ((operand_num == 0 && instr.operand(1)->shape().rank() <= 1) ||
3937             (operand_num == 1 && instr.operand(0)->shape().rank() <= 1)) {
3938           return UseKind::kUse;
3939         }
3940       }
3941       return UseKind::kReuse;
3942     case HloOpcode::kDynamicUpdateSlice:
3943       // Dynamic-update-slice reuses only start_indices.
3944       if (operand_num == 0 || operand_num == 1) {
3945         return UseKind::kUse;
3946       }
3947       return UseKind::kReuse;
3948     default:
3949       return instr.IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3950   }
3951 }
3952 
ComputeInternal(int64_t outer_param_num,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3953 UseKind FusionReusesParamElements::ComputeInternal(
3954     int64_t outer_param_num, const HloInstruction& hlo,
3955     absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3956   if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3957     if (hlo_param->parameter_number() == outer_param_num) {
3958       return UseKind::kUse;
3959     }
3960   }
3961 
3962   auto p = cache->emplace(&hlo, UseKind::kNoUse);
3963   auto value_it = p.first;
3964   const bool key_is_new = p.second;
3965 
3966   if (!key_is_new) {
3967     return value_it->second;
3968   }
3969 
3970   // Our dataflow graph has no loops, so we don't need the fixed point
3971   // computation.
3972   for (int64_t operand_num = 0; operand_num < hlo.operands().size();
3973        ++operand_num) {
3974     UseKind old_val = value_it->second;
3975 
3976     // Compute updated value.
3977     UseKind new_val = [&] {
3978       // How does the HLO use this operand.
3979       UseKind hlo_use = OperandElementUse(hlo, operand_num);
3980 
3981       // If the HLO does not use the outer operand, return previous value.
3982       if (hlo_use == UseKind::kNoUse) {
3983         return old_val;
3984       }
3985 
3986       UseKind operand_use =
3987           ComputeInternal(outer_param_num, *hlo.operand(operand_num), cache);
3988 
3989       // If the operand does not use the outer operand, return the previous
3990       // value.
3991       if (operand_use == UseKind::kNoUse) {
3992         return old_val;
3993       }
3994 
3995       // Meet operator on a lattice:
3996       //
3997       //   kReuse < kUse < kNoUse.
3998       return std::min({old_val, hlo_use, operand_use});
3999     }();
4000 
4001     value_it = cache->find(&hlo);
4002     value_it->second = new_val;
4003     // Fold() minimizes the UseKind value. If it is already minimum, we do not
4004     // have to check all the remaining operands.
4005     if (new_val == UseKind::kReuse) {
4006       break;
4007     }
4008   }
4009   return value_it->second;
4010 }
4011 
ReusesOperandElements(int64_t i) const4012 bool HloInstruction::ReusesOperandElements(int64_t i) const {
4013   return OperandElementUse(*this, i) == UseKind::kReuse;
4014 }
4015 
4016 std::optional<ShapeUtil::ShapeEqualityDescriptor>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const4017 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
4018   if (HloOpcode::kReshape != opcode_) {
4019     return std::nullopt;
4020   }
4021   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
4022                                                       shape_);
4023 }
4024 
ToString(HloInstruction::FusionKind kind)4025 std::string ToString(HloInstruction::FusionKind kind) {
4026   switch (kind) {
4027     case HloInstruction::FusionKind::kLoop:
4028       return "kLoop";
4029     case HloInstruction::FusionKind::kInput:
4030       return "kInput";
4031     case HloInstruction::FusionKind::kOutput:
4032       return "kOutput";
4033     case HloInstruction::FusionKind::kCustom:
4034       return "kCustom";
4035   }
4036 }
4037 
StringToFusionKind(const std::string & kind_name)4038 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
4039     const std::string& kind_name) {
4040   if (kind_name == "kLoop") {
4041     return HloInstruction::FusionKind::kLoop;
4042   }
4043   if (kind_name == "kInput") {
4044     return HloInstruction::FusionKind::kInput;
4045   }
4046   if (kind_name == "kOutput") {
4047     return HloInstruction::FusionKind::kOutput;
4048   }
4049   if (kind_name == "kCustom") {
4050     return HloInstruction::FusionKind::kCustom;
4051   }
4052   return InvalidArgument("Unknown fusion kind: %s", kind_name);
4053 }
4054 
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)4055 std::string FrontendAttributesToString(
4056     const FrontendAttributes& frontend_attributes) {
4057   std::vector<std::pair<std::string, std::string>> sorted_attributes(
4058       frontend_attributes.map().begin(), frontend_attributes.map().end());
4059   absl::c_sort(sorted_attributes);
4060   // Frontend attribute is a comma-separated list of attribute="value" pairs,
4061   // e.g., frontend_attributes={name="value_a",type="int32_t"}.
4062   const auto formatter = [](std::string* out,
4063                             const std::pair<std::string, std::string>& item) {
4064     absl::StrAppend(out, item.first, "=\"", item.second, "\"");
4065   };
4066   return absl::StrFormat("{%s}",
4067                          absl::StrJoin(sorted_attributes, ",", formatter));
4068 }
4069 
PaddingConfigToString(const PaddingConfig & padding)4070 std::string PaddingConfigToString(const PaddingConfig& padding) {
4071   bool has_interior_padding =
4072       absl::c_any_of(padding.dimensions(),
4073                      [](const PaddingConfig::PaddingConfigDimension& dim) {
4074                        return dim.interior_padding() != 0;
4075                      });
4076   return StrJoin(
4077       padding.dimensions(), "x",
4078       [&](std::string* out, const PaddingConfig::PaddingConfigDimension& dim) {
4079         StrAppend(
4080             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
4081             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
4082       });
4083 }
4084 
RandomDistributionToString(const RandomDistribution & distribution)4085 std::string RandomDistributionToString(const RandomDistribution& distribution) {
4086   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
4087 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)4088 std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
4089   return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
4090 }
4091 
PrecisionToString(const PrecisionConfig::Precision & precision)4092 std::string PrecisionToString(const PrecisionConfig::Precision& precision) {
4093   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
4094 }
4095 
CustomCallScheduleToString(const CustomCallSchedule & schedule)4096 static std::string CustomCallScheduleToString(
4097     const CustomCallSchedule& schedule) {
4098   return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule));
4099 }
4100 
CustomCallApiVersionToString(const CustomCallApiVersion & schedule)4101 static std::string CustomCallApiVersionToString(
4102     const CustomCallApiVersion& schedule) {
4103   return absl::AsciiStrToLower(CustomCallApiVersion_Name(schedule));
4104 }
4105 
DotDimensionNumbersToString(const DotDimensionNumbers & dnums)4106 std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums) {
4107   std::vector<std::string> result;
4108   if (!dnums.lhs_batch_dimensions().empty()) {
4109     result.push_back(StrCat("lhs_batch_dims={",
4110                             StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
4111   }
4112   result.push_back(StrCat("lhs_contracting_dims={",
4113                           StrJoin(dnums.lhs_contracting_dimensions(), ","),
4114                           "}"));
4115 
4116   if (!dnums.rhs_batch_dimensions().empty()) {
4117     result.push_back(StrCat("rhs_batch_dims={",
4118                             StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
4119   }
4120   result.push_back(StrCat("rhs_contracting_dims={",
4121                           StrJoin(dnums.rhs_contracting_dimensions(), ","),
4122                           "}"));
4123 
4124   return StrJoin(result, ", ");
4125 }
4126 
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)4127 std::string ConvolutionDimensionNumbersToString(
4128     const ConvolutionDimensionNumbers& dnums) {
4129   auto len_required = [](int64_t a, int64_t b, absl::Span<const int64_t> cs) {
4130     return std::max({a, b, cs.empty() ? 0 : *absl::c_max_element(cs)}) + 1;
4131   };
4132 
4133   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
4134   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
4135   std::vector<std::string> lhs_dims(
4136       len_required(dnums.input_batch_dimension(),
4137                    dnums.input_feature_dimension(),
4138                    dnums.input_spatial_dimensions()),
4139       "?");
4140   lhs_dims[dnums.input_batch_dimension()] = 'b';
4141   lhs_dims[dnums.input_feature_dimension()] = 'f';
4142   for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
4143     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
4144   }
4145 
4146   std::vector<std::string> rhs_dims(
4147       len_required(dnums.kernel_input_feature_dimension(),
4148                    dnums.kernel_output_feature_dimension(),
4149                    dnums.kernel_spatial_dimensions()),
4150       "?");
4151   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
4152   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
4153   for (int64_t i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
4154     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
4155   }
4156 
4157   std::vector<std::string> output_dims(
4158       len_required(dnums.output_batch_dimension(),
4159                    dnums.output_feature_dimension(),
4160                    dnums.output_spatial_dimensions()),
4161       "?");
4162   output_dims[dnums.output_batch_dimension()] = 'b';
4163   output_dims[dnums.output_feature_dimension()] = 'f';
4164   for (int64_t i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
4165     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
4166   }
4167 
4168   return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
4169                 StrJoin(output_dims, ""));
4170 }
4171 
ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups)4172 std::string ReplicaGroupsToString(
4173     absl::Span<const ReplicaGroup> replica_groups) {
4174   std::vector<std::string> replica_group_str;
4175   replica_group_str.reserve(replica_groups.size());
4176   for (const ReplicaGroup& group : replica_groups) {
4177     replica_group_str.push_back(
4178         StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
4179   }
4180   return StrCat("{", StrJoin(replica_group_str, ","), "}");
4181 }
4182 
StringToRandomAlgorithm(const std::string & name)4183 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const std::string& name) {
4184   static absl::flat_hash_map<std::string, RandomAlgorithm>* map = [] {
4185     static auto* map = new absl::flat_hash_map<std::string, RandomAlgorithm>;
4186     for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
4187       if (RandomAlgorithm_IsValid(i)) {
4188         auto value = static_cast<RandomAlgorithm>(i);
4189         (*map)[RandomAlgorithmToString(value)] = value;
4190       }
4191     }
4192     return map;
4193   }();
4194   auto found = map->find(absl::AsciiStrToLower(name));
4195   if (found == map->end()) {
4196     return InvalidArgument("Unknown algorithm");
4197   }
4198   return found->second;
4199 }
4200 
StringToRandomDistribution(const std::string & name)4201 StatusOr<RandomDistribution> StringToRandomDistribution(
4202     const std::string& name) {
4203   static absl::flat_hash_map<std::string, RandomDistribution>* map = [] {
4204     static auto* map = new absl::flat_hash_map<std::string, RandomDistribution>;
4205     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
4206       if (RandomDistribution_IsValid(i)) {
4207         auto value = static_cast<RandomDistribution>(i);
4208         (*map)[RandomDistributionToString(value)] = value;
4209       }
4210     }
4211     return map;
4212   }();
4213   auto found = map->find(absl::AsciiStrToLower(name));
4214   if (found == map->end()) {
4215     return InvalidArgument("Unknown distribution");
4216   }
4217   return found->second;
4218 }
4219 
StringToPrecision(const std::string & name)4220 StatusOr<PrecisionConfig::Precision> StringToPrecision(
4221     const std::string& name) {
4222   static absl::flat_hash_map<std::string, PrecisionConfig::Precision>* map =
4223       [] {
4224         static auto* map =
4225             new absl::flat_hash_map<std::string, PrecisionConfig::Precision>;
4226         for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
4227           if (PrecisionConfig::Precision_IsValid(i)) {
4228             auto value = static_cast<PrecisionConfig::Precision>(i);
4229             (*map)[PrecisionToString(value)] = value;
4230           }
4231         }
4232         return map;
4233       }();
4234   auto found = map->find(absl::AsciiStrToLower(name));
4235   if (found == map->end()) {
4236     return InvalidArgument("Unknown distribution");
4237   }
4238   return found->second;
4239 }
4240 
StringToCustomCallSchedule(absl::string_view name)4241 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(
4242     absl::string_view name) {
4243   static const absl::flat_hash_map<std::string, CustomCallSchedule>* map = [] {
4244     static auto* map = new absl::flat_hash_map<std::string, CustomCallSchedule>;
4245     for (int i = 0; i < CustomCallSchedule_ARRAYSIZE; i++) {
4246       if (CustomCallSchedule_IsValid(i)) {
4247         auto value = static_cast<CustomCallSchedule>(i);
4248         (*map)[CustomCallScheduleToString(value)] = value;
4249       }
4250     }
4251     return map;
4252   }();
4253   auto found = map->find(absl::AsciiStrToLower(name));
4254   if (found == map->end()) {
4255     return InvalidArgument("Unknown schedule");
4256   }
4257   return found->second;
4258 }
4259 
StringToCustomCallApiVersion(absl::string_view name)4260 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
4261     absl::string_view name) {
4262   static const absl::flat_hash_map<std::string, CustomCallApiVersion>* map =
4263       [] {
4264         static auto* map =
4265             new absl::flat_hash_map<std::string, CustomCallApiVersion>;
4266         for (int i = 0; i < CustomCallApiVersion_ARRAYSIZE; i++) {
4267           if (CustomCallApiVersion_IsValid(i)) {
4268             auto value = static_cast<CustomCallApiVersion>(i);
4269             (*map)[CustomCallApiVersionToString(value)] = value;
4270           }
4271         }
4272         return map;
4273       }();
4274   auto found = map->find(absl::AsciiStrToLower(name));
4275   if (found == map->end()) {
4276     return InvalidArgument("Unknown API version");
4277   }
4278   return found->second;
4279 }
4280 
operator <<(std::ostream & os,HloInstruction::FusionKind kind)4281 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
4282   return os << ToString(kind);
4283 }
4284 
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const4285 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
4286                                   const HloInstruction* const& rhs) const {
4287   if (rhs == nullptr) {
4288     // Nothing compares less than nullptr.
4289     return false;
4290   }
4291   if (lhs == nullptr) {
4292     return true;
4293   }
4294   auto lhs_module = lhs->GetModule();
4295   auto rhs_module = rhs->GetModule();
4296   CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
4297         (lhs_module != nullptr && rhs_module != nullptr));
4298   if (lhs_module != nullptr &&
4299       lhs_module->unique_id() != rhs_module->unique_id()) {
4300     return lhs_module->unique_id() < rhs_module->unique_id();
4301   }
4302   return lhs->unique_id() < rhs->unique_id();
4303 }
4304 
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const4305 Status HloInstruction::GetBackendConfigInternal(
4306     tensorflow::protobuf::Message* proto) const {
4307   proto->Clear();
4308 
4309   if (auto* proto_ptr = backend_config_.GetProtoPtr()) {
4310     if (proto_ptr->GetDescriptor() == proto->GetDescriptor()) {
4311       proto->CopyFrom(*proto_ptr);
4312       return OkStatus();
4313     }
4314   }
4315 
4316   auto& raw_string = raw_backend_config_string();
4317   // Empty string does not parse as valid JSON, but it's a valid backend config,
4318   // corresponding to the empty proto.
4319   if (raw_string.empty()) {
4320     return OkStatus();
4321   }
4322   TF_RETURN_IF_ERROR(tensorflow::HumanReadableJsonToProto(raw_string, proto));
4323   backend_config_.SetProto(*proto);
4324   return OkStatus();
4325 }
4326 
GetRawString() const4327 const std::string& HloInstruction::BackendConfigRep::GetRawString() const {
4328   if (proto_ && raw_string_.empty()) {
4329     raw_string_ = BackendConfigToRawString(*proto_).ValueOrDie();
4330   }
4331   return raw_string_;
4332 }
4333 
Clone() const4334 HloInstruction::BackendConfigRep HloInstruction::BackendConfigRep::Clone()
4335     const {
4336   // Prefer cloning protobuf, raw_string_ will be lazily generated if accessed.
4337   BackendConfigRep cloned;
4338   if (auto* proto = GetProtoPtr()) {
4339     cloned.SetProto(*proto);
4340   } else {
4341     cloned.raw_string_ = raw_string_;
4342   }
4343   return cloned;
4344 }
4345 
operator =(std::string raw_string)4346 HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
4347     std::string raw_string) {
4348   raw_string_ = std::move(raw_string);
4349   proto_.reset();
4350   return *this;
4351 }
4352 
operator =(const tensorflow::protobuf::Message & proto)4353 HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
4354     const tensorflow::protobuf::Message& proto) {
4355   SetProto(proto);
4356   raw_string_.clear();
4357   return *this;
4358 }
4359 
SetProto(const tensorflow::protobuf::Message & proto)4360 void HloInstruction::BackendConfigRep::SetProto(
4361     const tensorflow::protobuf::Message& proto) {
4362   proto_.reset(proto.New());
4363   proto_->CopyFrom(proto);
4364 }
4365 
operator ==(const BackendConfigRep & other) const4366 bool HloInstruction::BackendConfigRep::operator==(
4367     const BackendConfigRep& other) const {
4368   auto* proto_a = GetProtoPtr();
4369   auto* proto_b = other.GetProtoPtr();
4370   if (proto_a != nullptr && proto_b != nullptr) {
4371     using ::tensorflow::protobuf::util::MessageDifferencer;
4372     return MessageDifferencer::Equals(*proto_a, *proto_b);
4373   }
4374   // TODO(b/225956414): Consider canonicalizing raw string form.
4375   return GetRawString() == other.GetRawString();
4376 }
4377 
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)4378 /* static */ StatusOr<std::string> HloInstruction::BackendConfigToRawString(
4379     const tensorflow::protobuf::Message& proto) {
4380   std::string ret;
4381   // Pass ignore_accuracy_loss = true because estimated_cycles field can be
4382   // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
4383   // INT64_MAX, JsonFormat will return an error status, although there is no
4384   // accuracy loss for int64_t.
4385   TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
4386       proto, &ret, /*ignore_accuracy_loss=*/true));
4387   return ret;
4388 }
4389 
precision_config() const4390 const PrecisionConfig& HloInstruction::precision_config() const {
4391   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4392     return convolution->precision_config();
4393   }
4394   if (auto* dot = DynCast<HloDotInstruction>(this)) {
4395     return dot->precision_config();
4396   }
4397 
4398   if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
4399     return custom_call->precision_config();
4400   }
4401   LOG(FATAL) << "Unimplemented method.";
4402 }
4403 
mutable_precision_config()4404 PrecisionConfig* HloInstruction::mutable_precision_config() {
4405   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4406     return convolution->mutable_precision_config();
4407   }
4408   if (auto* dot = DynCast<HloDotInstruction>(this)) {
4409     return dot->mutable_precision_config();
4410   }
4411   LOG(FATAL) << "Unimplemented method.";
4412 }
4413 
GetModule() const4414 HloModule* HloInstruction::GetModule() const {
4415   if (parent_) {
4416     return parent_->parent();
4417   }
4418   return nullptr;
4419 }
4420 
UniquifyName(NameUniquer * name_uniquer)4421 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
4422   std::string parent_str = parent() == nullptr ? "noparent" : parent()->name();
4423   name_ = name_uniquer->GetUniqueName(name_);
4424 }
4425 
set_outer_dimension_partitions(const std::vector<int64_t> & outer_dimension_partitions)4426 void HloInstruction::set_outer_dimension_partitions(
4427     const std::vector<int64_t>& outer_dimension_partitions) {
4428   outer_dimension_partitions_ = outer_dimension_partitions;
4429 }
4430 
SortInstructionUsersAndControlLists(const MappedPtrContainerSorter<HloInstruction>::MapPtrFn & map_fn,const HloInstruction & sorted_instruction)4431 void HloInstruction::SortInstructionUsersAndControlLists(
4432     const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
4433     const HloInstruction& sorted_instruction) {
4434   using Sorter = MappedPtrContainerSorter<HloInstruction>;
4435   auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4436                              sorted_instruction.users_, users_);
4437   if (!status.ok()) {
4438     LOG(ERROR) << "Failed to sort instruction users for " << name() << "; "
4439                << status;
4440   }
4441   user_map_.clear();
4442   for (uint64_t i = 0; i < users_.size(); ++i) {
4443     user_map_[users_[i]] = i;
4444   }
4445   status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4446                         sorted_instruction.control_predecessors_,
4447                         control_predecessors_);
4448   if (!status.ok()) {
4449     LOG(ERROR) << "Failed to sort instruction control predecessors for "
4450                << name() << "; " << status;
4451   }
4452   status =
4453       Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4454                    sorted_instruction.control_successors_, control_successors_);
4455   if (!status.ok()) {
4456     LOG(ERROR) << "Failed to sort instruction control successors for " << name()
4457                << "; " << status;
4458   }
4459 }
4460 
4461 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const4462 int64_t HloInstruction::feature_index() const {
4463   return Cast<HloBatchNormInstruction>(this)->feature_index();
4464 }
4465 
epsilon() const4466 float HloInstruction::epsilon() const {
4467   return Cast<HloBatchNormInstruction>(this)->epsilon();
4468 }
4469 
fft_type() const4470 FftType HloInstruction::fft_type() const {
4471   return Cast<HloFftInstruction>(this)->fft_type();
4472 }
4473 
fft_length() const4474 const std::vector<int64_t>& HloInstruction::fft_length() const {
4475   return Cast<HloFftInstruction>(this)->fft_length();
4476 }
4477 
concatenate_dimension() const4478 int64_t HloInstruction::concatenate_dimension() const {
4479   return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
4480 }
4481 
dimension() const4482 int64_t HloInstruction::dimension() const {
4483   if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
4484     return set_size->dimension();
4485   }
4486   return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
4487 }
4488 
inferred_dimension() const4489 int64_t HloInstruction::inferred_dimension() const {
4490   return Cast<HloReshapeInstruction>(this)->inferred_dimension();
4491 }
4492 
IsRank2Transpose() const4493 bool HloInstruction::IsRank2Transpose() const {
4494   auto transpose = DynCast<HloTransposeInstruction>(this);
4495   return transpose != nullptr && transpose->IsRank2Transpose();
4496 }
4497 
slice_starts(int64_t dimension) const4498 int64_t HloInstruction::slice_starts(int64_t dimension) const {
4499   return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
4500 }
4501 
slice_starts() const4502 const std::vector<int64_t>& HloInstruction::slice_starts() const {
4503   return Cast<HloSliceInstruction>(this)->slice_starts();
4504 }
4505 
mutable_slice_starts()4506 std::vector<int64_t>* HloInstruction::mutable_slice_starts() {
4507   return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
4508 }
4509 
slice_limits(int64_t dimension) const4510 int64_t HloInstruction::slice_limits(int64_t dimension) const {
4511   return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
4512 }
4513 
slice_limits() const4514 const std::vector<int64_t>& HloInstruction::slice_limits() const {
4515   return Cast<HloSliceInstruction>(this)->slice_limits();
4516 }
4517 
mutable_slice_limits()4518 std::vector<int64_t>* HloInstruction::mutable_slice_limits() {
4519   return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
4520 }
4521 
slice_strides(int64_t dimension) const4522 int64_t HloInstruction::slice_strides(int64_t dimension) const {
4523   return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
4524 }
4525 
slice_strides() const4526 const std::vector<int64_t>& HloInstruction::slice_strides() const {
4527   return Cast<HloSliceInstruction>(this)->slice_strides();
4528 }
4529 
mutable_slice_strides()4530 std::vector<int64_t>* HloInstruction::mutable_slice_strides() {
4531   return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
4532 }
4533 
literal() const4534 const Literal& HloInstruction::literal() const {
4535   return Cast<HloConstantInstruction>(this)->literal();
4536 }
4537 
IsConstant() const4538 bool HloInstruction::IsConstant() const {
4539   return DynCast<HloConstantInstruction>(this) != nullptr;
4540 }
4541 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)4542 void HloInstruction::RelayoutConstant(const Layout& new_layout,
4543                                       const ShapeIndex& shape_index) {
4544   Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
4545 }
4546 
4547 // Delegates to HloCallableInstruction::AppendInstructionIntoCalledComputation.
AppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)4548 HloInstruction* HloInstruction::AppendInstructionIntoCalledComputation(
4549     HloInstruction* instruction_to_append, bool add_output) {
4550   return Cast<HloCallableInstruction>(this)
4551       ->AppendInstructionIntoCalledComputation(instruction_to_append,
4552                                                add_output);
4553 }
4554 
AddFusionOperand(HloInstruction * new_operand)4555 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
4556   return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
4557 }
4558 
4559 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)4560 void HloInstruction::MergeFusionInstruction(
4561     HloInstruction* instruction_to_merge) {
4562   return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
4563       Cast<HloFusionInstruction>(instruction_to_merge));
4564 }
4565 
4566 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)4567 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
4568     HloInstruction* instruction_to_merge) {
4569   return Cast<HloFusionInstruction>(this)
4570       ->MergeFusionInstructionIntoMultiOutput(
4571           Cast<HloFusionInstruction>(instruction_to_merge));
4572 }
4573 
FuseInstruction(HloInstruction * instruction_to_fuse)4574 HloInstruction* HloInstruction::FuseInstruction(
4575     HloInstruction* instruction_to_fuse) {
4576   return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4577 }
4578 
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4579 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4580     HloInstruction* instruction_to_fuse) {
4581   return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4582       instruction_to_fuse);
4583 }
4584 
fused_instructions_computation() const4585 HloComputation* HloInstruction::fused_instructions_computation() const {
4586   return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4587 }
4588 
fused_expression_root() const4589 HloInstruction* HloInstruction::fused_expression_root() const {
4590   return Cast<HloFusionInstruction>(this)->fused_expression_root();
4591 }
4592 
4593 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4594     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4595 HloInstruction::fused_instructions() const {
4596   return Cast<HloFusionInstruction>(this)->fused_instructions();
4597 }
4598 
4599 const tensorflow::gtl::iterator_range<
4600     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4601 HloInstruction::fused_instructions() {
4602   return Cast<HloFusionInstruction>(this)->fused_instructions();
4603 }
4604 
fused_instruction_count() const4605 int64_t HloInstruction::fused_instruction_count() const {
4606   return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4607 }
4608 
fused_parameter(int64_t parameter_number) const4609 HloInstruction* HloInstruction::fused_parameter(
4610     int64_t parameter_number) const {
4611   return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4612 }
4613 
fused_parameters() const4614 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4615   return Cast<HloFusionInstruction>(this)->fused_parameters();
4616 }
4617 
IsMultiOutputFusion() const4618 const bool HloInstruction::IsMultiOutputFusion() const {
4619   const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4620   return fusion != nullptr && fusion->IsMultiOutputFusion();
4621 }
4622 
fusion_kind() const4623 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4624   return Cast<HloFusionInstruction>(this)->fusion_kind();
4625 }
4626 
set_fusion_kind(FusionKind kind)4627 void HloInstruction::set_fusion_kind(FusionKind kind) {
4628   return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4629 }
4630 
random_distribution() const4631 RandomDistribution HloInstruction::random_distribution() const {
4632   return Cast<HloRngInstruction>(this)->random_distribution();
4633 }
4634 
parameter_number() const4635 int64_t HloInstruction::parameter_number() const {
4636   return Cast<HloParameterInstruction>(this)->parameter_number();
4637 }
4638 
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4639 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4640     absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4641   return Cast<HloParameterInstruction>(this)
4642       ->set_parameter_replicated_at_leaf_buffers(
4643           parameter_replicated_at_leaf_buffers);
4644 }
4645 
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4646 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4647     const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4648   return Cast<HloParameterInstruction>(this)
4649       ->set_parameter_replicated_at_leaf_buffers(
4650           parameter_replicated_at_leaf_buffers);
4651 }
4652 
4653 const std::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4654 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4655   return Cast<HloParameterInstruction>(this)
4656       ->parameter_replicated_at_leaf_buffers();
4657 }
4658 
tuple_index() const4659 int64_t HloInstruction::tuple_index() const {
4660   return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4661 }
4662 
set_tuple_index(int64_t new_tuple_index)4663 void HloInstruction::set_tuple_index(int64_t new_tuple_index) {
4664   return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4665       new_tuple_index);
4666 }
4667 
exponent_bits() const4668 int32_t HloInstruction::exponent_bits() const {
4669   return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4670 }
4671 
mantissa_bits() const4672 int32_t HloInstruction::mantissa_bits() const {
4673   return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4674 }
4675 
infeed_config() const4676 std::string HloInstruction::infeed_config() const {
4677   return Cast<HloInfeedInstruction>(this)->infeed_config();
4678 }
4679 
set_infeed_config(const std::string & config)4680 void HloInstruction::set_infeed_config(const std::string& config) {
4681   return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4682 }
4683 
outfeed_shape() const4684 const Shape& HloInstruction::outfeed_shape() const {
4685   return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4686 }
4687 
mutable_outfeed_shape()4688 Shape* HloInstruction::mutable_outfeed_shape() {
4689   return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4690 }
4691 
outfeed_config() const4692 const std::string& HloInstruction::outfeed_config() const {
4693   return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4694 }
4695 
set_outfeed_config(const std::string & config)4696 void HloInstruction::set_outfeed_config(const std::string& config) {
4697   return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4698 }
4699 
replica_groups() const4700 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4701   return Cast<HloCollectiveInstruction>(this)->replica_groups();
4702 }
4703 
4704 const std::vector<std::pair<int64_t, int64_t>>&
source_target_pairs() const4705 HloInstruction::source_target_pairs() const {
4706   return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4707 }
4708 
channel_id() const4709 std::optional<int64_t> HloInstruction::channel_id() const {
4710   return Cast<HloChannelInstruction>(this)->channel_id();
4711 }
4712 
set_channel_id(const std::optional<int64_t> & channel_id)4713 void HloInstruction::set_channel_id(const std::optional<int64_t>& channel_id) {
4714   return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4715 }
4716 
4717 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4718 HloInstruction::convolution_dimension_numbers() const {
4719   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4720     return convolution->convolution_dimension_numbers();
4721   }
4722   if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4723     return custom_call->convolution_dimension_numbers();
4724   }
4725   LOG(FATAL) << "Unimplemented method.";
4726 }
4727 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4728 void HloInstruction::set_convolution_dimension_numbers(
4729     const ConvolutionDimensionNumbers& dnums) {
4730   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4731     convolution->set_convolution_dimension_numbers(dnums);
4732   } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4733     custom_call->set_convolution_dimension_numbers(dnums);
4734   } else {
4735     LOG(FATAL) << "Unimplemented method.";
4736   }
4737 }
4738 
feature_group_count() const4739 int64_t HloInstruction::feature_group_count() const {
4740   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4741     return convolution->feature_group_count();
4742   }
4743   return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4744 }
4745 
set_feature_group_count(int64_t feature_group_count)4746 void HloInstruction::set_feature_group_count(int64_t feature_group_count) {
4747   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4748     return convolution->set_feature_group_count(feature_group_count);
4749   }
4750   Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4751       feature_group_count);
4752 }
4753 
batch_group_count() const4754 int64_t HloInstruction::batch_group_count() const {
4755   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4756     return convolution->batch_group_count();
4757   }
4758   return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4759 }
4760 
set_batch_group_count(int64_t batch_group_count)4761 void HloInstruction::set_batch_group_count(int64_t batch_group_count) {
4762   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4763     return convolution->set_batch_group_count(batch_group_count);
4764   }
4765   Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4766       batch_group_count);
4767 }
4768 
select() const4769 HloComputation* HloInstruction::select() const {
4770   return Cast<HloSelectAndScatterInstruction>(this)->select();
4771 }
4772 
scatter() const4773 HloComputation* HloInstruction::scatter() const {
4774   return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4775 }
4776 
set_select(HloComputation * computation)4777 void HloInstruction::set_select(HloComputation* computation) {
4778   return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4779 }
4780 
set_scatter(HloComputation * computation)4781 void HloInstruction::set_scatter(HloComputation* computation) {
4782   return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4783 }
4784 
custom_call_target() const4785 const std::string& HloInstruction::custom_call_target() const {
4786   return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4787 }
set_custom_call_target(absl::string_view target)4788 void HloInstruction::set_custom_call_target(absl::string_view target) {
4789   Cast<HloCustomCallInstruction>(this)->set_custom_call_target(target);
4790 }
4791 
padding_config() const4792 const PaddingConfig& HloInstruction::padding_config() const {
4793   return Cast<HloPadInstruction>(this)->padding_config();
4794 }
4795 
padding_type() const4796 PaddingType HloInstruction::padding_type() const {
4797   return Cast<HloCustomCallInstruction>(this)->padding_type();
4798 }
4799 
mutable_padding_config()4800 PaddingConfig* HloInstruction::mutable_padding_config() {
4801   return Cast<HloPadInstruction>(this)->mutable_padding_config();
4802 }
4803 
slice_sizes(int64_t dimension) const4804 int64_t HloInstruction::slice_sizes(int64_t dimension) const {
4805   return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4806 }
4807 
dynamic_slice_sizes() const4808 const std::vector<int64_t>& HloInstruction::dynamic_slice_sizes() const {
4809   return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4810 }
4811 
4812 const std::vector<std::vector<int64_t>>&
dynamic_slice_sizes_list() const4813 HloInstruction::dynamic_slice_sizes_list() const {
4814   return Cast<HloCollectivePermuteInstruction>(this)
4815       ->dynamic_slice_sizes_list();
4816 }
4817 
gather_dimension_numbers() const4818 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4819   return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4820 }
4821 
gather_slice_sizes() const4822 absl::Span<const int64_t> HloInstruction::gather_slice_sizes() const {
4823   return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4824 }
4825 
scatter_dimension_numbers() const4826 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4827     const {
4828   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4829 }
4830 
dot_dimension_numbers() const4831 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4832   return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4833 }
4834 
operand_side_metadata() const4835 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4836   return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4837 }
4838 
user_side_metadata() const4839 const DomainMetadata& HloInstruction::user_side_metadata() const {
4840   return Cast<HloDomainInstruction>(this)->user_side_metadata();
4841 }
4842 
IsAsynchronous() const4843 bool HloInstruction::IsAsynchronous() const {
4844   return opcode() == HloOpcode::kAsyncStart ||
4845          opcode() == HloOpcode::kAsyncUpdate ||
4846          opcode() == HloOpcode::kAsyncDone;
4847 }
4848 
async_wrapped_computation() const4849 HloComputation* HloInstruction::async_wrapped_computation() const {
4850   CHECK(IsAsynchronous());
4851   return called_computations()[0];
4852 }
4853 
async_wrapped_instruction() const4854 HloInstruction* HloInstruction::async_wrapped_instruction() const {
4855   return Cast<HloAsyncInstruction>(this)->async_wrapped_instruction();
4856 }
4857 
async_wrapped_opcode() const4858 HloOpcode HloInstruction::async_wrapped_opcode() const {
4859   return Cast<HloAsyncInstruction>(this)->async_wrapped_opcode();
4860 }
4861 
async_group_id() const4862 std::optional<int64_t> HloInstruction::async_group_id() const {
4863   return Cast<HloAsyncInstruction>(this)->async_group_id();
4864 }
4865 
set_async_group_id(std::optional<int64_t> async_group_id)4866 void HloInstruction::set_async_group_id(std::optional<int64_t> async_group_id) {
4867   Cast<HloAsyncInstruction>(this)->set_async_group_id(async_group_id);
4868 }
4869 
async_execution_thread() const4870 absl::string_view HloInstruction::async_execution_thread() const {
4871   return Cast<HloAsyncInstruction>(this)->async_execution_thread();
4872 }
4873 
set_async_execution_thread(absl::string_view async_execution_thread)4874 void HloInstruction::set_async_execution_thread(
4875     absl::string_view async_execution_thread) {
4876   Cast<HloAsyncInstruction>(this)->set_async_execution_thread(
4877       async_execution_thread);
4878 }
4879 
set_called_computations_execution_thread(absl::string_view async_execution_thread,bool skip_async_execution_thread_overwrite)4880 void HloInstruction::set_called_computations_execution_thread(
4881     absl::string_view async_execution_thread,
4882     bool skip_async_execution_thread_overwrite) {
4883   Cast<HloCallableInstruction>(this)->RecursivelySetComputationsThreadName(
4884       async_execution_thread, skip_async_execution_thread_overwrite);
4885 }
4886 
is_cross_program_prefetch() const4887 bool HloInstruction::is_cross_program_prefetch() const {
4888   return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4889 }
4890 
comparison_direction() const4891 ComparisonDirection HloInstruction::comparison_direction() const {
4892   return Cast<HloCompareInstruction>(this)->direction();
4893 }
4894 
comparison_order() const4895 ComparisonOrder HloInstruction::comparison_order() const {
4896   return Cast<HloCompareInstruction>(this)->order();
4897 }
4898 
triangular_solve_options() const4899 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4900   return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4901 }
4902 
cholesky_options() const4903 const CholeskyOptions& HloInstruction::cholesky_options() const {
4904   return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4905 }
4906 
4907 }  // namespace xla
4908