1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <iostream>
20 #include <ostream>
21 #include <set>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/ascii.h"
32 #include "absl/strings/escaping.h"
33 #include "absl/strings/numbers.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_join.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/protobuf_util.h"
41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
42 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
43 #include "tensorflow/compiler/xla/service/hlo_computation.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
47 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
49 #include "tensorflow/compiler/xla/service/name_uniquer.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/map_util.h"
57 #include "tensorflow/core/platform/human_readable_json.h"
58 #include "tensorflow/core/platform/logging.h"
59
60 namespace xla {
61
62 using absl::CEscape;
63 using absl::StrAppend;
64 using absl::StrCat;
65 using absl::StrJoin;
66
67 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)68 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
69 const HloInstructionProto& proto,
70 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
71 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
72 bool prohibit_empty_literal) {
73 TF_RET_CHECK(!proto.opcode().empty());
74 HloOpcode opcode;
75 auto opcode_or = StringToHloOpcode(proto.opcode());
76 absl::optional<ComparisonDirection> comparison_direction;
77 if (opcode_or.ok()) {
78 opcode = opcode_or.ConsumeValueOrDie();
79 } else {
80 // Unknown opcode. Try auto-upgrading deprecated "less-than",
81 // "greater-than", etc opcodes, which are now rolled into the kCompare
82 // opcode.
83 if (proto.opcode() == "equal-to") {
84 comparison_direction = ComparisonDirection::kEq;
85 } else if (proto.opcode() == "not-equal-to") {
86 comparison_direction = ComparisonDirection::kNe;
87 } else if (proto.opcode() == "greater-than-or-equal-to") {
88 comparison_direction = ComparisonDirection::kGe;
89 } else if (proto.opcode() == "greater-than") {
90 comparison_direction = ComparisonDirection::kGt;
91 } else if (proto.opcode() == "less-than-or-equal-to") {
92 comparison_direction = ComparisonDirection::kLe;
93 } else if (proto.opcode() == "less-than") {
94 comparison_direction = ComparisonDirection::kLt;
95 }
96 if (comparison_direction) {
97 opcode = HloOpcode::kCompare;
98 } else {
99 return InvalidArgument("Unknown opcode: %s", proto.opcode());
100 }
101 }
102
103 TF_RET_CHECK(proto.has_shape());
104
105 std::unique_ptr<HloInstruction> instruction;
106 const auto operands = [&instruction_map, &proto](int index) {
107 return instruction_map.at(proto.operand_ids(index));
108 };
109 const auto all_operands = [&instruction_map, &proto]() {
110 std::vector<HloInstruction*> result(proto.operand_ids_size());
111 std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
112 result.begin(), [&instruction_map](int64_t operand_id) {
113 return instruction_map.at(operand_id);
114 });
115 return result;
116 };
117 const auto computations = [&computation_map, &proto](int index) {
118 return computation_map.at(proto.called_computation_ids(index));
119 };
120 const auto all_computations = [&computation_map, &proto]() {
121 std::vector<HloComputation*> result(proto.called_computation_ids_size());
122 std::transform(proto.called_computation_ids().begin(),
123 proto.called_computation_ids().end(), result.begin(),
124 [&computation_map](int64_t computation_id) {
125 return computation_map.at(computation_id);
126 });
127 return result;
128 };
129
130 TF_RET_CHECK(
131 absl::c_all_of(proto.operand_ids(),
132 [&](int64_t id) { return instruction_map.contains(id); }))
133 << proto.name() << " instruction contains invalid operand id(s)";
134
135 TF_RET_CHECK(
136 absl::c_all_of(proto.called_computation_ids(),
137 [&](int64_t id) { return computation_map.contains(id); }))
138 << proto.name() << " instruction references invalid computation id(s)";
139
140 Shape shape(proto.shape());
141 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
142
143 absl::optional<int> arity = HloOpcodeArity(opcode);
144 if (arity) {
145 TF_RET_CHECK(proto.operand_ids_size() == *arity)
146 << proto.opcode() << " instruction should have " << *arity
147 << " operands but sees " << proto.operand_ids_size();
148 }
149
150 switch (opcode) {
151 // Ops migrated to subclasses.
152 case HloOpcode::kBatchNormTraining:
153 instruction =
154 CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
155 proto.epsilon(), proto.feature_index());
156 break;
157 case HloOpcode::kBatchNormInference:
158 instruction = CreateBatchNormInference(
159 shape, operands(0), operands(1), operands(2), operands(3),
160 operands(4), proto.epsilon(), proto.feature_index());
161 break;
162 case HloOpcode::kBatchNormGrad:
163 instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
164 operands(2), operands(3), operands(4),
165 proto.epsilon(), proto.feature_index());
166 break;
167 case HloOpcode::kFft: {
168 std::vector<int64> fft_length(proto.fft_length().begin(),
169 proto.fft_length().end());
170 instruction = CreateFft(shape, operands(0), proto.fft_type(),
171 absl::Span<const int64>(fft_length));
172 break;
173 }
174 case HloOpcode::kCopyStart: {
175 instruction = CreateCopyStart(shape, operands(0),
176 proto.is_cross_program_prefetch());
177 break;
178 }
179 case HloOpcode::kCompare: {
180 // Auto-upgraded from deprecated opcode skips the following.
181 if (!comparison_direction) {
182 TF_ASSIGN_OR_RETURN(
183 comparison_direction,
184 StringToComparisonDirection(proto.comparison_direction()));
185 }
186 auto comparison_type_str = proto.comparison_type();
187 if (!comparison_type_str.empty()) {
188 // If a comparison type is specified, it *must* be valid.
189 TF_ASSIGN_OR_RETURN(auto comparison_type,
190 StringToComparisonType(comparison_type_str));
191 instruction = CreateCompare(shape, operands(0), operands(1),
192 *comparison_direction, comparison_type);
193 } else {
194 // Allow the specify of comparison type to be optional.
195 // The comparison type will be determined by the types of the operands.
196 instruction = CreateCompare(shape, operands(0), operands(1),
197 *comparison_direction);
198 }
199 break;
200 }
201 case HloOpcode::kTriangularSolve: {
202 instruction = CreateTriangularSolve(shape, operands(0), operands(1),
203 proto.triangular_solve_options());
204 break;
205 }
206 case HloOpcode::kCholesky: {
207 instruction =
208 CreateCholesky(shape, operands(0), proto.cholesky_options());
209 break;
210 }
211 case HloOpcode::kSend:
212 instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
213 proto.is_host_transfer());
214 break;
215 case HloOpcode::kSendDone:
216 instruction = CreateSendDone(operands(0), proto.is_host_transfer());
217 break;
218 case HloOpcode::kRecv:
219 instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
220 proto.channel_id(), proto.is_host_transfer());
221 break;
222 case HloOpcode::kRecvDone:
223 instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
224 break;
225 case HloOpcode::kReverse:
226 instruction = CreateReverse(shape, operands(0),
227 std::vector<int64>(proto.dimensions().begin(),
228 proto.dimensions().end()));
229 break;
230 case HloOpcode::kConcatenate:
231 TF_RET_CHECK(proto.dimensions_size() == 1)
232 << "Concatenate instruction should have 1 dimension but sees "
233 << proto.dimensions_size();
234 instruction =
235 CreateConcatenate(shape, all_operands(), proto.dimensions(0));
236 break;
237 case HloOpcode::kConditional: {
238 TF_RET_CHECK(proto.called_computation_ids_size() > 0)
239 << "conditional should have at least 1 called computation";
240 if (operands(0)->shape().element_type() == PRED) {
241 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
242 << "conditional should have exactly 2 called computations but got "
243 << proto.called_computation_ids_size();
244 }
245 TF_RET_CHECK(proto.operand_ids_size() ==
246 proto.called_computation_ids_size() + 1)
247 << "conditional should have one branch_index operand plus one "
248 "operand per called computation but got "
249 << proto.operand_ids_size() << " operands for "
250 << proto.called_computation_ids_size() << " branch computations";
251 auto cond_operands = all_operands();
252 instruction =
253 CreateConditional(shape, cond_operands[0], all_computations(),
254 absl::MakeSpan(cond_operands).subspan(1));
255 break;
256 }
257 case HloOpcode::kReduce:
258 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
259 << "Reduce instruction should have an even number of operands but "
260 "sees "
261 << proto.operand_ids_size();
262 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
263 << "Reduce instruction should have 1 called computation but sees "
264 << proto.called_computation_ids_size();
265 {
266 const auto reduce_operands = all_operands();
267 auto inputs = absl::MakeSpan(reduce_operands)
268 .subspan(0, reduce_operands.size() / 2);
269 auto init_values =
270 absl::MakeSpan(reduce_operands)
271 .subspan(reduce_operands.size() / 2, reduce_operands.size());
272 instruction =
273 CreateReduce(shape, inputs, init_values,
274 std::vector<int64>(proto.dimensions().begin(),
275 proto.dimensions().end()),
276 computations(0));
277 }
278 break;
279 case HloOpcode::kSort: {
280 TF_RET_CHECK(proto.operand_ids_size() >= 1)
281 << "Sort instruction should have at least 1 operand but has "
282 << proto.operand_ids_size();
283 TF_RET_CHECK(proto.dimensions().size() == 1)
284 << "Sort instruction should have 1 dimension";
285 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
286 << "Sort instruction should one called computation but sees "
287 << proto.called_computation_ids_size();
288 auto sort_operands = all_operands();
289 instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
290 computations(0), proto.is_stable());
291 break;
292 }
293 case HloOpcode::kTranspose:
294 instruction =
295 CreateTranspose(shape, operands(0),
296 std::vector<int64>(proto.dimensions().begin(),
297 proto.dimensions().end()));
298 break;
299 case HloOpcode::kBroadcast:
300 instruction =
301 CreateBroadcast(shape, operands(0),
302 std::vector<int64>(proto.dimensions().begin(),
303 proto.dimensions().end()));
304 break;
305 case HloOpcode::kMap:
306 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
307 << "Map instruction should have 1 called computation but sees "
308 << proto.called_computation_ids_size();
309 instruction = CreateMap(shape, all_operands(), computations(0));
310 break;
311 case HloOpcode::kSlice: {
312 std::vector<int64> slice_starts, slice_limits, slice_strides;
313 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
314 proto.slice_dimensions()) {
315 slice_starts.push_back(slice_dimensions.start());
316 slice_limits.push_back(slice_dimensions.limit());
317 slice_strides.push_back(slice_dimensions.stride());
318 }
319 instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
320 slice_strides);
321 break;
322 }
323 case HloOpcode::kConstant: {
324 // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
325 if (proto.has_literal()) {
326 TF_ASSIGN_OR_RETURN(
327 auto literal,
328 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
329 instruction = CreateConstant(std::move(literal));
330 // Literal's shape may have no/different tiling info.
331 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
332 instruction->shape(), shape))
333 << instruction->shape().ToString(true) << " vs "
334 << shape.ToString(true);
335 *instruction->mutable_shape() = shape;
336 } else {
337 instruction = absl::make_unique<HloConstantInstruction>(shape);
338 }
339 break;
340 }
341 case HloOpcode::kTrace: {
342 TF_RET_CHECK(proto.has_literal());
343 TF_ASSIGN_OR_RETURN(
344 auto literal,
345 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
346 instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
347 break;
348 }
349 case HloOpcode::kFusion: {
350 // In the proto, fused computations are held exclusively within the
351 // HloInstructionProto and do not appear as an HloComputationProto within
352 // the HloModuleProto.
353 TF_RET_CHECK(!proto.fusion_kind().empty());
354 TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
355 StringToFusionKind(proto.fusion_kind()));
356
357 // Find the fused computation and set its fusion instruction.
358 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
359 << "Expect 1 called computation for fusion instruction but sees "
360 << proto.called_computation_ids_size();
361 const int64_t fusion_id = proto.called_computation_ids(0);
362 auto* fused_computation =
363 tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
364 TF_RET_CHECK(fused_computation != nullptr)
365 << "No fusion computation with id " << fusion_id;
366 instruction =
367 CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
368 break;
369 }
370 case HloOpcode::kRng:
371 instruction = CreateRng(shape, proto.distribution(), all_operands());
372 break;
373 case HloOpcode::kRngBitGenerator:
374 instruction =
375 CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
376 break;
377 case HloOpcode::kRngGetAndUpdateState:
378 instruction = CreateRngGetAndUpdateState(shape, proto.delta());
379 break;
380 case HloOpcode::kParameter:
381 instruction =
382 CreateParameter(proto.parameter_number(), shape, proto.name());
383 if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
384 instruction->set_parameter_replicated_at_leaf_buffers(
385 proto.parameter_replication().replicated_at_leaf_buffers());
386 }
387 break;
388 case HloOpcode::kGetTupleElement:
389 instruction =
390 CreateGetTupleElement(shape, operands(0), proto.tuple_index());
391 break;
392 case HloOpcode::kReducePrecision:
393 instruction = CreateReducePrecision(
394 shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
395 break;
396 case HloOpcode::kInfeed: {
397 TF_RET_CHECK(shape.IsTuple() &&
398 (ShapeUtil::TupleElementCount(shape) == 2))
399 << "Infeed should have a tuple shape with 2 operands, but has: "
400 << shape;
401 const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
402 instruction =
403 CreateInfeed(data_shape, operands(0), proto.infeed_config());
404 } break;
405 case HloOpcode::kOutfeed: {
406 Shape outfeed_shape(proto.outfeed_shape());
407 TF_RETURN_IF_ERROR(
408 ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
409 instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
410 proto.outfeed_config());
411 break;
412 }
413 case HloOpcode::kAllGather:
414 case HloOpcode::kAllGatherStart: {
415 absl::optional<int64> channel_id;
416 if (proto.channel_id() > 0) {
417 channel_id = proto.channel_id();
418 }
419
420 TF_RET_CHECK(proto.dimensions_size() == 1)
421 << "AllGather cannot have more than 1 all-gather dimensions";
422 int64_t all_gather_dimension = proto.dimensions(0);
423 if (opcode == HloOpcode::kAllGather) {
424 instruction = CreateAllGather(
425 shape, all_operands(), all_gather_dimension,
426 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
427 proto.replica_groups().end()),
428 proto.constrain_layout(), channel_id,
429 proto.use_global_device_ids());
430 } else {
431 instruction = CreateAllGatherStart(
432 shape, all_operands(), all_gather_dimension,
433 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
434 proto.replica_groups().end()),
435 proto.constrain_layout(), channel_id,
436 proto.use_global_device_ids());
437 }
438 break;
439 }
440 case HloOpcode::kAllReduce:
441 case HloOpcode::kAllReduceStart:
442 case HloOpcode::kReduceScatter: {
443 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
444 << "AllReduce should have 1 called computation but sees "
445 << proto.called_computation_ids_size();
446 TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
447 << "AllReduce cannot have both channel_id() and all_reduce_id()";
448 absl::optional<int64> channel_id;
449 if (proto.channel_id() > 0) {
450 channel_id = proto.channel_id();
451 }
452 if (proto.all_reduce_id() > 0) {
453 channel_id = proto.all_reduce_id();
454 }
455 std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(),
456 proto.replica_groups().end());
457 if (opcode == HloOpcode::kAllReduce) {
458 instruction =
459 CreateAllReduce(shape, all_operands(), computations(0),
460 replica_groups, proto.constrain_layout(),
461 channel_id, proto.use_global_device_ids());
462 } else if (opcode == HloOpcode::kReduceScatter) {
463 TF_RET_CHECK(proto.dimensions_size() == 1)
464 << "ReduceScatter cannot have more than 1 scatter dimensions";
465 int64_t scatter_dimension = proto.dimensions(0);
466 instruction = CreateReduceScatter(
467 shape, all_operands(), computations(0), replica_groups,
468 proto.constrain_layout(), channel_id, proto.use_global_device_ids(),
469 scatter_dimension);
470 } else {
471 instruction =
472 CreateAllReduceStart(shape, all_operands(), computations(0),
473 replica_groups, proto.constrain_layout(),
474 channel_id, proto.use_global_device_ids());
475 }
476 break;
477 }
478 case HloOpcode::kAllToAll: {
479 absl::optional<int64> channel_id;
480 if (proto.channel_id() > 0) {
481 channel_id = proto.channel_id();
482 }
483 absl::optional<int64> split_dimension;
484 if (proto.dimensions_size() > 0) {
485 TF_RET_CHECK(proto.dimensions_size() == 1)
486 << "AllToAll cannot have more than 1 dimension (split dimension)";
487 TF_RET_CHECK(all_operands().size() == 1)
488 << "AllToAll must have a single operand when the split dimension "
489 "is specified";
490 split_dimension = proto.dimensions(0);
491 }
492 instruction = CreateAllToAll(
493 shape, all_operands(),
494 /*replica_groups=*/
495 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
496 proto.replica_groups().end()),
497 /*constrain_layout=*/proto.constrain_layout(),
498 /*channel_id=*/channel_id, split_dimension);
499 break;
500 }
501 case HloOpcode::kCollectivePermute:
502 case HloOpcode::kCollectivePermuteStart: {
503 TF_RET_CHECK(proto.operand_ids().size() == 1 ||
504 proto.operand_ids().size() == 4);
505 std::vector<std::pair<int64_t, int64_t>> source_target_pairs(
506 proto.source_target_pairs_size());
507 absl::optional<int64_t> channel_id;
508 if (proto.channel_id() > 0) {
509 channel_id = proto.channel_id();
510 }
511 for (int i = 0; i < source_target_pairs.size(); ++i) {
512 source_target_pairs[i].first = proto.source_target_pairs(i).source();
513 source_target_pairs[i].second = proto.source_target_pairs(i).target();
514 }
515 if (proto.dynamic_slice_sizes_size() == 0) {
516 if (opcode == HloOpcode::kCollectivePermute) {
517 instruction = CreateCollectivePermute(
518 shape, operands(0), source_target_pairs, channel_id);
519 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
520 instruction = CreateCollectivePermuteStart(
521 shape, operands(0), source_target_pairs, channel_id);
522 } else {
523 LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
524 << "but got " << HloOpcodeString(opcode);
525 }
526 } else {
527 std::vector<std::vector<int64_t>> slice_sizes;
528 HloInstruction* input = operands(0);
529 HloInstruction* input_start_indices = operands(2);
530 if (input->shape().IsTuple() &&
531 input->shape().tuple_shapes_size() > 1) {
532 slice_sizes.resize(input->shape().tuple_shapes_size());
533 } else {
534 slice_sizes.resize(1);
535 }
536 int proto_index = 0;
537 if (input->shape().IsTuple()) {
538 if (input_start_indices->shape()
539 .tuple_shapes(0)
540 .tuple_shapes(0)
541 .IsArray()) {
542 slice_sizes.resize(input->shape().tuple_shapes_size());
543 for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
544 slice_sizes[i].resize(
545 input->shape().tuple_shapes(i).dimensions_size());
546 for (int j = 0;
547 j < input->shape().tuple_shapes(i).dimensions_size(); ++j) {
548 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
549 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
550 proto_index += 1;
551 }
552 }
553 } else {
554 slice_sizes.resize(
555 input->shape().tuple_shapes_size() *
556 ShapeUtil::TupleElementCount(
557 input_start_indices->shape().tuple_shapes(0)));
558 int slice_sizes_count = 0;
559 for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
560 for (int j = 0;
561 j < ShapeUtil::TupleElementCount(
562 input_start_indices->shape().tuple_shapes(i));
563 ++j) {
564 slice_sizes[slice_sizes_count].resize(
565 input->shape().tuple_shapes(i).rank());
566 for (int k = 0; k < input->shape().tuple_shapes(i).rank();
567 ++k) {
568 CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
569 slice_sizes[slice_sizes_count][k] =
570 proto.dynamic_slice_sizes(proto_index);
571 proto_index += 1;
572 }
573 slice_sizes_count += 1;
574 }
575 }
576 }
577 } else {
578 slice_sizes.resize(
579 ShapeUtil::TupleElementCount(input_start_indices->shape()));
580 if (input_start_indices->shape().tuple_shapes(0).IsTuple()) {
581 for (int i = 0;
582 i < ShapeUtil::TupleElementCount(input_start_indices->shape());
583 ++i) {
584 slice_sizes[i].resize(input->shape().dimensions_size());
585 for (int j = 0; j < input->shape().dimensions_size(); ++j) {
586 slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
587 proto_index += 1;
588 }
589 }
590 } else {
591 slice_sizes.resize(1);
592 slice_sizes[0].resize(input->shape().dimensions_size());
593 for (int j = 0; j < input->shape().dimensions_size(); ++j) {
594 slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index);
595 proto_index += 1;
596 }
597 }
598 }
599 if (opcode == HloOpcode::kCollectivePermute) {
600 instruction = CreateCollectivePermute(
601 shape, operands(0), operands(1), operands(2), operands(3),
602 source_target_pairs, slice_sizes, channel_id);
603 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
604 instruction = CreateCollectivePermuteStart(
605 shape, operands(0), operands(1), operands(2), operands(3),
606 source_target_pairs, slice_sizes, channel_id);
607 } else {
608 LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
609 << "but got " << HloOpcodeString(opcode);
610 }
611 }
612 break;
613 }
614 case HloOpcode::kReplicaId: {
615 instruction = CreateReplicaId(shape);
616 break;
617 }
618 case HloOpcode::kPartitionId: {
619 instruction = CreatePartitionId(shape);
620 break;
621 }
622 case HloOpcode::kConvolution: {
623 TF_RET_CHECK(proto.has_window());
624 TF_RET_CHECK(proto.has_convolution_dimension_numbers());
625 PrecisionConfig precision_config = proto.precision_config();
626 precision_config.mutable_operand_precision()->Resize(
627 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
628 instruction = CreateConvolve(
629 shape, operands(0), operands(1),
630 std::max<int64>(proto.feature_group_count(), 1),
631 std::max<int64>(proto.batch_group_count(), 1), proto.window(),
632 proto.convolution_dimension_numbers(), precision_config);
633 break;
634 }
635 case HloOpcode::kReduceWindow:
636 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
637 << "Reduce window should have an even number of operands but "
638 "sees "
639 << proto.operand_ids_size();
640 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
641 << "ReduceWindow should have 1 called computation but sees "
642 << proto.called_computation_ids_size();
643 {
644 const auto reduce_operands = all_operands();
645 auto inputs = absl::MakeSpan(reduce_operands)
646 .subspan(0, reduce_operands.size() / 2);
647 auto init_values =
648 absl::MakeSpan(reduce_operands)
649 .subspan(reduce_operands.size() / 2, reduce_operands.size());
650 instruction = CreateReduceWindow(shape, inputs, init_values,
651 proto.window(), computations(0));
652 }
653 break;
654 case HloOpcode::kSelectAndScatter:
655 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
656 << "SelectAndScatter should have 2 called computations but sees "
657 << proto.called_computation_ids_size();
658 instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
659 proto.window(), operands(1),
660 operands(2), computations(1));
661 break;
662 case HloOpcode::kCustomCall: {
663 if (proto.constrain_layout()) {
664 // A proto RepeatedPtrField cannot be converted to a Span (it is a
665 // vector of pointers essentially) so create a vector of shapes to pass
666 // in.
667 std::vector<Shape> operand_shapes;
668 for (const ShapeProto& shape_proto :
669 proto.operand_shapes_with_layout()) {
670 operand_shapes.emplace_back(shape_proto);
671 }
672 instruction =
673 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
674 operand_shapes, proto.backend_config());
675 } else {
676 if (proto.called_computation_ids_size() == 1) {
677 instruction = CreateCustomCall(shape, all_operands(), computations(0),
678 proto.custom_call_target(),
679 proto.backend_config());
680 } else if (proto.called_computation_ids_size() > 1) {
681 instruction = CreateCustomCall(
682 shape, all_operands(), all_computations(),
683 proto.custom_call_target(), proto.backend_config());
684
685 } else {
686 instruction = CreateCustomCall(shape, all_operands(),
687 proto.custom_call_target(),
688 proto.backend_config());
689 }
690 }
691 auto custom_call_instr =
692 Cast<HloCustomCallInstruction>(instruction.get());
693 if (proto.has_window()) {
694 custom_call_instr->set_window(proto.window());
695 }
696 if (proto.has_literal()) {
697 TF_ASSIGN_OR_RETURN(
698 auto literal,
699 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
700 custom_call_instr->set_literal(std::move(literal));
701 }
702 if (proto.has_convolution_dimension_numbers()) {
703 custom_call_instr->set_convolution_dimension_numbers(
704 proto.convolution_dimension_numbers());
705 }
706 custom_call_instr->set_feature_group_count(
707 std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
708 custom_call_instr->set_batch_group_count(
709 std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
710 custom_call_instr->set_custom_call_has_side_effect(
711 proto.custom_call_has_side_effect());
712 custom_call_instr->set_padding_type(proto.padding_type());
713
714 PrecisionConfig precision_config = proto.precision_config();
715 precision_config.mutable_operand_precision()->Resize(
716 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
717 *custom_call_instr->mutable_precision_config() = precision_config;
718 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
719 output_to_operand_aliasing;
720 for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
721 output_to_operand_aliasing.emplace_back(
722 ShapeIndex(aliasing.output_shape_index().begin(),
723 aliasing.output_shape_index().end()),
724 std::pair<int64, ShapeIndex>{
725 aliasing.operand_index(),
726 ShapeIndex(aliasing.operand_shape_index().begin(),
727 aliasing.operand_shape_index().end())});
728 }
729 custom_call_instr->set_output_to_operand_aliasing(
730 std::move(output_to_operand_aliasing));
731 custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule());
732 custom_call_instr->set_api_version(proto.custom_call_api_version());
733 break;
734 }
735 case HloOpcode::kPad:
736 TF_RET_CHECK(proto.has_padding_config());
737 instruction =
738 CreatePad(shape, operands(0), operands(1), proto.padding_config());
739 break;
740 case HloOpcode::kDynamicSlice: {
741 std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
742 absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
743 TF_RET_CHECK(proto.operand_ids_size() >= 1)
744 << "DynamicSlice instruction should have at least 1 operands but "
745 "sees "
746 << proto.operand_ids_size();
747 // TODO(b/118437727): Old form, make the check unconditional.
748 if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
749 auto expected_operands = 1 + operands(0)->shape().rank();
750 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
751 << "DynamicSlice instruction should have " << expected_operands
752 << " operands, but has " << proto.operand_ids_size();
753 }
754 const auto& operand_vector = all_operands();
755 instruction = CreateDynamicSlice(
756 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
757 slice_sizes);
758 break;
759 }
760 case HloOpcode::kDynamicUpdateSlice: {
761 TF_RET_CHECK(proto.operand_ids_size() >= 2)
762 << "DynamicUpdateSlice instruction should have at least 2 operands "
763 "but sees "
764 << proto.operand_ids_size();
765 // TODO(b/118437727): Old form, make the check unconditional.
766 if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
767 auto expected_operands = 2 + operands(0)->shape().rank();
768 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
769 << "DynamicUpdateSlice instruction should have "
770 << expected_operands << " operands, but has "
771 << proto.operand_ids_size();
772 }
773 const auto& operand_vector = all_operands();
774 instruction =
775 CreateDynamicUpdateSlice(shape, operands(0), operands(1),
776 absl::MakeSpan(operand_vector).subspan(2));
777
778 break;
779 }
780 case HloOpcode::kGather: {
781 TF_RET_CHECK(proto.has_gather_dimension_numbers())
782 << "Gather instruction should have GatherDimensionNumbers set.";
783 auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
784 proto.gather_dimension_numbers());
785 std::vector<int64> gather_slice_sizes;
786 for (int64_t bound : proto.gather_slice_sizes()) {
787 gather_slice_sizes.push_back(bound);
788 }
789 instruction = CreateGather(shape, operands(0), operands(1),
790 *gather_dimension_numbers, gather_slice_sizes,
791 proto.indices_are_sorted());
792 break;
793 }
794 case HloOpcode::kScatter: {
795 TF_RET_CHECK(proto.has_scatter_dimension_numbers())
796 << "Scatter instruction should have ScatterDimensionNumbers set.";
797 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
798 << "Scatter instruction should have 1 called computation but sees "
799 << proto.called_computation_ids_size();
800 auto scatter_dimension_numbers =
801 absl::make_unique<ScatterDimensionNumbers>(
802 proto.scatter_dimension_numbers());
803 instruction =
804 CreateScatter(shape, operands(0), operands(1), operands(2),
805 computations(0), *scatter_dimension_numbers,
806 proto.indices_are_sorted(), proto.unique_indices());
807 break;
808 }
809 case HloOpcode::kIota:
810 TF_RET_CHECK(proto.dimensions_size() == 1)
811 << "Iota instruction should have 1 dimension but sees "
812 << proto.dimensions_size();
813 instruction = CreateIota(shape, proto.dimensions(0));
814 break;
815 case HloOpcode::kDot: {
816 TF_RET_CHECK(proto.has_dot_dimension_numbers())
817 << "Dot instruction should have dot_dimension_numbers.";
818 PrecisionConfig precision_config = proto.precision_config();
819 precision_config.mutable_operand_precision()->Resize(
820 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
821 instruction = absl::make_unique<HloDotInstruction>(
822 shape, operands(0), operands(1), proto.dot_dimension_numbers(),
823 precision_config);
824 break;
825 }
826 case HloOpcode::kDomain: {
827 std::shared_ptr<const HloSharding> entry_hlo_sharding;
828 std::shared_ptr<const HloSharding> exit_hlo_sharding;
829 if (proto.has_domain_entry_sharding()) {
830 TF_ASSIGN_OR_RETURN(
831 HloSharding sharding,
832 HloSharding::FromProto(proto.domain_entry_sharding()));
833 entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
834 }
835 if (proto.has_domain_exit_sharding()) {
836 TF_ASSIGN_OR_RETURN(
837 HloSharding sharding,
838 HloSharding::FromProto(proto.domain_exit_sharding()));
839 exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
840 }
841 instruction = absl::make_unique<HloDomainInstruction>(
842 shape, operands(0),
843 absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
844 absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
845 break;
846 }
847 case HloOpcode::kGetDimensionSize:
848 TF_RET_CHECK(proto.dimensions_size() == 1);
849 instruction =
850 CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
851 break;
852 case HloOpcode::kSetDimensionSize:
853 TF_RET_CHECK(proto.dimensions_size() == 1);
854 instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
855 proto.dimensions(0));
856 break;
857 case HloOpcode::kReshape: {
858 int64_t inferred_dimension = -1;
859 if (!proto.dimensions().empty()) {
860 inferred_dimension = proto.dimensions()[0];
861 }
862 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
863 ShapeUtil::ElementsIn(shape) ==
864 ShapeUtil::ElementsIn(operands(0)->shape()))
865 << "shape: " << ShapeUtil::HumanString(shape)
866 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
867 instruction = CreateReshape(shape, operands(0), inferred_dimension);
868 break;
869 }
870 case HloOpcode::kDynamicReshape: {
871 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
872 ShapeUtil::ElementsIn(shape) ==
873 ShapeUtil::ElementsIn(operands(0)->shape()))
874 << "shape: " << ShapeUtil::HumanString(shape)
875 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
876 const auto& operand_vector = all_operands();
877 instruction = CreateDynamicReshape(
878 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
879 break;
880 }
881 default: {
882 instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
883 for (const int64_t operand_id : proto.operand_ids()) {
884 instruction->AppendOperand(instruction_map.at(operand_id));
885 }
886 if (instruction->opcode() != HloOpcode::kFusion) {
887 if (instruction->opcode() == HloOpcode::kCall) {
888 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
889 << "Call should have 1 called computation but has "
890 << proto.called_computation_ids_size();
891 }
892 for (const int64_t computation_id : proto.called_computation_ids()) {
893 instruction->called_computations_.push_back(
894 computation_map.at(computation_id));
895 }
896 }
897 TF_RET_CHECK(!proto.has_precision_config())
898 << instruction->opcode() << proto.DebugString();
899 TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
900 break;
901 }
902 }
903
904 for (const int64_t predecessor_id : proto.control_predecessor_ids()) {
905 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
906 << "No instruction with id " << predecessor_id;
907 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
908 ->AddControlDependencyTo(instruction.get()));
909 }
910
911 TF_RET_CHECK(!proto.name().empty());
912 instruction->SetAndSanitizeName(proto.name());
913 instruction->metadata_ = proto.metadata();
914 instruction->backend_config_ = proto.backend_config();
915 instruction->outer_dimension_partitions_.assign(
916 proto.outer_dimension_partitions().begin(),
917 proto.outer_dimension_partitions().end());
918
919 TF_RET_CHECK(proto.id() >= 0)
920 << "Instruction with negative id: " << proto.id();
921 TF_RET_CHECK(proto.id() <= INT_MAX)
922 << "Instruction with id > INT_MAX: " << proto.id();
923 instruction->unique_id_ = proto.id();
924
925 if (proto.has_sharding()) {
926 TF_ASSIGN_OR_RETURN(const auto& sharding,
927 HloSharding::FromProto(proto.sharding()));
928 instruction->set_sharding(sharding);
929 }
930
931 if (proto.has_frontend_attributes()) {
932 instruction->set_frontend_attributes(proto.frontend_attributes());
933 }
934
935 return std::move(instruction);
936 }
937
CreateParameter(int64_t parameter_number,const Shape & shape,const string & name)938 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
939 int64_t parameter_number, const Shape& shape, const string& name) {
940 return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
941 name);
942 }
943
CreateTrace(const string & tag,HloInstruction * operand)944 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
945 const string& tag, HloInstruction* operand) {
946 return absl::make_unique<HloTraceInstruction>(tag, operand);
947 }
948
CreateConstant(Literal literal)949 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
950 Literal literal) {
951 return absl::make_unique<HloConstantInstruction>(std::move(literal));
952 }
953
CreateIota(const Shape & shape,int64_t iota_dimension)954 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
955 const Shape& shape, int64_t iota_dimension) {
956 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
957 }
958
959 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64_t index)960 HloInstruction::CreateGetTupleElement(const Shape& shape,
961 HloInstruction* operand, int64_t index) {
962 return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
963 index);
964 }
965
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)966 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
967 const Shape& shape, RandomDistribution distribution,
968 absl::Span<HloInstruction* const> parameters) {
969 return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
970 }
971
972 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64_t delta)973 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64_t delta) {
974 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
975 }
976
977 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)978 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
979 RandomAlgorithm algorithm) {
980 return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
981 algorithm);
982 }
983
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)984 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
985 const Shape& shape, HloOpcode opcode,
986 absl::Span<HloInstruction* const> operands) {
987 if (opcode == HloOpcode::kCopy) {
988 // It is impossible to copy an opaque shape, we don't know how big it is.
989 CHECK(!shape.IsOpaque());
990 }
991 auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
992 for (auto operand : operands) {
993 instruction->AppendOperand(operand);
994 }
995 return instruction;
996 }
997
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)998 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
999 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
1000 // Only certain opcodes are supported with CreateUnary: opcodes of unary
1001 // instructions with no auxiliary fields.
1002 switch (opcode) {
1003 case HloOpcode::kAbs:
1004 case HloOpcode::kAllGatherDone:
1005 case HloOpcode::kAllReduceDone:
1006 case HloOpcode::kRoundNearestAfz:
1007 case HloOpcode::kBitcast:
1008 case HloOpcode::kCeil:
1009 case HloOpcode::kCollectivePermuteDone:
1010 case HloOpcode::kCopy:
1011 case HloOpcode::kCopyDone:
1012 case HloOpcode::kCos:
1013 case HloOpcode::kClz:
1014 case HloOpcode::kExp:
1015 case HloOpcode::kExpm1:
1016 case HloOpcode::kFloor:
1017 case HloOpcode::kImag:
1018 case HloOpcode::kIsFinite:
1019 case HloOpcode::kLog:
1020 case HloOpcode::kLog1p:
1021 case HloOpcode::kNot:
1022 case HloOpcode::kNegate:
1023 case HloOpcode::kPopulationCount:
1024 case HloOpcode::kReal:
1025 case HloOpcode::kRsqrt:
1026 case HloOpcode::kLogistic:
1027 case HloOpcode::kSign:
1028 case HloOpcode::kSin:
1029 case HloOpcode::kSqrt:
1030 case HloOpcode::kCbrt:
1031 case HloOpcode::kTanh:
1032 break;
1033 default:
1034 LOG(FATAL) << "Invalid unary instruction opcode "
1035 << HloOpcodeString(opcode);
1036 }
1037 return CreateNary(shape, opcode, {operand});
1038 }
1039
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)1040 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
1041 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1042 HloInstruction* rhs) {
1043 // Only certain opcodes are supported with CreateBinary: opcodes of binary
1044 // instructions with no auxiliary fields.
1045 switch (opcode) {
1046 case HloOpcode::kAdd:
1047 case HloOpcode::kAtan2:
1048 case HloOpcode::kDivide:
1049 case HloOpcode::kComplex:
1050 case HloOpcode::kMaximum:
1051 case HloOpcode::kMinimum:
1052 case HloOpcode::kMultiply:
1053 case HloOpcode::kPower:
1054 case HloOpcode::kRemainder:
1055 case HloOpcode::kSubtract:
1056 case HloOpcode::kAnd:
1057 case HloOpcode::kOr:
1058 case HloOpcode::kXor:
1059 case HloOpcode::kShiftLeft:
1060 case HloOpcode::kShiftRightArithmetic:
1061 case HloOpcode::kShiftRightLogical:
1062 break;
1063 default:
1064 LOG(FATAL) << "Invalid binary instruction opcode "
1065 << HloOpcodeString(opcode);
1066 }
1067 return CreateNary(shape, opcode, {lhs, rhs});
1068 }
1069
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)1070 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
1071 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
1072 HloInstruction* rhs, HloInstruction* ehs) {
1073 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
1074 // instructions with no auxiliary fields.
1075 switch (opcode) {
1076 case HloOpcode::kClamp:
1077 case HloOpcode::kSelect:
1078 case HloOpcode::kTupleSelect:
1079 break;
1080 default:
1081 LOG(FATAL) << "Invalid ternary instruction opcode "
1082 << HloOpcodeString(opcode);
1083 }
1084 return CreateNary(shape, opcode, {lhs, rhs, ehs});
1085 }
1086
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)1087 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
1088 const Shape& shape, HloOpcode opcode,
1089 absl::Span<HloInstruction* const> operands) {
1090 CHECK_EQ(HloOpcode::kTuple, opcode);
1091 return CreateNary(shape, opcode, operands);
1092 }
1093
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)1094 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
1095 const Shape& shape, absl::Span<HloInstruction* const> operands,
1096 HloComputation* map_computation) {
1097 return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
1098 }
1099
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
1101 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1102 int64_t feature_group_count, int64_t batch_group_count,
1103 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
1104 const PrecisionConfig& precision_config) {
1105 return absl::make_unique<HloConvolutionInstruction>(
1106 shape, lhs, rhs, feature_group_count, batch_group_count, window,
1107 dimension_numbers, precision_config);
1108 }
1109
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)1110 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
1111 const Shape& shape, HloInstruction* operand, FftType fft_type,
1112 absl::Span<const int64> fft_length) {
1113 return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
1114 fft_length);
1115 }
1116
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1117 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1118 const Shape& shape, HloInstruction* operand,
1119 bool is_cross_program_prefetch) {
1120 return absl::make_unique<HloCopyStartInstruction>(shape, operand,
1121 is_cross_program_prefetch);
1122 }
1123
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)1124 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1125 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1126 ComparisonDirection direction, absl::optional<Comparison::Type> type) {
1127 return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1128 type);
1129 }
1130
1131 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1132 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1133 HloInstruction* b,
1134 const TriangularSolveOptions& options) {
1135 return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1136 }
1137
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1138 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1139 const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1140 return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
1141 }
1142
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1143 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1144 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1145 const DotDimensionNumbers& dimension_numbers,
1146 const PrecisionConfig& precision_config) {
1147 return absl::make_unique<HloDotInstruction>(
1148 shape, lhs, rhs, dimension_numbers, precision_config);
1149 }
1150
1151 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1152 HloInstruction::CreateReducePrecision(const Shape& shape,
1153 HloInstruction* operand,
1154 const int exponent_bits,
1155 const int mantissa_bits) {
1156 return absl::make_unique<HloReducePrecisionInstruction>(
1157 shape, operand, exponent_bits, mantissa_bits);
1158 }
1159
CreateAllGather(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1160 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1161 const Shape& shape, absl::Span<HloInstruction* const> operands,
1162 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1163 bool constrain_layout, const absl::optional<int64>& channel_id,
1164 bool use_global_device_ids) {
1165 return absl::make_unique<HloAllGatherInstruction>(
1166 HloOpcode::kAllGather, shape, operands, all_gather_dimension,
1167 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1168 }
1169
1170 /* static */ std::unique_ptr<HloInstruction>
CreateAllGatherStart(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t all_gather_dimension,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1171 HloInstruction::CreateAllGatherStart(
1172 const Shape& shape, absl::Span<HloInstruction* const> operands,
1173 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
1174 bool constrain_layout, const absl::optional<int64>& channel_id,
1175 bool use_global_device_ids) {
1176 return absl::make_unique<HloAllGatherInstruction>(
1177 HloOpcode::kAllGatherStart, shape, operands, all_gather_dimension,
1178 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1179 }
1180
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1181 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1182 const Shape& shape, absl::Span<HloInstruction* const> operands,
1183 HloComputation* reduce_computation,
1184 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1185 const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1186 return absl::make_unique<HloAllReduceInstruction>(
1187 HloOpcode::kAllReduce, shape, operands, reduce_computation,
1188 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1189 }
1190
1191 /* static */ std::unique_ptr<HloInstruction>
CreateReduceScatter(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids,int64_t scatter_dimension)1192 HloInstruction::CreateReduceScatter(
1193 const Shape& shape, absl::Span<HloInstruction* const> operands,
1194 HloComputation* reduce_computation,
1195 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1196 const absl::optional<int64>& channel_id, bool use_global_device_ids,
1197 int64_t scatter_dimension) {
1198 return absl::make_unique<HloReduceScatterInstruction>(
1199 shape, operands, reduce_computation, replica_groups, constrain_layout,
1200 channel_id, use_global_device_ids, scatter_dimension);
1201 }
1202
1203 /* static */ std::unique_ptr<HloInstruction>
CreateAllReduceStart(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1204 HloInstruction::CreateAllReduceStart(
1205 const Shape& shape, absl::Span<HloInstruction* const> operands,
1206 HloComputation* reduce_computation,
1207 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1208 const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1209 return absl::make_unique<HloAllReduceInstruction>(
1210 HloOpcode::kAllReduceStart, shape, operands, reduce_computation,
1211 replica_groups, constrain_layout, channel_id, use_global_device_ids);
1212 }
1213
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<const ReplicaGroup> replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)1214 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1215 const Shape& shape, absl::Span<HloInstruction* const> operands,
1216 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
1217 const absl::optional<int64>& channel_id,
1218 const absl::optional<int64>& split_dimension) {
1219 return absl::make_unique<HloAllToAllInstruction>(
1220 shape, operands, replica_groups, constrain_layout, channel_id,
1221 split_dimension);
1222 }
1223
1224 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const absl::optional<int64_t> & channel_id)1225 HloInstruction::CreateCollectivePermute(
1226 const Shape& shape, HloInstruction* operand,
1227 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1228 const absl::optional<int64_t>& channel_id) {
1229 return absl::make_unique<HloCollectivePermuteInstruction>(
1230 HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1231 channel_id);
1232 }
1233
1234 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)1235 HloInstruction::CreateCollectivePermute(
1236 const Shape& shape, HloInstruction* input, HloInstruction* output,
1237 HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1238 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1239 absl::Span<const std::vector<int64_t>> slice_sizes,
1240 const absl::optional<int64_t>& channel_id) {
1241 return absl::make_unique<HloCollectivePermuteInstruction>(
1242 HloOpcode::kCollectivePermute, shape, input, output, input_start_indices,
1243 output_start_indices, source_target_pairs, slice_sizes, channel_id);
1244 }
1245
1246 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,const absl::optional<int64_t> & channel_id)1247 HloInstruction::CreateCollectivePermuteStart(
1248 const Shape& shape, HloInstruction* operand,
1249 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1250 const absl::optional<int64_t>& channel_id) {
1251 return absl::make_unique<HloCollectivePermuteInstruction>(
1252 HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1253 channel_id);
1254 }
1255
1256 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * input,HloInstruction * output,HloInstruction * input_start_indices,HloInstruction * output_start_indices,absl::Span<const std::pair<int64_t,int64_t>> source_target_pairs,absl::Span<const std::vector<int64_t>> slice_sizes,const absl::optional<int64_t> & channel_id)1257 HloInstruction::CreateCollectivePermuteStart(
1258 const Shape& shape, HloInstruction* input, HloInstruction* output,
1259 HloInstruction* input_start_indices, HloInstruction* output_start_indices,
1260 absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
1261 absl::Span<const std::vector<int64_t>> slice_sizes,
1262 const absl::optional<int64_t>& channel_id) {
1263 return absl::make_unique<HloCollectivePermuteInstruction>(
1264 HloOpcode::kCollectivePermuteStart, shape, input, output,
1265 input_start_indices, output_start_indices, source_target_pairs,
1266 slice_sizes, channel_id);
1267 }
1268
CreateReplicaId(const Shape & shape)1269 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1270 const Shape& shape) {
1271 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1272 << "HloInstruction replica-id must have a shape of u32[], but "
1273 << shape.ToString() << " is specified";
1274 return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1275 }
1276
CreatePartitionId(const Shape & shape)1277 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1278 const Shape& shape) {
1279 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1280 << "HloInstruction partition-id must have a shape of u32[], but "
1281 << shape.ToString() << " is specified";
1282 return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1283 }
1284
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1285 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1286 const Shape& infeed_shape, HloInstruction* token_operand,
1287 const string& config) {
1288 return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1289 config);
1290 }
1291
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1292 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1293 const Shape& outfeed_shape, HloInstruction* operand,
1294 HloInstruction* token_operand, absl::string_view outfeed_config) {
1295 return absl::make_unique<HloOutfeedInstruction>(
1296 outfeed_shape, operand, token_operand, outfeed_config);
1297 }
1298
CreateSend(HloInstruction * operand,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1299 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1300 HloInstruction* operand, HloInstruction* token, int64_t channel_id,
1301 bool is_host_transfer) {
1302 return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
1303 is_host_transfer);
1304 }
1305
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1306 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1307 HloInstruction* operand, bool is_host_transfer) {
1308 auto send_operand = DynCast<HloSendInstruction>(operand);
1309 CHECK(send_operand != nullptr)
1310 << "SendDone must take the context operand from Send";
1311 return absl::make_unique<HloSendDoneInstruction>(send_operand,
1312 is_host_transfer);
1313 }
1314
CreateRecv(const Shape & shape,HloInstruction * token,int64_t channel_id,bool is_host_transfer)1315 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1316 const Shape& shape, HloInstruction* token, int64_t channel_id,
1317 bool is_host_transfer) {
1318 return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
1319 is_host_transfer);
1320 }
1321
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1322 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1323 HloInstruction* operand, bool is_host_transfer) {
1324 auto recv_operand = DynCast<HloRecvInstruction>(operand);
1325 CHECK(recv_operand != nullptr)
1326 << "RecvDone must take the context operand from Recv";
1327 return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
1328 is_host_transfer);
1329 }
1330
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1331 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1332 const Shape& shape, HloInstruction* operand,
1333 absl::Span<const int64> dimensions) {
1334 return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1335 }
1336
CreateAfterAll(absl::Span<HloInstruction * const> operands)1337 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1338 absl::Span<HloInstruction* const> operands) {
1339 CHECK(!operands.empty());
1340 auto instruction = absl::WrapUnique(
1341 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1342 for (auto operand : operands) {
1343 instruction->AppendOperand(operand);
1344 }
1345 return instruction;
1346 }
1347
CreateToken()1348 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1349 return absl::WrapUnique(
1350 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1351 }
1352
1353 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1354 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1355 HloInstruction* token_operand) {
1356 auto instruction = absl::WrapUnique(
1357 new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1358 instruction->AppendOperand(data_operand);
1359 instruction->AppendOperand(token_operand);
1360 return instruction;
1361 }
1362
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1363 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1364 const Shape& shape, HloComputation* condition, HloComputation* body,
1365 HloInstruction* init) {
1366 auto instruction =
1367 absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1368 instruction->AppendOperand(init);
1369 // Body comes before condition computation in the vector.
1370 instruction->called_computations_.push_back(body);
1371 instruction->called_computations_.push_back(condition);
1372 return instruction;
1373 }
1374
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1375 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1376 const Shape& shape, HloInstruction* pred,
1377 HloInstruction* true_computation_arg, HloComputation* true_computation,
1378 HloInstruction* false_computation_arg, HloComputation* false_computation) {
1379 auto instruction =
1380 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1381 instruction->AppendOperand(pred);
1382 instruction->AppendOperand(true_computation_arg);
1383 instruction->AppendOperand(false_computation_arg);
1384 // In called_computations_, the index of true_computation must be 0 and that
1385 // of false computation must be 1, as defined by kTrueComputationIndex and
1386 // kFalseComputationIndex.
1387 instruction->called_computations_.push_back(true_computation);
1388 instruction->called_computations_.push_back(false_computation);
1389 return instruction;
1390 }
1391
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1392 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1393 const Shape& shape, HloInstruction* branch_index,
1394 absl::Span<HloComputation* const> branch_computations,
1395 absl::Span<HloInstruction* const> branch_computation_args) {
1396 auto instruction =
1397 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1398 instruction->AppendOperand(branch_index);
1399 CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1400 for (int i = 0; i < branch_computations.size(); ++i) {
1401 instruction->called_computations_.push_back(branch_computations[i]);
1402 instruction->AppendOperand(branch_computation_args[i]);
1403 }
1404 return instruction;
1405 }
1406
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1407 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1408 const Shape& shape, HloInstruction* operand,
1409 absl::Span<const int64> start_indices,
1410 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1411 return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1412 limit_indices, strides);
1413 }
1414
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1415 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1416 const Shape& shape, HloInstruction* operand,
1417 absl::Span<HloInstruction* const> start_indices,
1418 absl::Span<const int64> slice_sizes) {
1419 return absl::make_unique<HloDynamicSliceInstruction>(
1420 shape, operand, start_indices, slice_sizes);
1421 }
1422
1423 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1424 HloInstruction::CreateDynamicUpdateSlice(
1425 const Shape& shape, HloInstruction* operand, HloInstruction* update,
1426 absl::Span<HloInstruction* const> start_indices) {
1427 return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1428 shape, operand, update, start_indices);
1429 }
1430
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64_t dimension)1431 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1432 const Shape& shape, absl::Span<HloInstruction* const> operands,
1433 int64_t dimension) {
1434 return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1435 dimension);
1436 }
1437
CreateConvert(const Shape & shape,HloInstruction * operand)1438 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1439 const Shape& shape, HloInstruction* operand) {
1440 auto instruction =
1441 absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1442 instruction->AppendOperand(operand);
1443 return instruction;
1444 }
1445
1446 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1447 HloInstruction::CreateBitcastConvert(const Shape& shape,
1448 HloInstruction* operand) {
1449 auto instruction =
1450 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1451 instruction->AppendOperand(operand);
1452 return instruction;
1453 }
1454
CreateBitcast(const Shape & shape,HloInstruction * operand)1455 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1456 const Shape& shape, HloInstruction* operand) {
1457 auto instruction =
1458 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1459 instruction->AppendOperand(operand);
1460 return instruction;
1461 }
1462
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1463 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1464 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1465 absl::Span<const int64> dimensions_to_reduce,
1466 HloComputation* reduce_computation) {
1467 auto instruction = absl::WrapUnique(new HloReduceInstruction(
1468 shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1469 return std::move(instruction);
1470 }
1471
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1472 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1473 const Shape& shape, absl::Span<HloInstruction* const> operands,
1474 absl::Span<HloInstruction* const> init_values,
1475 absl::Span<const int64> dimensions_to_reduce,
1476 HloComputation* reduce_computation) {
1477 std::vector<HloInstruction*> all_args;
1478 all_args.reserve(operands.size() * 2);
1479 all_args.insert(all_args.end(), operands.begin(), operands.end());
1480 all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1481 return absl::make_unique<HloReduceInstruction>(
1482 shape, all_args, dimensions_to_reduce, reduce_computation);
1483 }
1484
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1485 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1486 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1487 const Window& window, HloComputation* reduce_computation) {
1488 return absl::make_unique<HloReduceWindowInstruction>(
1489 shape, operand, init_value, window, reduce_computation);
1490 }
1491
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1492 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1493 const Shape& shape, absl::Span<HloInstruction* const> operands,
1494 absl::Span<HloInstruction* const> init_values, const Window& window,
1495 HloComputation* reduce_computation) {
1496 return absl::make_unique<HloReduceWindowInstruction>(
1497 shape, operands, init_values, window, reduce_computation);
1498 }
1499 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64_t feature_index)1500 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1501 HloInstruction* operand,
1502 HloInstruction* scale,
1503 HloInstruction* offset, float epsilon,
1504 int64_t feature_index) {
1505 return absl::make_unique<HloBatchNormTrainingInstruction>(
1506 shape, operand, scale, offset, epsilon, feature_index);
1507 }
1508
1509 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64_t feature_index)1510 HloInstruction::CreateBatchNormInference(
1511 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1512 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1513 float epsilon, int64_t feature_index) {
1514 return absl::make_unique<HloBatchNormInferenceInstruction>(
1515 shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1516 }
1517
1518 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64_t feature_index)1519 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1520 HloInstruction* scale, HloInstruction* mean,
1521 HloInstruction* variance,
1522 HloInstruction* grad_output, float epsilon,
1523 int64_t feature_index) {
1524 return absl::make_unique<HloBatchNormGradInstruction>(
1525 shape, operand, scale, mean, variance, grad_output, epsilon,
1526 feature_index);
1527 }
1528
1529 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1530 HloInstruction::CreateSelectAndScatter(
1531 const Shape& shape, HloInstruction* operand, HloComputation* select,
1532 const Window& window, HloInstruction* source, HloInstruction* init_value,
1533 HloComputation* scatter) {
1534 return absl::make_unique<HloSelectAndScatterInstruction>(
1535 shape, operand, select, window, source, init_value, scatter);
1536 }
1537
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1538 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1539 const Shape& shape, HloInstruction* operand,
1540 absl::Span<const int64> broadcast_dimensions) {
1541 return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1542 broadcast_dimensions);
1543 }
1544
1545 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64_t dimension)1546 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1547 HloInstruction* operand,
1548 int64_t dimension) {
1549 return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1550 dimension);
1551 }
1552
1553 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64_t dimension)1554 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1555 HloInstruction* operand,
1556 HloInstruction* val, int64_t dimension) {
1557 return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1558 dimension);
1559 }
1560
1561 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1562 HloInstruction::CreateBroadcastSequence(
1563 const Shape& output_shape, HloInstruction* operand,
1564 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1565 adder) {
1566 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1567 operand->shape().rank() == output_shape.rank());
1568 Shape broadcast_shape = ShapeUtil::ChangeElementType(
1569 output_shape, operand->shape().element_type());
1570 // Do explicit broadcast for scalar.
1571 if (ShapeUtil::IsScalar(operand->shape())) {
1572 auto broadcast =
1573 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1574 broadcast->set_metadata(operand->metadata());
1575 if (operand->has_sharding()) {
1576 broadcast->set_sharding(operand->sharding());
1577 }
1578 broadcast->set_frontend_attributes(operand->frontend_attributes());
1579 return broadcast;
1580 }
1581 // Do explicit broadcast for degenerate broadcast.
1582 std::vector<int64> broadcast_dimensions;
1583 std::vector<int64> reshaped_dimensions;
1584 for (int i = 0; i < operand->shape().rank(); i++) {
1585 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1586 broadcast_dimensions.push_back(i);
1587 reshaped_dimensions.push_back(operand->shape().dimensions(i));
1588 } else {
1589 CHECK_EQ(operand->shape().dimensions(i), 1)
1590 << "An explicit broadcast sequence requires the broadcasted "
1591 "dimensions to be trivial; operand: "
1592 << operand->ToString() << "; output_shape: " << output_shape;
1593 }
1594 }
1595 // Eliminate the size one dimensions.
1596 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1597 ShapeUtil::MakeShape(operand->shape().element_type(),
1598 reshaped_dimensions),
1599 operand));
1600 reshaped_operand->set_metadata(operand->metadata());
1601 if (operand->has_sharding()) {
1602 reshaped_operand->set_sharding(operand->sharding());
1603 }
1604 reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1605 // Broadcast 'reshape' up to the larger size.
1606 auto broadcast = HloInstruction::CreateBroadcast(
1607 broadcast_shape, reshaped_operand, broadcast_dimensions);
1608 broadcast->set_metadata(operand->metadata());
1609 if (operand->has_sharding()) {
1610 broadcast->set_sharding(operand->sharding());
1611 }
1612 broadcast->set_frontend_attributes(operand->frontend_attributes());
1613 return broadcast;
1614 }
1615
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1616 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1617 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1618 const PaddingConfig& padding_config) {
1619 return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1620 padding_config);
1621 }
1622
CreateReshape(const Shape & shape,HloInstruction * operand,int64_t inferred_dimension)1623 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1624 const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) {
1625 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1626 ShapeUtil::ElementsIn(operand->shape()))
1627 << "shape: " << ShapeUtil::HumanString(shape)
1628 << " operand: " << ShapeUtil::HumanString(operand->shape());
1629
1630 return absl::make_unique<HloReshapeInstruction>(shape, operand,
1631 inferred_dimension);
1632 }
1633
1634 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1635 HloInstruction::CreateDynamicReshape(
1636 const Shape& shape, HloInstruction* data_operand,
1637 absl::Span<HloInstruction* const> dim_sizes) {
1638 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1639 ShapeUtil::ElementsIn(data_operand[0].shape()))
1640 << "shape: " << ShapeUtil::HumanString(shape)
1641 << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1642 CHECK_EQ(shape.rank(), dim_sizes.size());
1643 return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1644 dim_sizes);
1645 }
1646
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1647 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1648 const Shape& shape, HloInstruction* operand,
1649 absl::Span<const int64> dimensions) {
1650 return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1651 }
1652
CreateSort(const Shape & shape,int64_t dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1653 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1654 const Shape& shape, int64_t dimension,
1655 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1656 bool is_stable) {
1657 return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1658 compare, is_stable);
1659 }
1660
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1661 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1662 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1663 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1664 fused_root);
1665 }
1666
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1667 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1668 const Shape& shape, FusionKind fusion_kind,
1669 absl::Span<HloInstruction* const> operands,
1670 HloComputation* fusion_computation) {
1671 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1672 fusion_computation);
1673 }
1674
set_single_sharding(const HloSharding & sharding)1675 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1676 CHECK(!sharding.IsTuple()) << sharding;
1677 if (shape().IsTuple()) {
1678 set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1679 } else {
1680 set_sharding(sharding);
1681 }
1682 }
1683
SetupDerivedInstruction(HloInstruction * derived_instruction) const1684 void HloInstruction::SetupDerivedInstruction(
1685 HloInstruction* derived_instruction) const {
1686 if (sharding_ != nullptr &&
1687 ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1688 // Only copy sharding if the tuple tree shape of the two instruction is
1689 // compatible because copying it between differently shaped instructions
1690 // can produce invalid shardings.
1691 derived_instruction->set_sharding(*sharding_);
1692 } else {
1693 derived_instruction->clear_sharding();
1694 }
1695 derived_instruction->set_metadata(metadata_);
1696 derived_instruction->set_frontend_attributes(frontend_attributes_);
1697 }
1698
HasSideEffectNoRecurse() const1699 bool HloInstruction::HasSideEffectNoRecurse() const {
1700 switch (opcode_) {
1701 case HloOpcode::kSend:
1702 case HloOpcode::kSendDone:
1703 case HloOpcode::kRecv:
1704 case HloOpcode::kRecvDone:
1705 case HloOpcode::kRng:
1706 case HloOpcode::kRngGetAndUpdateState:
1707 case HloOpcode::kInfeed:
1708 case HloOpcode::kOutfeed:
1709 case HloOpcode::kTrace:
1710 case HloOpcode::kAllReduceStart:
1711 case HloOpcode::kAllReduceDone:
1712 case HloOpcode::kAllGatherStart:
1713 case HloOpcode::kAllGatherDone:
1714 case HloOpcode::kCollectivePermuteStart:
1715 case HloOpcode::kCollectivePermuteDone:
1716 return true;
1717 case HloOpcode::kAllReduce:
1718 return channel_id().has_value() ||
1719 Cast<HloAllReduceInstruction>(this)->constrain_layout();
1720 case HloOpcode::kAllToAll:
1721 return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1722 case HloOpcode::kCustomCall:
1723 return Cast<HloCustomCallInstruction>(this)
1724 ->custom_call_has_side_effect();
1725 default:
1726 return false;
1727 }
1728 }
1729
HasSideEffect() const1730 bool HloInstruction::HasSideEffect() const {
1731 if (HasSideEffectNoRecurse()) {
1732 return true;
1733 }
1734 // Check if any of the called computations has a side effect.
1735 for (const auto& computation : called_computations()) {
1736 if (computation->HasSideEffect()) {
1737 return true;
1738 }
1739 }
1740 return false;
1741 }
1742
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1743 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1744 const Shape& shape, absl::Span<HloInstruction* const> operands,
1745 HloComputation* computation) {
1746 std::unique_ptr<HloInstruction> instruction =
1747 absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1748 for (auto operand : operands) {
1749 instruction->AppendOperand(operand);
1750 }
1751 instruction->called_computations_.push_back(computation);
1752 return instruction;
1753 }
1754
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1755 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1756 const Shape& shape, absl::Span<HloInstruction* const> operands,
1757 absl::string_view custom_call_target, string opaque,
1758 CustomCallApiVersion api_version) {
1759 return absl::make_unique<HloCustomCallInstruction>(
1760 shape, operands, custom_call_target, std::move(opaque), api_version);
1761 }
1762
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1763 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1764 const Shape& shape, absl::Span<HloInstruction* const> operands,
1765 HloComputation* to_apply, absl::string_view custom_call_target,
1766 string opaque, CustomCallApiVersion api_version) {
1767 return absl::make_unique<HloCustomCallInstruction>(
1768 shape, operands, to_apply, custom_call_target, std::move(opaque),
1769 api_version);
1770 }
1771
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque,CustomCallApiVersion api_version)1772 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1773 const Shape& shape, absl::Span<HloInstruction* const> operands,
1774 absl::Span<HloComputation* const> called_computations,
1775 absl::string_view custom_call_target, string opaque,
1776 CustomCallApiVersion api_version) {
1777 return absl::make_unique<HloCustomCallInstruction>(
1778 shape, operands, called_computations, custom_call_target,
1779 std::move(opaque), api_version);
1780 }
1781
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque,CustomCallApiVersion api_version)1782 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1783 const Shape& shape, absl::Span<HloInstruction* const> operands,
1784 absl::string_view custom_call_target,
1785 absl::Span<const Shape> operand_shapes_with_layout, string opaque,
1786 CustomCallApiVersion api_version) {
1787 return absl::make_unique<HloCustomCallInstruction>(
1788 shape, operands, custom_call_target, std::move(opaque),
1789 operand_shapes_with_layout, api_version);
1790 }
1791
CreateTuple(absl::Span<HloInstruction * const> elements)1792 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1793 absl::Span<HloInstruction* const> elements) {
1794 std::vector<Shape> element_shapes;
1795 for (auto element : elements) {
1796 element_shapes.push_back(element->shape());
1797 }
1798 Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1799 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1800 }
1801
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1802 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1803 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1804 const GatherDimensionNumbers& gather_dim_numbers,
1805 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1806 return absl::make_unique<HloGatherInstruction>(
1807 shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1808 indices_are_sorted);
1809 }
1810
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1811 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1812 const Shape& shape, HloInstruction* operand,
1813 HloInstruction* scatter_indices, HloInstruction* updates,
1814 HloComputation* update_computation,
1815 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1816 bool unique_indices) {
1817 return absl::make_unique<HloScatterInstruction>(
1818 shape, operand, scatter_indices, updates, update_computation,
1819 scatter_dim_numbers, indices_are_sorted, unique_indices);
1820 }
1821
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1822 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1823 const Shape& shape, HloInstruction* operand,
1824 std::unique_ptr<DomainMetadata> operand_side_metadata,
1825 std::unique_ptr<DomainMetadata> user_side_metadata) {
1826 return absl::make_unique<HloDomainInstruction>(
1827 shape, operand, std::move(operand_side_metadata),
1828 std::move(user_side_metadata));
1829 }
1830
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1831 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1832 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1833 HloCloneContext* context) const {
1834 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1835 VLOG(3) << " new operands:";
1836 for (const HloInstruction* new_operand : new_operands) {
1837 VLOG(3) << " %" << new_operand->name();
1838 }
1839
1840 std::unique_ptr<HloInstruction> clone;
1841 // Explicitly call the factory for the instruction type. This is more robust
1842 // in the face of code changes than copying fields explicitly. This also
1843 // properly sets the user fields of the operands.
1844 switch (opcode_) {
1845 // Ops migrated to subclasses.
1846 // TODO(b/80131774): Remove this switch when migration is complete.
1847 case HloOpcode::kBatchNormTraining:
1848 case HloOpcode::kBatchNormInference:
1849 case HloOpcode::kBatchNormGrad:
1850 case HloOpcode::kFft:
1851 case HloOpcode::kCompare:
1852 case HloOpcode::kCopyStart:
1853 case HloOpcode::kSend:
1854 case HloOpcode::kSendDone:
1855 case HloOpcode::kRecv:
1856 case HloOpcode::kRecvDone:
1857 case HloOpcode::kReverse:
1858 case HloOpcode::kConcatenate:
1859 case HloOpcode::kReduce:
1860 case HloOpcode::kTranspose:
1861 case HloOpcode::kBroadcast:
1862 case HloOpcode::kReshape:
1863 case HloOpcode::kDynamicReshape:
1864 case HloOpcode::kMap:
1865 case HloOpcode::kSlice:
1866 case HloOpcode::kConstant:
1867 case HloOpcode::kTrace:
1868 case HloOpcode::kFusion:
1869 case HloOpcode::kRng:
1870 case HloOpcode::kRngBitGenerator:
1871 case HloOpcode::kRngGetAndUpdateState:
1872 case HloOpcode::kParameter:
1873 case HloOpcode::kGetTupleElement:
1874 case HloOpcode::kReducePrecision:
1875 case HloOpcode::kAllGather:
1876 case HloOpcode::kAllGatherStart:
1877 case HloOpcode::kAllReduce:
1878 case HloOpcode::kReduceScatter:
1879 case HloOpcode::kAllReduceStart:
1880 case HloOpcode::kAllToAll:
1881 case HloOpcode::kCollectivePermute:
1882 case HloOpcode::kCollectivePermuteStart:
1883 case HloOpcode::kInfeed:
1884 case HloOpcode::kOutfeed:
1885 case HloOpcode::kConvolution:
1886 case HloOpcode::kCustomCall:
1887 case HloOpcode::kReduceWindow:
1888 case HloOpcode::kSelectAndScatter:
1889 case HloOpcode::kPad:
1890 case HloOpcode::kDynamicSlice:
1891 case HloOpcode::kSort:
1892 case HloOpcode::kGather:
1893 case HloOpcode::kScatter:
1894 case HloOpcode::kIota:
1895 case HloOpcode::kDot:
1896 case HloOpcode::kDomain:
1897 case HloOpcode::kGetDimensionSize:
1898 case HloOpcode::kSetDimensionSize:
1899 case HloOpcode::kTriangularSolve:
1900 case HloOpcode::kCholesky:
1901 clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1902 break;
1903 // Unary ops.
1904 case HloOpcode::kAbs:
1905 case HloOpcode::kAllGatherDone:
1906 case HloOpcode::kAllReduceDone:
1907 case HloOpcode::kRoundNearestAfz:
1908 case HloOpcode::kBitcast:
1909 case HloOpcode::kCeil:
1910 case HloOpcode::kClz:
1911 case HloOpcode::kCollectivePermuteDone:
1912 case HloOpcode::kCopy:
1913 case HloOpcode::kCopyDone:
1914 case HloOpcode::kCos:
1915 case HloOpcode::kExp:
1916 case HloOpcode::kExpm1:
1917 case HloOpcode::kImag:
1918 case HloOpcode::kIsFinite:
1919 case HloOpcode::kFloor:
1920 case HloOpcode::kLog:
1921 case HloOpcode::kLog1p:
1922 case HloOpcode::kNot:
1923 case HloOpcode::kNegate:
1924 case HloOpcode::kPopulationCount:
1925 case HloOpcode::kReal:
1926 case HloOpcode::kRsqrt:
1927 case HloOpcode::kLogistic:
1928 case HloOpcode::kSign:
1929 case HloOpcode::kSin:
1930 case HloOpcode::kSqrt:
1931 case HloOpcode::kCbrt:
1932 case HloOpcode::kTanh:
1933 CHECK_EQ(new_operands.size(), 1);
1934 clone = CreateUnary(shape, opcode_, new_operands[0]);
1935 break;
1936 // Binary ops.
1937 case HloOpcode::kAdd:
1938 case HloOpcode::kAtan2:
1939 case HloOpcode::kComplex:
1940 case HloOpcode::kDivide:
1941 case HloOpcode::kMultiply:
1942 case HloOpcode::kSubtract:
1943 case HloOpcode::kMaximum:
1944 case HloOpcode::kMinimum:
1945 case HloOpcode::kPower:
1946 case HloOpcode::kRemainder:
1947 case HloOpcode::kAnd:
1948 case HloOpcode::kOr:
1949 case HloOpcode::kXor:
1950 case HloOpcode::kShiftLeft:
1951 case HloOpcode::kShiftRightArithmetic:
1952 case HloOpcode::kShiftRightLogical:
1953 CHECK_EQ(new_operands.size(), 2);
1954 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1955 break;
1956 // Ternary ops.
1957 case HloOpcode::kClamp:
1958 case HloOpcode::kSelect:
1959 case HloOpcode::kTupleSelect:
1960 CHECK_EQ(new_operands.size(), 3);
1961 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1962 new_operands[2]);
1963 break;
1964 // Other supported ops.
1965 case HloOpcode::kCall:
1966 clone = CreateCall(shape, new_operands, to_apply());
1967 break;
1968 case HloOpcode::kConvert:
1969 CHECK_EQ(new_operands.size(), 1);
1970 clone = CreateConvert(shape, new_operands[0]);
1971 break;
1972 case HloOpcode::kBitcastConvert:
1973 CHECK_EQ(new_operands.size(), 1);
1974 clone = CreateBitcastConvert(shape, new_operands[0]);
1975 break;
1976 case HloOpcode::kDynamicUpdateSlice:
1977 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1978 new_operands.subspan(2));
1979 break;
1980 case HloOpcode::kTuple:
1981 clone = CreateTuple(new_operands);
1982 *clone->mutable_shape() = shape;
1983 break;
1984 case HloOpcode::kWhile:
1985 CHECK_EQ(new_operands.size(), 1);
1986 clone =
1987 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1988 break;
1989 case HloOpcode::kConditional:
1990 CHECK_EQ(new_operands.size(), branch_count() + 1);
1991 clone = CreateConditional(shape, new_operands[0],
1992 absl::MakeSpan(branch_computations()),
1993 new_operands.subspan(1));
1994 break;
1995 case HloOpcode::kAfterAll:
1996 if (new_operands.empty()) {
1997 clone = CreateToken();
1998 } else {
1999 clone = CreateAfterAll(new_operands);
2000 }
2001 break;
2002 case HloOpcode::kAddDependency:
2003 CHECK_EQ(new_operands.size(), 2);
2004 clone = CreateAddDependency(new_operands[0], new_operands[1]);
2005 break;
2006 case HloOpcode::kReplicaId:
2007 CHECK_EQ(new_operands.size(), 0);
2008 clone = CreateReplicaId(shape);
2009 break;
2010 case HloOpcode::kPartitionId:
2011 CHECK_EQ(new_operands.size(), 0);
2012 clone = CreatePartitionId(shape);
2013 break;
2014 }
2015 // SetupDerivedInstruction will setup the precision_config_ field.
2016 SetupDerivedInstruction(clone.get());
2017 clone->set_parent(parent_);
2018 clone->set_outer_dimension_partitions(outer_dimension_partitions_);
2019 clone->set_raw_backend_config_string(backend_config_);
2020 if (context != nullptr) {
2021 context->MapInstruction(this, clone.get());
2022 clone->ReplaceCalledComputations([&](HloComputation* callee) {
2023 return callee->parent() != context->module()
2024 ? context->module()->DeepCloneComputation(callee, context)
2025 : callee;
2026 });
2027 }
2028 return clone;
2029 }
2030
DetachFromOperandsAndUsers()2031 void HloInstruction::DetachFromOperandsAndUsers() {
2032 if (cleaned_up_) {
2033 return;
2034 }
2035 cleaned_up_ = true;
2036 // Detach from operands. An instruction may be repeated as an operand. To
2037 // avoid calling RemoveUser twice on the same operand, check before remove.
2038 for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) {
2039 HloInstruction* operand = operands_[operand_num];
2040 if (operand == nullptr) {
2041 continue;
2042 }
2043 if (operand->user_map_.find(this) != operand->user_map_.end()) {
2044 operand->RemoveUser(this);
2045 }
2046 operands_[operand_num] = nullptr;
2047 }
2048
2049 // Update users. Set `nullptr` to the corresponding operand slot for users.
2050 for (auto& user : this->users()) {
2051 for (int i = 0; i < user->operand_count(); ++i) {
2052 if (user->operands_[i] == this) {
2053 user->operands_[i] = nullptr;
2054 }
2055 }
2056 }
2057 }
2058
CloneWithNewShape(const Shape & shape,const string & suffix,HloCloneContext * context) const2059 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
2060 const Shape& shape, const string& suffix, HloCloneContext* context) const {
2061 std::unique_ptr<HloInstruction> clone =
2062 CloneWithNewOperands(shape, operands_, context);
2063 if (suffix.empty()) {
2064 clone->name_ = name();
2065 } else {
2066 // If an instruction is cloned multiple times avoid names like
2067 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
2068 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
2069 // clone of foo.suffix2 is named foo.suffix3 and so on.
2070 const string dot_suffix = "." + suffix;
2071 size_t index = name().rfind(dot_suffix);
2072 if (index == string::npos) {
2073 // Existing name does not include ".suffix".
2074 clone->name_ = name() + dot_suffix;
2075 } else {
2076 // Existing name includes ".suffix". Determine if substring after
2077 // ".suffix" is numeric and should be replaced with an incremented number.
2078 string after_suffix = name().substr(index + dot_suffix.size());
2079 if (after_suffix.empty()) {
2080 // Existing name ends in ".suffix". New name should end in ".suffix2".
2081 clone->name_ = name() + "2";
2082 } else {
2083 // If names ends with .suffix[0-9]+ then replace with a suffix with the
2084 // numeric value incremented.
2085 int64_t numeric_suffix;
2086 if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
2087 clone->name_ =
2088 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
2089 } else {
2090 // Substring after ".suffix" is non-numeric.
2091 clone->name_ = name() + dot_suffix;
2092 }
2093 }
2094 }
2095 }
2096 return clone;
2097 }
2098
Clone(const string & suffix,HloCloneContext * context) const2099 std::unique_ptr<HloInstruction> HloInstruction::Clone(
2100 const string& suffix, HloCloneContext* context) const {
2101 std::unique_ptr<HloInstruction> clone =
2102 CloneWithNewShape(shape_, suffix, context);
2103 return clone;
2104 }
2105
2106 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const2107 HloInstruction::LatestNonGteAncestorAndIndex() const {
2108 const HloInstruction* hlo = this;
2109 ShapeIndex index;
2110 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2111 index.push_back(hlo->tuple_index());
2112 hlo = hlo->operand(0);
2113 }
2114
2115 // We built up index in the reverse order from what we want.
2116 std::reverse(index.begin(), index.end());
2117
2118 return {hlo, index};
2119 }
2120
LatestNonGteAncestor() const2121 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
2122 const HloInstruction* hlo = this;
2123 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
2124 hlo = hlo->operand(0);
2125 }
2126 return hlo;
2127 }
2128
operand(int64_t i) const2129 const HloInstruction* HloInstruction::operand(int64_t i) const {
2130 return operands_.at(i);
2131 }
2132
mutable_operand(int64_t i)2133 HloInstruction* HloInstruction::mutable_operand(int64_t i) {
2134 CHECK(operands_[i] != nullptr);
2135 return operands_.at(i);
2136 }
2137
operand_index(const HloInstruction * target) const2138 int64 HloInstruction::operand_index(const HloInstruction* target) const {
2139 for (int64_t i = 0; i < operand_count(); ++i) {
2140 if (target == operand(i)) {
2141 return i;
2142 }
2143 }
2144 LOG(FATAL) << "target was not an operand: " << target->ToString();
2145 }
2146
unique_operands() const2147 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
2148 InstructionVector unique;
2149 absl::flat_hash_set<const HloInstruction*> seen;
2150 for (HloInstruction* operand : operands()) {
2151 if (seen.insert(operand).second) {
2152 unique.push_back(operand);
2153 }
2154 }
2155 return unique;
2156 }
2157
AddControlDependencyTo(HloInstruction * instruction)2158 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
2159 TF_RET_CHECK(instruction->parent() == parent());
2160 if (!absl::c_linear_search(control_successors_, instruction)) {
2161 control_successors_.push_back(instruction);
2162 TF_RET_CHECK(
2163 !absl::c_linear_search(instruction->control_predecessors_, this));
2164 instruction->control_predecessors_.push_back(this);
2165 }
2166 return Status::OK();
2167 }
2168
RemoveControlDependencyTo(HloInstruction * instruction)2169 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
2170 TF_RET_CHECK(instruction->parent() == parent());
2171 TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
2172 TF_RETURN_IF_ERROR(
2173 EraseElementFromVector(&instruction->control_predecessors_, this));
2174 return Status::OK();
2175 }
2176
DropAllControlDeps()2177 Status HloInstruction::DropAllControlDeps() {
2178 for (auto* ctrl_succ : control_successors_) {
2179 TF_RETURN_IF_ERROR(
2180 EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
2181 }
2182 for (auto* ctrl_pred : control_predecessors_) {
2183 TF_RETURN_IF_ERROR(
2184 EraseElementFromVector(&ctrl_pred->control_successors_, this));
2185 }
2186 control_successors_.clear();
2187 control_predecessors_.clear();
2188 return Status::OK();
2189 }
2190
CopyAllControlDepsFrom(const HloInstruction * inst)2191 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2192 for (auto* ctrl_pred : inst->control_predecessors()) {
2193 TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2194 }
2195
2196 for (auto* ctrl_succ : inst->control_successors()) {
2197 TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2198 }
2199
2200 return Status::OK();
2201 }
2202
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values) const2203 bool HloInstruction::IdenticalInternal(
2204 const HloInstruction& other,
2205 const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2206 eq_operands,
2207 const std::function<bool(const HloComputation*, const HloComputation*)>&
2208 eq_computations,
2209 bool layout_sensitive, bool ignore_channel_id_values) const {
2210 // An instruction is always identical to itself.
2211 if (this == &other) {
2212 return true;
2213 }
2214
2215 // Identical instruction must have the same opcode, shape, and identical
2216 // operands.
2217 if (opcode() != other.opcode()) {
2218 return false;
2219 }
2220 if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2221 : ShapeUtil::Compatible(shape(), other.shape()))) {
2222 return false;
2223 }
2224 if (operands().size() != other.operands().size()) {
2225 return false;
2226 }
2227
2228 // Use an explicit loop rather than ContainerEquals, because copying around
2229 // std::functions may be too expensive in some cases.
2230 for (size_t i = 0; i < operands().size(); ++i) {
2231 if (!eq_operands(operand(i), other.operand(i))) {
2232 return false;
2233 }
2234 }
2235
2236 if (backend_config_ != other.backend_config_) {
2237 return false;
2238 }
2239
2240 if (ignore_channel_id_values) {
2241 if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2242 return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2243 other, eq_computations);
2244 }
2245 }
2246 return IdenticalSlowPath(other, eq_computations);
2247 }
2248
AppendOperand(HloInstruction * operand)2249 void HloInstruction::AppendOperand(HloInstruction* operand) {
2250 if (operand->parent() != nullptr) {
2251 DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2252 << "Operand " << operand->name() << " is already marked dead";
2253 }
2254 operands_.push_back(operand);
2255 operand->AddUser(this);
2256 }
2257
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2258 void HloInstruction::RemoveOperandsAtAscendingIndices(
2259 absl::Span<const int> ascending_indices) {
2260 if (ascending_indices.empty()) {
2261 return;
2262 }
2263 int next_index = 0;
2264 int removed_count = 0;
2265 for (int to_remove : ascending_indices) {
2266 while (next_index < to_remove) {
2267 operands_[next_index - removed_count] = operands_[next_index];
2268 ++next_index;
2269 }
2270 CHECK_LT(to_remove, operands_.size());
2271 ++removed_count;
2272 ++next_index;
2273 }
2274 while (next_index < operands_.size()) {
2275 operands_[next_index - removed_count] = operands_[next_index];
2276 ++next_index;
2277 }
2278 CHECK_EQ(removed_count, ascending_indices.size());
2279 operands_.resize(operands_.size() - removed_count);
2280 }
2281
AddUser(HloInstruction * user)2282 void HloInstruction::AddUser(HloInstruction* user) {
2283 if (!ContainsKey(user_map_, user)) {
2284 user_map_.emplace(user, users_.size());
2285 users_.push_back(user);
2286 }
2287 }
2288
UserId(HloInstruction * user)2289 int64 HloInstruction::UserId(HloInstruction* user) {
2290 auto result = user_map_.find(user);
2291 CHECK(result != user_map_.end());
2292 return result->second;
2293 }
2294
HasConstantOperand() const2295 bool HloInstruction::HasConstantOperand() const {
2296 for (const HloInstruction* operand : operands_) {
2297 if (operand->IsConstant()) {
2298 return true;
2299 }
2300 }
2301 return false;
2302 }
2303
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2304 bool HloInstruction::IdenticalSlowPath(
2305 const HloInstruction& other,
2306 const std::function<bool(const HloComputation*, const HloComputation*)>&
2307 eq_computations) const {
2308 // Perform opcode specific checks.
2309 switch (opcode()) {
2310 // The result of these instructions only depend upon their opcode and
2311 // operands.
2312 case HloOpcode::kAbs:
2313 case HloOpcode::kAllGatherDone:
2314 case HloOpcode::kAllReduceDone:
2315 case HloOpcode::kAtan2:
2316 case HloOpcode::kAdd:
2317 case HloOpcode::kBitcast:
2318 case HloOpcode::kBitcastConvert:
2319 case HloOpcode::kCeil:
2320 case HloOpcode::kClamp:
2321 case HloOpcode::kClz:
2322 case HloOpcode::kCollectivePermuteDone:
2323 case HloOpcode::kComplex:
2324 case HloOpcode::kConvert:
2325 case HloOpcode::kCopy:
2326 case HloOpcode::kCopyStart:
2327 case HloOpcode::kCopyDone:
2328 case HloOpcode::kCos:
2329 case HloOpcode::kDivide:
2330 case HloOpcode::kDynamicUpdateSlice:
2331 case HloOpcode::kExp:
2332 case HloOpcode::kExpm1:
2333 case HloOpcode::kFloor:
2334 case HloOpcode::kImag:
2335 case HloOpcode::kIsFinite:
2336 case HloOpcode::kLog:
2337 case HloOpcode::kLog1p:
2338 case HloOpcode::kAnd:
2339 case HloOpcode::kNot:
2340 case HloOpcode::kOr:
2341 case HloOpcode::kXor:
2342 case HloOpcode::kMaximum:
2343 case HloOpcode::kMinimum:
2344 case HloOpcode::kMultiply:
2345 case HloOpcode::kNegate:
2346 case HloOpcode::kPartitionId:
2347 case HloOpcode::kPopulationCount:
2348 case HloOpcode::kPower:
2349 case HloOpcode::kReal:
2350 case HloOpcode::kRemainder:
2351 case HloOpcode::kReshape:
2352 case HloOpcode::kDynamicReshape:
2353 case HloOpcode::kReplicaId:
2354 case HloOpcode::kRoundNearestAfz:
2355 case HloOpcode::kRsqrt:
2356 case HloOpcode::kSelect:
2357 case HloOpcode::kShiftLeft:
2358 case HloOpcode::kShiftRightArithmetic:
2359 case HloOpcode::kShiftRightLogical:
2360 case HloOpcode::kLogistic:
2361 case HloOpcode::kSign:
2362 case HloOpcode::kSin:
2363 case HloOpcode::kSqrt:
2364 case HloOpcode::kCbrt:
2365 case HloOpcode::kSubtract:
2366 case HloOpcode::kTanh:
2367 case HloOpcode::kTuple:
2368 case HloOpcode::kTupleSelect:
2369 return true;
2370
2371 // This opcode has complex or special behavior so just return false.
2372 case HloOpcode::kAfterAll:
2373 case HloOpcode::kAddDependency:
2374 return false;
2375
2376 // Remaining instructions with special values.
2377 case HloOpcode::kCall:
2378 return eq_computations(to_apply(), other.to_apply());
2379 case HloOpcode::kConditional:
2380 for (int j = 0; j < branch_count(); ++j) {
2381 if (!eq_computations(branch_computation(j),
2382 other.branch_computation(j))) {
2383 return false;
2384 }
2385 }
2386 return true;
2387 case HloOpcode::kWhile:
2388 return (eq_computations(while_body(), other.while_body()) &&
2389 eq_computations(while_condition(), other.while_condition()));
2390
2391 // Ops migrated to subclasses should never come to this line.
2392 // TODO(b/80131774): Remove this switch when migration is complete.
2393 case HloOpcode::kBatchNormTraining:
2394 case HloOpcode::kBatchNormInference:
2395 case HloOpcode::kBatchNormGrad:
2396 case HloOpcode::kFft:
2397 case HloOpcode::kCompare:
2398 case HloOpcode::kSend:
2399 case HloOpcode::kSendDone:
2400 case HloOpcode::kRecv:
2401 case HloOpcode::kRecvDone:
2402 case HloOpcode::kReverse:
2403 case HloOpcode::kConcatenate:
2404 case HloOpcode::kReduce:
2405 case HloOpcode::kSort:
2406 case HloOpcode::kTranspose:
2407 case HloOpcode::kBroadcast:
2408 case HloOpcode::kMap:
2409 case HloOpcode::kSlice:
2410 case HloOpcode::kConstant:
2411 case HloOpcode::kIota:
2412 case HloOpcode::kTrace:
2413 case HloOpcode::kFusion:
2414 case HloOpcode::kRng:
2415 case HloOpcode::kRngBitGenerator:
2416 case HloOpcode::kRngGetAndUpdateState:
2417 case HloOpcode::kParameter:
2418 case HloOpcode::kGetTupleElement:
2419 case HloOpcode::kReducePrecision:
2420 case HloOpcode::kInfeed:
2421 case HloOpcode::kOutfeed:
2422 case HloOpcode::kAllGather:
2423 case HloOpcode::kAllGatherStart:
2424 case HloOpcode::kAllReduce:
2425 case HloOpcode::kReduceScatter:
2426 case HloOpcode::kAllReduceStart:
2427 case HloOpcode::kAllToAll:
2428 case HloOpcode::kCollectivePermute:
2429 case HloOpcode::kCollectivePermuteStart:
2430 case HloOpcode::kConvolution:
2431 case HloOpcode::kCustomCall:
2432 case HloOpcode::kReduceWindow:
2433 case HloOpcode::kSelectAndScatter:
2434 case HloOpcode::kPad:
2435 case HloOpcode::kDynamicSlice:
2436 case HloOpcode::kGather:
2437 case HloOpcode::kScatter:
2438 case HloOpcode::kDot:
2439 case HloOpcode::kDomain:
2440 case HloOpcode::kGetDimensionSize:
2441 case HloOpcode::kSetDimensionSize:
2442 case HloOpcode::kTriangularSolve:
2443 case HloOpcode::kCholesky:
2444 LOG(FATAL) << "Base class impl called for opcode with subclass: "
2445 << opcode();
2446 }
2447 return false;
2448 }
2449
HashOperand(const HloInstruction * hlo)2450 static uint64 HashOperand(const HloInstruction* hlo) {
2451 return ShapeUtil::Hash(hlo->shape());
2452 }
2453
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const2454 uint64 HloInstruction::Hash(
2455 const std::function<uint64(const HloInstruction*)>& hash_operand) const {
2456 using tensorflow::Hash64Combine;
2457
2458 uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
2459 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
2460
2461 if (!IsCrossModuleAllReduce()) {
2462 if (!operands().empty()) {
2463 for (size_t i = 0; i < operands().size(); ++i) {
2464 hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
2465 }
2466 }
2467 }
2468
2469 hash_value = Hash64Combine(hash_value, InnerHash());
2470 return hash_value;
2471 }
2472
Hash() const2473 uint64 HloInstruction::Hash() const {
2474 // Use HashOperand as an argument to prevent non-termination.
2475 return Hash(HashOperand);
2476 }
2477
InnerHash() const2478 uint64 HloInstruction::InnerHash() const { return 13; }
2479
RemoveUser(HloInstruction * user)2480 void HloInstruction::RemoveUser(HloInstruction* user) {
2481 auto map_it = user_map_.find(user);
2482 CHECK(map_it != user_map_.end());
2483
2484 const int64_t index = map_it->second;
2485 CHECK_EQ(users_[index], user);
2486
2487 // Move the last user into the position of the removed user.
2488 users_[index] = users_.back();
2489 user_map_[users_.back()] = index;
2490
2491 // Remove the user from the map and drop the last slot from the vector what
2492 // have been moved to the position of the original user.
2493 user_map_.erase(map_it);
2494 users_.pop_back();
2495 }
2496
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2497 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2498 HloInstruction* new_producer) {
2499 TF_RET_CHECK(
2500 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2501 << "this shape: " << ShapeUtil::HumanString(shape())
2502 << ", replacement shape: "
2503 << ShapeUtil::HumanString(new_producer->shape());
2504 return ReplaceUseWithDifferentShape(user, new_producer);
2505 }
2506
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2507 Status HloInstruction::ReplaceUseWithDifferentShape(
2508 HloInstruction* user, HloInstruction* new_producer) {
2509 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2510 << " with " << new_producer->name();
2511
2512 RemoveUser(user);
2513
2514 TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2515 std::replace(user->operands_.begin(), user->operands_.end(), this,
2516 new_producer);
2517 new_producer->AddUser(user);
2518 // Custom fusions may not be able to handle deduplicated operands.
2519 if (user->opcode() == HloOpcode::kFusion) {
2520 TF_RETURN_IF_ERROR(
2521 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2522 }
2523 return Status::OK();
2524 }
2525
ReplaceOperandWith(int64_t operand_num,HloInstruction * new_operand)2526 Status HloInstruction::ReplaceOperandWith(int64_t operand_num,
2527 HloInstruction* new_operand) {
2528 auto old_operand = operand(operand_num);
2529 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2530 new_operand->shape()))
2531 << old_operand->shape() << " is not compatible with "
2532 << new_operand->shape();
2533 return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2534 }
2535
ReplaceOperandWithDifferentShape(int64_t operand_num,HloInstruction * new_operand)2536 Status HloInstruction::ReplaceOperandWithDifferentShape(
2537 int64_t operand_num, HloInstruction* new_operand) {
2538 TF_RET_CHECK(operand_num >= 0);
2539 TF_RET_CHECK(operand_num < operand_count());
2540 HloInstruction* old_operand = mutable_operand(operand_num);
2541 if (old_operand == new_operand) {
2542 return Status::OK();
2543 }
2544
2545 operands_[operand_num] = new_operand;
2546
2547 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2548 << new_operand->name() << ", was " << old_operand->name();
2549
2550 if (!absl::c_linear_search(operands_, old_operand)) {
2551 old_operand->RemoveUser(this);
2552 }
2553 new_operand->AddUser(this);
2554 return Status::OK();
2555 }
2556
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2557 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2558 HloInstruction* new_producer) {
2559 TF_RET_CHECK(
2560 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2561 << shape() << " is not compatible with " << new_producer->shape();
2562 return ReplaceAllUsesWithDifferentShape(users, new_producer);
2563 }
2564
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2565 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2566 absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2567 for (HloInstruction* user : users) {
2568 TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2569 }
2570
2571 if (parent_ && parent_->root_instruction() == this) {
2572 parent_->set_root_instruction(new_producer,
2573 /*accept_different_shape=*/true);
2574 }
2575 return Status::OK();
2576 }
2577
ReplaceAllUsesWith(HloInstruction * new_producer)2578 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2579 TF_RET_CHECK(
2580 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2581 << shape() << " is not compatible with " << new_producer->shape();
2582 return ReplaceAllUsesWithDifferentShape(new_producer);
2583 }
2584
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2585 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2586 HloInstruction* new_producer) {
2587 bool new_producer_is_user = false;
2588 for (HloInstruction* user : users()) {
2589 if (user == new_producer) {
2590 // It's possible that new_producer is a user of this instruction as might
2591 // be the case when replacing an instruction with a kCopy of itself. In
2592 // this case, don't do the replacement to avoid creating a cycle in the
2593 // graph. new_producer remains the only user of this instruction.
2594 new_producer_is_user = true;
2595 } else {
2596 std::replace(user->operands_.begin(), user->operands_.end(), this,
2597 new_producer);
2598 new_producer->AddUser(user);
2599 if (user->opcode() == HloOpcode::kFusion) {
2600 TF_RETURN_IF_ERROR(
2601 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2602 }
2603 }
2604 }
2605 users_.clear();
2606 user_map_.clear();
2607 if (new_producer_is_user) {
2608 AddUser(new_producer);
2609 }
2610 if (parent_ && parent_->root_instruction() == this) {
2611 parent_->set_root_instruction(new_producer,
2612 /*accept_different_shape=*/true);
2613 }
2614
2615 return Status::OK();
2616 }
2617
IsEffectiveBitcast() const2618 bool HloInstruction::IsEffectiveBitcast() const {
2619 return opcode_ == HloOpcode::kBitcast ||
2620 (opcode_ == HloOpcode::kTranspose &&
2621 ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2622 dimensions()));
2623 }
2624
to_apply() const2625 HloComputation* HloInstruction::to_apply() const {
2626 switch (opcode_) {
2627 case HloOpcode::kCall:
2628 case HloOpcode::kMap:
2629 case HloOpcode::kReduceWindow:
2630 case HloOpcode::kReduce:
2631 case HloOpcode::kAllReduce:
2632 case HloOpcode::kReduceScatter:
2633 case HloOpcode::kAllReduceStart:
2634 case HloOpcode::kScatter:
2635 case HloOpcode::kSort:
2636 case HloOpcode::kCustomCall:
2637 CHECK_EQ(called_computations_.size(), 1);
2638 return called_computations_[0];
2639 default:
2640 LOG(FATAL) << "Invalid opcode for to_apply(): "
2641 << HloOpcodeString(opcode());
2642 }
2643 }
2644
set_to_apply(HloComputation * computation)2645 void HloInstruction::set_to_apply(HloComputation* computation) {
2646 // Don't allow changing the computation for fused instructions so we don't
2647 // have to recompute called_instructions for the entire fusion instruction.
2648 CHECK(!IsFused());
2649 switch (opcode_) {
2650 case HloOpcode::kCall:
2651 case HloOpcode::kMap:
2652 case HloOpcode::kReduceWindow:
2653 case HloOpcode::kReduce:
2654 case HloOpcode::kAllReduce:
2655 case HloOpcode::kAllReduceStart:
2656 case HloOpcode::kScatter:
2657 case HloOpcode::kSort:
2658 case HloOpcode::kCustomCall:
2659 CHECK_EQ(called_computations_.size(), 1);
2660 called_computations_[0] = computation;
2661 break;
2662 default:
2663 LOG(FATAL) << "Invalid opcode for to_apply(): "
2664 << HloOpcodeString(opcode());
2665 }
2666 }
2667
while_condition() const2668 HloComputation* HloInstruction::while_condition() const {
2669 CHECK_EQ(HloOpcode::kWhile, opcode_);
2670 return called_computations_[kConditionComputationIndex];
2671 }
2672
while_body() const2673 HloComputation* HloInstruction::while_body() const {
2674 CHECK_EQ(HloOpcode::kWhile, opcode_);
2675 return called_computations_[kBodyComputationIndex];
2676 }
2677
set_while_condition(HloComputation * computation)2678 void HloInstruction::set_while_condition(HloComputation* computation) {
2679 // Don't allow changing the computation for fused instructions so we don't
2680 // have to recompute called_instructions for the entire fusion instruction.
2681 CHECK(!IsFused());
2682 CHECK_EQ(HloOpcode::kWhile, opcode_);
2683 called_computations_[kConditionComputationIndex] = computation;
2684 }
2685
set_while_body(HloComputation * computation)2686 void HloInstruction::set_while_body(HloComputation* computation) {
2687 // Don't allow changing the computation for fused instructions so we don't
2688 // have to recompute called_instructions for the entire fusion instruction.
2689 CHECK(!IsFused());
2690 CHECK_EQ(HloOpcode::kWhile, opcode_);
2691 called_computations_[kBodyComputationIndex] = computation;
2692 }
2693
while_init() const2694 HloInstruction* HloInstruction::while_init() const {
2695 CHECK_EQ(HloOpcode::kWhile, opcode_);
2696 return operands_[0];
2697 }
2698
true_computation() const2699 HloComputation* HloInstruction::true_computation() const {
2700 CHECK_EQ(HloOpcode::kConditional, opcode_);
2701 CHECK_EQ(PRED, operand(0)->shape().element_type());
2702 return called_computations_[kTrueComputationIndex];
2703 }
2704
false_computation() const2705 HloComputation* HloInstruction::false_computation() const {
2706 CHECK_EQ(HloOpcode::kConditional, opcode_);
2707 CHECK_EQ(PRED, operand(0)->shape().element_type());
2708 return called_computations_[kFalseComputationIndex];
2709 }
2710
branch_computations() const2711 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2712 const {
2713 CHECK(HloOpcode::kConditional == opcode_);
2714 return called_computations_;
2715 }
2716
branch_count() const2717 int HloInstruction::branch_count() const {
2718 CHECK(HloOpcode::kConditional == opcode_);
2719 return called_computations_.size();
2720 }
2721
branch_computation(int b) const2722 HloComputation* HloInstruction::branch_computation(int b) const {
2723 CHECK(HloOpcode::kConditional == opcode_);
2724 CHECK_GE(b, 0);
2725 CHECK_LT(b, called_computations_.size());
2726 return called_computations_[b];
2727 }
2728
set_branch_computation(int b,HloComputation * computation)2729 void HloInstruction::set_branch_computation(int b,
2730 HloComputation* computation) {
2731 // Don't allow changing the computation for fused instructions so we don't
2732 // have to recompute called_instructions for the entire fusion instruction.
2733 CHECK(!IsFused());
2734 CHECK_EQ(HloOpcode::kConditional, opcode_);
2735 called_computations_[b] = computation;
2736 }
2737
SignatureString() const2738 string HloInstruction::SignatureString() const {
2739 string operands =
2740 StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2741 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2742 });
2743 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2744 }
2745
PrintName(const string & name,bool print_ids)2746 string PrintName(const string& name, bool print_ids) {
2747 if (print_ids) {
2748 return name;
2749 } else {
2750 auto dot_position = name.find_first_of('.');
2751 return name.substr(0, dot_position);
2752 }
2753 }
2754
2755 namespace {
2756
2757 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2758
PrintNameInternal(const string & name,const HloPrintOptions & options)2759 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2760 return StrCat(options.print_percent() ? "%" : "",
2761 PrintName(name, options.print_ids()));
2762 }
2763
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2764 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2765 // This set contains HloInstructions from the top of `DFSStack` that might
2766 // belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
2767 // `subgraph` := {child,...,top}.
2768 absl::flat_hash_set<const HloInstruction*> subgraph;
2769 while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2770 subgraph.insert(dfs_stack->back().second);
2771 dfs_stack->pop_back();
2772 }
2773 // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2774 absl::flat_hash_set<const HloInstruction*> visited;
2775 absl::InlinedVector<const HloInstruction*, 16> dfs;
2776 dfs.push_back(child);
2777 while (!dfs.empty()) {
2778 bool found_next_instr = false;
2779 for (const auto& user : dfs.back()->users()) {
2780 if (user == child) {
2781 dfs.push_back(child);
2782 LOG(INFO) << "\n\nDirected cycle:\n "
2783 << absl::StrJoin(
2784 dfs, "\n ",
2785 [](std::string* out, const HloInstruction* instr) {
2786 out->append(instr->name());
2787 });
2788 return;
2789 }
2790 if (!subgraph.contains(user) || visited.contains(user)) {
2791 continue;
2792 }
2793 visited.insert(user);
2794 dfs.push_back(user);
2795 found_next_instr = true;
2796 }
2797 if (!found_next_instr) {
2798 dfs.pop_back();
2799 }
2800 }
2801 }
2802
2803 } // namespace
2804
ToString(const HloPrintOptions & options) const2805 string HloInstruction::ToString(const HloPrintOptions& options) const {
2806 CanonicalNameMap new_map;
2807 return ToStringWithCanonicalNameMap(options, &new_map);
2808 }
2809
IsOpElementwise(HloOpcode opcode)2810 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2811 switch (opcode) {
2812 // Unary elementwise operations.
2813 case HloOpcode::kAbs:
2814 case HloOpcode::kRoundNearestAfz:
2815 case HloOpcode::kCeil:
2816 case HloOpcode::kClz:
2817 case HloOpcode::kConvert:
2818 case HloOpcode::kBitcastConvert:
2819 case HloOpcode::kCopy:
2820 case HloOpcode::kCos:
2821 case HloOpcode::kExp:
2822 case HloOpcode::kExpm1:
2823 case HloOpcode::kFloor:
2824 case HloOpcode::kImag:
2825 case HloOpcode::kIsFinite:
2826 case HloOpcode::kLog:
2827 case HloOpcode::kLog1p:
2828 case HloOpcode::kNot:
2829 case HloOpcode::kNegate:
2830 case HloOpcode::kPopulationCount:
2831 case HloOpcode::kReal:
2832 case HloOpcode::kReducePrecision:
2833 case HloOpcode::kRsqrt:
2834 case HloOpcode::kLogistic:
2835 case HloOpcode::kSign:
2836 case HloOpcode::kSin:
2837 case HloOpcode::kSqrt:
2838 case HloOpcode::kCbrt:
2839 case HloOpcode::kTanh:
2840 return true;
2841
2842 // Binary elementwise operations, the same as in IsElementwiseBinary().
2843 case HloOpcode::kAdd:
2844 case HloOpcode::kAtan2:
2845 case HloOpcode::kCompare:
2846 case HloOpcode::kComplex:
2847 case HloOpcode::kDivide:
2848 case HloOpcode::kMaximum:
2849 case HloOpcode::kMinimum:
2850 case HloOpcode::kMultiply:
2851 case HloOpcode::kPower:
2852 case HloOpcode::kRemainder:
2853 case HloOpcode::kSubtract:
2854 case HloOpcode::kAnd:
2855 case HloOpcode::kOr:
2856 case HloOpcode::kXor:
2857 case HloOpcode::kShiftLeft:
2858 case HloOpcode::kShiftRightArithmetic:
2859 case HloOpcode::kShiftRightLogical:
2860 return true;
2861
2862 // Ternary elementwise operations.
2863 case HloOpcode::kSelect:
2864 case HloOpcode::kClamp:
2865 return true;
2866
2867 default:
2868 return false;
2869 }
2870 }
2871
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2872 bool HloInstruction::IsElementwiseImpl(
2873 const absl::optional<int64>& operand_idx) const {
2874 if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
2875 return operand_idx.has_value() && operand_idx.value() == 0;
2876 }
2877 return IsOpElementwise(opcode_);
2878 }
2879
IsCrossModuleAllReduce() const2880 bool HloInstruction::IsCrossModuleAllReduce() const {
2881 return opcode() == HloOpcode::kAllReduce && channel_id();
2882 }
2883
IsCrossReplicaAllReduce() const2884 bool HloInstruction::IsCrossReplicaAllReduce() const {
2885 return opcode() == HloOpcode::kAllReduce && !channel_id();
2886 }
2887
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2888 string HloInstruction::ToStringWithCanonicalNameMap(
2889 const HloPrintOptions& options,
2890 CanonicalNameMap* canonical_name_map) const {
2891 string result = "";
2892
2893 // Logic to print the instruction name (e.g. "%foo = ").
2894 if (options.canonicalize_instruction_names()) {
2895 if (options.is_in_nested_computation()) {
2896 // If we are canonicalizing instruction names and this is a top-level
2897 // HloInstruction::ToString() call, don't print an instruction name.
2898 StrAppend(&result,
2899 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2900 options),
2901 " = ");
2902 }
2903 } else {
2904 StrAppend(&result, PrintNameInternal(name(), options), " = ");
2905 }
2906
2907 if (options.print_result_shape()) {
2908 // Print shape.
2909 if (options.include_layout_in_shapes()) {
2910 StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
2911 } else {
2912 StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
2913 }
2914 }
2915
2916 // Print opcode, operand(s).
2917 StrAppend(&result, HloOpcodeString(opcode()), "(",
2918 OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2919 ")");
2920
2921 // Print additional attributes. If an instruction contains a subcomputation,
2922 // the subcomputation is also printed here.
2923 for (const string& extra : ExtraAttributesToString(options)) {
2924 StrAppend(&result, ", ", extra);
2925 }
2926
2927 if (options.print_metadata() &&
2928 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2929 !metadata_.source_file().empty())) {
2930 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2931 }
2932 if (options.print_backend_config() && !backend_config_.empty()) {
2933 StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2934 }
2935 return result;
2936 }
2937
OperandsToString(const HloPrintOptions & options) const2938 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2939 CanonicalNameMap new_map;
2940 return OperandsToStringWithCanonicalNameMap(options, &new_map);
2941 }
2942
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2943 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2944 const HloPrintOptions& options,
2945 CanonicalNameMap* canonical_name_map) const {
2946 string operands;
2947 absl::Span<HloInstruction* const> slice(operands_);
2948 const int64_t kMaxOperandsToShowIfCompact = 4;
2949 if (options.compact_operands() &&
2950 slice.size() > kMaxOperandsToShowIfCompact) {
2951 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2952 }
2953 for (int64_t i = 0; i < slice.size(); ++i) {
2954 HloInstruction* operand = slice[i];
2955 if (i != 0) {
2956 StrAppend(&operands, ", ");
2957 if (options.print_operand_index_annotation_interval() != 0 &&
2958 i % options.print_operand_index_annotation_interval() == 0) {
2959 StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i));
2960 }
2961 }
2962 // If operand is already been deleted, put `null` to the string output.
2963 if (operand == nullptr) {
2964 StrAppend(&operands, "null ");
2965 continue;
2966 }
2967 std::vector<string> str;
2968 if (options.print_operand_shape()) {
2969 if (options.include_layout_in_shapes()) {
2970 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2971 } else {
2972 str.push_back(ShapeUtil::HumanString(operand->shape()));
2973 }
2974 }
2975
2976 // In a top-level HloInstruction::ToString() call, the operand name is not
2977 // part of the canonical string.
2978 if (options.canonicalize_instruction_names() &&
2979 options.is_in_nested_computation()) {
2980 str.push_back(PrintNameInternal(
2981 canonical_name_map->LookupOrInsert(operand->name()), options));
2982 } else if (options.print_operand_names()) {
2983 str.push_back(PrintNameInternal(operand->name(), options));
2984 }
2985 StrAppend(&operands, StrJoin(str, " "));
2986 }
2987 const int64_t remaining = operands_.size() - slice.size();
2988 if (slice.size() != operands_.size()) {
2989 StrAppend(&operands, ", ...(+", remaining, ")");
2990 }
2991 return operands;
2992 }
2993
2994 namespace {
2995
IsSequentialCall(HloOpcode opcode)2996 bool IsSequentialCall(HloOpcode opcode) {
2997 switch (opcode) {
2998 case HloOpcode::kCall:
2999 case HloOpcode::kConditional:
3000 case HloOpcode::kWhile:
3001 return true;
3002 default:
3003 return false;
3004 }
3005 }
3006
3007 } // namespace
3008
ExtraAttributesToString(const HloPrintOptions & options) const3009 std::vector<string> HloInstruction::ExtraAttributesToString(
3010 const HloPrintOptions& options) const {
3011 std::vector<string> extra = options.print_extra_attributes()
3012 ? ExtraAttributesToStringImpl(options)
3013 : std::vector<string>();
3014
3015 const auto subcomputation_mode = options.print_subcomputation_mode();
3016 if (subcomputation_mode ==
3017 HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
3018 if (opcode() == HloOpcode::kWhile) {
3019 extra.push_back(StrCat(
3020 "condition=", PrintNameInternal(while_condition()->name(), options)));
3021 extra.push_back(
3022 StrCat("body=", PrintNameInternal(while_body()->name(), options)));
3023 } else if (opcode() == HloOpcode::kSelectAndScatter) {
3024 extra.push_back(
3025 StrCat("select=", PrintNameInternal(select()->name(), options)));
3026 extra.push_back(
3027 StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
3028 } else if (opcode() == HloOpcode::kConditional) {
3029 if (operand(0)->shape().element_type() == PRED) {
3030 extra.push_back(
3031 StrCat("true_computation=",
3032 PrintNameInternal(true_computation()->name(), options)));
3033 extra.push_back(
3034 StrCat("false_computation=",
3035 PrintNameInternal(false_computation()->name(), options)));
3036 } else {
3037 extra.push_back(StrCat(
3038 "branch_computations={",
3039 StrJoin(branch_computations(), ", ",
3040 [&](string* out, const HloComputation* computation) {
3041 StrAppend(
3042 out, PrintNameInternal(computation->name(), options));
3043 }),
3044 "}"));
3045 }
3046 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
3047 opcode() == HloOpcode::kReduceWindow ||
3048 opcode() == HloOpcode::kReduce ||
3049 opcode() == HloOpcode::kAllReduce ||
3050 opcode() == HloOpcode::kReduceScatter ||
3051 opcode() == HloOpcode::kAllReduceStart ||
3052 opcode() == HloOpcode::kScatter ||
3053 opcode() == HloOpcode::kSort) {
3054 extra.push_back(
3055 StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
3056 } else if (opcode() == HloOpcode::kCustomCall) {
3057 if (!called_computations().empty()) {
3058 extra.push_back(StrCat(
3059 "called_computations={",
3060 StrJoin(called_computations(), ", ",
3061 [&](string* out, const HloComputation* computation) {
3062 StrAppend(
3063 out, PrintNameInternal(computation->name(), options));
3064 }),
3065 "}"));
3066 }
3067 } else if (!called_computations().empty()) {
3068 extra.push_back(StrCat(
3069 "calls=",
3070 StrJoin(called_computations(), ", ",
3071 [&](string* out, const HloComputation* computation) {
3072 StrAppend(out,
3073 PrintNameInternal(computation->name(), options));
3074 })));
3075 }
3076 } else if ((subcomputation_mode ==
3077 HloPrintOptions::PrintSubcomputationMode::kFullBodies) ||
3078 (subcomputation_mode == HloPrintOptions::PrintSubcomputationMode::
3079 kNonSequentialBodies &&
3080 !IsSequentialCall(opcode()))) {
3081 HloPrintOptions new_options = options;
3082 new_options.set_is_in_nested_computation(true);
3083 switch (opcode()) {
3084 case HloOpcode::kWhile:
3085 extra.push_back(
3086 StrCat("condition=\n", while_condition()->ToString(new_options)));
3087 extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
3088 break;
3089 case HloOpcode::kSelectAndScatter:
3090 extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
3091 extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
3092 break;
3093 case HloOpcode::kConditional:
3094 if (operand(0)->shape().element_type() == PRED) {
3095 extra.push_back(StrCat("true_computation=\n",
3096 true_computation()->ToString(new_options)));
3097 extra.push_back(StrCat("false_computation=\n",
3098 false_computation()->ToString(new_options)));
3099 } else {
3100 extra.push_back(StrCat(
3101 "branch_computations={\n",
3102 StrJoin(branch_computations(), ",\n",
3103 [&](string* out, const HloComputation* computation) {
3104 StrAppend(out, computation->ToString(new_options));
3105 }),
3106 "\n}"));
3107 }
3108 break;
3109 case HloOpcode::kCall:
3110 case HloOpcode::kMap:
3111 case HloOpcode::kReduceWindow:
3112 case HloOpcode::kReduce:
3113 case HloOpcode::kAllReduce:
3114 case HloOpcode::kAllReduceStart:
3115 case HloOpcode::kScatter:
3116 case HloOpcode::kSort:
3117 extra.push_back(
3118 StrCat("to_apply=\n", to_apply()->ToString(new_options)));
3119 break;
3120 default:
3121 if (!called_computations().empty()) {
3122 extra.push_back(StrCat(
3123 "calls=\n",
3124 StrJoin(called_computations(), ", ",
3125 [&](string* out, const HloComputation* computation) {
3126 StrAppend(out, computation->ToString(new_options));
3127 })));
3128 }
3129 break;
3130 }
3131 }
3132
3133 if (has_sharding()) {
3134 extra.push_back(
3135 StrCat("sharding=", sharding().ToString(options.print_metadata())));
3136 }
3137 if (!frontend_attributes_.map().empty()) {
3138 extra.push_back(StrCat("frontend_attributes=",
3139 FrontendAttributesToString(frontend_attributes_)));
3140 }
3141 if (!outer_dimension_partitions_.empty()) {
3142 extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
3143 StrJoin(outer_dimension_partitions_, ",")));
3144 }
3145
3146 if (options.print_control_dependencies() && !control_predecessors_.empty()) {
3147 extra.push_back(StrCat("control-predecessors={",
3148 StrJoin(control_predecessors_, ", ",
3149 [&](string* out, HloInstruction* pre) {
3150 StrAppend(out, PrintNameInternal(
3151 pre->name(), options));
3152 }),
3153 "}"));
3154 }
3155
3156 return extra;
3157 }
3158
ToShortString() const3159 string HloInstruction::ToShortString() const {
3160 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
3161 StrJoin(operands_, ", ",
3162 [](string* out, HloInstruction* operand) {
3163 StrAppend(out, "%", operand->name());
3164 }),
3165 ")");
3166 }
3167
ToProto() const3168 HloInstructionProto HloInstruction::ToProto() const {
3169 HloInstructionProto proto;
3170 CHECK(unique_id_ != -1)
3171 << "This instruction does not have a valid id. Please make sure the "
3172 "instruction is inside a module before dumping it.";
3173 proto.set_id(unique_id_);
3174 proto.set_name(name_);
3175 proto.set_opcode(HloOpcodeString(opcode_));
3176 *proto.mutable_shape() = shape_.ToProto();
3177 for (const HloInstruction* operand : operands_) {
3178 proto.add_operand_ids(operand->unique_id());
3179 }
3180 for (const HloInstruction* control : control_predecessors_) {
3181 proto.add_control_predecessor_ids(control->unique_id());
3182 }
3183
3184 *proto.mutable_metadata() = metadata_;
3185 proto.set_backend_config(backend_config_);
3186 if (opcode() != HloOpcode::kFusion) {
3187 for (const HloComputation* computation : called_computations_) {
3188 proto.add_called_computation_ids(computation->unique_id());
3189 }
3190 }
3191
3192 if (has_sharding()) {
3193 *proto.mutable_sharding() = sharding().ToProto();
3194 }
3195 if (!outer_dimension_partitions_.empty()) {
3196 for (const auto& idx : outer_dimension_partitions_) {
3197 proto.mutable_outer_dimension_partitions()->Add(idx);
3198 }
3199 }
3200
3201 *proto.mutable_frontend_attributes() = frontend_attributes_;
3202
3203 return proto;
3204 }
3205
ToCategory() const3206 string HloInstruction::ToCategory() const {
3207 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
3208 opcode() == HloOpcode::kReshape ||
3209 opcode() == HloOpcode::kDynamicReshape) {
3210 return "data formatting";
3211 }
3212
3213 if (IsElementwise()) {
3214 return "non-fusion elementwise";
3215 }
3216
3217 return HloOpcodeString(opcode());
3218 }
3219
tracing() const3220 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
3221
set_tracing(HloInstruction * trace_instruction)3222 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
3223 trace_instruction_ = trace_instruction;
3224 }
3225
IsFused() const3226 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
3227
IsCustomCall(absl::string_view target) const3228 bool HloInstruction::IsCustomCall(absl::string_view target) const {
3229 return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3230 }
3231
IsInputFusion() const3232 bool HloInstruction::IsInputFusion() const {
3233 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3234 }
3235
IsLoopFusion() const3236 bool HloInstruction::IsLoopFusion() const {
3237 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3238 }
3239
IsOutputFusion() const3240 bool HloInstruction::IsOutputFusion() const {
3241 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3242 }
3243
IsCustomFusion() const3244 bool HloInstruction::IsCustomFusion() const {
3245 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3246 }
3247
IsFusible() const3248 bool HloInstruction::IsFusible() const {
3249 // Instructions which are traced should not be fused.
3250 if (tracing()) {
3251 return false;
3252 }
3253 // Some kinds of instructions don't make sense to fuse.
3254 switch (opcode_) {
3255 case HloOpcode::kDomain:
3256 case HloOpcode::kParameter:
3257 case HloOpcode::kWhile:
3258 case HloOpcode::kConditional:
3259 case HloOpcode::kCall:
3260 return false;
3261 // Fusions are always fusible.
3262 case HloOpcode::kFusion:
3263 // Side effecting reduce and reduce window would be invalid HLO.
3264 case HloOpcode::kMap:
3265 case HloOpcode::kReduce:
3266 case HloOpcode::kReduceWindow:
3267 return true;
3268 case HloOpcode::kRng:
3269 return user_count() <= 1;
3270 // Side effecting instructions cannot be fused.
3271 default:
3272 return !HasSideEffect();
3273 }
3274 }
3275
HloInstruction(HloOpcode opcode,const Shape & shape)3276 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3277 : unique_id_(-1),
3278 opcode_(opcode),
3279 shape_(shape),
3280 name_(HloOpcodeString(opcode)),
3281 marked_as_dead_(false) {
3282 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3283 }
3284
3285 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3286 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3287 switch (opcode_) {
3288 case HloOpcode::kAbs:
3289 return visitor->HandleAbs(this);
3290 case HloOpcode::kAtan2:
3291 return visitor->HandleAtan2(this);
3292 case HloOpcode::kRoundNearestAfz:
3293 return visitor->HandleRound(this);
3294 case HloOpcode::kBatchNormTraining:
3295 return visitor->HandleBatchNormTraining(this);
3296 case HloOpcode::kBatchNormInference:
3297 return visitor->HandleBatchNormInference(this);
3298 case HloOpcode::kBatchNormGrad:
3299 return visitor->HandleBatchNormGrad(this);
3300 case HloOpcode::kLogistic:
3301 return visitor->HandleLogistic(this);
3302 case HloOpcode::kSign:
3303 return visitor->HandleSign(this);
3304 case HloOpcode::kConstant:
3305 return visitor->HandleConstant(this);
3306 case HloOpcode::kGetTupleElement:
3307 return visitor->HandleGetTupleElement(this);
3308 case HloOpcode::kParameter:
3309 return visitor->HandleParameter(this);
3310 case HloOpcode::kCompare:
3311 return visitor->HandleCompare(this);
3312 case HloOpcode::kComplex:
3313 return visitor->HandleComplex(this);
3314 case HloOpcode::kAdd:
3315 return visitor->HandleAdd(this);
3316 case HloOpcode::kDivide:
3317 return visitor->HandleDivide(this);
3318 case HloOpcode::kSubtract:
3319 return visitor->HandleSubtract(this);
3320 case HloOpcode::kMaximum:
3321 return visitor->HandleMaximum(this);
3322 case HloOpcode::kMinimum:
3323 return visitor->HandleMinimum(this);
3324 case HloOpcode::kAnd:
3325 return visitor->HandleAnd(this);
3326 case HloOpcode::kOr:
3327 return visitor->HandleOr(this);
3328 case HloOpcode::kXor:
3329 return visitor->HandleXor(this);
3330 case HloOpcode::kShiftLeft:
3331 return visitor->HandleShiftLeft(this);
3332 case HloOpcode::kShiftRightArithmetic:
3333 return visitor->HandleShiftRightArithmetic(this);
3334 case HloOpcode::kShiftRightLogical:
3335 return visitor->HandleShiftRightLogical(this);
3336 case HloOpcode::kConcatenate:
3337 return visitor->HandleConcatenate(this);
3338 case HloOpcode::kConvert:
3339 return visitor->HandleConvert(this);
3340 case HloOpcode::kBitcastConvert:
3341 return visitor->HandleBitcastConvert(this);
3342 case HloOpcode::kCopy:
3343 return visitor->HandleCopy(this);
3344 case HloOpcode::kMultiply:
3345 return visitor->HandleMultiply(this);
3346 case HloOpcode::kDot:
3347 return visitor->HandleDot(this);
3348 case HloOpcode::kPower:
3349 return visitor->HandlePower(this);
3350 case HloOpcode::kRemainder:
3351 return visitor->HandleRemainder(this);
3352 case HloOpcode::kSelect:
3353 return visitor->HandleSelect(this);
3354 case HloOpcode::kTupleSelect:
3355 return visitor->HandleTupleSelect(this);
3356 case HloOpcode::kConvolution:
3357 return visitor->HandleConvolution(this);
3358 case HloOpcode::kFft:
3359 return visitor->HandleFft(this);
3360 case HloOpcode::kAllGather:
3361 return visitor->HandleAllGather(this);
3362 case HloOpcode::kAllGatherStart:
3363 return visitor->HandleAllGatherStart(this);
3364 case HloOpcode::kAllGatherDone:
3365 return visitor->HandleAllGatherDone(this);
3366 case HloOpcode::kAllReduce:
3367 return visitor->HandleAllReduce(this);
3368 case HloOpcode::kReduceScatter:
3369 return visitor->HandleReduceScatter(this);
3370 case HloOpcode::kAllReduceStart:
3371 return visitor->HandleAllReduceStart(this);
3372 case HloOpcode::kAllReduceDone:
3373 return visitor->HandleAllReduceDone(this);
3374 case HloOpcode::kAllToAll:
3375 return visitor->HandleAllToAll(this);
3376 case HloOpcode::kCollectivePermute:
3377 return visitor->HandleCollectivePermute(this);
3378 case HloOpcode::kCollectivePermuteStart:
3379 return visitor->HandleCollectivePermuteStart(this);
3380 case HloOpcode::kCollectivePermuteDone:
3381 return visitor->HandleCollectivePermuteDone(this);
3382 case HloOpcode::kReplicaId:
3383 return visitor->HandleReplicaId(this);
3384 case HloOpcode::kPartitionId:
3385 return visitor->HandlePartitionId(this);
3386 case HloOpcode::kTuple:
3387 return visitor->HandleTuple(this);
3388 case HloOpcode::kMap:
3389 return visitor->HandleMap(this);
3390 case HloOpcode::kClamp:
3391 return visitor->HandleClamp(this);
3392 case HloOpcode::kReduce:
3393 return visitor->HandleReduce(this);
3394 case HloOpcode::kReduceWindow:
3395 return visitor->HandleReduceWindow(this);
3396 case HloOpcode::kSelectAndScatter:
3397 return visitor->HandleSelectAndScatter(this);
3398 case HloOpcode::kNegate:
3399 return visitor->HandleNegate(this);
3400 case HloOpcode::kExp:
3401 return visitor->HandleExp(this);
3402 case HloOpcode::kExpm1:
3403 return visitor->HandleExpm1(this);
3404 case HloOpcode::kFloor:
3405 return visitor->HandleFloor(this);
3406 case HloOpcode::kCeil:
3407 return visitor->HandleCeil(this);
3408 case HloOpcode::kClz:
3409 return visitor->HandleClz(this);
3410 case HloOpcode::kLog:
3411 return visitor->HandleLog(this);
3412 case HloOpcode::kLog1p:
3413 return visitor->HandleLog1p(this);
3414 case HloOpcode::kTanh:
3415 return visitor->HandleTanh(this);
3416 case HloOpcode::kCos:
3417 return visitor->HandleCos(this);
3418 case HloOpcode::kSin:
3419 return visitor->HandleSin(this);
3420 case HloOpcode::kSqrt:
3421 return visitor->HandleSqrt(this);
3422 case HloOpcode::kCbrt:
3423 return visitor->HandleCbrt(this);
3424 case HloOpcode::kRsqrt:
3425 return visitor->HandleRsqrt(this);
3426 case HloOpcode::kReal:
3427 return visitor->HandleReal(this);
3428 case HloOpcode::kImag:
3429 return visitor->HandleImag(this);
3430 case HloOpcode::kIsFinite:
3431 return visitor->HandleIsFinite(this);
3432 case HloOpcode::kNot:
3433 return visitor->HandleNot(this);
3434 case HloOpcode::kPopulationCount:
3435 return visitor->HandlePopulationCount(this);
3436 case HloOpcode::kBitcast:
3437 return visitor->HandleBitcast(this);
3438 case HloOpcode::kBroadcast:
3439 return visitor->HandleBroadcast(this);
3440 case HloOpcode::kPad:
3441 return visitor->HandlePad(this);
3442 case HloOpcode::kReshape:
3443 return visitor->HandleReshape(this);
3444 case HloOpcode::kDynamicReshape:
3445 return visitor->HandleDynamicReshape(this);
3446 case HloOpcode::kTranspose:
3447 return visitor->HandleTranspose(this);
3448 case HloOpcode::kReverse:
3449 return visitor->HandleReverse(this);
3450 case HloOpcode::kReducePrecision:
3451 return visitor->HandleReducePrecision(this);
3452 case HloOpcode::kSlice:
3453 return visitor->HandleSlice(this);
3454 case HloOpcode::kDynamicSlice:
3455 return visitor->HandleDynamicSlice(this);
3456 case HloOpcode::kDynamicUpdateSlice:
3457 return visitor->HandleDynamicUpdateSlice(this);
3458 case HloOpcode::kSort:
3459 return visitor->HandleSort(this);
3460 case HloOpcode::kInfeed:
3461 return visitor->HandleInfeed(this);
3462 case HloOpcode::kOutfeed:
3463 return visitor->HandleOutfeed(this);
3464 case HloOpcode::kRng:
3465 return visitor->HandleRng(this);
3466 case HloOpcode::kRngBitGenerator:
3467 return visitor->HandleRngBitGenerator(this);
3468 case HloOpcode::kRngGetAndUpdateState:
3469 return visitor->HandleRngGetAndUpdateState(this);
3470 case HloOpcode::kWhile:
3471 return visitor->HandleWhile(this);
3472 case HloOpcode::kFusion:
3473 return visitor->HandleFusion(this);
3474 case HloOpcode::kCall:
3475 return visitor->HandleCall(this);
3476 case HloOpcode::kConditional:
3477 return visitor->HandleConditional(this);
3478 case HloOpcode::kCustomCall:
3479 return visitor->HandleCustomCall(this);
3480 case HloOpcode::kCopyStart:
3481 return visitor->HandleCopyStart(this);
3482 case HloOpcode::kCopyDone:
3483 return visitor->HandleCopyDone(this);
3484 case HloOpcode::kRecv:
3485 return visitor->HandleRecv(this);
3486 case HloOpcode::kRecvDone:
3487 return visitor->HandleRecvDone(this);
3488 case HloOpcode::kSend:
3489 return visitor->HandleSend(this);
3490 case HloOpcode::kSendDone:
3491 return visitor->HandleSendDone(this);
3492 case HloOpcode::kGather:
3493 return visitor->HandleGather(this);
3494 case HloOpcode::kScatter:
3495 return visitor->HandleScatter(this);
3496 case HloOpcode::kDomain:
3497 return visitor->HandleDomain(this);
3498 case HloOpcode::kAfterAll:
3499 return visitor->HandleAfterAll(this);
3500 case HloOpcode::kAddDependency:
3501 return visitor->HandleAddDependency(this);
3502 case HloOpcode::kIota:
3503 return visitor->HandleIota(this);
3504 case HloOpcode::kGetDimensionSize:
3505 return visitor->HandleGetDimensionSize(this);
3506 case HloOpcode::kSetDimensionSize:
3507 return visitor->HandleSetDimensionSize(this);
3508 case HloOpcode::kTriangularSolve:
3509 return visitor->HandleTriangularSolve(this);
3510 case HloOpcode::kCholesky:
3511 return visitor->HandleCholesky(this);
3512
3513 // These opcodes are not handled here.
3514 case HloOpcode::kTrace:
3515 return Status::OK();
3516 }
3517 return InternalError(
3518 "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3519 "please file a bug for XLA.",
3520 HloOpcodeString(opcode_));
3521 }
3522
3523 // Explicit instantiations.
3524 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3525 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3526
3527 // Push "child" onto the dfs_stack if not already visited. Returns false if a
3528 // cycle was detected, and true otherwise.
3529 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3530 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3531 HloInstruction* child) {
3532 CHECK(child != nullptr);
3533 const int id = child->unique_id();
3534 CHECK_GE(id, 0) << "instruction may not have a parent computation";
3535 switch (visitor->GetVisitState(id)) {
3536 case Visitor::kVisiting:
3537 return false;
3538
3539 case Visitor::kVisited:
3540 // Nothing to do
3541 return true;
3542
3543 case Visitor::kNotVisited:
3544 dfs_stack->push_back(std::make_pair(id, child));
3545 return true;
3546 }
3547 }
3548
3549 using InternalCompareFunction =
3550 std::function<bool(std::pair<int, const HloInstruction*>,
3551 std::pair<int, const HloInstruction*>)>;
3552 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3553 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3554 const InternalCompareFunction* operand_order,
3555 bool ignore_control_predecessors) {
3556 visitor->ReserveVisitStates(root->parent()->instruction_count());
3557
3558 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3559 //
3560 // We need to keep track of both the id and the instruction because
3561 // instructions can get deleted while they are on the stack, so we
3562 // can't always use the (potentially dead) instruction object to grab
3563 // its id.
3564 DFSStack dfs_stack;
3565 dfs_stack.emplace_back(root->unique_id(), root);
3566
3567 do {
3568 DCHECK(!dfs_stack.empty());
3569
3570 int current_id = dfs_stack.back().first;
3571 HloInstruction* current_node = dfs_stack.back().second;
3572 CHECK_GE(current_id, 0) << current_id << ": " << current_node
3573 << ": instruction may not have parent computation";
3574 typename Visitor::VisitState visit_state =
3575 visitor->GetVisitState(current_id);
3576 if (visit_state == Visitor::kVisited) {
3577 dfs_stack.pop_back();
3578 VLOG(3) << "Not visiting HLO (id = " << current_id
3579 << ") as it was already visited.";
3580 continue;
3581 }
3582
3583 if (visit_state == Visitor::kVisiting) {
3584 dfs_stack.pop_back();
3585
3586 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3587 VLOG(2) << "Visiting HLO %" << current_node->name();
3588 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3589 visitor->SetVisitState(current_id, Visitor::kVisited);
3590 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3591 continue;
3592 }
3593
3594 visitor->SetVisitState(current_id, Visitor::kVisiting);
3595
3596 const size_t old_dfs_stack_size = dfs_stack.size();
3597 for (HloInstruction* child : current_node->operands()) {
3598 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3599 PrintCycle(child, &dfs_stack);
3600 return FailedPrecondition(
3601 "A cycle is detected while visiting instruction %s",
3602 current_node->ToString());
3603 }
3604 }
3605
3606 if (!ignore_control_predecessors) {
3607 for (HloInstruction* child : current_node->control_predecessors()) {
3608 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3609 PrintCycle(child, &dfs_stack);
3610 return FailedPrecondition(
3611 "A cycle is detected while visiting instruction %s",
3612 current_node->ToString());
3613 }
3614 }
3615 }
3616
3617 if (operand_order != nullptr) {
3618 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3619 *operand_order);
3620 }
3621
3622 // This makes the traversal order the same as what you'd expect
3623 // out of a recursive algorithm.
3624 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3625 } while (!dfs_stack.empty());
3626
3627 return Status::OK();
3628 }
3629
3630 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3631 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3632 bool call_finish_visit,
3633 bool ignore_control_predecessors) {
3634 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3635 TF_RETURN_IF_ERROR(
3636 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3637 if (call_finish_visit) {
3638 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3639 }
3640 return Status::OK();
3641 }
3642
3643 // Explicit instantiations.
3644 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3645 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3646
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3647 Status HloInstruction::AcceptWithOperandOrder(
3648 DfsHloVisitor* visitor, const CompareFunction& operand_order,
3649 bool call_finish_visit) {
3650 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3651 InternalCompareFunction func = [&operand_order](
3652 std::pair<int, const HloInstruction*> a,
3653 std::pair<int, const HloInstruction*> b) {
3654 // Call the client's comparison function on the actual HloInstruction*
3655 // objects (ignoring the internal ids we also have in our stack entries)
3656 return operand_order(a.second, b.second);
3657 };
3658 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3659 /*ignore_control_predecessors=*/false));
3660 if (call_finish_visit) {
3661 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3662 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3663 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3664 }
3665 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3666 return Status::OK();
3667 }
3668
shape() const3669 const Shape& HloInstruction::shape() const { return shape_; }
3670
OperandIndices(const HloInstruction * operand) const3671 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3672 const HloInstruction* operand) const {
3673 absl::InlinedVector<int64, 4> result;
3674 for (int64_t i = 0; i < operand_count(); ++i) {
3675 if (this->operand(i) == operand) {
3676 result.push_back(i);
3677 }
3678 }
3679 return result;
3680 }
3681
IsElementwiseBinary() const3682 bool HloInstruction::IsElementwiseBinary() const {
3683 return IsElementwise() && operand_count() == 2;
3684 }
3685
IsElementwise() const3686 bool HloInstruction::IsElementwise() const {
3687 return IsElementwiseImpl(absl::nullopt);
3688 }
3689
IsElementwiseOnOperand(int64_t operand_idx) const3690 bool HloInstruction::IsElementwiseOnOperand(int64_t operand_idx) const {
3691 return IsElementwiseImpl(operand_idx);
3692 }
3693
3694 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3695 // in HloInstruction::OperandElementUse below.
3696 class HloInstruction::FusionReusesParamElements {
3697 public:
3698 using UseKind = HloInstruction::UseKind;
3699
3700 // We could rather iterate backwards through fused_instructions_ here, as it
3701 // is in reverse postorder, and compute whether each fused instruction reuses
3702 // the value of this parameter, which would save stack space but not allow us
3703 // to finish early if we find a reuse.
Compute(int64_t i,const HloInstruction & hlo)3704 static UseKind Compute(int64_t i, const HloInstruction& hlo) {
3705 absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3706 return ComputeInternal(i, hlo, &memoization_cache);
3707 }
3708
3709 private:
ComputeInternal(int64_t i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3710 static UseKind ComputeInternal(
3711 int64_t i, const HloInstruction& hlo,
3712 absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3713 if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3714 if (hlo_param->parameter_number() == i) {
3715 return UseKind::kUse;
3716 }
3717 }
3718
3719 auto p = cache->emplace(&hlo, UseKind::kNoUse);
3720 auto value_it = p.first;
3721 const bool key_is_new = p.second;
3722
3723 if (key_is_new) {
3724 for (int64_t j = 0; j < hlo.operands_.size(); ++j) {
3725 UseKind old_val = value_it->second;
3726
3727 // The next operation invalidates iterators.
3728 UseKind new_val =
3729 Fold(old_val,
3730 FoldUseMandatory(hlo.OperandElementUse(j),
3731 ComputeInternal(i, *hlo.operand(j), cache)));
3732
3733 // Re-acquire the iterator. We could work harder to do this only if
3734 // absolutely necessary, but this code is not hot enough to warrant
3735 // that.
3736 value_it = cache->find(&hlo);
3737 value_it->second = new_val;
3738 // Fold() minimizes the UseKind value. If it is already minimum, we can
3739 // break the loop early.
3740 if (new_val == UseKind::kReuse) {
3741 break;
3742 }
3743 }
3744 }
3745 return value_it->second;
3746 }
3747
3748 // Combines two UseKinds.
3749 //
3750 // This is the min operation on the lattice
3751 //
3752 // kReuse < kUse < kNoUse.
3753 //
3754 // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3755 static UseKind Fold(UseKind a, UseKind b) {
3756 // Without loss of generality, let `b` be the operation with the larger use
3757 // kind.
3758 if (b.kind < a.kind) {
3759 std::swap(a, b);
3760 }
3761 // If the kinds are different, return the smaller one, namely `a`.
3762 if (a.kind != b.kind) {
3763 return a;
3764 }
3765 // If the kinds are both kUse, check that they're the same permutation.
3766 if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3767 a.permutation_instr != b.permutation_instr) {
3768 return UseKind::kReuse;
3769 }
3770 return a; // They're the same.
3771 }
3772
3773 // Combines two UseKinds differently than Fold().
3774 //
3775 // This is the min operation on the lattice
3776 //
3777 // kNoUse < kReuse < kUse.
3778 //
3779 // If `a` and `b` are both kUse and one has a non-null permutation
3780 // instruction, returns kUse with that permutation. OTOH if both have
3781 // different, non-null permutation instructions, returns kReuse.
3782 //
3783 // You can think of this sort of as a conjunction, whereas Fold is sort of a
3784 // disjunction. FoldUseMandatory() says "no use" if either input isn't used,
3785 // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3786 static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3787 if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3788 return UseKind::kNoUse;
3789 }
3790 if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3791 return UseKind::kReuse;
3792 }
3793 if (a.permutation_instr == b.permutation_instr) {
3794 return a; // They're the same.
3795 }
3796 if (b.permutation_instr == nullptr) {
3797 return a;
3798 }
3799 if (a.permutation_instr == nullptr) {
3800 return b;
3801 }
3802 return UseKind::kReuse;
3803 }
3804 };
3805
OperandElementUse(int64_t operand_num) const3806 HloInstruction::UseKind HloInstruction::OperandElementUse(
3807 int64_t operand_num) const {
3808 switch (opcode_) {
3809 case HloOpcode::kBitcast:
3810 // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3811 // doesn't permute its elements, so it counts as a plain, non-permuting
3812 // use.
3813 return ShapeUtil::DropDegenerateDimensions(shape()) ==
3814 ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3815 ? UseKind::kUse
3816 : UseKind::Permuting(this);
3817 case HloOpcode::kConcatenate:
3818 case HloOpcode::kReshape:
3819 case HloOpcode::kReverse:
3820 case HloOpcode::kSlice:
3821 case HloOpcode::kTranspose:
3822 return UseKind::Permuting(this);
3823 case HloOpcode::kPad:
3824 // Pad reuses the padding value but not the padded array elements.
3825 return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3826 case HloOpcode::kReduce:
3827 // Reduce reuses the init values but not the operand array elements.
3828 return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3829 ? UseKind::kReuse
3830 : UseKind::Permuting(this);
3831 case HloOpcode::kFusion:
3832 // Uses the memoizing, recursive computation defined above.
3833 return FusionReusesParamElements::Compute(operand_num,
3834 *fused_expression_root());
3835 case HloOpcode::kDot:
3836 // Matrix-vector dots do not reuse the matrix operand.
3837 if (shape().dimensions_size() <= 1) {
3838 if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3839 (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3840 return UseKind::kUse;
3841 }
3842 }
3843 return UseKind::kReuse;
3844 case HloOpcode::kDynamicUpdateSlice:
3845 // Dynamic-update-slice reuses only start_indices.
3846 if (operand_num == 0 || operand_num == 1) {
3847 return UseKind::kUse;
3848 }
3849 return UseKind::kReuse;
3850 case HloOpcode::kGather:
3851 // Gather reads its indices in a linear fashion, and it permutes the
3852 // vector it's gathering from.
3853 return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3854 default:
3855 return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3856 }
3857 }
3858
3859 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3860 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3861 if (HloOpcode::kReshape != opcode_) {
3862 return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3863 }
3864 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3865 shape_);
3866 }
3867
ToString(HloInstruction::FusionKind kind)3868 string ToString(HloInstruction::FusionKind kind) {
3869 switch (kind) {
3870 case HloInstruction::FusionKind::kLoop:
3871 return "kLoop";
3872 case HloInstruction::FusionKind::kInput:
3873 return "kInput";
3874 case HloInstruction::FusionKind::kOutput:
3875 return "kOutput";
3876 case HloInstruction::FusionKind::kCustom:
3877 return "kCustom";
3878 }
3879 }
3880
StringToFusionKind(const string & kind_name)3881 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3882 const string& kind_name) {
3883 if (kind_name == "kLoop") {
3884 return HloInstruction::FusionKind::kLoop;
3885 }
3886 if (kind_name == "kInput") {
3887 return HloInstruction::FusionKind::kInput;
3888 }
3889 if (kind_name == "kOutput") {
3890 return HloInstruction::FusionKind::kOutput;
3891 }
3892 if (kind_name == "kCustom") {
3893 return HloInstruction::FusionKind::kCustom;
3894 }
3895 return InvalidArgument("Unknown fusion kind: %s", kind_name);
3896 }
3897
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3898 string FrontendAttributesToString(
3899 const FrontendAttributes& frontend_attributes) {
3900 std::vector<std::pair<string, string>> sorted_attributes(
3901 frontend_attributes.map().begin(), frontend_attributes.map().end());
3902 absl::c_sort(sorted_attributes);
3903 // Frontend attribute is a comma-separated list of attribute="value" pairs,
3904 // e.g., frontend_attributes={name="value_a",type="int32"}.
3905 const auto formatter = [](string* out,
3906 const std::pair<string, string>& item) {
3907 absl::StrAppend(out, item.first, "=\"", item.second, "\"");
3908 };
3909 return absl::StrFormat("{%s}",
3910 absl::StrJoin(sorted_attributes, ",", formatter));
3911 }
3912
PaddingConfigToString(const PaddingConfig & padding)3913 string PaddingConfigToString(const PaddingConfig& padding) {
3914 bool has_interior_padding =
3915 absl::c_any_of(padding.dimensions(),
3916 [](const PaddingConfig::PaddingConfigDimension& dim) {
3917 return dim.interior_padding() != 0;
3918 });
3919 return StrJoin(
3920 padding.dimensions(), "x",
3921 [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3922 StrAppend(
3923 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3924 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3925 });
3926 }
3927
RandomDistributionToString(const RandomDistribution & distribution)3928 string RandomDistributionToString(const RandomDistribution& distribution) {
3929 return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3930 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)3931 string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
3932 return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
3933 }
3934
PrecisionToString(const PrecisionConfig::Precision & precision)3935 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3936 return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3937 }
3938
CustomCallScheduleToString(const CustomCallSchedule & schedule)3939 static string CustomCallScheduleToString(const CustomCallSchedule& schedule) {
3940 return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule));
3941 }
3942
CustomCallApiVersionToString(const CustomCallApiVersion & schedule)3943 static string CustomCallApiVersionToString(
3944 const CustomCallApiVersion& schedule) {
3945 return absl::AsciiStrToLower(CustomCallApiVersion_Name(schedule));
3946 }
3947
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3948 string ConvolutionDimensionNumbersToString(
3949 const ConvolutionDimensionNumbers& dnums) {
3950 auto len_required = [](int64_t a, int64_t b, absl::Span<const int64> cs) {
3951 return std::max({a, b, cs.empty() ? 0 : *absl::c_max_element(cs)}) + 1;
3952 };
3953
3954 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3955 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3956 std::vector<string> lhs_dims(len_required(dnums.input_batch_dimension(),
3957 dnums.input_feature_dimension(),
3958 dnums.input_spatial_dimensions()),
3959 "?");
3960 lhs_dims[dnums.input_batch_dimension()] = 'b';
3961 lhs_dims[dnums.input_feature_dimension()] = 'f';
3962 for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3963 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3964 }
3965
3966 std::vector<string> rhs_dims(
3967 len_required(dnums.kernel_input_feature_dimension(),
3968 dnums.kernel_output_feature_dimension(),
3969 dnums.kernel_spatial_dimensions()),
3970 "?");
3971 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3972 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3973 for (int64_t i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3974 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3975 }
3976
3977 std::vector<string> output_dims(
3978 len_required(dnums.output_batch_dimension(),
3979 dnums.output_feature_dimension(),
3980 dnums.output_spatial_dimensions()),
3981 "?");
3982 output_dims[dnums.output_batch_dimension()] = 'b';
3983 output_dims[dnums.output_feature_dimension()] = 'f';
3984 for (int64_t i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3985 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3986 }
3987
3988 return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3989 StrJoin(output_dims, ""));
3990 }
3991
ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups)3992 string ReplicaGroupsToString(absl::Span<const ReplicaGroup> replica_groups) {
3993 std::vector<string> replica_group_str;
3994 replica_group_str.reserve(replica_groups.size());
3995 for (const ReplicaGroup& group : replica_groups) {
3996 replica_group_str.push_back(
3997 StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3998 }
3999 return StrCat("{", StrJoin(replica_group_str, ","), "}");
4000 }
4001
StringToRandomAlgorithm(const string & name)4002 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
4003 static std::unordered_map<string, RandomAlgorithm>* map = [] {
4004 static auto* map = new std::unordered_map<string, RandomAlgorithm>;
4005 for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
4006 if (RandomAlgorithm_IsValid(i)) {
4007 auto value = static_cast<RandomAlgorithm>(i);
4008 (*map)[RandomAlgorithmToString(value)] = value;
4009 }
4010 }
4011 return map;
4012 }();
4013 auto found = map->find(absl::AsciiStrToLower(name));
4014 if (found == map->end()) {
4015 return InvalidArgument("Unknown algorithm");
4016 }
4017 return found->second;
4018 }
4019
StringToRandomDistribution(const string & name)4020 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
4021 static std::unordered_map<string, RandomDistribution>* map = [] {
4022 static auto* map = new std::unordered_map<string, RandomDistribution>;
4023 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
4024 if (RandomDistribution_IsValid(i)) {
4025 auto value = static_cast<RandomDistribution>(i);
4026 (*map)[RandomDistributionToString(value)] = value;
4027 }
4028 }
4029 return map;
4030 }();
4031 auto found = map->find(absl::AsciiStrToLower(name));
4032 if (found == map->end()) {
4033 return InvalidArgument("Unknown distribution");
4034 }
4035 return found->second;
4036 }
4037
StringToPrecision(const string & name)4038 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
4039 static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
4040 static auto* map =
4041 new std::unordered_map<string, PrecisionConfig::Precision>;
4042 for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
4043 if (PrecisionConfig::Precision_IsValid(i)) {
4044 auto value = static_cast<PrecisionConfig::Precision>(i);
4045 (*map)[PrecisionToString(value)] = value;
4046 }
4047 }
4048 return map;
4049 }();
4050 auto found = map->find(absl::AsciiStrToLower(name));
4051 if (found == map->end()) {
4052 return InvalidArgument("Unknown distribution");
4053 }
4054 return found->second;
4055 }
4056
StringToCustomCallSchedule(absl::string_view name)4057 StatusOr<CustomCallSchedule> StringToCustomCallSchedule(
4058 absl::string_view name) {
4059 static const absl::flat_hash_map<string, CustomCallSchedule>* map = [] {
4060 static auto* map = new absl::flat_hash_map<string, CustomCallSchedule>;
4061 for (int i = 0; i < CustomCallSchedule_ARRAYSIZE; i++) {
4062 if (CustomCallSchedule_IsValid(i)) {
4063 auto value = static_cast<CustomCallSchedule>(i);
4064 (*map)[CustomCallScheduleToString(value)] = value;
4065 }
4066 }
4067 return map;
4068 }();
4069 auto found = map->find(absl::AsciiStrToLower(name));
4070 if (found == map->end()) {
4071 return InvalidArgument("Unknown schedule");
4072 }
4073 return found->second;
4074 }
4075
StringToCustomCallApiVersion(absl::string_view name)4076 StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
4077 absl::string_view name) {
4078 static const absl::flat_hash_map<string, CustomCallApiVersion>* map = [] {
4079 static auto* map = new absl::flat_hash_map<string, CustomCallApiVersion>;
4080 for (int i = 0; i < CustomCallApiVersion_ARRAYSIZE; i++) {
4081 if (CustomCallApiVersion_IsValid(i)) {
4082 auto value = static_cast<CustomCallApiVersion>(i);
4083 (*map)[CustomCallApiVersionToString(value)] = value;
4084 }
4085 }
4086 return map;
4087 }();
4088 auto found = map->find(absl::AsciiStrToLower(name));
4089 if (found == map->end()) {
4090 return InvalidArgument("Unknown API version");
4091 }
4092 return found->second;
4093 }
4094
operator <<(std::ostream & os,HloInstruction::FusionKind kind)4095 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
4096 return os << ToString(kind);
4097 }
4098
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const4099 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
4100 const HloInstruction* const& rhs) const {
4101 if (rhs == nullptr) {
4102 // Nothing compares less than nullptr.
4103 return false;
4104 }
4105 if (lhs == nullptr) {
4106 return true;
4107 }
4108 auto lhs_module = lhs->GetModule();
4109 auto rhs_module = rhs->GetModule();
4110 CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
4111 (lhs_module != nullptr && rhs_module != nullptr));
4112 if (lhs_module != nullptr &&
4113 lhs_module->unique_id() != rhs_module->unique_id()) {
4114 return lhs_module->unique_id() < rhs_module->unique_id();
4115 }
4116 return lhs->unique_id() < rhs->unique_id();
4117 }
4118
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const4119 Status HloInstruction::GetBackendConfigInternal(
4120 tensorflow::protobuf::Message* proto) const {
4121 proto->Clear();
4122
4123 // Empty string does not parse as valid JSON, but it's a valid backend config,
4124 // corresponding to the empty proto.
4125 if (backend_config_.empty()) {
4126 return Status::OK();
4127 }
4128 return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
4129 }
4130
set_backend_config(const tensorflow::protobuf::Message & proto)4131 Status HloInstruction::set_backend_config(
4132 const tensorflow::protobuf::Message& proto) {
4133 TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
4134 return Status::OK();
4135 }
4136
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)4137 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
4138 const tensorflow::protobuf::Message& proto) {
4139 string ret;
4140 // Pass ignore_accuracy_loss = true because estimated_cycles field can be
4141 // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
4142 // INT64_MAX, JsonFormat will return an error status, although there is no
4143 // accuracy loss for int64.
4144 TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
4145 proto, &ret, /*ignore_accuracy_loss=*/true));
4146 return ret;
4147 }
4148
precision_config() const4149 const PrecisionConfig& HloInstruction::precision_config() const {
4150 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4151 return convolution->precision_config();
4152 }
4153 if (auto* dot = DynCast<HloDotInstruction>(this)) {
4154 return dot->precision_config();
4155 }
4156
4157 if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
4158 return custom_call->precision_config();
4159 }
4160 LOG(FATAL) << "Unimplemented method.";
4161 }
4162
mutable_precision_config()4163 PrecisionConfig* HloInstruction::mutable_precision_config() {
4164 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
4165 return convolution->mutable_precision_config();
4166 }
4167 if (auto* dot = DynCast<HloDotInstruction>(this)) {
4168 return dot->mutable_precision_config();
4169 }
4170 LOG(FATAL) << "Unimplemented method.";
4171 }
4172
GetModule() const4173 HloModule* HloInstruction::GetModule() const {
4174 if (parent_) {
4175 return parent_->parent();
4176 }
4177 return nullptr;
4178 }
4179
UniquifyName(NameUniquer * name_uniquer)4180 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
4181 string parent_str = parent() == nullptr ? "noparent" : parent()->name();
4182 name_ = name_uniquer->GetUniqueName(name_);
4183 }
4184
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)4185 void HloInstruction::set_outer_dimension_partitions(
4186 const std::vector<int64>& outer_dimension_partitions) {
4187 outer_dimension_partitions_ = outer_dimension_partitions;
4188 }
4189
4190 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const4191 int64 HloInstruction::feature_index() const {
4192 return Cast<HloBatchNormInstruction>(this)->feature_index();
4193 }
4194
epsilon() const4195 float HloInstruction::epsilon() const {
4196 return Cast<HloBatchNormInstruction>(this)->epsilon();
4197 }
4198
fft_type() const4199 FftType HloInstruction::fft_type() const {
4200 return Cast<HloFftInstruction>(this)->fft_type();
4201 }
4202
fft_length() const4203 const std::vector<int64>& HloInstruction::fft_length() const {
4204 return Cast<HloFftInstruction>(this)->fft_length();
4205 }
4206
concatenate_dimension() const4207 int64 HloInstruction::concatenate_dimension() const {
4208 return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
4209 }
4210
dimension() const4211 int64 HloInstruction::dimension() const {
4212 if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
4213 return set_size->dimension();
4214 }
4215 return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
4216 }
4217
inferred_dimension() const4218 int64 HloInstruction::inferred_dimension() const {
4219 return Cast<HloReshapeInstruction>(this)->inferred_dimension();
4220 }
4221
IsRank2Transpose() const4222 bool HloInstruction::IsRank2Transpose() const {
4223 auto transpose = DynCast<HloTransposeInstruction>(this);
4224 return transpose != nullptr && transpose->IsRank2Transpose();
4225 }
4226
slice_starts(int64_t dimension) const4227 int64 HloInstruction::slice_starts(int64_t dimension) const {
4228 return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
4229 }
4230
slice_starts() const4231 const std::vector<int64>& HloInstruction::slice_starts() const {
4232 return Cast<HloSliceInstruction>(this)->slice_starts();
4233 }
4234
mutable_slice_starts()4235 std::vector<int64>* HloInstruction::mutable_slice_starts() {
4236 return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
4237 }
4238
slice_limits(int64_t dimension) const4239 int64 HloInstruction::slice_limits(int64_t dimension) const {
4240 return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
4241 }
4242
slice_limits() const4243 const std::vector<int64>& HloInstruction::slice_limits() const {
4244 return Cast<HloSliceInstruction>(this)->slice_limits();
4245 }
4246
mutable_slice_limits()4247 std::vector<int64>* HloInstruction::mutable_slice_limits() {
4248 return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
4249 }
4250
slice_strides(int64_t dimension) const4251 int64 HloInstruction::slice_strides(int64_t dimension) const {
4252 return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
4253 }
4254
slice_strides() const4255 const std::vector<int64>& HloInstruction::slice_strides() const {
4256 return Cast<HloSliceInstruction>(this)->slice_strides();
4257 }
4258
mutable_slice_strides()4259 std::vector<int64>* HloInstruction::mutable_slice_strides() {
4260 return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
4261 }
4262
literal() const4263 const Literal& HloInstruction::literal() const {
4264 return Cast<HloConstantInstruction>(this)->literal();
4265 }
4266
IsConstant() const4267 bool HloInstruction::IsConstant() const {
4268 return DynCast<HloConstantInstruction>(this) != nullptr;
4269 }
4270
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)4271 void HloInstruction::RelayoutConstant(const Layout& new_layout,
4272 const ShapeIndex& shape_index) {
4273 Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
4274 }
4275
TracingTag() const4276 string HloInstruction::TracingTag() const {
4277 return Cast<HloTraceInstruction>(this)->TracingTag();
4278 }
4279
AddFusionOperand(HloInstruction * new_operand)4280 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
4281 return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
4282 }
4283
4284 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)4285 void HloInstruction::MergeFusionInstruction(
4286 HloInstruction* instruction_to_merge) {
4287 return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
4288 Cast<HloFusionInstruction>(instruction_to_merge));
4289 }
4290
4291 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)4292 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
4293 HloInstruction* instruction_to_merge) {
4294 return Cast<HloFusionInstruction>(this)
4295 ->MergeFusionInstructionIntoMultiOutput(
4296 Cast<HloFusionInstruction>(instruction_to_merge));
4297 }
4298
FuseInstruction(HloInstruction * instruction_to_fuse)4299 HloInstruction* HloInstruction::FuseInstruction(
4300 HloInstruction* instruction_to_fuse) {
4301 return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4302 }
4303
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4304 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4305 HloInstruction* instruction_to_fuse) {
4306 return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4307 instruction_to_fuse);
4308 }
4309
fused_instructions_computation() const4310 HloComputation* HloInstruction::fused_instructions_computation() const {
4311 return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4312 }
4313
fused_expression_root() const4314 HloInstruction* HloInstruction::fused_expression_root() const {
4315 return Cast<HloFusionInstruction>(this)->fused_expression_root();
4316 }
4317
4318 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4319 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4320 HloInstruction::fused_instructions() const {
4321 return Cast<HloFusionInstruction>(this)->fused_instructions();
4322 }
4323
4324 const tensorflow::gtl::iterator_range<
4325 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4326 HloInstruction::fused_instructions() {
4327 return Cast<HloFusionInstruction>(this)->fused_instructions();
4328 }
4329
fused_instruction_count() const4330 int64 HloInstruction::fused_instruction_count() const {
4331 return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4332 }
4333
fused_parameter(int64_t parameter_number) const4334 HloInstruction* HloInstruction::fused_parameter(
4335 int64_t parameter_number) const {
4336 return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4337 }
4338
fused_parameters() const4339 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4340 return Cast<HloFusionInstruction>(this)->fused_parameters();
4341 }
4342
IsMultiOutputFusion() const4343 const bool HloInstruction::IsMultiOutputFusion() const {
4344 const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4345 return fusion != nullptr && fusion->IsMultiOutputFusion();
4346 }
4347
fusion_kind() const4348 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4349 return Cast<HloFusionInstruction>(this)->fusion_kind();
4350 }
4351
set_fusion_kind(FusionKind kind)4352 void HloInstruction::set_fusion_kind(FusionKind kind) {
4353 return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4354 }
4355
random_distribution() const4356 RandomDistribution HloInstruction::random_distribution() const {
4357 return Cast<HloRngInstruction>(this)->random_distribution();
4358 }
4359
parameter_number() const4360 int64 HloInstruction::parameter_number() const {
4361 return Cast<HloParameterInstruction>(this)->parameter_number();
4362 }
4363
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4364 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4365 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4366 return Cast<HloParameterInstruction>(this)
4367 ->set_parameter_replicated_at_leaf_buffers(
4368 parameter_replicated_at_leaf_buffers);
4369 }
4370
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4371 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4372 const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4373 return Cast<HloParameterInstruction>(this)
4374 ->set_parameter_replicated_at_leaf_buffers(
4375 parameter_replicated_at_leaf_buffers);
4376 }
4377
4378 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4379 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4380 return Cast<HloParameterInstruction>(this)
4381 ->parameter_replicated_at_leaf_buffers();
4382 }
4383
tuple_index() const4384 int64 HloInstruction::tuple_index() const {
4385 return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4386 }
4387
set_tuple_index(int64_t new_tuple_index)4388 void HloInstruction::set_tuple_index(int64_t new_tuple_index) {
4389 return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4390 new_tuple_index);
4391 }
4392
exponent_bits() const4393 int32 HloInstruction::exponent_bits() const {
4394 return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4395 }
4396
mantissa_bits() const4397 int32 HloInstruction::mantissa_bits() const {
4398 return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4399 }
4400
infeed_config() const4401 string HloInstruction::infeed_config() const {
4402 return Cast<HloInfeedInstruction>(this)->infeed_config();
4403 }
4404
set_infeed_config(const string & config)4405 void HloInstruction::set_infeed_config(const string& config) {
4406 return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4407 }
4408
outfeed_shape() const4409 const Shape& HloInstruction::outfeed_shape() const {
4410 return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4411 }
4412
mutable_outfeed_shape()4413 Shape* HloInstruction::mutable_outfeed_shape() {
4414 return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4415 }
4416
outfeed_config() const4417 const string& HloInstruction::outfeed_config() const {
4418 return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4419 }
4420
set_outfeed_config(const string & config)4421 void HloInstruction::set_outfeed_config(const string& config) {
4422 return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4423 }
4424
replica_groups() const4425 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4426 return Cast<HloCollectiveInstruction>(this)->replica_groups();
4427 }
4428
4429 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const4430 HloInstruction::source_target_pairs() const {
4431 return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4432 }
4433
channel_id() const4434 absl::optional<int64> HloInstruction::channel_id() const {
4435 return Cast<HloChannelInstruction>(this)->channel_id();
4436 }
4437
set_channel_id(const absl::optional<int64> & channel_id)4438 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
4439 return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4440 }
4441
4442 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4443 HloInstruction::convolution_dimension_numbers() const {
4444 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4445 return convolution->convolution_dimension_numbers();
4446 }
4447 if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4448 return custom_call->convolution_dimension_numbers();
4449 }
4450 LOG(FATAL) << "Unimplemented method.";
4451 }
4452
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4453 void HloInstruction::set_convolution_dimension_numbers(
4454 const ConvolutionDimensionNumbers& dnums) {
4455 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4456 convolution->set_convolution_dimension_numbers(dnums);
4457 } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4458 custom_call->set_convolution_dimension_numbers(dnums);
4459 } else {
4460 LOG(FATAL) << "Unimplemented method.";
4461 }
4462 }
4463
feature_group_count() const4464 int64 HloInstruction::feature_group_count() const {
4465 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4466 return convolution->feature_group_count();
4467 }
4468 return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4469 }
4470
set_feature_group_count(int64_t feature_group_count)4471 void HloInstruction::set_feature_group_count(int64_t feature_group_count) {
4472 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4473 return convolution->set_feature_group_count(feature_group_count);
4474 }
4475 Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4476 feature_group_count);
4477 }
4478
batch_group_count() const4479 int64 HloInstruction::batch_group_count() const {
4480 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4481 return convolution->batch_group_count();
4482 }
4483 return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4484 }
4485
set_batch_group_count(int64_t batch_group_count)4486 void HloInstruction::set_batch_group_count(int64_t batch_group_count) {
4487 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4488 return convolution->set_batch_group_count(batch_group_count);
4489 }
4490 Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4491 batch_group_count);
4492 }
4493
select() const4494 HloComputation* HloInstruction::select() const {
4495 return Cast<HloSelectAndScatterInstruction>(this)->select();
4496 }
4497
scatter() const4498 HloComputation* HloInstruction::scatter() const {
4499 return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4500 }
4501
set_select(HloComputation * computation)4502 void HloInstruction::set_select(HloComputation* computation) {
4503 return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4504 }
4505
set_scatter(HloComputation * computation)4506 void HloInstruction::set_scatter(HloComputation* computation) {
4507 return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4508 }
4509
custom_call_target() const4510 const string& HloInstruction::custom_call_target() const {
4511 return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4512 }
4513
padding_config() const4514 const PaddingConfig& HloInstruction::padding_config() const {
4515 return Cast<HloPadInstruction>(this)->padding_config();
4516 }
4517
padding_type() const4518 PaddingType HloInstruction::padding_type() const {
4519 return Cast<HloCustomCallInstruction>(this)->padding_type();
4520 }
4521
mutable_padding_config()4522 PaddingConfig* HloInstruction::mutable_padding_config() {
4523 return Cast<HloPadInstruction>(this)->mutable_padding_config();
4524 }
4525
slice_sizes(int64_t dimension) const4526 int64 HloInstruction::slice_sizes(int64_t dimension) const {
4527 return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4528 }
4529
dynamic_slice_sizes() const4530 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
4531 return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4532 }
4533
4534 const std::vector<std::vector<int64_t>>&
dynamic_slice_sizes_list() const4535 HloInstruction::dynamic_slice_sizes_list() const {
4536 return Cast<HloCollectivePermuteInstruction>(this)
4537 ->dynamic_slice_sizes_list();
4538 }
4539
gather_dimension_numbers() const4540 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4541 return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4542 }
4543
gather_slice_sizes() const4544 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
4545 return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4546 }
4547
scatter_dimension_numbers() const4548 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4549 const {
4550 return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4551 }
4552
dot_dimension_numbers() const4553 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4554 return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4555 }
4556
operand_side_metadata() const4557 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4558 return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4559 }
4560
user_side_metadata() const4561 const DomainMetadata& HloInstruction::user_side_metadata() const {
4562 return Cast<HloDomainInstruction>(this)->user_side_metadata();
4563 }
4564
is_cross_program_prefetch() const4565 bool HloInstruction::is_cross_program_prefetch() const {
4566 return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4567 }
4568
comparison_direction() const4569 ComparisonDirection HloInstruction::comparison_direction() const {
4570 return Cast<HloCompareInstruction>(this)->direction();
4571 }
4572
triangular_solve_options() const4573 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4574 return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4575 }
4576
cholesky_options() const4577 const CholeskyOptions& HloInstruction::cholesky_options() const {
4578 return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4579 }
4580
4581 } // namespace xla
4582