• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <unordered_set>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/escaping.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/protobuf_util.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/platform/human_readable_json.h"
52 #include "tensorflow/core/platform/logging.h"
53 
54 namespace xla {
55 
56 using absl::CEscape;
57 using absl::StrAppend;
58 using absl::StrCat;
59 using absl::StrJoin;
60 
61 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)62 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
63     const HloInstructionProto& proto,
64     const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
65     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
66     bool prohibit_empty_literal) {
67   TF_RET_CHECK(!proto.opcode().empty());
68   HloOpcode opcode;
69   auto opcode_or = StringToHloOpcode(proto.opcode());
70   absl::optional<ComparisonDirection> comparison_direction;
71   if (opcode_or.ok()) {
72     opcode = opcode_or.ConsumeValueOrDie();
73   } else {
74     // Unknown opcode. Try auto-upgrading deprecated "less-than",
75     // "greater-than", etc opcodes, which are now rolled into the kCompare
76     // opcode.
77     if (proto.opcode() == "equal-to") {
78       comparison_direction = ComparisonDirection::kEq;
79     } else if (proto.opcode() == "not-equal-to") {
80       comparison_direction = ComparisonDirection::kNe;
81     } else if (proto.opcode() == "greater-than-or-equal-to") {
82       comparison_direction = ComparisonDirection::kGe;
83     } else if (proto.opcode() == "greater-than") {
84       comparison_direction = ComparisonDirection::kGt;
85     } else if (proto.opcode() == "less-than-or-equal-to") {
86       comparison_direction = ComparisonDirection::kLe;
87     } else if (proto.opcode() == "less-than") {
88       comparison_direction = ComparisonDirection::kLt;
89     }
90     if (comparison_direction) {
91       opcode = HloOpcode::kCompare;
92     } else {
93       return InvalidArgument("Unknown opcode: %s", proto.opcode());
94     }
95   }
96 
97   TF_RET_CHECK(proto.has_shape());
98 
99   std::unique_ptr<HloInstruction> instruction;
100   const auto operands = [&instruction_map, &proto](int index) {
101     return instruction_map.at(proto.operand_ids(index));
102   };
103   const auto all_operands = [&instruction_map, &proto]() {
104     std::vector<HloInstruction*> result(proto.operand_ids_size());
105     std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
106                    result.begin(), [&instruction_map](int64 operand_id) {
107                      return instruction_map.at(operand_id);
108                    });
109     return result;
110   };
111   const auto computations = [&computation_map, &proto](int index) {
112     return computation_map.at(proto.called_computation_ids(index));
113   };
114   const auto all_computations = [&computation_map, &proto]() {
115     std::vector<HloComputation*> result(proto.called_computation_ids_size());
116     std::transform(proto.called_computation_ids().begin(),
117                    proto.called_computation_ids().end(), result.begin(),
118                    [&computation_map](int64 computation_id) {
119                      return computation_map.at(computation_id);
120                    });
121     return result;
122   };
123 
124   TF_RET_CHECK(
125       absl::c_all_of(proto.operand_ids(),
126                      [&](int64 id) { return instruction_map.contains(id); }))
127       << proto.name() << " instruction contains invalid operand id(s)";
128 
129   TF_RET_CHECK(
130       absl::c_all_of(proto.called_computation_ids(),
131                      [&](int64 id) { return computation_map.contains(id); }))
132       << proto.name() << " instruction references invalid computation id(s)";
133 
134   Shape shape(proto.shape());
135   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
136 
137   absl::optional<int> arity = HloOpcodeArity(opcode);
138   if (arity) {
139     TF_RET_CHECK(proto.operand_ids_size() == *arity)
140         << proto.opcode() << " instruction should have " << *arity
141         << " operands but sees " << proto.operand_ids_size();
142   }
143 
144   switch (opcode) {
145     // Ops migrated to subclasses.
146     case HloOpcode::kBatchNormTraining:
147       instruction =
148           CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
149                                   proto.epsilon(), proto.feature_index());
150       break;
151     case HloOpcode::kBatchNormInference:
152       instruction = CreateBatchNormInference(
153           shape, operands(0), operands(1), operands(2), operands(3),
154           operands(4), proto.epsilon(), proto.feature_index());
155       break;
156     case HloOpcode::kBatchNormGrad:
157       instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
158                                         operands(2), operands(3), operands(4),
159                                         proto.epsilon(), proto.feature_index());
160       break;
161     case HloOpcode::kFft: {
162       std::vector<int64> fft_length(proto.fft_length().begin(),
163                                     proto.fft_length().end());
164       instruction = CreateFft(shape, operands(0), proto.fft_type(),
165                               absl::Span<const int64>(fft_length));
166       break;
167     }
168     case HloOpcode::kCompare: {
169       // Auto-upgraded from deprecated opcode skips the following.
170       if (!comparison_direction) {
171         TF_ASSIGN_OR_RETURN(
172             comparison_direction,
173             StringToComparisonDirection(proto.comparison_direction()));
174       }
175       instruction =
176           CreateCompare(shape, operands(0), operands(1), *comparison_direction);
177       break;
178     }
179     case HloOpcode::kTriangularSolve: {
180       instruction = CreateTriangularSolve(shape, operands(0), operands(1),
181                                           proto.triangular_solve_options());
182       break;
183     }
184     case HloOpcode::kCholesky: {
185       instruction =
186           CreateCholesky(shape, operands(0), proto.cholesky_options());
187       break;
188     }
189     case HloOpcode::kSend:
190       instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
191                                proto.is_host_transfer());
192       break;
193     case HloOpcode::kSendDone:
194       instruction = CreateSendDone(operands(0), proto.is_host_transfer());
195       break;
196     case HloOpcode::kRecv:
197       instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
198                                proto.channel_id(), proto.is_host_transfer());
199       break;
200     case HloOpcode::kRecvDone:
201       instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
202       break;
203     case HloOpcode::kReverse:
204       instruction = CreateReverse(shape, operands(0),
205                                   std::vector<int64>(proto.dimensions().begin(),
206                                                      proto.dimensions().end()));
207       break;
208     case HloOpcode::kConcatenate:
209       TF_RET_CHECK(proto.dimensions_size() == 1)
210           << "Concatenate instruction should have 1 dimension but sees "
211           << proto.dimensions_size();
212       instruction =
213           CreateConcatenate(shape, all_operands(), proto.dimensions(0));
214       break;
215     case HloOpcode::kConditional: {
216       TF_RET_CHECK(proto.called_computation_ids_size() > 0)
217           << "conditional should have at least 1 called computation";
218       if (operands(0)->shape().element_type() == PRED) {
219         TF_RET_CHECK(proto.called_computation_ids_size() == 2)
220             << "conditional should have exactly 2 called computations but got "
221             << proto.called_computation_ids_size();
222       }
223       TF_RET_CHECK(proto.operand_ids_size() ==
224                    proto.called_computation_ids_size() + 1)
225           << "conditional should have one branch_index operand plus one "
226              "operand per called computation but got "
227           << proto.operand_ids_size() << " operands for "
228           << proto.called_computation_ids_size() << " branch computations";
229       auto cond_operands = all_operands();
230       instruction =
231           CreateConditional(shape, cond_operands[0], all_computations(),
232                             absl::MakeSpan(cond_operands).subspan(1));
233       break;
234     }
235     case HloOpcode::kReduce:
236       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
237           << "Reduce instruction should have an even number of operands but "
238              "sees "
239           << proto.operand_ids_size();
240       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
241           << "Reduce instruction should have 1 called computation but sees "
242           << proto.called_computation_ids_size();
243       {
244         const auto reduce_operands = all_operands();
245         auto inputs = absl::MakeSpan(reduce_operands)
246                           .subspan(0, reduce_operands.size() / 2);
247         auto init_values =
248             absl::MakeSpan(reduce_operands)
249                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
250         instruction =
251             CreateReduce(shape, inputs, init_values,
252                          std::vector<int64>(proto.dimensions().begin(),
253                                             proto.dimensions().end()),
254                          computations(0));
255       }
256       break;
257     case HloOpcode::kSort: {
258       TF_RET_CHECK(proto.operand_ids_size() >= 1)
259           << "Sort instruction should have at least 1 operand but has "
260           << proto.operand_ids_size();
261       TF_RET_CHECK(proto.dimensions().size() == 1)
262           << "Sort instruction should have 1 dimension";
263       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
264           << "Sort instruction should one called computation but sees "
265           << proto.called_computation_ids_size();
266       auto sort_operands = all_operands();
267       instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
268                                computations(0), proto.is_stable());
269       break;
270     }
271     case HloOpcode::kTranspose:
272       instruction =
273           CreateTranspose(shape, operands(0),
274                           std::vector<int64>(proto.dimensions().begin(),
275                                              proto.dimensions().end()));
276       break;
277     case HloOpcode::kBroadcast:
278       instruction =
279           CreateBroadcast(shape, operands(0),
280                           std::vector<int64>(proto.dimensions().begin(),
281                                              proto.dimensions().end()));
282       break;
283     case HloOpcode::kMap:
284       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
285           << "Map instruction should have 1 called computation but sees "
286           << proto.called_computation_ids_size();
287       instruction = CreateMap(shape, all_operands(), computations(0));
288       break;
289     case HloOpcode::kSlice: {
290       std::vector<int64> slice_starts, slice_limits, slice_strides;
291       for (const HloInstructionProto::SliceDimensions& slice_dimensions :
292            proto.slice_dimensions()) {
293         slice_starts.push_back(slice_dimensions.start());
294         slice_limits.push_back(slice_dimensions.limit());
295         slice_strides.push_back(slice_dimensions.stride());
296       }
297       instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
298                                 slice_strides);
299       break;
300     }
301     case HloOpcode::kConstant: {
302       // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
303       if (proto.has_literal()) {
304         TF_ASSIGN_OR_RETURN(
305             auto literal,
306             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
307         instruction = CreateConstant(std::move(literal));
308         // Literal's shape may have no/different tiling info.
309         TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
310             instruction->shape(), shape));
311         *instruction->mutable_shape() = shape;
312       } else {
313         instruction = absl::make_unique<HloConstantInstruction>(shape);
314       }
315       break;
316     }
317     case HloOpcode::kTrace: {
318       TF_RET_CHECK(proto.has_literal());
319       TF_ASSIGN_OR_RETURN(
320           auto literal,
321           Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
322       instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
323       break;
324     }
325     case HloOpcode::kFusion: {
326       // In the proto, fused computations are held exclusively within the
327       // HloInstructionProto and do not appear as an HloComputationProto within
328       // the HloModuleProto.
329       TF_RET_CHECK(!proto.fusion_kind().empty());
330       TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
331                           StringToFusionKind(proto.fusion_kind()));
332 
333       // Find the fused computation and set its fusion instruction.
334       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
335           << "Expect 1 called computation for fusion instruction but sees "
336           << proto.called_computation_ids_size();
337       const int64 fusion_id = proto.called_computation_ids(0);
338       auto* fused_computation =
339           tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
340       TF_RET_CHECK(fused_computation != nullptr)
341           << "No fusion computation with id " << fusion_id;
342       instruction =
343           CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
344       break;
345     }
346     case HloOpcode::kRng:
347       instruction = CreateRng(shape, proto.distribution(), all_operands());
348       break;
349     case HloOpcode::kRngGetAndUpdateState:
350       instruction = CreateRngGetAndUpdateState(shape, proto.delta());
351       break;
352     case HloOpcode::kParameter:
353       instruction =
354           CreateParameter(proto.parameter_number(), shape, proto.name());
355       if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
356         instruction->set_parameter_replicated_at_leaf_buffers(
357             proto.parameter_replication().replicated_at_leaf_buffers());
358       }
359       break;
360     case HloOpcode::kGetTupleElement:
361       instruction =
362           CreateGetTupleElement(shape, operands(0), proto.tuple_index());
363       break;
364     case HloOpcode::kReducePrecision:
365       instruction = CreateReducePrecision(
366           shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
367       break;
368     case HloOpcode::kInfeed: {
369       TF_RET_CHECK(shape.IsTuple() &&
370                    (ShapeUtil::TupleElementCount(shape) == 2))
371           << "Infeed should have a tuple shape with 2 operands, but has: "
372           << shape;
373       const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
374       instruction =
375           CreateInfeed(data_shape, operands(0), proto.infeed_config());
376     } break;
377     case HloOpcode::kOutfeed: {
378       Shape outfeed_shape(proto.outfeed_shape());
379       TF_RETURN_IF_ERROR(
380           ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
381       instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
382                                   proto.outfeed_config());
383       break;
384     }
385     case HloOpcode::kAllReduce: {
386       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
387           << "AllReduce should have 1 called computation but sees "
388           << proto.called_computation_ids_size();
389       TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
390           << "AllReduce cannot have both channel_id() and all_reduce_id()";
391       absl::optional<int64> channel_id;
392       if (proto.channel_id() > 0) {
393         channel_id = proto.channel_id();
394       }
395       if (proto.all_reduce_id() > 0) {
396         channel_id = proto.all_reduce_id();
397       }
398       instruction = CreateAllReduce(
399           shape, all_operands(), computations(0),
400           /*replica_groups=*/
401           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
402                                     proto.replica_groups().end()),
403           /*constrain_layout=*/proto.constrain_layout(),
404           /*channel_id=*/channel_id);
405       break;
406     }
407     case HloOpcode::kAllToAll: {
408       absl::optional<int64> channel_id;
409       if (proto.channel_id() > 0) {
410         channel_id = proto.channel_id();
411       }
412       absl::optional<int64> split_dimension;
413       if (proto.dimensions_size() > 0) {
414         TF_RET_CHECK(proto.dimensions_size() == 1)
415             << "AllToAll cannot have more than 1 dimension (split dimension)";
416         TF_RET_CHECK(all_operands().size() == 1)
417             << "AllToAll must have a single operand when the split dimension "
418                "is specified";
419         split_dimension = proto.dimensions(0);
420       }
421       instruction = CreateAllToAll(
422           shape, all_operands(),
423           /*replica_groups=*/
424           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
425                                     proto.replica_groups().end()),
426           /*channel_id=*/channel_id, split_dimension);
427       break;
428     }
429     case HloOpcode::kCollectivePermute: {
430       std::vector<std::pair<int64, int64>> source_target_pairs(
431           proto.source_target_pairs_size());
432       absl::optional<int64> channel_id;
433       if (proto.channel_id() > 0) {
434         channel_id = proto.channel_id();
435       }
436       for (int i = 0; i < source_target_pairs.size(); i++) {
437         source_target_pairs[i].first = proto.source_target_pairs(i).source();
438         source_target_pairs[i].second = proto.source_target_pairs(i).target();
439       }
440       instruction = CreateCollectivePermute(shape, operands(0),
441                                             source_target_pairs, channel_id);
442       break;
443     }
444     case HloOpcode::kReplicaId: {
445       instruction = CreateReplicaId();
446       break;
447     }
448     case HloOpcode::kPartitionId: {
449       instruction = CreatePartitionId();
450       break;
451     }
452     case HloOpcode::kConvolution: {
453       TF_RET_CHECK(proto.has_window());
454       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
455       PrecisionConfig precision_config = proto.precision_config();
456       precision_config.mutable_operand_precision()->Resize(
457           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
458       instruction = CreateConvolve(
459           shape, operands(0), operands(1),
460           std::max<int64>(proto.feature_group_count(), 1),
461           std::max<int64>(proto.batch_group_count(), 1), proto.window(),
462           proto.convolution_dimension_numbers(), precision_config);
463       break;
464     }
465     case HloOpcode::kReduceWindow:
466       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
467           << "ReduceWindow should have 1 called computation but sees "
468           << proto.called_computation_ids_size();
469       instruction = CreateReduceWindow(shape, operands(0), operands(1),
470                                        proto.window(), computations(0));
471       break;
472     case HloOpcode::kSelectAndScatter:
473       TF_RET_CHECK(proto.called_computation_ids_size() == 2)
474           << "SelectAndScatter should have 2 called computations but sees "
475           << proto.called_computation_ids_size();
476       instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
477                                            proto.window(), operands(1),
478                                            operands(2), computations(1));
479       break;
480     case HloOpcode::kCustomCall: {
481       if (proto.constrain_layout()) {
482         // A proto RepeatedPtrField cannot be converted to a Span (it is a
483         // vector of pointers essentially) so create a vector of shapes to pass
484         // in.
485         std::vector<Shape> operand_shapes;
486         for (const ShapeProto& shape_proto :
487              proto.operand_shapes_with_layout()) {
488           operand_shapes.emplace_back(shape_proto);
489         }
490         instruction =
491             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
492                              operand_shapes, proto.backend_config());
493       } else {
494         instruction =
495             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
496                              proto.backend_config());
497       }
498       auto custom_call_instr =
499           Cast<HloCustomCallInstruction>(instruction.get());
500       if (proto.has_window()) {
501         custom_call_instr->set_window(proto.window());
502       }
503       if (proto.has_convolution_dimension_numbers()) {
504         custom_call_instr->set_convolution_dimension_numbers(
505             proto.convolution_dimension_numbers());
506       }
507       custom_call_instr->set_feature_group_count(
508           std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
509       custom_call_instr->set_batch_group_count(
510           std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
511       custom_call_instr->set_custom_call_has_side_effect(
512           proto.custom_call_has_side_effect());
513       break;
514     }
515     case HloOpcode::kPad:
516       TF_RET_CHECK(proto.has_padding_config());
517       instruction =
518           CreatePad(shape, operands(0), operands(1), proto.padding_config());
519       break;
520     case HloOpcode::kDynamicSlice: {
521       std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
522       absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
523       TF_RET_CHECK(proto.operand_ids_size() >= 1)
524           << "DynamicSlice instruction should have at least 1 operands but "
525              "sees "
526           << proto.operand_ids_size();
527       // TODO(b/118437727): Old form, make the check unconditional.
528       if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
529         auto expected_operands = 1 + operands(0)->shape().rank();
530         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
531             << "DynamicSlice instruction should have " << expected_operands
532             << " operands, but has " << proto.operand_ids_size();
533       }
534       const auto& operand_vector = all_operands();
535       instruction = CreateDynamicSlice(
536           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
537           slice_sizes);
538       break;
539     }
540     case HloOpcode::kDynamicUpdateSlice: {
541       TF_RET_CHECK(proto.operand_ids_size() >= 2)
542           << "DynamicUpdateSlice instruction should have at least 2 operands "
543              "but sees "
544           << proto.operand_ids_size();
545       // TODO(b/118437727): Old form, make the check unconditional.
546       if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
547         auto expected_operands = 2 + operands(0)->shape().rank();
548         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
549             << "DynamicUpdateSlice instruction should have "
550             << expected_operands << " operands, but has "
551             << proto.operand_ids_size();
552       }
553       const auto& operand_vector = all_operands();
554       instruction =
555           CreateDynamicUpdateSlice(shape, operands(0), operands(1),
556                                    absl::MakeSpan(operand_vector).subspan(2));
557 
558       break;
559     }
560     case HloOpcode::kGather: {
561       TF_RET_CHECK(proto.has_gather_dimension_numbers())
562           << "Gather instruction should have GatherDimensionNumbers set.";
563       auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
564           proto.gather_dimension_numbers());
565       std::vector<int64> gather_slice_sizes;
566       for (int64 bound : proto.gather_slice_sizes()) {
567         gather_slice_sizes.push_back(bound);
568       }
569       instruction = CreateGather(shape, operands(0), operands(1),
570                                  *gather_dimension_numbers, gather_slice_sizes,
571                                  proto.indices_are_sorted());
572       break;
573     }
574     case HloOpcode::kScatter: {
575       TF_RET_CHECK(proto.has_scatter_dimension_numbers())
576           << "Scatter instruction should have ScatterDimensionNumbers set.";
577       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
578           << "Scatter instruction should have 1 called computation but sees "
579           << proto.called_computation_ids_size();
580       auto scatter_dimension_numbers =
581           absl::make_unique<ScatterDimensionNumbers>(
582               proto.scatter_dimension_numbers());
583       instruction =
584           CreateScatter(shape, operands(0), operands(1), operands(2),
585                         computations(0), *scatter_dimension_numbers,
586                         proto.indices_are_sorted(), proto.unique_indices());
587       break;
588     }
589     case HloOpcode::kIota:
590       TF_RET_CHECK(proto.dimensions_size() == 1)
591           << "Iota instruction should have 1 dimension but sees "
592           << proto.dimensions_size();
593       instruction = CreateIota(shape, proto.dimensions(0));
594       break;
595     case HloOpcode::kDot: {
596       TF_RET_CHECK(proto.has_dot_dimension_numbers())
597           << "Dot instruction should have dot_dimension_numbers.";
598       PrecisionConfig precision_config = proto.precision_config();
599       precision_config.mutable_operand_precision()->Resize(
600           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
601       instruction = absl::make_unique<HloDotInstruction>(
602           shape, operands(0), operands(1), proto.dot_dimension_numbers(),
603           precision_config);
604       break;
605     }
606     case HloOpcode::kDomain: {
607       std::shared_ptr<const HloSharding> entry_hlo_sharding;
608       std::shared_ptr<const HloSharding> exit_hlo_sharding;
609       if (proto.has_domain_entry_sharding()) {
610         TF_ASSIGN_OR_RETURN(
611             HloSharding sharding,
612             HloSharding::FromProto(proto.domain_entry_sharding()));
613         entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
614       }
615       if (proto.has_domain_exit_sharding()) {
616         TF_ASSIGN_OR_RETURN(
617             HloSharding sharding,
618             HloSharding::FromProto(proto.domain_exit_sharding()));
619         exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
620       }
621       instruction = absl::make_unique<HloDomainInstruction>(
622           shape, operands(0),
623           absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
624           absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
625       break;
626     }
627     case HloOpcode::kGetDimensionSize:
628       TF_RET_CHECK(proto.dimensions_size() == 1);
629       instruction =
630           CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
631       break;
632     case HloOpcode::kSetDimensionSize:
633       TF_RET_CHECK(proto.dimensions_size() == 1);
634       instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
635                                            proto.dimensions(0));
636       break;
637     case HloOpcode::kReshape: {
638       int64 inferred_dimension = -1;
639       if (!proto.dimensions().empty()) {
640         inferred_dimension = proto.dimensions()[0];
641       }
642       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
643                    ShapeUtil::ElementsIn(shape) ==
644                        ShapeUtil::ElementsIn(operands(0)->shape()))
645           << "shape: " << ShapeUtil::HumanString(shape)
646           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
647       instruction = CreateReshape(shape, operands(0), inferred_dimension);
648       break;
649     }
650     default: {
651       instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
652       for (const int64 operand_id : proto.operand_ids()) {
653         instruction->AppendOperand(instruction_map.at(operand_id));
654       }
655       if (instruction->opcode() != HloOpcode::kFusion) {
656         if (instruction->opcode() == HloOpcode::kCall) {
657           TF_RET_CHECK(proto.called_computation_ids_size() == 1)
658               << "Call should have 1 called computation but has "
659               << proto.called_computation_ids_size();
660         }
661         for (const int64 computation_id : proto.called_computation_ids()) {
662           instruction->called_computations_.push_back(
663               computation_map.at(computation_id));
664         }
665       }
666       TF_RET_CHECK(!proto.has_precision_config())
667           << instruction->opcode() << proto.DebugString();
668       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
669       break;
670     }
671   }
672 
673   for (const int64 predecessor_id : proto.control_predecessor_ids()) {
674     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
675         << "No instruction with id " << predecessor_id;
676     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
677                            ->AddControlDependencyTo(instruction.get()));
678   }
679 
680   TF_RET_CHECK(!proto.name().empty());
681   instruction->SetAndSanitizeName(proto.name());
682   instruction->metadata_ = proto.metadata();
683   instruction->backend_config_ = proto.backend_config();
684   instruction->outer_dimension_partitions_.assign(
685       proto.outer_dimension_partitions().begin(),
686       proto.outer_dimension_partitions().end());
687 
688   TF_RET_CHECK(proto.id() >= 0)
689       << "Instruction with negative id: " << proto.id();
690   TF_RET_CHECK(proto.id() <= INT_MAX)
691       << "Instruction with id > INT_MAX: " << proto.id();
692   instruction->unique_id_ = proto.id();
693 
694   if (proto.has_sharding()) {
695     TF_ASSIGN_OR_RETURN(const auto& sharding,
696                         HloSharding::FromProto(proto.sharding()));
697     instruction->set_sharding(sharding);
698   }
699 
700   if (proto.has_frontend_attributes()) {
701     instruction->set_frontend_attributes(proto.frontend_attributes());
702   }
703 
704   return std::move(instruction);
705 }
706 
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)707 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
708     int64 parameter_number, const Shape& shape, const string& name) {
709   return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
710                                                     name);
711 }
712 
CreateTrace(const string & tag,HloInstruction * operand)713 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
714     const string& tag, HloInstruction* operand) {
715   return absl::make_unique<HloTraceInstruction>(tag, operand);
716 }
717 
CreateConstant(Literal literal)718 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
719     Literal literal) {
720   return absl::make_unique<HloConstantInstruction>(std::move(literal));
721 }
722 
CreateIota(const Shape & shape,int64 iota_dimension)723 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
724     const Shape& shape, int64 iota_dimension) {
725   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
726 }
727 
728 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)729 HloInstruction::CreateGetTupleElement(const Shape& shape,
730                                       HloInstruction* operand, int64 index) {
731   return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
732                                                           index);
733 }
734 
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)735 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
736     const Shape& shape, RandomDistribution distribution,
737     absl::Span<HloInstruction* const> parameters) {
738   return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
739 }
740 
741 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64 delta)742 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
743   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
744 }
745 
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)746 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
747     const Shape& shape, HloOpcode opcode,
748     absl::Span<HloInstruction* const> operands) {
749   if (opcode == HloOpcode::kCopy) {
750     // It is impossible to copy an opaque shape, we don't know how big it is.
751     CHECK(!shape.IsOpaque());
752   }
753   auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
754   for (auto operand : operands) {
755     instruction->AppendOperand(operand);
756   }
757   return instruction;
758 }
759 
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)760 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
761     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
762   // Only certain opcodes are supported with CreateUnary: opcodes of unary
763   // instructions with no auxiliary fields.
764   switch (opcode) {
765     case HloOpcode::kAbs:
766     case HloOpcode::kRoundNearestAfz:
767     case HloOpcode::kBitcast:
768     case HloOpcode::kCeil:
769     case HloOpcode::kCopy:
770     case HloOpcode::kCopyStart:
771     case HloOpcode::kCopyDone:
772     case HloOpcode::kCos:
773     case HloOpcode::kClz:
774     case HloOpcode::kExp:
775     case HloOpcode::kExpm1:
776     case HloOpcode::kFloor:
777     case HloOpcode::kImag:
778     case HloOpcode::kIsFinite:
779     case HloOpcode::kLog:
780     case HloOpcode::kLog1p:
781     case HloOpcode::kNot:
782     case HloOpcode::kNegate:
783     case HloOpcode::kPopulationCount:
784     case HloOpcode::kReal:
785     case HloOpcode::kRsqrt:
786     case HloOpcode::kSign:
787     case HloOpcode::kSin:
788     case HloOpcode::kSqrt:
789     case HloOpcode::kTanh:
790       break;
791     default:
792       LOG(FATAL) << "Invalid unary instruction opcode "
793                  << HloOpcodeString(opcode);
794   }
795   return CreateNary(shape, opcode, {operand});
796 }
797 
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)798 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
799     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
800     HloInstruction* rhs) {
801   // Only certain opcodes are supported with CreateBinary: opcodes of binary
802   // instructions with no auxiliary fields.
803   switch (opcode) {
804     case HloOpcode::kAdd:
805     case HloOpcode::kAtan2:
806     case HloOpcode::kDivide:
807     case HloOpcode::kComplex:
808     case HloOpcode::kMaximum:
809     case HloOpcode::kMinimum:
810     case HloOpcode::kMultiply:
811     case HloOpcode::kPower:
812     case HloOpcode::kRemainder:
813     case HloOpcode::kSubtract:
814     case HloOpcode::kAnd:
815     case HloOpcode::kOr:
816     case HloOpcode::kXor:
817     case HloOpcode::kShiftLeft:
818     case HloOpcode::kShiftRightArithmetic:
819     case HloOpcode::kShiftRightLogical:
820       break;
821     default:
822       LOG(FATAL) << "Invalid binary instruction opcode "
823                  << HloOpcodeString(opcode);
824   }
825   return CreateNary(shape, opcode, {lhs, rhs});
826 }
827 
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)828 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
829     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
830     HloInstruction* rhs, HloInstruction* ehs) {
831   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
832   // instructions with no auxiliary fields.
833   switch (opcode) {
834     case HloOpcode::kClamp:
835     case HloOpcode::kSelect:
836     case HloOpcode::kTupleSelect:
837       break;
838     default:
839       LOG(FATAL) << "Invalid ternary instruction opcode "
840                  << HloOpcodeString(opcode);
841   }
842   return CreateNary(shape, opcode, {lhs, rhs, ehs});
843 }
844 
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)845 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
846     const Shape& shape, HloOpcode opcode,
847     absl::Span<HloInstruction* const> operands) {
848   CHECK_EQ(HloOpcode::kTuple, opcode);
849   return CreateNary(shape, opcode, operands);
850 }
851 
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)852 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
853     const Shape& shape, absl::Span<HloInstruction* const> operands,
854     HloComputation* map_computation) {
855   return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
856 }
857 
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)858 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
859     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
860     int64 feature_group_count, int64 batch_group_count, const Window& window,
861     const ConvolutionDimensionNumbers& dimension_numbers,
862     const PrecisionConfig& precision_config) {
863   return absl::make_unique<HloConvolutionInstruction>(
864       shape, lhs, rhs, feature_group_count, batch_group_count, window,
865       dimension_numbers, precision_config);
866 }
867 
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)868 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
869     const Shape& shape, HloInstruction* operand, FftType fft_type,
870     absl::Span<const int64> fft_length) {
871   return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
872                                               fft_length);
873 }
874 
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)875 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
876     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
877     ComparisonDirection direction) {
878   return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
879 }
880 
881 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)882 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
883                                       HloInstruction* b,
884                                       const TriangularSolveOptions& options) {
885   return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
886 }
887 
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)888 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
889     const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
890   return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
891 }
892 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)893 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
894     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
895     const DotDimensionNumbers& dimension_numbers,
896     const PrecisionConfig& precision_config) {
897   return absl::make_unique<HloDotInstruction>(
898       shape, lhs, rhs, dimension_numbers, precision_config);
899 }
900 
901 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)902 HloInstruction::CreateReducePrecision(const Shape& shape,
903                                       HloInstruction* operand,
904                                       const int exponent_bits,
905                                       const int mantissa_bits) {
906   return absl::make_unique<HloReducePrecisionInstruction>(
907       shape, operand, exponent_bits, mantissa_bits);
908 }
909 
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)910 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
911     const Shape& shape, absl::Span<HloInstruction* const> operands,
912     HloComputation* reduce_computation,
913     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
914     const absl::optional<int64>& channel_id) {
915   return absl::make_unique<HloAllReduceInstruction>(
916       shape, operands, reduce_computation, replica_groups, constrain_layout,
917       channel_id);
918 }
919 
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)920 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
921     const Shape& shape, absl::Span<HloInstruction* const> operands,
922     const std::vector<ReplicaGroup>& replica_groups,
923     const absl::optional<int64>& channel_id,
924     const absl::optional<int64>& split_dimension) {
925   return absl::make_unique<HloAllToAllInstruction>(
926       shape, operands, replica_groups, channel_id, split_dimension);
927 }
928 
929 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)930 HloInstruction::CreateCollectivePermute(
931     const Shape& shape, HloInstruction* operand,
932     const std::vector<std::pair<int64, int64>>& source_target_pairs,
933     const absl::optional<int64>& channel_id) {
934   return absl::make_unique<HloCollectivePermuteInstruction>(
935       shape, operand, source_target_pairs, channel_id);
936 }
937 
CreateReplicaId()938 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
939   return absl::WrapUnique(
940       new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
941 }
942 
943 /* static */ std::unique_ptr<HloInstruction>
CreatePartitionId()944 HloInstruction::CreatePartitionId() {
945   return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId,
946                                              ShapeUtil::MakeShape(U32, {})));
947 }
948 
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)949 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
950     const Shape& infeed_shape, HloInstruction* token_operand,
951     const string& config) {
952   return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
953                                                  config);
954 }
955 
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)956 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
957     const Shape& outfeed_shape, HloInstruction* operand,
958     HloInstruction* token_operand, absl::string_view outfeed_config) {
959   return absl::make_unique<HloOutfeedInstruction>(
960       outfeed_shape, operand, token_operand, outfeed_config);
961 }
962 
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)963 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
964     HloInstruction* operand, HloInstruction* token, int64 channel_id,
965     bool is_host_transfer) {
966   return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
967                                                is_host_transfer);
968 }
969 
CreateSendDone(HloInstruction * operand,bool is_host_transfer)970 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
971     HloInstruction* operand, bool is_host_transfer) {
972   auto send_operand = DynCast<HloSendInstruction>(operand);
973   CHECK(send_operand != nullptr)
974       << "SendDone must take the context operand from Send";
975   return absl::make_unique<HloSendDoneInstruction>(send_operand,
976                                                    is_host_transfer);
977 }
978 
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)979 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
980     const Shape& shape, HloInstruction* token, int64 channel_id,
981     bool is_host_transfer) {
982   return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
983                                                is_host_transfer);
984 }
985 
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)986 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
987     HloInstruction* operand, bool is_host_transfer) {
988   auto recv_operand = DynCast<HloRecvInstruction>(operand);
989   CHECK(recv_operand != nullptr)
990       << "RecvDone must take the context operand from Recv";
991   return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
992                                                    is_host_transfer);
993 }
994 
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)995 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
996     const Shape& shape, HloInstruction* operand,
997     absl::Span<const int64> dimensions) {
998   return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
999 }
1000 
CreateAfterAll(absl::Span<HloInstruction * const> operands)1001 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1002     absl::Span<HloInstruction* const> operands) {
1003   CHECK(!operands.empty());
1004   auto instruction = absl::WrapUnique(
1005       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1006   for (auto operand : operands) {
1007     instruction->AppendOperand(operand);
1008   }
1009   return instruction;
1010 }
1011 
CreateToken()1012 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1013   return absl::WrapUnique(
1014       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1015 }
1016 
1017 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1018 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1019                                     HloInstruction* token_operand) {
1020   auto instruction = absl::WrapUnique(
1021       new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1022   instruction->AppendOperand(data_operand);
1023   instruction->AppendOperand(token_operand);
1024   return instruction;
1025 }
1026 
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1027 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1028     const Shape& shape, HloComputation* condition, HloComputation* body,
1029     HloInstruction* init) {
1030   auto instruction =
1031       absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1032   instruction->AppendOperand(init);
1033   // Body comes before condition computation in the vector.
1034   instruction->called_computations_.push_back(body);
1035   instruction->called_computations_.push_back(condition);
1036   return instruction;
1037 }
1038 
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1039 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1040     const Shape& shape, HloInstruction* pred,
1041     HloInstruction* true_computation_arg, HloComputation* true_computation,
1042     HloInstruction* false_computation_arg, HloComputation* false_computation) {
1043   auto instruction =
1044       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1045   instruction->AppendOperand(pred);
1046   instruction->AppendOperand(true_computation_arg);
1047   instruction->AppendOperand(false_computation_arg);
1048   // In called_computations_, the index of true_computation must be 0 and that
1049   // of false computation must be 1, as defined by kTrueComputationIndex and
1050   // kFalseComputationIndex.
1051   instruction->called_computations_.push_back(true_computation);
1052   instruction->called_computations_.push_back(false_computation);
1053   return instruction;
1054 }
1055 
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1056 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1057     const Shape& shape, HloInstruction* branch_index,
1058     absl::Span<HloComputation* const> branch_computations,
1059     absl::Span<HloInstruction* const> branch_computation_args) {
1060   auto instruction =
1061       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1062   instruction->AppendOperand(branch_index);
1063   CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1064   for (int i = 0; i < branch_computations.size(); ++i) {
1065     instruction->called_computations_.push_back(branch_computations[i]);
1066     instruction->AppendOperand(branch_computation_args[i]);
1067   }
1068   return instruction;
1069 }
1070 
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1071 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1072     const Shape& shape, HloInstruction* operand,
1073     absl::Span<const int64> start_indices,
1074     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1075   return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1076                                                 limit_indices, strides);
1077 }
1078 
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1079 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1080     const Shape& shape, HloInstruction* operand,
1081     absl::Span<HloInstruction* const> start_indices,
1082     absl::Span<const int64> slice_sizes) {
1083   return absl::make_unique<HloDynamicSliceInstruction>(
1084       shape, operand, start_indices, slice_sizes);
1085 }
1086 
1087 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1088 HloInstruction::CreateDynamicUpdateSlice(
1089     const Shape& shape, HloInstruction* operand, HloInstruction* update,
1090     absl::Span<HloInstruction* const> start_indices) {
1091   return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1092       shape, operand, update, start_indices);
1093 }
1094 
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1095 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1096     const Shape& shape, absl::Span<HloInstruction* const> operands,
1097     int64 dimension) {
1098   return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1099                                                       dimension);
1100 }
1101 
CreateConvert(const Shape & shape,HloInstruction * operand)1102 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1103     const Shape& shape, HloInstruction* operand) {
1104   auto instruction =
1105       absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1106   instruction->AppendOperand(operand);
1107   return instruction;
1108 }
1109 
1110 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1111 HloInstruction::CreateBitcastConvert(const Shape& shape,
1112                                      HloInstruction* operand) {
1113   auto instruction =
1114       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1115   instruction->AppendOperand(operand);
1116   return instruction;
1117 }
1118 
CreateBitcast(const Shape & shape,HloInstruction * operand)1119 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1120     const Shape& shape, HloInstruction* operand) {
1121   auto instruction =
1122       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1123   instruction->AppendOperand(operand);
1124   return instruction;
1125 }
1126 
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1127 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1128     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1129     absl::Span<const int64> dimensions_to_reduce,
1130     HloComputation* reduce_computation) {
1131   auto instruction = absl::WrapUnique(new HloReduceInstruction(
1132       shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1133   return std::move(instruction);
1134 }
1135 
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1136 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1137     const Shape& shape, absl::Span<HloInstruction* const> operands,
1138     absl::Span<HloInstruction* const> init_values,
1139     absl::Span<const int64> dimensions_to_reduce,
1140     HloComputation* reduce_computation) {
1141   std::vector<HloInstruction*> all_args;
1142   all_args.reserve(operands.size() * 2);
1143   all_args.insert(all_args.end(), operands.begin(), operands.end());
1144   all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1145   return absl::make_unique<HloReduceInstruction>(
1146       shape, all_args, dimensions_to_reduce, reduce_computation);
1147 }
1148 
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1149 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1150     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1151     const Window& window, HloComputation* reduce_computation) {
1152   return absl::make_unique<HloReduceWindowInstruction>(
1153       shape, operand, init_value, window, reduce_computation);
1154 }
1155 
1156 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1157 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1158                                         HloInstruction* operand,
1159                                         HloInstruction* scale,
1160                                         HloInstruction* offset, float epsilon,
1161                                         int64 feature_index) {
1162   return absl::make_unique<HloBatchNormTrainingInstruction>(
1163       shape, operand, scale, offset, epsilon, feature_index);
1164 }
1165 
1166 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1167 HloInstruction::CreateBatchNormInference(
1168     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1169     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1170     float epsilon, int64 feature_index) {
1171   return absl::make_unique<HloBatchNormInferenceInstruction>(
1172       shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1173 }
1174 
1175 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1176 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1177                                     HloInstruction* scale, HloInstruction* mean,
1178                                     HloInstruction* variance,
1179                                     HloInstruction* grad_output, float epsilon,
1180                                     int64 feature_index) {
1181   return absl::make_unique<HloBatchNormGradInstruction>(
1182       shape, operand, scale, mean, variance, grad_output, epsilon,
1183       feature_index);
1184 }
1185 
1186 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1187 HloInstruction::CreateSelectAndScatter(
1188     const Shape& shape, HloInstruction* operand, HloComputation* select,
1189     const Window& window, HloInstruction* source, HloInstruction* init_value,
1190     HloComputation* scatter) {
1191   return absl::make_unique<HloSelectAndScatterInstruction>(
1192       shape, operand, select, window, source, init_value, scatter);
1193 }
1194 
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1195 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1196     const Shape& shape, HloInstruction* operand,
1197     absl::Span<const int64> broadcast_dimensions) {
1198   return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1199                                                     broadcast_dimensions);
1200 }
1201 
1202 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1203 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1204                                        HloInstruction* operand,
1205                                        int64 dimension) {
1206   return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1207                                                            dimension);
1208 }
1209 
1210 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)1211 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1212                                        HloInstruction* operand,
1213                                        HloInstruction* val, int64 dimension) {
1214   return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1215                                                            dimension);
1216 }
1217 
1218 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1219 HloInstruction::CreateBroadcastSequence(
1220     const Shape& output_shape, HloInstruction* operand,
1221     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1222         adder) {
1223   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1224         operand->shape().rank() == output_shape.rank());
1225   Shape broadcast_shape = ShapeUtil::ChangeElementType(
1226       output_shape, operand->shape().element_type());
1227   // Do explicit broadcast for scalar.
1228   if (ShapeUtil::IsScalar(operand->shape())) {
1229     auto broadcast =
1230         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1231     broadcast->set_metadata(operand->metadata());
1232     if (operand->has_sharding()) {
1233       broadcast->set_sharding(operand->sharding());
1234     }
1235     broadcast->set_frontend_attributes(operand->frontend_attributes());
1236     return broadcast;
1237   }
1238   // Do explicit broadcast for degenerate broadcast.
1239   std::vector<int64> broadcast_dimensions;
1240   std::vector<int64> reshaped_dimensions;
1241   for (int i = 0; i < operand->shape().rank(); i++) {
1242     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1243       broadcast_dimensions.push_back(i);
1244       reshaped_dimensions.push_back(operand->shape().dimensions(i));
1245     } else {
1246       CHECK_EQ(operand->shape().dimensions(i), 1)
1247           << "An explicit broadcast sequence requires the broadcasted "
1248              "dimensions to be trivial; operand: "
1249           << operand->ToString() << "; output_shape: " << output_shape;
1250     }
1251   }
1252   // Eliminate the size one dimensions.
1253   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1254       ShapeUtil::MakeShape(operand->shape().element_type(),
1255                            reshaped_dimensions),
1256       operand));
1257   reshaped_operand->set_metadata(operand->metadata());
1258   if (operand->has_sharding()) {
1259     reshaped_operand->set_sharding(operand->sharding());
1260   }
1261   reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1262   // Broadcast 'reshape' up to the larger size.
1263   auto broadcast = HloInstruction::CreateBroadcast(
1264       broadcast_shape, reshaped_operand, broadcast_dimensions);
1265   broadcast->set_metadata(operand->metadata());
1266   if (operand->has_sharding()) {
1267     broadcast->set_sharding(operand->sharding());
1268   }
1269   broadcast->set_frontend_attributes(operand->frontend_attributes());
1270   return broadcast;
1271 }
1272 
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1273 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1274     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1275     const PaddingConfig& padding_config) {
1276   return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1277                                               padding_config);
1278 }
1279 
CreateReshape(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1280 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1281     const Shape& shape, HloInstruction* operand, int64 inferred_dimension) {
1282   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1283            ShapeUtil::ElementsIn(operand->shape()))
1284       << "shape: " << ShapeUtil::HumanString(shape)
1285       << " operand: " << ShapeUtil::HumanString(operand->shape());
1286 
1287   return absl::make_unique<HloReshapeInstruction>(shape, operand,
1288                                                   inferred_dimension);
1289 }
1290 
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1291 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1292     const Shape& shape, HloInstruction* operand,
1293     absl::Span<const int64> dimensions) {
1294   return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1295 }
1296 
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1297 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1298     const Shape& shape, int64 dimension,
1299     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1300     bool is_stable) {
1301   return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1302                                                compare, is_stable);
1303 }
1304 
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1305 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1306     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1307   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1308                                                  fused_root);
1309 }
1310 
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1311 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1312     const Shape& shape, FusionKind fusion_kind,
1313     absl::Span<HloInstruction* const> operands,
1314     HloComputation* fusion_computation) {
1315   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1316                                                  fusion_computation);
1317 }
1318 
set_single_sharding(const HloSharding & sharding)1319 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1320   CHECK(!sharding.IsTuple()) << sharding;
1321   if (shape().IsTuple()) {
1322     set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1323   } else {
1324     set_sharding(sharding);
1325   }
1326 }
1327 
SetupDerivedInstruction(HloInstruction * derived_instruction) const1328 void HloInstruction::SetupDerivedInstruction(
1329     HloInstruction* derived_instruction) const {
1330   if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
1331                                   shape_, derived_instruction->shape())) {
1332     // Only copy sharding if the shape of the two instruction is compatible
1333     // because copying it between differently shaped instructions can produce
1334     // invalid shardings.
1335     derived_instruction->set_sharding(*sharding_);
1336   } else {
1337     derived_instruction->clear_sharding();
1338   }
1339   derived_instruction->set_metadata(metadata_);
1340   derived_instruction->set_frontend_attributes(frontend_attributes_);
1341 }
1342 
HasSideEffectNoRecurse() const1343 bool HloInstruction::HasSideEffectNoRecurse() const {
1344   switch (opcode_) {
1345     case HloOpcode::kSend:
1346     case HloOpcode::kSendDone:
1347     case HloOpcode::kRecv:
1348     case HloOpcode::kRecvDone:
1349     case HloOpcode::kRng:
1350     case HloOpcode::kRngGetAndUpdateState:
1351     case HloOpcode::kInfeed:
1352     case HloOpcode::kOutfeed:
1353     case HloOpcode::kTrace:
1354       return true;
1355     case HloOpcode::kAllReduce:
1356       return channel_id().has_value() ||
1357              Cast<HloAllReduceInstruction>(this)->constrain_layout();
1358     case HloOpcode::kCustomCall:
1359       return Cast<HloCustomCallInstruction>(this)
1360           ->custom_call_has_side_effect();
1361     default:
1362       return false;
1363   }
1364 }
1365 
HasSideEffect() const1366 bool HloInstruction::HasSideEffect() const {
1367   if (HasSideEffectNoRecurse()) {
1368     return true;
1369   }
1370   // Check if any of the called computations has a side effect.
1371   for (const auto& computation : called_computations()) {
1372     if (computation->HasSideEffect()) {
1373       return true;
1374     }
1375   }
1376   return false;
1377 }
1378 
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1379 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1380     const Shape& shape, absl::Span<HloInstruction* const> operands,
1381     HloComputation* computation) {
1382   std::unique_ptr<HloInstruction> instruction =
1383       absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1384   for (auto operand : operands) {
1385     instruction->AppendOperand(operand);
1386   }
1387   instruction->called_computations_.push_back(computation);
1388   return instruction;
1389 }
1390 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)1391 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1392     const Shape& shape, absl::Span<HloInstruction* const> operands,
1393     absl::string_view custom_call_target, string opaque) {
1394   return absl::make_unique<HloCustomCallInstruction>(
1395       shape, operands, custom_call_target, std::move(opaque));
1396 }
1397 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque)1398 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1399     const Shape& shape, absl::Span<HloInstruction* const> operands,
1400     absl::string_view custom_call_target,
1401     absl::Span<const Shape> operand_shapes_with_layout, string opaque) {
1402   return absl::make_unique<HloCustomCallInstruction>(
1403       shape, operands, custom_call_target, std::move(opaque),
1404       operand_shapes_with_layout);
1405 }
1406 
CreateTuple(absl::Span<HloInstruction * const> elements)1407 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1408     absl::Span<HloInstruction* const> elements) {
1409   std::vector<Shape> element_shapes;
1410   for (auto element : elements) {
1411     element_shapes.push_back(element->shape());
1412   }
1413   Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1414   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1415 }
1416 
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1417 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1418     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1419     const GatherDimensionNumbers& gather_dim_numbers,
1420     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1421   return absl::make_unique<HloGatherInstruction>(
1422       shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1423       indices_are_sorted);
1424 }
1425 
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1426 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1427     const Shape& shape, HloInstruction* operand,
1428     HloInstruction* scatter_indices, HloInstruction* updates,
1429     HloComputation* update_computation,
1430     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1431     bool unique_indices) {
1432   return absl::make_unique<HloScatterInstruction>(
1433       shape, operand, scatter_indices, updates, update_computation,
1434       scatter_dim_numbers, indices_are_sorted, unique_indices);
1435 }
1436 
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1437 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1438     const Shape& shape, HloInstruction* operand,
1439     std::unique_ptr<DomainMetadata> operand_side_metadata,
1440     std::unique_ptr<DomainMetadata> user_side_metadata) {
1441   return absl::make_unique<HloDomainInstruction>(
1442       shape, operand, std::move(operand_side_metadata),
1443       std::move(user_side_metadata));
1444 }
1445 
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1446 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1447     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1448     HloCloneContext* context) const {
1449   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
1450   VLOG(3) << "  new operands:";
1451   for (const HloInstruction* new_operand : new_operands) {
1452     VLOG(3) << "    %" << new_operand->name();
1453   }
1454 
1455   std::unique_ptr<HloInstruction> clone;
1456   // Explicitly call the factory for the instruction type. This is more robust
1457   // in the face of code changes than copying fields explicitly. This also
1458   // properly sets the user fields of the operands.
1459   switch (opcode_) {
1460     // Ops migrated to subclasses.
1461     // TODO(b/80131774): Remove this switch when migration is complete.
1462     case HloOpcode::kBatchNormTraining:
1463     case HloOpcode::kBatchNormInference:
1464     case HloOpcode::kBatchNormGrad:
1465     case HloOpcode::kFft:
1466     case HloOpcode::kCompare:
1467     case HloOpcode::kSend:
1468     case HloOpcode::kSendDone:
1469     case HloOpcode::kRecv:
1470     case HloOpcode::kRecvDone:
1471     case HloOpcode::kReverse:
1472     case HloOpcode::kConcatenate:
1473     case HloOpcode::kReduce:
1474     case HloOpcode::kTranspose:
1475     case HloOpcode::kBroadcast:
1476     case HloOpcode::kReshape:
1477     case HloOpcode::kMap:
1478     case HloOpcode::kSlice:
1479     case HloOpcode::kConstant:
1480     case HloOpcode::kTrace:
1481     case HloOpcode::kFusion:
1482     case HloOpcode::kRng:
1483     case HloOpcode::kRngGetAndUpdateState:
1484     case HloOpcode::kParameter:
1485     case HloOpcode::kGetTupleElement:
1486     case HloOpcode::kReducePrecision:
1487     case HloOpcode::kAllReduce:
1488     case HloOpcode::kAllToAll:
1489     case HloOpcode::kCollectivePermute:
1490     case HloOpcode::kInfeed:
1491     case HloOpcode::kOutfeed:
1492     case HloOpcode::kConvolution:
1493     case HloOpcode::kCustomCall:
1494     case HloOpcode::kReduceWindow:
1495     case HloOpcode::kSelectAndScatter:
1496     case HloOpcode::kPad:
1497     case HloOpcode::kDynamicSlice:
1498     case HloOpcode::kSort:
1499     case HloOpcode::kGather:
1500     case HloOpcode::kScatter:
1501     case HloOpcode::kIota:
1502     case HloOpcode::kDot:
1503     case HloOpcode::kDomain:
1504     case HloOpcode::kGetDimensionSize:
1505     case HloOpcode::kSetDimensionSize:
1506     case HloOpcode::kTriangularSolve:
1507     case HloOpcode::kCholesky:
1508       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1509       break;
1510     // Unary ops.
1511     case HloOpcode::kAbs:
1512     case HloOpcode::kRoundNearestAfz:
1513     case HloOpcode::kBitcast:
1514     case HloOpcode::kCeil:
1515     case HloOpcode::kClz:
1516     case HloOpcode::kCopy:
1517     case HloOpcode::kCopyStart:
1518     case HloOpcode::kCopyDone:
1519     case HloOpcode::kCos:
1520     case HloOpcode::kExp:
1521     case HloOpcode::kExpm1:
1522     case HloOpcode::kImag:
1523     case HloOpcode::kIsFinite:
1524     case HloOpcode::kFloor:
1525     case HloOpcode::kLog:
1526     case HloOpcode::kLog1p:
1527     case HloOpcode::kNot:
1528     case HloOpcode::kNegate:
1529     case HloOpcode::kPopulationCount:
1530     case HloOpcode::kReal:
1531     case HloOpcode::kRsqrt:
1532     case HloOpcode::kSign:
1533     case HloOpcode::kSin:
1534     case HloOpcode::kSqrt:
1535     case HloOpcode::kTanh:
1536       CHECK_EQ(new_operands.size(), 1);
1537       clone = CreateUnary(shape, opcode_, new_operands[0]);
1538       break;
1539     // Binary ops.
1540     case HloOpcode::kAdd:
1541     case HloOpcode::kAtan2:
1542     case HloOpcode::kComplex:
1543     case HloOpcode::kDivide:
1544     case HloOpcode::kMultiply:
1545     case HloOpcode::kSubtract:
1546     case HloOpcode::kMaximum:
1547     case HloOpcode::kMinimum:
1548     case HloOpcode::kPower:
1549     case HloOpcode::kRemainder:
1550     case HloOpcode::kAnd:
1551     case HloOpcode::kOr:
1552     case HloOpcode::kXor:
1553     case HloOpcode::kShiftLeft:
1554     case HloOpcode::kShiftRightArithmetic:
1555     case HloOpcode::kShiftRightLogical:
1556       CHECK_EQ(new_operands.size(), 2);
1557       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1558       break;
1559     // Ternary ops.
1560     case HloOpcode::kClamp:
1561     case HloOpcode::kSelect:
1562     case HloOpcode::kTupleSelect:
1563       CHECK_EQ(new_operands.size(), 3);
1564       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1565                             new_operands[2]);
1566       break;
1567     // Other supported ops.
1568     case HloOpcode::kCall:
1569       clone = CreateCall(shape, new_operands, to_apply());
1570       break;
1571     case HloOpcode::kConvert:
1572       CHECK_EQ(new_operands.size(), 1);
1573       clone = CreateConvert(shape, new_operands[0]);
1574       break;
1575     case HloOpcode::kBitcastConvert:
1576       CHECK_EQ(new_operands.size(), 1);
1577       clone = CreateBitcastConvert(shape, new_operands[0]);
1578       break;
1579     case HloOpcode::kDynamicUpdateSlice:
1580       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1581                                        new_operands.subspan(2));
1582       break;
1583     case HloOpcode::kTuple:
1584       clone = CreateTuple(new_operands);
1585       *clone->mutable_shape() = shape;
1586       break;
1587     case HloOpcode::kWhile:
1588       CHECK_EQ(new_operands.size(), 1);
1589       clone =
1590           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1591       break;
1592     case HloOpcode::kConditional:
1593       CHECK_EQ(new_operands.size(), branch_count() + 1);
1594       clone = CreateConditional(shape, new_operands[0],
1595                                 absl::MakeSpan(branch_computations()),
1596                                 new_operands.subspan(1));
1597       break;
1598     case HloOpcode::kAfterAll:
1599       if (new_operands.empty()) {
1600         clone = CreateToken();
1601       } else {
1602         clone = CreateAfterAll(new_operands);
1603       }
1604       break;
1605     case HloOpcode::kAddDependency:
1606       CHECK_EQ(new_operands.size(), 2);
1607       clone = CreateAddDependency(new_operands[0], new_operands[1]);
1608       break;
1609     case HloOpcode::kReplicaId:
1610       CHECK_EQ(new_operands.size(), 0);
1611       clone = CreateReplicaId();
1612       *clone->mutable_shape() = shape;
1613       break;
1614     case HloOpcode::kPartitionId:
1615       CHECK_EQ(new_operands.size(), 0);
1616       clone = CreatePartitionId();
1617       *clone->mutable_shape() = shape;
1618       break;
1619   }
1620   // SetupDerivedInstruction will setup the precision_config_ field.
1621   SetupDerivedInstruction(clone.get());
1622   clone->set_parent(parent_);
1623   clone->set_outer_dimension_partitions(outer_dimension_partitions_);
1624   clone->set_raw_backend_config_string(backend_config_);
1625   if (context != nullptr) {
1626     context->MapInstruction(this, clone.get());
1627     clone->ReplaceCalledComputations([&](HloComputation* callee) {
1628       return callee->parent() != context->module()
1629                  ? context->module()->DeepCloneComputation(callee, context)
1630                  : callee;
1631     });
1632   }
1633   return clone;
1634 }
1635 
~HloInstruction()1636 HloInstruction::~HloInstruction() {
1637   // Detach from operands. An instruction may be repeated as an operand. To
1638   // avoid calling RemoveUser twice on the same operand, check before remove.
1639   for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1640     HloInstruction* operand = operands_[operand_num];
1641     if (operand == nullptr) {
1642       continue;
1643     }
1644     if (operand->user_map_.find(this) != operand->user_map_.end()) {
1645       operand->RemoveUser(this);
1646     }
1647     operands_[operand_num] = nullptr;
1648   }
1649 
1650   // Update users. Set `nullptr` to the corresponding operand slot for users.
1651   for (auto& user : this->users()) {
1652     for (int i = 0; i < user->operand_count(); ++i) {
1653       if (user->operands_[i] == this) {
1654         user->operands_[i] = nullptr;
1655       }
1656     }
1657   }
1658 }
1659 
Clone(const string & suffix,HloCloneContext * context) const1660 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1661     const string& suffix, HloCloneContext* context) const {
1662   std::unique_ptr<HloInstruction> clone =
1663       CloneWithNewOperands(shape_, operands_, context);
1664   if (suffix.empty()) {
1665     clone->name_ = name();
1666   } else {
1667     // If an instruction is cloned multiple times avoid names like
1668     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1669     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1670     // clone of foo.suffix2 is named foo.suffix3 and so on.
1671     const string dot_suffix = "." + suffix;
1672     size_t index = name().rfind(dot_suffix);
1673     if (index == string::npos) {
1674       // Existing name does not include ".suffix".
1675       clone->name_ = name() + dot_suffix;
1676     } else {
1677       // Existing name includes ".suffix". Determine if substring after
1678       // ".suffix" is numeric and should be replaced with an incremented number.
1679       string after_suffix = name().substr(index + dot_suffix.size());
1680       if (after_suffix.empty()) {
1681         // Existing name ends in ".suffix". New name should end in ".suffix2".
1682         clone->name_ = name() + "2";
1683       } else {
1684         // If names ends with .suffix[0-9]+ then replace with a suffix with the
1685         // numeric value incremented.
1686         int64 numeric_suffix;
1687         if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1688           clone->name_ =
1689               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1690         } else {
1691           // Substring after ".suffix" is non-numeric.
1692           clone->name_ = name() + dot_suffix;
1693         }
1694       }
1695     }
1696   }
1697   return clone;
1698 }
1699 
1700 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1701 HloInstruction::LatestNonGteAncestorAndIndex() const {
1702   const HloInstruction* hlo = this;
1703   ShapeIndex index;
1704   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1705     index.push_back(hlo->tuple_index());
1706     hlo = hlo->operand(0);
1707   }
1708 
1709   // We built up index in the reverse order from what we want.
1710   std::reverse(index.begin(), index.end());
1711 
1712   return {hlo, index};
1713 }
1714 
LatestNonGteAncestor() const1715 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1716   const HloInstruction* hlo = this;
1717   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1718     hlo = hlo->operand(0);
1719   }
1720   return hlo;
1721 }
1722 
operand(int64 i) const1723 const HloInstruction* HloInstruction::operand(int64 i) const {
1724   return operands_.at(i);
1725 }
1726 
mutable_operand(int64 i)1727 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1728   CHECK(operands_[i] != nullptr);
1729   return operands_.at(i);
1730 }
1731 
operand_index(const HloInstruction * target) const1732 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1733   for (int64 i = 0; i < operand_count(); ++i) {
1734     if (target == operand(i)) {
1735       return i;
1736     }
1737   }
1738   LOG(FATAL) << "target was not an operand: " << target->ToString();
1739 }
1740 
unique_operands() const1741 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1742   InstructionVector unique;
1743   absl::flat_hash_set<const HloInstruction*> seen;
1744   for (HloInstruction* operand : operands()) {
1745     if (seen.insert(operand).second) {
1746       unique.push_back(operand);
1747     }
1748   }
1749   return unique;
1750 }
1751 
AddControlDependencyTo(HloInstruction * instruction)1752 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1753   TF_RET_CHECK(instruction->parent() == parent());
1754   if (!absl::c_linear_search(control_successors_, instruction)) {
1755     control_successors_.push_back(instruction);
1756     TF_RET_CHECK(
1757         !absl::c_linear_search(instruction->control_predecessors_, this));
1758     instruction->control_predecessors_.push_back(this);
1759   }
1760   return Status::OK();
1761 }
1762 
RemoveControlDependencyTo(HloInstruction * instruction)1763 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1764   TF_RET_CHECK(instruction->parent() == parent());
1765   TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1766   TF_RETURN_IF_ERROR(
1767       EraseElementFromVector(&instruction->control_predecessors_, this));
1768   return Status::OK();
1769 }
1770 
DropAllControlDeps()1771 Status HloInstruction::DropAllControlDeps() {
1772   for (auto* ctrl_succ : control_successors_) {
1773     TF_RETURN_IF_ERROR(
1774         EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1775   }
1776   for (auto* ctrl_pred : control_predecessors_) {
1777     TF_RETURN_IF_ERROR(
1778         EraseElementFromVector(&ctrl_pred->control_successors_, this));
1779   }
1780   control_successors_.clear();
1781   control_predecessors_.clear();
1782   return Status::OK();
1783 }
1784 
CopyAllControlDepsFrom(const HloInstruction * inst)1785 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
1786   for (auto* ctrl_pred : inst->control_predecessors()) {
1787     TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
1788   }
1789 
1790   for (auto* ctrl_succ : inst->control_successors()) {
1791     TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
1792   }
1793 
1794   return Status::OK();
1795 }
1796 
AppendOperand(HloInstruction * operand)1797 void HloInstruction::AppendOperand(HloInstruction* operand) {
1798   operands_.push_back(operand);
1799   operand->AddUser(this);
1800 }
1801 
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)1802 void HloInstruction::RemoveOperandsAtAscendingIndices(
1803     absl::Span<const int> ascending_indices) {
1804   if (ascending_indices.empty()) {
1805     return;
1806   }
1807   int next_index = 0;
1808   int removed_count = 0;
1809   for (int to_remove : ascending_indices) {
1810     while (next_index < to_remove) {
1811       operands_[next_index - removed_count] = operands_[next_index];
1812       ++next_index;
1813     }
1814     CHECK_LT(to_remove, operands_.size());
1815     ++removed_count;
1816     ++next_index;
1817   }
1818   while (next_index < operands_.size()) {
1819     operands_[next_index - removed_count] = operands_[next_index];
1820     ++next_index;
1821   }
1822   CHECK_EQ(removed_count, ascending_indices.size());
1823   operands_.resize(operands_.size() - removed_count);
1824 }
1825 
AddUser(HloInstruction * user)1826 void HloInstruction::AddUser(HloInstruction* user) {
1827   if (!ContainsKey(user_map_, user)) {
1828     user_map_.emplace(user, users_.size());
1829     users_.push_back(user);
1830   }
1831 }
1832 
UserId(HloInstruction * user)1833 int64 HloInstruction::UserId(HloInstruction* user) {
1834   auto result = user_map_.find(user);
1835   CHECK(result != user_map_.end());
1836   return result->second;
1837 }
1838 
HasConstantOperand() const1839 bool HloInstruction::HasConstantOperand() const {
1840   for (const HloInstruction* operand : operands_) {
1841     if (operand->IsConstant()) {
1842       return true;
1843     }
1844   }
1845   return false;
1846 }
1847 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1848 bool HloInstruction::IdenticalSlowPath(
1849     const HloInstruction& other,
1850     const std::function<bool(const HloComputation*, const HloComputation*)>&
1851         eq_computations) const {
1852   // Perform opcode specific checks.
1853   switch (opcode()) {
1854     // The result of these instructions only depend upon their opcode and
1855     // operands.
1856     case HloOpcode::kAbs:
1857     case HloOpcode::kAtan2:
1858     case HloOpcode::kAdd:
1859     case HloOpcode::kBitcast:
1860     case HloOpcode::kBitcastConvert:
1861     case HloOpcode::kCeil:
1862     case HloOpcode::kClamp:
1863     case HloOpcode::kClz:
1864     case HloOpcode::kComplex:
1865     case HloOpcode::kConvert:
1866     case HloOpcode::kCopy:
1867     case HloOpcode::kCopyStart:
1868     case HloOpcode::kCopyDone:
1869     case HloOpcode::kCos:
1870     case HloOpcode::kDivide:
1871     case HloOpcode::kDynamicUpdateSlice:
1872     case HloOpcode::kExp:
1873     case HloOpcode::kExpm1:
1874     case HloOpcode::kFloor:
1875     case HloOpcode::kImag:
1876     case HloOpcode::kIsFinite:
1877     case HloOpcode::kLog:
1878     case HloOpcode::kLog1p:
1879     case HloOpcode::kAnd:
1880     case HloOpcode::kNot:
1881     case HloOpcode::kOr:
1882     case HloOpcode::kXor:
1883     case HloOpcode::kMaximum:
1884     case HloOpcode::kMinimum:
1885     case HloOpcode::kMultiply:
1886     case HloOpcode::kNegate:
1887     case HloOpcode::kPartitionId:
1888     case HloOpcode::kPopulationCount:
1889     case HloOpcode::kPower:
1890     case HloOpcode::kReal:
1891     case HloOpcode::kRemainder:
1892     case HloOpcode::kReshape:
1893     case HloOpcode::kReplicaId:
1894     case HloOpcode::kRoundNearestAfz:
1895     case HloOpcode::kRsqrt:
1896     case HloOpcode::kSelect:
1897     case HloOpcode::kShiftLeft:
1898     case HloOpcode::kShiftRightArithmetic:
1899     case HloOpcode::kShiftRightLogical:
1900     case HloOpcode::kSign:
1901     case HloOpcode::kSin:
1902     case HloOpcode::kSqrt:
1903     case HloOpcode::kSubtract:
1904     case HloOpcode::kTanh:
1905     case HloOpcode::kTuple:
1906     case HloOpcode::kTupleSelect:
1907       return true;
1908 
1909     // This opcode has complex or special behavior so just return false.
1910     case HloOpcode::kAfterAll:
1911     case HloOpcode::kAddDependency:
1912       return false;
1913 
1914     // Remaining instructions with special values.
1915     case HloOpcode::kCall:
1916       return eq_computations(to_apply(), other.to_apply());
1917     case HloOpcode::kConditional:
1918       for (int j = 0; j < branch_count(); ++j) {
1919         if (!eq_computations(branch_computation(j),
1920                              other.branch_computation(j))) {
1921           return false;
1922         }
1923       }
1924       return true;
1925     case HloOpcode::kWhile:
1926       return (eq_computations(while_body(), other.while_body()) &&
1927               eq_computations(while_condition(), other.while_condition()));
1928 
1929     // Ops migrated to subclasses should never come to this line.
1930     // TODO(b/80131774): Remove this switch when migration is complete.
1931     case HloOpcode::kBatchNormTraining:
1932     case HloOpcode::kBatchNormInference:
1933     case HloOpcode::kBatchNormGrad:
1934     case HloOpcode::kFft:
1935     case HloOpcode::kCompare:
1936     case HloOpcode::kSend:
1937     case HloOpcode::kSendDone:
1938     case HloOpcode::kRecv:
1939     case HloOpcode::kRecvDone:
1940     case HloOpcode::kReverse:
1941     case HloOpcode::kConcatenate:
1942     case HloOpcode::kReduce:
1943     case HloOpcode::kSort:
1944     case HloOpcode::kTranspose:
1945     case HloOpcode::kBroadcast:
1946     case HloOpcode::kMap:
1947     case HloOpcode::kSlice:
1948     case HloOpcode::kConstant:
1949     case HloOpcode::kIota:
1950     case HloOpcode::kTrace:
1951     case HloOpcode::kFusion:
1952     case HloOpcode::kRng:
1953     case HloOpcode::kRngGetAndUpdateState:
1954     case HloOpcode::kParameter:
1955     case HloOpcode::kGetTupleElement:
1956     case HloOpcode::kReducePrecision:
1957     case HloOpcode::kInfeed:
1958     case HloOpcode::kOutfeed:
1959     case HloOpcode::kAllReduce:
1960     case HloOpcode::kAllToAll:
1961     case HloOpcode::kCollectivePermute:
1962     case HloOpcode::kConvolution:
1963     case HloOpcode::kCustomCall:
1964     case HloOpcode::kReduceWindow:
1965     case HloOpcode::kSelectAndScatter:
1966     case HloOpcode::kPad:
1967     case HloOpcode::kDynamicSlice:
1968     case HloOpcode::kGather:
1969     case HloOpcode::kScatter:
1970     case HloOpcode::kDot:
1971     case HloOpcode::kDomain:
1972     case HloOpcode::kGetDimensionSize:
1973     case HloOpcode::kSetDimensionSize:
1974     case HloOpcode::kTriangularSolve:
1975     case HloOpcode::kCholesky:
1976       LOG(FATAL) << "Base class impl called for opcode with subclass: "
1977                  << opcode();
1978   }
1979   return false;
1980 }
1981 
HashOperand(const HloInstruction * hlo)1982 static uint64 HashOperand(const HloInstruction* hlo) {
1983   return ShapeUtil::Hash(hlo->shape());
1984 }
1985 
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const1986 uint64 HloInstruction::Hash(
1987     const std::function<uint64(const HloInstruction*)>& hash_operand) const {
1988   using tensorflow::Hash64Combine;
1989 
1990   uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
1991   hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
1992 
1993   if (!IsCrossModuleAllReduce()) {
1994     if (!operands().empty()) {
1995       for (size_t i = 0; i < operands().size(); ++i) {
1996         hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
1997       }
1998     }
1999   }
2000 
2001   hash_value = Hash64Combine(hash_value, InnerHash());
2002   return hash_value;
2003 }
2004 
Hash() const2005 uint64 HloInstruction::Hash() const {
2006   // Use HashOperand as an argument to prevent non-termination.
2007   return Hash(HashOperand);
2008 }
2009 
InnerHash() const2010 uint64 HloInstruction::InnerHash() const { return 13; }
2011 
RemoveUser(HloInstruction * user)2012 void HloInstruction::RemoveUser(HloInstruction* user) {
2013   auto map_it = user_map_.find(user);
2014   CHECK(map_it != user_map_.end());
2015 
2016   const int64 index = map_it->second;
2017   CHECK_EQ(users_[index], user);
2018 
2019   // Move the last user into the position of the removed user.
2020   users_[index] = users_.back();
2021   user_map_[users_.back()] = index;
2022 
2023   // Remove the user from the map and drop the last slot from the vector what
2024   // have been moved to the position of the original user.
2025   user_map_.erase(map_it);
2026   users_.pop_back();
2027 }
2028 
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2029 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2030                                       HloInstruction* new_producer) {
2031   TF_RET_CHECK(
2032       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2033       << "this shape: " << ShapeUtil::HumanString(shape())
2034       << ", replacement shape: "
2035       << ShapeUtil::HumanString(new_producer->shape());
2036   return ReplaceUseWithDifferentShape(user, new_producer);
2037 }
2038 
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2039 Status HloInstruction::ReplaceUseWithDifferentShape(
2040     HloInstruction* user, HloInstruction* new_producer) {
2041   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2042           << " with " << new_producer->name();
2043 
2044   RemoveUser(user);
2045 
2046   TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2047   std::replace(user->operands_.begin(), user->operands_.end(), this,
2048                new_producer);
2049   new_producer->AddUser(user);
2050   // Custom fusions may not be able to handle deduplicated operands.
2051   if (user->opcode() == HloOpcode::kFusion) {
2052     TF_RETURN_IF_ERROR(
2053         Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2054   }
2055   return Status::OK();
2056 }
2057 
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)2058 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
2059                                           HloInstruction* new_operand) {
2060   auto old_operand = operand(operand_num);
2061   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2062                                                         new_operand->shape()))
2063       << old_operand->shape() << " is not compatible with "
2064       << new_operand->shape();
2065   return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2066 }
2067 
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)2068 Status HloInstruction::ReplaceOperandWithDifferentShape(
2069     int64 operand_num, HloInstruction* new_operand) {
2070   TF_RET_CHECK(operand_num >= 0);
2071   TF_RET_CHECK(operand_num < operand_count());
2072   HloInstruction* old_operand = mutable_operand(operand_num);
2073   if (old_operand == new_operand) {
2074     return Status::OK();
2075   }
2076 
2077   operands_[operand_num] = new_operand;
2078 
2079   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2080           << new_operand->name() << ", was " << old_operand->name();
2081 
2082   if (!absl::c_linear_search(operands_, old_operand)) {
2083     old_operand->RemoveUser(this);
2084   }
2085   new_operand->AddUser(this);
2086   return Status::OK();
2087 }
2088 
ReplaceAllUsesWith(HloInstruction * new_producer)2089 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2090   TF_RET_CHECK(
2091       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2092       << shape() << " is not compatible with " << new_producer->shape();
2093   return ReplaceAllUsesWithDifferentShape(new_producer);
2094 }
2095 
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2096 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2097     HloInstruction* new_producer) {
2098   bool new_producer_is_user = false;
2099   for (HloInstruction* user : users()) {
2100     if (user == new_producer) {
2101       // It's possible that new_producer is a user of this instruction as might
2102       // be the case when replacing an instruction with a kCopy of itself. In
2103       // this case, don't do the replacement to avoid creating a cycle in the
2104       // graph. new_producer remains the only user of this instruction.
2105       new_producer_is_user = true;
2106     } else {
2107       std::replace(user->operands_.begin(), user->operands_.end(), this,
2108                    new_producer);
2109       new_producer->AddUser(user);
2110       if (user->opcode() == HloOpcode::kFusion) {
2111         TF_RETURN_IF_ERROR(
2112             Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2113       }
2114     }
2115   }
2116   users_.clear();
2117   user_map_.clear();
2118   if (new_producer_is_user) {
2119     AddUser(new_producer);
2120   }
2121   if (parent_ && parent_->root_instruction() == this) {
2122     parent_->set_root_instruction(new_producer,
2123                                   /*accept_different_shape=*/true);
2124   }
2125 
2126   return Status::OK();
2127 }
2128 
to_apply() const2129 HloComputation* HloInstruction::to_apply() const {
2130   switch (opcode_) {
2131     case HloOpcode::kCall:
2132     case HloOpcode::kMap:
2133     case HloOpcode::kReduceWindow:
2134     case HloOpcode::kReduce:
2135     case HloOpcode::kAllReduce:
2136     case HloOpcode::kScatter:
2137     case HloOpcode::kSort:
2138       CHECK_EQ(called_computations_.size(), 1);
2139       return called_computations_[0];
2140     default:
2141       LOG(FATAL) << "Invalid opcode for to_apply(): "
2142                  << HloOpcodeString(opcode());
2143   }
2144 }
2145 
set_to_apply(HloComputation * computation)2146 void HloInstruction::set_to_apply(HloComputation* computation) {
2147   // Don't allow changing the computation for fused instructions so we don't
2148   // have to recompute called_instructions for the entire fusion instruction.
2149   CHECK(!IsFused());
2150   switch (opcode_) {
2151     case HloOpcode::kCall:
2152     case HloOpcode::kMap:
2153     case HloOpcode::kReduceWindow:
2154     case HloOpcode::kReduce:
2155     case HloOpcode::kAllReduce:
2156     case HloOpcode::kScatter:
2157     case HloOpcode::kSort:
2158       CHECK_EQ(called_computations_.size(), 1);
2159       called_computations_[0] = computation;
2160       break;
2161     default:
2162       LOG(FATAL) << "Invalid opcode for to_apply(): "
2163                  << HloOpcodeString(opcode());
2164   }
2165 }
2166 
while_condition() const2167 HloComputation* HloInstruction::while_condition() const {
2168   CHECK_EQ(HloOpcode::kWhile, opcode_);
2169   return called_computations_[kConditionComputationIndex];
2170 }
2171 
while_body() const2172 HloComputation* HloInstruction::while_body() const {
2173   CHECK_EQ(HloOpcode::kWhile, opcode_);
2174   return called_computations_[kBodyComputationIndex];
2175 }
2176 
set_while_condition(HloComputation * computation)2177 void HloInstruction::set_while_condition(HloComputation* computation) {
2178   // Don't allow changing the computation for fused instructions so we don't
2179   // have to recompute called_instructions for the entire fusion instruction.
2180   CHECK(!IsFused());
2181   CHECK_EQ(HloOpcode::kWhile, opcode_);
2182   called_computations_[kConditionComputationIndex] = computation;
2183 }
2184 
set_while_body(HloComputation * computation)2185 void HloInstruction::set_while_body(HloComputation* computation) {
2186   // Don't allow changing the computation for fused instructions so we don't
2187   // have to recompute called_instructions for the entire fusion instruction.
2188   CHECK(!IsFused());
2189   CHECK_EQ(HloOpcode::kWhile, opcode_);
2190   called_computations_[kBodyComputationIndex] = computation;
2191 }
2192 
while_init() const2193 HloInstruction* HloInstruction::while_init() const {
2194   CHECK_EQ(HloOpcode::kWhile, opcode_);
2195   return operands_[0];
2196 }
2197 
true_computation() const2198 HloComputation* HloInstruction::true_computation() const {
2199   CHECK_EQ(HloOpcode::kConditional, opcode_);
2200   CHECK_EQ(PRED, operand(0)->shape().element_type());
2201   return called_computations_[kTrueComputationIndex];
2202 }
2203 
false_computation() const2204 HloComputation* HloInstruction::false_computation() const {
2205   CHECK_EQ(HloOpcode::kConditional, opcode_);
2206   CHECK_EQ(PRED, operand(0)->shape().element_type());
2207   return called_computations_[kFalseComputationIndex];
2208 }
2209 
branch_computations() const2210 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2211     const {
2212   CHECK(HloOpcode::kConditional == opcode_);
2213   return called_computations_;
2214 }
2215 
branch_count() const2216 int HloInstruction::branch_count() const {
2217   CHECK(HloOpcode::kConditional == opcode_);
2218   return called_computations_.size();
2219 }
2220 
branch_computation(int b) const2221 HloComputation* HloInstruction::branch_computation(int b) const {
2222   CHECK(HloOpcode::kConditional == opcode_);
2223   CHECK_GE(b, 0);
2224   CHECK_LT(b, called_computations_.size());
2225   return called_computations_[b];
2226 }
2227 
set_branch_computation(int b,HloComputation * computation)2228 void HloInstruction::set_branch_computation(int b,
2229                                             HloComputation* computation) {
2230   // Don't allow changing the computation for fused instructions so we don't
2231   // have to recompute called_instructions for the entire fusion instruction.
2232   CHECK(!IsFused());
2233   CHECK_EQ(HloOpcode::kConditional, opcode_);
2234   called_computations_[b] = computation;
2235 }
2236 
SignatureString() const2237 string HloInstruction::SignatureString() const {
2238   string operands =
2239       StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2240         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2241       });
2242   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2243 }
2244 
PrintName(const string & name,bool print_ids)2245 string PrintName(const string& name, bool print_ids) {
2246   if (print_ids) {
2247     return name;
2248   } else {
2249     auto dot_position = name.find_first_of(".");
2250     return name.substr(0, dot_position);
2251   }
2252 }
2253 
2254 namespace {
2255 
2256 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2257 
PrintNameInternal(const string & name,const HloPrintOptions & options)2258 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2259   return StrCat(options.print_percent() ? "%" : "",
2260                 PrintName(name, options.print_ids()));
2261 }
2262 
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2263 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2264   // This set contains HloInstructions from the top of `DFSStack` that might
2265   // belong to the cycle, i.e. if  DFSStack :=[back,...,child,...,top], then
2266   // `subgraph` := {child,...,top}.
2267   absl::flat_hash_set<const HloInstruction*> subgraph;
2268   while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2269     subgraph.insert(dfs_stack->back().second);
2270     dfs_stack->pop_back();
2271   }
2272   // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2273   absl::flat_hash_set<const HloInstruction*> visited;
2274   absl::InlinedVector<const HloInstruction*, 16> dfs;
2275   dfs.push_back(child);
2276   while (!dfs.empty()) {
2277     bool found_next_instr = false;
2278     for (const auto& user : dfs.back()->users()) {
2279       if (user == child) {
2280         dfs.push_back(child);
2281         LOG(INFO) << "\n\nDirected cycle:\n  "
2282                   << absl::StrJoin(
2283                          dfs, "\n  ",
2284                          [](std::string* out, const HloInstruction* instr) {
2285                            out->append(instr->name());
2286                          });
2287         return;
2288       }
2289       if (!subgraph.contains(user) || visited.contains(user)) {
2290         continue;
2291       }
2292       visited.insert(user);
2293       dfs.push_back(user);
2294       found_next_instr = true;
2295     }
2296     if (!found_next_instr) {
2297       dfs.pop_back();
2298     }
2299   }
2300 }
2301 
2302 }  // namespace
2303 
ToString(const HloPrintOptions & options) const2304 string HloInstruction::ToString(const HloPrintOptions& options) const {
2305   CanonicalNameMap new_map;
2306   return ToStringWithCanonicalNameMap(options, &new_map);
2307 }
2308 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2309 bool HloInstruction::IsElementwiseImpl(
2310     const absl::optional<int64>& operand_idx) const {
2311   switch (opcode_) {
2312     // Unary elementwise operations.
2313     case HloOpcode::kAbs:
2314     case HloOpcode::kRoundNearestAfz:
2315     case HloOpcode::kCeil:
2316     case HloOpcode::kClz:
2317     case HloOpcode::kConvert:
2318     case HloOpcode::kBitcastConvert:
2319     case HloOpcode::kCopy:
2320     case HloOpcode::kCos:
2321     case HloOpcode::kExp:
2322     case HloOpcode::kExpm1:
2323     case HloOpcode::kFloor:
2324     case HloOpcode::kImag:
2325     case HloOpcode::kIsFinite:
2326     case HloOpcode::kLog:
2327     case HloOpcode::kLog1p:
2328     case HloOpcode::kNot:
2329     case HloOpcode::kNegate:
2330     case HloOpcode::kPopulationCount:
2331     case HloOpcode::kReal:
2332     case HloOpcode::kReducePrecision:
2333     case HloOpcode::kRsqrt:
2334     case HloOpcode::kSign:
2335     case HloOpcode::kSin:
2336     case HloOpcode::kSqrt:
2337     case HloOpcode::kTanh:
2338       CHECK_EQ(1, operand_count());
2339       return true;
2340 
2341     // Binary elementwise operations, the same as in IsElementwiseBinary().
2342     case HloOpcode::kAdd:
2343     case HloOpcode::kAtan2:
2344     case HloOpcode::kCompare:
2345     case HloOpcode::kComplex:
2346     case HloOpcode::kDivide:
2347     case HloOpcode::kMaximum:
2348     case HloOpcode::kMinimum:
2349     case HloOpcode::kMultiply:
2350     case HloOpcode::kPower:
2351     case HloOpcode::kRemainder:
2352     case HloOpcode::kSubtract:
2353     case HloOpcode::kAnd:
2354     case HloOpcode::kOr:
2355     case HloOpcode::kXor:
2356     case HloOpcode::kShiftLeft:
2357     case HloOpcode::kShiftRightArithmetic:
2358     case HloOpcode::kShiftRightLogical:
2359       CHECK_EQ(2, operand_count());
2360       return true;
2361 
2362     // Ternary elementwise operations.
2363     case HloOpcode::kSelect:
2364     case HloOpcode::kClamp:
2365       return true;
2366 
2367     case HloOpcode::kDynamicUpdateSlice:
2368       return operand_idx.has_value() && operand_idx.value() == 0;
2369 
2370     default:
2371       return false;
2372   }
2373 }
2374 
IsCrossModuleAllReduce() const2375 bool HloInstruction::IsCrossModuleAllReduce() const {
2376   return opcode() == HloOpcode::kAllReduce && channel_id();
2377 }
2378 
IsCrossReplicaAllReduce() const2379 bool HloInstruction::IsCrossReplicaAllReduce() const {
2380   return opcode() == HloOpcode::kAllReduce && !channel_id();
2381 }
2382 
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2383 string HloInstruction::ToStringWithCanonicalNameMap(
2384     const HloPrintOptions& options,
2385     CanonicalNameMap* canonical_name_map) const {
2386   string result = "";
2387 
2388   // Logic to print the instruction name (e.g. "%foo = ").
2389   if (options.canonicalize_instruction_names()) {
2390     if (options.is_in_nested_computation()) {
2391       // If we are canonicalizing instruction names and this is a top-level
2392       // HloInstruction::ToString() call, don't print an instruction name.
2393       StrAppend(&result,
2394                 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2395                                   options),
2396                 " = ");
2397     }
2398   } else {
2399     StrAppend(&result, PrintNameInternal(name(), options), " = ");
2400   }
2401 
2402   // Print shape.
2403   if (options.include_layout_in_shapes()) {
2404     StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()));
2405   } else {
2406     StrAppend(&result, ShapeUtil::HumanString(shape()));
2407   }
2408 
2409   // Print opcode, operand(s).
2410   StrAppend(&result, " ", HloOpcodeString(opcode()), "(",
2411             OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2412             ")");
2413 
2414   // Print additional attributes. If an instruction contains a subcomputation,
2415   // the subcomputation is also printed here.
2416   for (const string& extra : ExtraAttributesToString(options)) {
2417     StrAppend(&result, ", ", extra);
2418   }
2419 
2420   if (options.print_metadata() &&
2421       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2422        !metadata_.source_file().empty())) {
2423     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2424   }
2425   if (options.print_backend_config() && !backend_config_.empty()) {
2426     StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2427   }
2428   return result;
2429 }
2430 
OperandsToString(const HloPrintOptions & options) const2431 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2432   CanonicalNameMap new_map;
2433   return OperandsToStringWithCanonicalNameMap(options, &new_map);
2434 }
2435 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2436 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2437     const HloPrintOptions& options,
2438     CanonicalNameMap* canonical_name_map) const {
2439   string operands;
2440   absl::Span<HloInstruction* const> slice(operands_);
2441   const int64 kMaxOperandsToShowIfCompact = 4;
2442   if (options.compact_operands() &&
2443       slice.size() > kMaxOperandsToShowIfCompact) {
2444     slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2445   }
2446   operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2447     // If operand is already been deleted, put `null` to the string output.
2448     if (operand == nullptr) {
2449       StrAppend(out, "null ");
2450       return;
2451     }
2452     std::vector<string> str;
2453     if (options.print_operand_shape()) {
2454       if (options.include_layout_in_shapes()) {
2455         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2456       } else {
2457         str.push_back(ShapeUtil::HumanString(operand->shape()));
2458       }
2459     }
2460 
2461     // In a top-level HloInstruction::ToString() call, the operand name is not
2462     // part of the canonical string.
2463     if (options.canonicalize_instruction_names() &&
2464         options.is_in_nested_computation()) {
2465       str.push_back(PrintNameInternal(
2466           canonical_name_map->LookupOrInsert(operand->name()), options));
2467     } else if (options.print_operand_names()) {
2468       str.push_back(PrintNameInternal(operand->name(), options));
2469     }
2470     StrAppend(out, StrJoin(str, " "));
2471   });
2472   const int64 remaining = operands_.size() - slice.size();
2473   if (slice.size() != operands_.size()) {
2474     StrAppend(&operands, ", ...(+", remaining, ")");
2475   }
2476   return operands;
2477 }
2478 
ExtraAttributesToString(const HloPrintOptions & options) const2479 std::vector<string> HloInstruction::ExtraAttributesToString(
2480     const HloPrintOptions& options) const {
2481   std::vector<string> extra = ExtraAttributesToStringImpl(options);
2482 
2483   if (options.print_subcomputation_mode() ==
2484       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2485     if (opcode() == HloOpcode::kWhile) {
2486       extra.push_back(StrCat(
2487           "condition=", PrintNameInternal(while_condition()->name(), options)));
2488       extra.push_back(
2489           StrCat("body=", PrintNameInternal(while_body()->name(), options)));
2490     } else if (opcode() == HloOpcode::kSelectAndScatter) {
2491       extra.push_back(
2492           StrCat("select=", PrintNameInternal(select()->name(), options)));
2493       extra.push_back(
2494           StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
2495     } else if (opcode() == HloOpcode::kConditional) {
2496       if (operand(0)->shape().element_type() == PRED) {
2497         extra.push_back(
2498             StrCat("true_computation=",
2499                    PrintNameInternal(true_computation()->name(), options)));
2500         extra.push_back(
2501             StrCat("false_computation=",
2502                    PrintNameInternal(false_computation()->name(), options)));
2503       } else {
2504         extra.push_back(StrCat(
2505             "branch_computations={",
2506             StrJoin(branch_computations(), ", ",
2507                     [&](string* out, const HloComputation* computation) {
2508                       StrAppend(
2509                           out, PrintNameInternal(computation->name(), options));
2510                     }),
2511             "}"));
2512       }
2513     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2514                opcode() == HloOpcode::kReduceWindow ||
2515                opcode() == HloOpcode::kReduce ||
2516                opcode() == HloOpcode::kAllReduce ||
2517                opcode() == HloOpcode::kScatter ||
2518                opcode() == HloOpcode::kSort) {
2519       extra.push_back(
2520           StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
2521     } else if (!called_computations().empty()) {
2522       extra.push_back(StrCat(
2523           "calls=",
2524           StrJoin(called_computations(), ", ",
2525                   [&](string* out, const HloComputation* computation) {
2526                     StrAppend(out,
2527                               PrintNameInternal(computation->name(), options));
2528                   })));
2529     }
2530   } else if (options.print_subcomputation_mode() ==
2531              HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2532     HloPrintOptions new_options = options;
2533     new_options.set_is_in_nested_computation(true);
2534     switch (opcode()) {
2535       case HloOpcode::kWhile:
2536         extra.push_back(
2537             StrCat("condition=\n", while_condition()->ToString(new_options)));
2538         extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2539         break;
2540       case HloOpcode::kSelectAndScatter:
2541         extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2542         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2543         break;
2544       case HloOpcode::kConditional:
2545         if (operand(0)->shape().element_type() == PRED) {
2546           extra.push_back(StrCat("true_computation=\n",
2547                                  true_computation()->ToString(new_options)));
2548           extra.push_back(StrCat("false_computation=\n",
2549                                  false_computation()->ToString(new_options)));
2550         } else {
2551           extra.push_back(StrCat(
2552               "branch_computations={\n",
2553               StrJoin(branch_computations(), ",\n",
2554                       [&](string* out, const HloComputation* computation) {
2555                         StrAppend(out, computation->ToString(new_options));
2556                       }),
2557               "\n}"));
2558         }
2559         break;
2560       case HloOpcode::kCall:
2561       case HloOpcode::kMap:
2562       case HloOpcode::kReduceWindow:
2563       case HloOpcode::kReduce:
2564       case HloOpcode::kAllReduce:
2565       case HloOpcode::kScatter:
2566       case HloOpcode::kSort:
2567         extra.push_back(
2568             StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2569         break;
2570       default:
2571         if (!called_computations().empty()) {
2572           extra.push_back(StrCat(
2573               "calls=\n",
2574               StrJoin(called_computations(), ", ",
2575                       [&](string* out, const HloComputation* computation) {
2576                         StrAppend(out, computation->ToString(new_options));
2577                       })));
2578         }
2579         break;
2580     }
2581   }
2582 
2583   if (has_sharding()) {
2584     extra.push_back(StrCat("sharding=", sharding().ToString()));
2585   }
2586   if (!frontend_attributes_.map().empty()) {
2587     extra.push_back(StrCat("frontend_attributes=",
2588                            FrontendAttributesToString(frontend_attributes_)));
2589   }
2590   if (!outer_dimension_partitions_.empty()) {
2591     extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
2592                                     StrJoin(outer_dimension_partitions_, ",")));
2593   }
2594 
2595   if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2596     extra.push_back(StrCat("control-predecessors={",
2597                            StrJoin(control_predecessors_, ", ",
2598                                    [&](string* out, HloInstruction* pre) {
2599                                      StrAppend(out, PrintNameInternal(
2600                                                         pre->name(), options));
2601                                    }),
2602                            "}"));
2603   }
2604 
2605   return extra;
2606 }
2607 
ToShortString() const2608 string HloInstruction::ToShortString() const {
2609   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2610                 StrJoin(operands_, ", ",
2611                         [](string* out, HloInstruction* operand) {
2612                           StrAppend(out, "%", operand->name());
2613                         }),
2614                 ")");
2615 }
2616 
ToProto() const2617 HloInstructionProto HloInstruction::ToProto() const {
2618   HloInstructionProto proto;
2619   CHECK(unique_id_ != -1)
2620       << "This instruction does not have a valid id. Please make sure the "
2621          "instruction is inside a module before dumping it.";
2622   proto.set_id(unique_id_);
2623   proto.set_name(name_);
2624   proto.set_opcode(HloOpcodeString(opcode_));
2625   *proto.mutable_shape() = shape_.ToProto();
2626   for (const HloInstruction* operand : operands_) {
2627     proto.add_operand_ids(operand->unique_id());
2628   }
2629   for (const HloInstruction* control : control_predecessors_) {
2630     proto.add_control_predecessor_ids(control->unique_id());
2631   }
2632 
2633   *proto.mutable_metadata() = metadata_;
2634   proto.set_backend_config(backend_config_);
2635   if (opcode() != HloOpcode::kFusion) {
2636     for (const HloComputation* computation : called_computations_) {
2637       proto.add_called_computation_ids(computation->unique_id());
2638     }
2639   }
2640 
2641   if (has_sharding()) {
2642     *proto.mutable_sharding() = sharding().ToProto();
2643   }
2644   if (!outer_dimension_partitions_.empty()) {
2645     for (const auto& idx : outer_dimension_partitions_) {
2646       proto.mutable_outer_dimension_partitions()->Add(idx);
2647     }
2648   }
2649 
2650   *proto.mutable_frontend_attributes() = frontend_attributes_;
2651 
2652   return proto;
2653 }
2654 
ToCategory() const2655 string HloInstruction::ToCategory() const {
2656   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2657       opcode() == HloOpcode::kReshape) {
2658     return "data formatting";
2659   }
2660 
2661   if (IsElementwise()) {
2662     return "non-fusion elementwise";
2663   }
2664 
2665   return HloOpcodeString(opcode());
2666 }
2667 
tracing() const2668 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2669 
set_tracing(HloInstruction * trace_instruction)2670 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2671   trace_instruction_ = trace_instruction;
2672 }
2673 
IsFused() const2674 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2675 
IsInputFusion() const2676 bool HloInstruction::IsInputFusion() const {
2677   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
2678 }
2679 
IsLoopFusion() const2680 bool HloInstruction::IsLoopFusion() const {
2681   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
2682 }
2683 
IsOutputFusion() const2684 bool HloInstruction::IsOutputFusion() const {
2685   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
2686 }
2687 
IsCustomFusion() const2688 bool HloInstruction::IsCustomFusion() const {
2689   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
2690 }
2691 
IsFusible() const2692 bool HloInstruction::IsFusible() const {
2693   // Instructions which are traced should not be fused.
2694   if (tracing()) {
2695     return false;
2696   }
2697   // Some kinds of instructions don't make sense to fuse.
2698   switch (opcode_) {
2699     case HloOpcode::kDomain:
2700     case HloOpcode::kParameter:
2701     case HloOpcode::kWhile:
2702     case HloOpcode::kConditional:
2703     case HloOpcode::kCall:
2704       return false;
2705     // Fusions are always fusible.
2706     case HloOpcode::kFusion:
2707     // Side effecting reduce and reduce window would be invalid HLO.
2708     case HloOpcode::kMap:
2709     case HloOpcode::kReduce:
2710     case HloOpcode::kReduceWindow:
2711       return true;
2712     // Side effecting instructions cannot be fused.
2713     default:
2714       return !HasSideEffect();
2715   }
2716 }
2717 
HloInstruction(HloOpcode opcode,const Shape & shape)2718 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
2719     : unique_id_(-1),
2720       opcode_(opcode),
2721       shape_(shape),
2722       name_(HloOpcodeString(opcode)) {
2723   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2724 }
2725 
2726 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)2727 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
2728   switch (opcode_) {
2729     case HloOpcode::kAbs:
2730       return visitor->HandleAbs(this);
2731     case HloOpcode::kAtan2:
2732       return visitor->HandleAtan2(this);
2733     case HloOpcode::kRoundNearestAfz:
2734       return visitor->HandleRound(this);
2735     case HloOpcode::kBatchNormTraining:
2736       return visitor->HandleBatchNormTraining(this);
2737     case HloOpcode::kBatchNormInference:
2738       return visitor->HandleBatchNormInference(this);
2739     case HloOpcode::kBatchNormGrad:
2740       return visitor->HandleBatchNormGrad(this);
2741     case HloOpcode::kSign:
2742       return visitor->HandleSign(this);
2743     case HloOpcode::kConstant:
2744       return visitor->HandleConstant(this);
2745     case HloOpcode::kGetTupleElement:
2746       return visitor->HandleGetTupleElement(this);
2747     case HloOpcode::kParameter:
2748       return visitor->HandleParameter(this);
2749     case HloOpcode::kCompare:
2750       return visitor->HandleCompare(this);
2751     case HloOpcode::kComplex:
2752       return visitor->HandleComplex(this);
2753     case HloOpcode::kAdd:
2754       return visitor->HandleAdd(this);
2755     case HloOpcode::kDivide:
2756       return visitor->HandleDivide(this);
2757     case HloOpcode::kSubtract:
2758       return visitor->HandleSubtract(this);
2759     case HloOpcode::kMaximum:
2760       return visitor->HandleMaximum(this);
2761     case HloOpcode::kMinimum:
2762       return visitor->HandleMinimum(this);
2763     case HloOpcode::kAnd:
2764       return visitor->HandleAnd(this);
2765     case HloOpcode::kOr:
2766       return visitor->HandleOr(this);
2767     case HloOpcode::kXor:
2768       return visitor->HandleXor(this);
2769     case HloOpcode::kShiftLeft:
2770       return visitor->HandleShiftLeft(this);
2771     case HloOpcode::kShiftRightArithmetic:
2772       return visitor->HandleShiftRightArithmetic(this);
2773     case HloOpcode::kShiftRightLogical:
2774       return visitor->HandleShiftRightLogical(this);
2775     case HloOpcode::kConcatenate:
2776       return visitor->HandleConcatenate(this);
2777     case HloOpcode::kConvert:
2778       return visitor->HandleConvert(this);
2779     case HloOpcode::kBitcastConvert:
2780       return visitor->HandleBitcastConvert(this);
2781     case HloOpcode::kCopy:
2782       return visitor->HandleCopy(this);
2783     case HloOpcode::kMultiply:
2784       return visitor->HandleMultiply(this);
2785     case HloOpcode::kDot:
2786       return visitor->HandleDot(this);
2787     case HloOpcode::kPower:
2788       return visitor->HandlePower(this);
2789     case HloOpcode::kRemainder:
2790       return visitor->HandleRemainder(this);
2791     case HloOpcode::kSelect:
2792       return visitor->HandleSelect(this);
2793     case HloOpcode::kTupleSelect:
2794       return visitor->HandleTupleSelect(this);
2795     case HloOpcode::kConvolution:
2796       return visitor->HandleConvolution(this);
2797     case HloOpcode::kFft:
2798       return visitor->HandleFft(this);
2799     case HloOpcode::kAllReduce:
2800       return visitor->HandleAllReduce(this);
2801     case HloOpcode::kAllToAll:
2802       return visitor->HandleAllToAll(this);
2803     case HloOpcode::kCollectivePermute:
2804       return visitor->HandleCollectivePermute(this);
2805     case HloOpcode::kReplicaId:
2806       return visitor->HandleReplicaId(this);
2807     case HloOpcode::kPartitionId:
2808       return visitor->HandlePartitionId(this);
2809     case HloOpcode::kTuple:
2810       return visitor->HandleTuple(this);
2811     case HloOpcode::kMap:
2812       return visitor->HandleMap(this);
2813     case HloOpcode::kClamp:
2814       return visitor->HandleClamp(this);
2815     case HloOpcode::kReduce:
2816       return visitor->HandleReduce(this);
2817     case HloOpcode::kReduceWindow:
2818       return visitor->HandleReduceWindow(this);
2819     case HloOpcode::kSelectAndScatter:
2820       return visitor->HandleSelectAndScatter(this);
2821     case HloOpcode::kNegate:
2822       return visitor->HandleNegate(this);
2823     case HloOpcode::kExp:
2824       return visitor->HandleExp(this);
2825     case HloOpcode::kExpm1:
2826       return visitor->HandleExpm1(this);
2827     case HloOpcode::kFloor:
2828       return visitor->HandleFloor(this);
2829     case HloOpcode::kCeil:
2830       return visitor->HandleCeil(this);
2831     case HloOpcode::kClz:
2832       return visitor->HandleClz(this);
2833     case HloOpcode::kLog:
2834       return visitor->HandleLog(this);
2835     case HloOpcode::kLog1p:
2836       return visitor->HandleLog1p(this);
2837     case HloOpcode::kTanh:
2838       return visitor->HandleTanh(this);
2839     case HloOpcode::kCos:
2840       return visitor->HandleCos(this);
2841     case HloOpcode::kSin:
2842       return visitor->HandleSin(this);
2843     case HloOpcode::kSqrt:
2844       return visitor->HandleSqrt(this);
2845     case HloOpcode::kRsqrt:
2846       return visitor->HandleRsqrt(this);
2847     case HloOpcode::kReal:
2848       return visitor->HandleReal(this);
2849     case HloOpcode::kImag:
2850       return visitor->HandleImag(this);
2851     case HloOpcode::kIsFinite:
2852       return visitor->HandleIsFinite(this);
2853     case HloOpcode::kNot:
2854       return visitor->HandleNot(this);
2855     case HloOpcode::kPopulationCount:
2856       return visitor->HandlePopulationCount(this);
2857     case HloOpcode::kBitcast:
2858       return visitor->HandleBitcast(this);
2859     case HloOpcode::kBroadcast:
2860       return visitor->HandleBroadcast(this);
2861     case HloOpcode::kPad:
2862       return visitor->HandlePad(this);
2863     case HloOpcode::kReshape:
2864       return visitor->HandleReshape(this);
2865     case HloOpcode::kTranspose:
2866       return visitor->HandleTranspose(this);
2867     case HloOpcode::kReverse:
2868       return visitor->HandleReverse(this);
2869     case HloOpcode::kReducePrecision:
2870       return visitor->HandleReducePrecision(this);
2871     case HloOpcode::kSlice:
2872       return visitor->HandleSlice(this);
2873     case HloOpcode::kDynamicSlice:
2874       return visitor->HandleDynamicSlice(this);
2875     case HloOpcode::kDynamicUpdateSlice:
2876       return visitor->HandleDynamicUpdateSlice(this);
2877     case HloOpcode::kSort:
2878       return visitor->HandleSort(this);
2879     case HloOpcode::kInfeed:
2880       return visitor->HandleInfeed(this);
2881     case HloOpcode::kOutfeed:
2882       return visitor->HandleOutfeed(this);
2883     case HloOpcode::kRng:
2884       return visitor->HandleRng(this);
2885     case HloOpcode::kRngGetAndUpdateState:
2886       return visitor->HandleRngGetAndUpdateState(this);
2887     case HloOpcode::kWhile:
2888       return visitor->HandleWhile(this);
2889     case HloOpcode::kFusion:
2890       return visitor->HandleFusion(this);
2891     case HloOpcode::kCall:
2892       return visitor->HandleCall(this);
2893     case HloOpcode::kConditional:
2894       return visitor->HandleConditional(this);
2895     case HloOpcode::kCustomCall:
2896       return visitor->HandleCustomCall(this);
2897     case HloOpcode::kCopyStart:
2898       return visitor->HandleCopyStart(this);
2899     case HloOpcode::kCopyDone:
2900       return visitor->HandleCopyDone(this);
2901     case HloOpcode::kRecv:
2902       return visitor->HandleRecv(this);
2903     case HloOpcode::kRecvDone:
2904       return visitor->HandleRecvDone(this);
2905     case HloOpcode::kSend:
2906       return visitor->HandleSend(this);
2907     case HloOpcode::kSendDone:
2908       return visitor->HandleSendDone(this);
2909     case HloOpcode::kGather:
2910       return visitor->HandleGather(this);
2911     case HloOpcode::kScatter:
2912       return visitor->HandleScatter(this);
2913     case HloOpcode::kDomain:
2914       return visitor->HandleDomain(this);
2915     case HloOpcode::kAfterAll:
2916       return visitor->HandleAfterAll(this);
2917     case HloOpcode::kAddDependency:
2918       return visitor->HandleAddDependency(this);
2919     case HloOpcode::kIota:
2920       return visitor->HandleIota(this);
2921     case HloOpcode::kGetDimensionSize:
2922       return visitor->HandleGetDimensionSize(this);
2923     case HloOpcode::kSetDimensionSize:
2924       return visitor->HandleSetDimensionSize(this);
2925     case HloOpcode::kTriangularSolve:
2926       return visitor->HandleTriangularSolve(this);
2927     case HloOpcode::kCholesky:
2928       return visitor->HandleCholesky(this);
2929 
2930     // These opcodes are not handled here.
2931     case HloOpcode::kTrace:
2932       return Status::OK();
2933   }
2934   return InternalError(
2935       "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
2936       "please file a bug for XLA.",
2937       HloOpcodeString(opcode_));
2938 }
2939 
2940 // Explicit instantiations.
2941 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2942 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2943 
2944 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
2945 // cycle was detected, and true otherwise.
2946 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)2947 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
2948                          HloInstruction* child) {
2949   CHECK(child != nullptr);
2950   const int id = child->unique_id();
2951   CHECK_GE(id, 0) << "instruction may not have a parent computation";
2952   switch (visitor->GetVisitState(id)) {
2953     case Visitor::kVisiting:
2954       return false;
2955 
2956     case Visitor::kVisited:
2957       // Nothing to do
2958       return true;
2959 
2960     case Visitor::kNotVisited:
2961       dfs_stack->push_back(std::make_pair(id, child));
2962       return true;
2963   }
2964 }
2965 
2966 using InternalCompareFunction =
2967     std::function<bool(std::pair<int, const HloInstruction*>,
2968                        std::pair<int, const HloInstruction*>)>;
2969 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)2970 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
2971                            const InternalCompareFunction* operand_order,
2972                            bool ignore_control_predecessors) {
2973   // Calculating the instruction count within a module can be expensive on large
2974   // models so only do it if the visit state is empty. This will help when the
2975   // same visitor is reused across many computations of a single module.
2976   if (visitor->VisitStateCapacity() == 0) {
2977     visitor->ReserveVisitStates(root->GetModule()->instruction_count());
2978   }
2979 
2980   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
2981   //
2982   // We need to keep track of both the id and the instruction because
2983   // instructions can get deleted while they are on the stack, so we
2984   // can't always use the (potentially dead) instruction object to grab
2985   // its id.
2986   DFSStack dfs_stack;
2987   dfs_stack.emplace_back(root->unique_id(), root);
2988 
2989   do {
2990     DCHECK(!dfs_stack.empty());
2991 
2992     int current_id = dfs_stack.back().first;
2993     HloInstruction* current_node = dfs_stack.back().second;
2994     CHECK_GE(current_id, 0) << current_id << ": " << current_node
2995                             << ": instruction may not have parent computation";
2996     typename Visitor::VisitState visit_state =
2997         visitor->GetVisitState(current_id);
2998     if (visit_state == Visitor::kVisited) {
2999       dfs_stack.pop_back();
3000       VLOG(3) << "Not visiting HLO (id = " << current_id
3001               << ") as it was already visited.";
3002       continue;
3003     }
3004 
3005     if (visit_state == Visitor::kVisiting) {
3006       dfs_stack.pop_back();
3007 
3008       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3009       VLOG(2) << "Visiting HLO %" << current_node->name();
3010       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3011       visitor->SetVisitState(current_id, Visitor::kVisited);
3012       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3013       continue;
3014     }
3015 
3016     visitor->SetVisitState(current_id, Visitor::kVisiting);
3017 
3018     const size_t old_dfs_stack_size = dfs_stack.size();
3019     for (HloInstruction* child : current_node->operands()) {
3020       if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3021         PrintCycle(child, &dfs_stack);
3022         return FailedPrecondition(
3023             "A cycle is detected while visiting instruction %s",
3024             current_node->ToString());
3025       }
3026     }
3027 
3028     if (!ignore_control_predecessors) {
3029       for (HloInstruction* child : current_node->control_predecessors()) {
3030         if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3031           PrintCycle(child, &dfs_stack);
3032           return FailedPrecondition(
3033               "A cycle is detected while visiting instruction %s",
3034               current_node->ToString());
3035         }
3036       }
3037     }
3038 
3039     if (operand_order != nullptr) {
3040       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3041                 *operand_order);
3042     }
3043 
3044     // This makes the traversal order the same as what you'd expect
3045     // out of a recursive algorithm.
3046     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3047   } while (!dfs_stack.empty());
3048 
3049   return Status::OK();
3050 }
3051 
3052 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3053 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3054                               bool call_finish_visit,
3055                               bool ignore_control_predecessors) {
3056   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3057   TF_RETURN_IF_ERROR(
3058       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3059   if (call_finish_visit) {
3060     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3061   }
3062   return Status::OK();
3063 }
3064 
3065 // Explicit instantiations.
3066 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3067 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3068 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3069 Status HloInstruction::AcceptWithOperandOrder(
3070     DfsHloVisitor* visitor, const CompareFunction& operand_order,
3071     bool call_finish_visit) {
3072   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3073   InternalCompareFunction func = [&operand_order](
3074                                      std::pair<int, const HloInstruction*> a,
3075                                      std::pair<int, const HloInstruction*> b) {
3076     // Call the client's comparison function on the actual HloInstruction*
3077     // objects (ignoring the internal ids we also have in our stack entries)
3078     return operand_order(a.second, b.second);
3079   };
3080   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3081                                   /*ignore_control_predecessors=*/false));
3082   if (call_finish_visit) {
3083     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3084     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3085     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3086   }
3087   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3088   return Status::OK();
3089 }
3090 
shape() const3091 const Shape& HloInstruction::shape() const { return shape_; }
3092 
OperandIndices(const HloInstruction * operand) const3093 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3094     const HloInstruction* operand) const {
3095   absl::InlinedVector<int64, 4> result;
3096   for (int64 i = 0; i < operand_count(); ++i) {
3097     if (this->operand(i) == operand) {
3098       result.push_back(i);
3099     }
3100   }
3101   return result;
3102 }
3103 
IsElementwiseBinary() const3104 bool HloInstruction::IsElementwiseBinary() const {
3105   return IsElementwise() && operand_count() == 2;
3106 }
3107 
IsElementwise() const3108 bool HloInstruction::IsElementwise() const {
3109   return IsElementwiseImpl(absl::nullopt);
3110 }
3111 
IsElementwiseOnOperand(int64 operand_idx) const3112 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
3113   return IsElementwiseImpl(operand_idx);
3114 }
3115 
3116 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3117 // in HloInstruction::OperandElementUse below.
3118 class HloInstruction::FusionReusesParamElements {
3119  public:
3120   using UseKind = HloInstruction::UseKind;
3121 
3122   // We could rather iterate backwards through fused_instructions_ here, as it
3123   // is in reverse postorder, and compute whether each fused instruction reuses
3124   // the value of this parameter, which would save stack space but not allow us
3125   // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)3126   static UseKind Compute(int64 i, const HloInstruction& hlo) {
3127     absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3128     return ComputeInternal(i, hlo, &memoization_cache);
3129   }
3130 
3131  private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3132   static UseKind ComputeInternal(
3133       int64 i, const HloInstruction& hlo,
3134       absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3135     if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3136       if (hlo_param->parameter_number() == i) {
3137         return UseKind::kUse;
3138       }
3139     }
3140 
3141     auto p = cache->emplace(&hlo, UseKind::kNoUse);
3142     auto value_it = p.first;
3143     const bool key_is_new = p.second;
3144 
3145     if (key_is_new) {
3146       for (int64 j = 0; j < hlo.operands_.size(); ++j) {
3147         UseKind old_val = value_it->second;
3148 
3149         // The next operation invalidates iterators.
3150         UseKind new_val =
3151             Fold(old_val,
3152                  FoldUseMandatory(hlo.OperandElementUse(j),
3153                                   ComputeInternal(i, *hlo.operand(j), cache)));
3154 
3155         // Re-acquire the iterator. We could work harder to do this only if
3156         // absolutely necessary, but this code is not hot enough to warrant
3157         // that.
3158         value_it = cache->find(&hlo);
3159         value_it->second = new_val;
3160       }
3161     }
3162     return value_it->second;
3163   }
3164 
3165   // Combines two UseKinds.
3166   //
3167   // This is the min operation on the lattice
3168   //
3169   //   kReuse < kUse < kNoUse.
3170   //
3171   // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3172   static UseKind Fold(UseKind a, UseKind b) {
3173     // Without loss of generality, let `b` be the operation with the larger use
3174     // kind.
3175     if (b.kind < a.kind) {
3176       std::swap(a, b);
3177     }
3178     // If the kinds are different, return the smaller one, namely `a`.
3179     if (a.kind != b.kind) {
3180       return a;
3181     }
3182     // If the kinds are both kUse, check that they're the same permutation.
3183     if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3184         a.permutation_instr != b.permutation_instr) {
3185       return UseKind::kReuse;
3186     }
3187     return a;  // They're the same.
3188   }
3189 
3190   // Combines two UseKinds differently than Fold().
3191   //
3192   // This is the min operation on the lattice
3193   //
3194   //   kNoUse < kReuse < kUse.
3195   //
3196   // If `a` and `b` are both kUse and one has a non-null permutation
3197   // instruction, returns kUse with that permutation.  OTOH if both have
3198   // different, non-null permutation instructions, returns kReuse.
3199   //
3200   // You can think of this sort of as a conjunction, whereas Fold is sort of a
3201   // disjunction.  FoldUseMandatory() says "no use" if either input isn't used,
3202   // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3203   static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3204     if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3205       return UseKind::kNoUse;
3206     }
3207     if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3208       return UseKind::kReuse;
3209     }
3210     if (a.permutation_instr == b.permutation_instr) {
3211       return a;  // They're the same.
3212     }
3213     if (b.permutation_instr == nullptr) {
3214       return a;
3215     }
3216     if (a.permutation_instr == nullptr) {
3217       return b;
3218     }
3219     return UseKind::kReuse;
3220   }
3221 };
3222 
OperandElementUse(int64 operand_num) const3223 HloInstruction::UseKind HloInstruction::OperandElementUse(
3224     int64 operand_num) const {
3225   switch (opcode_) {
3226     case HloOpcode::kBitcast:
3227       // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3228       // doesn't permute its elements, so it counts as a plain, non-permuting
3229       // use.
3230       return ShapeUtil::DropDegenerateDimensions(shape()) ==
3231                      ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3232                  ? UseKind::kUse
3233                  : UseKind::Permuting(this);
3234     case HloOpcode::kConcatenate:
3235     case HloOpcode::kReshape:
3236     case HloOpcode::kReverse:
3237     case HloOpcode::kSlice:
3238     case HloOpcode::kTranspose:
3239       return UseKind::Permuting(this);
3240     case HloOpcode::kPad:
3241       // Pad reuses the padding value but not the padded array elements.
3242       return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3243     case HloOpcode::kReduce:
3244       // Reduce reuses the init values but not the operand array elements.
3245       return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3246                  ? UseKind::kReuse
3247                  : UseKind::Permuting(this);
3248     case HloOpcode::kFusion:
3249       // Uses the memoizing, recursive computation defined above.
3250       return FusionReusesParamElements::Compute(operand_num,
3251                                                 *fused_expression_root());
3252     case HloOpcode::kDot:
3253       // Matrix-vector dots do not reuse the matrix operand.
3254       if (shape().dimensions_size() <= 1) {
3255         if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3256             (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3257           return UseKind::kUse;
3258         }
3259       }
3260       return UseKind::kReuse;
3261     case HloOpcode::kDynamicUpdateSlice:
3262       // Dynamic-update-slice reuses only start_indices.
3263       if (operand_num == 0 || operand_num == 1) {
3264         return UseKind::kUse;
3265       }
3266       return UseKind::kReuse;
3267     case HloOpcode::kGather:
3268       // Gather reads its indices in a linear fashion, and it permutes the
3269       // vector it's gathering from.
3270       return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3271     default:
3272       return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3273   }
3274 }
3275 
3276 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3277 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3278   if (HloOpcode::kReshape != opcode_) {
3279     return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3280   }
3281   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3282                                                       shape_);
3283 }
3284 
ToString(HloInstruction::FusionKind kind)3285 string ToString(HloInstruction::FusionKind kind) {
3286   switch (kind) {
3287     case HloInstruction::FusionKind::kLoop:
3288       return "kLoop";
3289     case HloInstruction::FusionKind::kInput:
3290       return "kInput";
3291     case HloInstruction::FusionKind::kOutput:
3292       return "kOutput";
3293     case HloInstruction::FusionKind::kCustom:
3294       return "kCustom";
3295   }
3296 }
3297 
StringToFusionKind(const string & kind_name)3298 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3299     const string& kind_name) {
3300   if (kind_name == "kLoop") {
3301     return HloInstruction::FusionKind::kLoop;
3302   }
3303   if (kind_name == "kInput") {
3304     return HloInstruction::FusionKind::kInput;
3305   }
3306   if (kind_name == "kOutput") {
3307     return HloInstruction::FusionKind::kOutput;
3308   }
3309   if (kind_name == "kCustom") {
3310     return HloInstruction::FusionKind::kCustom;
3311   }
3312   return InvalidArgument("Unknown fusion kind: %s", kind_name);
3313 }
3314 
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3315 string FrontendAttributesToString(
3316     const FrontendAttributes& frontend_attributes) {
3317   std::vector<std::pair<string, string>> sorted_attributes(
3318       frontend_attributes.map().begin(), frontend_attributes.map().end());
3319   absl::c_sort(sorted_attributes);
3320   return absl::StrFormat(
3321       "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
3322 }
3323 
PaddingConfigToString(const PaddingConfig & padding)3324 string PaddingConfigToString(const PaddingConfig& padding) {
3325   bool has_interior_padding =
3326       absl::c_any_of(padding.dimensions(),
3327                      [](const PaddingConfig::PaddingConfigDimension& dim) {
3328                        return dim.interior_padding() != 0;
3329                      });
3330   return StrJoin(
3331       padding.dimensions(), "x",
3332       [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3333         StrAppend(
3334             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3335             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3336       });
3337 }
3338 
OpMetadataToString(const OpMetadata & metadata)3339 string OpMetadataToString(const OpMetadata& metadata) {
3340   std::vector<string> result;
3341   if (!metadata.op_type().empty()) {
3342     result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
3343   }
3344   if (!metadata.op_name().empty()) {
3345     result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
3346   }
3347   if (!metadata.source_file().empty()) {
3348     result.push_back(
3349         StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
3350   }
3351   if (metadata.source_line() != 0) {
3352     result.push_back(StrCat("source_line=", metadata.source_line()));
3353   }
3354   return StrJoin(result, " ");
3355 }
3356 
RandomDistributionToString(const RandomDistribution & distribution)3357 string RandomDistributionToString(const RandomDistribution& distribution) {
3358   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3359 }
3360 
PrecisionToString(const PrecisionConfig::Precision & precision)3361 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3362   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3363 }
3364 
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3365 string ConvolutionDimensionNumbersToString(
3366     const ConvolutionDimensionNumbers& dnums) {
3367   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3368   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3369   std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3370   lhs_dims[dnums.input_batch_dimension()] = 'b';
3371   lhs_dims[dnums.input_feature_dimension()] = 'f';
3372   for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3373     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3374   }
3375 
3376   std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3377   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3378   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3379   for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3380     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3381   }
3382 
3383   std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3384   output_dims[dnums.output_batch_dimension()] = 'b';
3385   output_dims[dnums.output_feature_dimension()] = 'f';
3386   for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3387     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3388   }
3389 
3390   return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3391                 StrJoin(output_dims, ""));
3392 }
3393 
ReplicaGroupsToString(const std::vector<ReplicaGroup> & replica_groups)3394 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
3395   std::vector<string> replica_group_str;
3396   replica_group_str.reserve(replica_groups.size());
3397   for (const ReplicaGroup& group : replica_groups) {
3398     replica_group_str.push_back(
3399         StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3400   }
3401   return StrCat("{", StrJoin(replica_group_str, ","), "}");
3402 }
3403 
StringToRandomDistribution(const string & name)3404 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3405   static std::unordered_map<string, RandomDistribution>* map = [] {
3406     static auto* map = new std::unordered_map<string, RandomDistribution>;
3407     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3408       if (RandomDistribution_IsValid(i)) {
3409         auto value = static_cast<RandomDistribution>(i);
3410         (*map)[RandomDistributionToString(value)] = value;
3411       }
3412     }
3413     return map;
3414   }();
3415   auto found = map->find(absl::AsciiStrToLower(name));
3416   if (found == map->end()) {
3417     return InvalidArgument("Unknown distribution");
3418   }
3419   return found->second;
3420 }
3421 
StringToPrecision(const string & name)3422 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3423   static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3424     static auto* map =
3425         new std::unordered_map<string, PrecisionConfig::Precision>;
3426     for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3427       if (PrecisionConfig::Precision_IsValid(i)) {
3428         auto value = static_cast<PrecisionConfig::Precision>(i);
3429         (*map)[PrecisionToString(value)] = value;
3430       }
3431     }
3432     return map;
3433   }();
3434   auto found = map->find(absl::AsciiStrToLower(name));
3435   if (found == map->end()) {
3436     return InvalidArgument("Unknown distribution");
3437   }
3438   return found->second;
3439 }
3440 
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3441 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3442   return os << ToString(kind);
3443 }
3444 
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3445 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3446                                   const HloInstruction* const& rhs) const {
3447   if (rhs == nullptr) {
3448     // Nothing compares less than nullptr.
3449     return false;
3450   }
3451   if (lhs == nullptr) {
3452     return true;
3453   }
3454   auto lhs_module = lhs->GetModule();
3455   auto rhs_module = rhs->GetModule();
3456   CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3457         (lhs_module != nullptr && rhs_module != nullptr));
3458   if (lhs_module != nullptr &&
3459       lhs_module->unique_id() != rhs_module->unique_id()) {
3460     return lhs_module->unique_id() < rhs_module->unique_id();
3461   }
3462   return lhs->unique_id() < rhs->unique_id();
3463 }
3464 
CouldBeBitcast() const3465 bool HloInstruction::CouldBeBitcast() const {
3466   switch (opcode_) {
3467     case HloOpcode::kTranspose:
3468       return true;
3469     case HloOpcode::kReshape:
3470       return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
3471     default:
3472       return false;
3473   }
3474 }
3475 
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3476 Status HloInstruction::GetBackendConfigInternal(
3477     tensorflow::protobuf::Message* proto) const {
3478   proto->Clear();
3479 
3480   // Empty string does not parse as valid JSON, but it's a valid backend config,
3481   // corresponding to the empty proto.
3482   if (backend_config_.empty()) {
3483     return Status::OK();
3484   }
3485   return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3486 }
3487 
set_backend_config(const tensorflow::protobuf::Message & proto)3488 Status HloInstruction::set_backend_config(
3489     const tensorflow::protobuf::Message& proto) {
3490   TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3491   return Status::OK();
3492 }
3493 
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3494 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3495     const tensorflow::protobuf::Message& proto) {
3496   string ret;
3497   // Pass ignore_accuracy_loss = true because estimated_cycles field can be
3498   // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
3499   // INT64_MAX, JsonFormat will return an error status, although there is no
3500   // accuracy loss for int64.
3501   TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
3502       proto, &ret, /*ignore_accuracy_loss=*/true));
3503   return ret;
3504 }
3505 
precision_config() const3506 const PrecisionConfig& HloInstruction::precision_config() const {
3507   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3508     return convolution->precision_config();
3509   }
3510   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3511     return dot->precision_config();
3512   }
3513   LOG(FATAL) << "Unimplemented method.";
3514 }
3515 
mutable_precision_config()3516 PrecisionConfig* HloInstruction::mutable_precision_config() {
3517   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3518     return convolution->mutable_precision_config();
3519   }
3520   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3521     return dot->mutable_precision_config();
3522   }
3523   LOG(FATAL) << "Unimplemented method.";
3524 }
3525 
GetModule() const3526 HloModule* HloInstruction::GetModule() const {
3527   if (parent_) {
3528     return parent_->parent();
3529   }
3530   return nullptr;
3531 }
3532 
UniquifyName(NameUniquer * name_uniquer)3533 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3534   string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3535   name_ = name_uniquer->GetUniqueName(name_);
3536 }
3537 
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3538 void HloInstruction::set_outer_dimension_partitions(
3539     const std::vector<int64>& outer_dimension_partitions) {
3540   outer_dimension_partitions_ = outer_dimension_partitions;
3541 }
3542 
3543 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3544 int64 HloInstruction::feature_index() const {
3545   return Cast<HloBatchNormInstruction>(this)->feature_index();
3546 }
3547 
epsilon() const3548 float HloInstruction::epsilon() const {
3549   return Cast<HloBatchNormInstruction>(this)->epsilon();
3550 }
3551 
fft_type() const3552 FftType HloInstruction::fft_type() const {
3553   return Cast<HloFftInstruction>(this)->fft_type();
3554 }
3555 
fft_length() const3556 const std::vector<int64>& HloInstruction::fft_length() const {
3557   return Cast<HloFftInstruction>(this)->fft_length();
3558 }
3559 
concatenate_dimension() const3560 int64 HloInstruction::concatenate_dimension() const {
3561   return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3562 }
3563 
dimension() const3564 int64 HloInstruction::dimension() const {
3565   if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
3566     return set_size->dimension();
3567   }
3568   return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3569 }
3570 
inferred_dimension() const3571 int64 HloInstruction::inferred_dimension() const {
3572   return Cast<HloReshapeInstruction>(this)->inferred_dimension();
3573 }
3574 
IsRank2Transpose() const3575 bool HloInstruction::IsRank2Transpose() const {
3576   auto transpose = DynCast<HloTransposeInstruction>(this);
3577   return transpose != nullptr && transpose->IsRank2Transpose();
3578 }
3579 
slice_starts(int64 dimension) const3580 int64 HloInstruction::slice_starts(int64 dimension) const {
3581   return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3582 }
3583 
slice_starts() const3584 const std::vector<int64>& HloInstruction::slice_starts() const {
3585   return Cast<HloSliceInstruction>(this)->slice_starts();
3586 }
3587 
slice_limits(int64 dimension) const3588 int64 HloInstruction::slice_limits(int64 dimension) const {
3589   return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3590 }
3591 
slice_limits() const3592 const std::vector<int64>& HloInstruction::slice_limits() const {
3593   return Cast<HloSliceInstruction>(this)->slice_limits();
3594 }
3595 
slice_strides(int64 dimension) const3596 int64 HloInstruction::slice_strides(int64 dimension) const {
3597   return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3598 }
3599 
slice_strides() const3600 const std::vector<int64>& HloInstruction::slice_strides() const {
3601   return Cast<HloSliceInstruction>(this)->slice_strides();
3602 }
3603 
literal() const3604 const Literal& HloInstruction::literal() const {
3605   return Cast<HloConstantInstruction>(this)->literal();
3606 }
3607 
IsConstant() const3608 bool HloInstruction::IsConstant() const {
3609   return DynCast<HloConstantInstruction>(this) != nullptr;
3610 }
3611 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3612 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3613                                       const ShapeIndex& shape_index) {
3614   Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3615 }
3616 
TracingTag() const3617 string HloInstruction::TracingTag() const {
3618   return Cast<HloTraceInstruction>(this)->TracingTag();
3619 }
3620 
AddFusionOperand(HloInstruction * new_operand)3621 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3622   return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3623 }
3624 
3625 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3626 void HloInstruction::MergeFusionInstruction(
3627     HloInstruction* instruction_to_merge) {
3628   return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3629       Cast<HloFusionInstruction>(instruction_to_merge));
3630 }
3631 
3632 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3633 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3634     HloInstruction* instruction_to_merge) {
3635   return Cast<HloFusionInstruction>(this)
3636       ->MergeFusionInstructionIntoMultiOutput(
3637           Cast<HloFusionInstruction>(instruction_to_merge));
3638 }
3639 
FuseInstruction(HloInstruction * instruction_to_fuse)3640 HloInstruction* HloInstruction::FuseInstruction(
3641     HloInstruction* instruction_to_fuse) {
3642   return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
3643 }
3644 
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)3645 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
3646     HloInstruction* instruction_to_fuse) {
3647   return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
3648       instruction_to_fuse);
3649 }
3650 
fused_instructions_computation() const3651 HloComputation* HloInstruction::fused_instructions_computation() const {
3652   return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
3653 }
3654 
fused_expression_root() const3655 HloInstruction* HloInstruction::fused_expression_root() const {
3656   return Cast<HloFusionInstruction>(this)->fused_expression_root();
3657 }
3658 
3659 const tensorflow::gtl::iterator_range<UnwrappingIterator<
3660     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const3661 HloInstruction::fused_instructions() const {
3662   return Cast<HloFusionInstruction>(this)->fused_instructions();
3663 }
3664 
3665 const tensorflow::gtl::iterator_range<
3666     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()3667 HloInstruction::fused_instructions() {
3668   return Cast<HloFusionInstruction>(this)->fused_instructions();
3669 }
3670 
fused_instruction_count() const3671 int64 HloInstruction::fused_instruction_count() const {
3672   return Cast<HloFusionInstruction>(this)->fused_instruction_count();
3673 }
3674 
fused_parameter(int64 parameter_number) const3675 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
3676   return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
3677 }
3678 
fused_parameters() const3679 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
3680   return Cast<HloFusionInstruction>(this)->fused_parameters();
3681 }
3682 
IsMultiOutputFusion() const3683 const bool HloInstruction::IsMultiOutputFusion() const {
3684   const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
3685   return fusion != nullptr && fusion->IsMultiOutputFusion();
3686 }
3687 
fusion_kind() const3688 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
3689   return Cast<HloFusionInstruction>(this)->fusion_kind();
3690 }
3691 
set_fusion_kind(FusionKind kind)3692 void HloInstruction::set_fusion_kind(FusionKind kind) {
3693   return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
3694 }
3695 
random_distribution() const3696 RandomDistribution HloInstruction::random_distribution() const {
3697   return Cast<HloRngInstruction>(this)->random_distribution();
3698 }
3699 
parameter_number() const3700 int64 HloInstruction::parameter_number() const {
3701   return Cast<HloParameterInstruction>(this)->parameter_number();
3702 }
3703 
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)3704 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3705     absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
3706   return Cast<HloParameterInstruction>(this)
3707       ->set_parameter_replicated_at_leaf_buffers(
3708           parameter_replicated_at_leaf_buffers);
3709 }
3710 
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)3711 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3712     const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
3713   return Cast<HloParameterInstruction>(this)
3714       ->set_parameter_replicated_at_leaf_buffers(
3715           parameter_replicated_at_leaf_buffers);
3716 }
3717 
3718 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const3719 HloInstruction::parameter_replicated_at_leaf_buffers() const {
3720   return Cast<HloParameterInstruction>(this)
3721       ->parameter_replicated_at_leaf_buffers();
3722 }
3723 
tuple_index() const3724 int64 HloInstruction::tuple_index() const {
3725   return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
3726 }
3727 
set_tuple_index(int64 new_tuple_index)3728 void HloInstruction::set_tuple_index(int64 new_tuple_index) {
3729   return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
3730       new_tuple_index);
3731 }
3732 
exponent_bits() const3733 int32 HloInstruction::exponent_bits() const {
3734   return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
3735 }
3736 
mantissa_bits() const3737 int32 HloInstruction::mantissa_bits() const {
3738   return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
3739 }
3740 
infeed_config() const3741 string HloInstruction::infeed_config() const {
3742   return Cast<HloInfeedInstruction>(this)->infeed_config();
3743 }
3744 
set_infeed_config(const string & config)3745 void HloInstruction::set_infeed_config(const string& config) {
3746   return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
3747 }
3748 
outfeed_shape() const3749 const Shape& HloInstruction::outfeed_shape() const {
3750   return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
3751 }
3752 
outfeed_config() const3753 const string& HloInstruction::outfeed_config() const {
3754   return Cast<HloOutfeedInstruction>(this)->outfeed_config();
3755 }
3756 
replica_groups() const3757 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
3758   return Cast<HloCollectiveInstruction>(this)->replica_groups();
3759 }
3760 
3761 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const3762 HloInstruction::source_target_pairs() const {
3763   return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
3764 }
3765 
channel_id() const3766 absl::optional<int64> HloInstruction::channel_id() const {
3767   return Cast<HloChannelInstruction>(this)->channel_id();
3768 }
3769 
set_channel_id(const absl::optional<int64> & channel_id)3770 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
3771   return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
3772 }
3773 
3774 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const3775 HloInstruction::convolution_dimension_numbers() const {
3776   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3777     return convolution->convolution_dimension_numbers();
3778   }
3779   if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3780     return custom_call->convolution_dimension_numbers();
3781   }
3782   LOG(FATAL) << "Unimplemented method.";
3783 }
3784 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)3785 void HloInstruction::set_convolution_dimension_numbers(
3786     const ConvolutionDimensionNumbers& dnums) {
3787   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3788     convolution->set_convolution_dimension_numbers(dnums);
3789   } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3790     custom_call->set_convolution_dimension_numbers(dnums);
3791   } else {
3792     LOG(FATAL) << "Unimplemented method.";
3793   }
3794 }
3795 
feature_group_count() const3796 int64 HloInstruction::feature_group_count() const {
3797   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3798     return convolution->feature_group_count();
3799   }
3800   return Cast<HloCustomCallInstruction>(this)->feature_group_count();
3801 }
3802 
set_feature_group_count(int64 feature_group_count)3803 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
3804   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3805     return convolution->set_feature_group_count(feature_group_count);
3806   }
3807   Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
3808       feature_group_count);
3809 }
3810 
batch_group_count() const3811 int64 HloInstruction::batch_group_count() const {
3812   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3813     return convolution->batch_group_count();
3814   }
3815   return Cast<HloCustomCallInstruction>(this)->batch_group_count();
3816 }
3817 
set_batch_group_count(int64 batch_group_count)3818 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
3819   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3820     return convolution->set_batch_group_count(batch_group_count);
3821   }
3822   Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
3823       batch_group_count);
3824 }
3825 
select() const3826 HloComputation* HloInstruction::select() const {
3827   return Cast<HloSelectAndScatterInstruction>(this)->select();
3828 }
3829 
scatter() const3830 HloComputation* HloInstruction::scatter() const {
3831   return Cast<HloSelectAndScatterInstruction>(this)->scatter();
3832 }
3833 
set_select(HloComputation * computation)3834 void HloInstruction::set_select(HloComputation* computation) {
3835   return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
3836 }
3837 
set_scatter(HloComputation * computation)3838 void HloInstruction::set_scatter(HloComputation* computation) {
3839   return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
3840 }
3841 
custom_call_target() const3842 const string& HloInstruction::custom_call_target() const {
3843   return Cast<HloCustomCallInstruction>(this)->custom_call_target();
3844 }
3845 
padding_config() const3846 const PaddingConfig& HloInstruction::padding_config() const {
3847   return Cast<HloPadInstruction>(this)->padding_config();
3848 }
3849 
slice_sizes(int64 dimension) const3850 int64 HloInstruction::slice_sizes(int64 dimension) const {
3851   return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
3852 }
3853 
dynamic_slice_sizes() const3854 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
3855   return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
3856 }
3857 
gather_dimension_numbers() const3858 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
3859   return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
3860 }
3861 
gather_slice_sizes() const3862 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
3863   return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
3864 }
3865 
scatter_dimension_numbers() const3866 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
3867     const {
3868   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
3869 }
3870 
dot_dimension_numbers() const3871 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
3872   return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
3873 }
3874 
operand_side_metadata() const3875 const DomainMetadata& HloInstruction::operand_side_metadata() const {
3876   return Cast<HloDomainInstruction>(this)->operand_side_metadata();
3877 }
3878 
user_side_metadata() const3879 const DomainMetadata& HloInstruction::user_side_metadata() const {
3880   return Cast<HloDomainInstruction>(this)->user_side_metadata();
3881 }
3882 
comparison_direction() const3883 ComparisonDirection HloInstruction::comparison_direction() const {
3884   return Cast<HloCompareInstruction>(this)->direction();
3885 }
3886 
triangular_solve_options() const3887 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
3888   return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
3889 }
3890 
cholesky_options() const3891 const CholeskyOptions& HloInstruction::cholesky_options() const {
3892   return Cast<HloCholeskyInstruction>(this)->cholesky_options();
3893 }
3894 
3895 }  // namespace xla
3896