1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <iostream>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <utility>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/strings/ascii.h"
33 #include "absl/strings/escaping.h"
34 #include "absl/strings/numbers.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/string_view.h"
38 #include "absl/types/span.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/literal.h"
41 #include "tensorflow/compiler/xla/protobuf_util.h"
42 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_computation.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_module.h"
47 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
50 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
51 #include "tensorflow/compiler/xla/service/name_uniquer.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/status_macros.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/gtl/map_util.h"
59 #include "tensorflow/core/platform/errors.h"
60 #include "tensorflow/core/platform/human_readable_json.h"
61 #include "tensorflow/core/platform/logging.h"
62
63 namespace xla {
64
65 using absl::CEscape;
66 using absl::StrAppend;
67 using absl::StrCat;
68 using absl::StrJoin;
69
AddInstruction(std::unique_ptr<HloInstruction> derived_instruction)70 HloInstruction* HloInstruction::AddInstruction(
71 std::unique_ptr<HloInstruction> derived_instruction) {
72 HloInstruction* derived =
73 parent()->AddInstruction(std::move(derived_instruction));
74 const bool has_prior_sharding = derived->has_sharding();
75 SetupDerivedInstruction(derived);
76 if (!has_prior_sharding && (derived->opcode() == HloOpcode::kReshape ||
77 derived->opcode() == HloOpcode::kTranspose)) {
78 derived->clear_sharding();
79 }
80 return derived;
81 }
82
83 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64_t,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64_t,HloComputation * > & computation_map,bool prohibit_empty_literal)84 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
85 const HloInstructionProto& proto,
86 const absl::flat_hash_map<int64_t, HloInstruction*>& instruction_map,
87 const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
88 bool prohibit_empty_literal) {
89 TF_RET_CHECK(!proto.opcode().empty());
90 HloOpcode opcode;
91 auto opcode_or = StringToHloOpcode(proto.opcode());
92 std::optional<ComparisonDirection> comparison_direction;
93 if (opcode_or.ok()) {
94 opcode = std::move(opcode_or).value();
95 } else {
96 // Unknown opcode. Try auto-upgrading deprecated "less-than",
97 // "greater-than", etc opcodes, which are now rolled into the kCompare
98 // opcode.
99 if (proto.opcode() == "equal-to") {
100 comparison_direction = ComparisonDirection::kEq;
101 } else if (proto.opcode() == "not-equal-to") {
102 comparison_direction = ComparisonDirection::kNe;
103 } else if (proto.opcode() == "greater-than-or-equal-to") {
104 comparison_direction = ComparisonDirection::kGe;
105 } else if (proto.opcode() == "greater-than") {
106 comparison_direction = ComparisonDirection::kGt;
107 } else if (proto.opcode() == "less-than-or-equal-to") {
108 comparison_direction = ComparisonDirection::kLe;
109 } else if (proto.opcode() == "less-than") {
110 comparison_direction = ComparisonDirection::kLt;
111 }
112 if (comparison_direction) {
113 opcode = HloOpcode::kCompare;
114 } else {
115 return InvalidArgument("Unknown opcode: %s", proto.opcode());
116 }
117 }
118
119 TF_RET_CHECK(proto.has_shape());
120
121 std::unique_ptr<HloInstruction> instruction;
122 const auto operands = [&instruction_map, &proto](int index) {
123 return instruction_map.at(proto.operand_ids(index));
124 };
125 const auto all_operands = [&instruction_map, &proto]() {
126 std::vector<HloInstruction*> result(proto.operand_ids_size());
127 std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
128 result.begin(), [&instruction_map](int64_t operand_id) {
129 return instruction_map.at(operand_id);
130 });
131 return result;
132 };
133 const auto computations = [&computation_map, &proto](int index) {
134 return computation_map.at(proto.called_computation_ids(index));
135 };
136 const auto all_computations = [&computation_map, &proto]() {
137 std::vector<HloComputation*> result(proto.called_computation_ids_size());
138 std::transform(proto.called_computation_ids().begin(),
139 proto.called_computation_ids().end(), result.begin(),
140 [&computation_map](int64_t computation_id) {
141 return computation_map.at(computation_id);
142 });
143 return result;
144 };
145
146 TF_RET_CHECK(
147 absl::c_all_of(proto.operand_ids(),
148 [&](int64_t id) { return instruction_map.contains(id); }))
149 << proto.name() << " instruction contains invalid operand id(s)";
150
151 TF_RET_CHECK(
152 absl::c_all_of(proto.called_computation_ids(),
153 [&](int64_t id) { return computation_map.contains(id); }))
154 << proto.name() << " instruction references invalid computation id(s)";
155
156 Shape shape(proto.shape());
157 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
158
159 std::optional<int> arity = HloOpcodeArity(opcode);
160 if (arity) {
161 TF_RET_CHECK(proto.operand_ids_size() == *arity)
162 << proto.opcode() << " instruction should have " << *arity
163 << " operands but sees " << proto.operand_ids_size();
164 }
165
166 switch (opcode) {
167 // Ops migrated to subclasses.
168 case HloOpcode::kBatchNormTraining:
169 instruction =
170 CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
171 proto.epsilon(), proto.feature_index());
172 break;
173 case HloOpcode::kBatchNormInference:
174 instruction = CreateBatchNormInference(
175 shape, operands(0), operands(1), operands(2), operands(3),
176 operands(4), proto.epsilon(), proto.feature_index());
177 break;
178 case HloOpcode::kBatchNormGrad:
179 instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
180 operands(2), operands(3), operands(4),
181 proto.epsilon(), proto.feature_index());
182 break;
183 case HloOpcode::kFft: {
184 std::vector<int64_t> fft_length(proto.fft_length().begin(),
185 proto.fft_length().end());
186 instruction = CreateFft(shape, operands(0), proto.fft_type(),
187 absl::Span<const int64_t>(fft_length));
188 break;
189 }
190 case HloOpcode::kAsyncStart: {
191 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
192 << "Async start instruction should have 1 called computation but "
193 "sees "
194 << proto.called_computation_ids_size();
195 std::optional<int64_t> async_group_id;
196 if (proto.async_group_id() >= 0) {
197 async_group_id = proto.async_group_id();
198 }
199 instruction = CreateAsyncStart(shape, all_operands(), computations(0),
200 async_group_id,
201 proto.async_execution_thread().empty()
202 ? kMainExecutionThread
203 : proto.async_execution_thread());
204 break;
205 }
206 case HloOpcode::kAsyncUpdate: {
207 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
208 << "Async update instruction should have 1 called computation but "
209 "sees "
210 << proto.called_computation_ids_size();
211 std::optional<int64_t> async_group_id;
212 if (proto.async_group_id() >= 0) {
213 async_group_id = proto.async_group_id();
214 }
215 instruction =
216 CreateAsyncUpdate(shape, operands(0), computations(0), async_group_id,
217 proto.async_execution_thread().empty()
218 ? kMainExecutionThread
219 : proto.async_execution_thread());
220 break;
221 }
222 case HloOpcode::kAsyncDone: {
223 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
224 << "Async done instruction should have 1 called computation but sees "
225 << proto.called_computation_ids_size();
226 std::optional<int64_t> async_group_id;
227 if (proto.async_group_id() >= 0) {
228 async_group_id = proto.async_group_id();
229 }
230 instruction =
231 CreateAsyncDone(shape, operands(0), computations(0), async_group_id,
232 proto.async_execution_thread().empty()
233 ? kMainExecutionThread
234 : proto.async_execution_thread());
235 break;
236 }
237 case HloOpcode::kCopyStart: {
238 instruction = CreateCopyStart(shape, operands(0),
239 proto.is_cross_program_prefetch());
240 break;
241 }
242 case HloOpcode::kCompare: {
243 // Auto-upgraded from deprecated opcode skips the following.
244 if (!comparison_direction) {
245 TF_ASSIGN_OR_RETURN(
246 comparison_direction,
247 StringToComparisonDirection(proto.comparison_direction()));
248 }
249 auto comparison_type_str = proto.comparison_type();
250 if (!comparison_type_str.empty()) {
251 // If a comparison type is specified, it *must* be valid.
252 TF_ASSIGN_OR_RETURN(auto comparison_type,
253 StringToComparisonType(comparison_type_str));
254 instruction = CreateCompare(shape, operands(0), operands(1),
255 *comparison_direction, comparison_type);
256 } else {
257 // Allow the specify of comparison type to be optional.
258 // The comparison type will be determined by the types of the operands.
259 instruction = CreateCompare(shape, operands(0), operands(1),
260 *comparison_direction);
261 }
262 break;
263 }
264 case HloOpcode::kTriangularSolve: {
265 instruction = CreateTriangularSolve(shape, operands(0), operands(1),
266 proto.triangular_solve_options());
267 break;
268 }
269 case HloOpcode::kCholesky: {
270 instruction =
271 CreateCholesky(shape, operands(0), proto.cholesky_options());
272 break;
273 }
274 case HloOpcode::kSend:
275 instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
276 proto.is_host_transfer());
277 break;
278 case HloOpcode::kSendDone:
279 instruction = CreateSendDone(operands(0), proto.is_host_transfer());
280 break;
281 case HloOpcode::kRecv:
282 instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
283 proto.channel_id(), proto.is_host_transfer());
284 break;
285 case HloOpcode::kRecvDone:
286 instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
287 break;
288 case HloOpcode::kReverse:
289 instruction =
290 CreateReverse(shape, operands(0),
291 std::vector<int64_t>(proto.dimensions().begin(),
292 proto.dimensions().end()));
293 break;
294 case HloOpcode::kConcatenate:
295 TF_RET_CHECK(proto.dimensions_size() == 1)
296 << "Concatenate instruction should have 1 dimension but sees "
297 << proto.dimensions_size();
298 instruction =
299 CreateConcatenate(shape, all_operands(), proto.dimensions(0));
300 break;
301 case HloOpcode::kConditional: {
302 TF_RET_CHECK(proto.called_computation_ids_size() > 0)
303 << "conditional should have at least 1 called computation";
304 if (operands(0)->shape().element_type() == PRED) {
305 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
306 << "conditional should have exactly 2 called computations but got "
307 << proto.called_computation_ids_size();
308 }
309 TF_RET_CHECK(proto.operand_ids_size() ==
310 proto.called_computation_ids_size() + 1)
311 << "conditional should have one branch_index operand plus one "
312 "operand per called computation but got "
313 << proto.operand_ids_size() << " operands for "
314 << proto.called_computation_ids_size() << " branch computations";
315 auto cond_operands = all_operands();
316 instruction =
317 CreateConditional(shape, cond_operands[0], all_computations(),
318 absl::MakeSpan(cond_operands).subspan(1));
319 break;
320 }
321 case HloOpcode::kReduce:
322 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
323 << "Reduce instruction should have an even number of operands but "
324 "sees "
325 << proto.operand_ids_size();
326 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
327 << "Reduce instruction should have 1 called computation but sees "
328 << proto.called_computation_ids_size();
329 {
330 const auto reduce_operands = all_operands();
331 auto inputs = absl::MakeSpan(reduce_operands)
332 .subspan(0, reduce_operands.size() / 2);
333 auto init_values =
334 absl::MakeSpan(reduce_operands)
335 .subspan(reduce_operands.size() / 2, reduce_operands.size());
336 instruction =
337 CreateReduce(shape, inputs, init_values,
338 std::vector<int64_t>(proto.dimensions().begin(),
339 proto.dimensions().end()),
340 computations(0));
341 }
342 break;
343 case HloOpcode::kSort: {
344 TF_RET_CHECK(proto.operand_ids_size() >= 1)
345 << "Sort instruction should have at least 1 operand but has "
346 << proto.operand_ids_size();
347 TF_RET_CHECK(proto.dimensions().size() == 1)
348 << "Sort instruction should have 1 dimension";
349 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
350 << "Sort instruction should one called computation but sees "
351 << proto.called_computation_ids_size();
352 auto sort_operands = all_operands();
353 instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
354 computations(0), proto.is_stable());
355 break;
356 }
357 case HloOpcode::kTranspose:
358 instruction =
359 CreateTranspose(shape, operands(0),
360 std::vector<int64_t>(proto.dimensions().begin(),
361 proto.dimensions().end()));
362 break;
363 case HloOpcode::kBroadcast:
364 instruction =
365 CreateBroadcast(shape, operands(0),
366 std::vector<int64_t>(proto.dimensions().begin(),
367 proto.dimensions().end()));
368 break;
369 case HloOpcode::kMap:
370 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
371 << "Map instruction should have 1 called computation but sees "
372 << proto.called_computation_ids_size();
373 instruction = CreateMap(shape, all_operands(), computations(0));
374 break;
375 case HloOpcode::kSlice: {
376 std::vector<int64_t> slice_starts, slice_limits, slice_strides;
377 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
378 proto.slice_dimensions()) {
379 slice_starts.push_back(slice_dimensions.start());
380 slice_limits.push_back(slice_dimensions.limit());
381 slice_strides.push_back(slice_dimensions.stride());
382 }
383 instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
384 slice_strides);
385 break;
386 }
387 case HloOpcode::kConstant: {
388 // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
389 if (proto.has_literal()) {
390 TF_ASSIGN_OR_RETURN(
391 auto literal,
392 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
393 instruction = CreateConstant(std::move(literal));
394 // Literal's shape may have no/different tiling info.
395 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
396 instruction->shape(), shape))
397 << instruction->shape().ToString(true) << " vs "
398 << shape.ToString(true);
399 *instruction->mutable_shape() = shape;
400 } else {
401 instruction = std::make_unique<HloConstantInstruction>(shape);
402 }
403 break;
404 }
405 case HloOpcode::kFusion: {
406 // In the proto, fused computations are held exclusively within the
407 // HloInstructionProto and do not appear as an HloComputationProto within
408 // the HloModuleProto.
409 TF_RET_CHECK(!proto.fusion_kind().empty());
410 TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
411 StringToFusionKind(proto.fusion_kind()));
412
413 // Find the fused computation and set its fusion instruction.
414 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
415 << "Expect 1 called computation for fusion instruction but sees "
416 << proto.called_computation_ids_size();
417 const int64_t fusion_id = proto.called_computation_ids(0);
418 auto* fused_computation =
419 tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
420 TF_RET_CHECK(fused_computation != nullptr)
421 << "No fusion computation with id " << fusion_id;
422 instruction =
423 CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
424 break;
425 }
426 case HloOpcode::kRng:
427 instruction = CreateRng(shape, proto.distribution(), all_operands());
428 break;
429 case HloOpcode::kRngBitGenerator:
430 instruction =
431 CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
432 break;
433 case HloOpcode::kRngGetAndUpdateState:
434 instruction = CreateRngGetAndUpdateState(shape, proto.delta());
435 break;
436 case HloOpcode::kParameter:
437 instruction =
438 CreateParameter(proto.parameter_number(), shape, proto.name());
439 if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
440 instruction->set_parameter_replicated_at_leaf_buffers(
441 proto.parameter_replication().replicated_at_leaf_buffers());
442 }
443 break;
444 case HloOpcode::kGetTupleElement:
445 instruction =
446 CreateGetTupleElement(shape, operands(0), proto.tuple_index());
447 break;
448 case HloOpcode::kReducePrecision:
449 instruction = CreateReducePrecision(
450 shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
451 break;
452 case HloOpcode::kInfeed: {
453 TF_RET_CHECK(shape.IsTuple() &&
454 (ShapeUtil::TupleElementCount(shape) == 2))
455 << "Infeed should have a tuple shape with 2 operands, but has: "
456 << shape;
457 const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
458 instruction =
459 CreateInfeed(data_shape, operands(0), proto.infeed_config());
460 } break;
461 case HloOpcode::kOutfeed: {
462 Shape outfeed_shape(proto.outfeed_shape());
463 TF_RETURN_IF_ERROR(
464 ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
465 instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
466 proto.outfeed_config());
467 break;
468 }
469 case HloOpcode::kAllGather:
470 case HloOpcode::kAllGatherStart: {
471 std::optional<int64_t> channel_id;
472 if (proto.channel_id() > 0) {
473 channel_id = proto.channel_id();
474 }
475
476 TF_RET_CHECK(proto.dimensions_size() == 1)
477 << "AllGather cannot have more than 1 all-gather dimensions";
478 int64_t all_gather_dimension = proto.dimensions(0);
479 if (opcode == HloOpcode::kAllGather) {
480 instruction = CreateAllGather(
481 shape, all_operands(), all_gather_dimension,
482 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
483 proto.replica_groups().end()),
484 proto.constrain_layout(), channel_id,
485 proto.use_global_device_ids());
486 } else {
487 instruction = CreateAllGatherStart(
488 shape, all_operands(), all_gather_dimension,
489 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
490 proto.replica_groups().end()),
491 proto.constrain_layout(), channel_id,
492 proto.use_global_device_ids());
493 }
494 break;
495 }
496 case HloOpcode::kAllReduce:
497 case HloOpcode::kAllReduceStart:
498 case HloOpcode::kReduceScatter: {
499 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
500 << "AllReduce should have 1 called computation but sees "
501 << proto.called_computation_ids_size();
502 TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
503 << "AllReduce cannot have both channel_id() and all_reduce_id()";
504 std::optional<int64_t> channel_id;
505 if (proto.channel_id() > 0) {
506 channel_id = proto.channel_id();
507 }
508 if (proto.all_reduce_id() > 0) {
509 channel_id = proto.all_reduce_id();
510 }
511 std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(),
512 proto.replica_groups().end());
513 if (opcode == HloOpcode::kAllReduce) {
514 instruction =
515 CreateAllReduce(shape, all_operands(), computations(0),
516 replica_groups, proto.constrain_layout(),
517 channel_id, proto.use_global_device_ids());
518 } else if (opcode == HloOpcode::kReduceScatter) {
519 TF_RET_CHECK(proto.dimensions_size() == 1)
520 << "ReduceScatter cannot have more than 1 scatter dimensions";
521 int64_t scatter_dimension = proto.dimensions(0);
522 instruction = CreateReduceScatter(
523 shape, all_operands(), computations(0), replica_groups,
524 proto.constrain_layout(), channel_id, proto.use_global_device_ids(),
525 scatter_dimension);
526 } else {
527 instruction =
528 CreateAllReduceStart(shape, all_operands(), computations(0),
529 replica_groups, proto.constrain_layout(),
530 channel_id, proto.use_global_device_ids());
531 }
532 break;
533 }
534 case HloOpcode::kAllToAll: {
535 std::optional<int64_t> channel_id;
536 if (proto.channel_id() > 0) {
537 channel_id = proto.channel_id();
538 }
539 std::optional<int64_t> split_dimension;
540 if (proto.dimensions_size() > 0) {
541 TF_RET_CHECK(proto.dimensions_size() == 1)
542 << "AllToAll cannot have more than 1 dimension (split dimension)";
543 TF_RET_CHECK(all_operands().size() == 1)
544 << "AllToAll must have a single operand when the split dimension "
545 "is specified";
546 split_dimension = proto.dimensions(0);
547 }
548 instruction = CreateAllToAll(
549 shape, all_operands(),
550 /*replica_groups=*/
551 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
552 proto.replica_groups().end()),
553 /*constrain_layout=*/proto.constrain_layout(),
554 /*channel_id=*/channel_id, split_dimension);
555 break;
556 }
557 case HloOpcode::kCollectivePermute:
558 case HloOpcode::kCollectivePermuteStart: {
559 TF_RET_CHECK(proto.operand_ids().size() == 1 ||
560 proto.operand_ids().size() == 4);
561 std::vector<std::pair<int64_t, int64_t>> source_target_pairs(
562 proto.source_target_pairs_size());
563 std::optional<int64_t> channel_id;
564 if (proto.channel_id() > 0) {
565 channel_id = proto.channel_id();
566 }
567 for (int i = 0; i < source_target_pairs.size(); ++i) {
568 source_target_pairs[i].first = proto.source_target_pairs(i).source();
569 source_target_pairs[i].second = proto.source_target_pairs(i).target();
570 }
571 if (proto.dynamic_slice_sizes_size() == 0) {
572 if (opcode == HloOpcode::kCollectivePermute) {
573 instruction = CreateCollectivePermute(
574 shape, operands(0), source_target_pairs, channel_id);
575 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
576 instruction = CreateCollectivePermuteStart(
577 shape, operands(0), source_target_pairs, channel_id);
578 } else {
579 LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
580 << "but got " << HloOpcodeString(opcode);
581 }
582 } else {
583 std::vector<std::vector<int64_t>> slice_sizes;
584 HloInstruction* input = operands(0);
585 HloInstruction* input_start_indices = operands(2);
586 if (input->shape().IsTuple() &&
587 input->shape().tuple_shapes_size() > 1) {
588 slice_sizes.resize(input->shape().tuple_shapes_size());
589 } else {
590 slice_sizes.resize(1);
591 }
592 int proto_index = 0;
593 if (input->shape().IsTuple()) {
594 if (input_start_indices->shape()
595 .tuple_shapes(0)
596 .tuple_shapes(0)
597 .IsArray()) {
598 slice_sizes.resize(input->shape().tuple_shapes_size());
599 for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
600 slice_sizes[i].resize(
601 input->shape().tuple_shapes(i).dimensions_size());
602 for (int j = 0;
603 j < input->shape().tuple_shapes(i).dimensions_size(); ++j) {
604 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
605 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
606 proto_index += 1;
607 }
608 }
609 } else {
610 slice_sizes.resize(
611 input->shape().tuple_shapes_size() *
612 ShapeUtil::TupleElementCount(
613 input_start_indices->shape().tuple_shapes(0)));
614 int slice_sizes_count = 0;
615 for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
616 for (int j = 0;
617 j < ShapeUtil::TupleElementCount(
618 input_start_indices->shape().tuple_shapes(i));
619 ++j) {
620 slice_sizes[slice_sizes_count].resize(
621 input->shape().tuple_shapes(i).rank());
622 for (int k = 0; k < input->shape().tuple_shapes(i).rank();
623 ++k) {
624 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
625 slice_sizes[slice_sizes_count][k] =
626 proto.dynamic_slice_sizes(proto_index);
627 proto_index += 1;
628 }
629 slice_sizes_count += 1;
630 }
631 }
632 }
633 } else {
634 slice_sizes.resize(
635 ShapeUtil::TupleElementCount(input_start_indices->shape()));
636 if (input_start_indices->shape().tuple_shapes(0).IsTuple()) {
637 for (int i = 0;
638 i < ShapeUtil::TupleElementCount(input_start_indices->shape());
639 ++i) {
640 slice_sizes[i].resize(input->shape().dimensions_size());
641 for (int j = 0; j < input->shape().dimensions_size(); ++j) {
642 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
643 proto_index += 1;
644 }
645 }
646 } else {
647 slice_sizes.resize(1);
648 slice_sizes[0].resize(input->shape().dimensions_size());
649 for (int j = 0; j < input->shape().dimensions_size(); ++j) {
650 slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index);
651 proto_index += 1;
652 }
653 }
654 }
655 if (opcode == HloOpcode::kCollectivePermute) {
656 instruction = CreateCollectivePermute(
657 shape, operands(0), operands(1), operands(2), operands(3),
658 source_target_pairs, slice_sizes, channel_id);
659 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
660 instruction = CreateCollectivePermuteStart(
661 shape, operands(0), operands(1), operands(2), operands(3),
662 source_target_pairs, slice_sizes, channel_id);
663 } else {
664 LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
665 << "but got " << HloOpcodeString(opcode);
666 }
667 }
668 break;
669 }
670 case HloOpcode::kReplicaId: {
671 instruction = CreateReplicaId(shape);
672 break;
673 }
674 case HloOpcode::kPartitionId: {
675 instruction = CreatePartitionId(shape);
676 break;
677 }
678 case HloOpcode::kConvolution: {
679 TF_RET_CHECK(proto.has_window());
680 TF_RET_CHECK(proto.has_convolution_dimension_numbers());
681 PrecisionConfig precision_config = proto.precision_config();
682 precision_config.mutable_operand_precision()->Resize(
683 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
684 instruction = CreateConvolve(
685 shape, operands(0), operands(1),
686 std::max<int64_t>(proto.feature_group_count(), 1),
687 std::max<int64_t>(proto.batch_group_count(), 1), proto.window(),
688 proto.convolution_dimension_numbers(), precision_config);
689 break;
690 }
691 case HloOpcode::kReduceWindow:
692 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
693 << "Reduce window should have an even number of operands but "
694 "sees "
695 << proto.operand_ids_size();
696 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
697 << "ReduceWindow should have 1 called computation but sees "
698 << proto.called_computation_ids_size();
699 {
700 const auto reduce_operands = all_operands();
701 auto inputs = absl::MakeSpan(reduce_operands)
702 .subspan(0, reduce_operands.size() / 2);
703 auto init_values =
704 absl::MakeSpan(reduce_operands)
705 .subspan(reduce_operands.size() / 2, reduce_operands.size());
706 instruction = CreateReduceWindow(shape, inputs, init_values,
707 proto.window(), computations(0));
708 }
709 break;
710 case HloOpcode::kSelectAndScatter:
711 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
712 << "SelectAndScatter should have 2 called computations but sees "
713 << proto.called_computation_ids_size();
714 instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
715 proto.window(), operands(1),
716 operands(2), computations(1));
717 break;
718 case HloOpcode::kCustomCall: {
719 if (proto.constrain_layout()) {
720 // A proto RepeatedPtrField cannot be converted to a Span (it is a
721 // vector of pointers essentially) so create a vector of shapes to pass
722 // in.
723 std::vector<Shape> operand_shapes;
724 const auto& operand_shapes_with_layout =
725 proto.operand_shapes_with_layout();
726 operand_shapes.reserve(operand_shapes_with_layout.size());
727 for (const ShapeProto& shape_proto : operand_shapes_with_layout) {
728 operand_shapes.emplace_back(shape_proto);
729 }
730 instruction =
731 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
732 operand_shapes, proto.backend_config());
733 } else {
734 if (proto.called_computation_ids_size() == 1) {
735 instruction = CreateCustomCall(shape, all_operands(), computations(0),
736 proto.custom_call_target(),
737 proto.backend_config());
738 } else if (proto.called_computation_ids_size() > 1) {
739 instruction = CreateCustomCall(
740 shape, all_operands(), all_computations(),
741 proto.custom_call_target(), proto.backend_config());
742
743 } else {
744 instruction = CreateCustomCall(shape, all_operands(),
745 proto.custom_call_target(),
746 proto.backend_config());
747 }
748 }
749 auto custom_call_instr =
750 Cast<HloCustomCallInstruction>(instruction.get());
751 if (proto.has_window()) {
752 custom_call_instr->set_window(proto.window());
753 }
754 if (proto.has_literal()) {
755 TF_ASSIGN_OR_RETURN(
756 auto literal,
757 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
758 custom_call_instr->set_literal(std::move(literal));
759 }
760 if (proto.has_convolution_dimension_numbers()) {
761 custom_call_instr->set_convolution_dimension_numbers(
762 proto.convolution_dimension_numbers());
763 }
764 custom_call_instr->set_feature_group_count(std::max(
765 static_cast<int64_t>(proto.feature_group_count()), int64_t{1}));
766 custom_call_instr->set_batch_group_count(std::max(
767 static_cast<int64_t>(proto.batch_group_count()), int64_t{1}));
768 custom_call_instr->set_custom_call_has_side_effect(
769 proto.custom_call_has_side_effect());
770 custom_call_instr->set_padding_type(proto.padding_type());
771
772 PrecisionConfig precision_config = proto.precision_config();
773 precision_config.mutable_operand_precision()->Resize(
774 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
775 *custom_call_instr->mutable_precision_config() = precision_config;
776 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
777 output_to_operand_aliasing;
778 for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
779 output_to_operand_aliasing.emplace_back(
780 ShapeIndex(aliasing.output_shape_index().begin(),
781 aliasing.output_shape_index().end()),
782 std::pair<int64_t, ShapeIndex>{
783 aliasing.operand_index(),
784 ShapeIndex(aliasing.operand_shape_index().begin(),
785 aliasing.operand_shape_index().end())});
786 }
787 custom_call_instr->set_output_to_operand_aliasing(
788 std::move(output_to_operand_aliasing));
789 custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule());
790 custom_call_instr->set_api_version(proto.custom_call_api_version());
791 break;
792 }
793 case HloOpcode::kPad:
794 TF_RET_CHECK(proto.has_padding_config());
795 instruction =
796 CreatePad(shape, operands(0), operands(1), proto.padding_config());
797 break;
798 case HloOpcode::kDynamicSlice: {
799 std::vector<int64_t> slice_sizes(proto.dynamic_slice_sizes_size());
800 absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
801 TF_RET_CHECK(proto.operand_ids_size() >= 1)
802 << "DynamicSlice instruction should have at least 1 operands but "
803 "sees "
804 << proto.operand_ids_size();
805 // TODO(b/118437727): Old form, make the check unconditional.
806 if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
807 auto expected_operands = 1 + operands(0)->shape().rank();
808 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
809 << "DynamicSlice instruction should have " << expected_operands
810 << " operands, but has " << proto.operand_ids_size();
811 }
812 const auto& operand_vector = all_operands();
813 instruction = CreateDynamicSlice(
814 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
815 slice_sizes);
816 break;
817 }
818 case HloOpcode::kDynamicUpdateSlice: {
819 TF_RET_CHECK(proto.operand_ids_size() >= 2)
820 << "DynamicUpdateSlice instruction should have at least 2 operands "
821 "but sees "
822 << proto.operand_ids_size();
823 // TODO(b/118437727): Old form, make the check unconditional.
824 if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
825 auto expected_operands = 2 + operands(0)->shape().rank();
826 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
827 << "DynamicUpdateSlice instruction should have "
828 << expected_operands << " operands, but has "
829 << proto.operand_ids_size();
830 }
831 const auto& operand_vector = all_operands();
832 instruction =
833 CreateDynamicUpdateSlice(shape, operands(0), operands(1),
834 absl::MakeSpan(operand_vector).subspan(2));
835
836 break;
837 }
838 case HloOpcode::kGather: {
839 TF_RET_CHECK(proto.has_gather_dimension_numbers())
840 << "Gather instruction should have GatherDimensionNumbers set.";
841 auto gather_dimension_numbers = std::make_unique<GatherDimensionNumbers>(
842 proto.gather_dimension_numbers());
843 std::vector<int64_t> gather_slice_sizes;
844 const auto& slice_sizes = proto.gather_slice_sizes();
845 gather_slice_sizes.reserve(slice_sizes.size());
846 for (int64_t bound : slice_sizes) {
847 gather_slice_sizes.push_back(bound);
848 }
849 instruction = CreateGather(shape, operands(0), operands(1),
850 *gather_dimension_numbers, gather_slice_sizes,
851 proto.indices_are_sorted());
852 break;
853 }
854 case HloOpcode::kScatter: {
855 TF_RET_CHECK(proto.has_scatter_dimension_numbers())
856 << "Scatter instruction should have ScatterDimensionNumbers set.";
857 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
858 << "Scatter instruction should have 1 called computation but sees "
859 << proto.called_computation_ids_size();
860 auto scatter_dimension_numbers =
861 std::make_unique<ScatterDimensionNumbers>(
862 proto.scatter_dimension_numbers());
863 auto operands = all_operands();
864 auto operand_span = absl::MakeConstSpan(operands);
865 auto input_count = operands.size() / 2;
866 instruction =
867 CreateScatter(shape, operand_span.first(input_count),
868 operands[input_count], operand_span.last(input_count),
869 computations(0), *scatter_dimension_numbers,
870 proto.indices_are_sorted(), proto.unique_indices());
871 break;
872 }
873 case HloOpcode::kIota:
874 TF_RET_CHECK(proto.dimensions_size() == 1)
875 << "Iota instruction should have 1 dimension but sees "
876 << proto.dimensions_size();
877 instruction = CreateIota(shape, proto.dimensions(0));
878 break;
879 case HloOpcode::kDot: {
880 TF_RET_CHECK(proto.has_dot_dimension_numbers())
881 << "Dot instruction should have dot_dimension_numbers.";
882 PrecisionConfig precision_config = proto.precision_config();
883 precision_config.mutable_operand_precision()->Resize(
884 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
885 instruction = std::make_unique<HloDotInstruction>(
886 shape, operands(0), operands(1), proto.dot_dimension_numbers(),
887 precision_config);
888 break;
889 }
890 case HloOpcode::kDomain: {
891 std::shared_ptr<const HloSharding> entry_hlo_sharding;
892 std::shared_ptr<const HloSharding> exit_hlo_sharding;
893 if (proto.has_domain_entry_sharding()) {
894 TF_ASSIGN_OR_RETURN(
895 HloSharding sharding,
896 HloSharding::FromProto(proto.domain_entry_sharding()));
897 entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
898 }
899 if (proto.has_domain_exit_sharding()) {
900 TF_ASSIGN_OR_RETURN(
901 HloSharding sharding,
902 HloSharding::FromProto(proto.domain_exit_sharding()));
903 exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
904 }
905 instruction = std::make_unique<HloDomainInstruction>(
906 shape, operands(0),
907 std::make_unique<ShardingMetadata>(entry_hlo_sharding),
908 std::make_unique<ShardingMetadata>(exit_hlo_sharding));
909 break;
910 }
911 case HloOpcode::kGetDimensionSize:
912 TF_RET_CHECK(proto.dimensions_size() == 1);
913 instruction =
914 CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
915 break;
916 case HloOpcode::kSetDimensionSize:
917 TF_RET_CHECK(proto.dimensions_size() == 1);
918 instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
919 proto.dimensions(0));
920 break;
921 case HloOpcode::kReshape: {
922 int64_t inferred_dimension = -1;
923 if (!proto.dimensions().empty()) {
924 inferred_dimension = proto.dimensions()[0];
925 }
926 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
927 ShapeUtil::ElementsIn(shape) ==
928 ShapeUtil::ElementsIn(operands(0)->shape()))
929 << "shape: " << ShapeUtil::HumanString(shape)
930 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
931 instruction = CreateReshape(shape, operands(0), inferred_dimension);
932 break;
933 }
934 case HloOpcode::kDynamicReshape: {
935 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
936 ShapeUtil::ElementsIn(shape) ==
937 ShapeUtil::ElementsIn(operands(0)->shape()))
938 << "shape: " << ShapeUtil::HumanString(shape)
939 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
940 const auto& operand_vector = all_operands();
941 instruction = CreateDynamicReshape(
942 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
943 break;
944 }
945 default: {
946 instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
947 for (const int64_t operand_id : proto.operand_ids()) {
948 instruction->AppendOperand(instruction_map.at(operand_id));
949 }
950 if (instruction->opcode() != HloOpcode::kFusion) {
951 if (instruction->opcode() == HloOpcode::kCall) {
952 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
953 << "Call should have 1 called computation but has "
954 << proto.called_computation_ids_size();
955 }
956 if (instruction->opcode() == HloOpcode::kWhile) {
957 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
958 << "While should have 2 called computation but has "
959 << proto.called_computation_ids_size();
960 }
961 for (const int64_t computation_id : proto.called_computation_ids()) {
962 instruction->called_computations_.push_back(
963 computation_map.at(computation_id));
964 }
965 }
966 TF_RET_CHECK(!proto.has_precision_config())
967 << instruction->opcode() << proto.DebugString();
968 TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
969 break;
970 }
971 }
972
973 for (const int64_t predecessor_id : proto.control_predecessor_ids()) {
974 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
975 << "No instruction with id " << predecessor_id;
976 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
977 ->AddControlDependencyTo(instruction.get()));
978 }
979
980 TF_RET_CHECK(!proto.name().empty());
981 instruction->SetAndSanitizeName(proto.name());
982 instruction->metadata_ = proto.metadata();
983 instruction->backend_config_ = proto.backend_config();
984 instruction->outer_dimension_partitions_.assign(
985 proto.outer_dimension_partitions().begin(),
986 proto.outer_dimension_partitions().end());
987
988 TF_RET_CHECK(proto.id() >= 0)
989 << "Instruction with negative id: " << proto.id();
990 TF_RET_CHECK(proto.id() <= INT_MAX)
991 << "Instruction with id > INT_MAX: " << proto.id();
992 instruction->unique_id_ = proto.id();
993
994 if (proto.has_sharding()) {
995 TF_ASSIGN_OR_RETURN(const auto& sharding,
996 HloSharding::FromProto(proto.sharding()));
997 instruction->set_sharding(sharding);
998 }
999
1000 if (proto.has_frontend_attributes()) {
1001 instruction->set_frontend_attributes(proto.frontend_attributes());
1002 }
1003
1004 return std::move(instruction);
1005 }
1006
CreateParameter(int64_t parameter_number,const Shape & shape,const std::string & name)1007 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
1008 int64_t parameter_number, const Shape& shape, const std::string& name) {
1009 return std::make_unique<HloParameterInstruction>(parameter_number, shape,
1010 name);
1011 }
1012
CreateConstant(Literal literal)1013 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
1014 Literal literal) {
1015 return std::make_unique<HloConstantInstruction>(std::move(literal));
1016 }
1017
CreateIota(const Shape & shape,int64_t iota_dimension)1018 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
1019 const Shape& shape, int64_t iota_dimension) {
1020 return std::make_unique<HloIotaInstruction>(shape, iota_dimension);
1021 }
1022
1023 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64_t index)1024 HloInstruction::CreateGetTupleElement(const Shape& shape,
1025 HloInstruction* operand, int64_t index) {
1026 return std::make_unique<HloGetTupleElementInstruction>(shape, operand, index);
1027 }
1028
1029 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(HloInstruction * operand,int64_t index)1030 HloInstruction::CreateGetTupleElement(HloInstruction* operand, int64_t index) {
1031 return std::make_unique<HloGetTupleElementInstruction>(
1032 operand->shape().tuple_shapes(index), operand, index);
1033 }
1034
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)1035 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
1036 const Shape& shape, RandomDistribution distribution,
1037 absl::Span<HloInstruction* const> parameters) {
1038 return std::make_unique<HloRngInstruction>(shape, distribution, parameters);
1039 }
1040
1041 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64_t delta)1042 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64_t delta) {
1043 return std::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
1044 }
1045
1046 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)1047 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
1048 RandomAlgorithm algorithm) {
1049 return std::make_unique<HloRngBitGeneratorInstruction>(shape, state,
1050 algorithm);
1051 }
1052
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1053 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
1054 const Shape& shape, HloOpcode opcode,
1055 absl::Span<HloInstruction* const> operands) {
1056 if (opcode == HloOpcode::kCopy) {
1057 // It is impossible to copy an opaque shape, we don't know how big it is.
1058 CHECK(!shape.IsOpaque());
1059 }
1060 auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
1061 for (auto operand : operands) {
1062 instruction->AppendOperand(operand);
1063 }
1064 return instruction;
1065 }
1066
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)1067 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
1068 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
1069 // Only certain opcodes are supported with CreateUnary: opcodes of unary
1070 // instructions with no auxiliary fields.
1071 switch (opcode) {
1072 case HloOpcode::kAbs:
1073 case HloOpcode::kAllGatherDone:
1074 case HloOpcode::kAllReduceDone:
1075 case HloOpcode::kRoundNearestAfz:
1076 case HloOpcode::kRoundNearestEven:
1077 case HloOpcode::kBitcast:
1078 case HloOpcode::kCeil:
1079 case HloOpcode::kCollectivePermuteDone:
1080 case HloOpcode::kCopy:
1081 case HloOpcode::kCopyDone:
1082 case HloOpcode::kCos:
1083 case HloOpcode::kOptimizationBarrier:
1084 case HloOpcode::kClz:
1085 case HloOpcode::kExp:
1086 case HloOpcode::kExpm1:
1087 case HloOpcode::kFloor:
1088 case HloOpcode::kImag:
1089 case HloOpcode::kIsFinite:
1090 case HloOpcode::kLog:
1091 case HloOpcode::kLog1p:
1092 case HloOpcode::kNot:
1093 case HloOpcode::kNegate:
1094 case HloOpcode::kPopulationCount:
1095 case HloOpcode::kReal:
1096 case HloOpcode::kRsqrt:
1097 case HloOpcode::kLogistic:
1098 case HloOpcode::kSign:
1099 case HloOpcode::kSin:
1100 case HloOpcode::kSqrt:
1101 case HloOpcode::kCbrt:
1102 case HloOpcode::kTanh:
1103 break;
1104 default:
1105 LOG(FATAL) << "Invalid unary instruction opcode "
1106 << HloOpcodeString(opcode);
1107 }
1108 return CreateNary(shape, opcode, {operand});
1109 }
1110
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)1111 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
1112 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1113 HloInstruction* rhs) {
1114 // Only certain opcodes are supported with CreateBinary: opcodes of binary
1115 // instructions with no auxiliary fields.
1116 switch (opcode) {
1117 case HloOpcode::kAdd:
1118 case HloOpcode::kAtan2:
1119 case HloOpcode::kDivide:
1120 case HloOpcode::kComplex:
1121 case HloOpcode::kMaximum:
1122 case HloOpcode::kMinimum:
1123 case HloOpcode::kMultiply:
1124 case HloOpcode::kPower:
1125 case HloOpcode::kRemainder:
1126 case HloOpcode::kSubtract:
1127 case HloOpcode::kAnd:
1128 case HloOpcode::kOr:
1129 case HloOpcode::kXor:
1130 case HloOpcode::kShiftLeft:
1131 case HloOpcode::kShiftRightArithmetic:
1132 case HloOpcode::kShiftRightLogical:
1133 break;
1134 default:
1135 LOG(FATAL) << "Invalid binary instruction opcode "
1136 << HloOpcodeString(opcode);
1137 }
1138 return CreateNary(shape, opcode, {lhs, rhs});
1139 }
1140
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)1141 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
1142 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1143 HloInstruction* rhs, HloInstruction* ehs) {
1144 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
1145 // instructions with no auxiliary fields.
1146 switch (opcode) {
1147 case HloOpcode::kClamp:
1148 case HloOpcode::kSelect:
1149 break;
1150 default:
1151 LOG(FATAL) << "Invalid ternary instruction opcode "
1152 << HloOpcodeString(opcode);
1153 }
1154 return CreateNary(shape, opcode, {lhs, rhs, ehs});
1155 }
1156
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1157 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
1158 const Shape& shape, HloOpcode opcode,
1159 absl::Span<HloInstruction* const> operands) {
1160 CHECK_EQ(HloOpcode::kTuple, opcode);
1161 return CreateNary(shape, opcode, operands);
1162 }
1163
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1164 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
1165 const Shape& shape, absl::Span<HloInstruction* const> operands,
1166 HloComputation* map_computation) {
1167 return std::make_unique<HloMapInstruction>(shape, operands, map_computation);
1168 }
1169
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1170 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
1171 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1172 int64_t feature_group_count, int64_t batch_group_count,
1173 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
1174 const PrecisionConfig& precision_config) {
1175 return std::make_unique<HloConvolutionInstruction>(
1176 shape, lhs, rhs, feature_group_count, batch_group_count, window,
1177 dimension_numbers, precision_config);
1178 }
1179
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64_t> fft_length)1180 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
1181 const Shape& shape, HloInstruction* operand, FftType fft_type,
1182 absl::Span<const int64_t> fft_length) {
1183 return std::make_unique<HloFftInstruction>(shape, operand, fft_type,
1184 fft_length);
1185 }
1186
CreateAsyncStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1187 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncStart(
1188 const Shape& shape, absl::Span<HloInstruction* const> operands,
1189 HloComputation* async_computation, std::optional<int64_t> async_group_id,
1190 absl::string_view async_execution_thread) {
1191 return std::make_unique<HloAsyncInstruction>(
1192 HloOpcode::kAsyncStart, shape, operands, async_computation,
1193 async_group_id, async_execution_thread);
1194 }
1195
CreateAsyncUpdate(const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1196 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncUpdate(
1197 const Shape& shape, HloInstruction* operand,
1198 HloComputation* async_computation, std::optional<int64_t> async_group_id,
1199 absl::string_view async_execution_thread) {
1200 return std::make_unique<HloAsyncInstruction>(
1201 HloOpcode::kAsyncUpdate, shape, operand, async_computation,
1202 async_group_id, async_execution_thread);
1203 }
1204
CreateAsyncDone(const Shape & shape,HloInstruction * operand,HloComputation * async_computation,std::optional<int64_t> async_group_id,absl::string_view async_execution_thread)1205 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncDone(
1206 const Shape& shape, HloInstruction* operand,
1207 HloComputation* async_computation, std::optional<int64_t> async_group_id,
1208 absl::string_view async_execution_thread) {
1209 return std::make_unique<HloAsyncInstruction>(
1210 HloOpcode::kAsyncDone, shape, operand, async_computation, async_group_id,
1211 async_execution_thread);
1212 }
1213
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1214 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1215 const Shape& shape, HloInstruction* operand,
1216 bool is_cross_program_prefetch) {
1217 return std::make_unique<HloCopyStartInstruction>(shape, operand,
1218 is_cross_program_prefetch);
1219 }
1220
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,std::optional<Comparison::Type> type)1221 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1222 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1223 ComparisonDirection direction, std::optional<Comparison::Type> type) {
1224 return std::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1225 type);
1226 }
1227
1228 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1229 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1230 HloInstruction* b,
1231 const TriangularSolveOptions& options) {
1232 return std::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1233 }
1234
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1235 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1236 const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1237 return std::make_unique<HloCholeskyInstruction>(shape, a, options);
1238 }
1239
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1240 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1241 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1242 const DotDimensionNumbers& dimension_numbers,
1243 const PrecisionConfig& precision_config) {
1244 return std::make_unique<HloDotInstruction>(shape, lhs, rhs, dimension_numbers,
1245 precision_config);
1246 }
1247
1248 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1249 HloInstruction::CreateReducePrecision(const Shape& shape,
1250 HloInstruction* operand,
1251 const int exponent_bits,
1252 const int mantissa_bits) {
1253 return std::make_unique<HloReducePrecisionInstruction>(
1254 shape, operand, exponent_bits, mantissa_bits);
1255 }
1256
CreateAllGather(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1257 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1258 const Shape& shape, absl::Span<HloInstruction* const> operands,
1259 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1260 bool constrain_layout, const std::optional<int64_t>& channel_id,
1261 bool use_global_device_ids) {
1262 return std::make_unique<HloAllGatherInstruction>(
1263 HloOpcode::kAllGather, shape, operands, all_gather_dimension,
1264 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1265 }
1266
1267 /* static */ std::unique_ptr<HloInstruction>
CreateAllGatherStart(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1268 HloInstruction::CreateAllGatherStart(
1269 const Shape& shape, absl::Span<HloInstruction* const> operands,
1270 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1271 bool constrain_layout, const std::optional<int64_t>& channel_id,
1272 bool use_global_device_ids) {
1273 return std::make_unique<HloAllGatherInstruction>(
1274 HloOpcode::kAllGatherStart, shape, operands, all_gather_dimension,
1275 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1276 }
1277
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1278 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1279 const Shape& shape, absl::Span<HloInstruction* const> operands,
1280 HloComputation* reduce_computation,
1281 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1282 const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
1283 return std::make_unique<HloAllReduceInstruction>(
1284 HloOpcode::kAllReduce, shape, operands, reduce_computation,
1285 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1286 }
1287
1288 /* static */ std::unique_ptr<HloInstruction>
CreateReduceScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)1289 HloInstruction::CreateReduceScatter(
1290 const Shape& shape, absl::Span<HloInstruction* const> operands,
1291 HloComputation* reduce_computation,
1292 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1293 const std::optional<int64_t>& channel_id, bool use_global_device_ids,
1294 int64_t scatter_dimension) {
1295 return std::make_unique<HloReduceScatterInstruction>(
1296 shape, operands, reduce_computation, replica_groups, constrain_layout,
1297 channel_id, use_global_device_ids, scatter_dimension);
1298 }
1299
1300 /* static */ std::unique_ptr<HloInstruction>
CreateAllReduceStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,bool use_global_device_ids)1301 HloInstruction::CreateAllReduceStart(
1302 const Shape& shape, absl::Span<HloInstruction* const> operands,
1303 HloComputation* reduce_computation,
1304 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1305 const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
1306 return std::make_unique<HloAllReduceInstruction>(
1307 HloOpcode::kAllReduceStart, shape, operands, reduce_computation,
1308 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1309 }
1310
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const std::optional<int64_t> & channel_id,const std::optional<int64_t> & split_dimension)1311 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1312 const Shape& shape, absl::Span<HloInstruction* const> operands,
1313 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1314 const std::optional<int64_t>& channel_id,
1315 const std::optional<int64_t>& split_dimension) {
1316 return std::make_unique<HloAllToAllInstruction>(
1317 shape, operands, replica_groups, constrain_layout, channel_id,
1318 split_dimension);
1319 }
1320
1321 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)1322 HloInstruction::CreateCollectivePermute(
1323 const Shape& shape, HloInstruction* operand,
1324 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1325 const std::optional<int64_t>& channel_id) {
1326 return std::make_unique<HloCollectivePermuteInstruction>(
1327 HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1328 channel_id);
1329 }
1330
1331 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)1332 HloInstruction::CreateCollectivePermute(
1333 const Shape& shape, HloInstruction* input, HloInstruction* output,
1334 HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1335 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1336 absl::Span<const std::vector<int64_t>> slice_sizes,
1337 const std::optional<int64_t>& channel_id) {
1338 return std::make_unique<HloCollectivePermuteInstruction>(
1339 HloOpcode::kCollectivePermute, shape, input, output, input_start_indices,
1340 output_start_indices, source_target_pairs, slice_sizes, channel_id);
1341 }
1342
1343 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const std::optional<int64_t> & channel_id)1344 HloInstruction::CreateCollectivePermuteStart(
1345 const Shape& shape, HloInstruction* operand,
1346 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1347 const std::optional<int64_t>& channel_id) {
1348 return std::make_unique<HloCollectivePermuteInstruction>(
1349 HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1350 channel_id);
1351 }
1352
1353 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const std::optional<int64_t> & channel_id)1354 HloInstruction::CreateCollectivePermuteStart(
1355 const Shape& shape, HloInstruction* input, HloInstruction* output,
1356 HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1357 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1358 absl::Span<const std::vector<int64_t>> slice_sizes,
1359 const std::optional<int64_t>& channel_id) {
1360 return std::make_unique<HloCollectivePermuteInstruction>(
1361 HloOpcode::kCollectivePermuteStart, shape, input, output,
1362 input_start_indices, output_start_indices, source_target_pairs,
1363 slice_sizes, channel_id);
1364 }
1365
CreateReplicaId(const Shape & shape)1366 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1367 const Shape& shape) {
1368 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1369 << "HloInstruction replica-id must have a shape of u32[], but "
1370 << shape.ToString() << " is specified";
1371 return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1372 }
1373
CreatePartitionId(const Shape & shape)1374 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1375 const Shape& shape) {
1376 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1377 << "HloInstruction partition-id must have a shape of u32[], but "
1378 << shape.ToString() << " is specified";
1379 return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1380 }
1381
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const std::string & config)1382 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1383 const Shape& infeed_shape, HloInstruction* token_operand,
1384 const std::string& config) {
1385 return std::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1386 config);
1387 }
1388
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1389 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1390 const Shape& outfeed_shape, HloInstruction* operand,
1391 HloInstruction* token_operand, absl::string_view outfeed_config) {
1392 return std::make_unique<HloOutfeedInstruction>(outfeed_shape, operand,
1393 token_operand, outfeed_config);
1394 }
1395
CreateSend(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1396 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1397 HloInstruction* operand, HloInstruction* token, int64_t channel_id,
1398 bool is_host_transfer) {
1399 return std::make_unique<HloSendInstruction>(operand, token, channel_id,
1400 is_host_transfer);
1401 }
1402
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1403 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1404 HloInstruction* operand, bool is_host_transfer) {
1405 auto send_operand = DynCast<HloSendInstruction>(operand);
1406 CHECK(send_operand != nullptr)
1407 << "SendDone must take the context operand from Send";
1408 return std::make_unique<HloSendDoneInstruction>(send_operand,
1409 is_host_transfer);
1410 }
1411
CreateRecv(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1412 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1413 const Shape& shape, HloInstruction* token, int64_t channel_id,
1414 bool is_host_transfer) {
1415 return std::make_unique<HloRecvInstruction>(shape, token, channel_id,
1416 is_host_transfer);
1417 }
1418
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1419 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1420 HloInstruction* operand, bool is_host_transfer) {
1421 auto recv_operand = DynCast<HloRecvInstruction>(operand);
1422 CHECK(recv_operand != nullptr)
1423 << "RecvDone must take the context operand from Recv";
1424 return std::make_unique<HloRecvDoneInstruction>(recv_operand,
1425 is_host_transfer);
1426 }
1427
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1428 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1429 const Shape& shape, HloInstruction* operand,
1430 absl::Span<const int64_t> dimensions) {
1431 return std::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1432 }
1433
CreateAfterAll(absl::Span<HloInstruction * const> operands)1434 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1435 absl::Span<HloInstruction* const> operands) {
1436 CHECK(!operands.empty());
1437 auto instruction = absl::WrapUnique(
1438 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1439 for (auto operand : operands) {
1440 instruction->AppendOperand(operand);
1441 }
1442 return instruction;
1443 }
1444
CreateToken()1445 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1446 return absl::WrapUnique(
1447 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1448 }
1449
1450 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1451 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1452 HloInstruction* token_operand) {
1453 auto instruction = absl::WrapUnique(
1454 new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1455 instruction->AppendOperand(data_operand);
1456 instruction->AppendOperand(token_operand);
1457 return instruction;
1458 }
1459
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1460 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1461 const Shape& shape, HloComputation* condition, HloComputation* body,
1462 HloInstruction* init) {
1463 auto instruction =
1464 absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1465 instruction->AppendOperand(init);
1466 // Body comes before condition computation in the vector.
1467 instruction->called_computations_.push_back(body);
1468 instruction->called_computations_.push_back(condition);
1469 return instruction;
1470 }
1471
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1472 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1473 const Shape& shape, HloInstruction* pred,
1474 HloInstruction* true_computation_arg, HloComputation* true_computation,
1475 HloInstruction* false_computation_arg, HloComputation* false_computation) {
1476 auto instruction =
1477 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1478 instruction->AppendOperand(pred);
1479 instruction->AppendOperand(true_computation_arg);
1480 instruction->AppendOperand(false_computation_arg);
1481 // In called_computations_, the index of true_computation must be 0 and that
1482 // of false computation must be 1, as defined by kTrueComputationIndex and
1483 // kFalseComputationIndex.
1484 instruction->called_computations_.push_back(true_computation);
1485 instruction->called_computations_.push_back(false_computation);
1486 return instruction;
1487 }
1488
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1489 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1490 const Shape& shape, HloInstruction* branch_index,
1491 absl::Span<HloComputation* const> branch_computations,
1492 absl::Span<HloInstruction* const> branch_computation_args) {
1493 auto instruction =
1494 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1495 instruction->AppendOperand(branch_index);
1496 CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1497 for (int i = 0; i < branch_computations.size(); ++i) {
1498 instruction->called_computations_.push_back(branch_computations[i]);
1499 instruction->AppendOperand(branch_computation_args[i]);
1500 }
1501 return instruction;
1502 }
1503
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1504 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1505 const Shape& shape, HloInstruction* operand,
1506 absl::Span<const int64_t> start_indices,
1507 absl::Span<const int64_t> limit_indices,
1508 absl::Span<const int64_t> strides) {
1509 return std::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1510 limit_indices, strides);
1511 }
1512
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes)1513 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1514 const Shape& shape, HloInstruction* operand,
1515 absl::Span<HloInstruction* const> start_indices,
1516 absl::Span<const int64_t> slice_sizes) {
1517 return std::make_unique<HloDynamicSliceInstruction>(
1518 shape, operand, start_indices, slice_sizes);
1519 }
1520
1521 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1522 HloInstruction::CreateDynamicUpdateSlice(
1523 const Shape& shape, HloInstruction* operand, HloInstruction* update,
1524 absl::Span<HloInstruction* const> start_indices) {
1525 return std::make_unique<HloDynamicUpdateSliceInstruction>(
1526 shape, operand, update, start_indices);
1527 }
1528
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1529 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1530 const Shape& shape, absl::Span<HloInstruction* const> operands,
1531 int64_t dimension) {
1532 return std::make_unique<HloConcatenateInstruction>(shape, operands,
1533 dimension);
1534 }
1535
CreateConvert(const Shape & shape,HloInstruction * operand)1536 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1537 const Shape& shape, HloInstruction* operand) {
1538 auto instruction =
1539 absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1540 instruction->AppendOperand(operand);
1541 return instruction;
1542 }
1543
1544 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1545 HloInstruction::CreateBitcastConvert(const Shape& shape,
1546 HloInstruction* operand) {
1547 auto instruction =
1548 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1549 instruction->AppendOperand(operand);
1550 return instruction;
1551 }
1552
CreateBitcast(const Shape & shape,HloInstruction * operand)1553 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1554 const Shape& shape, HloInstruction* operand) {
1555 auto instruction =
1556 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1557 instruction->AppendOperand(operand);
1558 return instruction;
1559 }
1560
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1561 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1562 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1563 absl::Span<const int64_t> dimensions_to_reduce,
1564 HloComputation* reduce_computation) {
1565 auto instruction = absl::WrapUnique(new HloReduceInstruction(
1566 shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1567 return std::move(instruction);
1568 }
1569
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1570 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1571 const Shape& shape, absl::Span<HloInstruction* const> operands,
1572 absl::Span<HloInstruction* const> init_values,
1573 absl::Span<const int64_t> dimensions_to_reduce,
1574 HloComputation* reduce_computation) {
1575 std::vector<HloInstruction*> all_args;
1576 all_args.reserve(operands.size() * 2);
1577 all_args.insert(all_args.end(), operands.begin(), operands.end());
1578 all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1579 return std::make_unique<HloReduceInstruction>(
1580 shape, all_args, dimensions_to_reduce, reduce_computation);
1581 }
1582
CreateReduce(const Shape & shape,HloInstruction * tuple_of_instructions,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions_to_reduce,HloComputation * reduce_computation)1583 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1584 const Shape& shape, HloInstruction* tuple_of_instructions,
1585 absl::Span<HloInstruction* const> init_values,
1586 absl::Span<const int64_t> dimensions_to_reduce,
1587 HloComputation* reduce_computation) {
1588 if (!tuple_of_instructions->shape().IsTuple()) {
1589 CHECK_EQ(init_values.size(), 1)
1590 << "The first input has to be a tuple, or the number of init values "
1591 "has to be one.";
1592 return CreateReduce(shape, tuple_of_instructions, init_values[0],
1593 dimensions_to_reduce, reduce_computation);
1594 }
1595 absl::InlinedVector<HloInstruction*, 4> inputs;
1596 for (int idx = 0; idx < tuple_of_instructions->shape().tuple_shapes_size();
1597 idx++) {
1598 std::unique_ptr<HloInstruction> gte =
1599 HloInstruction::CreateGetTupleElement(tuple_of_instructions, idx);
1600 inputs.push_back(
1601 tuple_of_instructions->parent()->AddInstruction(std::move(gte)));
1602 }
1603 return CreateReduce(shape, inputs, init_values, dimensions_to_reduce,
1604 reduce_computation);
1605 }
1606
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1607 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1608 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1609 const Window& window, HloComputation* reduce_computation) {
1610 return std::make_unique<HloReduceWindowInstruction>(
1611 shape, operand, init_value, window, reduce_computation);
1612 }
1613
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1614 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1615 const Shape& shape, absl::Span<HloInstruction* const> operands,
1616 absl::Span<HloInstruction* const> init_values, const Window& window,
1617 HloComputation* reduce_computation) {
1618 return std::make_unique<HloReduceWindowInstruction>(
1619 shape, operands, init_values, window, reduce_computation);
1620 }
1621 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)1622 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1623 HloInstruction* operand,
1624 HloInstruction* scale,
1625 HloInstruction* offset, float epsilon,
1626 int64_t feature_index) {
1627 return std::make_unique<HloBatchNormTrainingInstruction>(
1628 shape, operand, scale, offset, epsilon, feature_index);
1629 }
1630
1631 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)1632 HloInstruction::CreateBatchNormInference(
1633 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1634 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1635 float epsilon, int64_t feature_index) {
1636 return std::make_unique<HloBatchNormInferenceInstruction>(
1637 shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1638 }
1639
1640 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)1641 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1642 HloInstruction* scale, HloInstruction* mean,
1643 HloInstruction* variance,
1644 HloInstruction* grad_output, float epsilon,
1645 int64_t feature_index) {
1646 return std::make_unique<HloBatchNormGradInstruction>(
1647 shape, operand, scale, mean, variance, grad_output, epsilon,
1648 feature_index);
1649 }
1650
1651 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1652 HloInstruction::CreateSelectAndScatter(
1653 const Shape& shape, HloInstruction* operand, HloComputation* select,
1654 const Window& window, HloInstruction* source, HloInstruction* init_value,
1655 HloComputation* scatter) {
1656 return std::make_unique<HloSelectAndScatterInstruction>(
1657 shape, operand, select, window, source, init_value, scatter);
1658 }
1659
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions)1660 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1661 const Shape& shape, HloInstruction* operand,
1662 absl::Span<const int64_t> broadcast_dimensions) {
1663 return std::make_unique<HloBroadcastInstruction>(shape, operand,
1664 broadcast_dimensions);
1665 }
1666
1667 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64_t dimension)1668 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1669 HloInstruction* operand,
1670 int64_t dimension) {
1671 return std::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1672 dimension);
1673 }
1674
1675 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)1676 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1677 HloInstruction* operand,
1678 HloInstruction* val, int64_t dimension) {
1679 return std::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1680 dimension);
1681 }
1682
1683 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1684 HloInstruction::CreateBroadcastSequence(
1685 const Shape& output_shape, HloInstruction* operand,
1686 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1687 adder) {
1688 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1689 operand->shape().rank() == output_shape.rank());
1690 Shape broadcast_shape = ShapeUtil::ChangeElementType(
1691 output_shape, operand->shape().element_type());
1692 // Do explicit broadcast for scalar.
1693 if (ShapeUtil::IsScalar(operand->shape())) {
1694 auto broadcast =
1695 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1696 broadcast->set_metadata(operand->metadata());
1697 if (operand->has_sharding()) {
1698 broadcast->set_sharding(operand->sharding());
1699 }
1700 broadcast->set_frontend_attributes(operand->frontend_attributes());
1701 return broadcast;
1702 }
1703 // Do explicit broadcast for degenerate broadcast.
1704 std::vector<int64_t> broadcast_dimensions;
1705 std::vector<int64_t> reshaped_dimensions;
1706 for (int i = 0; i < operand->shape().rank(); i++) {
1707 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1708 broadcast_dimensions.push_back(i);
1709 reshaped_dimensions.push_back(operand->shape().dimensions(i));
1710 } else {
1711 CHECK_EQ(operand->shape().dimensions(i), 1)
1712 << "An explicit broadcast sequence requires the broadcasted "
1713 "dimensions to be trivial; operand: "
1714 << operand->ToString() << "; output_shape: " << output_shape;
1715 }
1716 }
1717 // Eliminate the size one dimensions.
1718 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1719 ShapeUtil::MakeShape(operand->shape().element_type(),
1720 reshaped_dimensions),
1721 operand));
1722 reshaped_operand->set_metadata(operand->metadata());
1723 if (operand->has_sharding()) {
1724 reshaped_operand->set_sharding(operand->sharding());
1725 }
1726 reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1727 // Broadcast 'reshape' up to the larger size.
1728 auto broadcast = HloInstruction::CreateBroadcast(
1729 broadcast_shape, reshaped_operand, broadcast_dimensions);
1730 broadcast->set_metadata(operand->metadata());
1731 if (operand->has_sharding()) {
1732 broadcast->set_sharding(operand->sharding());
1733 }
1734 broadcast->set_frontend_attributes(operand->frontend_attributes());
1735 return broadcast;
1736 }
1737
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1738 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1739 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1740 const PaddingConfig& padding_config) {
1741 return std::make_unique<HloPadInstruction>(shape, operand, padding_value,
1742 padding_config);
1743 }
1744
CreateReshape(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1745 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1746 const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) {
1747 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1748 ShapeUtil::ElementsIn(operand->shape()))
1749 << "shape: " << ShapeUtil::HumanString(shape)
1750 << " operand: " << ShapeUtil::HumanString(operand->shape());
1751
1752 return std::make_unique<HloReshapeInstruction>(shape, operand,
1753 inferred_dimension);
1754 }
1755
1756 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1757 HloInstruction::CreateDynamicReshape(
1758 const Shape& shape, HloInstruction* data_operand,
1759 absl::Span<HloInstruction* const> dim_sizes) {
1760 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1761 ShapeUtil::ElementsIn(data_operand[0].shape()))
1762 << "shape: " << ShapeUtil::HumanString(shape)
1763 << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1764 CHECK_EQ(shape.rank(), dim_sizes.size());
1765 return std::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1766 dim_sizes);
1767 }
1768
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64_t> dimensions)1769 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1770 const Shape& shape, HloInstruction* operand,
1771 absl::Span<const int64_t> dimensions) {
1772 return std::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1773 }
1774
CreateSort(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1775 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1776 const Shape& shape, int64_t dimension,
1777 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1778 bool is_stable) {
1779 return std::make_unique<HloSortInstruction>(shape, dimension, operands,
1780 compare, is_stable);
1781 }
1782
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1783 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1784 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1785 return std::make_unique<HloFusionInstruction>(shape, fusion_kind, fused_root);
1786 }
1787
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation,absl::string_view prefix)1788 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1789 const Shape& shape, FusionKind fusion_kind,
1790 absl::Span<HloInstruction* const> operands,
1791 HloComputation* fusion_computation, absl::string_view prefix) {
1792 return std::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1793 fusion_computation, prefix);
1794 }
1795
set_single_sharding(const HloSharding & sharding)1796 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1797 CHECK(!sharding.IsTuple()) << sharding;
1798 if (shape().IsTuple()) {
1799 set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1800 } else {
1801 set_sharding(sharding);
1802 }
1803 }
1804
SetupDerivedInstruction(HloInstruction * derived_instruction) const1805 void HloInstruction::SetupDerivedInstruction(
1806 HloInstruction* derived_instruction) const {
1807 if (sharding_ != nullptr &&
1808 ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1809 // Only copy sharding if the tuple tree shape of the two instruction is
1810 // compatible because copying it between differently shaped instructions
1811 // can produce invalid shardings.
1812 derived_instruction->set_sharding(*sharding_);
1813 } else {
1814 derived_instruction->clear_sharding();
1815 }
1816 derived_instruction->set_metadata(metadata_);
1817 derived_instruction->set_frontend_attributes(frontend_attributes_);
1818 }
1819
IsRoot() const1820 bool HloInstruction::IsRoot() const {
1821 return parent_ != nullptr && this == parent_->root_instruction();
1822 }
1823
HasSideEffectNoRecurse() const1824 bool HloInstruction::HasSideEffectNoRecurse() const {
1825 switch (opcode_) {
1826 case HloOpcode::kSend:
1827 case HloOpcode::kSendDone:
1828 case HloOpcode::kRecv:
1829 case HloOpcode::kRecvDone:
1830 case HloOpcode::kRng:
1831 case HloOpcode::kRngGetAndUpdateState:
1832 case HloOpcode::kInfeed:
1833 case HloOpcode::kOutfeed:
1834 case HloOpcode::kAllReduceStart:
1835 case HloOpcode::kAllReduceDone:
1836 case HloOpcode::kAllGatherStart:
1837 case HloOpcode::kAllGatherDone:
1838 case HloOpcode::kCollectivePermuteStart:
1839 case HloOpcode::kCollectivePermuteDone:
1840 return true;
1841 case HloOpcode::kAllReduce:
1842 return channel_id().has_value() ||
1843 Cast<HloAllReduceInstruction>(this)->constrain_layout();
1844 case HloOpcode::kAllToAll:
1845 return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1846 case HloOpcode::kCustomCall:
1847 return Cast<HloCustomCallInstruction>(this)
1848 ->custom_call_has_side_effect();
1849 default:
1850 return false;
1851 }
1852 }
1853
HasSideEffect() const1854 bool HloInstruction::HasSideEffect() const {
1855 if (HasSideEffectNoRecurse()) {
1856 return true;
1857 }
1858 // Check if any of the called computations has a side effect.
1859 for (const auto& computation : called_computations()) {
1860 if (computation->HasSideEffect()) {
1861 return true;
1862 }
1863 }
1864 return false;
1865 }
1866
CreateCall(const Shape & shape,HloInstruction * called_computation_root)1867 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1868 const Shape& shape, HloInstruction* called_computation_root) {
1869 return std::make_unique<HloCallInstruction>(shape, called_computation_root);
1870 }
1871
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1872 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1873 const Shape& shape, absl::Span<HloInstruction* const> operands,
1874 HloComputation* computation) {
1875 return std::make_unique<HloCallInstruction>(shape, operands, computation);
1876 }
1877
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1878 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1879 const Shape& shape, absl::Span<HloInstruction* const> operands,
1880 absl::string_view custom_call_target, std::string opaque,
1881 CustomCallApiVersion api_version) {
1882 return std::make_unique<HloCustomCallInstruction>(
1883 shape, operands, custom_call_target, std::move(opaque), api_version);
1884 }
1885
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1886 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1887 const Shape& shape, absl::Span<HloInstruction* const> operands,
1888 HloComputation* to_apply, absl::string_view custom_call_target,
1889 std::string opaque, CustomCallApiVersion api_version) {
1890 return std::make_unique<HloCustomCallInstruction>(
1891 shape, operands, to_apply, custom_call_target, std::move(opaque),
1892 api_version);
1893 }
1894
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,std::string opaque,CustomCallApiVersion api_version)1895 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1896 const Shape& shape, absl::Span<HloInstruction* const> operands,
1897 absl::Span<HloComputation* const> called_computations,
1898 absl::string_view custom_call_target, std::string opaque,
1899 CustomCallApiVersion api_version) {
1900 return std::make_unique<HloCustomCallInstruction>(
1901 shape, operands, called_computations, custom_call_target,
1902 std::move(opaque), api_version);
1903 }
1904
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,std::string opaque,CustomCallApiVersion api_version)1905 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1906 const Shape& shape, absl::Span<HloInstruction* const> operands,
1907 absl::string_view custom_call_target,
1908 absl::Span<const Shape> operand_shapes_with_layout, std::string opaque,
1909 CustomCallApiVersion api_version) {
1910 return std::make_unique<HloCustomCallInstruction>(
1911 shape, operands, custom_call_target, std::move(opaque),
1912 operand_shapes_with_layout, api_version);
1913 }
1914
CreateTuple(absl::Span<HloInstruction * const> elements)1915 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1916 absl::Span<HloInstruction* const> elements) {
1917 std::vector<const Shape*> element_shapes;
1918 element_shapes.reserve(elements.size());
1919 for (auto element : elements) {
1920 element_shapes.push_back(&element->shape());
1921 }
1922 Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(element_shapes);
1923 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1924 }
1925
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)1926 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1927 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1928 const GatherDimensionNumbers& gather_dim_numbers,
1929 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
1930 return std::make_unique<HloGatherInstruction>(shape, operand, start_indices,
1931 gather_dim_numbers, slice_sizes,
1932 indices_are_sorted);
1933 }
1934
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1935 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1936 const Shape& shape, HloInstruction* operand,
1937 HloInstruction* scatter_indices, HloInstruction* updates,
1938 HloComputation* update_computation,
1939 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1940 bool unique_indices) {
1941 return absl::WrapUnique(new HloScatterInstruction(
1942 shape, {operand, scatter_indices, updates}, update_computation,
1943 scatter_dim_numbers, indices_are_sorted, unique_indices));
1944 }
1945
CreateScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloInstruction * scatter_indices,absl::Span<HloInstruction * const> updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1946 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1947 const Shape& shape, absl::Span<HloInstruction* const> operands,
1948 HloInstruction* scatter_indices, absl::Span<HloInstruction* const> updates,
1949 HloComputation* update_computation,
1950 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1951 bool unique_indices) {
1952 absl::InlinedVector<HloInstruction*, 3> args;
1953 args.reserve(operands.size() + updates.size() + 1);
1954 absl::c_copy(operands, std::back_inserter(args));
1955 args.push_back(scatter_indices);
1956 absl::c_copy(updates, std::back_inserter(args));
1957 return std::make_unique<HloScatterInstruction>(
1958 shape, args, update_computation, scatter_dim_numbers, indices_are_sorted,
1959 unique_indices);
1960 }
1961
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1962 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1963 const Shape& shape, HloInstruction* operand,
1964 std::unique_ptr<DomainMetadata> operand_side_metadata,
1965 std::unique_ptr<DomainMetadata> user_side_metadata) {
1966 return std::make_unique<HloDomainInstruction>(
1967 shape, operand, std::move(operand_side_metadata),
1968 std::move(user_side_metadata));
1969 }
1970
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1971 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1972 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1973 HloCloneContext* context) const {
1974 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1975 VLOG(3) << " new operands:";
1976 for (const HloInstruction* new_operand : new_operands) {
1977 VLOG(3) << " %" << new_operand->name();
1978 }
1979
1980 std::unique_ptr<HloInstruction> clone;
1981 // Explicitly call the factory for the instruction type. This is more robust
1982 // in the face of code changes than copying fields explicitly. This also
1983 // properly sets the user fields of the operands.
1984 switch (opcode_) {
1985 // Ops migrated to subclasses.
1986 // TODO(b/80131774): Remove this switch when migration is complete.
1987 case HloOpcode::kBatchNormTraining:
1988 case HloOpcode::kBatchNormInference:
1989 case HloOpcode::kBatchNormGrad:
1990 case HloOpcode::kFft:
1991 case HloOpcode::kCompare:
1992 case HloOpcode::kAsyncStart:
1993 case HloOpcode::kAsyncUpdate:
1994 case HloOpcode::kAsyncDone:
1995 case HloOpcode::kCopyStart:
1996 case HloOpcode::kSend:
1997 case HloOpcode::kSendDone:
1998 case HloOpcode::kRecv:
1999 case HloOpcode::kRecvDone:
2000 case HloOpcode::kReverse:
2001 case HloOpcode::kConcatenate:
2002 case HloOpcode::kReduce:
2003 case HloOpcode::kTranspose:
2004 case HloOpcode::kBroadcast:
2005 case HloOpcode::kReshape:
2006 case HloOpcode::kDynamicReshape:
2007 case HloOpcode::kMap:
2008 case HloOpcode::kSlice:
2009 case HloOpcode::kConstant:
2010 case HloOpcode::kFusion:
2011 case HloOpcode::kRng:
2012 case HloOpcode::kRngBitGenerator:
2013 case HloOpcode::kRngGetAndUpdateState:
2014 case HloOpcode::kParameter:
2015 case HloOpcode::kGetTupleElement:
2016 case HloOpcode::kReducePrecision:
2017 case HloOpcode::kAllGather:
2018 case HloOpcode::kAllGatherStart:
2019 case HloOpcode::kAllReduce:
2020 case HloOpcode::kReduceScatter:
2021 case HloOpcode::kAllReduceStart:
2022 case HloOpcode::kAllToAll:
2023 case HloOpcode::kCollectivePermute:
2024 case HloOpcode::kCollectivePermuteStart:
2025 case HloOpcode::kInfeed:
2026 case HloOpcode::kOutfeed:
2027 case HloOpcode::kConvolution:
2028 case HloOpcode::kCustomCall:
2029 case HloOpcode::kReduceWindow:
2030 case HloOpcode::kSelectAndScatter:
2031 case HloOpcode::kPad:
2032 case HloOpcode::kDynamicSlice:
2033 case HloOpcode::kSort:
2034 case HloOpcode::kGather:
2035 case HloOpcode::kScatter:
2036 case HloOpcode::kIota:
2037 case HloOpcode::kDot:
2038 case HloOpcode::kDomain:
2039 case HloOpcode::kGetDimensionSize:
2040 case HloOpcode::kSetDimensionSize:
2041 case HloOpcode::kTriangularSolve:
2042 case HloOpcode::kCholesky:
2043 clone = CloneWithNewOperandsImpl(shape, new_operands, context);
2044 break;
2045 // Unary ops.
2046 case HloOpcode::kAbs:
2047 case HloOpcode::kAllGatherDone:
2048 case HloOpcode::kAllReduceDone:
2049 case HloOpcode::kRoundNearestAfz:
2050 case HloOpcode::kRoundNearestEven:
2051 case HloOpcode::kBitcast:
2052 case HloOpcode::kCeil:
2053 case HloOpcode::kClz:
2054 case HloOpcode::kCollectivePermuteDone:
2055 case HloOpcode::kCopy:
2056 case HloOpcode::kOptimizationBarrier:
2057 case HloOpcode::kCopyDone:
2058 case HloOpcode::kCos:
2059 case HloOpcode::kExp:
2060 case HloOpcode::kExpm1:
2061 case HloOpcode::kImag:
2062 case HloOpcode::kIsFinite:
2063 case HloOpcode::kFloor:
2064 case HloOpcode::kLog:
2065 case HloOpcode::kLog1p:
2066 case HloOpcode::kNot:
2067 case HloOpcode::kNegate:
2068 case HloOpcode::kPopulationCount:
2069 case HloOpcode::kReal:
2070 case HloOpcode::kRsqrt:
2071 case HloOpcode::kLogistic:
2072 case HloOpcode::kSign:
2073 case HloOpcode::kSin:
2074 case HloOpcode::kSqrt:
2075 case HloOpcode::kCbrt:
2076 case HloOpcode::kTanh:
2077 CHECK_EQ(new_operands.size(), 1);
2078 clone = CreateUnary(shape, opcode_, new_operands[0]);
2079 break;
2080 // Binary ops.
2081 case HloOpcode::kAdd:
2082 case HloOpcode::kAtan2:
2083 case HloOpcode::kComplex:
2084 case HloOpcode::kDivide:
2085 case HloOpcode::kMultiply:
2086 case HloOpcode::kSubtract:
2087 case HloOpcode::kMaximum:
2088 case HloOpcode::kMinimum:
2089 case HloOpcode::kPower:
2090 case HloOpcode::kRemainder:
2091 case HloOpcode::kAnd:
2092 case HloOpcode::kOr:
2093 case HloOpcode::kXor:
2094 case HloOpcode::kShiftLeft:
2095 case HloOpcode::kShiftRightArithmetic:
2096 case HloOpcode::kShiftRightLogical:
2097 CHECK_EQ(new_operands.size(), 2);
2098 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
2099 break;
2100 // Ternary ops.
2101 case HloOpcode::kClamp:
2102 case HloOpcode::kSelect:
2103 CHECK_EQ(new_operands.size(), 3);
2104 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
2105 new_operands[2]);
2106 break;
2107 // Other supported ops.
2108 case HloOpcode::kCall:
2109 clone = CreateCall(shape, new_operands, to_apply());
2110 break;
2111 case HloOpcode::kConvert:
2112 CHECK_EQ(new_operands.size(), 1);
2113 clone = CreateConvert(shape, new_operands[0]);
2114 break;
2115 case HloOpcode::kBitcastConvert:
2116 CHECK_EQ(new_operands.size(), 1);
2117 clone = CreateBitcastConvert(shape, new_operands[0]);
2118 break;
2119 case HloOpcode::kDynamicUpdateSlice:
2120 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
2121 new_operands.subspan(2));
2122 break;
2123 case HloOpcode::kTuple:
2124 clone = CreateTuple(new_operands);
2125 *clone->mutable_shape() = shape;
2126 break;
2127 case HloOpcode::kWhile:
2128 CHECK_EQ(new_operands.size(), 1);
2129 clone =
2130 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
2131 break;
2132 case HloOpcode::kConditional:
2133 CHECK_EQ(new_operands.size(), branch_count() + 1);
2134 clone = CreateConditional(shape, new_operands[0],
2135 absl::MakeSpan(branch_computations()),
2136 new_operands.subspan(1));
2137 break;
2138 case HloOpcode::kAfterAll:
2139 if (new_operands.empty()) {
2140 clone = CreateToken();
2141 } else {
2142 clone = CreateAfterAll(new_operands);
2143 }
2144 break;
2145 case HloOpcode::kAddDependency:
2146 CHECK_EQ(new_operands.size(), 2);
2147 clone = CreateAddDependency(new_operands[0], new_operands[1]);
2148 break;
2149 case HloOpcode::kReplicaId:
2150 CHECK_EQ(new_operands.size(), 0);
2151 clone = CreateReplicaId(shape);
2152 break;
2153 case HloOpcode::kPartitionId:
2154 CHECK_EQ(new_operands.size(), 0);
2155 clone = CreatePartitionId(shape);
2156 break;
2157 }
2158 // SetupDerivedInstruction will setup the precision_config_ field.
2159 SetupDerivedInstruction(clone.get());
2160 clone->set_parent(parent_);
2161 clone->set_outer_dimension_partitions(outer_dimension_partitions_);
2162 clone->backend_config_ = backend_config_.Clone();
2163 // The new instruction's name will be uniquified when it's added to a
2164 // computation.
2165 clone->SetAndSanitizeName(name());
2166 if (context != nullptr) {
2167 context->MapInstruction(this, clone.get());
2168 clone->ReplaceCalledComputations([&](HloComputation* callee) {
2169 return callee->parent() != context->module()
2170 ? context->module()->DeepCloneComputation(callee, context)
2171 : callee;
2172 });
2173 }
2174 return clone;
2175 }
2176
DetachFromOperandsAndUsers()2177 void HloInstruction::DetachFromOperandsAndUsers() {
2178 if (cleaned_up_) {
2179 return;
2180 }
2181 cleaned_up_ = true;
2182 // Detach from operands. An instruction may be repeated as an operand. To
2183 // avoid calling RemoveUser twice on the same operand, check before remove.
2184 for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) {
2185 HloInstruction* operand = operands_[operand_num];
2186 if (operand == nullptr) {
2187 continue;
2188 }
2189 if (operand->user_map_.find(this) != operand->user_map_.end()) {
2190 operand->RemoveUser(this);
2191 }
2192 operands_[operand_num] = nullptr;
2193 }
2194
2195 // Update users. Set `nullptr` to the corresponding operand slot for users.
2196 for (auto& user : this->users()) {
2197 for (int i = 0; i < user->operand_count(); ++i) {
2198 if (user->operands_[i] == this) {
2199 user->operands_[i] = nullptr;
2200 }
2201 }
2202 }
2203 }
2204
CloneWithNewShape(const Shape & shape,const std::string & suffix,HloCloneContext * context) const2205 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
2206 const Shape& shape, const std::string& suffix,
2207 HloCloneContext* context) const {
2208 std::unique_ptr<HloInstruction> clone =
2209 CloneWithNewOperands(shape, operands_, context);
2210 if (suffix.empty()) {
2211 clone->name_ = name();
2212 } else {
2213 // If an instruction is cloned multiple times avoid names like
2214 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
2215 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
2216 // clone of foo.suffix2 is named foo.suffix3 and so on.
2217 const std::string dot_suffix = "." + suffix;
2218 size_t index = name().rfind(dot_suffix);
2219 if (index == std::string::npos) {
2220 // Existing name does not include ".suffix".
2221 clone->name_ = name() + dot_suffix;
2222 } else {
2223 // Existing name includes ".suffix". Determine if substring after
2224 // ".suffix" is numeric and should be replaced with an incremented number.
2225 std::string after_suffix = name().substr(index + dot_suffix.size());
2226 if (after_suffix.empty()) {
2227 // Existing name ends in ".suffix". New name should end in ".suffix2".
2228 clone->name_ = name() + "2";
2229 } else {
2230 // If names ends with .suffix[0-9]+ then replace with a suffix with the
2231 // numeric value incremented.
2232 int64_t numeric_suffix;
2233 if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
2234 clone->name_ =
2235 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
2236 } else {
2237 // Substring after ".suffix" is non-numeric.
2238 clone->name_ = name() + dot_suffix;
2239 }
2240 }
2241 }
2242 }
2243 return clone;
2244 }
2245
Clone(const std::string & suffix,HloCloneContext * context) const2246 std::unique_ptr<HloInstruction> HloInstruction::Clone(
2247 const std::string& suffix, HloCloneContext* context) const {
2248 std::unique_ptr<HloInstruction> clone =
2249 CloneWithNewShape(shape_, suffix, context);
2250 return clone;
2251 }
2252
2253 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const2254 HloInstruction::LatestNonGteAncestorAndIndex() const {
2255 const HloInstruction* hlo = this;
2256 ShapeIndex index;
2257 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2258 index.push_back(hlo->tuple_index());
2259 hlo = hlo->operand(0);
2260 }
2261
2262 // We built up index in the reverse order from what we want.
2263 std::reverse(index.begin(), index.end());
2264
2265 return {hlo, index};
2266 }
2267
LatestNonGteAncestor() const2268 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
2269 const HloInstruction* hlo = this;
2270 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2271 hlo = hlo->operand(0);
2272 }
2273 return hlo;
2274 }
2275
operand(int64_t i) const2276 const HloInstruction* HloInstruction::operand(int64_t i) const {
2277 return operands_.at(i);
2278 }
2279
mutable_operand(int64_t i)2280 HloInstruction* HloInstruction::mutable_operand(int64_t i) {
2281 CHECK(operands_[i] != nullptr);
2282 return operands_.at(i);
2283 }
2284
operand_index(const HloInstruction * target) const2285 int64_t HloInstruction::operand_index(const HloInstruction* target) const {
2286 for (int64_t i = 0; i < operand_count(); ++i) {
2287 if (target == operand(i)) {
2288 return i;
2289 }
2290 }
2291 LOG(FATAL) << "target was not an operand: " << target->ToString();
2292 }
2293
unique_operands() const2294 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
2295 InstructionVector unique;
2296 absl::flat_hash_set<const HloInstruction*> seen;
2297 for (HloInstruction* operand : operands()) {
2298 if (seen.insert(operand).second) {
2299 unique.push_back(operand);
2300 }
2301 }
2302 return unique;
2303 }
2304
AddControlDependencyTo(HloInstruction * instruction)2305 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
2306 TF_RET_CHECK(instruction->parent() == parent());
2307 if (!absl::c_linear_search(control_successors_, instruction)) {
2308 control_successors_.push_back(instruction);
2309 TF_RET_CHECK(
2310 !absl::c_linear_search(instruction->control_predecessors_, this));
2311 instruction->control_predecessors_.push_back(this);
2312 }
2313 return OkStatus();
2314 }
2315
RemoveControlDependencyTo(HloInstruction * instruction)2316 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
2317 TF_RET_CHECK(instruction->parent() == parent());
2318 TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
2319 TF_RETURN_IF_ERROR(
2320 EraseElementFromVector(&instruction->control_predecessors_, this));
2321 return OkStatus();
2322 }
2323
DropAllControlDeps()2324 Status HloInstruction::DropAllControlDeps() {
2325 for (auto* ctrl_succ : control_successors_) {
2326 TF_RETURN_IF_ERROR(
2327 EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
2328 }
2329 for (auto* ctrl_pred : control_predecessors_) {
2330 TF_RETURN_IF_ERROR(
2331 EraseElementFromVector(&ctrl_pred->control_successors_, this));
2332 }
2333 control_successors_.clear();
2334 control_predecessors_.clear();
2335 return OkStatus();
2336 }
2337
CopyAllControlDepsFrom(const HloInstruction * inst)2338 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2339 for (auto* ctrl_pred : inst->control_predecessors()) {
2340 TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2341 }
2342
2343 for (auto* ctrl_succ : inst->control_successors()) {
2344 TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2345 }
2346
2347 return OkStatus();
2348 }
2349
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values,bool ignore_commutative_operand_order) const2350 bool HloInstruction::IdenticalInternal(
2351 const HloInstruction& other,
2352 const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2353 eq_operands,
2354 const std::function<bool(const HloComputation*, const HloComputation*)>&
2355 eq_computations,
2356 bool layout_sensitive, bool ignore_channel_id_values,
2357 bool ignore_commutative_operand_order) const {
2358 // An instruction is always identical to itself.
2359 if (this == &other) {
2360 return true;
2361 }
2362
2363 // Identical instruction must have the same opcode, shape, and identical
2364 // operands.
2365 if (opcode() != other.opcode()) {
2366 return false;
2367 }
2368 if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2369 : ShapeUtil::Compatible(shape(), other.shape()))) {
2370 return false;
2371 }
2372 if (operands().size() != other.operands().size()) {
2373 return false;
2374 }
2375
2376 // Check that operands are equal.
2377 //
2378 // Use an explicit loop rather than ContainerEquals, because copying around
2379 // std::functions may be too expensive in some cases.
2380 if (ignore_commutative_operand_order &&
2381 HloOpcodeIsBinaryCommutative(opcode())) {
2382 CHECK_EQ(operand_count(), 2);
2383 if (!(eq_operands(operand(0), other.operand(0)) &&
2384 eq_operands(operand(1), other.operand(1))) &&
2385 !(eq_operands(operand(0), other.operand(1)) &&
2386 eq_operands(operand(1), other.operand(0)))) {
2387 return false;
2388 }
2389 } else {
2390 for (size_t i = 0; i < operands().size(); ++i) {
2391 if (!eq_operands(operand(i), other.operand(i))) {
2392 return false;
2393 }
2394 }
2395 }
2396
2397 if (backend_config_ != other.backend_config_) {
2398 return false;
2399 }
2400
2401 if (ignore_channel_id_values) {
2402 if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2403 return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2404 other, eq_computations);
2405 }
2406 }
2407 return IdenticalSlowPath(other, eq_computations);
2408 }
2409
AppendOperand(HloInstruction * operand)2410 void HloInstruction::AppendOperand(HloInstruction* operand) {
2411 if (operand->parent() != nullptr) {
2412 DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2413 << "Operand " << operand->name() << " is already marked dead";
2414 }
2415 operands_.push_back(operand);
2416 operand->AddUser(this);
2417 }
2418
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2419 void HloInstruction::RemoveOperandsAtAscendingIndices(
2420 absl::Span<const int> ascending_indices) {
2421 if (ascending_indices.empty()) {
2422 return;
2423 }
2424 int next_index = 0;
2425 int removed_count = 0;
2426 for (int to_remove : ascending_indices) {
2427 while (next_index < to_remove) {
2428 operands_[next_index - removed_count] = operands_[next_index];
2429 ++next_index;
2430 }
2431 CHECK_LT(to_remove, operands_.size());
2432 ++removed_count;
2433 ++next_index;
2434 }
2435 while (next_index < operands_.size()) {
2436 operands_[next_index - removed_count] = operands_[next_index];
2437 ++next_index;
2438 }
2439 CHECK_EQ(removed_count, ascending_indices.size());
2440 operands_.resize(operands_.size() - removed_count);
2441 }
2442
AddUser(HloInstruction * user)2443 void HloInstruction::AddUser(HloInstruction* user) {
2444 if (!ContainsKey(user_map_, user)) {
2445 user_map_.emplace(user, users_.size());
2446 users_.push_back(user);
2447 }
2448 }
2449
UserId(HloInstruction * user)2450 int64_t HloInstruction::UserId(HloInstruction* user) {
2451 auto result = user_map_.find(user);
2452 CHECK(result != user_map_.end());
2453 return result->second;
2454 }
2455
HasConstantOperand() const2456 bool HloInstruction::HasConstantOperand() const {
2457 for (const HloInstruction* operand : operands_) {
2458 if (operand->IsConstant()) {
2459 return true;
2460 }
2461 }
2462 return false;
2463 }
2464
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2465 bool HloInstruction::IdenticalSlowPath(
2466 const HloInstruction& other,
2467 const std::function<bool(const HloComputation*, const HloComputation*)>&
2468 eq_computations) const {
2469 // Perform opcode specific checks.
2470 switch (opcode()) {
2471 // The result of these instructions only depend upon their opcode and
2472 // operands.
2473 case HloOpcode::kAbs:
2474 case HloOpcode::kAllGatherDone:
2475 case HloOpcode::kAllReduceDone:
2476 case HloOpcode::kAtan2:
2477 case HloOpcode::kAdd:
2478 case HloOpcode::kBitcast:
2479 case HloOpcode::kBitcastConvert:
2480 case HloOpcode::kCeil:
2481 case HloOpcode::kClamp:
2482 case HloOpcode::kClz:
2483 case HloOpcode::kCollectivePermuteDone:
2484 case HloOpcode::kComplex:
2485 case HloOpcode::kConvert:
2486 case HloOpcode::kCopy:
2487 case HloOpcode::kCopyStart:
2488 case HloOpcode::kCopyDone:
2489 case HloOpcode::kCos:
2490 case HloOpcode::kDivide:
2491 case HloOpcode::kDynamicUpdateSlice:
2492 case HloOpcode::kExp:
2493 case HloOpcode::kExpm1:
2494 case HloOpcode::kFloor:
2495 case HloOpcode::kImag:
2496 case HloOpcode::kIsFinite:
2497 case HloOpcode::kLog:
2498 case HloOpcode::kLog1p:
2499 case HloOpcode::kAnd:
2500 case HloOpcode::kNot:
2501 case HloOpcode::kOr:
2502 case HloOpcode::kXor:
2503 case HloOpcode::kMaximum:
2504 case HloOpcode::kMinimum:
2505 case HloOpcode::kMultiply:
2506 case HloOpcode::kNegate:
2507 case HloOpcode::kOptimizationBarrier:
2508 case HloOpcode::kPartitionId:
2509 case HloOpcode::kPopulationCount:
2510 case HloOpcode::kPower:
2511 case HloOpcode::kReal:
2512 case HloOpcode::kRemainder:
2513 case HloOpcode::kReshape:
2514 case HloOpcode::kDynamicReshape:
2515 case HloOpcode::kReplicaId:
2516 case HloOpcode::kRoundNearestAfz:
2517 case HloOpcode::kRoundNearestEven:
2518 case HloOpcode::kRsqrt:
2519 case HloOpcode::kSelect:
2520 case HloOpcode::kShiftLeft:
2521 case HloOpcode::kShiftRightArithmetic:
2522 case HloOpcode::kShiftRightLogical:
2523 case HloOpcode::kLogistic:
2524 case HloOpcode::kSign:
2525 case HloOpcode::kSin:
2526 case HloOpcode::kSqrt:
2527 case HloOpcode::kCbrt:
2528 case HloOpcode::kSubtract:
2529 case HloOpcode::kTanh:
2530 case HloOpcode::kTuple:
2531 return true;
2532
2533 // This opcode has complex or special behavior so just return false.
2534 case HloOpcode::kAfterAll:
2535 case HloOpcode::kAddDependency:
2536 return false;
2537
2538 // Remaining instructions with special values.
2539 case HloOpcode::kCall:
2540 return eq_computations(to_apply(), other.to_apply());
2541 case HloOpcode::kConditional:
2542 for (int j = 0; j < branch_count(); ++j) {
2543 if (!eq_computations(branch_computation(j),
2544 other.branch_computation(j))) {
2545 return false;
2546 }
2547 }
2548 return true;
2549 case HloOpcode::kWhile:
2550 return (eq_computations(while_body(), other.while_body()) &&
2551 eq_computations(while_condition(), other.while_condition()));
2552
2553 // Ops migrated to subclasses should never come to this line.
2554 // TODO(b/80131774): Remove this switch when migration is complete.
2555 case HloOpcode::kAsyncStart:
2556 case HloOpcode::kAsyncUpdate:
2557 case HloOpcode::kAsyncDone:
2558 case HloOpcode::kBatchNormTraining:
2559 case HloOpcode::kBatchNormInference:
2560 case HloOpcode::kBatchNormGrad:
2561 case HloOpcode::kFft:
2562 case HloOpcode::kCompare:
2563 case HloOpcode::kSend:
2564 case HloOpcode::kSendDone:
2565 case HloOpcode::kRecv:
2566 case HloOpcode::kRecvDone:
2567 case HloOpcode::kReverse:
2568 case HloOpcode::kConcatenate:
2569 case HloOpcode::kReduce:
2570 case HloOpcode::kSort:
2571 case HloOpcode::kTranspose:
2572 case HloOpcode::kBroadcast:
2573 case HloOpcode::kMap:
2574 case HloOpcode::kSlice:
2575 case HloOpcode::kConstant:
2576 case HloOpcode::kIota:
2577 case HloOpcode::kFusion:
2578 case HloOpcode::kRng:
2579 case HloOpcode::kRngBitGenerator:
2580 case HloOpcode::kRngGetAndUpdateState:
2581 case HloOpcode::kParameter:
2582 case HloOpcode::kGetTupleElement:
2583 case HloOpcode::kReducePrecision:
2584 case HloOpcode::kInfeed:
2585 case HloOpcode::kOutfeed:
2586 case HloOpcode::kAllGather:
2587 case HloOpcode::kAllGatherStart:
2588 case HloOpcode::kAllReduce:
2589 case HloOpcode::kReduceScatter:
2590 case HloOpcode::kAllReduceStart:
2591 case HloOpcode::kAllToAll:
2592 case HloOpcode::kCollectivePermute:
2593 case HloOpcode::kCollectivePermuteStart:
2594 case HloOpcode::kConvolution:
2595 case HloOpcode::kCustomCall:
2596 case HloOpcode::kReduceWindow:
2597 case HloOpcode::kSelectAndScatter:
2598 case HloOpcode::kPad:
2599 case HloOpcode::kDynamicSlice:
2600 case HloOpcode::kGather:
2601 case HloOpcode::kScatter:
2602 case HloOpcode::kDot:
2603 case HloOpcode::kDomain:
2604 case HloOpcode::kGetDimensionSize:
2605 case HloOpcode::kSetDimensionSize:
2606 case HloOpcode::kTriangularSolve:
2607 case HloOpcode::kCholesky:
2608 LOG(FATAL) << "Base class impl called for opcode with subclass: "
2609 << opcode();
2610 }
2611 return false;
2612 }
2613
RemoveUser(HloInstruction * user)2614 void HloInstruction::RemoveUser(HloInstruction* user) {
2615 auto map_it = user_map_.find(user);
2616 CHECK(map_it != user_map_.end());
2617
2618 const int64_t index = map_it->second;
2619 CHECK_EQ(users_[index], user);
2620
2621 // Move the last user into the position of the removed user.
2622 users_[index] = users_.back();
2623 user_map_[users_.back()] = index;
2624
2625 // Remove the user from the map and drop the last slot from the vector what
2626 // have been moved to the position of the original user.
2627 user_map_.erase(map_it);
2628 users_.pop_back();
2629 }
2630
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2631 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2632 HloInstruction* new_producer) {
2633 TF_RET_CHECK(
2634 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2635 << "this shape: " << ShapeUtil::HumanString(shape())
2636 << ", replacement shape: "
2637 << ShapeUtil::HumanString(new_producer->shape());
2638 return ReplaceUseWithDifferentShape(user, new_producer);
2639 }
2640
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2641 Status HloInstruction::ReplaceUseWithDifferentShape(
2642 HloInstruction* user, HloInstruction* new_producer) {
2643 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2644 << " with " << new_producer->name();
2645
2646 RemoveUser(user);
2647
2648 TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2649 std::replace(user->operands_.begin(), user->operands_.end(), this,
2650 new_producer);
2651 new_producer->AddUser(user);
2652 // Custom fusions may not be able to handle deduplicated operands.
2653 if (user->opcode() == HloOpcode::kFusion) {
2654 TF_RETURN_IF_ERROR(
2655 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2656 }
2657 return OkStatus();
2658 }
2659
ReplaceUseWith(HloInstruction * user,int operand_number,HloInstruction * new_producer)2660 Status HloInstruction::ReplaceUseWith(HloInstruction* user, int operand_number,
2661 HloInstruction* new_producer) {
2662 TF_RET_CHECK(
2663 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2664 << "this shape: " << ShapeUtil::HumanString(shape())
2665 << ", replacement shape: "
2666 << ShapeUtil::HumanString(new_producer->shape());
2667 return ReplaceUseWithDifferentShape(user, operand_number, new_producer);
2668 }
2669
ReplaceUseWithDifferentShape(HloInstruction * user,int operand_number,HloInstruction * new_producer)2670 Status HloInstruction::ReplaceUseWithDifferentShape(
2671 HloInstruction* user, int operand_number, HloInstruction* new_producer) {
2672 VLOG(3) << "Replacing operand " << operand_number << " of " << name()
2673 << " in " << user->name() << " with " << new_producer->name();
2674
2675 if (absl::c_count(user->operands_, this) == 1) {
2676 RemoveUser(user);
2677 }
2678
2679 TF_RET_CHECK(user->operand(operand_number) == this)
2680 << "Expected operand " << operand_number << " of " << user->ToString()
2681 << " to be equal to " << ToString();
2682 user->operands_[operand_number] = new_producer;
2683 new_producer->AddUser(user);
2684 return OkStatus();
2685 }
2686
ReplaceOperandWith(int64_t operand_num,HloInstruction * new_operand)2687 Status HloInstruction::ReplaceOperandWith(int64_t operand_num,
2688 HloInstruction* new_operand) {
2689 auto old_operand = operand(operand_num);
2690 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2691 new_operand->shape()))
2692 << old_operand->shape() << " is not compatible with "
2693 << new_operand->shape();
2694 return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2695 }
2696
ReplaceOperandWithDifferentShape(int64_t operand_num,HloInstruction * new_operand)2697 Status HloInstruction::ReplaceOperandWithDifferentShape(
2698 int64_t operand_num, HloInstruction* new_operand) {
2699 TF_RET_CHECK(operand_num >= 0);
2700 TF_RET_CHECK(operand_num < operand_count());
2701 HloInstruction* old_operand = mutable_operand(operand_num);
2702 if (old_operand == new_operand) {
2703 return OkStatus();
2704 }
2705
2706 operands_[operand_num] = new_operand;
2707
2708 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2709 << new_operand->name() << ", was " << old_operand->name();
2710
2711 if (!absl::c_linear_search(operands_, old_operand)) {
2712 old_operand->RemoveUser(this);
2713 }
2714 new_operand->AddUser(this);
2715 return OkStatus();
2716 }
2717
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2718 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2719 HloInstruction* new_producer) {
2720 TF_RET_CHECK(
2721 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2722 << shape() << " is not compatible with " << new_producer->shape();
2723 return ReplaceAllUsesWithDifferentShape(users, new_producer);
2724 }
2725
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2726 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2727 absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2728 for (HloInstruction* user : users) {
2729 TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2730 }
2731
2732 if (parent_ && parent_->root_instruction() == this) {
2733 parent_->set_root_instruction(new_producer,
2734 /*accept_different_shape=*/true);
2735 }
2736 return OkStatus();
2737 }
2738
ReplaceAllUsesWith(HloInstruction * new_producer)2739 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2740 TF_RET_CHECK(
2741 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2742 << shape() << " is not compatible with " << new_producer->shape();
2743 return ReplaceAllUsesWithDifferentShape(new_producer);
2744 }
2745
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2746 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2747 HloInstruction* new_producer) {
2748 bool new_producer_is_user = false;
2749 for (HloInstruction* user : users()) {
2750 if (user == new_producer) {
2751 // It's possible that new_producer is a user of this instruction as might
2752 // be the case when replacing an instruction with a kCopy of itself. In
2753 // this case, don't do the replacement to avoid creating a cycle in the
2754 // graph. new_producer remains the only user of this instruction.
2755 new_producer_is_user = true;
2756 } else {
2757 std::replace(user->operands_.begin(), user->operands_.end(), this,
2758 new_producer);
2759 new_producer->AddUser(user);
2760 if (user->opcode() == HloOpcode::kFusion) {
2761 TF_RETURN_IF_ERROR(
2762 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2763 }
2764 }
2765 }
2766 users_.clear();
2767 user_map_.clear();
2768 if (new_producer_is_user) {
2769 AddUser(new_producer);
2770 }
2771 if (parent_ && parent_->root_instruction() == this) {
2772 parent_->set_root_instruction(new_producer,
2773 /*accept_different_shape=*/true);
2774 }
2775
2776 return OkStatus();
2777 }
2778
IsEffectiveBitcast() const2779 bool HloInstruction::IsEffectiveBitcast() const {
2780 return opcode_ == HloOpcode::kBitcast ||
2781 (opcode_ == HloOpcode::kTranspose &&
2782 ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2783 dimensions()));
2784 }
2785
to_apply() const2786 HloComputation* HloInstruction::to_apply() const {
2787 switch (opcode_) {
2788 case HloOpcode::kCall:
2789 case HloOpcode::kMap:
2790 case HloOpcode::kReduceWindow:
2791 case HloOpcode::kReduce:
2792 case HloOpcode::kAllReduce:
2793 case HloOpcode::kReduceScatter:
2794 case HloOpcode::kAllReduceStart:
2795 case HloOpcode::kScatter:
2796 case HloOpcode::kSort:
2797 case HloOpcode::kCustomCall:
2798 CHECK_EQ(called_computations_.size(), 1);
2799 return called_computations_[0];
2800 default:
2801 LOG(FATAL) << "Invalid opcode for to_apply(): "
2802 << HloOpcodeString(opcode());
2803 }
2804 }
2805
set_to_apply(HloComputation * computation)2806 void HloInstruction::set_to_apply(HloComputation* computation) {
2807 // Don't allow changing the computation for fused instructions so we don't
2808 // have to recompute called_instructions for the entire fusion instruction.
2809 CHECK(!IsFused());
2810 switch (opcode_) {
2811 case HloOpcode::kCall:
2812 case HloOpcode::kMap:
2813 case HloOpcode::kReduceWindow:
2814 case HloOpcode::kReduce:
2815 case HloOpcode::kAllReduce:
2816 case HloOpcode::kAllReduceStart:
2817 case HloOpcode::kScatter:
2818 case HloOpcode::kSort:
2819 case HloOpcode::kCustomCall:
2820 CHECK_EQ(called_computations_.size(), 1);
2821 called_computations_[0] = computation;
2822 break;
2823 default:
2824 LOG(FATAL) << "Invalid opcode for to_apply(): "
2825 << HloOpcodeString(opcode());
2826 }
2827 }
2828
while_condition() const2829 HloComputation* HloInstruction::while_condition() const {
2830 CHECK_EQ(HloOpcode::kWhile, opcode_);
2831 return called_computations_[kConditionComputationIndex];
2832 }
2833
while_body() const2834 HloComputation* HloInstruction::while_body() const {
2835 CHECK_EQ(HloOpcode::kWhile, opcode_);
2836 return called_computations_[kBodyComputationIndex];
2837 }
2838
set_while_condition(HloComputation * computation)2839 void HloInstruction::set_while_condition(HloComputation* computation) {
2840 // Don't allow changing the computation for fused instructions so we don't
2841 // have to recompute called_instructions for the entire fusion instruction.
2842 CHECK(!IsFused());
2843 CHECK_EQ(HloOpcode::kWhile, opcode_);
2844 called_computations_[kConditionComputationIndex] = computation;
2845 }
2846
set_while_body(HloComputation * computation)2847 void HloInstruction::set_while_body(HloComputation* computation) {
2848 // Don't allow changing the computation for fused instructions so we don't
2849 // have to recompute called_instructions for the entire fusion instruction.
2850 CHECK(!IsFused());
2851 CHECK_EQ(HloOpcode::kWhile, opcode_);
2852 called_computations_[kBodyComputationIndex] = computation;
2853 }
2854
while_init() const2855 HloInstruction* HloInstruction::while_init() const {
2856 CHECK_EQ(HloOpcode::kWhile, opcode_);
2857 return operands_[0];
2858 }
2859
true_computation() const2860 HloComputation* HloInstruction::true_computation() const {
2861 CHECK_EQ(HloOpcode::kConditional, opcode_);
2862 CHECK_EQ(PRED, operand(0)->shape().element_type());
2863 return called_computations_[kTrueComputationIndex];
2864 }
2865
false_computation() const2866 HloComputation* HloInstruction::false_computation() const {
2867 CHECK_EQ(HloOpcode::kConditional, opcode_);
2868 CHECK_EQ(PRED, operand(0)->shape().element_type());
2869 return called_computations_[kFalseComputationIndex];
2870 }
2871
branch_computations() const2872 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2873 const {
2874 CHECK(HloOpcode::kConditional == opcode_);
2875 return called_computations_;
2876 }
2877
branch_count() const2878 int HloInstruction::branch_count() const {
2879 CHECK(HloOpcode::kConditional == opcode_);
2880 return called_computations_.size();
2881 }
2882
branch_computation(int b) const2883 HloComputation* HloInstruction::branch_computation(int b) const {
2884 CHECK(HloOpcode::kConditional == opcode_);
2885 CHECK_GE(b, 0);
2886 CHECK_LT(b, called_computations_.size());
2887 return called_computations_[b];
2888 }
2889
set_branch_computation(int b,HloComputation * computation)2890 void HloInstruction::set_branch_computation(int b,
2891 HloComputation* computation) {
2892 // Don't allow changing the computation for fused instructions so we don't
2893 // have to recompute called_instructions for the entire fusion instruction.
2894 CHECK(!IsFused());
2895 CHECK_EQ(HloOpcode::kConditional, opcode_);
2896 called_computations_[b] = computation;
2897 }
2898
SignatureString() const2899 std::string HloInstruction::SignatureString() const {
2900 std::string operands =
2901 StrJoin(operands_, ", ", [](std::string* out, HloInstruction* operand) {
2902 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2903 });
2904 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2905 }
2906
PrintName(const std::string & name,bool print_ids)2907 std::string PrintName(const std::string& name, bool print_ids) {
2908 if (print_ids) {
2909 return name;
2910 } else {
2911 auto dot_position = name.find_first_of('.');
2912 return name.substr(0, dot_position);
2913 }
2914 }
2915
2916 namespace {
2917
2918 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2919
PrintNameInternal(const std::string & name,const HloPrintOptions & options)2920 std::string PrintNameInternal(const std::string& name,
2921 const HloPrintOptions& options) {
2922 return StrCat(options.print_percent() ? "%" : "",
2923 PrintName(name, options.print_ids()));
2924 }
2925
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2926 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2927 // This set contains HloInstructions from the top of `DFSStack` that might
2928 // belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
2929 // `subgraph` := {child,...,top}.
2930 absl::flat_hash_set<const HloInstruction*> subgraph;
2931 while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2932 subgraph.insert(dfs_stack->back().second);
2933 dfs_stack->pop_back();
2934 }
2935 // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2936 absl::flat_hash_set<const HloInstruction*> visited;
2937 absl::InlinedVector<const HloInstruction*, 16> dfs;
2938 dfs.push_back(child);
2939 while (!dfs.empty()) {
2940 bool found_next_instr = false;
2941 for (const auto& user : dfs.back()->users()) {
2942 if (user == child) {
2943 dfs.push_back(child);
2944 LOG(INFO) << "\n\nDirected cycle:\n "
2945 << absl::StrJoin(
2946 dfs, "\n ",
2947 [](std::string* out, const HloInstruction* instr) {
2948 out->append(instr->name());
2949 });
2950 return;
2951 }
2952 if (!subgraph.contains(user) || visited.contains(user)) {
2953 continue;
2954 }
2955 visited.insert(user);
2956 dfs.push_back(user);
2957 found_next_instr = true;
2958 }
2959 if (!found_next_instr) {
2960 dfs.pop_back();
2961 }
2962 }
2963 }
2964
2965 } // namespace
2966
ToString(const HloPrintOptions & options) const2967 std::string HloInstruction::ToString(const HloPrintOptions& options) const {
2968 CanonicalNameMap new_map;
2969 return ToStringWithCanonicalNameMap(options, &new_map);
2970 }
2971
IsOpElementwise(HloOpcode opcode)2972 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2973 switch (opcode) {
2974 // Unary elementwise operations.
2975 case HloOpcode::kAbs:
2976 case HloOpcode::kRoundNearestAfz:
2977 case HloOpcode::kRoundNearestEven:
2978 case HloOpcode::kCeil:
2979 case HloOpcode::kClz:
2980 case HloOpcode::kConvert:
2981 case HloOpcode::kBitcastConvert:
2982 case HloOpcode::kCopy:
2983 case HloOpcode::kCos:
2984 case HloOpcode::kExp:
2985 case HloOpcode::kExpm1:
2986 case HloOpcode::kFloor:
2987 case HloOpcode::kImag:
2988 case HloOpcode::kIsFinite:
2989 case HloOpcode::kLog:
2990 case HloOpcode::kLog1p:
2991 case HloOpcode::kNot:
2992 case HloOpcode::kNegate:
2993 case HloOpcode::kPopulationCount:
2994 case HloOpcode::kReal:
2995 case HloOpcode::kReducePrecision:
2996 case HloOpcode::kRsqrt:
2997 case HloOpcode::kLogistic:
2998 case HloOpcode::kSign:
2999 case HloOpcode::kSin:
3000 case HloOpcode::kSqrt:
3001 case HloOpcode::kCbrt:
3002 case HloOpcode::kTanh:
3003 return true;
3004
3005 // Binary elementwise operations, the same as in IsElementwiseBinary().
3006 case HloOpcode::kAdd:
3007 case HloOpcode::kAtan2:
3008 case HloOpcode::kCompare:
3009 case HloOpcode::kComplex:
3010 case HloOpcode::kDivide:
3011 case HloOpcode::kMaximum:
3012 case HloOpcode::kMinimum:
3013 case HloOpcode::kMultiply:
3014 case HloOpcode::kPower:
3015 case HloOpcode::kRemainder:
3016 case HloOpcode::kSubtract:
3017 case HloOpcode::kAnd:
3018 case HloOpcode::kOr:
3019 case HloOpcode::kXor:
3020 case HloOpcode::kShiftLeft:
3021 case HloOpcode::kShiftRightArithmetic:
3022 case HloOpcode::kShiftRightLogical:
3023 return true;
3024
3025 // Ternary elementwise operations.
3026 case HloOpcode::kSelect:
3027 case HloOpcode::kClamp:
3028 return true;
3029
3030 default:
3031 return false;
3032 }
3033 }
3034
IsElementwiseImpl(const std::optional<int64_t> & operand_idx) const3035 bool HloInstruction::IsElementwiseImpl(
3036 const std::optional<int64_t>& operand_idx) const {
3037 if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
3038 return operand_idx.has_value() && operand_idx.value() == 0;
3039 }
3040 if (opcode_ == HloOpcode::kBitcastConvert &&
3041 primitive_util::BitWidth(shape_.element_type()) !=
3042 primitive_util::BitWidth(operands_[0]->shape().element_type())) {
3043 return false;
3044 }
3045 return IsOpElementwise(opcode_);
3046 }
3047
IsCrossModuleAllReduce() const3048 bool HloInstruction::IsCrossModuleAllReduce() const {
3049 return opcode() == HloOpcode::kAllReduce && channel_id();
3050 }
3051
IsCrossReplicaAllReduce() const3052 bool HloInstruction::IsCrossReplicaAllReduce() const {
3053 return opcode() == HloOpcode::kAllReduce && !channel_id();
3054 }
3055
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const3056 std::string HloInstruction::ToStringWithCanonicalNameMap(
3057 const HloPrintOptions& options,
3058 CanonicalNameMap* canonical_name_map) const {
3059 std::string result = "";
3060
3061 // Logic to print the instruction name (e.g. "%foo = ").
3062 if (options.canonicalize_instruction_names()) {
3063 if (options.is_in_nested_computation()) {
3064 // If we are canonicalizing instruction names and this is a top-level
3065 // HloInstruction::ToString() call, don't print an instruction name.
3066 DCHECK(!options.print_percent()); // no need to call PrintNameInternal
3067 StrAppend(&result, canonical_name_map->LookupOrInsert(name()), " = ");
3068 }
3069 } else {
3070 StrAppend(&result, PrintNameInternal(name(), options), " = ");
3071 }
3072
3073 if (options.print_result_shape()) {
3074 // Print shape.
3075 if (options.include_layout_in_shapes()) {
3076 StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
3077 } else {
3078 StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
3079 }
3080 }
3081
3082 // Print opcode, operand(s).
3083 if (options.syntax_sugar_async_ops() && HloOpcodeIsAsync(opcode())) {
3084 std::string suffix = [&]() {
3085 switch (opcode()) {
3086 case HloOpcode::kAsyncStart:
3087 return "-start";
3088 case HloOpcode::kAsyncUpdate:
3089 return "-update";
3090 default:
3091 CHECK(opcode() == HloOpcode::kAsyncDone)
3092 << "Unexpected async opcode: " << HloOpcodeString(opcode());
3093 return "-done";
3094 }
3095 }();
3096 StrAppend(&result, HloOpcodeString(async_wrapped_opcode()), suffix);
3097 } else {
3098 StrAppend(&result, HloOpcodeString(opcode()));
3099 }
3100 StrAppend(&result, "(",
3101 OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
3102 ")");
3103
3104 // Print additional attributes. If an instruction contains a subcomputation,
3105 // the subcomputation is also printed here.
3106 for (const std::string& extra : ExtraAttributesToString(options)) {
3107 StrAppend(&result, ", ", extra);
3108 }
3109
3110 if (options.print_metadata() &&
3111 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
3112 !metadata_.source_file().empty())) {
3113 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
3114 }
3115 if (options.print_backend_config() && !backend_config_.empty()) {
3116 StrAppend(&result, ", backend_config=\"",
3117 CEscape(backend_config_.GetRawString()), "\"");
3118 }
3119 return result;
3120 }
3121
OperandsToString(const HloPrintOptions & options) const3122 std::string HloInstruction::OperandsToString(
3123 const HloPrintOptions& options) const {
3124 CanonicalNameMap new_map;
3125 return OperandsToStringWithCanonicalNameMap(options, &new_map);
3126 }
3127
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const3128 std::string HloInstruction::OperandsToStringWithCanonicalNameMap(
3129 const HloPrintOptions& options,
3130 CanonicalNameMap* canonical_name_map) const {
3131 std::string operands;
3132 absl::Span<HloInstruction* const> slice(operands_);
3133 const int64_t kMaxOperandsToShowIfCompact = 4;
3134 if (options.compact_operands() &&
3135 slice.size() > kMaxOperandsToShowIfCompact) {
3136 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
3137 }
3138 for (int64_t i = 0; i < slice.size(); ++i) {
3139 HloInstruction* operand = slice[i];
3140 if (i != 0) {
3141 StrAppend(&operands, ", ");
3142 if (options.print_operand_index_annotation_interval() != 0 &&
3143 i % options.print_operand_index_annotation_interval() == 0) {
3144 StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i));
3145 }
3146 }
3147 // If operand is already been deleted, put `null` to the string output.
3148 if (operand == nullptr) {
3149 StrAppend(&operands, "null ");
3150 continue;
3151 }
3152 std::vector<std::string> str;
3153 if (options.print_operand_shape()) {
3154 if (options.include_layout_in_shapes()) {
3155 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
3156 } else {
3157 str.push_back(ShapeUtil::HumanString(operand->shape()));
3158 }
3159 }
3160 if (options.canonicalize_instruction_names()) {
3161 if (options.is_in_nested_computation()) {
3162 // In a top-level HloInstruction::ToString() call, the operand name is
3163 // not part of the canonical string.
3164 DCHECK(!options.print_percent()); // no need to call PrintNameInternal
3165 str.push_back(canonical_name_map->LookupOrInsert(operand->name()));
3166 }
3167 } else if (options.print_operand_names()) {
3168 str.push_back(PrintNameInternal(operand->name(), options));
3169 }
3170 StrAppend(&operands, StrJoin(str, " "));
3171 }
3172 const int64_t remaining = operands_.size() - slice.size();
3173 if (slice.size() != operands_.size()) {
3174 StrAppend(&operands, ", ...(+", remaining, ")");
3175 }
3176 return operands;
3177 }
3178
3179 namespace {
3180
IsSequentialCall(HloOpcode opcode)3181 bool IsSequentialCall(HloOpcode opcode) {
3182 switch (opcode) {
3183 case HloOpcode::kCall:
3184 case HloOpcode::kConditional:
3185 case HloOpcode::kWhile:
3186 return true;
3187 default:
3188 return false;
3189 }
3190 }
3191
3192 } // namespace
3193
ExtraAttributesToString(const HloPrintOptions & options) const3194 std::vector<std::string> HloInstruction::ExtraAttributesToString(
3195 const HloPrintOptions& options) const {
3196 std::vector<std::string> extra = options.print_extra_attributes()
3197 ? ExtraAttributesToStringImpl(options)
3198 : std::vector<std::string>();
3199
3200 const auto subcomputation_mode = options.print_subcomputation_mode();
3201 if (subcomputation_mode ==
3202 HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
3203 if (opcode() == HloOpcode::kWhile) {
3204 extra.push_back(StrCat(
3205 "condition=", PrintNameInternal(while_condition()->name(), options)));
3206 extra.push_back(
3207 StrCat("body=", PrintNameInternal(while_body()->name(), options)));
3208 } else if (opcode() == HloOpcode::kSelectAndScatter) {
3209 extra.push_back(
3210 StrCat("select=", PrintNameInternal(select()->name(), options)));
3211 extra.push_back(
3212 StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
3213 } else if (opcode() == HloOpcode::kConditional) {
3214 if (operand(0)->shape().element_type() == PRED) {
3215 extra.push_back(
3216 StrCat("true_computation=",
3217 PrintNameInternal(true_computation()->name(), options)));
3218 extra.push_back(
3219 StrCat("false_computation=",
3220 PrintNameInternal(false_computation()->name(), options)));
3221 } else {
3222 extra.push_back(StrCat(
3223 "branch_computations={",
3224 StrJoin(branch_computations(), ", ",
3225 [&](std::string* out, const HloComputation* computation) {
3226 StrAppend(
3227 out, PrintNameInternal(computation->name(), options));
3228 }),
3229 "}"));
3230 }
3231 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
3232 opcode() == HloOpcode::kReduceWindow ||
3233 opcode() == HloOpcode::kReduce ||
3234 opcode() == HloOpcode::kAllReduce ||
3235 opcode() == HloOpcode::kReduceScatter ||
3236 opcode() == HloOpcode::kAllReduceStart ||
3237 opcode() == HloOpcode::kScatter ||
3238 opcode() == HloOpcode::kSort) {
3239 if (!called_computations().empty()) {
3240 extra.push_back(StrCat("to_apply=",
3241 PrintNameInternal(to_apply()->name(), options)));
3242 }
3243 } else if (opcode() == HloOpcode::kCustomCall) {
3244 if (!called_computations().empty()) {
3245 extra.push_back(StrCat(
3246 "called_computations={",
3247 StrJoin(called_computations(), ", ",
3248 [&](std::string* out, const HloComputation* computation) {
3249 StrAppend(
3250 out, PrintNameInternal(computation->name(), options));
3251 }),
3252 "}"));
3253 }
3254 } else if (HloOpcodeIsAsync(opcode())) {
3255 if (!options.syntax_sugar_async_ops()) {
3256 extra.push_back(StrCat(
3257 "calls=",
3258 PrintNameInternal(async_wrapped_computation()->name(), options)));
3259 }
3260 } else if (!called_computations().empty()) {
3261 extra.push_back(StrCat(
3262 "calls=",
3263 StrJoin(called_computations(), ", ",
3264 [&](std::string* out, const HloComputation* computation) {
3265 StrAppend(out,
3266 PrintNameInternal(computation->name(), options));
3267 })));
3268 }
3269 } else if ((subcomputation_mode ==
3270 HloPrintOptions::PrintSubcomputationMode::kFullBodies) ||
3271 (subcomputation_mode == HloPrintOptions::PrintSubcomputationMode::
3272 kNonSequentialBodies &&
3273 !IsSequentialCall(opcode()))) {
3274 HloPrintOptions new_options = options;
3275 new_options.set_is_in_nested_computation(true);
3276 switch (opcode()) {
3277 case HloOpcode::kWhile:
3278 extra.push_back(
3279 StrCat("condition=\n", while_condition()->ToString(new_options)));
3280 extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
3281 break;
3282 case HloOpcode::kSelectAndScatter:
3283 extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
3284 extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
3285 break;
3286 case HloOpcode::kConditional:
3287 if (operand(0)->shape().element_type() == PRED) {
3288 extra.push_back(StrCat("true_computation=\n",
3289 true_computation()->ToString(new_options)));
3290 extra.push_back(StrCat("false_computation=\n",
3291 false_computation()->ToString(new_options)));
3292 } else {
3293 extra.push_back(StrCat(
3294 "branch_computations={\n",
3295 StrJoin(branch_computations(), ",\n",
3296 [&](std::string* out, const HloComputation* computation) {
3297 StrAppend(out, computation->ToString(new_options));
3298 }),
3299 "\n}"));
3300 }
3301 break;
3302 case HloOpcode::kCall:
3303 case HloOpcode::kMap:
3304 case HloOpcode::kReduceWindow:
3305 case HloOpcode::kReduce:
3306 case HloOpcode::kAllReduce:
3307 case HloOpcode::kAllReduceStart:
3308 case HloOpcode::kScatter:
3309 case HloOpcode::kSort:
3310 if (!called_computations().empty()) {
3311 extra.push_back(
3312 StrCat("to_apply=\n", to_apply()->ToString(new_options)));
3313 }
3314 break;
3315 default:
3316 if (!called_computations().empty()) {
3317 extra.push_back(StrCat(
3318 "calls=\n",
3319 StrJoin(called_computations(), ", ",
3320 [&](std::string* out, const HloComputation* computation) {
3321 StrAppend(out, computation->ToString(new_options));
3322 })));
3323 }
3324 break;
3325 }
3326 }
3327
3328 if (has_sharding()) {
3329 extra.push_back(
3330 StrCat("sharding=", sharding().ToString(options.print_metadata())));
3331 }
3332 if (!frontend_attributes_.map().empty()) {
3333 extra.push_back(StrCat("frontend_attributes=",
3334 FrontendAttributesToString(frontend_attributes_)));
3335 }
3336 if (!outer_dimension_partitions_.empty()) {
3337 extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
3338 StrJoin(outer_dimension_partitions_, ",")));
3339 }
3340
3341 if (options.print_control_dependencies() && !control_predecessors_.empty()) {
3342 extra.push_back(StrCat("control-predecessors={",
3343 StrJoin(control_predecessors_, ", ",
3344 [&](std::string* out, HloInstruction* pre) {
3345 StrAppend(out, PrintNameInternal(
3346 pre->name(), options));
3347 }),
3348 "}"));
3349 }
3350
3351 return extra;
3352 }
3353
ToShortString() const3354 std::string HloInstruction::ToShortString() const {
3355 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
3356 StrJoin(operands_, ", ",
3357 [](std::string* out, HloInstruction* operand) {
3358 StrAppend(out, "%", operand->name());
3359 }),
3360 ")");
3361 }
3362
ToProto() const3363 HloInstructionProto HloInstruction::ToProto() const {
3364 HloInstructionProto proto;
3365 CHECK(unique_id_ != -1)
3366 << "This instruction does not have a valid id. Please make sure the "
3367 "instruction is inside a module before dumping it.";
3368 proto.set_id(unique_id_);
3369 proto.set_name(name_);
3370 proto.set_opcode(HloOpcodeString(opcode_));
3371 *proto.mutable_shape() = shape_.ToProto();
3372 for (const HloInstruction* operand : operands_) {
3373 proto.add_operand_ids(operand->unique_id());
3374 }
3375 for (const HloInstruction* control : control_predecessors_) {
3376 proto.add_control_predecessor_ids(control->unique_id());
3377 }
3378
3379 *proto.mutable_metadata() = metadata_;
3380 proto.set_backend_config(backend_config_.GetRawString());
3381 if (opcode() != HloOpcode::kFusion) {
3382 for (const HloComputation* computation : called_computations_) {
3383 proto.add_called_computation_ids(computation->unique_id());
3384 }
3385 }
3386
3387 if (has_sharding()) {
3388 *proto.mutable_sharding() = sharding().ToProto();
3389 }
3390 if (!outer_dimension_partitions_.empty()) {
3391 for (const auto& idx : outer_dimension_partitions_) {
3392 proto.mutable_outer_dimension_partitions()->Add(idx);
3393 }
3394 }
3395
3396 *proto.mutable_frontend_attributes() = frontend_attributes_;
3397
3398 return proto;
3399 }
3400
ToCategory() const3401 std::string HloInstruction::ToCategory() const {
3402 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
3403 opcode() == HloOpcode::kReshape ||
3404 opcode() == HloOpcode::kDynamicReshape) {
3405 return "data formatting";
3406 }
3407
3408 if (IsElementwise()) {
3409 return "non-fusion elementwise";
3410 }
3411
3412 return HloOpcodeString(opcode());
3413 }
3414
IsFused() const3415 bool HloInstruction::IsFused() const {
3416 return parent_ != nullptr && parent_->IsFusionComputation();
3417 }
3418
IsCustomCall(absl::string_view target) const3419 bool HloInstruction::IsCustomCall(absl::string_view target) const {
3420 return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3421 }
3422
IsInputFusion() const3423 bool HloInstruction::IsInputFusion() const {
3424 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3425 }
3426
IsLoopFusion() const3427 bool HloInstruction::IsLoopFusion() const {
3428 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3429 }
3430
IsOutputFusion() const3431 bool HloInstruction::IsOutputFusion() const {
3432 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3433 }
3434
IsCustomFusion() const3435 bool HloInstruction::IsCustomFusion() const {
3436 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3437 }
3438
IsFusible() const3439 bool HloInstruction::IsFusible() const {
3440 // Some kinds of instructions don't make sense to fuse.
3441 switch (opcode_) {
3442 case HloOpcode::kDomain:
3443 case HloOpcode::kParameter:
3444 case HloOpcode::kWhile:
3445 case HloOpcode::kConditional:
3446 case HloOpcode::kCall:
3447 return false;
3448 // Fusions are always fusible.
3449 case HloOpcode::kFusion:
3450 // Side effecting reduce and reduce window would be invalid HLO.
3451 case HloOpcode::kMap:
3452 case HloOpcode::kReduce:
3453 case HloOpcode::kReduceWindow:
3454 return true;
3455 case HloOpcode::kRng:
3456 return user_count() <= 1;
3457 // Side effecting instructions cannot be fused.
3458 default:
3459 return !HasSideEffect();
3460 }
3461 }
3462
HloInstruction(HloOpcode opcode,const Shape & shape)3463 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3464 : unique_id_(-1),
3465 opcode_(opcode),
3466 shape_(shape),
3467 name_(HloOpcodeString(opcode)),
3468 marked_as_dead_(false) {
3469 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3470 }
3471
3472 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3473 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3474 switch (opcode_) {
3475 case HloOpcode::kAbs:
3476 return visitor->HandleAbs(this);
3477 case HloOpcode::kAtan2:
3478 return visitor->HandleAtan2(this);
3479 case HloOpcode::kRoundNearestAfz:
3480 return visitor->HandleRound(this);
3481 case HloOpcode::kRoundNearestEven:
3482 return visitor->HandleRoundNearestEven(this);
3483 case HloOpcode::kBatchNormTraining:
3484 return visitor->HandleBatchNormTraining(this);
3485 case HloOpcode::kBatchNormInference:
3486 return visitor->HandleBatchNormInference(this);
3487 case HloOpcode::kBatchNormGrad:
3488 return visitor->HandleBatchNormGrad(this);
3489 case HloOpcode::kLogistic:
3490 return visitor->HandleLogistic(this);
3491 case HloOpcode::kSign:
3492 return visitor->HandleSign(this);
3493 case HloOpcode::kConstant:
3494 return visitor->HandleConstant(this);
3495 case HloOpcode::kGetTupleElement:
3496 return visitor->HandleGetTupleElement(this);
3497 case HloOpcode::kParameter:
3498 return visitor->HandleParameter(this);
3499 case HloOpcode::kCompare:
3500 return visitor->HandleCompare(this);
3501 case HloOpcode::kComplex:
3502 return visitor->HandleComplex(this);
3503 case HloOpcode::kAdd:
3504 return visitor->HandleAdd(this);
3505 case HloOpcode::kDivide:
3506 return visitor->HandleDivide(this);
3507 case HloOpcode::kSubtract:
3508 return visitor->HandleSubtract(this);
3509 case HloOpcode::kMaximum:
3510 return visitor->HandleMaximum(this);
3511 case HloOpcode::kMinimum:
3512 return visitor->HandleMinimum(this);
3513 case HloOpcode::kAnd:
3514 return visitor->HandleAnd(this);
3515 case HloOpcode::kOr:
3516 return visitor->HandleOr(this);
3517 case HloOpcode::kXor:
3518 return visitor->HandleXor(this);
3519 case HloOpcode::kShiftLeft:
3520 return visitor->HandleShiftLeft(this);
3521 case HloOpcode::kShiftRightArithmetic:
3522 return visitor->HandleShiftRightArithmetic(this);
3523 case HloOpcode::kShiftRightLogical:
3524 return visitor->HandleShiftRightLogical(this);
3525 case HloOpcode::kConcatenate:
3526 return visitor->HandleConcatenate(this);
3527 case HloOpcode::kConvert:
3528 return visitor->HandleConvert(this);
3529 case HloOpcode::kBitcastConvert:
3530 return visitor->HandleBitcastConvert(this);
3531 case HloOpcode::kCopy:
3532 return visitor->HandleCopy(this);
3533 case HloOpcode::kMultiply:
3534 return visitor->HandleMultiply(this);
3535 case HloOpcode::kDot:
3536 return visitor->HandleDot(this);
3537 case HloOpcode::kPower:
3538 return visitor->HandlePower(this);
3539 case HloOpcode::kRemainder:
3540 return visitor->HandleRemainder(this);
3541 case HloOpcode::kSelect:
3542 return visitor->HandleSelect(this);
3543 case HloOpcode::kConvolution:
3544 return visitor->HandleConvolution(this);
3545 case HloOpcode::kFft:
3546 return visitor->HandleFft(this);
3547 case HloOpcode::kAllGather:
3548 return visitor->HandleAllGather(this);
3549 case HloOpcode::kAllGatherStart:
3550 return visitor->HandleAllGatherStart(this);
3551 case HloOpcode::kAllGatherDone:
3552 return visitor->HandleAllGatherDone(this);
3553 case HloOpcode::kAllReduce:
3554 return visitor->HandleAllReduce(this);
3555 case HloOpcode::kReduceScatter:
3556 return visitor->HandleReduceScatter(this);
3557 case HloOpcode::kAllReduceStart:
3558 return visitor->HandleAllReduceStart(this);
3559 case HloOpcode::kAllReduceDone:
3560 return visitor->HandleAllReduceDone(this);
3561 case HloOpcode::kAllToAll:
3562 return visitor->HandleAllToAll(this);
3563 case HloOpcode::kCollectivePermute:
3564 return visitor->HandleCollectivePermute(this);
3565 case HloOpcode::kCollectivePermuteStart:
3566 return visitor->HandleCollectivePermuteStart(this);
3567 case HloOpcode::kCollectivePermuteDone:
3568 return visitor->HandleCollectivePermuteDone(this);
3569 case HloOpcode::kReplicaId:
3570 return visitor->HandleReplicaId(this);
3571 case HloOpcode::kPartitionId:
3572 return visitor->HandlePartitionId(this);
3573 case HloOpcode::kTuple:
3574 return visitor->HandleTuple(this);
3575 case HloOpcode::kMap:
3576 return visitor->HandleMap(this);
3577 case HloOpcode::kClamp:
3578 return visitor->HandleClamp(this);
3579 case HloOpcode::kReduce:
3580 return visitor->HandleReduce(this);
3581 case HloOpcode::kReduceWindow:
3582 return visitor->HandleReduceWindow(this);
3583 case HloOpcode::kSelectAndScatter:
3584 return visitor->HandleSelectAndScatter(this);
3585 case HloOpcode::kNegate:
3586 return visitor->HandleNegate(this);
3587 case HloOpcode::kExp:
3588 return visitor->HandleExp(this);
3589 case HloOpcode::kExpm1:
3590 return visitor->HandleExpm1(this);
3591 case HloOpcode::kFloor:
3592 return visitor->HandleFloor(this);
3593 case HloOpcode::kCeil:
3594 return visitor->HandleCeil(this);
3595 case HloOpcode::kClz:
3596 return visitor->HandleClz(this);
3597 case HloOpcode::kLog:
3598 return visitor->HandleLog(this);
3599 case HloOpcode::kLog1p:
3600 return visitor->HandleLog1p(this);
3601 case HloOpcode::kTanh:
3602 return visitor->HandleTanh(this);
3603 case HloOpcode::kCos:
3604 return visitor->HandleCos(this);
3605 case HloOpcode::kSin:
3606 return visitor->HandleSin(this);
3607 case HloOpcode::kSqrt:
3608 return visitor->HandleSqrt(this);
3609 case HloOpcode::kCbrt:
3610 return visitor->HandleCbrt(this);
3611 case HloOpcode::kRsqrt:
3612 return visitor->HandleRsqrt(this);
3613 case HloOpcode::kReal:
3614 return visitor->HandleReal(this);
3615 case HloOpcode::kImag:
3616 return visitor->HandleImag(this);
3617 case HloOpcode::kIsFinite:
3618 return visitor->HandleIsFinite(this);
3619 case HloOpcode::kNot:
3620 return visitor->HandleNot(this);
3621 case HloOpcode::kPopulationCount:
3622 return visitor->HandlePopulationCount(this);
3623 case HloOpcode::kBitcast:
3624 return visitor->HandleBitcast(this);
3625 case HloOpcode::kBroadcast:
3626 return visitor->HandleBroadcast(this);
3627 case HloOpcode::kPad:
3628 return visitor->HandlePad(this);
3629 case HloOpcode::kReshape:
3630 return visitor->HandleReshape(this);
3631 case HloOpcode::kDynamicReshape:
3632 return visitor->HandleDynamicReshape(this);
3633 case HloOpcode::kTranspose:
3634 return visitor->HandleTranspose(this);
3635 case HloOpcode::kReverse:
3636 return visitor->HandleReverse(this);
3637 case HloOpcode::kReducePrecision:
3638 return visitor->HandleReducePrecision(this);
3639 case HloOpcode::kSlice:
3640 return visitor->HandleSlice(this);
3641 case HloOpcode::kDynamicSlice:
3642 return visitor->HandleDynamicSlice(this);
3643 case HloOpcode::kDynamicUpdateSlice:
3644 return visitor->HandleDynamicUpdateSlice(this);
3645 case HloOpcode::kSort:
3646 return visitor->HandleSort(this);
3647 case HloOpcode::kInfeed:
3648 return visitor->HandleInfeed(this);
3649 case HloOpcode::kOutfeed:
3650 return visitor->HandleOutfeed(this);
3651 case HloOpcode::kRng:
3652 return visitor->HandleRng(this);
3653 case HloOpcode::kRngBitGenerator:
3654 return visitor->HandleRngBitGenerator(this);
3655 case HloOpcode::kRngGetAndUpdateState:
3656 return visitor->HandleRngGetAndUpdateState(this);
3657 case HloOpcode::kWhile:
3658 return visitor->HandleWhile(this);
3659 case HloOpcode::kFusion:
3660 return visitor->HandleFusion(this);
3661 case HloOpcode::kCall:
3662 return visitor->HandleCall(this);
3663 case HloOpcode::kConditional:
3664 return visitor->HandleConditional(this);
3665 case HloOpcode::kCustomCall:
3666 return visitor->HandleCustomCall(this);
3667 case HloOpcode::kAsyncStart:
3668 return visitor->HandleAsyncStart(this);
3669 case HloOpcode::kAsyncUpdate:
3670 return visitor->HandleAsyncUpdate(this);
3671 case HloOpcode::kAsyncDone:
3672 return visitor->HandleAsyncDone(this);
3673 case HloOpcode::kCopyStart:
3674 return visitor->HandleCopyStart(this);
3675 case HloOpcode::kCopyDone:
3676 return visitor->HandleCopyDone(this);
3677 case HloOpcode::kRecv:
3678 return visitor->HandleRecv(this);
3679 case HloOpcode::kRecvDone:
3680 return visitor->HandleRecvDone(this);
3681 case HloOpcode::kSend:
3682 return visitor->HandleSend(this);
3683 case HloOpcode::kSendDone:
3684 return visitor->HandleSendDone(this);
3685 case HloOpcode::kGather:
3686 return visitor->HandleGather(this);
3687 case HloOpcode::kScatter:
3688 return visitor->HandleScatter(this);
3689 case HloOpcode::kDomain:
3690 return visitor->HandleDomain(this);
3691 case HloOpcode::kAfterAll:
3692 return visitor->HandleAfterAll(this);
3693 case HloOpcode::kAddDependency:
3694 return visitor->HandleAddDependency(this);
3695 case HloOpcode::kIota:
3696 return visitor->HandleIota(this);
3697 case HloOpcode::kGetDimensionSize:
3698 return visitor->HandleGetDimensionSize(this);
3699 case HloOpcode::kSetDimensionSize:
3700 return visitor->HandleSetDimensionSize(this);
3701 case HloOpcode::kTriangularSolve:
3702 return visitor->HandleTriangularSolve(this);
3703 case HloOpcode::kCholesky:
3704 return visitor->HandleCholesky(this);
3705 case HloOpcode::kOptimizationBarrier:
3706 return visitor->HandleOptimizationBarrier(this);
3707 }
3708 return InternalError(
3709 "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3710 "please file a bug for XLA.",
3711 HloOpcodeString(opcode_));
3712 }
3713
3714 // Explicit instantiations.
3715 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3716 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3717
3718 // Push "child" onto the dfs_stack if not already visited. Returns false if a
3719 // cycle was detected, and true otherwise.
3720 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3721 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3722 HloInstruction* child) {
3723 CHECK(child != nullptr);
3724 const int id = child->unique_id();
3725 CHECK_GE(id, 0) << "instruction may not have a parent computation";
3726 switch (visitor->GetVisitState(id)) {
3727 case Visitor::kVisiting:
3728 return false;
3729
3730 case Visitor::kVisited:
3731 // Nothing to do
3732 return true;
3733
3734 case Visitor::kNotVisited:
3735 dfs_stack->push_back(std::make_pair(id, child));
3736 return true;
3737 }
3738 }
3739
3740 using InternalCompareFunction =
3741 std::function<bool(std::pair<int, const HloInstruction*>,
3742 std::pair<int, const HloInstruction*>)>;
3743 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3744 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3745 const InternalCompareFunction* operand_order,
3746 bool ignore_control_predecessors) {
3747 visitor->ReserveVisitStates(root->parent()->instruction_count());
3748
3749 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3750 //
3751 // We need to keep track of both the id and the instruction because
3752 // instructions can get deleted while they are on the stack, so we
3753 // can't always use the (potentially dead) instruction object to grab
3754 // its id.
3755 DFSStack dfs_stack;
3756 dfs_stack.emplace_back(root->unique_id(), root);
3757
3758 do {
3759 DCHECK(!dfs_stack.empty());
3760
3761 int current_id = dfs_stack.back().first;
3762 HloInstruction* current_node = dfs_stack.back().second;
3763 CHECK_GE(current_id, 0) << current_id << ": " << current_node
3764 << ": instruction may not have parent computation";
3765 typename Visitor::VisitState visit_state =
3766 visitor->GetVisitState(current_id);
3767 if (visit_state == Visitor::kVisited) {
3768 dfs_stack.pop_back();
3769 VLOG(3) << "Not visiting HLO (id = " << current_id
3770 << ") as it was already visited.";
3771 continue;
3772 }
3773
3774 if (visit_state == Visitor::kVisiting) {
3775 dfs_stack.pop_back();
3776
3777 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3778 VLOG(2) << "Visiting HLO %" << current_node->name();
3779 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3780 visitor->SetVisitState(current_id, Visitor::kVisited);
3781 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3782 continue;
3783 }
3784
3785 visitor->SetVisitState(current_id, Visitor::kVisiting);
3786
3787 const size_t old_dfs_stack_size = dfs_stack.size();
3788 for (HloInstruction* child : current_node->operands()) {
3789 if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3790 PrintCycle(child, &dfs_stack);
3791 return FailedPrecondition(
3792 "A cycle is detected while visiting instruction %s",
3793 current_node->ToString());
3794 }
3795 }
3796
3797 if (!ignore_control_predecessors) {
3798 for (HloInstruction* child : current_node->control_predecessors()) {
3799 if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3800 PrintCycle(child, &dfs_stack);
3801 return FailedPrecondition(
3802 "A cycle is detected while visiting instruction %s",
3803 current_node->ToString());
3804 }
3805 }
3806 }
3807
3808 if (operand_order != nullptr) {
3809 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3810 *operand_order);
3811 }
3812
3813 // This makes the traversal order the same as what you'd expect
3814 // out of a recursive algorithm.
3815 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3816 } while (!dfs_stack.empty());
3817
3818 return OkStatus();
3819 }
3820
3821 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3822 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3823 bool call_finish_visit,
3824 bool ignore_control_predecessors) {
3825 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3826 TF_RETURN_IF_ERROR(
3827 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3828 if (call_finish_visit) {
3829 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3830 }
3831 return OkStatus();
3832 }
3833
3834 // Explicit instantiations.
3835 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3836 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3837
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3838 Status HloInstruction::AcceptWithOperandOrder(
3839 DfsHloVisitor* visitor, const CompareFunction& operand_order,
3840 bool call_finish_visit) {
3841 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3842 InternalCompareFunction func = [&operand_order](
3843 std::pair<int, const HloInstruction*> a,
3844 std::pair<int, const HloInstruction*> b) {
3845 // Call the client's comparison function on the actual HloInstruction*
3846 // objects (ignoring the internal ids we also have in our stack entries)
3847 return operand_order(a.second, b.second);
3848 };
3849 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3850 /*ignore_control_predecessors=*/false));
3851 if (call_finish_visit) {
3852 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3853 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3854 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3855 }
3856 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3857 return OkStatus();
3858 }
3859
shape() const3860 const Shape& HloInstruction::shape() const { return shape_; }
3861
OperandIndices(const HloInstruction * operand) const3862 absl::InlinedVector<int64_t, 4> HloInstruction::OperandIndices(
3863 const HloInstruction* operand) const {
3864 absl::InlinedVector<int64_t, 4> result;
3865 for (int64_t i = 0; i < operand_count(); ++i) {
3866 if (this->operand(i) == operand) {
3867 result.push_back(i);
3868 }
3869 }
3870 return result;
3871 }
3872
IsElementwiseBinary() const3873 bool HloInstruction::IsElementwiseBinary() const {
3874 return IsElementwise() && operand_count() == 2;
3875 }
3876
IsElementwise() const3877 bool HloInstruction::IsElementwise() const {
3878 return IsElementwiseImpl(std::nullopt);
3879 }
3880
IsElementwiseOnOperand(int64_t operand_idx) const3881 bool HloInstruction::IsElementwiseOnOperand(int64_t operand_idx) const {
3882 return IsElementwiseImpl(operand_idx);
3883 }
3884
3885 namespace {
3886
3887 // Indicates how an instruction uses a value (such as an operand).
3888 //
3889 // Does it (a) not use it, (b) use it, or (c) use it multiple times?
3890 enum class UseKind { kReuse = 0, kUse = 1, kNoUse = 2 };
3891
3892 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3893 // in HloInstruction::OperandElementUse below.
3894 class FusionReusesParamElements {
3895 public:
Compute(int64_t i,const HloInstruction & hlo)3896 static UseKind Compute(int64_t i, const HloInstruction& hlo) {
3897 absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3898 return ComputeInternal(i, hlo, &memoization_cache);
3899 }
3900
3901 private:
3902 static UseKind ComputeInternal(
3903 int64_t outer_param_num, const HloInstruction& hlo,
3904 absl::flat_hash_map<const HloInstruction*, UseKind>* cache);
3905 };
3906
3907 } // namespace
3908
3909 // Returns how this instruction uses elements of its operand at operand_num.
OperandElementUse(const HloInstruction & instr,int64_t operand_num)3910 static UseKind OperandElementUse(const HloInstruction& instr,
3911 int64_t operand_num) {
3912 switch (instr.opcode()) {
3913 case HloOpcode::kBitcast:
3914 case HloOpcode::kConcatenate:
3915 case HloOpcode::kReshape:
3916 case HloOpcode::kReverse:
3917 case HloOpcode::kSlice:
3918 case HloOpcode::kTranspose:
3919 case HloOpcode::kGather:
3920 return UseKind::kUse;
3921 case HloOpcode::kPad:
3922 // Pad reuses the padding value but not the padded array elements.
3923 return operand_num > 0 ? UseKind::kReuse : UseKind::kUse;
3924 case HloOpcode::kReduce:
3925 // Reduce reuses the init values but not the operand array elements.
3926 return operand_num >= Cast<HloReduceInstruction>(&instr)->input_count()
3927 ? UseKind::kReuse
3928 : UseKind::kUse;
3929 case HloOpcode::kFusion:
3930 // Uses the memoizing, recursive computation defined above.
3931 return FusionReusesParamElements::Compute(operand_num,
3932 *instr.fused_expression_root());
3933 case HloOpcode::kDot:
3934 // Matrix-vector dots do not reuse the matrix operand.
3935 if (instr.shape().dimensions_size() <= 1) {
3936 if ((operand_num == 0 && instr.operand(1)->shape().rank() <= 1) ||
3937 (operand_num == 1 && instr.operand(0)->shape().rank() <= 1)) {
3938 return UseKind::kUse;
3939 }
3940 }
3941 return UseKind::kReuse;
3942 case HloOpcode::kDynamicUpdateSlice:
3943 // Dynamic-update-slice reuses only start_indices.
3944 if (operand_num == 0 || operand_num == 1) {
3945 return UseKind::kUse;
3946 }
3947 return UseKind::kReuse;
3948 default:
3949 return instr.IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3950 }
3951 }
3952
ComputeInternal(int64_t outer_param_num,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3953 UseKind FusionReusesParamElements::ComputeInternal(
3954 int64_t outer_param_num, const HloInstruction& hlo,
3955 absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3956 if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3957 if (hlo_param->parameter_number() == outer_param_num) {
3958 return UseKind::kUse;
3959 }
3960 }
3961
3962 auto p = cache->emplace(&hlo, UseKind::kNoUse);
3963 auto value_it = p.first;
3964 const bool key_is_new = p.second;
3965
3966 if (!key_is_new) {
3967 return value_it->second;
3968 }
3969
3970 // Our dataflow graph has no loops, so we don't need the fixed point
3971 // computation.
3972 for (int64_t operand_num = 0; operand_num < hlo.operands().size();
3973 ++operand_num) {
3974 UseKind old_val = value_it->second;
3975
3976 // Compute updated value.
3977 UseKind new_val = [&] {
3978 // How does the HLO use this operand.
3979 UseKind hlo_use = OperandElementUse(hlo, operand_num);
3980
3981 // If the HLO does not use the outer operand, return previous value.
3982 if (hlo_use == UseKind::kNoUse) {
3983 return old_val;
3984 }
3985
3986 UseKind operand_use =
3987 ComputeInternal(outer_param_num, *hlo.operand(operand_num), cache);
3988
3989 // If the operand does not use the outer operand, return the previous
3990 // value.
3991 if (operand_use == UseKind::kNoUse) {
3992 return old_val;
3993 }
3994
3995 // Meet operator on a lattice:
3996 //
3997 // kReuse < kUse < kNoUse.
3998 return std::min({old_val, hlo_use, operand_use});
3999 }();
4000
4001 value_it = cache->find(&hlo);
4002 value_it->second = new_val;
4003 // Fold() minimizes the UseKind value. If it is already minimum, we do not
4004 // have to check all the remaining operands.
4005 if (new_val == UseKind::kReuse) {
4006 break;
4007 }
4008 }
4009 return value_it->second;
4010 }
4011
ReusesOperandElements(int64_t i) const4012 bool HloInstruction::ReusesOperandElements(int64_t i) const {
4013 return OperandElementUse(*this, i) == UseKind::kReuse;
4014 }
4015
4016 std::optional<ShapeUtil::ShapeEqualityDescriptor>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const4017 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
4018 if (HloOpcode::kReshape != opcode_) {
4019 return std::nullopt;
4020 }
4021 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
4022 shape_);
4023 }
4024
ToString(HloInstruction::FusionKind kind)4025 std::string ToString(HloInstruction::FusionKind kind) {
4026 switch (kind) {
4027 case HloInstruction::FusionKind::kLoop:
4028 return "kLoop";
4029 case HloInstruction::FusionKind::kInput:
4030 return "kInput";
4031 case HloInstruction::FusionKind::kOutput:
4032 return "kOutput";
4033 case HloInstruction::FusionKind::kCustom:
4034 return "kCustom";
4035 }
4036 }
4037
StringToFusionKind(const std::string & kind_name)4038 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
4039 const std::string& kind_name) {
4040 if (kind_name == "kLoop") {
4041 return HloInstruction::FusionKind::kLoop;
4042 }
4043 if (kind_name == "kInput") {
4044 return HloInstruction::FusionKind::kInput;
4045 }
4046 if (kind_name == "kOutput") {
4047 return HloInstruction::FusionKind::kOutput;
4048 }
4049 if (kind_name == "kCustom") {
4050 return HloInstruction::FusionKind::kCustom;
4051 }
4052 return InvalidArgument("Unknown fusion kind: %s", kind_name);
4053 }
4054
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)4055 std::string FrontendAttributesToString(
4056 const FrontendAttributes& frontend_attributes) {
4057 std::vector<std::pair<std::string, std::string>> sorted_attributes(
4058 frontend_attributes.map().begin(), frontend_attributes.map().end());
4059 absl::c_sort(sorted_attributes);
4060 // Frontend attribute is a comma-separated list of attribute="value" pairs,
4061 // e.g., frontend_attributes={name="value_a",type="int32_t"}.
4062 const auto formatter = [](std::string* out,
4063 const std::pair<std::string, std::string>& item) {
4064 absl::StrAppend(out, item.first, "=\"", item.second, "\"");
4065 };
4066 return absl::StrFormat("{%s}",
4067 absl::StrJoin(sorted_attributes, ",", formatter));
4068 }
4069
PaddingConfigToString(const PaddingConfig & padding)4070 std::string PaddingConfigToString(const PaddingConfig& padding) {
4071 bool has_interior_padding =
4072 absl::c_any_of(padding.dimensions(),
4073 [](const PaddingConfig::PaddingConfigDimension& dim) {
4074 return dim.interior_padding() != 0;
4075 });
4076 return StrJoin(
4077 padding.dimensions(), "x",
4078 [&](std::string* out, const PaddingConfig::PaddingConfigDimension& dim) {
4079 StrAppend(
4080 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
4081 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
4082 });
4083 }
4084
RandomDistributionToString(const RandomDistribution & distribution)4085 std::string RandomDistributionToString(const RandomDistribution& distribution) {
4086 return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
4087 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)4088 std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
4089 return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
4090 }
4091
PrecisionToString(const PrecisionConfig::Precision & precision)4092 std::string PrecisionToString(const PrecisionConfig::Precision& precision) {
4093 return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
4094 }
4095
CustomCallScheduleToString(const CustomCallSchedule & schedule)4096 static std::string CustomCallScheduleToString(
4097 const CustomCallSchedule& schedule) {
4098 return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule));
4099 }
4100
CustomCallApiVersionToString(const CustomCallApiVersion & schedule)4101 static std::string CustomCallApiVersionToString(
4102 const CustomCallApiVersion& schedule) {
4103 return absl::AsciiStrToLower(CustomCallApiVersion_Name(schedule));
4104 }
4105
DotDimensionNumbersToString(const DotDimensionNumbers & dnums)4106 std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums) {
4107 std::vector<std::string> result;
4108 if (!dnums.lhs_batch_dimensions().empty()) {
4109 result.push_back(StrCat("lhs_batch_dims={",
4110 StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
4111 }
4112 result.push_back(StrCat("lhs_contracting_dims={",
4113 StrJoin(dnums.lhs_contracting_dimensions(), ","),
4114 "}"));
4115
4116 if (!dnums.rhs_batch_dimensions().empty()) {
4117 result.push_back(StrCat("rhs_batch_dims={",
4118 StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
4119 }
4120 result.push_back(StrCat("rhs_contracting_dims={",
4121 StrJoin(dnums.rhs_contracting_dimensions(), ","),
4122 "}"));
4123
4124 return StrJoin(result, ", ");
4125 }
4126
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)4127 std::string ConvolutionDimensionNumbersToString(
4128 const ConvolutionDimensionNumbers& dnums) {
4129 auto len_required = [](int64_t a, int64_t b, absl::Span<const int64_t> cs) {
4130 return std::max({a, b, cs.empty() ? 0 : *absl::c_max_element(cs)}) + 1;
4131 };
4132
4133 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
4134 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
4135 std::vector<std::string> lhs_dims(
4136 len_required(dnums.input_batch_dimension(),
4137 dnums.input_feature_dimension(),
4138 dnums.input_spatial_dimensions()),
4139 "?");
4140 lhs_dims[dnums.input_batch_dimension()] = 'b';
4141 lhs_dims[dnums.input_feature_dimension()] = 'f';
4142 for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
4143 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
4144 }
4145
4146 std::vector<std::string> rhs_dims(
4147 len_required(dnums.kernel_input_feature_dimension(),
4148 dnums.kernel_output_feature_dimension(),
4149 dnums.kernel_spatial_dimensions()),
4150 "?");
4151 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
4152 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
4153 for (int64_t i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
4154 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
4155 }
4156
4157 std::vector<std::string> output_dims(
4158 len_required(dnums.output_batch_dimension(),
4159 dnums.output_feature_dimension(),
4160 dnums.output_spatial_dimensions()),
4161 "?");
4162 output_dims[dnums.output_batch_dimension()] = 'b';
4163 output_dims[dnums.output_feature_dimension()] = 'f';
4164 for (int64_t i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
4165 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
4166 }
4167
4168 return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
4169 StrJoin(output_dims, ""));
4170 }
4171
ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups)4172 std::string ReplicaGroupsToString(
4173 absl::Span<const ReplicaGroup> replica_groups) {
4174 std::vector<std::string> replica_group_str;
4175 replica_group_str.reserve(replica_groups.size());
4176 for (const ReplicaGroup& group : replica_groups) {
4177 replica_group_str.push_back(
4178 StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
4179 }
4180 return StrCat("{", StrJoin(replica_group_str, ","), "}");
4181 }
4182
StringToRandomAlgorithm(const std::string & name)4183 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const std::string& name) {
4184 static absl::flat_hash_map<std::string, RandomAlgorithm>* map = [] {
4185 static auto* map = new absl::flat_hash_map<std::string, RandomAlgorithm>;
4186 for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
4187 if (RandomAlgorithm_IsValid(i)) {
4188 auto value = static_cast<RandomAlgorithm>(i);
4189 (*map)[RandomAlgorithmToString(value)] = value;
4190 }
4191 }
4192 return map;
4193 }();
4194 auto found = map->find(absl::AsciiStrToLower(name));
4195 if (found == map->end()) {
4196 return InvalidArgument("Unknown algorithm");
4197 }
4198 return found->second;
4199 }
4200
StringToRandomDistribution(const std::string & name)4201 StatusOr<RandomDistribution> StringToRandomDistribution(
4202 const std::string& name) {
4203 static absl::flat_hash_map<std::string, RandomDistribution>* map = [] {
4204 static auto* map = new absl::flat_hash_map<std::string, RandomDistribution>;
4205 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
4206 if (RandomDistribution_IsValid(i)) {
4207 auto value = static_cast<RandomDistribution>(i);
4208 (*map)[RandomDistributionToString(value)] = value;
4209 }
4210 }
4211 return map;
4212 }();
4213 auto found = map->find(absl::AsciiStrToLower(name));
4214 if (found == map->end()) {
4215 return InvalidArgument("Unknown distribution");
4216 }
4217 return found->second;
4218 }
4219
StringToPrecision(const std::string & name)4220 StatusOr<PrecisionConfig::Precision> StringToPrecision(
4221 const std::string& name) {
4222 static absl::flat_hash_map<std::string, PrecisionConfig::Precision>* map =
4223 [] {
4224 static auto* map =
4225 new absl::flat_hash_map<std::string, PrecisionConfig::Precision>;
4226 for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
4227 if (PrecisionConfig::Precision_IsValid(i)) {
4228 auto value = static_cast<PrecisionConfig::Precision>(i);
4229 (*map)[PrecisionToString(value)] = value;
4230 }
4231 }
4232 return map;
4233 }();
4234 auto found = map->find(absl::AsciiStrToLower(name));
4235 if (found == map->end()) {
4236 return InvalidArgument("Unknown distribution");
4237 }
4238 return found->second;
4239 }
4240
StringToCustomCallSchedule(absl::string_view name)4241 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(
4242 absl::string_view name) {
4243 static const absl::flat_hash_map<std::string, CustomCallSchedule>* map = [] {
4244 static auto* map = new absl::flat_hash_map<std::string, CustomCallSchedule>;
4245 for (int i = 0; i < CustomCallSchedule_ARRAYSIZE; i++) {
4246 if (CustomCallSchedule_IsValid(i)) {
4247 auto value = static_cast<CustomCallSchedule>(i);
4248 (*map)[CustomCallScheduleToString(value)] = value;
4249 }
4250 }
4251 return map;
4252 }();
4253 auto found = map->find(absl::AsciiStrToLower(name));
4254 if (found == map->end()) {
4255 return InvalidArgument("Unknown schedule");
4256 }
4257 return found->second;
4258 }
4259
StringToCustomCallApiVersion(absl::string_view name)4260 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
4261 absl::string_view name) {
4262 static const absl::flat_hash_map<std::string, CustomCallApiVersion>* map =
4263 [] {
4264 static auto* map =
4265 new absl::flat_hash_map<std::string, CustomCallApiVersion>;
4266 for (int i = 0; i < CustomCallApiVersion_ARRAYSIZE; i++) {
4267 if (CustomCallApiVersion_IsValid(i)) {
4268 auto value = static_cast<CustomCallApiVersion>(i);
4269 (*map)[CustomCallApiVersionToString(value)] = value;
4270 }
4271 }
4272 return map;
4273 }();
4274 auto found = map->find(absl::AsciiStrToLower(name));
4275 if (found == map->end()) {
4276 return InvalidArgument("Unknown API version");
4277 }
4278 return found->second;
4279 }
4280
operator <<(std::ostream & os,HloInstruction::FusionKind kind)4281 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
4282 return os << ToString(kind);
4283 }
4284
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const4285 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
4286 const HloInstruction* const& rhs) const {
4287 if (rhs == nullptr) {
4288 // Nothing compares less than nullptr.
4289 return false;
4290 }
4291 if (lhs == nullptr) {
4292 return true;
4293 }
4294 auto lhs_module = lhs->GetModule();
4295 auto rhs_module = rhs->GetModule();
4296 CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
4297 (lhs_module != nullptr && rhs_module != nullptr));
4298 if (lhs_module != nullptr &&
4299 lhs_module->unique_id() != rhs_module->unique_id()) {
4300 return lhs_module->unique_id() < rhs_module->unique_id();
4301 }
4302 return lhs->unique_id() < rhs->unique_id();
4303 }
4304
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const4305 Status HloInstruction::GetBackendConfigInternal(
4306 tensorflow::protobuf::Message* proto) const {
4307 proto->Clear();
4308
4309 if (auto* proto_ptr = backend_config_.GetProtoPtr()) {
4310 if (proto_ptr->GetDescriptor() == proto->GetDescriptor()) {
4311 proto->CopyFrom(*proto_ptr);
4312 return OkStatus();
4313 }
4314 }
4315
4316 auto& raw_string = raw_backend_config_string();
4317 // Empty string does not parse as valid JSON, but it's a valid backend config,
4318 // corresponding to the empty proto.
4319 if (raw_string.empty()) {
4320 return OkStatus();
4321 }
4322 TF_RETURN_IF_ERROR(tensorflow::HumanReadableJsonToProto(raw_string, proto));
4323 backend_config_.SetProto(*proto);
4324 return OkStatus();
4325 }
4326
GetRawString() const4327 const std::string& HloInstruction::BackendConfigRep::GetRawString() const {
4328 if (proto_ && raw_string_.empty()) {
4329 raw_string_ = BackendConfigToRawString(*proto_).ValueOrDie();
4330 }
4331 return raw_string_;
4332 }
4333
Clone() const4334 HloInstruction::BackendConfigRep HloInstruction::BackendConfigRep::Clone()
4335 const {
4336 // Prefer cloning protobuf, raw_string_ will be lazily generated if accessed.
4337 BackendConfigRep cloned;
4338 if (auto* proto = GetProtoPtr()) {
4339 cloned.SetProto(*proto);
4340 } else {
4341 cloned.raw_string_ = raw_string_;
4342 }
4343 return cloned;
4344 }
4345
operator =(std::string raw_string)4346 HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
4347 std::string raw_string) {
4348 raw_string_ = std::move(raw_string);
4349 proto_.reset();
4350 return *this;
4351 }
4352
operator =(const tensorflow::protobuf::Message & proto)4353 HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
4354 const tensorflow::protobuf::Message& proto) {
4355 SetProto(proto);
4356 raw_string_.clear();
4357 return *this;
4358 }
4359
SetProto(const tensorflow::protobuf::Message & proto)4360 void HloInstruction::BackendConfigRep::SetProto(
4361 const tensorflow::protobuf::Message& proto) {
4362 proto_.reset(proto.New());
4363 proto_->CopyFrom(proto);
4364 }
4365
operator ==(const BackendConfigRep & other) const4366 bool HloInstruction::BackendConfigRep::operator==(
4367 const BackendConfigRep& other) const {
4368 auto* proto_a = GetProtoPtr();
4369 auto* proto_b = other.GetProtoPtr();
4370 if (proto_a != nullptr && proto_b != nullptr) {
4371 using ::tensorflow::protobuf::util::MessageDifferencer;
4372 return MessageDifferencer::Equals(*proto_a, *proto_b);
4373 }
4374 // TODO(b/225956414): Consider canonicalizing raw string form.
4375 return GetRawString() == other.GetRawString();
4376 }
4377
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)4378 /* static */ StatusOr<std::string> HloInstruction::BackendConfigToRawString(
4379 const tensorflow::protobuf::Message& proto) {
4380 std::string ret;
4381 // Pass ignore_accuracy_loss = true because estimated_cycles field can be
4382 // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
4383 // INT64_MAX, JsonFormat will return an error status, although there is no
4384 // accuracy loss for int64_t.
4385 TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
4386 proto, &ret, /*ignore_accuracy_loss=*/true));
4387 return ret;
4388 }
4389
precision_config() const4390 const PrecisionConfig& HloInstruction::precision_config() const {
4391 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4392 return convolution->precision_config();
4393 }
4394 if (auto* dot = DynCast<HloDotInstruction>(this)) {
4395 return dot->precision_config();
4396 }
4397
4398 if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
4399 return custom_call->precision_config();
4400 }
4401 LOG(FATAL) << "Unimplemented method.";
4402 }
4403
mutable_precision_config()4404 PrecisionConfig* HloInstruction::mutable_precision_config() {
4405 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4406 return convolution->mutable_precision_config();
4407 }
4408 if (auto* dot = DynCast<HloDotInstruction>(this)) {
4409 return dot->mutable_precision_config();
4410 }
4411 LOG(FATAL) << "Unimplemented method.";
4412 }
4413
GetModule() const4414 HloModule* HloInstruction::GetModule() const {
4415 if (parent_) {
4416 return parent_->parent();
4417 }
4418 return nullptr;
4419 }
4420
UniquifyName(NameUniquer * name_uniquer)4421 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
4422 std::string parent_str = parent() == nullptr ? "noparent" : parent()->name();
4423 name_ = name_uniquer->GetUniqueName(name_);
4424 }
4425
set_outer_dimension_partitions(const std::vector<int64_t> & outer_dimension_partitions)4426 void HloInstruction::set_outer_dimension_partitions(
4427 const std::vector<int64_t>& outer_dimension_partitions) {
4428 outer_dimension_partitions_ = outer_dimension_partitions;
4429 }
4430
SortInstructionUsersAndControlLists(const MappedPtrContainerSorter<HloInstruction>::MapPtrFn & map_fn,const HloInstruction & sorted_instruction)4431 void HloInstruction::SortInstructionUsersAndControlLists(
4432 const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
4433 const HloInstruction& sorted_instruction) {
4434 using Sorter = MappedPtrContainerSorter<HloInstruction>;
4435 auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4436 sorted_instruction.users_, users_);
4437 if (!status.ok()) {
4438 LOG(ERROR) << "Failed to sort instruction users for " << name() << "; "
4439 << status;
4440 }
4441 user_map_.clear();
4442 for (uint64_t i = 0; i < users_.size(); ++i) {
4443 user_map_[users_[i]] = i;
4444 }
4445 status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4446 sorted_instruction.control_predecessors_,
4447 control_predecessors_);
4448 if (!status.ok()) {
4449 LOG(ERROR) << "Failed to sort instruction control predecessors for "
4450 << name() << "; " << status;
4451 }
4452 status =
4453 Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
4454 sorted_instruction.control_successors_, control_successors_);
4455 if (!status.ok()) {
4456 LOG(ERROR) << "Failed to sort instruction control successors for " << name()
4457 << "; " << status;
4458 }
4459 }
4460
4461 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const4462 int64_t HloInstruction::feature_index() const {
4463 return Cast<HloBatchNormInstruction>(this)->feature_index();
4464 }
4465
epsilon() const4466 float HloInstruction::epsilon() const {
4467 return Cast<HloBatchNormInstruction>(this)->epsilon();
4468 }
4469
fft_type() const4470 FftType HloInstruction::fft_type() const {
4471 return Cast<HloFftInstruction>(this)->fft_type();
4472 }
4473
fft_length() const4474 const std::vector<int64_t>& HloInstruction::fft_length() const {
4475 return Cast<HloFftInstruction>(this)->fft_length();
4476 }
4477
concatenate_dimension() const4478 int64_t HloInstruction::concatenate_dimension() const {
4479 return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
4480 }
4481
dimension() const4482 int64_t HloInstruction::dimension() const {
4483 if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
4484 return set_size->dimension();
4485 }
4486 return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
4487 }
4488
inferred_dimension() const4489 int64_t HloInstruction::inferred_dimension() const {
4490 return Cast<HloReshapeInstruction>(this)->inferred_dimension();
4491 }
4492
IsRank2Transpose() const4493 bool HloInstruction::IsRank2Transpose() const {
4494 auto transpose = DynCast<HloTransposeInstruction>(this);
4495 return transpose != nullptr && transpose->IsRank2Transpose();
4496 }
4497
slice_starts(int64_t dimension) const4498 int64_t HloInstruction::slice_starts(int64_t dimension) const {
4499 return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
4500 }
4501
slice_starts() const4502 const std::vector<int64_t>& HloInstruction::slice_starts() const {
4503 return Cast<HloSliceInstruction>(this)->slice_starts();
4504 }
4505
mutable_slice_starts()4506 std::vector<int64_t>* HloInstruction::mutable_slice_starts() {
4507 return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
4508 }
4509
slice_limits(int64_t dimension) const4510 int64_t HloInstruction::slice_limits(int64_t dimension) const {
4511 return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
4512 }
4513
slice_limits() const4514 const std::vector<int64_t>& HloInstruction::slice_limits() const {
4515 return Cast<HloSliceInstruction>(this)->slice_limits();
4516 }
4517
mutable_slice_limits()4518 std::vector<int64_t>* HloInstruction::mutable_slice_limits() {
4519 return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
4520 }
4521
slice_strides(int64_t dimension) const4522 int64_t HloInstruction::slice_strides(int64_t dimension) const {
4523 return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
4524 }
4525
slice_strides() const4526 const std::vector<int64_t>& HloInstruction::slice_strides() const {
4527 return Cast<HloSliceInstruction>(this)->slice_strides();
4528 }
4529
mutable_slice_strides()4530 std::vector<int64_t>* HloInstruction::mutable_slice_strides() {
4531 return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
4532 }
4533
literal() const4534 const Literal& HloInstruction::literal() const {
4535 return Cast<HloConstantInstruction>(this)->literal();
4536 }
4537
IsConstant() const4538 bool HloInstruction::IsConstant() const {
4539 return DynCast<HloConstantInstruction>(this) != nullptr;
4540 }
4541
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)4542 void HloInstruction::RelayoutConstant(const Layout& new_layout,
4543 const ShapeIndex& shape_index) {
4544 Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
4545 }
4546
4547 // Delegates to HloCallableInstruction::AppendInstructionIntoCalledComputation.
AppendInstructionIntoCalledComputation(HloInstruction * instruction_to_append,bool add_output)4548 HloInstruction* HloInstruction::AppendInstructionIntoCalledComputation(
4549 HloInstruction* instruction_to_append, bool add_output) {
4550 return Cast<HloCallableInstruction>(this)
4551 ->AppendInstructionIntoCalledComputation(instruction_to_append,
4552 add_output);
4553 }
4554
AddFusionOperand(HloInstruction * new_operand)4555 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
4556 return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
4557 }
4558
4559 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)4560 void HloInstruction::MergeFusionInstruction(
4561 HloInstruction* instruction_to_merge) {
4562 return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
4563 Cast<HloFusionInstruction>(instruction_to_merge));
4564 }
4565
4566 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)4567 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
4568 HloInstruction* instruction_to_merge) {
4569 return Cast<HloFusionInstruction>(this)
4570 ->MergeFusionInstructionIntoMultiOutput(
4571 Cast<HloFusionInstruction>(instruction_to_merge));
4572 }
4573
FuseInstruction(HloInstruction * instruction_to_fuse)4574 HloInstruction* HloInstruction::FuseInstruction(
4575 HloInstruction* instruction_to_fuse) {
4576 return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4577 }
4578
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4579 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4580 HloInstruction* instruction_to_fuse) {
4581 return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4582 instruction_to_fuse);
4583 }
4584
fused_instructions_computation() const4585 HloComputation* HloInstruction::fused_instructions_computation() const {
4586 return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4587 }
4588
fused_expression_root() const4589 HloInstruction* HloInstruction::fused_expression_root() const {
4590 return Cast<HloFusionInstruction>(this)->fused_expression_root();
4591 }
4592
4593 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4594 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4595 HloInstruction::fused_instructions() const {
4596 return Cast<HloFusionInstruction>(this)->fused_instructions();
4597 }
4598
4599 const tensorflow::gtl::iterator_range<
4600 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4601 HloInstruction::fused_instructions() {
4602 return Cast<HloFusionInstruction>(this)->fused_instructions();
4603 }
4604
fused_instruction_count() const4605 int64_t HloInstruction::fused_instruction_count() const {
4606 return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4607 }
4608
fused_parameter(int64_t parameter_number) const4609 HloInstruction* HloInstruction::fused_parameter(
4610 int64_t parameter_number) const {
4611 return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4612 }
4613
fused_parameters() const4614 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4615 return Cast<HloFusionInstruction>(this)->fused_parameters();
4616 }
4617
IsMultiOutputFusion() const4618 const bool HloInstruction::IsMultiOutputFusion() const {
4619 const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4620 return fusion != nullptr && fusion->IsMultiOutputFusion();
4621 }
4622
fusion_kind() const4623 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4624 return Cast<HloFusionInstruction>(this)->fusion_kind();
4625 }
4626
set_fusion_kind(FusionKind kind)4627 void HloInstruction::set_fusion_kind(FusionKind kind) {
4628 return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4629 }
4630
random_distribution() const4631 RandomDistribution HloInstruction::random_distribution() const {
4632 return Cast<HloRngInstruction>(this)->random_distribution();
4633 }
4634
parameter_number() const4635 int64_t HloInstruction::parameter_number() const {
4636 return Cast<HloParameterInstruction>(this)->parameter_number();
4637 }
4638
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4639 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4640 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4641 return Cast<HloParameterInstruction>(this)
4642 ->set_parameter_replicated_at_leaf_buffers(
4643 parameter_replicated_at_leaf_buffers);
4644 }
4645
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4646 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4647 const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4648 return Cast<HloParameterInstruction>(this)
4649 ->set_parameter_replicated_at_leaf_buffers(
4650 parameter_replicated_at_leaf_buffers);
4651 }
4652
4653 const std::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4654 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4655 return Cast<HloParameterInstruction>(this)
4656 ->parameter_replicated_at_leaf_buffers();
4657 }
4658
tuple_index() const4659 int64_t HloInstruction::tuple_index() const {
4660 return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4661 }
4662
set_tuple_index(int64_t new_tuple_index)4663 void HloInstruction::set_tuple_index(int64_t new_tuple_index) {
4664 return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4665 new_tuple_index);
4666 }
4667
exponent_bits() const4668 int32_t HloInstruction::exponent_bits() const {
4669 return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4670 }
4671
mantissa_bits() const4672 int32_t HloInstruction::mantissa_bits() const {
4673 return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4674 }
4675
infeed_config() const4676 std::string HloInstruction::infeed_config() const {
4677 return Cast<HloInfeedInstruction>(this)->infeed_config();
4678 }
4679
set_infeed_config(const std::string & config)4680 void HloInstruction::set_infeed_config(const std::string& config) {
4681 return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4682 }
4683
outfeed_shape() const4684 const Shape& HloInstruction::outfeed_shape() const {
4685 return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4686 }
4687
mutable_outfeed_shape()4688 Shape* HloInstruction::mutable_outfeed_shape() {
4689 return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4690 }
4691
outfeed_config() const4692 const std::string& HloInstruction::outfeed_config() const {
4693 return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4694 }
4695
set_outfeed_config(const std::string & config)4696 void HloInstruction::set_outfeed_config(const std::string& config) {
4697 return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4698 }
4699
replica_groups() const4700 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4701 return Cast<HloCollectiveInstruction>(this)->replica_groups();
4702 }
4703
4704 const std::vector<std::pair<int64_t, int64_t>>&
source_target_pairs() const4705 HloInstruction::source_target_pairs() const {
4706 return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4707 }
4708
channel_id() const4709 std::optional<int64_t> HloInstruction::channel_id() const {
4710 return Cast<HloChannelInstruction>(this)->channel_id();
4711 }
4712
set_channel_id(const std::optional<int64_t> & channel_id)4713 void HloInstruction::set_channel_id(const std::optional<int64_t>& channel_id) {
4714 return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4715 }
4716
4717 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4718 HloInstruction::convolution_dimension_numbers() const {
4719 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4720 return convolution->convolution_dimension_numbers();
4721 }
4722 if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4723 return custom_call->convolution_dimension_numbers();
4724 }
4725 LOG(FATAL) << "Unimplemented method.";
4726 }
4727
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4728 void HloInstruction::set_convolution_dimension_numbers(
4729 const ConvolutionDimensionNumbers& dnums) {
4730 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4731 convolution->set_convolution_dimension_numbers(dnums);
4732 } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4733 custom_call->set_convolution_dimension_numbers(dnums);
4734 } else {
4735 LOG(FATAL) << "Unimplemented method.";
4736 }
4737 }
4738
feature_group_count() const4739 int64_t HloInstruction::feature_group_count() const {
4740 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4741 return convolution->feature_group_count();
4742 }
4743 return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4744 }
4745
set_feature_group_count(int64_t feature_group_count)4746 void HloInstruction::set_feature_group_count(int64_t feature_group_count) {
4747 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4748 return convolution->set_feature_group_count(feature_group_count);
4749 }
4750 Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4751 feature_group_count);
4752 }
4753
batch_group_count() const4754 int64_t HloInstruction::batch_group_count() const {
4755 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4756 return convolution->batch_group_count();
4757 }
4758 return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4759 }
4760
set_batch_group_count(int64_t batch_group_count)4761 void HloInstruction::set_batch_group_count(int64_t batch_group_count) {
4762 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4763 return convolution->set_batch_group_count(batch_group_count);
4764 }
4765 Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4766 batch_group_count);
4767 }
4768
select() const4769 HloComputation* HloInstruction::select() const {
4770 return Cast<HloSelectAndScatterInstruction>(this)->select();
4771 }
4772
scatter() const4773 HloComputation* HloInstruction::scatter() const {
4774 return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4775 }
4776
set_select(HloComputation * computation)4777 void HloInstruction::set_select(HloComputation* computation) {
4778 return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4779 }
4780
set_scatter(HloComputation * computation)4781 void HloInstruction::set_scatter(HloComputation* computation) {
4782 return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4783 }
4784
custom_call_target() const4785 const std::string& HloInstruction::custom_call_target() const {
4786 return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4787 }
set_custom_call_target(absl::string_view target)4788 void HloInstruction::set_custom_call_target(absl::string_view target) {
4789 Cast<HloCustomCallInstruction>(this)->set_custom_call_target(target);
4790 }
4791
padding_config() const4792 const PaddingConfig& HloInstruction::padding_config() const {
4793 return Cast<HloPadInstruction>(this)->padding_config();
4794 }
4795
padding_type() const4796 PaddingType HloInstruction::padding_type() const {
4797 return Cast<HloCustomCallInstruction>(this)->padding_type();
4798 }
4799
mutable_padding_config()4800 PaddingConfig* HloInstruction::mutable_padding_config() {
4801 return Cast<HloPadInstruction>(this)->mutable_padding_config();
4802 }
4803
slice_sizes(int64_t dimension) const4804 int64_t HloInstruction::slice_sizes(int64_t dimension) const {
4805 return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4806 }
4807
dynamic_slice_sizes() const4808 const std::vector<int64_t>& HloInstruction::dynamic_slice_sizes() const {
4809 return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4810 }
4811
4812 const std::vector<std::vector<int64_t>>&
dynamic_slice_sizes_list() const4813 HloInstruction::dynamic_slice_sizes_list() const {
4814 return Cast<HloCollectivePermuteInstruction>(this)
4815 ->dynamic_slice_sizes_list();
4816 }
4817
gather_dimension_numbers() const4818 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4819 return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4820 }
4821
gather_slice_sizes() const4822 absl::Span<const int64_t> HloInstruction::gather_slice_sizes() const {
4823 return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4824 }
4825
scatter_dimension_numbers() const4826 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4827 const {
4828 return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4829 }
4830
dot_dimension_numbers() const4831 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4832 return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4833 }
4834
operand_side_metadata() const4835 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4836 return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4837 }
4838
user_side_metadata() const4839 const DomainMetadata& HloInstruction::user_side_metadata() const {
4840 return Cast<HloDomainInstruction>(this)->user_side_metadata();
4841 }
4842
IsAsynchronous() const4843 bool HloInstruction::IsAsynchronous() const {
4844 return opcode() == HloOpcode::kAsyncStart ||
4845 opcode() == HloOpcode::kAsyncUpdate ||
4846 opcode() == HloOpcode::kAsyncDone;
4847 }
4848
async_wrapped_computation() const4849 HloComputation* HloInstruction::async_wrapped_computation() const {
4850 CHECK(IsAsynchronous());
4851 return called_computations()[0];
4852 }
4853
async_wrapped_instruction() const4854 HloInstruction* HloInstruction::async_wrapped_instruction() const {
4855 return Cast<HloAsyncInstruction>(this)->async_wrapped_instruction();
4856 }
4857
async_wrapped_opcode() const4858 HloOpcode HloInstruction::async_wrapped_opcode() const {
4859 return Cast<HloAsyncInstruction>(this)->async_wrapped_opcode();
4860 }
4861
async_group_id() const4862 std::optional<int64_t> HloInstruction::async_group_id() const {
4863 return Cast<HloAsyncInstruction>(this)->async_group_id();
4864 }
4865
set_async_group_id(std::optional<int64_t> async_group_id)4866 void HloInstruction::set_async_group_id(std::optional<int64_t> async_group_id) {
4867 Cast<HloAsyncInstruction>(this)->set_async_group_id(async_group_id);
4868 }
4869
async_execution_thread() const4870 absl::string_view HloInstruction::async_execution_thread() const {
4871 return Cast<HloAsyncInstruction>(this)->async_execution_thread();
4872 }
4873
set_async_execution_thread(absl::string_view async_execution_thread)4874 void HloInstruction::set_async_execution_thread(
4875 absl::string_view async_execution_thread) {
4876 Cast<HloAsyncInstruction>(this)->set_async_execution_thread(
4877 async_execution_thread);
4878 }
4879
set_called_computations_execution_thread(absl::string_view async_execution_thread,bool skip_async_execution_thread_overwrite)4880 void HloInstruction::set_called_computations_execution_thread(
4881 absl::string_view async_execution_thread,
4882 bool skip_async_execution_thread_overwrite) {
4883 Cast<HloCallableInstruction>(this)->RecursivelySetComputationsThreadName(
4884 async_execution_thread, skip_async_execution_thread_overwrite);
4885 }
4886
is_cross_program_prefetch() const4887 bool HloInstruction::is_cross_program_prefetch() const {
4888 return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4889 }
4890
comparison_direction() const4891 ComparisonDirection HloInstruction::comparison_direction() const {
4892 return Cast<HloCompareInstruction>(this)->direction();
4893 }
4894
comparison_order() const4895 ComparisonOrder HloInstruction::comparison_order() const {
4896 return Cast<HloCompareInstruction>(this)->order();
4897 }
4898
triangular_solve_options() const4899 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4900 return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4901 }
4902
cholesky_options() const4903 const CholeskyOptions& HloInstruction::cholesky_options() const {
4904 return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4905 }
4906
4907 } // namespace xla
4908