• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <algorithm>
19 #include <iostream>
20 #include <ostream>
21 #include <set>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/ascii.h"
32 #include "absl/strings/escaping.h"
33 #include "absl/strings/numbers.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_join.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/protobuf_util.h"
41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
42 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
43 #include "tensorflow/compiler/xla/service/hlo_computation.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
47 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
49 #include "tensorflow/compiler/xla/service/name_uniquer.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/map_util.h"
57 #include "tensorflow/core/platform/human_readable_json.h"
58 #include "tensorflow/core/platform/logging.h"
59 
60 namespace xla {
61 
62 using absl::CEscape;
63 using absl::StrAppend;
64 using absl::StrCat;
65 using absl::StrJoin;
66 
67 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)68 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
69     const HloInstructionProto& proto,
70     const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
71     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
72     bool prohibit_empty_literal) {
73   TF_RET_CHECK(!proto.opcode().empty());
74   HloOpcode opcode;
75   auto opcode_or = StringToHloOpcode(proto.opcode());
76   absl::optional<ComparisonDirection> comparison_direction;
77   if (opcode_or.ok()) {
78     opcode = opcode_or.ConsumeValueOrDie();
79   } else {
80     // Unknown opcode. Try auto-upgrading deprecated "less-than",
81     // "greater-than", etc opcodes, which are now rolled into the kCompare
82     // opcode.
83     if (proto.opcode() == "equal-to") {
84       comparison_direction = ComparisonDirection::kEq;
85     } else if (proto.opcode() == "not-equal-to") {
86       comparison_direction = ComparisonDirection::kNe;
87     } else if (proto.opcode() == "greater-than-or-equal-to") {
88       comparison_direction = ComparisonDirection::kGe;
89     } else if (proto.opcode() == "greater-than") {
90       comparison_direction = ComparisonDirection::kGt;
91     } else if (proto.opcode() == "less-than-or-equal-to") {
92       comparison_direction = ComparisonDirection::kLe;
93     } else if (proto.opcode() == "less-than") {
94       comparison_direction = ComparisonDirection::kLt;
95     }
96     if (comparison_direction) {
97       opcode = HloOpcode::kCompare;
98     } else {
99       return InvalidArgument("Unknown opcode: %s", proto.opcode());
100     }
101   }
102 
103   TF_RET_CHECK(proto.has_shape());
104 
105   std::unique_ptr<HloInstruction> instruction;
106   const auto operands = [&instruction_map, &proto](int index) {
107     return instruction_map.at(proto.operand_ids(index));
108   };
109   const auto all_operands = [&instruction_map, &proto]() {
110     std::vector<HloInstruction*> result(proto.operand_ids_size());
111     std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
112                    result.begin(), [&instruction_map](int64_t operand_id) {
113                      return instruction_map.at(operand_id);
114                    });
115     return result;
116   };
117   const auto computations = [&computation_map, &proto](int index) {
118     return computation_map.at(proto.called_computation_ids(index));
119   };
120   const auto all_computations = [&computation_map, &proto]() {
121     std::vector<HloComputation*> result(proto.called_computation_ids_size());
122     std::transform(proto.called_computation_ids().begin(),
123                    proto.called_computation_ids().end(), result.begin(),
124                    [&computation_map](int64_t computation_id) {
125                      return computation_map.at(computation_id);
126                    });
127     return result;
128   };
129 
130   TF_RET_CHECK(
131       absl::c_all_of(proto.operand_ids(),
132                      [&](int64_t id) { return instruction_map.contains(id); }))
133       << proto.name() << " instruction contains invalid operand id(s)";
134 
135   TF_RET_CHECK(
136       absl::c_all_of(proto.called_computation_ids(),
137                      [&](int64_t id) { return computation_map.contains(id); }))
138       << proto.name() << " instruction references invalid computation id(s)";
139 
140   Shape shape(proto.shape());
141   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
142 
143   absl::optional<int> arity = HloOpcodeArity(opcode);
144   if (arity) {
145     TF_RET_CHECK(proto.operand_ids_size() == *arity)
146         << proto.opcode() << " instruction should have " << *arity
147         << " operands but sees " << proto.operand_ids_size();
148   }
149 
150   switch (opcode) {
151     // Ops migrated to subclasses.
152     case HloOpcode::kBatchNormTraining:
153       instruction =
154           CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
155                                   proto.epsilon(), proto.feature_index());
156       break;
157     case HloOpcode::kBatchNormInference:
158       instruction = CreateBatchNormInference(
159           shape, operands(0), operands(1), operands(2), operands(3),
160           operands(4), proto.epsilon(), proto.feature_index());
161       break;
162     case HloOpcode::kBatchNormGrad:
163       instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
164                                         operands(2), operands(3), operands(4),
165                                         proto.epsilon(), proto.feature_index());
166       break;
167     case HloOpcode::kFft: {
168       std::vector<int64> fft_length(proto.fft_length().begin(),
169                                     proto.fft_length().end());
170       instruction = CreateFft(shape, operands(0), proto.fft_type(),
171                               absl::Span<const int64>(fft_length));
172       break;
173     }
174     case HloOpcode::kCopyStart: {
175       instruction = CreateCopyStart(shape, operands(0),
176                                     proto.is_cross_program_prefetch());
177       break;
178     }
179     case HloOpcode::kCompare: {
180       // Auto-upgraded from deprecated opcode skips the following.
181       if (!comparison_direction) {
182         TF_ASSIGN_OR_RETURN(
183             comparison_direction,
184             StringToComparisonDirection(proto.comparison_direction()));
185       }
186       auto comparison_type_str = proto.comparison_type();
187       if (!comparison_type_str.empty()) {
188         // If a comparison type is specified, it *must* be valid.
189         TF_ASSIGN_OR_RETURN(auto comparison_type,
190                             StringToComparisonType(comparison_type_str));
191         instruction = CreateCompare(shape, operands(0), operands(1),
192                                     *comparison_direction, comparison_type);
193       } else {
194         // Allow the specify of comparison type to be optional.
195         // The comparison type will be determined by the types of the operands.
196         instruction = CreateCompare(shape, operands(0), operands(1),
197                                     *comparison_direction);
198       }
199       break;
200     }
201     case HloOpcode::kTriangularSolve: {
202       instruction = CreateTriangularSolve(shape, operands(0), operands(1),
203                                           proto.triangular_solve_options());
204       break;
205     }
206     case HloOpcode::kCholesky: {
207       instruction =
208           CreateCholesky(shape, operands(0), proto.cholesky_options());
209       break;
210     }
211     case HloOpcode::kSend:
212       instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
213                                proto.is_host_transfer());
214       break;
215     case HloOpcode::kSendDone:
216       instruction = CreateSendDone(operands(0), proto.is_host_transfer());
217       break;
218     case HloOpcode::kRecv:
219       instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
220                                proto.channel_id(), proto.is_host_transfer());
221       break;
222     case HloOpcode::kRecvDone:
223       instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
224       break;
225     case HloOpcode::kReverse:
226       instruction = CreateReverse(shape, operands(0),
227                                   std::vector<int64>(proto.dimensions().begin(),
228                                                      proto.dimensions().end()));
229       break;
230     case HloOpcode::kConcatenate:
231       TF_RET_CHECK(proto.dimensions_size() == 1)
232           << "Concatenate instruction should have 1 dimension but sees "
233           << proto.dimensions_size();
234       instruction =
235           CreateConcatenate(shape, all_operands(), proto.dimensions(0));
236       break;
237     case HloOpcode::kConditional: {
238       TF_RET_CHECK(proto.called_computation_ids_size() > 0)
239           << "conditional should have at least 1 called computation";
240       if (operands(0)->shape().element_type() == PRED) {
241         TF_RET_CHECK(proto.called_computation_ids_size() == 2)
242             << "conditional should have exactly 2 called computations but got "
243             << proto.called_computation_ids_size();
244       }
245       TF_RET_CHECK(proto.operand_ids_size() ==
246                    proto.called_computation_ids_size() + 1)
247           << "conditional should have one branch_index operand plus one "
248              "operand per called computation but got "
249           << proto.operand_ids_size() << " operands for "
250           << proto.called_computation_ids_size() << " branch computations";
251       auto cond_operands = all_operands();
252       instruction =
253           CreateConditional(shape, cond_operands[0], all_computations(),
254                             absl::MakeSpan(cond_operands).subspan(1));
255       break;
256     }
257     case HloOpcode::kReduce:
258       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
259           << "Reduce instruction should have an even number of operands but "
260              "sees "
261           << proto.operand_ids_size();
262       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
263           << "Reduce instruction should have 1 called computation but sees "
264           << proto.called_computation_ids_size();
265       {
266         const auto reduce_operands = all_operands();
267         auto inputs = absl::MakeSpan(reduce_operands)
268                           .subspan(0, reduce_operands.size() / 2);
269         auto init_values =
270             absl::MakeSpan(reduce_operands)
271                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
272         instruction =
273             CreateReduce(shape, inputs, init_values,
274                          std::vector<int64>(proto.dimensions().begin(),
275                                             proto.dimensions().end()),
276                          computations(0));
277       }
278       break;
279     case HloOpcode::kSort: {
280       TF_RET_CHECK(proto.operand_ids_size() >= 1)
281           << "Sort instruction should have at least 1 operand but has "
282           << proto.operand_ids_size();
283       TF_RET_CHECK(proto.dimensions().size() == 1)
284           << "Sort instruction should have 1 dimension";
285       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
286           << "Sort instruction should one called computation but sees "
287           << proto.called_computation_ids_size();
288       auto sort_operands = all_operands();
289       instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
290                                computations(0), proto.is_stable());
291       break;
292     }
293     case HloOpcode::kTranspose:
294       instruction =
295           CreateTranspose(shape, operands(0),
296                           std::vector<int64>(proto.dimensions().begin(),
297                                              proto.dimensions().end()));
298       break;
299     case HloOpcode::kBroadcast:
300       instruction =
301           CreateBroadcast(shape, operands(0),
302                           std::vector<int64>(proto.dimensions().begin(),
303                                              proto.dimensions().end()));
304       break;
305     case HloOpcode::kMap:
306       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
307           << "Map instruction should have 1 called computation but sees "
308           << proto.called_computation_ids_size();
309       instruction = CreateMap(shape, all_operands(), computations(0));
310       break;
311     case HloOpcode::kSlice: {
312       std::vector<int64> slice_starts, slice_limits, slice_strides;
313       for (const HloInstructionProto::SliceDimensions& slice_dimensions :
314            proto.slice_dimensions()) {
315         slice_starts.push_back(slice_dimensions.start());
316         slice_limits.push_back(slice_dimensions.limit());
317         slice_strides.push_back(slice_dimensions.stride());
318       }
319       instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
320                                 slice_strides);
321       break;
322     }
323     case HloOpcode::kConstant: {
324       // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
325       if (proto.has_literal()) {
326         TF_ASSIGN_OR_RETURN(
327             auto literal,
328             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
329         instruction = CreateConstant(std::move(literal));
330         // Literal's shape may have no/different tiling info.
331         TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
332             instruction->shape(), shape))
333             << instruction->shape().ToString(true) << " vs "
334             << shape.ToString(true);
335         *instruction->mutable_shape() = shape;
336       } else {
337         instruction = absl::make_unique<HloConstantInstruction>(shape);
338       }
339       break;
340     }
341     case HloOpcode::kTrace: {
342       TF_RET_CHECK(proto.has_literal());
343       TF_ASSIGN_OR_RETURN(
344           auto literal,
345           Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
346       instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
347       break;
348     }
349     case HloOpcode::kFusion: {
350       // In the proto, fused computations are held exclusively within the
351       // HloInstructionProto and do not appear as an HloComputationProto within
352       // the HloModuleProto.
353       TF_RET_CHECK(!proto.fusion_kind().empty());
354       TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
355                           StringToFusionKind(proto.fusion_kind()));
356 
357       // Find the fused computation and set its fusion instruction.
358       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
359           << "Expect 1 called computation for fusion instruction but sees "
360           << proto.called_computation_ids_size();
361       const int64_t fusion_id = proto.called_computation_ids(0);
362       auto* fused_computation =
363           tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
364       TF_RET_CHECK(fused_computation != nullptr)
365           << "No fusion computation with id " << fusion_id;
366       instruction =
367           CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
368       break;
369     }
370     case HloOpcode::kRng:
371       instruction = CreateRng(shape, proto.distribution(), all_operands());
372       break;
373     case HloOpcode::kRngBitGenerator:
374       instruction =
375           CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
376       break;
377     case HloOpcode::kRngGetAndUpdateState:
378       instruction = CreateRngGetAndUpdateState(shape, proto.delta());
379       break;
380     case HloOpcode::kParameter:
381       instruction =
382           CreateParameter(proto.parameter_number(), shape, proto.name());
383       if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
384         instruction->set_parameter_replicated_at_leaf_buffers(
385             proto.parameter_replication().replicated_at_leaf_buffers());
386       }
387       break;
388     case HloOpcode::kGetTupleElement:
389       instruction =
390           CreateGetTupleElement(shape, operands(0), proto.tuple_index());
391       break;
392     case HloOpcode::kReducePrecision:
393       instruction = CreateReducePrecision(
394           shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
395       break;
396     case HloOpcode::kInfeed: {
397       TF_RET_CHECK(shape.IsTuple() &&
398                    (ShapeUtil::TupleElementCount(shape) == 2))
399           << "Infeed should have a tuple shape with 2 operands, but has: "
400           << shape;
401       const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
402       instruction =
403           CreateInfeed(data_shape, operands(0), proto.infeed_config());
404     } break;
405     case HloOpcode::kOutfeed: {
406       Shape outfeed_shape(proto.outfeed_shape());
407       TF_RETURN_IF_ERROR(
408           ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
409       instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
410                                   proto.outfeed_config());
411       break;
412     }
413     case HloOpcode::kAllGather:
414     case HloOpcode::kAllGatherStart: {
415       absl::optional<int64> channel_id;
416       if (proto.channel_id() > 0) {
417         channel_id = proto.channel_id();
418       }
419 
420       TF_RET_CHECK(proto.dimensions_size() == 1)
421           << "AllGather cannot have more than 1 all-gather dimensions";
422       int64_t all_gather_dimension = proto.dimensions(0);
423       if (opcode == HloOpcode::kAllGather) {
424         instruction = CreateAllGather(
425             shape, all_operands(), all_gather_dimension,
426             std::vector<ReplicaGroup>(proto.replica_groups().begin(),
427                                       proto.replica_groups().end()),
428             proto.constrain_layout(), channel_id,
429             proto.use_global_device_ids());
430       } else {
431         instruction = CreateAllGatherStart(
432             shape, all_operands(), all_gather_dimension,
433             std::vector<ReplicaGroup>(proto.replica_groups().begin(),
434                                       proto.replica_groups().end()),
435             proto.constrain_layout(), channel_id,
436             proto.use_global_device_ids());
437       }
438       break;
439     }
440     case HloOpcode::kAllReduce:
441     case HloOpcode::kAllReduceStart:
442     case HloOpcode::kReduceScatter: {
443       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
444           << "AllReduce should have 1 called computation but sees "
445           << proto.called_computation_ids_size();
446       TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
447           << "AllReduce cannot have both channel_id() and all_reduce_id()";
448       absl::optional<int64> channel_id;
449       if (proto.channel_id() > 0) {
450         channel_id = proto.channel_id();
451       }
452       if (proto.all_reduce_id() > 0) {
453         channel_id = proto.all_reduce_id();
454       }
455       std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(),
456                                                proto.replica_groups().end());
457       if (opcode == HloOpcode::kAllReduce) {
458         instruction =
459             CreateAllReduce(shape, all_operands(), computations(0),
460                             replica_groups, proto.constrain_layout(),
461                             channel_id, proto.use_global_device_ids());
462       } else if (opcode == HloOpcode::kReduceScatter) {
463         TF_RET_CHECK(proto.dimensions_size() == 1)
464             << "ReduceScatter cannot have more than 1 scatter dimensions";
465         int64_t scatter_dimension = proto.dimensions(0);
466         instruction = CreateReduceScatter(
467             shape, all_operands(), computations(0), replica_groups,
468             proto.constrain_layout(), channel_id, proto.use_global_device_ids(),
469             scatter_dimension);
470       } else {
471         instruction =
472             CreateAllReduceStart(shape, all_operands(), computations(0),
473                                  replica_groups, proto.constrain_layout(),
474                                  channel_id, proto.use_global_device_ids());
475       }
476       break;
477     }
478     case HloOpcode::kAllToAll: {
479       absl::optional<int64> channel_id;
480       if (proto.channel_id() > 0) {
481         channel_id = proto.channel_id();
482       }
483       absl::optional<int64> split_dimension;
484       if (proto.dimensions_size() > 0) {
485         TF_RET_CHECK(proto.dimensions_size() == 1)
486             << "AllToAll cannot have more than 1 dimension (split dimension)";
487         TF_RET_CHECK(all_operands().size() == 1)
488             << "AllToAll must have a single operand when the split dimension "
489                "is specified";
490         split_dimension = proto.dimensions(0);
491       }
492       instruction = CreateAllToAll(
493           shape, all_operands(),
494           /*replica_groups=*/
495           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
496                                     proto.replica_groups().end()),
497           /*constrain_layout=*/proto.constrain_layout(),
498           /*channel_id=*/channel_id, split_dimension);
499       break;
500     }
501     case HloOpcode::kCollectivePermute:
502     case HloOpcode::kCollectivePermuteStart: {
503       TF_RET_CHECK(proto.operand_ids().size() == 1 ||
504                    proto.operand_ids().size() == 4);
505       std::vector<std::pair<int64_t, int64_t>> source_target_pairs(
506           proto.source_target_pairs_size());
507       absl::optional<int64_t> channel_id;
508       if (proto.channel_id() > 0) {
509         channel_id = proto.channel_id();
510       }
511       for (int i = 0; i < source_target_pairs.size(); ++i) {
512         source_target_pairs[i].first = proto.source_target_pairs(i).source();
513         source_target_pairs[i].second = proto.source_target_pairs(i).target();
514       }
515       if (proto.dynamic_slice_sizes_size() == 0) {
516         if (opcode == HloOpcode::kCollectivePermute) {
517           instruction = CreateCollectivePermute(
518               shape, operands(0), source_target_pairs, channel_id);
519         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
520           instruction = CreateCollectivePermuteStart(
521               shape, operands(0), source_target_pairs, channel_id);
522         } else {
523           LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
524                      << "but got " << HloOpcodeString(opcode);
525         }
526       } else {
527         std::vector<std::vector<int64_t>> slice_sizes;
528         HloInstruction* input = operands(0);
529         HloInstruction* input_start_indices = operands(2);
530         if (input->shape().IsTuple() &&
531             input->shape().tuple_shapes_size() > 1) {
532           slice_sizes.resize(input->shape().tuple_shapes_size());
533         } else {
534           slice_sizes.resize(1);
535         }
536         int proto_index = 0;
537         if (input->shape().IsTuple()) {
538           if (input_start_indices->shape()
539                   .tuple_shapes(0)
540                   .tuple_shapes(0)
541                   .IsArray()) {
542             slice_sizes.resize(input->shape().tuple_shapes_size());
543             for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
544               slice_sizes[i].resize(
545                   input->shape().tuple_shapes(i).dimensions_size());
546               for (int j = 0;
547                    j < input->shape().tuple_shapes(i).dimensions_size(); ++j) {
548                 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
549                 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
550                 proto_index += 1;
551               }
552             }
553           } else {
554             slice_sizes.resize(
555                 input->shape().tuple_shapes_size() *
556                 ShapeUtil::TupleElementCount(
557                     input_start_indices->shape().tuple_shapes(0)));
558             int slice_sizes_count = 0;
559             for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
560               for (int j = 0;
561                    j < ShapeUtil::TupleElementCount(
562                            input_start_indices->shape().tuple_shapes(i));
563                    ++j) {
564                 slice_sizes[slice_sizes_count].resize(
565                     input->shape().tuple_shapes(i).rank());
566                 for (int k = 0; k < input->shape().tuple_shapes(i).rank();
567                      ++k) {
568                   CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
569                   slice_sizes[slice_sizes_count][k] =
570                       proto.dynamic_slice_sizes(proto_index);
571                   proto_index += 1;
572                 }
573                 slice_sizes_count += 1;
574               }
575             }
576           }
577         } else {
578           slice_sizes.resize(
579               ShapeUtil::TupleElementCount(input_start_indices->shape()));
580           if (input_start_indices->shape().tuple_shapes(0).IsTuple()) {
581             for (int i = 0;
582                  i < ShapeUtil::TupleElementCount(input_start_indices->shape());
583                  ++i) {
584               slice_sizes[i].resize(input->shape().dimensions_size());
585               for (int j = 0; j < input->shape().dimensions_size(); ++j) {
586                 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
587                 proto_index += 1;
588               }
589             }
590           } else {
591             slice_sizes.resize(1);
592             slice_sizes[0].resize(input->shape().dimensions_size());
593             for (int j = 0; j < input->shape().dimensions_size(); ++j) {
594               slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index);
595               proto_index += 1;
596             }
597           }
598         }
599         if (opcode == HloOpcode::kCollectivePermute) {
600           instruction = CreateCollectivePermute(
601               shape, operands(0), operands(1), operands(2), operands(3),
602               source_target_pairs, slice_sizes, channel_id);
603         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
604           instruction = CreateCollectivePermuteStart(
605               shape, operands(0), operands(1), operands(2), operands(3),
606               source_target_pairs, slice_sizes, channel_id);
607         } else {
608           LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
609                      << "but got " << HloOpcodeString(opcode);
610         }
611       }
612       break;
613     }
614     case HloOpcode::kReplicaId: {
615       instruction = CreateReplicaId(shape);
616       break;
617     }
618     case HloOpcode::kPartitionId: {
619       instruction = CreatePartitionId(shape);
620       break;
621     }
622     case HloOpcode::kConvolution: {
623       TF_RET_CHECK(proto.has_window());
624       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
625       PrecisionConfig precision_config = proto.precision_config();
626       precision_config.mutable_operand_precision()->Resize(
627           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
628       instruction = CreateConvolve(
629           shape, operands(0), operands(1),
630           std::max<int64>(proto.feature_group_count(), 1),
631           std::max<int64>(proto.batch_group_count(), 1), proto.window(),
632           proto.convolution_dimension_numbers(), precision_config);
633       break;
634     }
635     case HloOpcode::kReduceWindow:
636       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
637           << "Reduce window should have an even number of operands but "
638              "sees "
639           << proto.operand_ids_size();
640       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
641           << "ReduceWindow should have 1 called computation but sees "
642           << proto.called_computation_ids_size();
643       {
644         const auto reduce_operands = all_operands();
645         auto inputs = absl::MakeSpan(reduce_operands)
646                           .subspan(0, reduce_operands.size() / 2);
647         auto init_values =
648             absl::MakeSpan(reduce_operands)
649                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
650         instruction = CreateReduceWindow(shape, inputs, init_values,
651                                          proto.window(), computations(0));
652       }
653       break;
654     case HloOpcode::kSelectAndScatter:
655       TF_RET_CHECK(proto.called_computation_ids_size() == 2)
656           << "SelectAndScatter should have 2 called computations but sees "
657           << proto.called_computation_ids_size();
658       instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
659                                            proto.window(), operands(1),
660                                            operands(2), computations(1));
661       break;
662     case HloOpcode::kCustomCall: {
663       if (proto.constrain_layout()) {
664         // A proto RepeatedPtrField cannot be converted to a Span (it is a
665         // vector of pointers essentially) so create a vector of shapes to pass
666         // in.
667         std::vector<Shape> operand_shapes;
668         for (const ShapeProto& shape_proto :
669              proto.operand_shapes_with_layout()) {
670           operand_shapes.emplace_back(shape_proto);
671         }
672         instruction =
673             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
674                              operand_shapes, proto.backend_config());
675       } else {
676         if (proto.called_computation_ids_size() == 1) {
677           instruction = CreateCustomCall(shape, all_operands(), computations(0),
678                                          proto.custom_call_target(),
679                                          proto.backend_config());
680         } else if (proto.called_computation_ids_size() > 1) {
681           instruction = CreateCustomCall(
682               shape, all_operands(), all_computations(),
683               proto.custom_call_target(), proto.backend_config());
684 
685         } else {
686           instruction = CreateCustomCall(shape, all_operands(),
687                                          proto.custom_call_target(),
688                                          proto.backend_config());
689         }
690       }
691       auto custom_call_instr =
692           Cast<HloCustomCallInstruction>(instruction.get());
693       if (proto.has_window()) {
694         custom_call_instr->set_window(proto.window());
695       }
696       if (proto.has_literal()) {
697         TF_ASSIGN_OR_RETURN(
698             auto literal,
699             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
700         custom_call_instr->set_literal(std::move(literal));
701       }
702       if (proto.has_convolution_dimension_numbers()) {
703         custom_call_instr->set_convolution_dimension_numbers(
704             proto.convolution_dimension_numbers());
705       }
706       custom_call_instr->set_feature_group_count(
707           std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
708       custom_call_instr->set_batch_group_count(
709           std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
710       custom_call_instr->set_custom_call_has_side_effect(
711           proto.custom_call_has_side_effect());
712       custom_call_instr->set_padding_type(proto.padding_type());
713 
714       PrecisionConfig precision_config = proto.precision_config();
715       precision_config.mutable_operand_precision()->Resize(
716           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
717       *custom_call_instr->mutable_precision_config() = precision_config;
718       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
719           output_to_operand_aliasing;
720       for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
721         output_to_operand_aliasing.emplace_back(
722             ShapeIndex(aliasing.output_shape_index().begin(),
723                        aliasing.output_shape_index().end()),
724             std::pair<int64, ShapeIndex>{
725                 aliasing.operand_index(),
726                 ShapeIndex(aliasing.operand_shape_index().begin(),
727                            aliasing.operand_shape_index().end())});
728       }
729       custom_call_instr->set_output_to_operand_aliasing(
730           std::move(output_to_operand_aliasing));
731       custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule());
732       custom_call_instr->set_api_version(proto.custom_call_api_version());
733       break;
734     }
735     case HloOpcode::kPad:
736       TF_RET_CHECK(proto.has_padding_config());
737       instruction =
738           CreatePad(shape, operands(0), operands(1), proto.padding_config());
739       break;
740     case HloOpcode::kDynamicSlice: {
741       std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
742       absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
743       TF_RET_CHECK(proto.operand_ids_size() >= 1)
744           << "DynamicSlice instruction should have at least 1 operands but "
745              "sees "
746           << proto.operand_ids_size();
747       // TODO(b/118437727): Old form, make the check unconditional.
748       if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
749         auto expected_operands = 1 + operands(0)->shape().rank();
750         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
751             << "DynamicSlice instruction should have " << expected_operands
752             << " operands, but has " << proto.operand_ids_size();
753       }
754       const auto& operand_vector = all_operands();
755       instruction = CreateDynamicSlice(
756           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
757           slice_sizes);
758       break;
759     }
760     case HloOpcode::kDynamicUpdateSlice: {
761       TF_RET_CHECK(proto.operand_ids_size() >= 2)
762           << "DynamicUpdateSlice instruction should have at least 2 operands "
763              "but sees "
764           << proto.operand_ids_size();
765       // TODO(b/118437727): Old form, make the check unconditional.
766       if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
767         auto expected_operands = 2 + operands(0)->shape().rank();
768         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
769             << "DynamicUpdateSlice instruction should have "
770             << expected_operands << " operands, but has "
771             << proto.operand_ids_size();
772       }
773       const auto& operand_vector = all_operands();
774       instruction =
775           CreateDynamicUpdateSlice(shape, operands(0), operands(1),
776                                    absl::MakeSpan(operand_vector).subspan(2));
777 
778       break;
779     }
780     case HloOpcode::kGather: {
781       TF_RET_CHECK(proto.has_gather_dimension_numbers())
782           << "Gather instruction should have GatherDimensionNumbers set.";
783       auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
784           proto.gather_dimension_numbers());
785       std::vector<int64> gather_slice_sizes;
786       for (int64_t bound : proto.gather_slice_sizes()) {
787         gather_slice_sizes.push_back(bound);
788       }
789       instruction = CreateGather(shape, operands(0), operands(1),
790                                  *gather_dimension_numbers, gather_slice_sizes,
791                                  proto.indices_are_sorted());
792       break;
793     }
794     case HloOpcode::kScatter: {
795       TF_RET_CHECK(proto.has_scatter_dimension_numbers())
796           << "Scatter instruction should have ScatterDimensionNumbers set.";
797       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
798           << "Scatter instruction should have 1 called computation but sees "
799           << proto.called_computation_ids_size();
800       auto scatter_dimension_numbers =
801           absl::make_unique<ScatterDimensionNumbers>(
802               proto.scatter_dimension_numbers());
803       instruction =
804           CreateScatter(shape, operands(0), operands(1), operands(2),
805                         computations(0), *scatter_dimension_numbers,
806                         proto.indices_are_sorted(), proto.unique_indices());
807       break;
808     }
809     case HloOpcode::kIota:
810       TF_RET_CHECK(proto.dimensions_size() == 1)
811           << "Iota instruction should have 1 dimension but sees "
812           << proto.dimensions_size();
813       instruction = CreateIota(shape, proto.dimensions(0));
814       break;
815     case HloOpcode::kDot: {
816       TF_RET_CHECK(proto.has_dot_dimension_numbers())
817           << "Dot instruction should have dot_dimension_numbers.";
818       PrecisionConfig precision_config = proto.precision_config();
819       precision_config.mutable_operand_precision()->Resize(
820           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
821       instruction = absl::make_unique<HloDotInstruction>(
822           shape, operands(0), operands(1), proto.dot_dimension_numbers(),
823           precision_config);
824       break;
825     }
826     case HloOpcode::kDomain: {
827       std::shared_ptr<const HloSharding> entry_hlo_sharding;
828       std::shared_ptr<const HloSharding> exit_hlo_sharding;
829       if (proto.has_domain_entry_sharding()) {
830         TF_ASSIGN_OR_RETURN(
831             HloSharding sharding,
832             HloSharding::FromProto(proto.domain_entry_sharding()));
833         entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
834       }
835       if (proto.has_domain_exit_sharding()) {
836         TF_ASSIGN_OR_RETURN(
837             HloSharding sharding,
838             HloSharding::FromProto(proto.domain_exit_sharding()));
839         exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
840       }
841       instruction = absl::make_unique<HloDomainInstruction>(
842           shape, operands(0),
843           absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
844           absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
845       break;
846     }
847     case HloOpcode::kGetDimensionSize:
848       TF_RET_CHECK(proto.dimensions_size() == 1);
849       instruction =
850           CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
851       break;
852     case HloOpcode::kSetDimensionSize:
853       TF_RET_CHECK(proto.dimensions_size() == 1);
854       instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
855                                            proto.dimensions(0));
856       break;
857     case HloOpcode::kReshape: {
858       int64_t inferred_dimension = -1;
859       if (!proto.dimensions().empty()) {
860         inferred_dimension = proto.dimensions()[0];
861       }
862       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
863                    ShapeUtil::ElementsIn(shape) ==
864                        ShapeUtil::ElementsIn(operands(0)->shape()))
865           << "shape: " << ShapeUtil::HumanString(shape)
866           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
867       instruction = CreateReshape(shape, operands(0), inferred_dimension);
868       break;
869     }
870     case HloOpcode::kDynamicReshape: {
871       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
872                    ShapeUtil::ElementsIn(shape) ==
873                        ShapeUtil::ElementsIn(operands(0)->shape()))
874           << "shape: " << ShapeUtil::HumanString(shape)
875           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
876       const auto& operand_vector = all_operands();
877       instruction = CreateDynamicReshape(
878           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
879       break;
880     }
881     default: {
882       instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
883       for (const int64_t operand_id : proto.operand_ids()) {
884         instruction->AppendOperand(instruction_map.at(operand_id));
885       }
886       if (instruction->opcode() != HloOpcode::kFusion) {
887         if (instruction->opcode() == HloOpcode::kCall) {
888           TF_RET_CHECK(proto.called_computation_ids_size() == 1)
889               << "Call should have 1 called computation but has "
890               << proto.called_computation_ids_size();
891         }
892         for (const int64_t computation_id : proto.called_computation_ids()) {
893           instruction->called_computations_.push_back(
894               computation_map.at(computation_id));
895         }
896       }
897       TF_RET_CHECK(!proto.has_precision_config())
898           << instruction->opcode() << proto.DebugString();
899       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
900       break;
901     }
902   }
903 
904   for (const int64_t predecessor_id : proto.control_predecessor_ids()) {
905     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
906         << "No instruction with id " << predecessor_id;
907     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
908                            ->AddControlDependencyTo(instruction.get()));
909   }
910 
911   TF_RET_CHECK(!proto.name().empty());
912   instruction->SetAndSanitizeName(proto.name());
913   instruction->metadata_ = proto.metadata();
914   instruction->backend_config_ = proto.backend_config();
915   instruction->outer_dimension_partitions_.assign(
916       proto.outer_dimension_partitions().begin(),
917       proto.outer_dimension_partitions().end());
918 
919   TF_RET_CHECK(proto.id() >= 0)
920       << "Instruction with negative id: " << proto.id();
921   TF_RET_CHECK(proto.id() <= INT_MAX)
922       << "Instruction with id > INT_MAX: " << proto.id();
923   instruction->unique_id_ = proto.id();
924 
925   if (proto.has_sharding()) {
926     TF_ASSIGN_OR_RETURN(const auto& sharding,
927                         HloSharding::FromProto(proto.sharding()));
928     instruction->set_sharding(sharding);
929   }
930 
931   if (proto.has_frontend_attributes()) {
932     instruction->set_frontend_attributes(proto.frontend_attributes());
933   }
934 
935   return std::move(instruction);
936 }
937 
CreateParameter(int64_t parameter_number,const Shape & shape,const string & name)938 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
939     int64_t parameter_number, const Shape& shape, const string& name) {
940   return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
941                                                     name);
942 }
943 
CreateTrace(const string & tag,HloInstruction * operand)944 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
945     const string& tag, HloInstruction* operand) {
946   return absl::make_unique<HloTraceInstruction>(tag, operand);
947 }
948 
CreateConstant(Literal literal)949 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
950     Literal literal) {
951   return absl::make_unique<HloConstantInstruction>(std::move(literal));
952 }
953 
CreateIota(const Shape & shape,int64_t iota_dimension)954 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
955     const Shape& shape, int64_t iota_dimension) {
956   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
957 }
958 
959 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64_t index)960 HloInstruction::CreateGetTupleElement(const Shape& shape,
961                                       HloInstruction* operand, int64_t index) {
962   return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
963                                                           index);
964 }
965 
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)966 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
967     const Shape& shape, RandomDistribution distribution,
968     absl::Span<HloInstruction* const> parameters) {
969   return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
970 }
971 
972 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64_t delta)973 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64_t delta) {
974   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
975 }
976 
977 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)978 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
979                                       RandomAlgorithm algorithm) {
980   return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
981                                                           algorithm);
982 }
983 
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)984 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
985     const Shape& shape, HloOpcode opcode,
986     absl::Span<HloInstruction* const> operands) {
987   if (opcode == HloOpcode::kCopy) {
988     // It is impossible to copy an opaque shape, we don't know how big it is.
989     CHECK(!shape.IsOpaque());
990   }
991   auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
992   for (auto operand : operands) {
993     instruction->AppendOperand(operand);
994   }
995   return instruction;
996 }
997 
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)998 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
999     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
1000   // Only certain opcodes are supported with CreateUnary: opcodes of unary
1001   // instructions with no auxiliary fields.
1002   switch (opcode) {
1003     case HloOpcode::kAbs:
1004     case HloOpcode::kAllGatherDone:
1005     case HloOpcode::kAllReduceDone:
1006     case HloOpcode::kRoundNearestAfz:
1007     case HloOpcode::kBitcast:
1008     case HloOpcode::kCeil:
1009     case HloOpcode::kCollectivePermuteDone:
1010     case HloOpcode::kCopy:
1011     case HloOpcode::kCopyDone:
1012     case HloOpcode::kCos:
1013     case HloOpcode::kClz:
1014     case HloOpcode::kExp:
1015     case HloOpcode::kExpm1:
1016     case HloOpcode::kFloor:
1017     case HloOpcode::kImag:
1018     case HloOpcode::kIsFinite:
1019     case HloOpcode::kLog:
1020     case HloOpcode::kLog1p:
1021     case HloOpcode::kNot:
1022     case HloOpcode::kNegate:
1023     case HloOpcode::kPopulationCount:
1024     case HloOpcode::kReal:
1025     case HloOpcode::kRsqrt:
1026     case HloOpcode::kLogistic:
1027     case HloOpcode::kSign:
1028     case HloOpcode::kSin:
1029     case HloOpcode::kSqrt:
1030     case HloOpcode::kCbrt:
1031     case HloOpcode::kTanh:
1032       break;
1033     default:
1034       LOG(FATAL) << "Invalid unary instruction opcode "
1035                  << HloOpcodeString(opcode);
1036   }
1037   return CreateNary(shape, opcode, {operand});
1038 }
1039 
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)1040 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
1041     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1042     HloInstruction* rhs) {
1043   // Only certain opcodes are supported with CreateBinary: opcodes of binary
1044   // instructions with no auxiliary fields.
1045   switch (opcode) {
1046     case HloOpcode::kAdd:
1047     case HloOpcode::kAtan2:
1048     case HloOpcode::kDivide:
1049     case HloOpcode::kComplex:
1050     case HloOpcode::kMaximum:
1051     case HloOpcode::kMinimum:
1052     case HloOpcode::kMultiply:
1053     case HloOpcode::kPower:
1054     case HloOpcode::kRemainder:
1055     case HloOpcode::kSubtract:
1056     case HloOpcode::kAnd:
1057     case HloOpcode::kOr:
1058     case HloOpcode::kXor:
1059     case HloOpcode::kShiftLeft:
1060     case HloOpcode::kShiftRightArithmetic:
1061     case HloOpcode::kShiftRightLogical:
1062       break;
1063     default:
1064       LOG(FATAL) << "Invalid binary instruction opcode "
1065                  << HloOpcodeString(opcode);
1066   }
1067   return CreateNary(shape, opcode, {lhs, rhs});
1068 }
1069 
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)1070 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
1071     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1072     HloInstruction* rhs, HloInstruction* ehs) {
1073   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
1074   // instructions with no auxiliary fields.
1075   switch (opcode) {
1076     case HloOpcode::kClamp:
1077     case HloOpcode::kSelect:
1078     case HloOpcode::kTupleSelect:
1079       break;
1080     default:
1081       LOG(FATAL) << "Invalid ternary instruction opcode "
1082                  << HloOpcodeString(opcode);
1083   }
1084   return CreateNary(shape, opcode, {lhs, rhs, ehs});
1085 }
1086 
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1087 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
1088     const Shape& shape, HloOpcode opcode,
1089     absl::Span<HloInstruction* const> operands) {
1090   CHECK_EQ(HloOpcode::kTuple, opcode);
1091   return CreateNary(shape, opcode, operands);
1092 }
1093 
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1094 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
1095     const Shape& shape, absl::Span<HloInstruction* const> operands,
1096     HloComputation* map_computation) {
1097   return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
1098 }
1099 
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
1101     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1102     int64_t feature_group_count, int64_t batch_group_count,
1103     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
1104     const PrecisionConfig& precision_config) {
1105   return absl::make_unique<HloConvolutionInstruction>(
1106       shape, lhs, rhs, feature_group_count, batch_group_count, window,
1107       dimension_numbers, precision_config);
1108 }
1109 
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)1110 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
1111     const Shape& shape, HloInstruction* operand, FftType fft_type,
1112     absl::Span<const int64> fft_length) {
1113   return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
1114                                               fft_length);
1115 }
1116 
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1117 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1118     const Shape& shape, HloInstruction* operand,
1119     bool is_cross_program_prefetch) {
1120   return absl::make_unique<HloCopyStartInstruction>(shape, operand,
1121                                                     is_cross_program_prefetch);
1122 }
1123 
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)1124 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1125     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1126     ComparisonDirection direction, absl::optional<Comparison::Type> type) {
1127   return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1128                                                   type);
1129 }
1130 
1131 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1132 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1133                                       HloInstruction* b,
1134                                       const TriangularSolveOptions& options) {
1135   return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1136 }
1137 
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1138 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1139     const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1140   return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
1141 }
1142 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1143 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1144     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1145     const DotDimensionNumbers& dimension_numbers,
1146     const PrecisionConfig& precision_config) {
1147   return absl::make_unique<HloDotInstruction>(
1148       shape, lhs, rhs, dimension_numbers, precision_config);
1149 }
1150 
1151 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1152 HloInstruction::CreateReducePrecision(const Shape& shape,
1153                                       HloInstruction* operand,
1154                                       const int exponent_bits,
1155                                       const int mantissa_bits) {
1156   return absl::make_unique<HloReducePrecisionInstruction>(
1157       shape, operand, exponent_bits, mantissa_bits);
1158 }
1159 
CreateAllGather(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1160 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1161     const Shape& shape, absl::Span<HloInstruction* const> operands,
1162     int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1163     bool constrain_layout, const absl::optional<int64>& channel_id,
1164     bool use_global_device_ids) {
1165   return absl::make_unique<HloAllGatherInstruction>(
1166       HloOpcode::kAllGather, shape, operands, all_gather_dimension,
1167       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1168 }
1169 
1170 /* static */ std::unique_ptr<HloInstruction>
CreateAllGatherStart(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1171 HloInstruction::CreateAllGatherStart(
1172     const Shape& shape, absl::Span<HloInstruction* const> operands,
1173     int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1174     bool constrain_layout, const absl::optional<int64>& channel_id,
1175     bool use_global_device_ids) {
1176   return absl::make_unique<HloAllGatherInstruction>(
1177       HloOpcode::kAllGatherStart, shape, operands, all_gather_dimension,
1178       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1179 }
1180 
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1181 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1182     const Shape& shape, absl::Span<HloInstruction* const> operands,
1183     HloComputation* reduce_computation,
1184     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1185     const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1186   return absl::make_unique<HloAllReduceInstruction>(
1187       HloOpcode::kAllReduce, shape, operands, reduce_computation,
1188       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1189 }
1190 
1191 /* static */ std::unique_ptr<HloInstruction>
CreateReduceScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)1192 HloInstruction::CreateReduceScatter(
1193     const Shape& shape, absl::Span<HloInstruction* const> operands,
1194     HloComputation* reduce_computation,
1195     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1196     const absl::optional<int64>& channel_id, bool use_global_device_ids,
1197     int64_t scatter_dimension) {
1198   return absl::make_unique<HloReduceScatterInstruction>(
1199       shape, operands, reduce_computation, replica_groups, constrain_layout,
1200       channel_id, use_global_device_ids, scatter_dimension);
1201 }
1202 
1203 /* static */ std::unique_ptr<HloInstruction>
CreateAllReduceStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1204 HloInstruction::CreateAllReduceStart(
1205     const Shape& shape, absl::Span<HloInstruction* const> operands,
1206     HloComputation* reduce_computation,
1207     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1208     const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1209   return absl::make_unique<HloAllReduceInstruction>(
1210       HloOpcode::kAllReduceStart, shape, operands, reduce_computation,
1211       replica_groups, constrain_layout, channel_id, use_global_device_ids);
1212 }
1213 
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)1214 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1215     const Shape& shape, absl::Span<HloInstruction* const> operands,
1216     absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1217     const absl::optional<int64>& channel_id,
1218     const absl::optional<int64>& split_dimension) {
1219   return absl::make_unique<HloAllToAllInstruction>(
1220       shape, operands, replica_groups, constrain_layout, channel_id,
1221       split_dimension);
1222 }
1223 
1224 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const absl::optional<int64_t> & channel_id)1225 HloInstruction::CreateCollectivePermute(
1226     const Shape& shape, HloInstruction* operand,
1227     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1228     const absl::optional<int64_t>& channel_id) {
1229   return absl::make_unique<HloCollectivePermuteInstruction>(
1230       HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1231       channel_id);
1232 }
1233 
1234 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)1235 HloInstruction::CreateCollectivePermute(
1236     const Shape& shape, HloInstruction* input, HloInstruction* output,
1237     HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1238     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1239     absl::Span<const std::vector<int64_t>> slice_sizes,
1240     const absl::optional<int64_t>& channel_id) {
1241   return absl::make_unique<HloCollectivePermuteInstruction>(
1242       HloOpcode::kCollectivePermute, shape, input, output, input_start_indices,
1243       output_start_indices, source_target_pairs, slice_sizes, channel_id);
1244 }
1245 
1246 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const absl::optional<int64_t> & channel_id)1247 HloInstruction::CreateCollectivePermuteStart(
1248     const Shape& shape, HloInstruction* operand,
1249     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1250     const absl::optional<int64_t>& channel_id) {
1251   return absl::make_unique<HloCollectivePermuteInstruction>(
1252       HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1253       channel_id);
1254 }
1255 
1256 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)1257 HloInstruction::CreateCollectivePermuteStart(
1258     const Shape& shape, HloInstruction* input, HloInstruction* output,
1259     HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1260     absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1261     absl::Span<const std::vector<int64_t>> slice_sizes,
1262     const absl::optional<int64_t>& channel_id) {
1263   return absl::make_unique<HloCollectivePermuteInstruction>(
1264       HloOpcode::kCollectivePermuteStart, shape, input, output,
1265       input_start_indices, output_start_indices, source_target_pairs,
1266       slice_sizes, channel_id);
1267 }
1268 
CreateReplicaId(const Shape & shape)1269 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1270     const Shape& shape) {
1271   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1272       << "HloInstruction replica-id must have a shape of u32[], but "
1273       << shape.ToString() << " is specified";
1274   return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1275 }
1276 
CreatePartitionId(const Shape & shape)1277 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1278     const Shape& shape) {
1279   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1280       << "HloInstruction partition-id must have a shape of u32[], but "
1281       << shape.ToString() << " is specified";
1282   return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1283 }
1284 
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1285 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1286     const Shape& infeed_shape, HloInstruction* token_operand,
1287     const string& config) {
1288   return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1289                                                  config);
1290 }
1291 
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1292 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1293     const Shape& outfeed_shape, HloInstruction* operand,
1294     HloInstruction* token_operand, absl::string_view outfeed_config) {
1295   return absl::make_unique<HloOutfeedInstruction>(
1296       outfeed_shape, operand, token_operand, outfeed_config);
1297 }
1298 
CreateSend(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1299 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1300     HloInstruction* operand, HloInstruction* token, int64_t channel_id,
1301     bool is_host_transfer) {
1302   return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
1303                                                is_host_transfer);
1304 }
1305 
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1306 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1307     HloInstruction* operand, bool is_host_transfer) {
1308   auto send_operand = DynCast<HloSendInstruction>(operand);
1309   CHECK(send_operand != nullptr)
1310       << "SendDone must take the context operand from Send";
1311   return absl::make_unique<HloSendDoneInstruction>(send_operand,
1312                                                    is_host_transfer);
1313 }
1314 
CreateRecv(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1315 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1316     const Shape& shape, HloInstruction* token, int64_t channel_id,
1317     bool is_host_transfer) {
1318   return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
1319                                                is_host_transfer);
1320 }
1321 
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1322 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1323     HloInstruction* operand, bool is_host_transfer) {
1324   auto recv_operand = DynCast<HloRecvInstruction>(operand);
1325   CHECK(recv_operand != nullptr)
1326       << "RecvDone must take the context operand from Recv";
1327   return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
1328                                                    is_host_transfer);
1329 }
1330 
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1331 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1332     const Shape& shape, HloInstruction* operand,
1333     absl::Span<const int64> dimensions) {
1334   return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1335 }
1336 
CreateAfterAll(absl::Span<HloInstruction * const> operands)1337 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1338     absl::Span<HloInstruction* const> operands) {
1339   CHECK(!operands.empty());
1340   auto instruction = absl::WrapUnique(
1341       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1342   for (auto operand : operands) {
1343     instruction->AppendOperand(operand);
1344   }
1345   return instruction;
1346 }
1347 
CreateToken()1348 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1349   return absl::WrapUnique(
1350       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1351 }
1352 
1353 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1354 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1355                                     HloInstruction* token_operand) {
1356   auto instruction = absl::WrapUnique(
1357       new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1358   instruction->AppendOperand(data_operand);
1359   instruction->AppendOperand(token_operand);
1360   return instruction;
1361 }
1362 
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1363 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1364     const Shape& shape, HloComputation* condition, HloComputation* body,
1365     HloInstruction* init) {
1366   auto instruction =
1367       absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1368   instruction->AppendOperand(init);
1369   // Body comes before condition computation in the vector.
1370   instruction->called_computations_.push_back(body);
1371   instruction->called_computations_.push_back(condition);
1372   return instruction;
1373 }
1374 
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1375 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1376     const Shape& shape, HloInstruction* pred,
1377     HloInstruction* true_computation_arg, HloComputation* true_computation,
1378     HloInstruction* false_computation_arg, HloComputation* false_computation) {
1379   auto instruction =
1380       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1381   instruction->AppendOperand(pred);
1382   instruction->AppendOperand(true_computation_arg);
1383   instruction->AppendOperand(false_computation_arg);
1384   // In called_computations_, the index of true_computation must be 0 and that
1385   // of false computation must be 1, as defined by kTrueComputationIndex and
1386   // kFalseComputationIndex.
1387   instruction->called_computations_.push_back(true_computation);
1388   instruction->called_computations_.push_back(false_computation);
1389   return instruction;
1390 }
1391 
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1392 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1393     const Shape& shape, HloInstruction* branch_index,
1394     absl::Span<HloComputation* const> branch_computations,
1395     absl::Span<HloInstruction* const> branch_computation_args) {
1396   auto instruction =
1397       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1398   instruction->AppendOperand(branch_index);
1399   CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1400   for (int i = 0; i < branch_computations.size(); ++i) {
1401     instruction->called_computations_.push_back(branch_computations[i]);
1402     instruction->AppendOperand(branch_computation_args[i]);
1403   }
1404   return instruction;
1405 }
1406 
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1407 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1408     const Shape& shape, HloInstruction* operand,
1409     absl::Span<const int64> start_indices,
1410     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1411   return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1412                                                 limit_indices, strides);
1413 }
1414 
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1415 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1416     const Shape& shape, HloInstruction* operand,
1417     absl::Span<HloInstruction* const> start_indices,
1418     absl::Span<const int64> slice_sizes) {
1419   return absl::make_unique<HloDynamicSliceInstruction>(
1420       shape, operand, start_indices, slice_sizes);
1421 }
1422 
1423 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1424 HloInstruction::CreateDynamicUpdateSlice(
1425     const Shape& shape, HloInstruction* operand, HloInstruction* update,
1426     absl::Span<HloInstruction* const> start_indices) {
1427   return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1428       shape, operand, update, start_indices);
1429 }
1430 
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1431 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1432     const Shape& shape, absl::Span<HloInstruction* const> operands,
1433     int64_t dimension) {
1434   return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1435                                                       dimension);
1436 }
1437 
CreateConvert(const Shape & shape,HloInstruction * operand)1438 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1439     const Shape& shape, HloInstruction* operand) {
1440   auto instruction =
1441       absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1442   instruction->AppendOperand(operand);
1443   return instruction;
1444 }
1445 
1446 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1447 HloInstruction::CreateBitcastConvert(const Shape& shape,
1448                                      HloInstruction* operand) {
1449   auto instruction =
1450       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1451   instruction->AppendOperand(operand);
1452   return instruction;
1453 }
1454 
CreateBitcast(const Shape & shape,HloInstruction * operand)1455 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1456     const Shape& shape, HloInstruction* operand) {
1457   auto instruction =
1458       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1459   instruction->AppendOperand(operand);
1460   return instruction;
1461 }
1462 
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1463 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1464     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1465     absl::Span<const int64> dimensions_to_reduce,
1466     HloComputation* reduce_computation) {
1467   auto instruction = absl::WrapUnique(new HloReduceInstruction(
1468       shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1469   return std::move(instruction);
1470 }
1471 
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1472 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1473     const Shape& shape, absl::Span<HloInstruction* const> operands,
1474     absl::Span<HloInstruction* const> init_values,
1475     absl::Span<const int64> dimensions_to_reduce,
1476     HloComputation* reduce_computation) {
1477   std::vector<HloInstruction*> all_args;
1478   all_args.reserve(operands.size() * 2);
1479   all_args.insert(all_args.end(), operands.begin(), operands.end());
1480   all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1481   return absl::make_unique<HloReduceInstruction>(
1482       shape, all_args, dimensions_to_reduce, reduce_computation);
1483 }
1484 
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1485 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1486     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1487     const Window& window, HloComputation* reduce_computation) {
1488   return absl::make_unique<HloReduceWindowInstruction>(
1489       shape, operand, init_value, window, reduce_computation);
1490 }
1491 
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1492 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1493     const Shape& shape, absl::Span<HloInstruction* const> operands,
1494     absl::Span<HloInstruction* const> init_values, const Window& window,
1495     HloComputation* reduce_computation) {
1496   return absl::make_unique<HloReduceWindowInstruction>(
1497       shape, operands, init_values, window, reduce_computation);
1498 }
1499 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)1500 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1501                                         HloInstruction* operand,
1502                                         HloInstruction* scale,
1503                                         HloInstruction* offset, float epsilon,
1504                                         int64_t feature_index) {
1505   return absl::make_unique<HloBatchNormTrainingInstruction>(
1506       shape, operand, scale, offset, epsilon, feature_index);
1507 }
1508 
1509 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)1510 HloInstruction::CreateBatchNormInference(
1511     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1512     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1513     float epsilon, int64_t feature_index) {
1514   return absl::make_unique<HloBatchNormInferenceInstruction>(
1515       shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1516 }
1517 
1518 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)1519 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1520                                     HloInstruction* scale, HloInstruction* mean,
1521                                     HloInstruction* variance,
1522                                     HloInstruction* grad_output, float epsilon,
1523                                     int64_t feature_index) {
1524   return absl::make_unique<HloBatchNormGradInstruction>(
1525       shape, operand, scale, mean, variance, grad_output, epsilon,
1526       feature_index);
1527 }
1528 
1529 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1530 HloInstruction::CreateSelectAndScatter(
1531     const Shape& shape, HloInstruction* operand, HloComputation* select,
1532     const Window& window, HloInstruction* source, HloInstruction* init_value,
1533     HloComputation* scatter) {
1534   return absl::make_unique<HloSelectAndScatterInstruction>(
1535       shape, operand, select, window, source, init_value, scatter);
1536 }
1537 
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1538 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1539     const Shape& shape, HloInstruction* operand,
1540     absl::Span<const int64> broadcast_dimensions) {
1541   return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1542                                                     broadcast_dimensions);
1543 }
1544 
1545 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64_t dimension)1546 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1547                                        HloInstruction* operand,
1548                                        int64_t dimension) {
1549   return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1550                                                            dimension);
1551 }
1552 
1553 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)1554 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1555                                        HloInstruction* operand,
1556                                        HloInstruction* val, int64_t dimension) {
1557   return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1558                                                            dimension);
1559 }
1560 
1561 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1562 HloInstruction::CreateBroadcastSequence(
1563     const Shape& output_shape, HloInstruction* operand,
1564     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1565         adder) {
1566   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1567         operand->shape().rank() == output_shape.rank());
1568   Shape broadcast_shape = ShapeUtil::ChangeElementType(
1569       output_shape, operand->shape().element_type());
1570   // Do explicit broadcast for scalar.
1571   if (ShapeUtil::IsScalar(operand->shape())) {
1572     auto broadcast =
1573         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1574     broadcast->set_metadata(operand->metadata());
1575     if (operand->has_sharding()) {
1576       broadcast->set_sharding(operand->sharding());
1577     }
1578     broadcast->set_frontend_attributes(operand->frontend_attributes());
1579     return broadcast;
1580   }
1581   // Do explicit broadcast for degenerate broadcast.
1582   std::vector<int64> broadcast_dimensions;
1583   std::vector<int64> reshaped_dimensions;
1584   for (int i = 0; i < operand->shape().rank(); i++) {
1585     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1586       broadcast_dimensions.push_back(i);
1587       reshaped_dimensions.push_back(operand->shape().dimensions(i));
1588     } else {
1589       CHECK_EQ(operand->shape().dimensions(i), 1)
1590           << "An explicit broadcast sequence requires the broadcasted "
1591              "dimensions to be trivial; operand: "
1592           << operand->ToString() << "; output_shape: " << output_shape;
1593     }
1594   }
1595   // Eliminate the size one dimensions.
1596   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1597       ShapeUtil::MakeShape(operand->shape().element_type(),
1598                            reshaped_dimensions),
1599       operand));
1600   reshaped_operand->set_metadata(operand->metadata());
1601   if (operand->has_sharding()) {
1602     reshaped_operand->set_sharding(operand->sharding());
1603   }
1604   reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1605   // Broadcast 'reshape' up to the larger size.
1606   auto broadcast = HloInstruction::CreateBroadcast(
1607       broadcast_shape, reshaped_operand, broadcast_dimensions);
1608   broadcast->set_metadata(operand->metadata());
1609   if (operand->has_sharding()) {
1610     broadcast->set_sharding(operand->sharding());
1611   }
1612   broadcast->set_frontend_attributes(operand->frontend_attributes());
1613   return broadcast;
1614 }
1615 
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1616 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1617     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1618     const PaddingConfig& padding_config) {
1619   return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1620                                               padding_config);
1621 }
1622 
CreateReshape(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1623 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1624     const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) {
1625   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1626            ShapeUtil::ElementsIn(operand->shape()))
1627       << "shape: " << ShapeUtil::HumanString(shape)
1628       << " operand: " << ShapeUtil::HumanString(operand->shape());
1629 
1630   return absl::make_unique<HloReshapeInstruction>(shape, operand,
1631                                                   inferred_dimension);
1632 }
1633 
1634 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1635 HloInstruction::CreateDynamicReshape(
1636     const Shape& shape, HloInstruction* data_operand,
1637     absl::Span<HloInstruction* const> dim_sizes) {
1638   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1639            ShapeUtil::ElementsIn(data_operand[0].shape()))
1640       << "shape: " << ShapeUtil::HumanString(shape)
1641       << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1642   CHECK_EQ(shape.rank(), dim_sizes.size());
1643   return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1644                                                          dim_sizes);
1645 }
1646 
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1647 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1648     const Shape& shape, HloInstruction* operand,
1649     absl::Span<const int64> dimensions) {
1650   return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1651 }
1652 
CreateSort(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1653 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1654     const Shape& shape, int64_t dimension,
1655     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1656     bool is_stable) {
1657   return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1658                                                compare, is_stable);
1659 }
1660 
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1661 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1662     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1663   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1664                                                  fused_root);
1665 }
1666 
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1667 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1668     const Shape& shape, FusionKind fusion_kind,
1669     absl::Span<HloInstruction* const> operands,
1670     HloComputation* fusion_computation) {
1671   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1672                                                  fusion_computation);
1673 }
1674 
set_single_sharding(const HloSharding & sharding)1675 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1676   CHECK(!sharding.IsTuple()) << sharding;
1677   if (shape().IsTuple()) {
1678     set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1679   } else {
1680     set_sharding(sharding);
1681   }
1682 }
1683 
SetupDerivedInstruction(HloInstruction * derived_instruction) const1684 void HloInstruction::SetupDerivedInstruction(
1685     HloInstruction* derived_instruction) const {
1686   if (sharding_ != nullptr &&
1687       ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1688     // Only copy sharding if the tuple tree shape of the two instruction is
1689     // compatible because copying it between differently shaped instructions
1690     // can produce invalid shardings.
1691     derived_instruction->set_sharding(*sharding_);
1692   } else {
1693     derived_instruction->clear_sharding();
1694   }
1695   derived_instruction->set_metadata(metadata_);
1696   derived_instruction->set_frontend_attributes(frontend_attributes_);
1697 }
1698 
HasSideEffectNoRecurse() const1699 bool HloInstruction::HasSideEffectNoRecurse() const {
1700   switch (opcode_) {
1701     case HloOpcode::kSend:
1702     case HloOpcode::kSendDone:
1703     case HloOpcode::kRecv:
1704     case HloOpcode::kRecvDone:
1705     case HloOpcode::kRng:
1706     case HloOpcode::kRngGetAndUpdateState:
1707     case HloOpcode::kInfeed:
1708     case HloOpcode::kOutfeed:
1709     case HloOpcode::kTrace:
1710     case HloOpcode::kAllReduceStart:
1711     case HloOpcode::kAllReduceDone:
1712     case HloOpcode::kAllGatherStart:
1713     case HloOpcode::kAllGatherDone:
1714     case HloOpcode::kCollectivePermuteStart:
1715     case HloOpcode::kCollectivePermuteDone:
1716       return true;
1717     case HloOpcode::kAllReduce:
1718       return channel_id().has_value() ||
1719              Cast<HloAllReduceInstruction>(this)->constrain_layout();
1720     case HloOpcode::kAllToAll:
1721       return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1722     case HloOpcode::kCustomCall:
1723       return Cast<HloCustomCallInstruction>(this)
1724           ->custom_call_has_side_effect();
1725     default:
1726       return false;
1727   }
1728 }
1729 
HasSideEffect() const1730 bool HloInstruction::HasSideEffect() const {
1731   if (HasSideEffectNoRecurse()) {
1732     return true;
1733   }
1734   // Check if any of the called computations has a side effect.
1735   for (const auto& computation : called_computations()) {
1736     if (computation->HasSideEffect()) {
1737       return true;
1738     }
1739   }
1740   return false;
1741 }
1742 
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1743 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1744     const Shape& shape, absl::Span<HloInstruction* const> operands,
1745     HloComputation* computation) {
1746   std::unique_ptr<HloInstruction> instruction =
1747       absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1748   for (auto operand : operands) {
1749     instruction->AppendOperand(operand);
1750   }
1751   instruction->called_computations_.push_back(computation);
1752   return instruction;
1753 }
1754 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1755 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1756     const Shape& shape, absl::Span<HloInstruction* const> operands,
1757     absl::string_view custom_call_target, string opaque,
1758     CustomCallApiVersion api_version) {
1759   return absl::make_unique<HloCustomCallInstruction>(
1760       shape, operands, custom_call_target, std::move(opaque), api_version);
1761 }
1762 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1763 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1764     const Shape& shape, absl::Span<HloInstruction* const> operands,
1765     HloComputation* to_apply, absl::string_view custom_call_target,
1766     string opaque, CustomCallApiVersion api_version) {
1767   return absl::make_unique<HloCustomCallInstruction>(
1768       shape, operands, to_apply, custom_call_target, std::move(opaque),
1769       api_version);
1770 }
1771 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1772 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1773     const Shape& shape, absl::Span<HloInstruction* const> operands,
1774     absl::Span<HloComputation* const> called_computations,
1775     absl::string_view custom_call_target, string opaque,
1776     CustomCallApiVersion api_version) {
1777   return absl::make_unique<HloCustomCallInstruction>(
1778       shape, operands, called_computations, custom_call_target,
1779       std::move(opaque), api_version);
1780 }
1781 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque,CustomCallApiVersion api_version)1782 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1783     const Shape& shape, absl::Span<HloInstruction* const> operands,
1784     absl::string_view custom_call_target,
1785     absl::Span<const Shape> operand_shapes_with_layout, string opaque,
1786     CustomCallApiVersion api_version) {
1787   return absl::make_unique<HloCustomCallInstruction>(
1788       shape, operands, custom_call_target, std::move(opaque),
1789       operand_shapes_with_layout, api_version);
1790 }
1791 
CreateTuple(absl::Span<HloInstruction * const> elements)1792 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1793     absl::Span<HloInstruction* const> elements) {
1794   std::vector<Shape> element_shapes;
1795   for (auto element : elements) {
1796     element_shapes.push_back(element->shape());
1797   }
1798   Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1799   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1800 }
1801 
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1802 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1803     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1804     const GatherDimensionNumbers& gather_dim_numbers,
1805     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1806   return absl::make_unique<HloGatherInstruction>(
1807       shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1808       indices_are_sorted);
1809 }
1810 
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1811 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1812     const Shape& shape, HloInstruction* operand,
1813     HloInstruction* scatter_indices, HloInstruction* updates,
1814     HloComputation* update_computation,
1815     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1816     bool unique_indices) {
1817   return absl::make_unique<HloScatterInstruction>(
1818       shape, operand, scatter_indices, updates, update_computation,
1819       scatter_dim_numbers, indices_are_sorted, unique_indices);
1820 }
1821 
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1822 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1823     const Shape& shape, HloInstruction* operand,
1824     std::unique_ptr<DomainMetadata> operand_side_metadata,
1825     std::unique_ptr<DomainMetadata> user_side_metadata) {
1826   return absl::make_unique<HloDomainInstruction>(
1827       shape, operand, std::move(operand_side_metadata),
1828       std::move(user_side_metadata));
1829 }
1830 
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1831 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1832     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1833     HloCloneContext* context) const {
1834   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
1835   VLOG(3) << "  new operands:";
1836   for (const HloInstruction* new_operand : new_operands) {
1837     VLOG(3) << "    %" << new_operand->name();
1838   }
1839 
1840   std::unique_ptr<HloInstruction> clone;
1841   // Explicitly call the factory for the instruction type. This is more robust
1842   // in the face of code changes than copying fields explicitly. This also
1843   // properly sets the user fields of the operands.
1844   switch (opcode_) {
1845     // Ops migrated to subclasses.
1846     // TODO(b/80131774): Remove this switch when migration is complete.
1847     case HloOpcode::kBatchNormTraining:
1848     case HloOpcode::kBatchNormInference:
1849     case HloOpcode::kBatchNormGrad:
1850     case HloOpcode::kFft:
1851     case HloOpcode::kCompare:
1852     case HloOpcode::kCopyStart:
1853     case HloOpcode::kSend:
1854     case HloOpcode::kSendDone:
1855     case HloOpcode::kRecv:
1856     case HloOpcode::kRecvDone:
1857     case HloOpcode::kReverse:
1858     case HloOpcode::kConcatenate:
1859     case HloOpcode::kReduce:
1860     case HloOpcode::kTranspose:
1861     case HloOpcode::kBroadcast:
1862     case HloOpcode::kReshape:
1863     case HloOpcode::kDynamicReshape:
1864     case HloOpcode::kMap:
1865     case HloOpcode::kSlice:
1866     case HloOpcode::kConstant:
1867     case HloOpcode::kTrace:
1868     case HloOpcode::kFusion:
1869     case HloOpcode::kRng:
1870     case HloOpcode::kRngBitGenerator:
1871     case HloOpcode::kRngGetAndUpdateState:
1872     case HloOpcode::kParameter:
1873     case HloOpcode::kGetTupleElement:
1874     case HloOpcode::kReducePrecision:
1875     case HloOpcode::kAllGather:
1876     case HloOpcode::kAllGatherStart:
1877     case HloOpcode::kAllReduce:
1878     case HloOpcode::kReduceScatter:
1879     case HloOpcode::kAllReduceStart:
1880     case HloOpcode::kAllToAll:
1881     case HloOpcode::kCollectivePermute:
1882     case HloOpcode::kCollectivePermuteStart:
1883     case HloOpcode::kInfeed:
1884     case HloOpcode::kOutfeed:
1885     case HloOpcode::kConvolution:
1886     case HloOpcode::kCustomCall:
1887     case HloOpcode::kReduceWindow:
1888     case HloOpcode::kSelectAndScatter:
1889     case HloOpcode::kPad:
1890     case HloOpcode::kDynamicSlice:
1891     case HloOpcode::kSort:
1892     case HloOpcode::kGather:
1893     case HloOpcode::kScatter:
1894     case HloOpcode::kIota:
1895     case HloOpcode::kDot:
1896     case HloOpcode::kDomain:
1897     case HloOpcode::kGetDimensionSize:
1898     case HloOpcode::kSetDimensionSize:
1899     case HloOpcode::kTriangularSolve:
1900     case HloOpcode::kCholesky:
1901       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1902       break;
1903     // Unary ops.
1904     case HloOpcode::kAbs:
1905     case HloOpcode::kAllGatherDone:
1906     case HloOpcode::kAllReduceDone:
1907     case HloOpcode::kRoundNearestAfz:
1908     case HloOpcode::kBitcast:
1909     case HloOpcode::kCeil:
1910     case HloOpcode::kClz:
1911     case HloOpcode::kCollectivePermuteDone:
1912     case HloOpcode::kCopy:
1913     case HloOpcode::kCopyDone:
1914     case HloOpcode::kCos:
1915     case HloOpcode::kExp:
1916     case HloOpcode::kExpm1:
1917     case HloOpcode::kImag:
1918     case HloOpcode::kIsFinite:
1919     case HloOpcode::kFloor:
1920     case HloOpcode::kLog:
1921     case HloOpcode::kLog1p:
1922     case HloOpcode::kNot:
1923     case HloOpcode::kNegate:
1924     case HloOpcode::kPopulationCount:
1925     case HloOpcode::kReal:
1926     case HloOpcode::kRsqrt:
1927     case HloOpcode::kLogistic:
1928     case HloOpcode::kSign:
1929     case HloOpcode::kSin:
1930     case HloOpcode::kSqrt:
1931     case HloOpcode::kCbrt:
1932     case HloOpcode::kTanh:
1933       CHECK_EQ(new_operands.size(), 1);
1934       clone = CreateUnary(shape, opcode_, new_operands[0]);
1935       break;
1936     // Binary ops.
1937     case HloOpcode::kAdd:
1938     case HloOpcode::kAtan2:
1939     case HloOpcode::kComplex:
1940     case HloOpcode::kDivide:
1941     case HloOpcode::kMultiply:
1942     case HloOpcode::kSubtract:
1943     case HloOpcode::kMaximum:
1944     case HloOpcode::kMinimum:
1945     case HloOpcode::kPower:
1946     case HloOpcode::kRemainder:
1947     case HloOpcode::kAnd:
1948     case HloOpcode::kOr:
1949     case HloOpcode::kXor:
1950     case HloOpcode::kShiftLeft:
1951     case HloOpcode::kShiftRightArithmetic:
1952     case HloOpcode::kShiftRightLogical:
1953       CHECK_EQ(new_operands.size(), 2);
1954       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1955       break;
1956     // Ternary ops.
1957     case HloOpcode::kClamp:
1958     case HloOpcode::kSelect:
1959     case HloOpcode::kTupleSelect:
1960       CHECK_EQ(new_operands.size(), 3);
1961       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1962                             new_operands[2]);
1963       break;
1964     // Other supported ops.
1965     case HloOpcode::kCall:
1966       clone = CreateCall(shape, new_operands, to_apply());
1967       break;
1968     case HloOpcode::kConvert:
1969       CHECK_EQ(new_operands.size(), 1);
1970       clone = CreateConvert(shape, new_operands[0]);
1971       break;
1972     case HloOpcode::kBitcastConvert:
1973       CHECK_EQ(new_operands.size(), 1);
1974       clone = CreateBitcastConvert(shape, new_operands[0]);
1975       break;
1976     case HloOpcode::kDynamicUpdateSlice:
1977       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1978                                        new_operands.subspan(2));
1979       break;
1980     case HloOpcode::kTuple:
1981       clone = CreateTuple(new_operands);
1982       *clone->mutable_shape() = shape;
1983       break;
1984     case HloOpcode::kWhile:
1985       CHECK_EQ(new_operands.size(), 1);
1986       clone =
1987           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1988       break;
1989     case HloOpcode::kConditional:
1990       CHECK_EQ(new_operands.size(), branch_count() + 1);
1991       clone = CreateConditional(shape, new_operands[0],
1992                                 absl::MakeSpan(branch_computations()),
1993                                 new_operands.subspan(1));
1994       break;
1995     case HloOpcode::kAfterAll:
1996       if (new_operands.empty()) {
1997         clone = CreateToken();
1998       } else {
1999         clone = CreateAfterAll(new_operands);
2000       }
2001       break;
2002     case HloOpcode::kAddDependency:
2003       CHECK_EQ(new_operands.size(), 2);
2004       clone = CreateAddDependency(new_operands[0], new_operands[1]);
2005       break;
2006     case HloOpcode::kReplicaId:
2007       CHECK_EQ(new_operands.size(), 0);
2008       clone = CreateReplicaId(shape);
2009       break;
2010     case HloOpcode::kPartitionId:
2011       CHECK_EQ(new_operands.size(), 0);
2012       clone = CreatePartitionId(shape);
2013       break;
2014   }
2015   // SetupDerivedInstruction will setup the precision_config_ field.
2016   SetupDerivedInstruction(clone.get());
2017   clone->set_parent(parent_);
2018   clone->set_outer_dimension_partitions(outer_dimension_partitions_);
2019   clone->set_raw_backend_config_string(backend_config_);
2020   if (context != nullptr) {
2021     context->MapInstruction(this, clone.get());
2022     clone->ReplaceCalledComputations([&](HloComputation* callee) {
2023       return callee->parent() != context->module()
2024                  ? context->module()->DeepCloneComputation(callee, context)
2025                  : callee;
2026     });
2027   }
2028   return clone;
2029 }
2030 
DetachFromOperandsAndUsers()2031 void HloInstruction::DetachFromOperandsAndUsers() {
2032   if (cleaned_up_) {
2033     return;
2034   }
2035   cleaned_up_ = true;
2036   // Detach from operands. An instruction may be repeated as an operand. To
2037   // avoid calling RemoveUser twice on the same operand, check before remove.
2038   for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) {
2039     HloInstruction* operand = operands_[operand_num];
2040     if (operand == nullptr) {
2041       continue;
2042     }
2043     if (operand->user_map_.find(this) != operand->user_map_.end()) {
2044       operand->RemoveUser(this);
2045     }
2046     operands_[operand_num] = nullptr;
2047   }
2048 
2049   // Update users. Set `nullptr` to the corresponding operand slot for users.
2050   for (auto& user : this->users()) {
2051     for (int i = 0; i < user->operand_count(); ++i) {
2052       if (user->operands_[i] == this) {
2053         user->operands_[i] = nullptr;
2054       }
2055     }
2056   }
2057 }
2058 
CloneWithNewShape(const Shape & shape,const string & suffix,HloCloneContext * context) const2059 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
2060     const Shape& shape, const string& suffix, HloCloneContext* context) const {
2061   std::unique_ptr<HloInstruction> clone =
2062       CloneWithNewOperands(shape, operands_, context);
2063   if (suffix.empty()) {
2064     clone->name_ = name();
2065   } else {
2066     // If an instruction is cloned multiple times avoid names like
2067     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
2068     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
2069     // clone of foo.suffix2 is named foo.suffix3 and so on.
2070     const string dot_suffix = "." + suffix;
2071     size_t index = name().rfind(dot_suffix);
2072     if (index == string::npos) {
2073       // Existing name does not include ".suffix".
2074       clone->name_ = name() + dot_suffix;
2075     } else {
2076       // Existing name includes ".suffix". Determine if substring after
2077       // ".suffix" is numeric and should be replaced with an incremented number.
2078       string after_suffix = name().substr(index + dot_suffix.size());
2079       if (after_suffix.empty()) {
2080         // Existing name ends in ".suffix". New name should end in ".suffix2".
2081         clone->name_ = name() + "2";
2082       } else {
2083         // If names ends with .suffix[0-9]+ then replace with a suffix with the
2084         // numeric value incremented.
2085         int64_t numeric_suffix;
2086         if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
2087           clone->name_ =
2088               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
2089         } else {
2090           // Substring after ".suffix" is non-numeric.
2091           clone->name_ = name() + dot_suffix;
2092         }
2093       }
2094     }
2095   }
2096   return clone;
2097 }
2098 
Clone(const string & suffix,HloCloneContext * context) const2099 std::unique_ptr<HloInstruction> HloInstruction::Clone(
2100     const string& suffix, HloCloneContext* context) const {
2101   std::unique_ptr<HloInstruction> clone =
2102       CloneWithNewShape(shape_, suffix, context);
2103   return clone;
2104 }
2105 
2106 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const2107 HloInstruction::LatestNonGteAncestorAndIndex() const {
2108   const HloInstruction* hlo = this;
2109   ShapeIndex index;
2110   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2111     index.push_back(hlo->tuple_index());
2112     hlo = hlo->operand(0);
2113   }
2114 
2115   // We built up index in the reverse order from what we want.
2116   std::reverse(index.begin(), index.end());
2117 
2118   return {hlo, index};
2119 }
2120 
LatestNonGteAncestor() const2121 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
2122   const HloInstruction* hlo = this;
2123   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2124     hlo = hlo->operand(0);
2125   }
2126   return hlo;
2127 }
2128 
operand(int64_t i) const2129 const HloInstruction* HloInstruction::operand(int64_t i) const {
2130   return operands_.at(i);
2131 }
2132 
mutable_operand(int64_t i)2133 HloInstruction* HloInstruction::mutable_operand(int64_t i) {
2134   CHECK(operands_[i] != nullptr);
2135   return operands_.at(i);
2136 }
2137 
operand_index(const HloInstruction * target) const2138 int64 HloInstruction::operand_index(const HloInstruction* target) const {
2139   for (int64_t i = 0; i < operand_count(); ++i) {
2140     if (target == operand(i)) {
2141       return i;
2142     }
2143   }
2144   LOG(FATAL) << "target was not an operand: " << target->ToString();
2145 }
2146 
unique_operands() const2147 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
2148   InstructionVector unique;
2149   absl::flat_hash_set<const HloInstruction*> seen;
2150   for (HloInstruction* operand : operands()) {
2151     if (seen.insert(operand).second) {
2152       unique.push_back(operand);
2153     }
2154   }
2155   return unique;
2156 }
2157 
AddControlDependencyTo(HloInstruction * instruction)2158 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
2159   TF_RET_CHECK(instruction->parent() == parent());
2160   if (!absl::c_linear_search(control_successors_, instruction)) {
2161     control_successors_.push_back(instruction);
2162     TF_RET_CHECK(
2163         !absl::c_linear_search(instruction->control_predecessors_, this));
2164     instruction->control_predecessors_.push_back(this);
2165   }
2166   return Status::OK();
2167 }
2168 
RemoveControlDependencyTo(HloInstruction * instruction)2169 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
2170   TF_RET_CHECK(instruction->parent() == parent());
2171   TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
2172   TF_RETURN_IF_ERROR(
2173       EraseElementFromVector(&instruction->control_predecessors_, this));
2174   return Status::OK();
2175 }
2176 
DropAllControlDeps()2177 Status HloInstruction::DropAllControlDeps() {
2178   for (auto* ctrl_succ : control_successors_) {
2179     TF_RETURN_IF_ERROR(
2180         EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
2181   }
2182   for (auto* ctrl_pred : control_predecessors_) {
2183     TF_RETURN_IF_ERROR(
2184         EraseElementFromVector(&ctrl_pred->control_successors_, this));
2185   }
2186   control_successors_.clear();
2187   control_predecessors_.clear();
2188   return Status::OK();
2189 }
2190 
CopyAllControlDepsFrom(const HloInstruction * inst)2191 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2192   for (auto* ctrl_pred : inst->control_predecessors()) {
2193     TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2194   }
2195 
2196   for (auto* ctrl_succ : inst->control_successors()) {
2197     TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2198   }
2199 
2200   return Status::OK();
2201 }
2202 
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values) const2203 bool HloInstruction::IdenticalInternal(
2204     const HloInstruction& other,
2205     const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2206         eq_operands,
2207     const std::function<bool(const HloComputation*, const HloComputation*)>&
2208         eq_computations,
2209     bool layout_sensitive, bool ignore_channel_id_values) const {
2210   // An instruction is always identical to itself.
2211   if (this == &other) {
2212     return true;
2213   }
2214 
2215   // Identical instruction must have the same opcode, shape, and identical
2216   // operands.
2217   if (opcode() != other.opcode()) {
2218     return false;
2219   }
2220   if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2221                          : ShapeUtil::Compatible(shape(), other.shape()))) {
2222     return false;
2223   }
2224   if (operands().size() != other.operands().size()) {
2225     return false;
2226   }
2227 
2228   // Use an explicit loop rather than ContainerEquals, because copying around
2229   // std::functions may be too expensive in some cases.
2230   for (size_t i = 0; i < operands().size(); ++i) {
2231     if (!eq_operands(operand(i), other.operand(i))) {
2232       return false;
2233     }
2234   }
2235 
2236   if (backend_config_ != other.backend_config_) {
2237     return false;
2238   }
2239 
2240   if (ignore_channel_id_values) {
2241     if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2242       return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2243           other, eq_computations);
2244     }
2245   }
2246   return IdenticalSlowPath(other, eq_computations);
2247 }
2248 
AppendOperand(HloInstruction * operand)2249 void HloInstruction::AppendOperand(HloInstruction* operand) {
2250   if (operand->parent() != nullptr) {
2251     DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2252         << "Operand " << operand->name() << " is already marked dead";
2253   }
2254   operands_.push_back(operand);
2255   operand->AddUser(this);
2256 }
2257 
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2258 void HloInstruction::RemoveOperandsAtAscendingIndices(
2259     absl::Span<const int> ascending_indices) {
2260   if (ascending_indices.empty()) {
2261     return;
2262   }
2263   int next_index = 0;
2264   int removed_count = 0;
2265   for (int to_remove : ascending_indices) {
2266     while (next_index < to_remove) {
2267       operands_[next_index - removed_count] = operands_[next_index];
2268       ++next_index;
2269     }
2270     CHECK_LT(to_remove, operands_.size());
2271     ++removed_count;
2272     ++next_index;
2273   }
2274   while (next_index < operands_.size()) {
2275     operands_[next_index - removed_count] = operands_[next_index];
2276     ++next_index;
2277   }
2278   CHECK_EQ(removed_count, ascending_indices.size());
2279   operands_.resize(operands_.size() - removed_count);
2280 }
2281 
AddUser(HloInstruction * user)2282 void HloInstruction::AddUser(HloInstruction* user) {
2283   if (!ContainsKey(user_map_, user)) {
2284     user_map_.emplace(user, users_.size());
2285     users_.push_back(user);
2286   }
2287 }
2288 
UserId(HloInstruction * user)2289 int64 HloInstruction::UserId(HloInstruction* user) {
2290   auto result = user_map_.find(user);
2291   CHECK(result != user_map_.end());
2292   return result->second;
2293 }
2294 
HasConstantOperand() const2295 bool HloInstruction::HasConstantOperand() const {
2296   for (const HloInstruction* operand : operands_) {
2297     if (operand->IsConstant()) {
2298       return true;
2299     }
2300   }
2301   return false;
2302 }
2303 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2304 bool HloInstruction::IdenticalSlowPath(
2305     const HloInstruction& other,
2306     const std::function<bool(const HloComputation*, const HloComputation*)>&
2307         eq_computations) const {
2308   // Perform opcode specific checks.
2309   switch (opcode()) {
2310     // The result of these instructions only depend upon their opcode and
2311     // operands.
2312     case HloOpcode::kAbs:
2313     case HloOpcode::kAllGatherDone:
2314     case HloOpcode::kAllReduceDone:
2315     case HloOpcode::kAtan2:
2316     case HloOpcode::kAdd:
2317     case HloOpcode::kBitcast:
2318     case HloOpcode::kBitcastConvert:
2319     case HloOpcode::kCeil:
2320     case HloOpcode::kClamp:
2321     case HloOpcode::kClz:
2322     case HloOpcode::kCollectivePermuteDone:
2323     case HloOpcode::kComplex:
2324     case HloOpcode::kConvert:
2325     case HloOpcode::kCopy:
2326     case HloOpcode::kCopyStart:
2327     case HloOpcode::kCopyDone:
2328     case HloOpcode::kCos:
2329     case HloOpcode::kDivide:
2330     case HloOpcode::kDynamicUpdateSlice:
2331     case HloOpcode::kExp:
2332     case HloOpcode::kExpm1:
2333     case HloOpcode::kFloor:
2334     case HloOpcode::kImag:
2335     case HloOpcode::kIsFinite:
2336     case HloOpcode::kLog:
2337     case HloOpcode::kLog1p:
2338     case HloOpcode::kAnd:
2339     case HloOpcode::kNot:
2340     case HloOpcode::kOr:
2341     case HloOpcode::kXor:
2342     case HloOpcode::kMaximum:
2343     case HloOpcode::kMinimum:
2344     case HloOpcode::kMultiply:
2345     case HloOpcode::kNegate:
2346     case HloOpcode::kPartitionId:
2347     case HloOpcode::kPopulationCount:
2348     case HloOpcode::kPower:
2349     case HloOpcode::kReal:
2350     case HloOpcode::kRemainder:
2351     case HloOpcode::kReshape:
2352     case HloOpcode::kDynamicReshape:
2353     case HloOpcode::kReplicaId:
2354     case HloOpcode::kRoundNearestAfz:
2355     case HloOpcode::kRsqrt:
2356     case HloOpcode::kSelect:
2357     case HloOpcode::kShiftLeft:
2358     case HloOpcode::kShiftRightArithmetic:
2359     case HloOpcode::kShiftRightLogical:
2360     case HloOpcode::kLogistic:
2361     case HloOpcode::kSign:
2362     case HloOpcode::kSin:
2363     case HloOpcode::kSqrt:
2364     case HloOpcode::kCbrt:
2365     case HloOpcode::kSubtract:
2366     case HloOpcode::kTanh:
2367     case HloOpcode::kTuple:
2368     case HloOpcode::kTupleSelect:
2369       return true;
2370 
2371     // This opcode has complex or special behavior so just return false.
2372     case HloOpcode::kAfterAll:
2373     case HloOpcode::kAddDependency:
2374       return false;
2375 
2376     // Remaining instructions with special values.
2377     case HloOpcode::kCall:
2378       return eq_computations(to_apply(), other.to_apply());
2379     case HloOpcode::kConditional:
2380       for (int j = 0; j < branch_count(); ++j) {
2381         if (!eq_computations(branch_computation(j),
2382                              other.branch_computation(j))) {
2383           return false;
2384         }
2385       }
2386       return true;
2387     case HloOpcode::kWhile:
2388       return (eq_computations(while_body(), other.while_body()) &&
2389               eq_computations(while_condition(), other.while_condition()));
2390 
2391     // Ops migrated to subclasses should never come to this line.
2392     // TODO(b/80131774): Remove this switch when migration is complete.
2393     case HloOpcode::kBatchNormTraining:
2394     case HloOpcode::kBatchNormInference:
2395     case HloOpcode::kBatchNormGrad:
2396     case HloOpcode::kFft:
2397     case HloOpcode::kCompare:
2398     case HloOpcode::kSend:
2399     case HloOpcode::kSendDone:
2400     case HloOpcode::kRecv:
2401     case HloOpcode::kRecvDone:
2402     case HloOpcode::kReverse:
2403     case HloOpcode::kConcatenate:
2404     case HloOpcode::kReduce:
2405     case HloOpcode::kSort:
2406     case HloOpcode::kTranspose:
2407     case HloOpcode::kBroadcast:
2408     case HloOpcode::kMap:
2409     case HloOpcode::kSlice:
2410     case HloOpcode::kConstant:
2411     case HloOpcode::kIota:
2412     case HloOpcode::kTrace:
2413     case HloOpcode::kFusion:
2414     case HloOpcode::kRng:
2415     case HloOpcode::kRngBitGenerator:
2416     case HloOpcode::kRngGetAndUpdateState:
2417     case HloOpcode::kParameter:
2418     case HloOpcode::kGetTupleElement:
2419     case HloOpcode::kReducePrecision:
2420     case HloOpcode::kInfeed:
2421     case HloOpcode::kOutfeed:
2422     case HloOpcode::kAllGather:
2423     case HloOpcode::kAllGatherStart:
2424     case HloOpcode::kAllReduce:
2425     case HloOpcode::kReduceScatter:
2426     case HloOpcode::kAllReduceStart:
2427     case HloOpcode::kAllToAll:
2428     case HloOpcode::kCollectivePermute:
2429     case HloOpcode::kCollectivePermuteStart:
2430     case HloOpcode::kConvolution:
2431     case HloOpcode::kCustomCall:
2432     case HloOpcode::kReduceWindow:
2433     case HloOpcode::kSelectAndScatter:
2434     case HloOpcode::kPad:
2435     case HloOpcode::kDynamicSlice:
2436     case HloOpcode::kGather:
2437     case HloOpcode::kScatter:
2438     case HloOpcode::kDot:
2439     case HloOpcode::kDomain:
2440     case HloOpcode::kGetDimensionSize:
2441     case HloOpcode::kSetDimensionSize:
2442     case HloOpcode::kTriangularSolve:
2443     case HloOpcode::kCholesky:
2444       LOG(FATAL) << "Base class impl called for opcode with subclass: "
2445                  << opcode();
2446   }
2447   return false;
2448 }
2449 
HashOperand(const HloInstruction * hlo)2450 static uint64 HashOperand(const HloInstruction* hlo) {
2451   return ShapeUtil::Hash(hlo->shape());
2452 }
2453 
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const2454 uint64 HloInstruction::Hash(
2455     const std::function<uint64(const HloInstruction*)>& hash_operand) const {
2456   using tensorflow::Hash64Combine;
2457 
2458   uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
2459   hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
2460 
2461   if (!IsCrossModuleAllReduce()) {
2462     if (!operands().empty()) {
2463       for (size_t i = 0; i < operands().size(); ++i) {
2464         hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
2465       }
2466     }
2467   }
2468 
2469   hash_value = Hash64Combine(hash_value, InnerHash());
2470   return hash_value;
2471 }
2472 
Hash() const2473 uint64 HloInstruction::Hash() const {
2474   // Use HashOperand as an argument to prevent non-termination.
2475   return Hash(HashOperand);
2476 }
2477 
InnerHash() const2478 uint64 HloInstruction::InnerHash() const { return 13; }
2479 
RemoveUser(HloInstruction * user)2480 void HloInstruction::RemoveUser(HloInstruction* user) {
2481   auto map_it = user_map_.find(user);
2482   CHECK(map_it != user_map_.end());
2483 
2484   const int64_t index = map_it->second;
2485   CHECK_EQ(users_[index], user);
2486 
2487   // Move the last user into the position of the removed user.
2488   users_[index] = users_.back();
2489   user_map_[users_.back()] = index;
2490 
2491   // Remove the user from the map and drop the last slot from the vector what
2492   // have been moved to the position of the original user.
2493   user_map_.erase(map_it);
2494   users_.pop_back();
2495 }
2496 
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2497 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2498                                       HloInstruction* new_producer) {
2499   TF_RET_CHECK(
2500       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2501       << "this shape: " << ShapeUtil::HumanString(shape())
2502       << ", replacement shape: "
2503       << ShapeUtil::HumanString(new_producer->shape());
2504   return ReplaceUseWithDifferentShape(user, new_producer);
2505 }
2506 
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2507 Status HloInstruction::ReplaceUseWithDifferentShape(
2508     HloInstruction* user, HloInstruction* new_producer) {
2509   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2510           << " with " << new_producer->name();
2511 
2512   RemoveUser(user);
2513 
2514   TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2515   std::replace(user->operands_.begin(), user->operands_.end(), this,
2516                new_producer);
2517   new_producer->AddUser(user);
2518   // Custom fusions may not be able to handle deduplicated operands.
2519   if (user->opcode() == HloOpcode::kFusion) {
2520     TF_RETURN_IF_ERROR(
2521         Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2522   }
2523   return Status::OK();
2524 }
2525 
ReplaceOperandWith(int64_t operand_num,HloInstruction * new_operand)2526 Status HloInstruction::ReplaceOperandWith(int64_t operand_num,
2527                                           HloInstruction* new_operand) {
2528   auto old_operand = operand(operand_num);
2529   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2530                                                         new_operand->shape()))
2531       << old_operand->shape() << " is not compatible with "
2532       << new_operand->shape();
2533   return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2534 }
2535 
ReplaceOperandWithDifferentShape(int64_t operand_num,HloInstruction * new_operand)2536 Status HloInstruction::ReplaceOperandWithDifferentShape(
2537     int64_t operand_num, HloInstruction* new_operand) {
2538   TF_RET_CHECK(operand_num >= 0);
2539   TF_RET_CHECK(operand_num < operand_count());
2540   HloInstruction* old_operand = mutable_operand(operand_num);
2541   if (old_operand == new_operand) {
2542     return Status::OK();
2543   }
2544 
2545   operands_[operand_num] = new_operand;
2546 
2547   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2548           << new_operand->name() << ", was " << old_operand->name();
2549 
2550   if (!absl::c_linear_search(operands_, old_operand)) {
2551     old_operand->RemoveUser(this);
2552   }
2553   new_operand->AddUser(this);
2554   return Status::OK();
2555 }
2556 
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2557 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2558                                        HloInstruction* new_producer) {
2559   TF_RET_CHECK(
2560       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2561       << shape() << " is not compatible with " << new_producer->shape();
2562   return ReplaceAllUsesWithDifferentShape(users, new_producer);
2563 }
2564 
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2565 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2566     absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2567   for (HloInstruction* user : users) {
2568     TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2569   }
2570 
2571   if (parent_ && parent_->root_instruction() == this) {
2572     parent_->set_root_instruction(new_producer,
2573                                   /*accept_different_shape=*/true);
2574   }
2575   return Status::OK();
2576 }
2577 
ReplaceAllUsesWith(HloInstruction * new_producer)2578 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2579   TF_RET_CHECK(
2580       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2581       << shape() << " is not compatible with " << new_producer->shape();
2582   return ReplaceAllUsesWithDifferentShape(new_producer);
2583 }
2584 
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2585 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2586     HloInstruction* new_producer) {
2587   bool new_producer_is_user = false;
2588   for (HloInstruction* user : users()) {
2589     if (user == new_producer) {
2590       // It's possible that new_producer is a user of this instruction as might
2591       // be the case when replacing an instruction with a kCopy of itself. In
2592       // this case, don't do the replacement to avoid creating a cycle in the
2593       // graph. new_producer remains the only user of this instruction.
2594       new_producer_is_user = true;
2595     } else {
2596       std::replace(user->operands_.begin(), user->operands_.end(), this,
2597                    new_producer);
2598       new_producer->AddUser(user);
2599       if (user->opcode() == HloOpcode::kFusion) {
2600         TF_RETURN_IF_ERROR(
2601             Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2602       }
2603     }
2604   }
2605   users_.clear();
2606   user_map_.clear();
2607   if (new_producer_is_user) {
2608     AddUser(new_producer);
2609   }
2610   if (parent_ && parent_->root_instruction() == this) {
2611     parent_->set_root_instruction(new_producer,
2612                                   /*accept_different_shape=*/true);
2613   }
2614 
2615   return Status::OK();
2616 }
2617 
IsEffectiveBitcast() const2618 bool HloInstruction::IsEffectiveBitcast() const {
2619   return opcode_ == HloOpcode::kBitcast ||
2620          (opcode_ == HloOpcode::kTranspose &&
2621           ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2622                                         dimensions()));
2623 }
2624 
to_apply() const2625 HloComputation* HloInstruction::to_apply() const {
2626   switch (opcode_) {
2627     case HloOpcode::kCall:
2628     case HloOpcode::kMap:
2629     case HloOpcode::kReduceWindow:
2630     case HloOpcode::kReduce:
2631     case HloOpcode::kAllReduce:
2632     case HloOpcode::kReduceScatter:
2633     case HloOpcode::kAllReduceStart:
2634     case HloOpcode::kScatter:
2635     case HloOpcode::kSort:
2636     case HloOpcode::kCustomCall:
2637       CHECK_EQ(called_computations_.size(), 1);
2638       return called_computations_[0];
2639     default:
2640       LOG(FATAL) << "Invalid opcode for to_apply(): "
2641                  << HloOpcodeString(opcode());
2642   }
2643 }
2644 
set_to_apply(HloComputation * computation)2645 void HloInstruction::set_to_apply(HloComputation* computation) {
2646   // Don't allow changing the computation for fused instructions so we don't
2647   // have to recompute called_instructions for the entire fusion instruction.
2648   CHECK(!IsFused());
2649   switch (opcode_) {
2650     case HloOpcode::kCall:
2651     case HloOpcode::kMap:
2652     case HloOpcode::kReduceWindow:
2653     case HloOpcode::kReduce:
2654     case HloOpcode::kAllReduce:
2655     case HloOpcode::kAllReduceStart:
2656     case HloOpcode::kScatter:
2657     case HloOpcode::kSort:
2658     case HloOpcode::kCustomCall:
2659       CHECK_EQ(called_computations_.size(), 1);
2660       called_computations_[0] = computation;
2661       break;
2662     default:
2663       LOG(FATAL) << "Invalid opcode for to_apply(): "
2664                  << HloOpcodeString(opcode());
2665   }
2666 }
2667 
while_condition() const2668 HloComputation* HloInstruction::while_condition() const {
2669   CHECK_EQ(HloOpcode::kWhile, opcode_);
2670   return called_computations_[kConditionComputationIndex];
2671 }
2672 
while_body() const2673 HloComputation* HloInstruction::while_body() const {
2674   CHECK_EQ(HloOpcode::kWhile, opcode_);
2675   return called_computations_[kBodyComputationIndex];
2676 }
2677 
set_while_condition(HloComputation * computation)2678 void HloInstruction::set_while_condition(HloComputation* computation) {
2679   // Don't allow changing the computation for fused instructions so we don't
2680   // have to recompute called_instructions for the entire fusion instruction.
2681   CHECK(!IsFused());
2682   CHECK_EQ(HloOpcode::kWhile, opcode_);
2683   called_computations_[kConditionComputationIndex] = computation;
2684 }
2685 
set_while_body(HloComputation * computation)2686 void HloInstruction::set_while_body(HloComputation* computation) {
2687   // Don't allow changing the computation for fused instructions so we don't
2688   // have to recompute called_instructions for the entire fusion instruction.
2689   CHECK(!IsFused());
2690   CHECK_EQ(HloOpcode::kWhile, opcode_);
2691   called_computations_[kBodyComputationIndex] = computation;
2692 }
2693 
while_init() const2694 HloInstruction* HloInstruction::while_init() const {
2695   CHECK_EQ(HloOpcode::kWhile, opcode_);
2696   return operands_[0];
2697 }
2698 
true_computation() const2699 HloComputation* HloInstruction::true_computation() const {
2700   CHECK_EQ(HloOpcode::kConditional, opcode_);
2701   CHECK_EQ(PRED, operand(0)->shape().element_type());
2702   return called_computations_[kTrueComputationIndex];
2703 }
2704 
false_computation() const2705 HloComputation* HloInstruction::false_computation() const {
2706   CHECK_EQ(HloOpcode::kConditional, opcode_);
2707   CHECK_EQ(PRED, operand(0)->shape().element_type());
2708   return called_computations_[kFalseComputationIndex];
2709 }
2710 
branch_computations() const2711 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2712     const {
2713   CHECK(HloOpcode::kConditional == opcode_);
2714   return called_computations_;
2715 }
2716 
branch_count() const2717 int HloInstruction::branch_count() const {
2718   CHECK(HloOpcode::kConditional == opcode_);
2719   return called_computations_.size();
2720 }
2721 
branch_computation(int b) const2722 HloComputation* HloInstruction::branch_computation(int b) const {
2723   CHECK(HloOpcode::kConditional == opcode_);
2724   CHECK_GE(b, 0);
2725   CHECK_LT(b, called_computations_.size());
2726   return called_computations_[b];
2727 }
2728 
set_branch_computation(int b,HloComputation * computation)2729 void HloInstruction::set_branch_computation(int b,
2730                                             HloComputation* computation) {
2731   // Don't allow changing the computation for fused instructions so we don't
2732   // have to recompute called_instructions for the entire fusion instruction.
2733   CHECK(!IsFused());
2734   CHECK_EQ(HloOpcode::kConditional, opcode_);
2735   called_computations_[b] = computation;
2736 }
2737 
SignatureString() const2738 string HloInstruction::SignatureString() const {
2739   string operands =
2740       StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2741         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2742       });
2743   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2744 }
2745 
PrintName(const string & name,bool print_ids)2746 string PrintName(const string& name, bool print_ids) {
2747   if (print_ids) {
2748     return name;
2749   } else {
2750     auto dot_position = name.find_first_of('.');
2751     return name.substr(0, dot_position);
2752   }
2753 }
2754 
2755 namespace {
2756 
2757 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2758 
PrintNameInternal(const string & name,const HloPrintOptions & options)2759 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2760   return StrCat(options.print_percent() ? "%" : "",
2761                 PrintName(name, options.print_ids()));
2762 }
2763 
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2764 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2765   // This set contains HloInstructions from the top of `DFSStack` that might
2766   // belong to the cycle, i.e. if  DFSStack :=[back,...,child,...,top], then
2767   // `subgraph` := {child,...,top}.
2768   absl::flat_hash_set<const HloInstruction*> subgraph;
2769   while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2770     subgraph.insert(dfs_stack->back().second);
2771     dfs_stack->pop_back();
2772   }
2773   // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2774   absl::flat_hash_set<const HloInstruction*> visited;
2775   absl::InlinedVector<const HloInstruction*, 16> dfs;
2776   dfs.push_back(child);
2777   while (!dfs.empty()) {
2778     bool found_next_instr = false;
2779     for (const auto& user : dfs.back()->users()) {
2780       if (user == child) {
2781         dfs.push_back(child);
2782         LOG(INFO) << "\n\nDirected cycle:\n  "
2783                   << absl::StrJoin(
2784                          dfs, "\n  ",
2785                          [](std::string* out, const HloInstruction* instr) {
2786                            out->append(instr->name());
2787                          });
2788         return;
2789       }
2790       if (!subgraph.contains(user) || visited.contains(user)) {
2791         continue;
2792       }
2793       visited.insert(user);
2794       dfs.push_back(user);
2795       found_next_instr = true;
2796     }
2797     if (!found_next_instr) {
2798       dfs.pop_back();
2799     }
2800   }
2801 }
2802 
2803 }  // namespace
2804 
ToString(const HloPrintOptions & options) const2805 string HloInstruction::ToString(const HloPrintOptions& options) const {
2806   CanonicalNameMap new_map;
2807   return ToStringWithCanonicalNameMap(options, &new_map);
2808 }
2809 
IsOpElementwise(HloOpcode opcode)2810 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2811   switch (opcode) {
2812     // Unary elementwise operations.
2813     case HloOpcode::kAbs:
2814     case HloOpcode::kRoundNearestAfz:
2815     case HloOpcode::kCeil:
2816     case HloOpcode::kClz:
2817     case HloOpcode::kConvert:
2818     case HloOpcode::kBitcastConvert:
2819     case HloOpcode::kCopy:
2820     case HloOpcode::kCos:
2821     case HloOpcode::kExp:
2822     case HloOpcode::kExpm1:
2823     case HloOpcode::kFloor:
2824     case HloOpcode::kImag:
2825     case HloOpcode::kIsFinite:
2826     case HloOpcode::kLog:
2827     case HloOpcode::kLog1p:
2828     case HloOpcode::kNot:
2829     case HloOpcode::kNegate:
2830     case HloOpcode::kPopulationCount:
2831     case HloOpcode::kReal:
2832     case HloOpcode::kReducePrecision:
2833     case HloOpcode::kRsqrt:
2834     case HloOpcode::kLogistic:
2835     case HloOpcode::kSign:
2836     case HloOpcode::kSin:
2837     case HloOpcode::kSqrt:
2838     case HloOpcode::kCbrt:
2839     case HloOpcode::kTanh:
2840       return true;
2841 
2842     // Binary elementwise operations, the same as in IsElementwiseBinary().
2843     case HloOpcode::kAdd:
2844     case HloOpcode::kAtan2:
2845     case HloOpcode::kCompare:
2846     case HloOpcode::kComplex:
2847     case HloOpcode::kDivide:
2848     case HloOpcode::kMaximum:
2849     case HloOpcode::kMinimum:
2850     case HloOpcode::kMultiply:
2851     case HloOpcode::kPower:
2852     case HloOpcode::kRemainder:
2853     case HloOpcode::kSubtract:
2854     case HloOpcode::kAnd:
2855     case HloOpcode::kOr:
2856     case HloOpcode::kXor:
2857     case HloOpcode::kShiftLeft:
2858     case HloOpcode::kShiftRightArithmetic:
2859     case HloOpcode::kShiftRightLogical:
2860       return true;
2861 
2862     // Ternary elementwise operations.
2863     case HloOpcode::kSelect:
2864     case HloOpcode::kClamp:
2865       return true;
2866 
2867     default:
2868       return false;
2869   }
2870 }
2871 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2872 bool HloInstruction::IsElementwiseImpl(
2873     const absl::optional<int64>& operand_idx) const {
2874   if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
2875     return operand_idx.has_value() && operand_idx.value() == 0;
2876   }
2877   return IsOpElementwise(opcode_);
2878 }
2879 
IsCrossModuleAllReduce() const2880 bool HloInstruction::IsCrossModuleAllReduce() const {
2881   return opcode() == HloOpcode::kAllReduce && channel_id();
2882 }
2883 
IsCrossReplicaAllReduce() const2884 bool HloInstruction::IsCrossReplicaAllReduce() const {
2885   return opcode() == HloOpcode::kAllReduce && !channel_id();
2886 }
2887 
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2888 string HloInstruction::ToStringWithCanonicalNameMap(
2889     const HloPrintOptions& options,
2890     CanonicalNameMap* canonical_name_map) const {
2891   string result = "";
2892 
2893   // Logic to print the instruction name (e.g. "%foo = ").
2894   if (options.canonicalize_instruction_names()) {
2895     if (options.is_in_nested_computation()) {
2896       // If we are canonicalizing instruction names and this is a top-level
2897       // HloInstruction::ToString() call, don't print an instruction name.
2898       StrAppend(&result,
2899                 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2900                                   options),
2901                 " = ");
2902     }
2903   } else {
2904     StrAppend(&result, PrintNameInternal(name(), options), " = ");
2905   }
2906 
2907   if (options.print_result_shape()) {
2908     // Print shape.
2909     if (options.include_layout_in_shapes()) {
2910       StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
2911     } else {
2912       StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
2913     }
2914   }
2915 
2916   // Print opcode, operand(s).
2917   StrAppend(&result, HloOpcodeString(opcode()), "(",
2918             OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2919             ")");
2920 
2921   // Print additional attributes. If an instruction contains a subcomputation,
2922   // the subcomputation is also printed here.
2923   for (const string& extra : ExtraAttributesToString(options)) {
2924     StrAppend(&result, ", ", extra);
2925   }
2926 
2927   if (options.print_metadata() &&
2928       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2929        !metadata_.source_file().empty())) {
2930     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2931   }
2932   if (options.print_backend_config() && !backend_config_.empty()) {
2933     StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2934   }
2935   return result;
2936 }
2937 
OperandsToString(const HloPrintOptions & options) const2938 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2939   CanonicalNameMap new_map;
2940   return OperandsToStringWithCanonicalNameMap(options, &new_map);
2941 }
2942 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2943 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2944     const HloPrintOptions& options,
2945     CanonicalNameMap* canonical_name_map) const {
2946   string operands;
2947   absl::Span<HloInstruction* const> slice(operands_);
2948   const int64_t kMaxOperandsToShowIfCompact = 4;
2949   if (options.compact_operands() &&
2950       slice.size() > kMaxOperandsToShowIfCompact) {
2951     slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2952   }
2953   for (int64_t i = 0; i < slice.size(); ++i) {
2954     HloInstruction* operand = slice[i];
2955     if (i != 0) {
2956       StrAppend(&operands, ", ");
2957       if (options.print_operand_index_annotation_interval() != 0 &&
2958           i % options.print_operand_index_annotation_interval() == 0) {
2959         StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i));
2960       }
2961     }
2962     // If operand is already been deleted, put `null` to the string output.
2963     if (operand == nullptr) {
2964       StrAppend(&operands, "null ");
2965       continue;
2966     }
2967     std::vector<string> str;
2968     if (options.print_operand_shape()) {
2969       if (options.include_layout_in_shapes()) {
2970         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2971       } else {
2972         str.push_back(ShapeUtil::HumanString(operand->shape()));
2973       }
2974     }
2975 
2976     // In a top-level HloInstruction::ToString() call, the operand name is not
2977     // part of the canonical string.
2978     if (options.canonicalize_instruction_names() &&
2979         options.is_in_nested_computation()) {
2980       str.push_back(PrintNameInternal(
2981           canonical_name_map->LookupOrInsert(operand->name()), options));
2982     } else if (options.print_operand_names()) {
2983       str.push_back(PrintNameInternal(operand->name(), options));
2984     }
2985     StrAppend(&operands, StrJoin(str, " "));
2986   }
2987   const int64_t remaining = operands_.size() - slice.size();
2988   if (slice.size() != operands_.size()) {
2989     StrAppend(&operands, ", ...(+", remaining, ")");
2990   }
2991   return operands;
2992 }
2993 
2994 namespace {
2995 
IsSequentialCall(HloOpcode opcode)2996 bool IsSequentialCall(HloOpcode opcode) {
2997   switch (opcode) {
2998     case HloOpcode::kCall:
2999     case HloOpcode::kConditional:
3000     case HloOpcode::kWhile:
3001       return true;
3002     default:
3003       return false;
3004   }
3005 }
3006 
3007 }  // namespace
3008 
ExtraAttributesToString(const HloPrintOptions & options) const3009 std::vector<string> HloInstruction::ExtraAttributesToString(
3010     const HloPrintOptions& options) const {
3011   std::vector<string> extra = options.print_extra_attributes()
3012                                   ? ExtraAttributesToStringImpl(options)
3013                                   : std::vector<string>();
3014 
3015   const auto subcomputation_mode = options.print_subcomputation_mode();
3016   if (subcomputation_mode ==
3017       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
3018     if (opcode() == HloOpcode::kWhile) {
3019       extra.push_back(StrCat(
3020           "condition=", PrintNameInternal(while_condition()->name(), options)));
3021       extra.push_back(
3022           StrCat("body=", PrintNameInternal(while_body()->name(), options)));
3023     } else if (opcode() == HloOpcode::kSelectAndScatter) {
3024       extra.push_back(
3025           StrCat("select=", PrintNameInternal(select()->name(), options)));
3026       extra.push_back(
3027           StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
3028     } else if (opcode() == HloOpcode::kConditional) {
3029       if (operand(0)->shape().element_type() == PRED) {
3030         extra.push_back(
3031             StrCat("true_computation=",
3032                    PrintNameInternal(true_computation()->name(), options)));
3033         extra.push_back(
3034             StrCat("false_computation=",
3035                    PrintNameInternal(false_computation()->name(), options)));
3036       } else {
3037         extra.push_back(StrCat(
3038             "branch_computations={",
3039             StrJoin(branch_computations(), ", ",
3040                     [&](string* out, const HloComputation* computation) {
3041                       StrAppend(
3042                           out, PrintNameInternal(computation->name(), options));
3043                     }),
3044             "}"));
3045       }
3046     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
3047                opcode() == HloOpcode::kReduceWindow ||
3048                opcode() == HloOpcode::kReduce ||
3049                opcode() == HloOpcode::kAllReduce ||
3050                opcode() == HloOpcode::kReduceScatter ||
3051                opcode() == HloOpcode::kAllReduceStart ||
3052                opcode() == HloOpcode::kScatter ||
3053                opcode() == HloOpcode::kSort) {
3054       extra.push_back(
3055           StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
3056     } else if (opcode() == HloOpcode::kCustomCall) {
3057       if (!called_computations().empty()) {
3058         extra.push_back(StrCat(
3059             "called_computations={",
3060             StrJoin(called_computations(), ", ",
3061                     [&](string* out, const HloComputation* computation) {
3062                       StrAppend(
3063                           out, PrintNameInternal(computation->name(), options));
3064                     }),
3065             "}"));
3066       }
3067     } else if (!called_computations().empty()) {
3068       extra.push_back(StrCat(
3069           "calls=",
3070           StrJoin(called_computations(), ", ",
3071                   [&](string* out, const HloComputation* computation) {
3072                     StrAppend(out,
3073                               PrintNameInternal(computation->name(), options));
3074                   })));
3075     }
3076   } else if ((subcomputation_mode ==
3077               HloPrintOptions::PrintSubcomputationMode::kFullBodies) ||
3078              (subcomputation_mode == HloPrintOptions::PrintSubcomputationMode::
3079                                          kNonSequentialBodies &&
3080               !IsSequentialCall(opcode()))) {
3081     HloPrintOptions new_options = options;
3082     new_options.set_is_in_nested_computation(true);
3083     switch (opcode()) {
3084       case HloOpcode::kWhile:
3085         extra.push_back(
3086             StrCat("condition=\n", while_condition()->ToString(new_options)));
3087         extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
3088         break;
3089       case HloOpcode::kSelectAndScatter:
3090         extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
3091         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
3092         break;
3093       case HloOpcode::kConditional:
3094         if (operand(0)->shape().element_type() == PRED) {
3095           extra.push_back(StrCat("true_computation=\n",
3096                                  true_computation()->ToString(new_options)));
3097           extra.push_back(StrCat("false_computation=\n",
3098                                  false_computation()->ToString(new_options)));
3099         } else {
3100           extra.push_back(StrCat(
3101               "branch_computations={\n",
3102               StrJoin(branch_computations(), ",\n",
3103                       [&](string* out, const HloComputation* computation) {
3104                         StrAppend(out, computation->ToString(new_options));
3105                       }),
3106               "\n}"));
3107         }
3108         break;
3109       case HloOpcode::kCall:
3110       case HloOpcode::kMap:
3111       case HloOpcode::kReduceWindow:
3112       case HloOpcode::kReduce:
3113       case HloOpcode::kAllReduce:
3114       case HloOpcode::kAllReduceStart:
3115       case HloOpcode::kScatter:
3116       case HloOpcode::kSort:
3117         extra.push_back(
3118             StrCat("to_apply=\n", to_apply()->ToString(new_options)));
3119         break;
3120       default:
3121         if (!called_computations().empty()) {
3122           extra.push_back(StrCat(
3123               "calls=\n",
3124               StrJoin(called_computations(), ", ",
3125                       [&](string* out, const HloComputation* computation) {
3126                         StrAppend(out, computation->ToString(new_options));
3127                       })));
3128         }
3129         break;
3130     }
3131   }
3132 
3133   if (has_sharding()) {
3134     extra.push_back(
3135         StrCat("sharding=", sharding().ToString(options.print_metadata())));
3136   }
3137   if (!frontend_attributes_.map().empty()) {
3138     extra.push_back(StrCat("frontend_attributes=",
3139                            FrontendAttributesToString(frontend_attributes_)));
3140   }
3141   if (!outer_dimension_partitions_.empty()) {
3142     extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
3143                                     StrJoin(outer_dimension_partitions_, ",")));
3144   }
3145 
3146   if (options.print_control_dependencies() && !control_predecessors_.empty()) {
3147     extra.push_back(StrCat("control-predecessors={",
3148                            StrJoin(control_predecessors_, ", ",
3149                                    [&](string* out, HloInstruction* pre) {
3150                                      StrAppend(out, PrintNameInternal(
3151                                                         pre->name(), options));
3152                                    }),
3153                            "}"));
3154   }
3155 
3156   return extra;
3157 }
3158 
ToShortString() const3159 string HloInstruction::ToShortString() const {
3160   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
3161                 StrJoin(operands_, ", ",
3162                         [](string* out, HloInstruction* operand) {
3163                           StrAppend(out, "%", operand->name());
3164                         }),
3165                 ")");
3166 }
3167 
ToProto() const3168 HloInstructionProto HloInstruction::ToProto() const {
3169   HloInstructionProto proto;
3170   CHECK(unique_id_ != -1)
3171       << "This instruction does not have a valid id. Please make sure the "
3172          "instruction is inside a module before dumping it.";
3173   proto.set_id(unique_id_);
3174   proto.set_name(name_);
3175   proto.set_opcode(HloOpcodeString(opcode_));
3176   *proto.mutable_shape() = shape_.ToProto();
3177   for (const HloInstruction* operand : operands_) {
3178     proto.add_operand_ids(operand->unique_id());
3179   }
3180   for (const HloInstruction* control : control_predecessors_) {
3181     proto.add_control_predecessor_ids(control->unique_id());
3182   }
3183 
3184   *proto.mutable_metadata() = metadata_;
3185   proto.set_backend_config(backend_config_);
3186   if (opcode() != HloOpcode::kFusion) {
3187     for (const HloComputation* computation : called_computations_) {
3188       proto.add_called_computation_ids(computation->unique_id());
3189     }
3190   }
3191 
3192   if (has_sharding()) {
3193     *proto.mutable_sharding() = sharding().ToProto();
3194   }
3195   if (!outer_dimension_partitions_.empty()) {
3196     for (const auto& idx : outer_dimension_partitions_) {
3197       proto.mutable_outer_dimension_partitions()->Add(idx);
3198     }
3199   }
3200 
3201   *proto.mutable_frontend_attributes() = frontend_attributes_;
3202 
3203   return proto;
3204 }
3205 
ToCategory() const3206 string HloInstruction::ToCategory() const {
3207   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
3208       opcode() == HloOpcode::kReshape ||
3209       opcode() == HloOpcode::kDynamicReshape) {
3210     return "data formatting";
3211   }
3212 
3213   if (IsElementwise()) {
3214     return "non-fusion elementwise";
3215   }
3216 
3217   return HloOpcodeString(opcode());
3218 }
3219 
tracing() const3220 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
3221 
set_tracing(HloInstruction * trace_instruction)3222 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
3223   trace_instruction_ = trace_instruction;
3224 }
3225 
IsFused() const3226 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
3227 
IsCustomCall(absl::string_view target) const3228 bool HloInstruction::IsCustomCall(absl::string_view target) const {
3229   return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3230 }
3231 
IsInputFusion() const3232 bool HloInstruction::IsInputFusion() const {
3233   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3234 }
3235 
IsLoopFusion() const3236 bool HloInstruction::IsLoopFusion() const {
3237   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3238 }
3239 
IsOutputFusion() const3240 bool HloInstruction::IsOutputFusion() const {
3241   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3242 }
3243 
IsCustomFusion() const3244 bool HloInstruction::IsCustomFusion() const {
3245   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3246 }
3247 
IsFusible() const3248 bool HloInstruction::IsFusible() const {
3249   // Instructions which are traced should not be fused.
3250   if (tracing()) {
3251     return false;
3252   }
3253   // Some kinds of instructions don't make sense to fuse.
3254   switch (opcode_) {
3255     case HloOpcode::kDomain:
3256     case HloOpcode::kParameter:
3257     case HloOpcode::kWhile:
3258     case HloOpcode::kConditional:
3259     case HloOpcode::kCall:
3260       return false;
3261     // Fusions are always fusible.
3262     case HloOpcode::kFusion:
3263     // Side effecting reduce and reduce window would be invalid HLO.
3264     case HloOpcode::kMap:
3265     case HloOpcode::kReduce:
3266     case HloOpcode::kReduceWindow:
3267       return true;
3268     case HloOpcode::kRng:
3269       return user_count() <= 1;
3270     // Side effecting instructions cannot be fused.
3271     default:
3272       return !HasSideEffect();
3273   }
3274 }
3275 
HloInstruction(HloOpcode opcode,const Shape & shape)3276 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3277     : unique_id_(-1),
3278       opcode_(opcode),
3279       shape_(shape),
3280       name_(HloOpcodeString(opcode)),
3281       marked_as_dead_(false) {
3282   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3283 }
3284 
3285 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3286 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3287   switch (opcode_) {
3288     case HloOpcode::kAbs:
3289       return visitor->HandleAbs(this);
3290     case HloOpcode::kAtan2:
3291       return visitor->HandleAtan2(this);
3292     case HloOpcode::kRoundNearestAfz:
3293       return visitor->HandleRound(this);
3294     case HloOpcode::kBatchNormTraining:
3295       return visitor->HandleBatchNormTraining(this);
3296     case HloOpcode::kBatchNormInference:
3297       return visitor->HandleBatchNormInference(this);
3298     case HloOpcode::kBatchNormGrad:
3299       return visitor->HandleBatchNormGrad(this);
3300     case HloOpcode::kLogistic:
3301       return visitor->HandleLogistic(this);
3302     case HloOpcode::kSign:
3303       return visitor->HandleSign(this);
3304     case HloOpcode::kConstant:
3305       return visitor->HandleConstant(this);
3306     case HloOpcode::kGetTupleElement:
3307       return visitor->HandleGetTupleElement(this);
3308     case HloOpcode::kParameter:
3309       return visitor->HandleParameter(this);
3310     case HloOpcode::kCompare:
3311       return visitor->HandleCompare(this);
3312     case HloOpcode::kComplex:
3313       return visitor->HandleComplex(this);
3314     case HloOpcode::kAdd:
3315       return visitor->HandleAdd(this);
3316     case HloOpcode::kDivide:
3317       return visitor->HandleDivide(this);
3318     case HloOpcode::kSubtract:
3319       return visitor->HandleSubtract(this);
3320     case HloOpcode::kMaximum:
3321       return visitor->HandleMaximum(this);
3322     case HloOpcode::kMinimum:
3323       return visitor->HandleMinimum(this);
3324     case HloOpcode::kAnd:
3325       return visitor->HandleAnd(this);
3326     case HloOpcode::kOr:
3327       return visitor->HandleOr(this);
3328     case HloOpcode::kXor:
3329       return visitor->HandleXor(this);
3330     case HloOpcode::kShiftLeft:
3331       return visitor->HandleShiftLeft(this);
3332     case HloOpcode::kShiftRightArithmetic:
3333       return visitor->HandleShiftRightArithmetic(this);
3334     case HloOpcode::kShiftRightLogical:
3335       return visitor->HandleShiftRightLogical(this);
3336     case HloOpcode::kConcatenate:
3337       return visitor->HandleConcatenate(this);
3338     case HloOpcode::kConvert:
3339       return visitor->HandleConvert(this);
3340     case HloOpcode::kBitcastConvert:
3341       return visitor->HandleBitcastConvert(this);
3342     case HloOpcode::kCopy:
3343       return visitor->HandleCopy(this);
3344     case HloOpcode::kMultiply:
3345       return visitor->HandleMultiply(this);
3346     case HloOpcode::kDot:
3347       return visitor->HandleDot(this);
3348     case HloOpcode::kPower:
3349       return visitor->HandlePower(this);
3350     case HloOpcode::kRemainder:
3351       return visitor->HandleRemainder(this);
3352     case HloOpcode::kSelect:
3353       return visitor->HandleSelect(this);
3354     case HloOpcode::kTupleSelect:
3355       return visitor->HandleTupleSelect(this);
3356     case HloOpcode::kConvolution:
3357       return visitor->HandleConvolution(this);
3358     case HloOpcode::kFft:
3359       return visitor->HandleFft(this);
3360     case HloOpcode::kAllGather:
3361       return visitor->HandleAllGather(this);
3362     case HloOpcode::kAllGatherStart:
3363       return visitor->HandleAllGatherStart(this);
3364     case HloOpcode::kAllGatherDone:
3365       return visitor->HandleAllGatherDone(this);
3366     case HloOpcode::kAllReduce:
3367       return visitor->HandleAllReduce(this);
3368     case HloOpcode::kReduceScatter:
3369       return visitor->HandleReduceScatter(this);
3370     case HloOpcode::kAllReduceStart:
3371       return visitor->HandleAllReduceStart(this);
3372     case HloOpcode::kAllReduceDone:
3373       return visitor->HandleAllReduceDone(this);
3374     case HloOpcode::kAllToAll:
3375       return visitor->HandleAllToAll(this);
3376     case HloOpcode::kCollectivePermute:
3377       return visitor->HandleCollectivePermute(this);
3378     case HloOpcode::kCollectivePermuteStart:
3379       return visitor->HandleCollectivePermuteStart(this);
3380     case HloOpcode::kCollectivePermuteDone:
3381       return visitor->HandleCollectivePermuteDone(this);
3382     case HloOpcode::kReplicaId:
3383       return visitor->HandleReplicaId(this);
3384     case HloOpcode::kPartitionId:
3385       return visitor->HandlePartitionId(this);
3386     case HloOpcode::kTuple:
3387       return visitor->HandleTuple(this);
3388     case HloOpcode::kMap:
3389       return visitor->HandleMap(this);
3390     case HloOpcode::kClamp:
3391       return visitor->HandleClamp(this);
3392     case HloOpcode::kReduce:
3393       return visitor->HandleReduce(this);
3394     case HloOpcode::kReduceWindow:
3395       return visitor->HandleReduceWindow(this);
3396     case HloOpcode::kSelectAndScatter:
3397       return visitor->HandleSelectAndScatter(this);
3398     case HloOpcode::kNegate:
3399       return visitor->HandleNegate(this);
3400     case HloOpcode::kExp:
3401       return visitor->HandleExp(this);
3402     case HloOpcode::kExpm1:
3403       return visitor->HandleExpm1(this);
3404     case HloOpcode::kFloor:
3405       return visitor->HandleFloor(this);
3406     case HloOpcode::kCeil:
3407       return visitor->HandleCeil(this);
3408     case HloOpcode::kClz:
3409       return visitor->HandleClz(this);
3410     case HloOpcode::kLog:
3411       return visitor->HandleLog(this);
3412     case HloOpcode::kLog1p:
3413       return visitor->HandleLog1p(this);
3414     case HloOpcode::kTanh:
3415       return visitor->HandleTanh(this);
3416     case HloOpcode::kCos:
3417       return visitor->HandleCos(this);
3418     case HloOpcode::kSin:
3419       return visitor->HandleSin(this);
3420     case HloOpcode::kSqrt:
3421       return visitor->HandleSqrt(this);
3422     case HloOpcode::kCbrt:
3423       return visitor->HandleCbrt(this);
3424     case HloOpcode::kRsqrt:
3425       return visitor->HandleRsqrt(this);
3426     case HloOpcode::kReal:
3427       return visitor->HandleReal(this);
3428     case HloOpcode::kImag:
3429       return visitor->HandleImag(this);
3430     case HloOpcode::kIsFinite:
3431       return visitor->HandleIsFinite(this);
3432     case HloOpcode::kNot:
3433       return visitor->HandleNot(this);
3434     case HloOpcode::kPopulationCount:
3435       return visitor->HandlePopulationCount(this);
3436     case HloOpcode::kBitcast:
3437       return visitor->HandleBitcast(this);
3438     case HloOpcode::kBroadcast:
3439       return visitor->HandleBroadcast(this);
3440     case HloOpcode::kPad:
3441       return visitor->HandlePad(this);
3442     case HloOpcode::kReshape:
3443       return visitor->HandleReshape(this);
3444     case HloOpcode::kDynamicReshape:
3445       return visitor->HandleDynamicReshape(this);
3446     case HloOpcode::kTranspose:
3447       return visitor->HandleTranspose(this);
3448     case HloOpcode::kReverse:
3449       return visitor->HandleReverse(this);
3450     case HloOpcode::kReducePrecision:
3451       return visitor->HandleReducePrecision(this);
3452     case HloOpcode::kSlice:
3453       return visitor->HandleSlice(this);
3454     case HloOpcode::kDynamicSlice:
3455       return visitor->HandleDynamicSlice(this);
3456     case HloOpcode::kDynamicUpdateSlice:
3457       return visitor->HandleDynamicUpdateSlice(this);
3458     case HloOpcode::kSort:
3459       return visitor->HandleSort(this);
3460     case HloOpcode::kInfeed:
3461       return visitor->HandleInfeed(this);
3462     case HloOpcode::kOutfeed:
3463       return visitor->HandleOutfeed(this);
3464     case HloOpcode::kRng:
3465       return visitor->HandleRng(this);
3466     case HloOpcode::kRngBitGenerator:
3467       return visitor->HandleRngBitGenerator(this);
3468     case HloOpcode::kRngGetAndUpdateState:
3469       return visitor->HandleRngGetAndUpdateState(this);
3470     case HloOpcode::kWhile:
3471       return visitor->HandleWhile(this);
3472     case HloOpcode::kFusion:
3473       return visitor->HandleFusion(this);
3474     case HloOpcode::kCall:
3475       return visitor->HandleCall(this);
3476     case HloOpcode::kConditional:
3477       return visitor->HandleConditional(this);
3478     case HloOpcode::kCustomCall:
3479       return visitor->HandleCustomCall(this);
3480     case HloOpcode::kCopyStart:
3481       return visitor->HandleCopyStart(this);
3482     case HloOpcode::kCopyDone:
3483       return visitor->HandleCopyDone(this);
3484     case HloOpcode::kRecv:
3485       return visitor->HandleRecv(this);
3486     case HloOpcode::kRecvDone:
3487       return visitor->HandleRecvDone(this);
3488     case HloOpcode::kSend:
3489       return visitor->HandleSend(this);
3490     case HloOpcode::kSendDone:
3491       return visitor->HandleSendDone(this);
3492     case HloOpcode::kGather:
3493       return visitor->HandleGather(this);
3494     case HloOpcode::kScatter:
3495       return visitor->HandleScatter(this);
3496     case HloOpcode::kDomain:
3497       return visitor->HandleDomain(this);
3498     case HloOpcode::kAfterAll:
3499       return visitor->HandleAfterAll(this);
3500     case HloOpcode::kAddDependency:
3501       return visitor->HandleAddDependency(this);
3502     case HloOpcode::kIota:
3503       return visitor->HandleIota(this);
3504     case HloOpcode::kGetDimensionSize:
3505       return visitor->HandleGetDimensionSize(this);
3506     case HloOpcode::kSetDimensionSize:
3507       return visitor->HandleSetDimensionSize(this);
3508     case HloOpcode::kTriangularSolve:
3509       return visitor->HandleTriangularSolve(this);
3510     case HloOpcode::kCholesky:
3511       return visitor->HandleCholesky(this);
3512 
3513     // These opcodes are not handled here.
3514     case HloOpcode::kTrace:
3515       return Status::OK();
3516   }
3517   return InternalError(
3518       "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3519       "please file a bug for XLA.",
3520       HloOpcodeString(opcode_));
3521 }
3522 
3523 // Explicit instantiations.
3524 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3525 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3526 
3527 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
3528 // cycle was detected, and true otherwise.
3529 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3530 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3531                          HloInstruction* child) {
3532   CHECK(child != nullptr);
3533   const int id = child->unique_id();
3534   CHECK_GE(id, 0) << "instruction may not have a parent computation";
3535   switch (visitor->GetVisitState(id)) {
3536     case Visitor::kVisiting:
3537       return false;
3538 
3539     case Visitor::kVisited:
3540       // Nothing to do
3541       return true;
3542 
3543     case Visitor::kNotVisited:
3544       dfs_stack->push_back(std::make_pair(id, child));
3545       return true;
3546   }
3547 }
3548 
3549 using InternalCompareFunction =
3550     std::function<bool(std::pair<int, const HloInstruction*>,
3551                        std::pair<int, const HloInstruction*>)>;
3552 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3553 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3554                            const InternalCompareFunction* operand_order,
3555                            bool ignore_control_predecessors) {
3556   visitor->ReserveVisitStates(root->parent()->instruction_count());
3557 
3558   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3559   //
3560   // We need to keep track of both the id and the instruction because
3561   // instructions can get deleted while they are on the stack, so we
3562   // can't always use the (potentially dead) instruction object to grab
3563   // its id.
3564   DFSStack dfs_stack;
3565   dfs_stack.emplace_back(root->unique_id(), root);
3566 
3567   do {
3568     DCHECK(!dfs_stack.empty());
3569 
3570     int current_id = dfs_stack.back().first;
3571     HloInstruction* current_node = dfs_stack.back().second;
3572     CHECK_GE(current_id, 0) << current_id << ": " << current_node
3573                             << ": instruction may not have parent computation";
3574     typename Visitor::VisitState visit_state =
3575         visitor->GetVisitState(current_id);
3576     if (visit_state == Visitor::kVisited) {
3577       dfs_stack.pop_back();
3578       VLOG(3) << "Not visiting HLO (id = " << current_id
3579               << ") as it was already visited.";
3580       continue;
3581     }
3582 
3583     if (visit_state == Visitor::kVisiting) {
3584       dfs_stack.pop_back();
3585 
3586       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3587       VLOG(2) << "Visiting HLO %" << current_node->name();
3588       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3589       visitor->SetVisitState(current_id, Visitor::kVisited);
3590       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3591       continue;
3592     }
3593 
3594     visitor->SetVisitState(current_id, Visitor::kVisiting);
3595 
3596     const size_t old_dfs_stack_size = dfs_stack.size();
3597     for (HloInstruction* child : current_node->operands()) {
3598       if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3599         PrintCycle(child, &dfs_stack);
3600         return FailedPrecondition(
3601             "A cycle is detected while visiting instruction %s",
3602             current_node->ToString());
3603       }
3604     }
3605 
3606     if (!ignore_control_predecessors) {
3607       for (HloInstruction* child : current_node->control_predecessors()) {
3608         if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3609           PrintCycle(child, &dfs_stack);
3610           return FailedPrecondition(
3611               "A cycle is detected while visiting instruction %s",
3612               current_node->ToString());
3613         }
3614       }
3615     }
3616 
3617     if (operand_order != nullptr) {
3618       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3619                 *operand_order);
3620     }
3621 
3622     // This makes the traversal order the same as what you'd expect
3623     // out of a recursive algorithm.
3624     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3625   } while (!dfs_stack.empty());
3626 
3627   return Status::OK();
3628 }
3629 
3630 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3631 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3632                               bool call_finish_visit,
3633                               bool ignore_control_predecessors) {
3634   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3635   TF_RETURN_IF_ERROR(
3636       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3637   if (call_finish_visit) {
3638     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3639   }
3640   return Status::OK();
3641 }
3642 
3643 // Explicit instantiations.
3644 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3645 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3646 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3647 Status HloInstruction::AcceptWithOperandOrder(
3648     DfsHloVisitor* visitor, const CompareFunction& operand_order,
3649     bool call_finish_visit) {
3650   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3651   InternalCompareFunction func = [&operand_order](
3652                                      std::pair<int, const HloInstruction*> a,
3653                                      std::pair<int, const HloInstruction*> b) {
3654     // Call the client's comparison function on the actual HloInstruction*
3655     // objects (ignoring the internal ids we also have in our stack entries)
3656     return operand_order(a.second, b.second);
3657   };
3658   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3659                                   /*ignore_control_predecessors=*/false));
3660   if (call_finish_visit) {
3661     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3662     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3663     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3664   }
3665   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3666   return Status::OK();
3667 }
3668 
shape() const3669 const Shape& HloInstruction::shape() const { return shape_; }
3670 
OperandIndices(const HloInstruction * operand) const3671 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3672     const HloInstruction* operand) const {
3673   absl::InlinedVector<int64, 4> result;
3674   for (int64_t i = 0; i < operand_count(); ++i) {
3675     if (this->operand(i) == operand) {
3676       result.push_back(i);
3677     }
3678   }
3679   return result;
3680 }
3681 
IsElementwiseBinary() const3682 bool HloInstruction::IsElementwiseBinary() const {
3683   return IsElementwise() && operand_count() == 2;
3684 }
3685 
IsElementwise() const3686 bool HloInstruction::IsElementwise() const {
3687   return IsElementwiseImpl(absl::nullopt);
3688 }
3689 
IsElementwiseOnOperand(int64_t operand_idx) const3690 bool HloInstruction::IsElementwiseOnOperand(int64_t operand_idx) const {
3691   return IsElementwiseImpl(operand_idx);
3692 }
3693 
3694 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3695 // in HloInstruction::OperandElementUse below.
3696 class HloInstruction::FusionReusesParamElements {
3697  public:
3698   using UseKind = HloInstruction::UseKind;
3699 
3700   // We could rather iterate backwards through fused_instructions_ here, as it
3701   // is in reverse postorder, and compute whether each fused instruction reuses
3702   // the value of this parameter, which would save stack space but not allow us
3703   // to finish early if we find a reuse.
Compute(int64_t i,const HloInstruction & hlo)3704   static UseKind Compute(int64_t i, const HloInstruction& hlo) {
3705     absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3706     return ComputeInternal(i, hlo, &memoization_cache);
3707   }
3708 
3709  private:
ComputeInternal(int64_t i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3710   static UseKind ComputeInternal(
3711       int64_t i, const HloInstruction& hlo,
3712       absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3713     if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3714       if (hlo_param->parameter_number() == i) {
3715         return UseKind::kUse;
3716       }
3717     }
3718 
3719     auto p = cache->emplace(&hlo, UseKind::kNoUse);
3720     auto value_it = p.first;
3721     const bool key_is_new = p.second;
3722 
3723     if (key_is_new) {
3724       for (int64_t j = 0; j < hlo.operands_.size(); ++j) {
3725         UseKind old_val = value_it->second;
3726 
3727         // The next operation invalidates iterators.
3728         UseKind new_val =
3729             Fold(old_val,
3730                  FoldUseMandatory(hlo.OperandElementUse(j),
3731                                   ComputeInternal(i, *hlo.operand(j), cache)));
3732 
3733         // Re-acquire the iterator. We could work harder to do this only if
3734         // absolutely necessary, but this code is not hot enough to warrant
3735         // that.
3736         value_it = cache->find(&hlo);
3737         value_it->second = new_val;
3738         // Fold() minimizes the UseKind value. If it is already minimum, we can
3739         // break the loop early.
3740         if (new_val == UseKind::kReuse) {
3741           break;
3742         }
3743       }
3744     }
3745     return value_it->second;
3746   }
3747 
3748   // Combines two UseKinds.
3749   //
3750   // This is the min operation on the lattice
3751   //
3752   //   kReuse < kUse < kNoUse.
3753   //
3754   // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3755   static UseKind Fold(UseKind a, UseKind b) {
3756     // Without loss of generality, let `b` be the operation with the larger use
3757     // kind.
3758     if (b.kind < a.kind) {
3759       std::swap(a, b);
3760     }
3761     // If the kinds are different, return the smaller one, namely `a`.
3762     if (a.kind != b.kind) {
3763       return a;
3764     }
3765     // If the kinds are both kUse, check that they're the same permutation.
3766     if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3767         a.permutation_instr != b.permutation_instr) {
3768       return UseKind::kReuse;
3769     }
3770     return a;  // They're the same.
3771   }
3772 
3773   // Combines two UseKinds differently than Fold().
3774   //
3775   // This is the min operation on the lattice
3776   //
3777   //   kNoUse < kReuse < kUse.
3778   //
3779   // If `a` and `b` are both kUse and one has a non-null permutation
3780   // instruction, returns kUse with that permutation.  OTOH if both have
3781   // different, non-null permutation instructions, returns kReuse.
3782   //
3783   // You can think of this sort of as a conjunction, whereas Fold is sort of a
3784   // disjunction.  FoldUseMandatory() says "no use" if either input isn't used,
3785   // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3786   static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3787     if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3788       return UseKind::kNoUse;
3789     }
3790     if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3791       return UseKind::kReuse;
3792     }
3793     if (a.permutation_instr == b.permutation_instr) {
3794       return a;  // They're the same.
3795     }
3796     if (b.permutation_instr == nullptr) {
3797       return a;
3798     }
3799     if (a.permutation_instr == nullptr) {
3800       return b;
3801     }
3802     return UseKind::kReuse;
3803   }
3804 };
3805 
OperandElementUse(int64_t operand_num) const3806 HloInstruction::UseKind HloInstruction::OperandElementUse(
3807     int64_t operand_num) const {
3808   switch (opcode_) {
3809     case HloOpcode::kBitcast:
3810       // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3811       // doesn't permute its elements, so it counts as a plain, non-permuting
3812       // use.
3813       return ShapeUtil::DropDegenerateDimensions(shape()) ==
3814                      ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3815                  ? UseKind::kUse
3816                  : UseKind::Permuting(this);
3817     case HloOpcode::kConcatenate:
3818     case HloOpcode::kReshape:
3819     case HloOpcode::kReverse:
3820     case HloOpcode::kSlice:
3821     case HloOpcode::kTranspose:
3822       return UseKind::Permuting(this);
3823     case HloOpcode::kPad:
3824       // Pad reuses the padding value but not the padded array elements.
3825       return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3826     case HloOpcode::kReduce:
3827       // Reduce reuses the init values but not the operand array elements.
3828       return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3829                  ? UseKind::kReuse
3830                  : UseKind::Permuting(this);
3831     case HloOpcode::kFusion:
3832       // Uses the memoizing, recursive computation defined above.
3833       return FusionReusesParamElements::Compute(operand_num,
3834                                                 *fused_expression_root());
3835     case HloOpcode::kDot:
3836       // Matrix-vector dots do not reuse the matrix operand.
3837       if (shape().dimensions_size() <= 1) {
3838         if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3839             (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3840           return UseKind::kUse;
3841         }
3842       }
3843       return UseKind::kReuse;
3844     case HloOpcode::kDynamicUpdateSlice:
3845       // Dynamic-update-slice reuses only start_indices.
3846       if (operand_num == 0 || operand_num == 1) {
3847         return UseKind::kUse;
3848       }
3849       return UseKind::kReuse;
3850     case HloOpcode::kGather:
3851       // Gather reads its indices in a linear fashion, and it permutes the
3852       // vector it's gathering from.
3853       return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3854     default:
3855       return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3856   }
3857 }
3858 
3859 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3860 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3861   if (HloOpcode::kReshape != opcode_) {
3862     return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3863   }
3864   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3865                                                       shape_);
3866 }
3867 
ToString(HloInstruction::FusionKind kind)3868 string ToString(HloInstruction::FusionKind kind) {
3869   switch (kind) {
3870     case HloInstruction::FusionKind::kLoop:
3871       return "kLoop";
3872     case HloInstruction::FusionKind::kInput:
3873       return "kInput";
3874     case HloInstruction::FusionKind::kOutput:
3875       return "kOutput";
3876     case HloInstruction::FusionKind::kCustom:
3877       return "kCustom";
3878   }
3879 }
3880 
StringToFusionKind(const string & kind_name)3881 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3882     const string& kind_name) {
3883   if (kind_name == "kLoop") {
3884     return HloInstruction::FusionKind::kLoop;
3885   }
3886   if (kind_name == "kInput") {
3887     return HloInstruction::FusionKind::kInput;
3888   }
3889   if (kind_name == "kOutput") {
3890     return HloInstruction::FusionKind::kOutput;
3891   }
3892   if (kind_name == "kCustom") {
3893     return HloInstruction::FusionKind::kCustom;
3894   }
3895   return InvalidArgument("Unknown fusion kind: %s", kind_name);
3896 }
3897 
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3898 string FrontendAttributesToString(
3899     const FrontendAttributes& frontend_attributes) {
3900   std::vector<std::pair<string, string>> sorted_attributes(
3901       frontend_attributes.map().begin(), frontend_attributes.map().end());
3902   absl::c_sort(sorted_attributes);
3903   // Frontend attribute is a comma-separated list of attribute="value" pairs,
3904   // e.g., frontend_attributes={name="value_a",type="int32"}.
3905   const auto formatter = [](string* out,
3906                             const std::pair<string, string>& item) {
3907     absl::StrAppend(out, item.first, "=\"", item.second, "\"");
3908   };
3909   return absl::StrFormat("{%s}",
3910                          absl::StrJoin(sorted_attributes, ",", formatter));
3911 }
3912 
PaddingConfigToString(const PaddingConfig & padding)3913 string PaddingConfigToString(const PaddingConfig& padding) {
3914   bool has_interior_padding =
3915       absl::c_any_of(padding.dimensions(),
3916                      [](const PaddingConfig::PaddingConfigDimension& dim) {
3917                        return dim.interior_padding() != 0;
3918                      });
3919   return StrJoin(
3920       padding.dimensions(), "x",
3921       [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3922         StrAppend(
3923             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3924             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3925       });
3926 }
3927 
RandomDistributionToString(const RandomDistribution & distribution)3928 string RandomDistributionToString(const RandomDistribution& distribution) {
3929   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3930 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)3931 string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
3932   return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
3933 }
3934 
PrecisionToString(const PrecisionConfig::Precision & precision)3935 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3936   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3937 }
3938 
CustomCallScheduleToString(const CustomCallSchedule & schedule)3939 static string CustomCallScheduleToString(const CustomCallSchedule& schedule) {
3940   return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule));
3941 }
3942 
CustomCallApiVersionToString(const CustomCallApiVersion & schedule)3943 static string CustomCallApiVersionToString(
3944     const CustomCallApiVersion& schedule) {
3945   return absl::AsciiStrToLower(CustomCallApiVersion_Name(schedule));
3946 }
3947 
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3948 string ConvolutionDimensionNumbersToString(
3949     const ConvolutionDimensionNumbers& dnums) {
3950   auto len_required = [](int64_t a, int64_t b, absl::Span<const int64> cs) {
3951     return std::max({a, b, cs.empty() ? 0 : *absl::c_max_element(cs)}) + 1;
3952   };
3953 
3954   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3955   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3956   std::vector<string> lhs_dims(len_required(dnums.input_batch_dimension(),
3957                                             dnums.input_feature_dimension(),
3958                                             dnums.input_spatial_dimensions()),
3959                                "?");
3960   lhs_dims[dnums.input_batch_dimension()] = 'b';
3961   lhs_dims[dnums.input_feature_dimension()] = 'f';
3962   for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3963     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3964   }
3965 
3966   std::vector<string> rhs_dims(
3967       len_required(dnums.kernel_input_feature_dimension(),
3968                    dnums.kernel_output_feature_dimension(),
3969                    dnums.kernel_spatial_dimensions()),
3970       "?");
3971   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3972   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3973   for (int64_t i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3974     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3975   }
3976 
3977   std::vector<string> output_dims(
3978       len_required(dnums.output_batch_dimension(),
3979                    dnums.output_feature_dimension(),
3980                    dnums.output_spatial_dimensions()),
3981       "?");
3982   output_dims[dnums.output_batch_dimension()] = 'b';
3983   output_dims[dnums.output_feature_dimension()] = 'f';
3984   for (int64_t i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3985     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3986   }
3987 
3988   return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3989                 StrJoin(output_dims, ""));
3990 }
3991 
ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups)3992 string ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups) {
3993   std::vector<string> replica_group_str;
3994   replica_group_str.reserve(replica_groups.size());
3995   for (const ReplicaGroup& group : replica_groups) {
3996     replica_group_str.push_back(
3997         StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3998   }
3999   return StrCat("{", StrJoin(replica_group_str, ","), "}");
4000 }
4001 
StringToRandomAlgorithm(const string & name)4002 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
4003   static std::unordered_map<string, RandomAlgorithm>* map = [] {
4004     static auto* map = new std::unordered_map<string, RandomAlgorithm>;
4005     for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
4006       if (RandomAlgorithm_IsValid(i)) {
4007         auto value = static_cast<RandomAlgorithm>(i);
4008         (*map)[RandomAlgorithmToString(value)] = value;
4009       }
4010     }
4011     return map;
4012   }();
4013   auto found = map->find(absl::AsciiStrToLower(name));
4014   if (found == map->end()) {
4015     return InvalidArgument("Unknown algorithm");
4016   }
4017   return found->second;
4018 }
4019 
StringToRandomDistribution(const string & name)4020 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
4021   static std::unordered_map<string, RandomDistribution>* map = [] {
4022     static auto* map = new std::unordered_map<string, RandomDistribution>;
4023     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
4024       if (RandomDistribution_IsValid(i)) {
4025         auto value = static_cast<RandomDistribution>(i);
4026         (*map)[RandomDistributionToString(value)] = value;
4027       }
4028     }
4029     return map;
4030   }();
4031   auto found = map->find(absl::AsciiStrToLower(name));
4032   if (found == map->end()) {
4033     return InvalidArgument("Unknown distribution");
4034   }
4035   return found->second;
4036 }
4037 
StringToPrecision(const string & name)4038 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
4039   static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
4040     static auto* map =
4041         new std::unordered_map<string, PrecisionConfig::Precision>;
4042     for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
4043       if (PrecisionConfig::Precision_IsValid(i)) {
4044         auto value = static_cast<PrecisionConfig::Precision>(i);
4045         (*map)[PrecisionToString(value)] = value;
4046       }
4047     }
4048     return map;
4049   }();
4050   auto found = map->find(absl::AsciiStrToLower(name));
4051   if (found == map->end()) {
4052     return InvalidArgument("Unknown distribution");
4053   }
4054   return found->second;
4055 }
4056 
StringToCustomCallSchedule(absl::string_view name)4057 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(
4058     absl::string_view name) {
4059   static const absl::flat_hash_map<string, CustomCallSchedule>* map = [] {
4060     static auto* map = new absl::flat_hash_map<string, CustomCallSchedule>;
4061     for (int i = 0; i < CustomCallSchedule_ARRAYSIZE; i++) {
4062       if (CustomCallSchedule_IsValid(i)) {
4063         auto value = static_cast<CustomCallSchedule>(i);
4064         (*map)[CustomCallScheduleToString(value)] = value;
4065       }
4066     }
4067     return map;
4068   }();
4069   auto found = map->find(absl::AsciiStrToLower(name));
4070   if (found == map->end()) {
4071     return InvalidArgument("Unknown schedule");
4072   }
4073   return found->second;
4074 }
4075 
StringToCustomCallApiVersion(absl::string_view name)4076 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
4077     absl::string_view name) {
4078   static const absl::flat_hash_map<string, CustomCallApiVersion>* map = [] {
4079     static auto* map = new absl::flat_hash_map<string, CustomCallApiVersion>;
4080     for (int i = 0; i < CustomCallApiVersion_ARRAYSIZE; i++) {
4081       if (CustomCallApiVersion_IsValid(i)) {
4082         auto value = static_cast<CustomCallApiVersion>(i);
4083         (*map)[CustomCallApiVersionToString(value)] = value;
4084       }
4085     }
4086     return map;
4087   }();
4088   auto found = map->find(absl::AsciiStrToLower(name));
4089   if (found == map->end()) {
4090     return InvalidArgument("Unknown API version");
4091   }
4092   return found->second;
4093 }
4094 
operator <<(std::ostream & os,HloInstruction::FusionKind kind)4095 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
4096   return os << ToString(kind);
4097 }
4098 
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const4099 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
4100                                   const HloInstruction* const& rhs) const {
4101   if (rhs == nullptr) {
4102     // Nothing compares less than nullptr.
4103     return false;
4104   }
4105   if (lhs == nullptr) {
4106     return true;
4107   }
4108   auto lhs_module = lhs->GetModule();
4109   auto rhs_module = rhs->GetModule();
4110   CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
4111         (lhs_module != nullptr && rhs_module != nullptr));
4112   if (lhs_module != nullptr &&
4113       lhs_module->unique_id() != rhs_module->unique_id()) {
4114     return lhs_module->unique_id() < rhs_module->unique_id();
4115   }
4116   return lhs->unique_id() < rhs->unique_id();
4117 }
4118 
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const4119 Status HloInstruction::GetBackendConfigInternal(
4120     tensorflow::protobuf::Message* proto) const {
4121   proto->Clear();
4122 
4123   // Empty string does not parse as valid JSON, but it's a valid backend config,
4124   // corresponding to the empty proto.
4125   if (backend_config_.empty()) {
4126     return Status::OK();
4127   }
4128   return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
4129 }
4130 
set_backend_config(const tensorflow::protobuf::Message & proto)4131 Status HloInstruction::set_backend_config(
4132     const tensorflow::protobuf::Message& proto) {
4133   TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
4134   return Status::OK();
4135 }
4136 
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)4137 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
4138     const tensorflow::protobuf::Message& proto) {
4139   string ret;
4140   // Pass ignore_accuracy_loss = true because estimated_cycles field can be
4141   // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
4142   // INT64_MAX, JsonFormat will return an error status, although there is no
4143   // accuracy loss for int64.
4144   TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
4145       proto, &ret, /*ignore_accuracy_loss=*/true));
4146   return ret;
4147 }
4148 
precision_config() const4149 const PrecisionConfig& HloInstruction::precision_config() const {
4150   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4151     return convolution->precision_config();
4152   }
4153   if (auto* dot = DynCast<HloDotInstruction>(this)) {
4154     return dot->precision_config();
4155   }
4156 
4157   if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
4158     return custom_call->precision_config();
4159   }
4160   LOG(FATAL) << "Unimplemented method.";
4161 }
4162 
mutable_precision_config()4163 PrecisionConfig* HloInstruction::mutable_precision_config() {
4164   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4165     return convolution->mutable_precision_config();
4166   }
4167   if (auto* dot = DynCast<HloDotInstruction>(this)) {
4168     return dot->mutable_precision_config();
4169   }
4170   LOG(FATAL) << "Unimplemented method.";
4171 }
4172 
GetModule() const4173 HloModule* HloInstruction::GetModule() const {
4174   if (parent_) {
4175     return parent_->parent();
4176   }
4177   return nullptr;
4178 }
4179 
UniquifyName(NameUniquer * name_uniquer)4180 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
4181   string parent_str = parent() == nullptr ? "noparent" : parent()->name();
4182   name_ = name_uniquer->GetUniqueName(name_);
4183 }
4184 
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)4185 void HloInstruction::set_outer_dimension_partitions(
4186     const std::vector<int64>& outer_dimension_partitions) {
4187   outer_dimension_partitions_ = outer_dimension_partitions;
4188 }
4189 
4190 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const4191 int64 HloInstruction::feature_index() const {
4192   return Cast<HloBatchNormInstruction>(this)->feature_index();
4193 }
4194 
epsilon() const4195 float HloInstruction::epsilon() const {
4196   return Cast<HloBatchNormInstruction>(this)->epsilon();
4197 }
4198 
fft_type() const4199 FftType HloInstruction::fft_type() const {
4200   return Cast<HloFftInstruction>(this)->fft_type();
4201 }
4202 
fft_length() const4203 const std::vector<int64>& HloInstruction::fft_length() const {
4204   return Cast<HloFftInstruction>(this)->fft_length();
4205 }
4206 
concatenate_dimension() const4207 int64 HloInstruction::concatenate_dimension() const {
4208   return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
4209 }
4210 
dimension() const4211 int64 HloInstruction::dimension() const {
4212   if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
4213     return set_size->dimension();
4214   }
4215   return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
4216 }
4217 
inferred_dimension() const4218 int64 HloInstruction::inferred_dimension() const {
4219   return Cast<HloReshapeInstruction>(this)->inferred_dimension();
4220 }
4221 
IsRank2Transpose() const4222 bool HloInstruction::IsRank2Transpose() const {
4223   auto transpose = DynCast<HloTransposeInstruction>(this);
4224   return transpose != nullptr && transpose->IsRank2Transpose();
4225 }
4226 
slice_starts(int64_t dimension) const4227 int64 HloInstruction::slice_starts(int64_t dimension) const {
4228   return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
4229 }
4230 
slice_starts() const4231 const std::vector<int64>& HloInstruction::slice_starts() const {
4232   return Cast<HloSliceInstruction>(this)->slice_starts();
4233 }
4234 
mutable_slice_starts()4235 std::vector<int64>* HloInstruction::mutable_slice_starts() {
4236   return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
4237 }
4238 
slice_limits(int64_t dimension) const4239 int64 HloInstruction::slice_limits(int64_t dimension) const {
4240   return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
4241 }
4242 
slice_limits() const4243 const std::vector<int64>& HloInstruction::slice_limits() const {
4244   return Cast<HloSliceInstruction>(this)->slice_limits();
4245 }
4246 
mutable_slice_limits()4247 std::vector<int64>* HloInstruction::mutable_slice_limits() {
4248   return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
4249 }
4250 
slice_strides(int64_t dimension) const4251 int64 HloInstruction::slice_strides(int64_t dimension) const {
4252   return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
4253 }
4254 
slice_strides() const4255 const std::vector<int64>& HloInstruction::slice_strides() const {
4256   return Cast<HloSliceInstruction>(this)->slice_strides();
4257 }
4258 
mutable_slice_strides()4259 std::vector<int64>* HloInstruction::mutable_slice_strides() {
4260   return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
4261 }
4262 
literal() const4263 const Literal& HloInstruction::literal() const {
4264   return Cast<HloConstantInstruction>(this)->literal();
4265 }
4266 
IsConstant() const4267 bool HloInstruction::IsConstant() const {
4268   return DynCast<HloConstantInstruction>(this) != nullptr;
4269 }
4270 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)4271 void HloInstruction::RelayoutConstant(const Layout& new_layout,
4272                                       const ShapeIndex& shape_index) {
4273   Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
4274 }
4275 
TracingTag() const4276 string HloInstruction::TracingTag() const {
4277   return Cast<HloTraceInstruction>(this)->TracingTag();
4278 }
4279 
AddFusionOperand(HloInstruction * new_operand)4280 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
4281   return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
4282 }
4283 
4284 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)4285 void HloInstruction::MergeFusionInstruction(
4286     HloInstruction* instruction_to_merge) {
4287   return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
4288       Cast<HloFusionInstruction>(instruction_to_merge));
4289 }
4290 
4291 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)4292 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
4293     HloInstruction* instruction_to_merge) {
4294   return Cast<HloFusionInstruction>(this)
4295       ->MergeFusionInstructionIntoMultiOutput(
4296           Cast<HloFusionInstruction>(instruction_to_merge));
4297 }
4298 
FuseInstruction(HloInstruction * instruction_to_fuse)4299 HloInstruction* HloInstruction::FuseInstruction(
4300     HloInstruction* instruction_to_fuse) {
4301   return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4302 }
4303 
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4304 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4305     HloInstruction* instruction_to_fuse) {
4306   return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4307       instruction_to_fuse);
4308 }
4309 
fused_instructions_computation() const4310 HloComputation* HloInstruction::fused_instructions_computation() const {
4311   return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4312 }
4313 
fused_expression_root() const4314 HloInstruction* HloInstruction::fused_expression_root() const {
4315   return Cast<HloFusionInstruction>(this)->fused_expression_root();
4316 }
4317 
4318 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4319     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4320 HloInstruction::fused_instructions() const {
4321   return Cast<HloFusionInstruction>(this)->fused_instructions();
4322 }
4323 
4324 const tensorflow::gtl::iterator_range<
4325     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4326 HloInstruction::fused_instructions() {
4327   return Cast<HloFusionInstruction>(this)->fused_instructions();
4328 }
4329 
fused_instruction_count() const4330 int64 HloInstruction::fused_instruction_count() const {
4331   return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4332 }
4333 
fused_parameter(int64_t parameter_number) const4334 HloInstruction* HloInstruction::fused_parameter(
4335     int64_t parameter_number) const {
4336   return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4337 }
4338 
fused_parameters() const4339 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4340   return Cast<HloFusionInstruction>(this)->fused_parameters();
4341 }
4342 
IsMultiOutputFusion() const4343 const bool HloInstruction::IsMultiOutputFusion() const {
4344   const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4345   return fusion != nullptr && fusion->IsMultiOutputFusion();
4346 }
4347 
fusion_kind() const4348 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4349   return Cast<HloFusionInstruction>(this)->fusion_kind();
4350 }
4351 
set_fusion_kind(FusionKind kind)4352 void HloInstruction::set_fusion_kind(FusionKind kind) {
4353   return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4354 }
4355 
random_distribution() const4356 RandomDistribution HloInstruction::random_distribution() const {
4357   return Cast<HloRngInstruction>(this)->random_distribution();
4358 }
4359 
parameter_number() const4360 int64 HloInstruction::parameter_number() const {
4361   return Cast<HloParameterInstruction>(this)->parameter_number();
4362 }
4363 
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4364 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4365     absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4366   return Cast<HloParameterInstruction>(this)
4367       ->set_parameter_replicated_at_leaf_buffers(
4368           parameter_replicated_at_leaf_buffers);
4369 }
4370 
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4371 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4372     const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4373   return Cast<HloParameterInstruction>(this)
4374       ->set_parameter_replicated_at_leaf_buffers(
4375           parameter_replicated_at_leaf_buffers);
4376 }
4377 
4378 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4379 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4380   return Cast<HloParameterInstruction>(this)
4381       ->parameter_replicated_at_leaf_buffers();
4382 }
4383 
tuple_index() const4384 int64 HloInstruction::tuple_index() const {
4385   return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4386 }
4387 
set_tuple_index(int64_t new_tuple_index)4388 void HloInstruction::set_tuple_index(int64_t new_tuple_index) {
4389   return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4390       new_tuple_index);
4391 }
4392 
exponent_bits() const4393 int32 HloInstruction::exponent_bits() const {
4394   return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4395 }
4396 
mantissa_bits() const4397 int32 HloInstruction::mantissa_bits() const {
4398   return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4399 }
4400 
infeed_config() const4401 string HloInstruction::infeed_config() const {
4402   return Cast<HloInfeedInstruction>(this)->infeed_config();
4403 }
4404 
set_infeed_config(const string & config)4405 void HloInstruction::set_infeed_config(const string& config) {
4406   return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4407 }
4408 
outfeed_shape() const4409 const Shape& HloInstruction::outfeed_shape() const {
4410   return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4411 }
4412 
mutable_outfeed_shape()4413 Shape* HloInstruction::mutable_outfeed_shape() {
4414   return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4415 }
4416 
outfeed_config() const4417 const string& HloInstruction::outfeed_config() const {
4418   return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4419 }
4420 
set_outfeed_config(const string & config)4421 void HloInstruction::set_outfeed_config(const string& config) {
4422   return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4423 }
4424 
replica_groups() const4425 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4426   return Cast<HloCollectiveInstruction>(this)->replica_groups();
4427 }
4428 
4429 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const4430 HloInstruction::source_target_pairs() const {
4431   return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4432 }
4433 
channel_id() const4434 absl::optional<int64> HloInstruction::channel_id() const {
4435   return Cast<HloChannelInstruction>(this)->channel_id();
4436 }
4437 
set_channel_id(const absl::optional<int64> & channel_id)4438 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
4439   return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4440 }
4441 
4442 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4443 HloInstruction::convolution_dimension_numbers() const {
4444   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4445     return convolution->convolution_dimension_numbers();
4446   }
4447   if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4448     return custom_call->convolution_dimension_numbers();
4449   }
4450   LOG(FATAL) << "Unimplemented method.";
4451 }
4452 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4453 void HloInstruction::set_convolution_dimension_numbers(
4454     const ConvolutionDimensionNumbers& dnums) {
4455   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4456     convolution->set_convolution_dimension_numbers(dnums);
4457   } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4458     custom_call->set_convolution_dimension_numbers(dnums);
4459   } else {
4460     LOG(FATAL) << "Unimplemented method.";
4461   }
4462 }
4463 
feature_group_count() const4464 int64 HloInstruction::feature_group_count() const {
4465   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4466     return convolution->feature_group_count();
4467   }
4468   return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4469 }
4470 
set_feature_group_count(int64_t feature_group_count)4471 void HloInstruction::set_feature_group_count(int64_t feature_group_count) {
4472   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4473     return convolution->set_feature_group_count(feature_group_count);
4474   }
4475   Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4476       feature_group_count);
4477 }
4478 
batch_group_count() const4479 int64 HloInstruction::batch_group_count() const {
4480   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4481     return convolution->batch_group_count();
4482   }
4483   return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4484 }
4485 
set_batch_group_count(int64_t batch_group_count)4486 void HloInstruction::set_batch_group_count(int64_t batch_group_count) {
4487   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4488     return convolution->set_batch_group_count(batch_group_count);
4489   }
4490   Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4491       batch_group_count);
4492 }
4493 
select() const4494 HloComputation* HloInstruction::select() const {
4495   return Cast<HloSelectAndScatterInstruction>(this)->select();
4496 }
4497 
scatter() const4498 HloComputation* HloInstruction::scatter() const {
4499   return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4500 }
4501 
set_select(HloComputation * computation)4502 void HloInstruction::set_select(HloComputation* computation) {
4503   return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4504 }
4505 
set_scatter(HloComputation * computation)4506 void HloInstruction::set_scatter(HloComputation* computation) {
4507   return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4508 }
4509 
custom_call_target() const4510 const string& HloInstruction::custom_call_target() const {
4511   return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4512 }
4513 
padding_config() const4514 const PaddingConfig& HloInstruction::padding_config() const {
4515   return Cast<HloPadInstruction>(this)->padding_config();
4516 }
4517 
padding_type() const4518 PaddingType HloInstruction::padding_type() const {
4519   return Cast<HloCustomCallInstruction>(this)->padding_type();
4520 }
4521 
mutable_padding_config()4522 PaddingConfig* HloInstruction::mutable_padding_config() {
4523   return Cast<HloPadInstruction>(this)->mutable_padding_config();
4524 }
4525 
slice_sizes(int64_t dimension) const4526 int64 HloInstruction::slice_sizes(int64_t dimension) const {
4527   return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4528 }
4529 
dynamic_slice_sizes() const4530 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
4531   return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4532 }
4533 
4534 const std::vector<std::vector<int64_t>>&
dynamic_slice_sizes_list() const4535 HloInstruction::dynamic_slice_sizes_list() const {
4536   return Cast<HloCollectivePermuteInstruction>(this)
4537       ->dynamic_slice_sizes_list();
4538 }
4539 
gather_dimension_numbers() const4540 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4541   return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4542 }
4543 
gather_slice_sizes() const4544 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
4545   return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4546 }
4547 
scatter_dimension_numbers() const4548 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4549     const {
4550   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4551 }
4552 
dot_dimension_numbers() const4553 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4554   return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4555 }
4556 
operand_side_metadata() const4557 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4558   return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4559 }
4560 
user_side_metadata() const4561 const DomainMetadata& HloInstruction::user_side_metadata() const {
4562   return Cast<HloDomainInstruction>(this)->user_side_metadata();
4563 }
4564 
is_cross_program_prefetch() const4565 bool HloInstruction::is_cross_program_prefetch() const {
4566   return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4567 }
4568 
comparison_direction() const4569 ComparisonDirection HloInstruction::comparison_direction() const {
4570   return Cast<HloCompareInstruction>(this)->direction();
4571 }
4572 
triangular_solve_options() const4573 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4574   return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4575 }
4576 
cholesky_options() const4577 const CholeskyOptions& HloInstruction::cholesky_options() const {
4578   return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4579 }
4580 
4581 }  // namespace xla
4582